# MiSaTo-Dataset: a tutorial


In this notebook we will show how our QM and MD dataset are stored in h5 files. We also show how the data can be loaded so that it can be used by a deep learning model. 

We start by importing the useful packages and set up the paths of the files

In [1]:
import os
import sys
import h5py
 
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader

sys.path.insert(0, '/p/project/hai_drug_qm/MiSaTo-dataset/src/data/')
sys.path.insert(0, '/p/project/hai_drug_qm/MiSaTo-dataset/src/data/components/')
sys.path.insert(0, '/p/project/hai_drug_qm/MiSaTo-dataset/data/QM/')

from datasets import MolDataset, ProtDataset
from transformQM import GNNTransformQM
from transformMD import GNNTransformMD
from qm_datamodule import QMDataModule
from md_datamodule import MDDataModule

In [2]:
h5_file = "/p/project/hai_drug_qm/MiSaTo-dataset/data/QM/h5_files/qm.hdf5"
norm_file = "/p/project/hai_drug_qm/MiSaTo-dataset/data/QM/h5_files/qm_norm_fold1.hdf5"
norm_txtfile = "/p/project/hai_drug_qm/MiSaTo/data/QM/splits/train_norm_fold1.txt"

## H5 files presentations

We read the QM H5 file and H5 file used to normalize the target values.

In [3]:
qm_H5File = h5py.File(h5_file)
qm_normFile = h5py.File(norm_file)

The molecule names can be accessed using keys method.

In [4]:
list(qm_H5File.keys())[:10]

['10gs',
 '11gs',
 '13gs',
 '16pk',
 '184l',
 '185l',
 '186l',
 '187l',
 '188l',
 '1a07']

Target values can be accessed by specifiying into bracket the molecule name, then mol_properties and finally the name of the target value that we want to access: 

In [4]:
qm_H5File["5gmm"]["mol_properties"]["Electron_Affinity"][()]

7.7383

We can access to the mean and standard-deviation of each target value by specifiying it into bracket.
We first specify the set, then the target value and finally either mean or std. 

In [5]:
print(qm_normFile["train"]["Electron_Affinity"]["mean"][()])
print(qm_normFile["train"]["Electron_Affinity"]["std"][()])

6.472403422183977
3.811168623042665


## Datasets and dataloaders

### PyTorch

The QM and MD datasets are warped into a PyTorch Dataset class under the name MolDataset and ProtDataset, respectively. 
The parameters taken by the two classes as well as their types can be found as follow.

In [6]:
help(MolDataset)

Help on class MolDataset in module datasets:

class MolDataset(torch.utils.data.dataset.Dataset)
 |  MolDataset(data_file, idx_file, target_norm_file, transform, isTrain=False, post_transform=None)
 |  
 |  Load the QM dataset.
 |  
 |  Method resolution order:
 |      MolDataset
 |      torch.utils.data.dataset.Dataset
 |      typing.Generic
 |      builtins.object
 |  
 |  Methods defined here:
 |  
 |  __getitem__(self, index: int)
 |  
 |  __init__(self, data_file, idx_file, target_norm_file, transform, isTrain=False, post_transform=None)
 |      Args:
 |          data_file (str): H5 file path
 |          idx_file (str): path of txt file which contains pdb ids for a specific split such as train, val or test.
 |          target_norm_file (str): H5 file path where training mean and std are stored.  
 |          transform (obj): class that convert a dict to a PyTorch Geometric graph.
 |          isTrain (bool, optional): Flag to standardize the target values (only used for train set).

In [7]:
help(ProtDataset)

Help on class ProtDataset in module datasets:

class ProtDataset(torch.utils.data.dataset.Dataset)
 |  ProtDataset(md_data_file, idx_file, transform=None, post_transform=None)
 |  
 |  Load the MD dataset
 |  
 |  Method resolution order:
 |      ProtDataset
 |      torch.utils.data.dataset.Dataset
 |      typing.Generic
 |      builtins.object
 |  
 |  Methods defined here:
 |  
 |  __getitem__(self, index: int)
 |  
 |  __init__(self, md_data_file, idx_file, transform=None, post_transform=None)
 |      Args:
 |          md_data_file (str): H5 file path
 |          idx_file (str): path of txt file which contains pdb ids for a specific split such as train, val or test.
 |          transform (obj): class that convert a dict to a PyTorch Geometric graph.
 |          post_transform (PyTorch Geometric, optional): data augmentation. Defaults to None.
 |  
 |  __len__(self) -> int
 |  
 |  ----------------------------------------------------------------------
 |  Data and other attributes de

We can load the data by instanciating MolDataset and providing the QM H5 file, the text file that indicates the molecule used for training and the norm file used to normalize the target values. 

The MolDataset class without any transform return a dictionary that contain the elements and their coordinates. We use GNNTransformQM class to transform our data to a graph that can be used by a GNN. The parameter post_transform is another transformation used to perform data augmentation.

In [3]:
transform = T.RandomTranslate(0.25)
batch_size = 128
num_workers = 48

data_train = MolDataset(h5_file, norm_txtfile, target_norm_file=norm_file, transform=GNNTransformQM(), post_transform=transform)


Finally, we can load our data using the PyTorch DataLoader.

In [4]:
train_loader = DataLoader(data_train, batch_size, shuffle=True, num_workers=0)

for idx, val in enumerate(train_loader):
    print(val)
    break

DataBatch(x=[7679, 25], edge_index=[2, 147188], edge_attr=[147188, 1], y=[256], pos=[7679, 3], id=[128], batch=[7679], ptr=[129])


### PyTorch lightning 

The QMDataModule is a class inherated from LightningDataModule that instanciate the MolDataset for training, validation and test set and retrun a dataloader for each set. 

We start by instanciation of the QMDataModule

In [4]:
files_root =  "/p/project/hai_drug_qm/MiSaTo-dataset/data/QM/"
fold = 1
qmdata = QMDataModule(files_root, fold)


Then, we call the setup function to instanciate the MolDataset for training, validation and test set

In [5]:
qmdata.setup()

Data(x=[37, 25], edge_index=[2, 648], edge_attr=[648, 1], y=[2], pos=[37, 3], id='6f23')


Finally, we can return a dataloader for each set.

In [6]:
train_loader = qmdata.train_dataloader()

for idx, val in enumerate(train_loader):
    print(val)
    break
    

DataBatch(x=[7847, 25], edge_index=[2, 150976], edge_attr=[150976, 1], y=[256], pos=[7847, 3], id=[128], batch=[7847], ptr=[129])


# MD dataset

The same steps can be used to load the MD dataset

In [8]:
mdh5_file = '/p/project/hai_drug_qm/MiSaTo-dataset/data/MD/h5_files/MD_dataset_soft_hard_noH.hdf5'
train_idx = "/p/project/hai_drug_qm/MiSaTo-dataset/data/MD/splits/train_soft_hard.txt"

post_transform = T.RandomTranslate(0.1)

train_dataset = ProtDataset(mdh5_file, train_idx, transform=GNNTransformMD(), post_transform=post_transform)

In [9]:
train_loader = DataLoader(train_dataset, batch_size, shuffle=True, num_workers=0)

for idx, val in enumerate(train_loader):
    print(val)
    break

DataBatch(x=[405300, 11], edge_index=[2, 6424088], edge_attr=[6424088], y=[405300], pos=[405300, 3], ids=[128], batch=[405300], ptr=[129])


In [10]:
files_root =  "/p/project/hai_drug_qm/MiSaTo-dataset/data/MD"

mddata = MDDataModule(files_root)

In [11]:
mddata.setup()

In [12]:
train_loader = mddata.train_dataloader()

for idx, val in enumerate(train_loader):
    print(val)
    break

DataBatch(x=[54188, 11], edge_index=[2, 865460], edge_attr=[865460], y=[54188], pos=[54188, 3], ids=[16], batch=[54188], ptr=[17])
