In [1]:
import os, sys, logging
import torch

from torch import nn, tensor
from functools import partial

sys.path.append(os.path.join(os.getcwd(), '..'))
from utils.DataLoaders_jupyter import Get_Dataset, Create_DataLoaders
from utils.DataLoaders_noSplit import Get_Dataset_noSplit, Create_DataLoaders_noSplit
from utils.Learner import Learner
from utils.Callbacks import TrainCB, DeviceCB, MetricsCB, BatchSchedCB, SaveCB
from utils.utils_ML import Adam, set_seed, LayerNorm, MeanAbsoluteError, EDM_MSELoss, MSELoss, EDM_MSELoss_EvalLoss
from utils.Callbacks import to_device, ProgressCB
from utils.download import download_dataset_weights
from torcheval.metrics import MeanSquaredError
from model.XMolNet import XMolNet


torch.set_printoptions(precision=5, linewidth=140, sci_mode=False)
torch.manual_seed(1)
logging.disable(logging.WARNING)
def_device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('device:',def_device)


%load_ext autoreload
%autoreload 2


device: cuda


In [2]:
download_dataset_weights()

Downloading...
From: https://drive.google.com/uc?id=1TR_bS2GRgz-HqnP_NA566y1eogbORw8F
To: /sdf/data/lcls/ds/prj/prjsim00221/results/molexa_git/download/dataset.zip
100%|██████████████████████████████████████| 22.0M/22.0M [00:00<00:00, 65.4MB/s]


File downloaded successfully to ../download/dataset.zip
Contents extracted to ../download/dataset


Downloading...
From (original): https://drive.google.com/uc?id=114-Qe5rrJc4nghPRBZ9_xEqK1XteJ2bb
From (redirected): https://drive.google.com/uc?id=114-Qe5rrJc4nghPRBZ9_xEqK1XteJ2bb&confirm=t&uuid=a234f9a0-e97c-45c6-8502-13fb26590c0f
To: /sdf/data/lcls/ds/prj/prjsim00221/results/molexa_git/download/model.zip
100%|████████████████████████████████████████| 148M/148M [00:01<00:00, 89.5MB/s]


File downloaded successfully to ../download/model.zip
Contents extracted to ../download/model


## Set up model

In [2]:
epochs = 3
lr = 0.001
beta1 = 0.9
beta2 = 0.99

sv_step_size = 200
sv_prefix = 'model_'

diffu_params = \
    {'y_c': tensor([0., -0., -0.]),
     'y_hw': tensor([23.59669, 19.97351, 15.76245]),
     'n_diffu': 1,#8,
     'P_mean': -1.2,
     'P_std': 1.2,
     'sigma_data': 0.25}


natts_diffu = 2
    
num_steps =  15

z_max = 20
z_emb_dim = 64
q_max = 20
q_emb_dim = 64
pos_out_dim = 64

natts = 6
scale = 4
att_dim = int(128*scale)
nheads = int(8*scale)

denoise = True
sample = False
loss_eval_c = 1
err_bin_center = torch.arange(0, 10, 0.05)

model = XMolNet(z_max, z_emb_dim, q_max, q_emb_dim, pos_out_dim, att_dim=att_dim, diffu_params = diffu_params, natts=natts, nheads=nheads, 
                dot_product=True, res=True, act1=nn.ReLU, act2=nn.ReLU, norm=LayerNorm, attention_type='full', lstm=True, sumup=False, 
                num_steps=num_steps, natts_diffu=natts_diffu, err_bin_center=err_bin_center, denoise=denoise, sample=sample)



In [3]:
tls_train, tls_valid = Get_Dataset(path='../download/dataset/dataset_2_7/', rank=0) 
dls = Create_DataLoaders(tls_train, tls_valid, batch_size=32, sampler=None, vshuffle=True)

load_path = '../download/model/molexa.chk'

loss_func = EDM_MSELoss_EvalLoss(diffu_params=diffu_params, loss_eval_c = loss_eval_c, err_bin_center=err_bin_center, device='cuda')
val_loss_func = MSELoss()
MCB = MetricsCB(mae=MeanAbsoluteError(), mse_eval=MeanSquaredError())

SCB = SaveCB(epochs=epochs, step_size=sv_step_size, save_dir='.', prefix=sv_prefix)
cbs = [TrainCB(), DeviceCB(), MCB, ProgressCB(plot=False)]
xtra = []
learn = Learner(model.to(def_device), dls, loss_func=loss_func, val_loss_func=val_loss_func, cbs=cbs+xtra, load_path=load_path, opt_func=torch.optim.Adam, lr=lr, 
                beta1=beta1, beta2=beta2) 

learn.model.num_steps=5
learn.model.sigma_min=0.002
learn.model.sigma_max=80
learn.model.rho=1.5#7

learn.model.S_churn=30
learn.model.S_min=0.01
learn.model.S_max=1
learn.model.S_noise=1.1#1.007

learn.model.heun=False
learn.model.step_scale = 1

# of valid: 8580 # of train: 60832


## Get model performance on validation data

In [4]:
learn.fit(1, train=False, valid=True)

epoch           loss_v          mae_v           mse_eval_v      lr              beta1           beta2           time            rank           


0               1.02814         0.529393024     0.465253979     0.001000000     0.900           0.990           01:06:30,01-11  0              


# Prediction of single molecule

In [11]:
tls = Get_Dataset_noSplit(path='../download/dataset/dataset_2_7/test/', rank=0)    
# tls = Get_Dataset_noSplit(path='../download/dataset/dataset_8_9/', rank=0)    
dls = Create_DataLoaders_noSplit(tls, batch_size=1, sampler=None, vshuffle=False)

# of samples: 8762


In [12]:

no_filter = False
mol_tp = '5'
mol_name = 'Cl2H2Si'
# mol_tp = '8'
# mol_name = 'C2H3N3'

Z_dict_inverse = {1: 'H', 6: 'C', 7: 'N', 8: 'O', 9: 'F', 14: 'Si', 15: 'P', 16: 'S', 17: 'Cl'}

indent = " " * 0

for batch in dls.samples:
    if no_filter or (batch.mol_name[0]==mol_name and batch.mol_tp[0]==mol_tp):
        batch = to_device(batch)
        pred = (learn.predict_batch(batch)).squeeze()
        ys = batch.y
        zs = batch.z.tolist()

        RMSE = ((pred-ys)**2).mean().sqrt()       
        MAE = torch.abs(pred-ys).mean()
        
        print('Molecule:', batch.mol_name[0])
        print(f"{indent}+-----+------------+------------+------------+")
        print(f"{indent}| {'Atom':<3} | {'X (a.u.)':>10} | {'Y (a.u.)':>10} | {'Z (a.u.)':>10} |")
        print(f"{indent}+-----+------------+------------+------------+")       
        print('Prediction:')
        for i, row in enumerate(pred):
            atom_name = Z_dict_inverse[zs[i]]
            x, y, z = map(lambda coord: round(coord, 3), row.tolist())
            print(f"{indent}| {atom_name:<3} | {x:>10.3f} | {y:>10.3f} | {z:>10.3f} |")
            print(f"{indent}+-----+------------+------------+------------+")
            
        print('Ground truth:')
        for i, row in enumerate(ys):
            atom_name = Z_dict_inverse[zs[i]]
            x, y, z = map(lambda coord: round(coord, 3), row.tolist())
            print(f"{indent}| {atom_name:<3} | {x:>10.3f} | {y:>10.3f} | {z:>10.3f} |")
            print(f"{indent}+-----+------------+------------+------------+")
            
        print('RMSE (a.u.):', round(RMSE.item(),3), 'MAE (a.u.):', round(MAE.item(), 3)) 

        break 
        




Molecule: Cl2H2Si
+-----+------------+------------+------------+
| Atom |   X (a.u.) |   Y (a.u.) |   Z (a.u.) |
+-----+------------+------------+------------+
Prediction:
| Cl  |      3.657 |      0.001 |     -0.021 |
+-----+------------+------------+------------+
| H   |     -1.072 |      2.659 |      0.014 |
+-----+------------+------------+------------+
| Cl  |     -1.608 |     -1.897 |     -2.581 |
+-----+------------+------------+------------+
| Si  |     -0.139 |      0.125 |      0.003 |
+-----+------------+------------+------------+
| H   |     -1.071 |     -0.922 |      2.557 |
+-----+------------+------------+------------+
Ground truth:
| Cl  |      3.757 |     -0.000 |     -0.000 |
+-----+------------+------------+------------+
| H   |     -0.938 |      2.794 |      0.000 |
+-----+------------+------------+------------+
| Cl  |     -1.741 |     -1.846 |     -2.770 |
+-----+------------+------------+------------+
| Si  |     -0.140 |      0.220 |      0.131 |
+-----+--------