# MiSaTo-Dataset: a tutorial

In this notebook, we will show how our QM and MD dataset are stored in h5 files. We also show how the data can be loaded so that it can be used by a deep learning model. 

We start by importing the useful packages and set up the paths of the files

In [1]:
import h5py
 
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader

from data.components.datasets import MolDataset, ProtDataset
from data.components.transformQM import GNNTransformQM
from data.components.transformMD import GNNTransformMD
from data.qm_datamodule import QMDataModule
from data.md_datamodule import MDDataModule

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
qmh5_file = "../data/QM/h5_files/tiny-qm.hdf5"
norm_file = "../data/QM/h5_files/qm_norm_fold1.hdf5"
norm_txtfile = "../data/QM/splits/train_norm_fold1.txt"

## H5 files presentations

We read the QM H5 file and H5 file used to normalize the target values.

In [5]:
qm_H5File = h5py.File(qmh5_file)
qm_normFile = h5py.File(norm_file)

The ligands can be accessed using the pdb-id. Bellow we show the first ten molecules of the file.

In [6]:
qm_H5File.keys()

<KeysViewHDF5 ['10gs', '11gs', '13gs', '16pk', '184l', '185l', '186l', '187l', '188l', '1a07', '1a08', '1a09', '1a0q', '1a0tA', '1a0tB', '1a1b', '1a1c', '1a1e', '1a28', '1a2c', '1a30', '1a37', '1a42', '1a46', '1a4g', '1a4h', '1a4k', '1a4m', '1a4q', '1a4r', '1a4w', '1a50', '1a52', '1a5g', '1a5h', '1a5v', '1a61', '1a69', '1a7c', '1a7t', '1a7x', '1a85', '1a86', '1a8i', '1a8t', '1a94', '1a99', '1a9m', '1a9q', '1a9u']>

You can access to the molecule trajectories as follow

In [7]:
prop = qm_H5File["10gs"]["atom_properties"]["atom_properties_values"]

x = prop[:,0]
y = prop[:,1]
z = prop[:,2]

Target values can be accessed by specifiying into bracket the molecule name, then mol_properties and finally the name of the target value that we want to access: 

In [8]:
qm_H5File["10gs"]["mol_properties"]["Electron_Affinity"][()]

6.0974

We can access to the mean and standard-deviation of each target value by specifiying it into bracket.
We first specify the set, then the target value and finally either mean or std. 

In [9]:
qm_normFile.keys()

<KeysViewHDF5 ['Electron_Affinity', 'Electronegativity', 'Hardness', 'Ionization_Potential']>

In [10]:
print(qm_normFile["Electron_Affinity"]["mean"][()])
print(qm_normFile["Electron_Affinity"]["std"][()])

6.33265
18.636927


## Datasets and dataloaders

### PyTorch

The QM and MD datasets are warped into a PyTorch Dataset class under the name MolDataset and ProtDataset, respectively. 
The parameters taken by the two classes as well as their types can be found as follow.

In [11]:
help(MolDataset)

Help on class MolDataset in module data.components.datasets:

class MolDataset(torch.utils.data.dataset.Dataset)
 |  MolDataset(data_file, idx_file, target_norm_file, transform, isTrain=False, post_transform=None)
 |  
 |  Load the QM dataset.
 |  
 |  Method resolution order:
 |      MolDataset
 |      torch.utils.data.dataset.Dataset
 |      typing.Generic
 |      builtins.object
 |  
 |  Methods defined here:
 |  
 |  __getitem__(self, index: int)
 |  
 |  __init__(self, data_file, idx_file, target_norm_file, transform, isTrain=False, post_transform=None)
 |      Args:
 |          data_file (str): H5 file path
 |          idx_file (str): path of txt file which contains pdb ids for a specific split such as train, val or test.
 |          target_norm_file (str): H5 file path where training mean and std are stored.  
 |          transform (obj): class that convert a dict to a PyTorch Geometric graph.
 |          isTrain (bool, optional): Flag to standardize the target values (only used

In [12]:
help(ProtDataset)

Help on class ProtDataset in module data.components.datasets:

class ProtDataset(torch.utils.data.dataset.Dataset)
 |  ProtDataset(md_data_file, idx_file, transform=None, post_transform=None)
 |  
 |  Load the MD dataset
 |  
 |  Method resolution order:
 |      ProtDataset
 |      torch.utils.data.dataset.Dataset
 |      typing.Generic
 |      builtins.object
 |  
 |  Methods defined here:
 |  
 |  __getitem__(self, index: int)
 |  
 |  __init__(self, md_data_file, idx_file, transform=None, post_transform=None)
 |      Args:
 |          md_data_file (str): H5 file path
 |          idx_file (str): path of txt file which contains pdb ids for a specific split such as train, val or test.
 |          transform (obj): class that convert a dict to a PyTorch Geometric graph.
 |          post_transform (PyTorch Geometric, optional): data augmentation. Defaults to None.
 |  
 |  __len__(self) -> int
 |  
 |  ----------------------------------------------------------------------
 |  Data and oth

We can load the data by instanciating MolDataset and providing the QM H5 file, the text file that indicates the molecule used for training and the norm file used to normalize the target values. 

The MolDataset class without any transform return a dictionary that contain the elements and their coordinates. We use GNNTransformQM class to transform our data to a graph that can be used by a GNN. The parameter post_transform is another transformation used to perform data augmentation.

In [13]:
train = "../data/QM/splits/train_tinyQM.txt"

transform = T.RandomTranslate(0.25)
batch_size = 128
num_workers = 48

data_train = MolDataset(qmh5_file, train, target_norm_file=norm_file, transform=GNNTransformQM(), post_transform=transform)

Finally, we can load our data using the PyTorch DataLoader.

In [14]:
train_loader = DataLoader(data_train, batch_size, shuffle=True, num_workers=0)

for idx, val in enumerate(train_loader):
    print(val)
    break

DataBatch(x=[1602, 25], edge_index=[2, 30354], edge_attr=[30354, 1], y=[60], pos=[1602, 3], id=[30], batch=[1602], ptr=[31])


### PyTorch lightning 

The QMDataModule is a class inherated from LightningDataModule that instanciate the MolDataset for training, validation and test set and retrun a dataloader for each set. 

We start by instanciation of the QMDataModule

In [33]:
files_root =  "../data/QM"

qmh5file = "h5_files/tiny-qm.hdf5"

tr = "splits/train_tinyQM.txt"
v = "splits/val_tinyQM.txt"
te = "splits/test_tinyQM.txt"

qmdata = QMDataModule(files_root, h5file=qmh5file, train=tr, val=v, test=te, num_workers=0)

Then, we call the setup function to instanciate the MolDataset for training, validation and test set

In [34]:
qmdata.setup()

Finally, we can return a dataloader for each set.

In [35]:
train_loader = qmdata.train_dataloader()

for idx, val in enumerate(train_loader):
    print(val)
    break
    

DataBatch(x=[1602, 25], edge_index=[2, 30354], edge_attr=[30354, 1], y=[60], pos=[1602, 3], id=[30], batch=[1602], ptr=[31])


# MD dataset

The same steps can be used to load the MD dataset

In [36]:
mdh5_file = '../data/MD/h5_files/tiny-md.hdf5'
md_H5File = h5py.File(mdh5_file)

cutoff = md_H5File["10gs"]["molecules_begin_atom_index"][:][-1]

x = md_H5File["10gs"]["atoms_coordinates_ref"][:][:cutoff, 0]
y = md_H5File["10gs"]["atoms_coordinates_ref"][:][:cutoff, 1]
z = md_H5File["10gs"]["atoms_coordinates_ref"][:][:cutoff, 2]

In [37]:
train_idx = "../data/MD/splits/train_tinyMD.txt"
batch_size = 2
post_transform = T.RandomTranslate(0.1)

train_dataset = ProtDataset(mdh5_file, train_idx, transform=GNNTransformMD(), post_transform=post_transform)

In [38]:
train_loader = DataLoader(train_dataset, batch_size, shuffle=True)

for idx, val in enumerate(train_loader):
    print(val)
    break

DataBatch(x=[6151, 11], edge_index=[2, 97370], edge_attr=[97370], y=[6151], pos=[6151, 3], ids=[2], batch=[6151], ptr=[3])


In [44]:
files_root =  "../data/MD"

mdh5_file = 'h5_files/tiny-md.hdf5'

train_idx = "splits/train_tinyMD.txt"
val_idx = "splits/val_tinyMD.txt"
test_idx = "splits/test_tinyMD.txt"



mddata = MDDataModule(files_root, h5file=mdh5_file, train=train_idx, val=val_idx, test=test_idx, num_workers=0)

In [45]:
mddata.setup()

In [46]:
train_loader = mddata.train_dataloader()

for idx, val in enumerate(train_loader):
    print(val)
    break

DataBatch(x=[36985, 11], edge_index=[2, 583786], edge_attr=[583786], y=[36985], pos=[36985, 3], ids=[16], batch=[36985], ptr=[17])
