In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import os, sys, logging, torch
import pandas as pd
import numpy as np
from torch import nn, tensor
from functools import partial

sys.path.append(os.path.join(os.getcwd(), '..'))
from utils.create_batch import sort_atoms, momentum_transform, geom_transform, get_batch
from utils.DataLoaders_jupyter import Get_Dataset, Create_DataLoaders
from utils.DataLoaders_noSplit import Get_Dataset_noSplit, Create_DataLoaders_noSplit
from utils.Learner import Learner
from utils.Callbacks import TrainCB, DeviceCB, MetricsCB, BatchSchedCB, SaveCB
from utils.utils_ML import Adam, set_seed, LayerNorm, MeanAbsoluteError, EDM_MSELoss, MSELoss, EDM_MSELoss_EvalLoss
from utils.Callbacks import to_device, ProgressCB
from utils.download import download_dataset_weights, load_obj
from torcheval.metrics import MeanSquaredError
from model.XMolNet import XMolNet


torch.set_printoptions(precision=5, linewidth=140, sci_mode=False)
torch.manual_seed(1)
logging.disable(logging.WARNING)
def_device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('device:',def_device)


%load_ext autoreload
%autoreload 2

Z_dict = {'H':1, 'C':6, 'N':7, 'O':8, 'F':9, 'Si': 14, 'P':15, 'S': 16, 'Cl':17, 'Br':35, 'I':53}
Z_dict_inverse = {1: 'H', 6: 'C', 7: 'N', 8: 'O', 9: 'F', 14: 'Si', 15: 'P', 16: 'S', 17: 'Cl'}
indent = " " * 0

device: cuda


In [2]:
download_dataset_weights()

## Set up model

In [3]:
epochs = 3
lr = 0.001
beta1 = 0.9
beta2 = 0.99

sv_step_size = 200
sv_prefix = 'model_'

diffu_params = \
    {'y_c': tensor([0., -0., -0.]),
     'y_hw': tensor([23.59669, 19.97351, 15.76245]),
     'n_diffu': 1,#8,
     'P_mean': -1.2,
     'P_std': 1.2,
     'sigma_data': 0.25}


natts_diffu = 2
    
num_steps =  15

z_max = 20
z_emb_dim = 64
q_max = 20
q_emb_dim = 64
pos_out_dim = 64

natts = 6
scale = 4
att_dim = int(128*scale)
nheads = int(8*scale)

denoise = True
sample = False
loss_eval_c = 1
err_bin_center = torch.arange(0, 10, 0.05)

model = XMolNet(z_max, z_emb_dim, q_max, q_emb_dim, pos_out_dim, att_dim=att_dim, diffu_params = diffu_params, natts=natts, nheads=nheads, 
                dot_product=True, res=True, act1=nn.ReLU, act2=nn.ReLU, norm=LayerNorm, attention_type='full', lstm=True, sumup=False, 
                num_steps=num_steps, natts_diffu=natts_diffu, err_bin_center=err_bin_center, denoise=denoise, sample=sample)



In [4]:
tls_train, tls_valid = Get_Dataset(path='../download/dataset/dataset_2_7/', rank=0) 
dls = Create_DataLoaders(tls_train, tls_valid, batch_size=128, sampler=None, vshuffle=True)

load_path = '../download/model/molexa.chk'

loss_func = EDM_MSELoss_EvalLoss(diffu_params=diffu_params, loss_eval_c = loss_eval_c, err_bin_center=err_bin_center, device='cuda')
val_loss_func = MSELoss()
MCB = MetricsCB(mae=MeanAbsoluteError(), mse_eval=MeanSquaredError())

SCB = SaveCB(epochs=epochs, step_size=sv_step_size, save_dir='.', prefix=sv_prefix)
cbs = [TrainCB(), DeviceCB(), MCB, ProgressCB(plot=False)]
xtra = []
learn = Learner(model.to(def_device), dls, loss_func=loss_func, val_loss_func=val_loss_func, cbs=cbs+xtra, load_path=load_path, opt_func=torch.optim.Adam, lr=lr, 
                beta1=beta1, beta2=beta2) 

learn.model.num_steps=5
learn.model.sigma_min=0.002
learn.model.sigma_max=80
learn.model.rho=1.5#7

learn.model.S_churn=30
learn.model.S_min=0.01
learn.model.S_max=1
learn.model.S_noise=1.1#1.007

learn.model.heun=False
learn.model.step_scale = 1

# of valid: 8580 # of train: 60832


## Get model performance on validation data

In [18]:
learn.fit(1, train=False, valid=True)

epoch           loss_v          mae_v           mse_eval_v      lr              beta1           beta2           time            rank           


0               1.04305         0.535756527     0.474880189     0.001000000     0.900           0.990           23:19:16,01-19  0              


## Single molecule prediction

### 1. Input

#### Option A - Use the data in the test dataset

In [5]:
# Check available molecules in the test dataset
test_molecule_locations = load_obj('../download/dataset/test_molecule_locations.pkl')
print('Available molecules:')
print(list(test_molecule_locations.keys()))

Available molecules:
['C2F3N', 'H4Si2', 'H2N2O2', 'ClF5', 'CH3F2P', 'H3O3P', 'Cl2H2Si', 'H2N2O', 'ClF3Si', 'CH2Cl2', 'FN', 'HP', 'ClF', 'NO', 'CS2', 'CCl2', 'CClN', 'H2Si', 'C2O', 'H4SSi', 'F3HSi', 'Cl2FN', 'Cl2OS', 'F3S', 'CHN', 'CH2', 'C2H2N2O', 'CH4N', 'CH2Cl3P', 'F3NO2S', 'CH4N2O', 'C2H4O2', 'C2HF5', 'CH4Cl2Si', 'C2H2F2O2', 'C3H2OS2', 'CH2N4O', 'C2H4ClF', 'CH4N2S', 'CH3Cl3Si', 'C4H2O2', 'C2H2Cl2O2', 'F6Si2', 'H4N4', 'C2H2N4', 'C3H4O', 'C2H3ClN2', 'C3H2O2S', 'CH3Cl2OP', 'CH3FO2S', 'CH3F3Si', 'C2H3ClOS', 'C2H2Cl3F', 'C2H2Cl4', 'CH2Cl2O2S', 'H3NO3S', 'CH4O2S', 'C2H5Cl', 'C3H4S', 'CH3F2OP', 'C2H4Cl2', 'H3N5', 'C2H3O3', 'C2H4F2', 'C2H2Cl2OS', 'C2HF3O2', 'C2H5S', 'CH3ClFOP', 'C2H2F4', 'C2H3FO2', 'C2H5O', 'CH5NO', 'C3H2O3', 'C2HN5', 'C2H5N', 'C2H4S2', 'C2H2O4', 'C4H2N2', 'C2H2Cl2F2', 'C2HClF4', 'C3H2S3', 'C2H3N3', 'C2H4OS', 'C3H5', 'CH3NO3', 'C2H3F3', 'C3H4N', 'C2H5F', 'C2H5P', 'CH6Si', 'H3O4P', 'C4H4', 'H6Si2', 'CH3ClO2S', 'C3H3NO', 'C3H3NS', 'F5HSi2', 'C2HF3OS', 'C2H3Cl3', 'CH4ClFSi', '

In [6]:
# Pick a molecule and check the available variations for it, use 'CH2Cl2' in this case
mol_name = 'CH2Cl2'
print('Available variations for '+ mol_name+':')
print(list(test_molecule_locations[mol_name].keys()))

Available molecules for CH2Cl2:
['5_1_10', '5_3_1', '5_1_6', '5_4_2', '5_4_10', '5_4_7', '5_3_3', '5_4_8', '5_1_2', '5_4_9', '5_4_4', '5_2_10', '5_4_5', '5_2_2', '5_2_4', '5_3_9', '5_3_8', '5', '5_3_4', '5_1_5', '5_1_3', '5_2_5', '5_2_7', '5_1_9', '5_1_8', '5_4_3', '5_4_6', '5_2_3', '5_3_10', '5_2_9', '5_2_6', '5_4_1', '5_1_7', '5_1_1', '5_2_8', '5_2_1', '5_3_5', '5_3_2', '5_1_4']


In [7]:
# Choose the variation type for this molecule, use '5' for this case, which is the ground state.
# For the other variabions like '5_2_1', the first number (5) is the number of atoms, 
# the second number (2) is the number of atoms that was moved with respect to the ground state
# the third number (1) is the dataset index.
# Check the available charge state identifiers.
variation = '5'
print('Available charge state identifiers for '+ mol_name+' with variation type ' + variation + ':')
print(list(test_molecule_locations[mol_name][variation].keys()))

Available charge state identifiers for CH2Cl2 variation type 5:
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40]


In [15]:
# choose the charge state id for this molecule, use 40 for this case
charge_state_id = 40
location, df_index = test_molecule_locations[mol_name][variation][charge_state_id]
print('The molecule ' + mol_name + ' with variation type ' + variation + ' and charge state id ' + str(charge_state_id)
      + ' can be found at '+location + ' with index ' + str(df_index))

The molecule CH2Cl2 with variation type 5 and charge state id 40 can be found at ../download/dataset/dataset_2_7/test/5.pkl with index 40


In [16]:
# Get atomic number, charge state, ground-truth geometry and momentum
num_atoms = int(variation.split('_')[0])
counter = {}
z_dict = {} # atomic number
q_dict = {} # chareg state
geom_dict = {} # geometry
momen_dict = {} # momentum
df = load_obj(location)

for i in range(num_atoms):
    atom = df.iloc[df_index]['atom_'+str(i+1)]
    if atom not in counter.keys():
        counter[atom] = 1
    else:
        counter [atom] += 1
    z_dict[atom+str(counter[atom])] = Z_dict[atom]
    q_dict[atom+str(counter[atom])] = int(df.iloc[df_index]['q_'+str(i+1)])
    geom_dict[atom+str(counter[atom])] = np.array([df.iloc[df_index]['x_'+str(i+1)], df.iloc[df_index]['y_'+str(i+1)], df.iloc[df_index]['z_'+str(i+1)]])
    momen_dict[atom+str(counter[atom])] = np.array([df.iloc[df_index]['px_'+str(i+1)], df.iloc[df_index]['py_'+str(i+1)], df.iloc[df_index]['pz_'+str(i+1)]])

In [17]:
z_dict, q_dict, geom_dict, momen_dict

({'Cl1': 17, 'Cl2': 17, 'C1': 6, 'H1': 1, 'H2': 1},
 {'Cl1': 11, 'Cl2': 9, 'C1': 2, 'H1': 1, 'H2': 1},
 {'Cl1': array([3.23785309e+00, 0.00000000e+00, 1.11022302e-16]),
  'Cl2': array([-1.69812368, -1.63840275,  2.21687154]),
  'C1': array([-0.13731963,  0.14599576, -0.19748977]),
  'H1': array([-7.01144596e-01,  2.11242760e+00,  5.55111512e-17]),
  'H2': array([-0.70126518, -0.6200206 , -2.01938177])},
 {'Cl1': array([ 1.16145918e+03, -2.16546358e-15,  1.74401939e-15]),
  'Cl2': array([-950.8006898 , -366.10833699,  374.44491917]),
  'C1': array([-150.79687258,  251.61855298, -262.5184522 ]),
  'H1': array([-2.78753899e+01,  1.14453722e+02, -3.95021317e-16]),
  'H2': array([ -29.75563561,   -1.25618022, -110.15187724])})

#### Option B - define the input by yourself

In [14]:
z_dict = {'Cl1': 17, 'Cl2': 17, 'C1': 6, 'H1': 1, 'H2': 1} # atomic number

q_dict = {'Cl1': 11, 'Cl2': 9, 'C1': 2, 'H1': 1, 'H2': 1} # chareg state

geom_dict =  {'Cl1': np.array([3.23785309e+00, 0.00000000e+00, 1.11022302e-16]),
  'Cl2': np.array([-1.69812368, -1.63840275,  2.21687154]),
  'C1': np.array([-0.13731963,  0.14599576, -0.19748977]),
  'H1': np.array([-7.01144596e-01,  2.11242760e+00,  5.55111512e-17]),
  'H2': np.array([-0.70126518, -0.6200206 , -2.01938177])} # ground-truth geometry

momen_dict =  {'Cl1': np.array([ 1.16145918e+03, -2.16546358e-15,  1.74401939e-15]),
  'Cl2': np.array([-950.8006898 , -366.10833699,  374.44491917]),
  'C1': np.array([-150.79687258,  251.61855298, -262.5184522 ]),
  'H1': np.array([-2.78753899e+01,  1.14453722e+02, -3.95021317e-16]),
  'H2': np.array([ -29.75563561,   -1.25618022, -110.15187724])} # momentum

### 2. Batch creation
Define the molecular frame, transform the momentum and geometry to this frame and assemble the molecular data to a batch

In [18]:
sorted_atom_lst = sort_atoms(z_dict)

momen_dict_td, ptc1, ptc2 = momentum_transform(momen_dict, sorted_atom_lst)
geom_dict_td = geom_transform(geom_dict, ptc1, ptc2)

anchor_atoms = [ptc1, ptc2]
atom_lst = anchor_atoms+[atom for atom in sorted_atom_lst if atom not in anchor_atoms]

batch = get_batch(mol_name, variation, atom_lst, z_dict, q_dict, geom_dict_td, momen_dict_td)

In [19]:
batch

Data(y=[5, 3], pos=[5, 3], mol_name='CH2Cl2', variation='5', z=[5], q=[5], batch=[5], natoms=[1])

### 3. Predict molecular geometry

In [20]:
batch = batch.to(def_device)
pred = (learn.predict_batch(batch)).squeeze()
ys = batch.y
zs = batch.z.tolist()
err_pred = learn.model.err_pred.squeeze()

RMSE = ((pred-ys)**2).mean().sqrt()       
MAE = torch.abs(pred-ys).mean()

print('Molecule:', batch.mol_name, 'Variation type:', batch.variation)

indent = ""
atom_col_width = 5
num_col_width  = 17

separator = (
    f"{indent}+{'-'*atom_col_width}"
    f"+{'-'*num_col_width}"
    f"+{'-'*num_col_width}"
    f"+{'-'*num_col_width}+"
)

print(separator)
print(
    f"{indent}|{'Atom':<{atom_col_width}}"
    f"|{'X (a.u.)':^{num_col_width}}"
    f"|{'Y (a.u.)':^{num_col_width}}"
    f"|{'Z (a.u.)':^{num_col_width}}|"
)
print(separator)

print("Prediction:")
for i, row in enumerate(pred):
    atom_name = Z_dict_inverse[zs[i]]
    x, y, z = map(lambda coord: round(coord, 3), row.tolist())
    x_err, y_err, z_err = map(lambda coord: round(coord, 3), err_pred[i].tolist())
    print(
        f"{indent}|{atom_name:<{atom_col_width}}"
        f"|{f'{x:.3f} ± {x_err:.3f}':^{num_col_width}}"
        f"|{f'{y:.3f} ± {y_err:.3f}':^{num_col_width}}"
        f"|{f'{z:.3f} ± {z_err:.3f}':^{num_col_width}}|"
    )
    print(separator)

print("Ground truth:")
for i, row in enumerate(ys):
    atom_name = Z_dict_inverse[zs[i]]
    x, y, z = map(lambda coord: round(coord, 3), row.tolist())
    print(
        f"{indent}|{atom_name:<{atom_col_width}}"
        f"|{f'{x:.3f}':^{num_col_width}}"
        f"|{f'{y:.3f}':^{num_col_width}}"
        f"|{f'{z:.3f}':^{num_col_width}}|"
    )
    print(separator)

print('RMSE (a.u.):', round(RMSE.item(),3), 'MAE (a.u.):', round(MAE.item(), 3)) 





Molecule: CH2Cl2 Variation type: 5
+-----+-----------------+-----------------+-----------------+
|Atom |    X (a.u.)     |    Y (a.u.)     |    Z (a.u.)     |
+-----+-----------------+-----------------+-----------------+
Prediction:
|Cl   |  3.231 ± 0.130  |  0.121 ± 0.102  | -0.022 ± 0.000  |
+-----+-----------------+-----------------+-----------------+
|H    | -0.780 ± 0.218  |  1.988 ± 0.090  | -0.015 ± 0.000  |
+-----+-----------------+-----------------+-----------------+
|Cl   | -1.893 ± 0.265  | -1.584 ± 0.231  |  2.145 ± 0.210  |
+-----+-----------------+-----------------+-----------------+
|C    | -0.148 ± 0.111  |  0.231 ± 0.067  | -0.161 ± 0.118  |
+-----+-----------------+-----------------+-----------------+
|H    | -0.754 ± 0.221  | -0.557 ± 0.148  | -1.944 ± 0.161  |
+-----+-----------------+-----------------+-----------------+
Ground truth:
|Cl   |      3.238      |      0.000      |      0.000      |
+-----+-----------------+-----------------+-----------------+
|H    |  