In [1]:
import sys
sys.path.append("..")

import torch

from models import MACEBaryModel
from torch_geometric.data import Batch

import config as cfg

In [2]:
mace_kwargs = {
    "r_max": 5,
    "num_bessel": 10,
    "num_polynomial_cutoff": 6,
    "max_ell": 2,
    "correlation": 3,
    "num_layers": 5,
    "emb_dim": 64,
    "hidden_irreps": None,
    "mlp_dim": 256,
    "in_dim": cfg.NUM_POSSIBLE_ATOMS,
    "out_dim": 1,
    "aggr": "sum",
    "pool": "sum",
    "batch_norm": True,
    "residual": True,
    "equivariant_pred": True,
    "as_featurizer": True,
}

mace_bary = MACEBaryModel(mace_kwargs)

In [3]:
import sys
sys.path.append('../')

import datamol as dm
import logging
import numpy as np
import os
import pandas as pd 
import torch

from graphium.features import featurizer as gff
from datamol.descriptors.compute import _DEFAULT_PROPERTIES_FN
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import BaseTransform, Compose, NormalizeScale
from torch_geometric.utils import remove_self_loops
from tqdm import tqdm
from transformers import pipeline

import config as cfg

In [4]:

class CompleteGraph(BaseTransform):
    """
    This transform adds all pairwise edges into the edge index per data sample,
    then removes self loops, i.e. it builds a fully connected or complete graph
    """
    def __call__(self, data):
        
        device = data.edge_index.device

        row = torch.arange(data.num_nodes, dtype=torch.long, device=device)
        col = torch.arange(data.num_nodes, dtype=torch.long, device=device)

        row = row.view(-1, 1).repeat(1, data.num_nodes).view(-1)
        col = col.repeat(data.num_nodes)
        edge_index = torch.stack([row, col], dim=0)

        edge_attr = None
        if data.edge_attr is not None:
            idx = data.edge_index[0] * data.num_nodes + data.edge_index[1]
            size = list(data.edge_attr.size())
            size[0] = data.num_nodes * data.num_nodes
            edge_attr = data.edge_attr.new_zeros(size)
            edge_attr[idx] = data.edge_attr

        edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
        data.edge_attr = edge_attr
        data.edge_index = edge_index

        return data

In [5]:
class ConcatenateGlobal(BaseTransform):
    """_summary_

    Args:
        BaseTransform (_type_): _description_
    """
    
    def __call__(self, data):
        u = torch.concatenate([data.u_chem, data.u_dm], dim=0).view(1, -1)
        del data.u_dm
        del data.u_chem
        data.u = u
        
        return data

In [6]:
# TODO: Change the number of processed SMILES to the length of the dataframe
# for idx in tqdm(range(10), desc="Processing molecules", total=len(df_data)): -> for idx in tqdm(range(lrn(df_data)), desc="Processing molecules", total=len(df_data)):

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class MolDataset(InMemoryDataset):
    def __init__(
            self, 
            root: str, 
            pre_transform: callable = Compose([NormalizeScale(), ConcatenateGlobal()]), 
            transform = None, 
            **kwargs
        ):
        super().__init__(root, pre_transform=pre_transform, transform=transform, **kwargs)
        self.load(self.processed_paths[0])

    @property
    def raw_file_names(self) -> list[str]:
        raw_files = os.listdir(self.raw_dir)
        return raw_files
    
    @property
    def processed_file_names(self) -> list[str]:
        raw_files = os.listdir(self.raw_dir)
        processed_files = [raw_file.replace(".csv", ".pt") for raw_file in raw_files]
        return processed_files
    
    def normalize_mol_descriptors(self, data_list, desc_mins, desc_maxs):
        # Normalize the molecular descriptors to the range [0, 1]
        for data in data_list:
            data.u_dm = (data.u_dm - desc_mins) / (desc_maxs - desc_mins)
            data.u_dm = torch.nan_to_num(data.u_dm, nan=0.0, posinf=0.0, neginf=0.0)

    def process(self) -> None:
        raw_files = self.raw_file_names
        
        # Get ChemBERTa model
        pipe = pipeline("feature-extraction", model="seyonec/ChemBERTa-zinc-base-v1")
        
        # Read the raw data
        df_data = pd.read_csv(os.path.join(self.raw_dir, raw_files[0]), header=0)

        # Create a list to store the data objects
        data_list = []

        # Init an array to hold the statistics of the molecular descriptors
        desc_mins = torch.ones(len(_DEFAULT_PROPERTIES_FN), dtype=torch.float) * 10_000
        desc_maxs = -torch.ones(len(_DEFAULT_PROPERTIES_FN), dtype=torch.float) * 10_000

        for idx in tqdm(range(10), desc="Processing molecules", total=len(df_data)):
            dm.disable_rdkit_log() # stop logging a lot of info for datamol methods calls
            
            smile = df_data['CXSMILES'][idx]
            mol = dm.to_mol(smile, add_hs=True)
            mol = dm.sanitize_mol(mol)
            mol = dm.fix_mol(mol)
            mol = dm.standardize_mol(
                mol,
                disconnect_metals=True,
                normalize=True,
                reionize=True,
                uncharge=False,
                stereo=True,
            )

            # Get CHEMBERTa features              
            u = pipe(df_data['CXSMILES'].iloc[idx])
            u_chem = torch.tensor(u[0][0], dtype=torch.float)
            
            # Get Datamol molecular features
            descriptors = dm.descriptors.compute_many_descriptors(mol)
            u = list(descriptors.values())
            u_dm = torch.tensor(u, dtype=torch.float)

            # Update the statistics of the molecular descriptors
            desc_maxs= torch.where(u_dm > desc_maxs, u_dm, desc_maxs)
            desc_mins = torch.where(u_dm < desc_mins, u_dm, desc_mins)
            
            # Allowable atomic node and edge features
            atomic_features = [ "atomic-number", "mass", "weight","valence","total-valence",
                                "implicit-valence","hybridization","ring", "in-ring","min-ring",
                                "max-ring","num-ring","degree","radical-electron","formal-charge",
                                "vdw-radius","covalent-radius","electronegativity","ionization",
                                "first-ionization","metal","single-bond","aromatic-bond",
                                "double-bond","triple-bond","is-carbon","group","period" ]
            
            edge_features = ["bond-type-onehot", "in-ring", "conjugated", "estimated-bond-length"]
            
            # Get float atomic features
            values_atomic_feat = gff.get_mol_atomic_features_float(mol, atomic_features, mask_nan='warn').values()
            x_array = np.column_stack(list(values_atomic_feat))
            x = torch.tensor(x_array, dtype=torch.float)
            
            # Get one-hot atomic numbers
            atoms_onehot = gff.get_mol_atomic_features_onehot(mol, ["atomic-number"]).values()
            atoms_onehot = np.column_stack(list(atoms_onehot))
            atoms_onehot = torch.tensor(atoms_onehot, dtype=torch.float)
            # Transform onehot to indices of possible atoms
            atoms = torch.where(atoms_onehot > 0)[1]

            # Generate conformers
            try:
                mol_confs = dm.conformers.generate(mol, n_confs=cfg.NUM_CONFORMERS)
                list_xyz = [mol_confs.GetConformer(i).GetPositions() for i in range(cfg.NUM_CONFORMERS)]
                pos_array = np.stack(list_xyz, axis=1)
            except Exception as e:
                logger.warning(f"Conformer generation failed for {smile}: {e}")
                logger.info("Setting atom positions to all zeros")
                pos_array = np.zeros((x.shape[0], cfg.NUM_CONFORMERS, 3))
            pos = torch.tensor(pos_array, dtype=torch.float)
            
            # Additional edge features:"bond-type-onehot", "stereo",conformer-bond-length" (might cause problems with complex molecules)
            edge_dict = gff.get_mol_edge_features(mol, edge_features, mask_nan='warn')          
            edge_list = list(edge_dict.values())
            edge_attr = np.column_stack(edge_list)
            edge_attr = torch.tensor(edge_attr, dtype=torch.float)
            
            # Repeat edge_attr twice to account for both directions of the edges
            edge_attr = edge_attr.repeat_interleave(2, dim=0)

            # Get adjacency matrix
            adj = gff.mol_to_adjacency_matrix(mol)
            edge_index = torch.stack([torch.tensor(adj.coords[0], dtype=torch.int64), torch.tensor(adj.coords[1], dtype=torch.int64)], dim=0)

            # Get the target values
            df_y = df_data[["pIC50 (MERS-CoV Mpro)", "pIC50 (SARS-CoV-2 Mpro)"]].iloc[idx]
            y = torch.tensor(np.array(df_y), dtype=torch.float)
            
            # Get a PyG data object
            data = Data(u_chem=u_chem, u_dm=u_dm, edge_attr=edge_attr, pos=pos, x=x, y=y, atoms=atoms, edge_index=edge_index)
            
            # Append the data object to the list
            data_list.append(data)
        

        # Normalize molecular descriptors
        self.normalize_mol_descriptors(data_list, desc_mins, desc_maxs)

        # Apply the pre_transform if provided
        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]


        #Save the processed data
        self.save(data_list, self.processed_paths[0])

In [7]:
train_dataset = MolDataset(root=cfg.TRAIN_DIR)
train_loader = DataLoader(
    train_dataset, 
    batch_size=cfg.BATCH_SIZE, 
    shuffle=True, 
    num_workers=cfg.NUM_WORKERS,
)

In [8]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
batch = next(iter(train_loader))
batch

DataBatch(x=[23, 24], edge_index=[2, 52], edge_attr=[52, 8], y=[2], pos=[23, 10, 3], atoms=[23], u=[1, 790], batch=[23], ptr=[2])

In [9]:
batch = batch.to(cfg.DEVICE)
batch.x.device

device(type='cuda', index=0)

In [10]:
next(mace_bary.parameters()).device

device(type='cpu')

In [11]:
mace_bary = mace_bary.to(cfg.DEVICE)

In [12]:
output = mace_bary(batch)

In [13]:
output.shape

torch.Size([69, 576])

In [None]:
from scipy.sparse.csgraph import shortest_path


[Data(x=[23, 24], edge_index=[2, 52], edge_attr=[52, 8], y=[2], pos=[23, 3, 3], atoms=[23], u=[1, 790])]

In [16]:
debug_dict = torch.load('cfm_log.pt')
N = debug_dict["N"]
Ys = debug_dict["Ys"]
Cs = debug_dict["Cs"]
ps = debug_dict["ps"]

  debug_dict = torch.load('cfm_log.pt')


In [19]:
for i in range(10):
    print(torch.all(Cs[0] == Cs[i]))

tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
