In [1]:
from ase import Atom
from ase import Atoms
import nglview as nv
from IPython.display import display
import numpy as np
from ase.io import read
import ase.io
import torch
from torch_geometric.data import Data
from ase.neighborlist import primitive_neighbor_list
from sklearn.preprocessing import LabelEncoder
from tqdm.notebook import trange
from torch import nn
from functools import partial
from graphite.nn.basis import bessel



In [2]:
def write_lammps_data_file(atoms, filename="mos2_0.8.lammps"):
    unique_atomic_numbers = sorted(set(atoms.numbers))
    atomic_number_to_type = {num: idx + 1 for idx, num in enumerate(unique_atomic_numbers)}
    
    with open(filename, "w") as file:
        # Write header information
        file.write("ITEM: TIMESTEP\n0\n")
        file.write("ITEM: NUMBER OF ATOMS\n")
        file.write(f"{len(atoms)}\n")
        file.write("ITEM: BOX BOUNDS pp pp pp\n")
        for i in range(3):  # Assuming orthorhombic box for simplicity
            bounds = (0, atoms.cell.lengths()[i])
            file.write(f"{bounds[0]} {bounds[1]}\n")
        
        # Write atoms section
        file.write("ITEM: ATOMS type id x y z\n")
        for i, atom in enumerate(atoms, start=1):
            atom_type = atomic_number_to_type[atom.number]
            file.write(f"{atom_type} {i} {atom.position[0]} {atom.position[1]} {atom.position[2]}\n")

In [3]:
def ase_graph(data, cutoff):
    i, j, D = primitive_neighbor_list('ijD', cutoff=cutoff, pbc=data.pbc, cell=data.cell, positions=data.pos.numpy(), numbers=data.numbers)
    data.edge_index = torch.tensor(np.stack((i, j)), dtype=torch.long)
    data.edge_attr = torch.tensor(D, dtype=torch.float)
    return data

In [4]:
@torch.no_grad()
def denoise_snapshot(atoms, model, scale=1.0, steps=8):
    # Convert to PyG format
    x = LabelEncoder().fit_transform(atoms.numbers)
    data = Data(
        x       = torch.tensor(x).long(),
        pos     = torch.tensor(atoms.positions).float(),
        cell    = atoms.cell,
        pbc     = atoms.pbc,
        numbers = atoms.numbers,
    )
    
    # Scale
    data.pos  *= scale
    data.cell *= scale
    
    # Denoising trajectory
    pos_traj = [atoms.positions]    
    for _ in trange(steps):
        data = ase_graph(data, cutoff=CUTOFF)
        disp = model(data)
        data.pos -= disp
        pos_traj.append(data.pos.clone().numpy() / scale)
    
    return pos_traj

In [5]:
class InitialEmbedding(nn.Module):
    def __init__(self, num_species, cutoff):
        super().__init__()
        self.embed_node_x = nn.Embedding(num_species, 8)
        self.embed_node_z = nn.Embedding(num_species, 8)
        self.embed_edge   = partial(bessel, start=0.0, end=cutoff, num_basis=16)
    
    def forward(self, data):
        # Embed node
        data.h_node_x = self.embed_node_x(data.x)
        data.h_node_z = self.embed_node_z(data.x)

        # Embed edge
        data.h_edge = self.embed_edge(data.edge_attr.norm(dim=-1))
        
        return data