# Processing Drug-Target Interaction Data

This notebook covers:
- Converting merged DTI data into an `h5torch` dataset
- Splitting the dataset (stratified) into train/val/test in two settings: random split, and cold-start split
- Computing embeddings from foundation models and storing them in the `h5torch` file
    - Drugs: `MMELON` (graph, image, text), and `RDKit` fingerprints
    - Targets: `NT`, `ESM`, and `ESPF` fingerprints
- Visualizing the foundatoin model embeddings

In [2]:
from resolve import *

Setting working directory to: /home/robsyc/Desktop/thesis/MB-VAE-DTI


The drug and protein embedding generation was offloaded to an HPC. We used: 
- Digital Ocean droplet with a 48 GB NVIDIA L40S GPU
- Ubuntu 22.04, Python3.11 and basic virtual environments

Due to dependency-conflicts between the foundation models, we had to create a new venv for each model (basic `requirements.txt` files can be found in the corresponding folders in the `external` directory). Check the `scripts/embedding.sh` file for more details.

The `embeddings.sh` script creates HDF5 files in the `external/temp` directory. Namely, `dti_smiles.hdf5`, `dti_aa.hdf5`, and `dti_dna.hdf5` for the DTI dataset, and `pretrain_smiles.hdf5`, `pretrain_aa.hdf5`, and `pretrain_dna.hdf5` for the pre-training datasets.

These files are then used to construct the `h5torch` files using the `h5torch_creation.py` script, namely, `dti.h5torch`, `drugs.h5torch`, and `targets.h5torch`. Below we inspect the structure and contents of these files as well as how they are used to instantiate `PretrainDataset` and `DTIDataset` dataloaders.

> Note: The pretrain dataset `drugs.h5torch` was limited to 2 million entities (from original 3,460,396) due to storage constraints. See `cap_drugs_h5torch.py` for more details.

## Pretrain Datasets

In [2]:
from mb_vae_dti.processing import inspect_h5torch_file
from pathlib import Path

output_dir = Path("/home/robsyc/Desktop/thesis/MB-VAE-DTI/data/input")

target_output_file = output_dir / "targets.h5torch"
inspect_h5torch_file(target_output_file)

drug_output_file = output_dir / "drugs.h5torch"
inspect_h5torch_file(drug_output_file)

2025-07-13 17:04:29,318 - mb_vae_dti.processing.h5factory - INFO - --- Inspecting H5torch File: targets.h5torch ---
2025-07-13 17:04:29,324 - mb_vae_dti.processing.h5factory - INFO - --- Finished Inspecting: targets.h5torch ---
2025-07-13 17:04:29,324 - mb_vae_dti.processing.h5factory - INFO - --- Inspecting H5torch File: drugs.h5torch ---
2025-07-13 17:04:29,328 - mb_vae_dti.processing.h5factory - INFO - --- Finished Inspecting: drugs.h5torch ---



[Root Attributes]
  - entity_type: target
  - n_items: 190851

[Central Dataset]
  Mode: N/A (Implicitly N-D or similar)
    - Name: central
      - Path: /central
      - Shape/Length: (190851,)
      - Saved Dtype: uint32

[Aligned Axes]

  --- Axis 0 ---
    - Name: EMB-ESM
      - Path: /0/EMB-ESM
      - Shape/Length: (190851, 1152)
      - Saved Dtype: float32
    - Name: EMB-NT
      - Path: /0/EMB-NT
      - Shape/Length: (190851, 1024)
      - Saved Dtype: float32
    - Name: FP-ESP
      - Path: /0/FP-ESP
      - Shape/Length: (190851, 4170)
      - Saved Dtype: uint8
    - Name: aa
      - Path: /0/aa
      - Shape/Length: Length: 190851
      - Saved Dtype: |S1280
    - Name: dna
      - Path: /0/dna
      - Shape/Length: Length: 190851
      - Saved Dtype: |S3843

[Unstructured Datasets]
    - Name: is_train
      - Path: /unstructured/is_train
      - Shape/Length: (190851,)
      - Saved Dtype: bool

[Root Attributes]
  - entity_type: drug
  - n_items: 2000000

[Central

In [31]:
from mb_vae_dti.processing import PretrainDataset
from external.ESPF.script import get_target_fingerprint
import numpy as np

targets_pretrain_training = PretrainDataset(
    h5_path=target_output_file,
    subset_filters={'split_col': 'is_train', 'split_value': True}
)
sample = targets_pretrain_training[42]
for key, value in sample.items():
    print(key, value)

np.all(sample["features"]["FP-ESP"] == get_target_fingerprint(sample["representations"]["aa"]))

2025-05-30 18:15:27,808 - INFO - Subset mask for targets.h5torch: kept 171765 / 190851 items
2025-05-30 18:15:27,810 - INFO - Initialized PretrainDataset from targets.h5torch. Size: 171765 items.
2025-05-30 18:15:27,811 - INFO -   Features (Axis 0): ['EMB-ESM', 'EMB-NT', 'FP-ESP']
2025-05-30 18:15:27,811 - INFO -   Representations (Axis 0): ['aa', 'dna']


id 49
representations {'aa': 'MAAAMTFCRLLNRCGEAARSLPLGARCFGVRVSPTGEKVTHTGQVYDDKDYRRIRFVGRQKEVNENFAIDLIAEQPVSEVETRVIACDGGGGALGHPKVYINLDKETKTGTCGYCGLQFRQHHH', 'dna': 'ATGGCGGCGGCGATGACCTTCTGCCGGCTGCTGAACCGGTGCGGCGAGGCGGCGCGGAGCCTGCCCCTGGGCGCCAGGTGTTTCGGGGTGCGGGTCTCGCCGACCGGGGAGAAGGTCACGCACACTGGCCAGGTTTATGATGATAAAGACTACAGGAGAATTCGGTTTGTAGGTCGTCAGAAAGAGGTGAATGAAAACTTTGCCATTGATTTGATAGCAGAGCAGCCCGTGAGCGAGGTGGAGACTCGGGTGATAGCGTGCGATGGCGGCGGGGGAGCTCTTGGCCACCCAAAAGTGTATATAAACTTGGACAAAGAAACAAAAACCGGCACATGCGGTTACTGTGGGCTCCAGTTCAGACAGCACCACCACTAG'}
features {'EMB-ESM': array([-0.01264881,  0.00669643, -0.00759549, ...,  0.00806052,
        0.01426091, -0.00678943], dtype=float32), 'EMB-NT': array([ 0.3568277 ,  0.11620766, -0.11930461, ...,  0.15212396,
       -0.13019717,  0.31840327], dtype=float32), 'FP-ESP': array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)}


True

In [39]:
from mb_vae_dti.processing import PretrainDataset
from external.MorganFP.script import get_drug_fingerprint
import numpy as np

drugs_pretrain_validation = PretrainDataset(
    h5_path=drug_output_file,
    subset_filters={'split_col': 'is_train', 'split_value': False}
)
sample = drugs_pretrain_validation[42]
print(sample)

np.all(sample["features"]["FP-Morgan"] == get_drug_fingerprint(sample["representations"]["smiles"]))

2025-05-30 18:21:07,335 - INFO - Subset mask for drugs.h5torch: kept 200000 / 2000000 items
2025-05-30 18:21:07,342 - INFO - Initialized PretrainDataset from drugs.h5torch. Size: 200000 items.
2025-05-30 18:21:07,342 - INFO -   Features (Axis 0): ['EMB-BiomedGraph', 'EMB-BiomedImg', 'EMB-BiomedText', 'FP-Morgan']
2025-05-30 18:21:07,343 - INFO -   Representations (Axis 0): ['smiles']


{'id': 313, 'representations': {'smiles': 'CCOc1cc2c(c(O)c1OCC)C(=O)NC1C2CC(O)C(O)C1O'}, 'features': {'EMB-BiomedGraph': array([ 3.28338034e-02,  5.93008325e-02, -6.75319880e-02, -2.33449712e-01,
        4.99292940e-01,  3.82202864e-03,  8.50097910e-02, -1.92996599e-02,
       -2.97931135e-01, -9.10175368e-02, -3.02154664e-03, -5.64331174e-01,
        1.00632846e-01, -3.97299835e-03, -1.95052072e-01,  8.50920454e-02,
        6.33843467e-02,  1.54944748e-01,  4.93423976e-02,  1.09941289e-01,
       -1.31553829e-01, -2.16462798e-02, -4.77177389e-02,  3.33764516e-02,
        2.83989847e-01, -8.55271611e-03,  1.59002018e+00, -5.26811257e-02,
        4.47995365e-02,  2.23343277e+00,  7.47756287e-03, -2.32666992e-02,
       -2.66399048e-02,  5.62607646e-02,  2.10713923e-01, -4.17282850e-01,
       -2.69951411e-02,  8.95682499e-02, -8.06409940e-02, -1.74574982e-02,
        4.65488546e-02,  7.21551552e-02,  2.56469473e-02,  4.38158214e-02,
       -6.77484721e-02, -1.05443504e-02, -2.94203628e-

True

## DTI Dataset

In [3]:
from mb_vae_dti.processing import inspect_h5torch_file
from pathlib import Path

output_dir = Path("/home/robsyc/Desktop/thesis/MB-VAE-DTI/data/input")

dti_output_file = output_dir / "dti.h5torch"
inspect_h5torch_file(dti_output_file)

2025-07-13 17:04:36,856 - mb_vae_dti.processing.h5factory - INFO - --- Inspecting H5torch File: dti.h5torch ---
2025-07-13 17:04:36,869 - mb_vae_dti.processing.h5factory - INFO - --- Finished Inspecting: dti.h5torch ---



[Root Attributes]
  - created_at: 2025-06-27T15:54:43.511488
  - n_drugs: 126811
  - n_interactions: 339197
  - n_targets: 1976
  - sparsity: 0.0013536554463707139

[Central Dataset]
  Mode: coo
  Shape (Attr): [126811   1976]
    - Name: indices
      - Path: /central/indices
      - Shape/Length: (2, 339197)
      - Saved Dtype: int64
    - Dataset 'values' not found or not a dataset.

[Aligned Axes]

  --- Axis 0 ---
    - Name: Drug_ID
      - Path: /0/Drug_ID
      - Shape/Length: Length: 126811
      - Saved Dtype: |S7
    - Name: Drug_InChIKey
      - Path: /0/Drug_InChIKey
      - Shape/Length: Length: 126811
      - Saved Dtype: |S27
    - Name: EMB-BiomedGraph
      - Path: /0/EMB-BiomedGraph
      - Shape/Length: (126811, 512)
      - Saved Dtype: float32
    - Name: EMB-BiomedImg
      - Path: /0/EMB-BiomedImg
      - Shape/Length: (126811, 512)
      - Saved Dtype: float32
    - Name: EMB-BiomedText
      - Path: /0/EMB-BiomedText
      - Shape/Length: (126811, 768)
     

In [6]:
from mb_vae_dti.processing import DTIDataset
from external.MorganFP.script import get_drug_fingerprint
from external.ESPF.script import get_target_fingerprint
import numpy as np

dti_dataset = DTIDataset(
    h5_path=dti_output_file,
    subset_filters={
        'split_col': 'split_rand',
        'split_value': 'train',
        'provenance_cols': ['in_DAVIS']#, 'in_KIBA']
        }
)
sample = dti_dataset[42]
print(sample)

print(np.all(sample["drug"]["features"]["FP-Morgan"] == get_drug_fingerprint(sample["drug"]["representations"]["SMILES"])))
print(np.all(sample["target"]["features"]["FP-ESP"] == get_target_fingerprint(sample["target"]["representations"]["AA"])))

2025-07-13 17:31:22,620 - mb_vae_dti.processing.h5datasets - INFO - Subset mask for dti.h5torch: kept 13805 / 339197 items
2025-07-13 17:31:22,625 - mb_vae_dti.processing.h5datasets - INFO - Pre-loaded unstructured Y data for 3 columns: ['Y_KIBA', 'Y_pKd', 'Y_pKi']
2025-07-13 17:31:22,626 - mb_vae_dti.processing.h5datasets - INFO - Initialized DTIDataset from dti.h5torch. Size: 13805 interactions.
2025-07-13 17:31:22,626 - mb_vae_dti.processing.h5datasets - INFO -   Drug paths (Axis 0): ['Drug_ID', 'Drug_InChIKey', 'EMB-BiomedGraph', 'EMB-BiomedImg', 'EMB-BiomedText', 'FP-Morgan', 'SMILES']
2025-07-13 17:31:22,626 - mb_vae_dti.processing.h5datasets - INFO -   Target paths (Axis 1): ['AA', 'DNA', 'EMB-ESM', 'EMB-NT', 'FP-ESP', 'Target_Gene_name', 'Target_ID', 'Target_RefSeq_ID', 'Target_UniProt_ID']
2025-07-13 17:31:22,627 - mb_vae_dti.processing.h5datasets - INFO -   Y paths (Central/Unstructured): ['Y_KIBA', 'Y_pKd', 'Y_pKi']
2025-07-13 17:31:22,627 - mb_vae_dti.processing.h5datasets 

{'id': 42, 'y': {'Y': 0.0, 'Y_KIBA': None, 'Y_pKd': 5.7212234, 'Y_pKi': None}, 'drug': {'id': {'Drug_ID': 'D000525', 'Drug_InChIKey': 'AAKJLRGGTJKAMG-UHFFFAOYSA-N'}, 'representations': {'SMILES': 'C#Cc1cccc(Nc2ncnc3cc(OCCOC)c(OCCOC)cc23)c1'}, 'features': {'EMB-BiomedGraph': array([ 3.21993306e-02,  5.82491681e-02, -6.79564700e-02, -1.92157716e-01,
       -5.96494898e-02,  3.29180807e-03,  8.39334726e-02, -1.98042523e-02,
       -2.98353732e-01, -9.14402083e-02, -3.63479182e-03, -5.64899325e-01,
        9.95670855e-02, -4.49689757e-03, -1.95666224e-01,  8.42691138e-02,
        6.25201836e-02,  1.54438481e-01,  4.85873409e-02,  2.14559078e-01,
       -1.32025585e-01, -2.21281070e-02, -4.82717119e-02,  3.27815562e-02,
        1.01557541e+00, -9.07609425e-03,  2.10126710e+00, -5.31653315e-02,
        4.42590490e-02,  2.30898070e+00,  6.85733650e-03, -1.04858592e-01,
       -2.74030678e-02,  5.55637218e-02,  1.41454265e-01, -4.17915612e-01,
       -2.76262220e-02,  8.86448920e-02, -8.108396

In [10]:
print(sample['drug']['id'])
print(sample['target']['id'])
print(sample['y'])

{'Drug_ID': 'D000465', 'Drug_InChIKey': 'WOTLXQZLXOXMFD-UHFFFAOYSA-N'}
{'Target_Gene_name': 'PLK1', 'Target_ID': 'T000212', 'Target_RefSeq_ID': 'NM_005030', 'Target_UniProt_ID': 'P53350'}
{'Y': 0.0, 'Y_KIBA': 11.2, 'Y_pKd': None, 'Y_pKi': None}


In [None]:
import pandas as pd

df = pd.read_csv("data/processed/dti.csv")

dID = 'D000465'
tID = 'T000212'

# get row where Drug_ID == dID and Target_ID == tID
df[(df['Drug_ID'] == dID) & (df['Target_ID'] == tID)]

Unnamed: 0,Drug_ID,Drug_InChIKey,Drug_SMILES,Target_ID,Target_UniProt_ID,Target_Gene_name,Target_RefSeq_ID,Target_AA,Target_DNA,Y,Y_pKd,Y_pKi,Y_KIBA,in_DAVIS,in_BindingDB_Kd,in_BindingDB_Ki,in_Metz,in_KIBA
802,D000465,WOTLXQZLXOXMFD-UHFFFAOYSA-N,C#Cc1cc2c(cc1OC)-c1[nH]nc(-c3ccc(C#N)nc3)c1C2,T000212,P53350,PLK1,NM_005030,MSAAVTAGKLARAPADPGKAGVPGVAAPGAPAAAPPAKEIPEVLVD...,ATGAGTGCTGCAGTGACTGCAGGGAAGCTGGCACGGGCACCGGCCG...,False,,,11.2,False,False,False,False,True


---
---

## Molecular statistics

The full discrete denoising diffusion model's forward process requires statistics of the drug modules, these include: 
- Heavy atom count distribution: heavy atom node count of the molecular graphs & the number of occurrences of that molecule size in the dataset
- Max number of nodes (heavy atoms): used to determine the size of the adjacency matrix
- Marginal distribution over node & edge types
- Chemical properties of atoms in your dataset:
  - max molecular weight
  - weight of each atom type
  - valence of each atom type

In [5]:
import pandas as pd
import numpy as np
from mb_vae_dti.processing.split import add_splits

df = pd.read_csv("data/processed/dti.csv")
df = add_splits(df, split_fractions=(0.8, 0.1, 0.1), stratify=True, random_state=42)
# dti_drugs = df["Drug_SMILES"].unique()

np.random.seed(42)
indices = np.random.permutation(num_items)
num_train = int(num_items * train_frac)

# Create boolean array: True for train, False for validation
split_data = np.zeros(num_items, dtype=bool)
split_data[indices[:num_train]] = True


In [6]:
df

Unnamed: 0,Drug_ID,Drug_InChIKey,Drug_SMILES,Target_ID,Target_UniProt_ID,Target_Gene_name,Target_RefSeq_ID,Target_AA,Target_DNA,Y,Y_pKd,Y_pKi,Y_KIBA,in_DAVIS,in_BindingDB_Kd,in_BindingDB_Ki,in_Metz,in_KIBA
0,D000001,HYTVYLVVJDEURY-AUCFXJAVSA-N,C#CC(=O)C1(C)CCC2c3ccc(O)cc3CCC2C1,T000001,P14061,HSD17B1,NM_000413,MARTVVLITGCSSGIGLHLAVRLASDPSQSFKVYATLRDLKTQGRL...,ATGGCCCGCACCGTGGTGCTCATCACCGGCTGTTCCTCGGGCATCG...,False,,5.552826,,False,False,True,False,False
1,D000002,CFCGTXOJJUJIIE-UHFFFAOYSA-N,C#CC(C#C)=C1CCC(N(CCC)CCC)CC1,T000002,P35462,DRD3,NM_000796,MASLSQLSSHLNYTCGAENSTGASQARPHAYYALSYCALILAIVFG...,ATGGCATCTCTGAGCCAGCTGAGTGGCCACCTGAACTACACCTGTG...,False,,5.356537,,False,False,True,False,False
2,D000002,CFCGTXOJJUJIIE-UHFFFAOYSA-N,C#CC(C#C)=C1CCC(N(CCC)CCC)CC1,T000003,P14416,DRD2,NM_000795,MDPLNLSWYDDDLERQNWSRPFNGSDGKADRPHYNYYATLLTLLIA...,ATGGATCCACTGAATCTGTCCTGGTATGATGATGATCTGGAGAGGC...,False,,4.809891,,False,False,True,False,False
3,D000002,CFCGTXOJJUJIIE-UHFFFAOYSA-N,C#CC(C#C)=C1CCC(N(CCC)CCC)CC1,T000004,Q95136,DRD1,NM_174042,MRTLNTSTMEGTGLVAERDFSFRILTACFLSLLILSTLLGNTLVCA...,ATGAGGACTCTCAACACGTCTACCATGGAAGGCACCGGGCTGGTGG...,False,,4.795877,,False,False,True,False,False
4,D000003,PPWNCLVNXGCGAF-UHFFFAOYSA-N,C#CC(C)(C)C,T000005,P05182,Cyp2e1,NM_031543,MAVLGITIALLVWVATLLVISIWKQIYNSWNLPPGPFPLPILGNIF...,ATGGCGGTTCTTGGCATCACCATTGCCTTGCTGGTGTGGGTGGCCA...,False,,3.000000,,False,False,True,False,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
339192,D136571,ZWXHJFIYTQUPOJ-UHFFFAOYSA-N,c1ncc(Cc2cc3c(s2)CCCC3)[nH]1,T000568,P25100,ADRA1D,NM_000678,MTFRDLLSVSFEGPRPDSSAGGSSAGGGGGSAGGAAPSEGPAVGGV...,ATGACTTTCCGCGATCTCCTGAGCGTCAGTTTCGAGGGACCCCGCC...,True,,8.420216,,False,False,True,False,False
339193,D136572,MPRYSKNMRJZZIQ-UHFFFAOYSA-N,c1ncc(Cc2ccsc2)[nH]1,T000568,P25100,ADRA1D,NM_000678,MTFRDLLSVSFEGPRPDSSAGGSSAGGGGGSAGGAAPSEGPAVGGV...,ATGACTTTCCGCGATCTCCTGAGCGTCAGTTTCGAGGGACCCCGCC...,True,,8.494850,,False,False,True,False,False
339194,D136573,XYXCOZZITXKLLF-BETUJISGSA-N,c1ncc(N2CC3CNCC3C2)cc1N1CCOCC1,T000235,P09483,Chrna4,NM_024354,MANSGTGAPPPLLLLPLLLLLGTGLLPASSHIETRAHAEERLLKRL...,GGCCCCGGGGCGCCGCCGCCGCTGCTGCTACTGCCGCTGCTGCTGC...,False,,7.326979,,False,False,True,False,False
339195,D136573,XYXCOZZITXKLLF-BETUJISGSA-N,c1ncc(N2CC3CNCC3C2)cc1N1CCOCC1,T000514,Q05941,Chrna7,NM_012832,MCGGRGGIWLALAAALLHVSLQGEFQRRLYKELVKNYNPLERPVAN...,ATGTGCGGCGGGCGGGGAGGCATCTGGCTGGCTCTGGCCGCGGCGC...,False,,4.999996,,False,False,True,False,False
