In [1]:
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from torch_geometric.loader import DataLoader
import torch

In [2]:
from modules.nn import SimplePeriodicNetwork

batch_size = 10 
radial_cutoff = 5

net = SimplePeriodicNetwork(
        irreps_in="1x1o",  
        irreps_out="1x0e",  # Single scalar (L=0 and even parity) to output (for example) energy
        max_radius=radial_cutoff, # Cutoff radius for convolution
        num_neighbors=10.0,  # scaling factor based on the typical number of neighbors
        pool_nodes=True,  # We pool nodes to predict total energy
    )

criterion = nn.MSELoss()  # Example: Mean Squared Error
optimizer = optim.Adam(net.parameters(), lr=0.001)

num_epochs = 50
# project_name = 'LiCondEquivariantModel'
project_name = None
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
from sevenn.train.dataload import graph_build
from torch_geometric.loader import DataLoader
from sevenn.train.dataset import AtomGraphDataset
from ase.calculators.singlepoint import SinglePointCalculator
from sevenn.train.dataload import _set_atoms_y
from typing import Any, List, Optional
import numpy as np
import torch
from torch_geometric.loader.dataloader import Collater
import sevenn


import time

def assign_dummy_y(atoms):
    dummy = {'energy': np.nan, 'free_energy': np.nan}
    dummy['forces'] = np.full((len(atoms), 3), np.nan) 
    dummy['stress'] = np.full((6,), np.nan)  
    calc = SinglePointCalculator(atoms, **dummy)
    atoms = calc.get_atoms()
    return calc.get_atoms()


class SevenNetPropertiesPredictor():
    def __init__(
            self,
            config_name,
        ):

        checkpoint = sevenn.util.pretrained_name_to_path(config_name)
        sevennet_model, sevennet_config = sevenn.util.model_from_checkpoint(checkpoint)

        self.sevennet_model = sevennet_model
        self.sevennet_config = sevennet_config


    def predict(self, batch: List[Any]) -> List[Any]:

        atoms_list = []    
        atoms_len = []
        for atoms in batch:
            atoms_list.append(assign_dummy_y(atoms))
            atoms_len.append(atoms.get_positions().shape[0])

        atoms_list = _set_atoms_y(atoms_list)

        sevennet_data_list = graph_build(
                    atoms_list,
                    self.sevennet_config['cutoff'],
                    num_cores=max(1, self.sevennet_config['_num_workers']),
                    y_from_calc=False,
                )
        
        start_time = time.perf_counter()

        sevennet_inference_set = AtomGraphDataset(sevennet_data_list, self.sevennet_config['cutoff'])
        sevennet_inference_set.x_to_one_hot_idx(self.sevennet_config['_type_map'])
        sevennet_inference_set.toggle_requires_grad_of_data(sevenn._keys.POS, True)
        sevennet_infer_list = sevennet_inference_set.to_list()

        sevennet_batch = DataLoader(sevennet_infer_list, batch_size=len(sevennet_infer_list), shuffle=False)

        end_time = time.perf_counter()

        create_dataset_time = end_time - start_time

        (sevennet_batch,) = sevennet_batch

        start_time = time.perf_counter()

        sevennet_output = self.sevennet_model(sevennet_batch).detach().to("cpu")

        end_time = time.perf_counter()

        sevennet_model_inference = end_time - start_time

        forces = []
        energies = []
        total_lenn = 0

        start_time = time.perf_counter()

        for index, lenn in enumerate(atoms_len):
            forces.append(sevennet_output.inferred_force[total_lenn:total_lenn+lenn, :].clone().detach())
            energies.append(sevennet_output.inferred_total_energy[index].clone().detach())
            total_lenn += lenn

        end_time = time.perf_counter()
        build_results = end_time - start_time

        return  {
            'forces': forces,
            'energy': energies,
        }, create_dataset_time, sevennet_model_inference, build_results

In [6]:
from modules.dataset import build_dataset 

dataset = build_dataset(csv_path = '../data/sevennet_slopes.csv')
checkpoint_name = '7net-0'
SevennetPredictor = SevenNetPropertiesPredictor(checkpoint_name)

batch_size = 150
dataloader = DataLoader(dataset, batch_size=batch_size)


100%|██████████| 179/179 [00:00<00:00, 475.67it/s]


In [7]:
from copy import deepcopy
import time
from torch_geometric.data import Batch

from torch_geometric.data import Data
from torch_geometric.loader.dataloader import Collater
from tqdm import tqdm
import ase.io
from pymatgen.io.ase import AseAtomsAdaptor


def set_noise_to_structures(batch):
    for atoms in batch:
        positions = atoms.get_positions() 
        noise = np.random.normal(loc=0, scale=1, size=positions.shape)    
        atoms.set_positions(positions + noise)
    return batch

graph_building = []
model_evaluation = []
sevennet_property_predictor = []
batch_time = []

create_dataset_sevennet = []
model_inference_sevennet = []
build_results_sevennet = []

for batch in dataloader:

    batch_start = time.perf_counter()

    atoms_batch = batch.x['atoms']
    noise_structures_batch = set_noise_to_structures(deepcopy(atoms_batch))
    log_diffusion_batch = batch.x['log_diffusion']

    start_time = time.perf_counter()
    properties_batch, create_dataset_time, model_inference_time, build_results_time = SevennetPredictor.predict(noise_structures_batch)
    end_time = time.perf_counter()

    create_dataset_sevennet.append(create_dataset_time)
    model_inference_sevennet.append(model_inference_time)
    build_results_sevennet.append(build_results_time)

    sevennet_property_predictor.append(end_time - start_time)

    atoms_list = []    

    start_time = time.perf_counter()

    for log_diffusion, noise_structures, forces in zip(log_diffusion_batch, noise_structures_batch, properties_batch['forces']):
        atoms = noise_structures
        edge_src, edge_dst, edge_shift = ase.neighborlist.neighbor_list("ijS", a=atoms, cutoff= 5, self_interaction=True) 

        data = Data(
                pos=torch.tensor(atoms.get_positions(), dtype=torch.float32),
                x=forces,
                lattice=torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0), 
                edge_index=torch.stack([torch.LongTensor(edge_src), torch.LongTensor(edge_dst)], dim=0),
                edge_shift=torch.tensor(edge_shift, dtype=torch.float32),
                target = log_diffusion
        )

        atoms_list.append(data)

    atoms_batch = Batch.from_data_list(atoms_list)
    
    end_time = time.perf_counter()

    graph_building.append(end_time - start_time)

    start_time = time.perf_counter()
    outputs = net(atoms_batch)
    end_time = time.perf_counter()

    batch_end = time.perf_counter()

    batch_time.append(batch_end - batch_start)
    model_evaluation.append(end_time - start_time)



graph_build (1): 100%|██████████| 150/150 [00:05<00:00, 26.98it/s]
graph_build (1): 100%|██████████| 29/29 [00:01<00:00, 28.11it/s]


In [8]:
print('graph_building:', np.array(graph_building).mean())
print('model_evaluation:', np.array(model_evaluation).mean())

print('create_dataset_sevennet:', np.array(create_dataset_sevennet).mean())
print('model_inference_sevennet:', np.array(model_inference_sevennet).mean())
print('build_results_sevennet:', np.array(build_results_sevennet).mean())

print('batch_time:', np.array(batch_time).mean())

graph_building: 2.6488343398086727
model_evaluation: 18.829915893729776
create_dataset_sevennet: 0.008556838845834136
model_inference_sevennet: 40.66284448839724
build_results_sevennet: 0.0020445329137146473
batch_time: 64.89997762395069
