In [1]:
import numpy as np
from matplotlib import pyplot as plt
from molecular_generation_utils import *
from invert_CM import *
import torch
from Model import Multi_VAE
from torch.distributions.categorical import Categorical
from tqdm import tqdm
import copy
from ase.visualize import view

reproduce_paper = False

if reproduce_paper:
    paper_path = '_paper'
else:
    paper_path = ''


properties = torch.load('./data/properties_total.pt'.format(paper_path))
p_means = torch.load('./data/properties_means.pt'.format(paper_path))
p_stds = torch.load('./data/properties_stds.pt'.format(paper_path))
norm_props = (properties - p_means)/p_stds

properties_list =  ['eAT', 'eMBD', 'eXX', 'mPOL', 'eNN', 'eNE', 'eEE', 'eKIN', 'DIP', 'HLgap', 'HOMO_0', 'LUMO_0', 'HOMO_1', 'LUMO_1', 'HOMO_2', 'LUMO_2', 'dimension']
p_arr = np.array(properties_list)

PATH = "last.ckpt"
#'./models_saved/masked/epoch=2597-step=145487.ckpt'

modello = Multi_VAE.load_from_checkpoint(
    PATH,
    map_location=torch.device('cpu'),
    structures_dim = len(torch.load('./data/data_val/CMs.pt')[0,:]),
    properties_dim = len(torch.load('./data/data_val/properties.pt')[0,:]),
    latent_size = 21,
    extra_dim = 32 - len(torch.load('./data/data_val/properties.pt')[0,:]),
    initial_lr = 1e-3,
    properties_means = p_means,
    properties_stds = p_stds,
    beta = 4.,
    alpha = 1.,
    decay = .01
)

# if you want the non-masked model uncomment this

# if reproduce_paper == True:
#     PATH='./special/VAE_reduced_21'
#     modello.VAE.load_state_dict(torch.load(PATH,map_location=torch.device('cpu')))
#     PATH='./special/prop_ecoder_reduced_21'
#     modello.property_encoder.load_state_dict(torch.load(PATH,map_location=torch.device('cpu')))

modello.eval()
modello.freeze()

Lightning automatically upgraded your loaded checkpoint from v1.5.10 to v2.2.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint last.ckpt`


Fit the dataset property distribution with 91 (optional number) Gaussians

In [2]:
gm, labels = props_fit_Gaussian_mix(
    norm_props, 
    min_components = 91,
    max_components = 92
    )

100%|██████████| 1/1 [01:02<00:00, 62.96s/it]


using 91 components


In [3]:
# recall that here the target value is the normalized one

generated = start_generation(
    modello,
    {
        'mPOL': 2.,
        'eMBD': 2.
    },
    p_arr,
    177,
    int(5e3),
    gm.means_,
    gm.covariances_,
    cm_diff = 5,
    deltaz = 6,
    check_new_comp = False,
    verbose = False
)

. sampling...


  M[i,j]=(M[i,j]/(Z_w[i]*Z_w[j]))**(-1)


In [None]:
generated

In [None]:
# load the coulomb matrices in the dataset
CMs = torch.load('./data/CMs_total.pt'.format(paper_path))

In [None]:
generated.shape

In [17]:
CM = generated[1]

from invert_CM import *
from CM_preparation import *

distance_mat, master_vec2 = recover_distance_mat(
    CM 
)


print("Master Vector:", master_vec2)

# Recover the Cartesian coordinates
cartesian = cartesian_recovery(distance_mat)
# truncate last two columns
cartesian = cartesian[:,0:3]
# remove imaginary part
cartesian = np.real(cartesian)
print("Cartesian coordinates:", cartesian)

# Create the recovered ASE Atoms object
rec_mol = Atoms(symbols=master_vec2, positions=cartesian)
print("Recovered mol:")
view(rec_mol)
# Get the RMSD between the original and recovered molecule

Master Vector: [7, 7, 7, 6, 6, 6, 6]
Cartesian coordinates: [[ 0.          0.          0.        ]
 [-3.01945377  1.33733603  0.55352034]
 [-3.63094339  0.25392509 -0.68692134]
 [-4.42263893 -1.28195918  0.04642233]
 [-3.74078607 -0.21149771  0.52600459]
 [-1.04466364  0.10172513 -0.17455889]
 [-2.07755196  0.67123701 -0.56209478]]
Recovered mol:


<Popen: returncode: None args: ['/Users/jan/miniconda3/bin/python', '-m', 'a...>

In [None]:
#find the closest mmatrix in the dataset

where_closest = []
for cm in generated:
    where_closest.append(torch.argmin(torch.norm(CMs - cm.view(1,-1), dim = 1)))
    #print(torch.norm(CMs - cm.view(1,-1), dim = 1).min())

In [20]:
import rmsd
from ase.visualize import view
from ase.optimize import BFGS, FIRE
from ase import Atoms
from ase.io import write, read
from openbabel import openbabel as ob
from ase.constraints import FixAtoms
from ase.calculators.emt import EMT

# Define the EMT calculator
calc = EMT()

n = 0
pos, comp = get_cartesian(generated[n, :].tolist())
print(len(comp))
atom = Atoms(comp, pos)
write("./temp.xyz", atom)
obConversion = ob.OBConversion()
obConversion.SetInAndOutFormats("xyz", "mol2")
mol = ob.OBMol()
obConversion.ReadFile(mol, "temp.xyz")
obConversion.WriteFile(mol, "temp.mol2")
obConversion = ob.OBConversion()
obConversion.SetInAndOutFormats("mol2", "xyz")
mol = ob.OBMol()
obConversion.ReadFile(mol, "temp.mol2")
mol.ConnectTheDots()
mol.PerceiveBondOrders()
mol.AddHydrogens()
obConversion.WriteFile(mol, "./temp.xyz")
atom = read("./temp.xyz")
print(comp, len(comp))
c = FixAtoms(indices=[atomo.index for atomo in atom if atomo.symbol != "H"])
atom.set_constraint(c)
atom.set_calculator(calc)
# try:
opt = FIRE(atom)
opt.run(steps=22222)
# atom.set_constraint()
# opt.run(steps = 10)
# except:
#    pass
view(atom, viewer="x3d")

6
[7, 7, 6, 6, 6, 6] 6
      Step     Time          Energy          fmax
FIRE:    0 15:15:15       10.371183       12.052721
FIRE:    1 15:15:15        7.395353        5.260548
FIRE:    2 15:15:15        5.989948        1.767563
FIRE:    3 15:15:15        5.655581        0.718729
FIRE:    4 15:15:15        5.653971        0.690765
FIRE:    5 15:15:15        5.651011        0.635164
FIRE:    6 15:15:15        5.647180        0.552697
FIRE:    7 15:15:15        5.643113        0.444841
FIRE:    8 15:15:15        5.639503        0.314254
FIRE:    9 15:15:15        5.636981        0.166274
FIRE:   10 15:15:15        5.635968        0.083559
FIRE:   11 15:15:15        5.636472        0.177173
FIRE:   12 15:15:15        5.636432        0.174656
FIRE:   13 15:15:15        5.636356        0.169676
FIRE:   14 15:15:15        5.636245        0.162337
FIRE:   15 15:15:15        5.636105        0.152802
FIRE:   16 15:15:15        5.635941        0.141294
FIRE:   17 15:15:15        5.635761        

  atom.set_calculator(calc)


FIRE:  111 15:15:15        5.260602        0.087138
FIRE:  112 15:15:15        5.258677        0.087848
FIRE:  113 15:15:15        5.257168        0.105381
FIRE:  114 15:15:15        5.254905        0.113639
FIRE:  115 15:15:15        5.251726        0.091228
FIRE:  116 15:15:15        5.249252        0.093145
FIRE:  117 15:15:15        5.246201        0.096814
FIRE:  118 15:15:15        5.242023        0.101829
FIRE:  119 15:15:15        5.237211        0.106094
FIRE:  120 15:15:15        5.230701        0.111117
FIRE:  121 15:15:15        5.223317        0.113204
FIRE:  122 15:15:15        5.214618        0.106451
FIRE:  123 15:15:15        5.206179        0.117490
FIRE:  124 15:15:15        5.201776        0.157556
FIRE:  125 15:15:15        5.200069        0.079200
FIRE:  126 15:15:15        5.200117        0.108210
FIRE:  127 15:15:15        5.199864        0.086697
FIRE:  128 15:15:15        5.199534        0.048332


In [None]:
# visualize the closest molecules in the dataset by coulomb matrix

import rmsd
from ase.visualize import view
from ase import Atoms

pos, comp = get_cartesian(CMs[where_closest[n],:].tolist())
print(len(comp))
atom = Atoms(comp, pos)
view(atom, viewer = 'x3d')

In [None]:
import pandas as pd

import matplotlib.pyplot as plt
import matplotlib
font = {'family' : "Times New Roman",
        'weight' : 'normal',
        'size'   : 20}

matplotlib.rc('font', **font)
#plt.rcParams["font.family"] = "Times New Roman"

gold = (1+(5**0.5))/2
plt.figure(figsize=(5*gold,5))

plt.scatter(properties[:,1], properties[:,3], c = 'lightgrey', label = 'QM7-X')
index = [0,1,2,4]
mols = pd.read_csv('/work/projects/tcp/lmedrano/alessio/vae/mol-new/code/fhi/extract/xyzfiles2/12/data-12.dat', header = None).values[index]
print(mols)
for i in range(0, len(mols)):
    if i == 0:
        plt.scatter(float(mols[i][0].split()[2]), float(mols[i][0].split()[3]), c = 'turquoise', s = 50, label = r'samples $T_1$')
    else:
        plt.scatter(float(mols[i][0].split()[2]), float(mols[i][0].split()[3]), c = 'turquoise', s = 50)
        
plt.scatter(-1*p_stds[1] + p_means[1], 2*p_stds[3] + p_means[3], c = 'navy', marker = '*', s = 70, label = r'$T_1$')

mols = pd.read_csv('/work/projects/tcp/lmedrano/alessio/vae/mol-new/code/fhi/extract/xyzfiles/set2/data-set2.dat', header = None).values
print(mols)
for i in range(0, len(mols)):
    if i == 0:
        plt.scatter(float(mols[i][0].split()[2]), float(mols[i][0].split()[3]), c = 'orange', s = 50, label = r'samples $T_2$')
    else:
        plt.scatter(float(mols[i][0].split()[2]), float(mols[i][0].split()[3]), c = 'orange', s = 50)
                
plt.scatter(-1.5*p_stds[1] + p_means[1], 1*p_stds[3] + p_means[3], c = 'red', marker = '*', s = 70, label = r'$T_2$')
plt.legend(loc = 'lower right', prop = {'size': 14})
plt.xlim(-.5, -.2)
plt.ylim(80, 120)
plt.locator_params(axis='x', nbins=4)
plt.locator_params(axis='y', nbins=4)
plt.xlabel(r'$E_{MBD}$', fontdict={'size':22})
plt.ylabel(r'$\alpha$', fontdict={'size':22})
plt.tight_layout()
plt.savefig('./new_chem.pdf', dpi = 100)