In [1]:
import warnings
warnings.filterwarnings("ignore")
import torch
from lightning import seed_everything
from e3nn.io import CartesianTensor
from numpy import array
import pandas as pd
from ase import Atoms
import numpy as np
from anisonet.data import BaseDataset
seed_everything(1234)
from tqdm import tqdm

default_dtype = torch.float64
torch.set_default_dtype(default_dtype)

ct = CartesianTensor("ij=ji")

Seed set to 1234


In [2]:
df_MP = pd.read_pickle("../dataset/df_with_dim.p") # replace with your dataset
df_MP = df_MP.head(20)

df_MP['target'] = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0] for _ in range(len(df_MP))] # make an empty target column just for graph construction

dataset = BaseDataset(df_MP, cutoff=5)

100%|████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 1609.51it/s]


In [3]:
from anisonet.model import E3nnModel
net = E3nnModel(                           # This is the model itself
    in_dim=118,                            # dimension of one-hot encoding of atom type
    em_dim=48,                             # dim of node_input
    in_attr_dim=118,
    em_attr_dim=48,                        # dim of node_attr
    irreps_out=str(ct),
    layers=2,                              # number of nonlinearities (number of convolutions = layers + 1)
    mul=48,                                # multiplicity of irreducible representations
    lmax=3,                                # maximum order of spherical harmonics
    max_radius=dataset.cutoff,             # cutoff radius for convolution
    number_of_basis=15,
    num_neighbors=dataset.num_neighbors,   # scaling factor based on the typical number of neighbors
    reduce_output=True,                    # whether or not to aggregate features of all atoms at the end
    same_em_layer=True
)

from anisonet.train import BaseLightning
model = BaseLightning(                     # this is the lightning wrapper
    dataset,
    net,
    batch_size=12,
    optimizer=None,
    scheduler=None
)

checkpoint_path = "anisonet-stock.ckpt"     # Change this to the actual path
checkpoint = torch.load(checkpoint_path)
state_dict = checkpoint['state_dict']
adjusted_state_dict = {key.replace("model.", ""): value for key, value in state_dict.items()}
net.load_state_dict(adjusted_state_dict)
model.setup()

In [4]:
net = net.eval().to("cuda")
dataloader = model.train_dataloader()
out = torch.empty(0, 6).to("cuda")

with torch.no_grad():
    for batch in tqdm(dataloader):
        batch.to("cuda")
        predictions = net(batch)
        out = torch.cat((out, predictions), dim=0)

cart_pred = ct.to_cartesian(torch.tensor(out.detach().to("cpu")))
cart_pred_scalar = [np.linalg.eigvalsh(x).mean() for x in cart_pred]
ar_pred = [np.linalg.eigvalsh(x).max()/np.linalg.eigvalsh(x).min() for x in cart_pred]

100%|████████████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.27s/it]


In [5]:
ar_pred

[np.float64(1.0368581018918632),
 np.float64(1.1114086654215025),
 np.float64(1.2459092036863955),
 np.float64(1.0000000000047657),
 np.float64(1.0000000000000004),
 np.float64(1.0000001723005107),
 np.float64(1.1536169543788604),
 np.float64(1.0000000000000002),
 np.float64(1.0000000000000004),
 np.float64(1.0000332458823564),
 np.float64(1.1123274135053172),
 np.float64(1.0000000000000002),
 np.float64(1.3368763550954919),
 np.float64(1.0156822157422156),
 np.float64(1.0351968177676922),
 np.float64(1.0000122543412633)]