In [437]:
import torch

from torch_geometric.data import Dataset, InMemoryDataset, download_url, Data
from torch_geometric.transforms import BaseTransform, Compose
from torch_geometric.utils import remove_self_loops


In [438]:
class Scale(BaseTransform):
    
    def __call__(self, data):
        
        for key in data.keys():
            
            if key != 'y':
                offset = (data[key] - torch.min(data[key]))
                max =  torch.max(data[key] - torch.min(data[key]))
                scaled_attr = offset / max 
                data[key] = scaled_attr
        
        return data

#scale_transform = Scale()

#data_new = scale_transform(data)

#torch.max(data_new.u_dm), torch.min(data_new.u_dm)

In [439]:

class CompleteGraph(BaseTransform):
    """
    This transform adds all pairwise edges into the edge index per data sample,
    then removes self loops, i.e. it builds a fully connected or complete graph
    """
    def __call__(self, data):
        
        #device = data.edge_index.device

        row = torch.arange(data.num_nodes, dtype=torch.long)#, device=device)
        col = torch.arange(data.num_nodes, dtype=torch.long)#, device=device)

        row = row.view(-1, 1).repeat(1, data.num_nodes).view(-1)
        col = col.repeat(data.num_nodes)
        edge_index = torch.stack([row, col], dim=0)

        edge_attr = None
        if data.edge_attr is not None:
            idx = data.edge_index[0] * data.num_nodes + data.edge_index[1]
            size = list(data.edge_attr.size())
            size[0] = data.num_nodes * data.num_nodes
            edge_attr = data.edge_attr.new_zeros(size)
            edge_attr[idx] = data.edge_attr

        edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
        data.edge_attr = edge_attr
        data.edge_index = edge_index

        return data

In [440]:
class ConcatenateGlobal(BaseTransform):
    """_summary_

    Args:
        BaseTransform (_type_): _description_
    """
    
    def __call__(self, data):
        u = torch.concatenate([torch.flatten(data.u_chem), torch.flatten(data.u_dm)])
        del data.u_dm
        del data.u_chem
        data.u = u
        
        return data
        
        



In [517]:
import pandas as pd 
import logging
import os
from transformers import pipeline
import datamol as dm
from graphium.features import featurizer as gff
#logging.basicConfig

class Dataset(InMemoryDataset):
    """_summary_

    Args:
        InMemoryDataset (_type_): _description_
    """
    def __init__(self, path_data: str, type_data: str, 
                 pre_transform: callable = Compose([Scale(), CompleteGraph(), ConcatenateGlobal()]), 
                 transform = None,pre_filter=None, **kwargs):
        #self.size = size
        self.type_data = type_data
        self.path_data= path_data
        super().__init__(pre_transform=pre_transform, transform=transform, pre_filter=None, **kwargs)
    
    @property
    def raw_file_names(self) -> list[str]:
        raw_files = os.listdir(self.path_data + "/" + self.type_data + "/raw" )
        return raw_files
    
    @property
    def processed_file_names(self) -> list[str]:
        #raw_files = os.listdir(self.path_data + '/' + self.type_data + '/raw' )
        raw_files = self.raw_file_names
        total_rows = 0
        
        for file in raw_files: 
            df_data = pd.read_csv(self.path_data + "/" + self.type_data + "/raw/" +  file)
            total_rows =+ len(df_data)
        
        return  [f'data_{row}' for row in range(total_rows)]

    def process(self) -> None:
        #raw_files = os.listdir(self.path_data + '/' + self.type_data + '/raw' )
        raw_files = self.raw_file_names
        pipe = pipeline("feature-extraction", model="seyonec/ChemBERTa-zinc-base-v1")
        
        for file in raw_files:
            df_data = pd.read_csv(self.path_data + "/" + self.type_data + "/raw/" +  file)
            

            for rows in range(1):#(len(df_data)):
                dm.disable_rdkit_log() # stop logging a lot of info for datamol methods calls
                
                smile = df_data['CXSMILES'][rows]
                mol = dm.to_mol(smile, add_hs=True)
                mol = dm.sanitize_mol(mol)
                mol = dm.fix_mol(mol)
                mol = dm.standardize_mol(
                                        mol,
                                        disconnect_metals=True,
                                        normalize=True,
                                        reionize=True,
                                        uncharge=False,
                                        stereo=True,
                                        )
                                
                u1 = pipe(df_data['CXSMILES'][rows])
                u_chem = torch.tensor(u1[0][0], dtype=torch.float)
                
                
                descriptors = dm.descriptors.compute_many_descriptors(mol)
                u2 = list(descriptors.values())
                u_dm = torch.tensor(u2, dtype=torch.float)
                
                atomic_features = [ "atomic-number", "mass", "weight","valence","total-valence",
                                    "implicit-valence","hybridization","ring", "in-ring","min-ring",
                                    "max-ring","num-ring","degree","radical-electron","formal-charge",
                                    "vdw-radius","covalent-radius","electronegativity","ionization",
                                    "first-ionization","melting-point","metal","single-bond","aromatic-bond",
                                    "double-bond","triple-bond","is-carbon","group","period" ]
                
                values_atomic_feat = gff.get_mol_atomic_features_float(mol, atomic_features, mask_nan='warn').values()
                x_array = np.column_stack(list(values_atomic_feat))
                x = torch.tensor(x_array, dtype=torch.float)
                
                #refine selection of arguments for dm.confomers.generate()
                n_confs = 8
                mol_confs = dm.conformers.generate(mol, n_confs=n_confs)
                list_xyz = [mol_confs.GetConformer(i).GetPositions() for i in range(n_confs)]
                pos_array = np.stack(list_xyz, axis=1)
                pos = torch.tensor(pos_array, dtype=torch.float)
                
                edge_features = ["bond-type-float", "in-ring", "conjugated", "estimated-bond-length"]
                # additional edge features:"bond-type-onehot", "stereo",conformer-bond-length" (might cause problems with complex molecules)
                edge_dict = gff.get_mol_edge_features(mol, edge_features, mask_nan='warn')          
                edge_list = list(edge_dict.values())
                edge = np.column_stack(edge_list)
                edge_feat = torch.tensor(edge, dtype=torch.float)
                
                df_y = df_data[["pIC50 (MERS-CoV Mpro)", "pIC50 (SARS-CoV-2 Mpro)"]]
                y = torch.tensor(np.array(df_y), dtype=torch.float)
                
                data = Data(u_chem=u_chem, u_dm=u_dm, edge_feat=edge_feat, pos=pos, x=x, y=y)
                
                
                """ # Scale (all), Concatenate(only global features u) and complete graph
                if self.pre_transform is not None:
                    data = self.pre_transform(data)
                
                #Save the processed data
                path_processed = self.path_data + "/" + self.type_data + "/" + "processed" 
                
                if os.path.isdir(path_processed) != False:
                    os.mkdir(path_processed)
                
                torch.save(data, os.path.join(path_processed, f"data_{rows}.pt"))
                 """
            
            


In [518]:

data = Dataset('/home/slav/Documents/github/repo/DSR_Project/polaris-ligand-potency/data', 'train' ).process()




Processing...
Device set to use cuda:0
Done!
Device set to use cuda:0


In [502]:
data


Data(x=[22, 25], y=[1031, 2], pos=[22, 8, 3], edge_feat=[24, 4], edge_index=[2, 462], u=[790])

In [311]:
print( 'x:', torch.max(data.x), 'u_dm:', torch.max(data.u_dm),  'pos:', torch.max(data.pos), 'u_chem:', torch.max(data.u_chem))

x: tensor(6.) u_dm: tensor(299.1270) pos: tensor(4.2711) u_chem: tensor(3.9375)


In [496]:
transform = Compose([Scale(), ConcatenateGlobal(), CompleteGraph()])

data_new = transform(data)

data_new.keys(), data.keys()
#torch.max(data_new.u_dm), torch.min(data_new.u_dm)

(['x', 'edge_feat', 'edge_index', 'y', 'pos', 'u'],
 ['x', 'edge_feat', 'u_dm', 'y', 'pos', 'u_chem'])

In [468]:
os.path.isdir("./DSRProject/polaris-ligand-potency")

False

In [None]:
py