In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Add the project root to the path
import sys
import os
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

#### Arguments

In [3]:
args = {
    # General
    'seed': 42,
    'device': 'cpu',
    'root_dir': '/Users/svlg/MasterThesis/v02',

    # FlowMol
    'model': 'qm9_ctmc',
    'n_molecules': 1000,
    'n_timesteps': 50,

    # Reward model
    'reward_model': 'PAMNet_s',
    'n_layer': 6,
    'dim': 128,
    'target': 7,
    'cutoff_l': 5.0,
    'cutoff_g': 5.0,

    # Data / Dataset
    'dataset': 'QM9',
    'data_path': '/data',
    'batch_size': 1000,
}

In [4]:
import torch
import numpy as np
import random

def set_seed(seed):
    """Seed all random generators."""
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

set_seed(args['seed'])

#### Data

In [5]:
import os.path as osp
from torch_geometric.loader import DataLoader
from dataset.QM9 import QM9

In [8]:
data_root = osp.join('/Users/svlg/MasterThesis/v02/data/QM9')
train_dataset = QM9(data_root)
train_dataset.load(osp.join(data_root, 'qm9_train_data.pt'))
train_loader = DataLoader(train_dataset, batch_size=args['batch_size'], shuffle=True)
val_dataset = QM9(data_root)
val_dataset.load(osp.join(data_root, 'qm9_val_data.pt'))
val_loader = DataLoader(val_dataset, batch_size=args['batch_size'], shuffle=False)
test_dataset = QM9(data_root)
test_dataset.load(osp.join(data_root, 'qm9_test_data.pt'))
test_loader = DataLoader(test_dataset, batch_size=args['batch_size'], shuffle=False)

FileNotFoundError: [Errno 2] No such file or directory: '/Users/svlg/MasterThesis/v02/data/QM9/qm9_train_data.pt'

In [None]:
for batch_data in train_loader:
    print(batch_data)
    break

qm9_data = batch_data.to_data_list()
qm9_smiles = [data.smiles for data in qm9_data]

In [15]:
from rdkit import Chem
from rdkit.Geometry.rdGeometry import Point3D
bond_type_map = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]

In [None]:
def pyg_to_rdmol(positions, atom_types, bond_src_idxs, bond_dst_idxs, bond_types):
    """Builds a rdkit molecule from the given atom and bond information."""
    # create a rdkit molecule and add atoms to it
    mol = Chem.RWMol()
    for atom_type in atom_types:
        a = Chem.Atom(int(atom_type))
        mol.AddAtom(a)

    # add bonds to rdkit molecule
    visited = set()
    for bond_type, src_idx, dst_idx in zip(bond_types, bond_src_idxs, bond_dst_idxs):
        src_idx = int(src_idx)
        dst_idx = int(dst_idx)
        if (src_idx, dst_idx) in visited or (dst_idx, src_idx) in visited:
            continue
        mol.AddBond(src_idx, dst_idx, bond_type_map[bond_type])
        visited.add((src_idx, dst_idx))

    try:
        mol = mol.GetMol()
    except Chem.KekulizeException:
        return None

    # Set coordinates
    conf = Chem.Conformer(mol.GetNumAtoms())
    for i in range(mol.GetNumAtoms()):
        x, y, z = positions[i]
        x, y, z = float(x), float(y), float(z)
        conf.SetAtomPosition(i, Point3D(x,y,z))
    mol.AddConformer(conf)

    return mol

In [None]:
qm9_rdkit_mols = []
for data in qm9_data:
    mol = pyg_to_rdmol(data.pos, data.z.tolist(), data.edge_index[0].tolist(), data.edge_index[1].tolist(), torch.argmax(data.edge_attr, dim=1))
    qm9_rdkit_mols.append(mol)

In [None]:
def show_mol(index):
    import py3Dmol

    pdb_block = Chem.MolToMolBlock(qm9_rdkit_mols[index])

    # Visualize using py3Dmol
    viewer = py3Dmol.view(width=250, height=250)
    viewer.addModel(pdb_block, "mol")
    viewer.setStyle({"stick": {}, "sphere": {"scale": 0.3}})
    viewer.zoomTo()
    viewer.show()

show_mol(50)

#### Sampling

In [9]:
import flowmol

In [10]:
model = flowmol.load_pretrained(args['model'])
model = model.to(args['device'])
model.eval()

FlowMol(
  (interpolant_scheduler): InterpolantScheduler()
  (vector_field): CTMCVectorField(
    (interpolant_scheduler): InterpolantScheduler()
    (scalar_embedding): Sequential(
      (0): Linear(in_features=14, out_features=256, bias=True)
      (1): SiLU()
      (2): Linear(in_features=256, out_features=256, bias=True)
      (3): SiLU()
      (4): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    )
    (edge_embedding): Sequential(
      (0): Linear(in_features=6, out_features=128, bias=True)
      (1): SiLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): SiLU()
      (4): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (conv_layers): ModuleList(
      (0-7): 8 x GVPConv(
        (edge_message): Sequential(
          (0): GVP(
            (vectors_activation): Sigmoid()
            (to_feats_out): Sequential(
              (0): Linear(in_features=437, out_features=256, bias=True)
              (1): SiLU()
            )
        

In [12]:
n_molecules = 50
n_timesteps = 101
print(f"Sampling {n_molecules} molecules...")
generated_molecules = model.sample_random_sizes(n_molecules=n_molecules, n_timesteps=n_timesteps, device=args['device'])

Sampling 1000 molecules...


In [13]:
gen_rdkit_mols = []
gen_smiles = []
for mol in generated_molecules:
    gen_rdkit_mols.append(mol.rdkit_mol)
    gen_smiles.append(mol.smiles)

In [21]:
def show_mol(index):
    import py3Dmol

    pdb_block = Chem.MolToMolBlock(gen_rdkit_mols[index])

    # Visualize using py3Dmol
    viewer = py3Dmol.view(width=500, height=500)
    viewer.addModel(pdb_block, "mol")
    viewer.setStyle({"stick": {}, "sphere": {"scale": 0.3}})
    viewer.zoomTo()
    viewer.show()

show_mol(50)

In [17]:
from rdkit import Chem
from rdkit.Chem import Draw

mol = gen_rdkit_mols[50]
drawer = Draw.MolDraw2DSVG(400, 400)  # width x height in pixels
drawer.DrawMolecule(mol)
drawer.FinishDrawing()

# Save to .svg file
with open("mol_50.svg", "w") as f:
    f.write(drawer.GetDrawingText())


#### Constrains

In [None]:
class Constrains:
    def __init__(self,
                 use_sa_score: bool = False,
                 use_aizynth_finder: bool = False,
                 use_fs_score: bool = False,
                 use_ra_score: bool = False,
                 use_posebuster: bool = False,
                 pretrained_models_path: str = "/Users/svlg/MasterThesis/v02/pretrained_models/",):

        self.use_sa_score = use_sa_score
        self.use_aizynth_finder = use_aizynth_finder
        self.use_fs_score = use_fs_score
        self.use_ra_score = use_ra_score
        self.use_posebuster = use_posebuster
        self.pretrained_models_path = pretrained_models_path

        if not (self.use_sa_score or self.use_aizynth_finder or self.use_fs_score or self.use_ra_score or self.use_posebuster):
            print("No scoring function selected. Please select at least one scoring function.")

        if self.use_sa_score:
            from molscore.scoring_functions.SA_Score import sascorer
            self.sascorer = sascorer

        if self.use_aizynth_finder:
            from molscore.scoring_functions.aizynthfinder import AiZynthFinder
            self.aizynth_finder = AiZynthFinder(filter_policy=None)

        if self.use_fs_score:
            from fsscore.score import Scorer
            from fsscore.models.ranknet import LitRankNet
            fs_score_model_path = "/FSscore/pretrain_graph_GGLGGL_ep242_best_valloss.ckpt"
            model_path = self.pretrained_models_path + fs_score_model_path
            model = LitRankNet.load_from_checkpoint(model_path)
            model.to("cpu")
            model.eval()
            self.fc_scorer = Scorer(model=model, device="cpu")

        if self.use_ra_score:
            from molscore.scoring_functions.rascore_xgb import RAScore_XGB
            self.ra_scorer = RAScore_XGB(model = "GDB")

        if self.use_posebuster:
            from posebusters import PoseBusters
            self.posebuster = PoseBusters(config="mol")

    def get_sa_score(self, tmp_list):
        r"""Input: List of RDKit molecules"""
        rt_list = []
        for tmp in tmp_list:
            Chem.GetSSSR(tmp)
            Chem.SanitizeMol(tmp)
            score = self.sascorer.calculateScore(tmp)
            rt_list.append(score)
        return rt_list

    def get_aizynth_finder(self, tmp_list, directory='ai_finder_results', just_scores=True):
        r"""Input: List of SMILES"""
        tmp_list = self.aizynth_finder(tmp_list, directory)
        if just_scores:
            return [(tmp['AiZynth_is_solved'], tmp['AiZynth_top_score']) for tmp in tmp_list]
        else:
            return tmp_list

    def get_fs_score(self, tmp_list):
        r"""Input: List of SMILES"""
        return self.fc_scorer.score(tmp_list)
    
    def get_ra_score(self, tmp_list):
        r"""Input: List of SMILES"""
        tmp_list = self.ra_scorer(tmp_list)
        return [tmp['RAScore_pred_proba'] for tmp in tmp_list]
    
    def get_posebuster(self, tmp_list, full_report=False):
        r"""Input: List of RDKit molecules"""
        return self.posebuster.bust(tmp_list, None, None, full_report=full_report)

    def score(self, tmp_list, scorer):
        r"""Input: List of SMILES"""
        rt_list = []
        for tmp in tmp_list:
            try:
                rt_list.append(scorer([tmp]))
            except:
                rt_list.append(None)
        return rt_list

    def __call__(self, tmp_list):

        r"""Input: List of Molecules"""
        scores = dict()
        if self.use_sa_score:
            scores['sa_score'] = self.score(tmp_list, self.get_sa_score)
        if self.use_aizynth_finder:
            scores['aizynth_finder'] = self.score(tmp_list, self.get_aizynth_finder)
        if self.use_fs_score:
            scores['fs_score'] = self.score(tmp_list, self.get_fs_score)
        if self.use_ra_score:
            scores['ra_score'] = self.score(tmp_list, self.get_ra_score)
        if self.use_posebuster:
            scores['posebuster'] = self.score(tmp_list, self.get_posebuster)
        return scores

In [None]:
# smiles_constrains = Constrains(use_sa_score=False, use_aizynth_finder=True, use_fs_score=True, use_ra_score=True, use_posebuster=False)
# rdkit_constrains = Constrains(use_sa_score=True, use_aizynth_finder=False, use_fs_score=False, use_ra_score=False, use_posebuster=True)

#### SA Score

In [None]:
sa_score_constrains = Constrains(use_sa_score=True)

In [None]:
qm9_sa_scores = sa_score_constrains(qm9_rdkit_mols)

In [None]:
gen_sa_scores = sa_score_constrains(gen_rdkit_mols)

In [None]:
sa_score_qm9 = np.array(qm9_sa_scores)
sa_score_gen = np.array(gen_sa_scores)
np.save("output/sa_score_qm9.npy", sa_score_qm9)
np.save("output/sa_score_gen.npy", sa_score_gen)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Define the number of bins (adjustable)
num_bins = 30

# Create the figure and axes for 3 subplots
fig, ax = plt.subplots(1, 1, figsize=(15, 5), sharey=True, sharex=True)  # 1 row, 3 columns

# Plot both histograms with overlay
ax.hist(qm9_sa_scores, bins=num_bins, alpha=0.5, color='blue', label='QM9', edgecolor='black')
ax.hist(gen_sa_scores, bins=num_bins, alpha=0.5, color='red', label='Gen', edgecolor='black')
ax.set_title('SA Scores')
ax.set_ylabel('Freq')
ax.set_xlabel('Value')
ax.legend()

# Adjust layout for better spacing
plt.tight_layout()
plt.show()

#### FS Score

In [None]:
fs_score_constrains = Constrains(use_fs_score=True)

In [None]:
qm9_fs_score = fs_score_constrains(qm9_smiles)
gen_fs_score = fs_score_constrains(gen_smiles)

In [None]:
print(qm9_fs_score)

In [None]:
fs_score_qm9 = np.array(qm9_fs_score)
fs_score_gen = np.array(gen_fs_score)
np.save("output/fs_score_qm9.npy", fs_score_qm9)
np.save("output/fs_score_gen.npy", fs_score_gen)

#### RA Score

In [None]:
ra_score_constrains = Constrains(use_ra_score=True)

In [None]:
ra_score_qm9 = ra_score_constrains(qm9_smiles)
ra_score_qm9 = np.array(ra_score_qm9)
np.save("output/ra_score_qm9.npy", ra_score_qm9)

In [None]:
ra_score_gen = ra_score_constrains(gen_smiles)
ra_score_gen = np.array(ra_score_gen)
np.save("output/ra_score_gen.npy", ra_score_gen)

#### Posebuster

In [None]:
posebuster_constrains = Constrains(use_posebuster=True)

In [None]:
pose_score_qm9 = posebuster_constrains.get_posebuster(qm9_rdkit_mols)
pose_score_qm9 = np.array(pose_score_qm9)
np.save("output/pose_score_qm9.npy", pose_score_qm9)

In [None]:
pose_score_gen = posebuster_constrains.get_posebuster(gen_rdkit_mols)
pose_score_gen = np.array(pose_score_gen)
np.save("output/pose_score_gen.npy", pose_score_gen)

#### AiZynth Finder

In [None]:
aizynth_finder_constrains = Constrains(use_aizynth_finder=True)

In [None]:
aifinder_score_qm9 = aizynth_finder_constrains(qm9_smiles)
aifinder_score_qm9 = np.array(aifinder_score_qm9)
np.save("output/aifinder_score_qm9.npy", aifinder_score_qm9)

In [None]:
aifinder_score_gen = aizynth_finder_constrains(gen_smiles)
aifinder_score_gen = np.array(aifinder_score_gen)
np.save("output/aifinder_score_gen.npy", aifinder_score_gen)

#### Plots

#### RDKit Constrains

In [None]:
qm9_rdkit_constrains = rdkit_constrains(qm9_rdkit_mols[:10])

In [None]:
qm9_rdkit_constrains['sa_score']

In [None]:
qm9_rdkit_constrains['posebuster'][1]

In [None]:
gen_rdkit_constrains = rdkit_constrains(gen_rdkit_mols)

In [None]:
gen_rdkit_constrains

#### XTB Simulation

In [None]:
from true_reward import xtb_simulation

In [None]:
#  Calculate the true reward
for mol in qm9_data[:10]:
    quantity_value = xtb_simulation.compute_true_reward(mol, "pyg", "homolumo")
    homolumo_gap, lumo, homo = quantity_value
    print(f"HOMO-LUMO gap: {homolumo_gap:.4f} eV")
    print(f"LUMO: {lumo} eV\nHOMO: {homo} eV")

In [None]:
#  Calculate the true reward
for mol in generated_molecules:
    quantity_value = xtb_simulation.compute_true_reward(mol.g, "dgl", "homolumo")
    homolumo_gap, lumo, homo = quantity_value
    print(f"HOMO-LUMO gap: {homolumo_gap:.6f} eV")
    print(f"LUMO: {lumo} eV\nHOMO: {homo} eV")

#### Compare 1000 Molecules wrt there Homolumo Energy

In [None]:
qm9_gap = []
qm9_homo_lumo = []
for mol in qm9_data:
    quantity_value = xtb_simulation.compute_true_reward(mol, "pyg", "homolumo")
    gap, lumo, homo = quantity_value
    qm9_gap.append(gap)
    qm9_homo_lumo.append((lumo, homo))

In [None]:
qm9_gap = np.array(qm9_gap)
qm9_homo_lumo = np.array(qm9_homo_lumo)
np.save("qm9_gap.npy", qm9_gap)
np.save("qm9_homo_lumo.npy", qm9_homo_lumo)

In [None]:
gen_gap = []
gen_homo_lumo = []
for i in range(100):
    print(f"Round {i+1} - Sampling {args['n_molecules']} molecules...")
    tmp_molecules = model.sample_random_sizes(n_molecules=args['n_molecules'], n_timesteps=args['n_timesteps'], device=args['device'])
    for mol in tmp_molecules:
        quantity_value = xtb_simulation.compute_true_reward(mol.g, "dgl", "homolumo")
        gap, lumo, homo = quantity_value
        gen_gap.append(gap)
        gen_homo_lumo.append((lumo, homo))

In [None]:
gen_gap = np.array(gen_gap)
gen_homo_lumo = np.array(gen_homo_lumo)
np.save("gen_gap.npy", gen_gap)
np.save("gen_homo_lumo.npy", gen_homo_lumo)

#### Differentiable Reward

In [None]:
from PAMNet.models import PAMNet_s, Config

In [None]:
config = Config(dataset=args['dataset'], dim=args['dim'], n_layer=args['n_layer'], cutoff_l=args['cutoff_l'], cutoff_g=args['cutoff_g'])
reward_model = PAMNet_s(config).to(args['device'])
reward_model.eval()

In [None]:
targets = []
for data in generated_molecules:
    data = data.pyg_mol
    data.pos.requires_grad_()
    tmp = reward_model(data)
    targets.append(tmp)
    tmp.backward()
    pos_grad = data.pos.grad
print(len(targets))
print(data.pos.shape)
print(pos_grad.shape)

#### Visualize the Molecules

In [None]:
import py3Dmol

In [None]:
# Convert RDKit Mol to PDB block
pdb_blocks = []
for mol in gen_rdkit_mols:
    pdb_blocks.append(Chem.MolToMolBlock(mol))

# Visualize using py3Dmol
viewer = py3Dmol.view(width=600, height=600, viewergrid=(3, 3))
viewer.addModel(pdb_blocks[0], "mol", viewer=(0, 0))
viewer.addModel(pdb_blocks[1], "mol", viewer=(0, 1))
viewer.addModel(pdb_blocks[2], "mol", viewer=(0, 2))
viewer.addModel(pdb_blocks[3], "mol", viewer=(1, 0))
viewer.addModel(pdb_blocks[4], "mol", viewer=(1, 1))
viewer.addModel(pdb_blocks[5], "mol", viewer=(1, 2))
viewer.addModel(pdb_blocks[6], "mol", viewer=(2, 0))
viewer.addModel(pdb_blocks[7], "mol", viewer=(2, 1))
viewer.addModel(pdb_blocks[8], "mol", viewer=(2, 2))
viewer.setStyle({"stick": {}, "sphere": {"scale": 0.3}})
viewer.zoomTo()
viewer.show()