### Dataset Download

Dataset can be downloaded from yt-hub https://girder.hub.yt/#folder/5e6d2a7168085e00018c9088, and there should be 1594 files in total

### Environment Setup

We will use the ASE package to digest the `.traj` files and here is the installation instruction https://wiki.fysik.dtu.dk/ase/install.html

In [1]:
from ase.io.trajectory import Trajectory
import matplotlib.pyplot as plt
import os
import h5py
import numpy as np

### File Information

In [2]:
def gather_folder_info(root_folder):
    # Create an empty dictionary to store information
    folder_info = {}
    
    # Check if the provided root folder exists
    if not os.path.exists(root_folder):
        print(f"The folder {root_folder} does not exist.")
        return folder_info

    # Walk through the directory
    for dirpath, dirnames, filenames in os.walk(root_folder):
        # Check if the current directory is the root directory
        if dirpath == root_folder:
            continue
        
        # Filter out files that match the 'traj' extension
        traj_files = [f for f in filenames if f.endswith('.traj')]

        # If there are 'traj' files in the current directory, store them in the dictionary
        if traj_files:
            relative_path = os.path.relpath(dirpath, root_folder)
            folder_info[relative_path] = traj_files

    return folder_info

In [3]:
def generate_full_paths(base_folder, folder_dictionary):
    full_paths = []

    for subfolder, files in folder_dictionary.items():
        for file in files:
            # Join the base folder, subfolder, and file name to get the full path
            full_path = os.path.join(base_folder, subfolder, file)
            full_paths.append(full_path)

    return full_paths

In [4]:
full_paths_list = generate_full_paths("files",gather_folder_info("files") )

In [5]:
full_paths_list[:5]

['files/f54-ipc/npt-p175-t1800-b0.dmc_mean.traj',
 'files/f54-ipc/npt-p175-t600-b0.dmc_mean.traj',
 'files/f54-ipc/npt-p200-t1200-b0.dmc_mean.traj',
 'files/f54-ipc/npt-p175-t800-b0.dmc_mean.traj',
 'files/f54-ipc/npt-p175-t1600-b0.dmc_mean.traj']

## Visualization

In [6]:
fn = full_paths_list[0]
traj = Trajectory(fn)
number_atom = len(traj)
traj_positions = []
for atom_id in range(number_atom):
    position = traj[atom_id].get_positions()
    traj_positions.append(position)
threejs_array = np.array(traj_positions)

In [7]:
threejs_array.shape

(13, 96, 3)

In [8]:
threejs_array.min()

-45.22910020991855

In [9]:
threejs_array.max()

52.15400022557109

In [10]:
from pythreejs import *
from IPython.display import display, clear_output
import numpy as np
import ipywidgets as widgets
import time


def center_positions_around_average(arr):
    """Subtracts the average position at each timestamp from every atom's position."""
    avg_positions = np.mean(arr, axis=0)
    centered_positions = arr - avg_positions
    return centered_positions

def visualize_atomic_positions(arr):
    arr = center_positions_around_average(arr)  # Center the positions
    
    # Validate array shape
    if arr.shape[1] != 96 or arr.shape[2] != 3:
        raise ValueError("The array dimensions do not match the expected shape (_, 96, 3).")

    # Create a basic material for the atoms
    atom_material = MeshStandardMaterial(color='blue', roughness=0.8, metalness=0.8)

    # Trajectory material
    trajectory_material = LineBasicMaterial(color='lightgray')  # This will make the trajectory appear red
    trajectories = []
    
    # Create a list to hold all atomic meshes for the first timestamp
    meshes = []
    
    # Calculate cube center and dimensions once
    min_bounds = np.min(arr, axis=(0,1))
    max_bounds = np.max(arr, axis=(0,1))
    bounding_dimensions = max_bounds - min_bounds
    cube_center = (max_bounds + min_bounds) / 2.0

    # Add a bounding cube to the scene
    cube_geometry = BoxGeometry(bounding_dimensions[0], bounding_dimensions[1], bounding_dimensions[2])
    cube_material = LineBasicMaterial(color='gray')
    edges = EdgesGeometry(cube_geometry)
    bounding_cube = LineSegments(edges, cube_material)
    bounding_cube.position = cube_center.tolist()
    meshes.append(bounding_cube)

    for atom_idx in range(arr.shape[0]):
        x, y, z = arr[atom_idx, 0]
        # Create a mesh for the atom and add to the list
        geometry = SphereGeometry(radius=1)
        mesh = Mesh(geometry=geometry, material=atom_material, position=(x, y, z))
        meshes.append(mesh)
    
    # Scene setup
    scene = Scene(children=meshes)


    # Add trajectories for each atom
    for atom_idx in range(arr.shape[0]):
        # For initialization, just use the first position
        trajectory_positions = arr[atom_idx, :1]
        trajectory_geometry = BufferGeometry(attributes={
            'position': BufferAttribute(trajectory_positions, normalized=False)
        })
        trajectory_line = Line(geometry=trajectory_geometry, material=trajectory_material)
        scene.add(trajectory_line)
        trajectories.append(trajectory_line)
    
    # Modify the scene
    point_light = PointLight(position=[cube_center[0], cube_center[1], cube_center[2]], intensity=10)
    ambient_light = AmbientLight(intensity=1.5)
    scene.add(point_light)
    scene.add(ambient_light)
    camera_distance = np.linalg.norm(bounding_dimensions) / np.tan(np.pi/8) 
    camera_position = [cube_center[0], cube_center[1], cube_center[2] + camera_distance]
    camera = PerspectiveCamera(position=camera_position, fov=40)
    camera.lookAt(cube_center.tolist())
    renderer = Renderer(scene=scene, camera=camera, controls=[OrbitControls(controlling=camera)], width=600, height=400, background='black')

    # Slider widget to control timestamp
    timestamp_slider = widgets.IntSlider(
        value=0,
        min=0,
        max=arr.shape[1]-1,
        step=1,
        description='Timestamp:',
        continuous_update=True
    )

    # Autoplay button and its functionality
    autoplay_button = widgets.ToggleButton(value=False, description='Autoplay', button_style='')

    def update(change):
        timestamp = change['new']
        for atom_idx, mesh in enumerate(meshes[1:]): # Excluding the bounding_cube
            mesh.position = tuple(arr[atom_idx, timestamp])

    def on_autoplay_button_click(change):
        if change['new']:  # If button turned on
            for timestamp in range(timestamp_slider.value, arr.shape[1]):
                timestamp_slider.value = timestamp
                
                # Update trajectory lines
                for atom_idx, trajectory_line in enumerate(trajectories):
                    trajectory_positions = arr[atom_idx, :timestamp+1]  # Up to the current timestamp
                    trajectory_line.geometry.attributes['position'].array = trajectory_positions
                    trajectory_line.geometry.attributes['position'].needsUpdate = True
                time.sleep(0.2)  # Interval between updates

            autoplay_button.value = False

    def on_reset_button_click(_):
        # Reset atom positions
        for atom_idx, mesh in enumerate(scene.children[1:arr.shape[0]+1]):  # Adjust indices to point to atom meshes
            mesh.position = tuple(arr[atom_idx, 0])
            
        # Clear trajectories
        for trajectory_line in trajectories:
            trajectory_positions = arr[atom_idx, :1]  # Just the first position
            trajectory_line.geometry.attributes['position'].array = trajectory_positions
            trajectory_line.geometry.attributes['position'].needsUpdate = True

        # Reset autoplay button and timestamp slider
        autoplay_button.value = False
        timestamp_slider.value = 0

    reset_button = widgets.Button(description="Reset")
    reset_button.on_click(on_reset_button_click)
    
    timestamp_slider.observe(update, names='value')
    autoplay_button.observe(on_autoplay_button_click, names='value')
    
    # Display the renderer, the slider, and the autoplay button
    display(widgets.VBox([timestamp_slider, autoplay_button, reset_button, renderer]))

In [11]:
# each timestamp, we will get (13,3)
threejs_array[:, 0, :]
# mean for the array
np.mean(threejs_array[:, 0, :],axis=0)
threejs_array[:, 0, :]-np.mean(threejs_array[:, 0, :],axis=0)
# center_positions_around_average(threejs_array)[:,0,:]

array([[ -4.65558541,   1.41192623, -13.10233006],
       [  7.81929465,  -5.35155379,  20.05178009],
       [ -3.7954154 ,  -1.03208377, -12.53833004],
       [ -2.4028554 ,   6.71027626,  -9.77243005],
       [ -7.15022542,   1.36971623,   1.59637001],
       [ -0.92597539,   4.67604625, -10.68553005],
       [  3.95998463,  -4.34627879,   7.55397003],
       [ -3.5671554 ,  -1.87788378,  -1.62073002],
       [ -3.7046354 ,   3.70284625,  -2.35023001],
       [  6.07659464,  -4.17460479,   9.75797004],
       [  7.64049464,  -3.31533678,  15.37485007],
       [ -3.9888154 ,   5.39493625,  -4.23913002],
       [  4.69429464,  -3.16800578,  -0.02623   ]])

In [12]:
visualize_atomic_positions(threejs_array)



VBox(children=(IntSlider(value=0, description='Timestamp:', max=95), ToggleButton(value=False, description='Au…