In [1]:
import h5py as hf
import numpy as np
#import tensorflow as tf
import pyvista as pv
import pyvistaqt as pvqt
import time
from threading import Thread

In [2]:
db = hf.File('database.hdf5', 'r')
data_C = np.array(db['C']['Positions'])
data_H = np.array(db['H']['Positions'])
data_N = np.array(db['N']['Positions'])
data_F = np.array(db['F']['Positions'])
db.close()

In [3]:
class TrajectoryVisualizer:
    """
    Class for visualizing trajectories of atoms.
    """
    def __init__(self):
        self.data_dictionary = {}
        
        self.plotter: pvqt.BackgroundPlotter
        
    def _prepare_canvas(self):
        """
        Prepare the pyqt canvas.
        """
        self.plotter = pvqt.BackgroundPlotter(notebook=True)
        #self.plotter = pv.Plotter()
        
    def _update_positions(self, start: int, stop: int):
        """
        Update positions
        """
        for item in self.data_dictionary:
            position_tensor = self.data_dictionary[item]['tensor'][:, stop]
            for i, sphere in enumerate(self.data_dictionary[item]['spheres']):
                self.plotter.update_coordinates(position_tensor[i], sphere)
        #self.plotter.update()
    
    def _construct_spheres(self):
        """
        Construct the sphere objects
        """
        for item in self.data_dictionary:
            self.data_dictionary[item]['spheres'] = []
            radius = self.data_dictionary[item]['mass']
            for atom in self.data_dictionary[item]['tensor']:
                self.data_dictionary[item]['spheres'].append(pv.Sphere(radius=radius, center=atom[0]))
    
    def _draw_spheres(self):
        """
        draw the spheres on the canvas.
        """
        for item in self.data_dictionary:
            colour = self.data_dictionary[item]['colour']
            for sphere in self.data_dictionary[item]['spheres']:
                self.plotter.add_mesh(sphere, lighting=False, show_edges=False, color=colour)
    
    def _loop_configurations(self, start=1):
        """
        Loop over configurations until end or interrupted.
        """
        if start == 0:
            start = 1
        species = list(self.data_dictionary)
        loop_range = int(len(self.data_dictionary[species[0]]['tensor'][start:]) - 1)
        for i in range(start, loop_range):
            self._update_positions(start=i-1, stop=i)
            
        
                
    def run_visualization(self):
        """
        Run a visualization and hold open for inputs.
        """
        self._prepare_canvas()
        self._construct_spheres()
        self._draw_spheres()
        thread = Thread(target=self._loop_configurations())
        thread.start()
        
        

In [4]:
vis = TrajectoryVisualizer()
vis.data_dictionary = {'C': {'tensor': data_C, 'mass': 1.0, 'colour': 'black'},
                       'N': {'tensor': data_N, 'mass': 0.7, 'colour': 'red'},
                       'H': {'tensor': data_H, 'mass': 0.2, 'colour': 'white'},
                       'F': {'tensor': data_F, 'mass': 0.5, 'colour': 'green'}
                      }

In [None]:
vis.run_visualization()