In [1]:
import os
import torch

os.chdir('/home/robsyc/Desktop/thesis/MB-VAE-DTI/')
print(os.listdir('.'))
print(os.getcwd())

print(torch.cuda.is_available())
torch.cuda.empty_cache()

['.git', 'utils', 'drug_target_tree.onnx', 'data', 'bmfm_sm', 'scripts', '.gitignore', 'README.md', 'ESPF', 'notebooks', 'hpc.pbs', 'requirements.txt']
/home/robsyc/Desktop/thesis/MB-VAE-DTI
True


In [2]:
from utils.modelBuilding import ResidualBranch, DrugTargetTree, VariationalDrugTargetTree
import torch

model = ResidualBranch(
    input_dim = 512,
    hidden_dim = 256,
    output_dim = 512,
    depth = 3,
    dropout_prob = 0.1,
    activation = 'SiLU',
)
x = torch.randn(32, 512)
print(model(x).shape)

model = DrugTargetTree(
    drug_dims = [512, 512, 768],
    target_dims = [512, 1024, 768],
    hidden_dim = 256,
    latent_dim = 512,
    depth = 2,
    dropout_prob = 0.1,
)
drugs = [torch.randn(32, 512), torch.randn(32, 512), torch.randn(32, 768)]
targets = [torch.randn(32, 512), torch.randn(32, 1024), torch.randn(32, 768)]

y = model(drugs, targets)
print(y.shape)

model = VariationalDrugTargetTree(
    drug_dims = [512, 512, 768],
    target_dims = [512, 1024, 768],
    hidden_dim = 256,
    latent_dim = 512,
    depth = 2,
    dropout_prob = 0.1,
)

y = model(drugs, targets)
print(y.shape)

torch.Size([32, 512])
Number of parameters
 - Drug Leafs: 1,646,848
 - Target Leafs: 1,777,920
 - Attention: 1,026
 - Total: 3,425,794
torch.Size([32])
Number of parameters
 - Drug Leafs: 2,041,600
 - Target Leafs: 2,172,672
 - Attention: 2,052
 - Total: 4,216,324
torch.Size([32])


In [13]:
y, recon, kl = model(drugs, targets , compute_recon_loss=True, compute_kl_loss=True)

In [15]:
kl

tensor(24.8482, grad_fn=<AddBackward0>)

In [18]:
from utils.modelBuilding import EncoderBlock, MultiViewBlock, MultiBranchDTI, Generator
import torch

# model = EncoderBlock(
#     input_dim=2048,
#     hidden_dim=1024,
#     output_dim=777,
#     depth=2,
#     variational=True
# )
# x = torch.randn(10, 2048)
# z, kl = model(x, compute_kl_loss=True)
# z.shape

# model = MultiViewBlock(
#     input_dim_list=[2048, 768],
#     hidden_dim=1024,
#     latent_dim=777,
#     depth=2,
#     variational=True
# )
# x = [torch.randn(10, 2048), torch.randn(10, 768)]
# z, kl = model(x, compute_kl_loss=True)
# z.shape

model = MultiBranchDTI(
    input_dim_list_0=[512, 512, 768],
    input_dim_list_1=[512, 1024, 768],
    hidden_dim=256,
    latent_dim=512,
    depth=2,
    variational=True
)
x0 = [torch.randn(32, 512), torch.randn(32, 512), torch.randn(32, 768)]
x1 = [torch.randn(32, 512), torch.randn(32, 1024), torch.randn(32, 768)]
z, kl = model(x0, x1, compute_kl_loss=True)
z.shape

# model = Generator(
#     latent_dim=777,
#     hidden_dim=222,
#     depth=4,
#     dropout_prob=0.1,
#     n_nodes=60,
#     n_iters=3
# )
# z = torch.randn(10, 777)
# x = model(z)
# x.shape

Number of parameters
 - Branch 0: 2,236,928
 - Branch 1: 2,368,000


torch.Size([32])

In [5]:
model

Generator(
  (latent2hidden): Linear(in_features=777, out_features=222, bias=True)
  (residual_blocks): ModuleList(
    (0-3): 4 x Sequential(
      (0): LayerNorm((222,), eps=1e-05, elementwise_affine=True)
      (1): Linear(in_features=222, out_features=222, bias=False)
      (2): ELU(alpha=1.0)
      (3): Dropout(p=0.1, inplace=False)
      (4): Linear(in_features=222, out_features=222, bias=True)
    )
  )
  (hidden2topology): Linear(in_features=222, out_features=3540, bias=True)
  (gumbel2hidden): Sequential(
    (0): Linear(in_features=1770, out_features=222, bias=True)
    (1): ELU(alpha=1.0)
  )
)

In [3]:
model

VariationalMultiBranch(
  (branch0): Branch(
    (encoders): ModuleList(
      (0): Encoder(
        (input2hidden): Sequential(
          (0): Linear(in_features=2048, out_features=128, bias=True)
          (1): SiLU()
          (2): Dropout(p=0.1, inplace=False)
        )
        (hidden_layers): ModuleList(
          (0-1): 2 x Sequential(
            (0): Linear(in_features=128, out_features=128, bias=False)
            (1): SiLU()
            (2): Dropout(p=0.1, inplace=False)
          )
        )
        (layer_norms): ModuleList(
          (0-1): 2 x LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        )
        (hidden2output): Linear(in_features=128, out_features=1024, bias=True)
      )
      (1): Encoder(
        (input2hidden): Sequential(
          (0): Linear(in_features=768, out_features=128, bias=True)
          (1): SiLU()
          (2): Dropout(p=0.1, inplace=False)
        )
        (hidden_layers): ModuleList(
          (0-1): 2 x Sequential(
            (

In [4]:
kl_loss

tensor(0.9148, grad_fn=<DivBackward0>)

In [9]:
model

PlainBranch(
  (encoders): ModuleList(
    (0): Encoder(
      (input2hidden): Sequential(
        (0): Linear(in_features=2048, out_features=1024, bias=True)
        (1): SiLU()
        (2): Dropout(p=0.1, inplace=False)
      )
      (hidden_layers): ModuleList(
        (0-1): 2 x Sequential(
          (0): Linear(in_features=1024, out_features=1024, bias=False)
          (1): SiLU()
          (2): Dropout(p=0.1, inplace=False)
        )
      )
      (layer_norms): ModuleList(
        (0-1): 2 x LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      )
      (hidden2output): Linear(in_features=1024, out_features=500, bias=True)
    )
    (1): Encoder(
      (input2hidden): Sequential(
        (0): Linear(in_features=768, out_features=1024, bias=True)
        (1): SiLU()
        (2): Dropout(p=0.1, inplace=False)
      )
      (hidden_layers): ModuleList(
        (0-1): 2 x Sequential(
          (0): Linear(in_features=1024, out_features=1024, bias=False)
          (1): SiLU()
 

In [12]:
model

VariationalMultiBranch(
  (branch0): VariationalBranch(
    (blocks): ModuleList(
      (0): VariationalEncoder(
        (input2hidden): Sequential(
          (0): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (1): Linear(in_features=2048, out_features=512, bias=True)
          (2): SiLU()
          (3): Dropout(p=0.1, inplace=False)
        )
        (hidden_layers): ModuleList(
          (0-1): 2 x Sequential(
            (0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (1): Linear(in_features=512, out_features=512, bias=True)
            (2): SiLU()
            (3): Dropout(p=0.1, inplace=False)
          )
        )
        (hidden2output): Linear(in_features=512, out_features=1024, bias=True)
      )
      (1): VariationalEncoder(
        (input2hidden): Sequential(
          (0): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (1): Linear(in_features=768, out_features=512, bias=True)
          (2): SiLU()
          (3): Dro

---
---

# Load Example Data

In [5]:
import h5torch

dataset = h5torch.Dataset(
    "./data/dataset/DAVIS.h5t",
    sampling="coo",
    subset=("unstructured/split_cold", "valid"),
    in_memory=True,
)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=12, shuffle=True)
batch = next(iter(dataloader))
batch

{'central': tensor([5.0000, 5.0000, 5.4437, 5.0000, 5.0000, 5.0000, 5.0000, 5.0000, 5.0000,
         5.0000, 5.0000, 5.0000]),
 '0/Drug_ID': ['9926054',
  '9926054',
  '216239',
  '11717001',
  '216239',
  '123631',
  '123631',
  '9926791',
  '156422',
  '216239',
  '9926054',
  '9926791'],
 '0/Drug_SMILES': ['Cc1ccc2nc(NCCN)c3ncc(C)n3c2c1.Cl',
  'Cc1ccc2nc(NCCN)c3ncc(C)n3c2c1.Cl',
  'CNC(=O)c1cc(Oc2ccc(NC(=O)Nc3ccc(Cl)c(C(F)(F)F)c3)cc2)ccn1',
  'OCCn1cc(-c2ccc3c(c2)CCC3=NO)c(-c2ccncc2)n1',
  'CNC(=O)c1cc(Oc2ccc(NC(=O)Nc3ccc(Cl)c(C(F)(F)F)c3)cc2)ccn1',
  'COc1cc2ncnc(Nc3ccc(F)c(Cl)c3)c2cc1OCCCN1CCOCC1',
  'COc1cc2ncnc(Nc3ccc(F)c(Cl)c3)c2cc1OCCCN1CCOCC1',
  'CC1CCN(C(=O)CC#N)CC1N(C)c1ncnc2[nH]ccc12',
  'Cc1ccc(-n2nc(C(C)(C)C)cc2NC(=O)Nc2ccc(OCCN3CCOCC3)c3ccccc23)cc1',
  'CNC(=O)c1cc(Oc2ccc(NC(=O)Nc3ccc(Cl)c(C(F)(F)F)c3)cc2)ccn1',
  'Cc1ccc2nc(NCCN)c3ncc(C)n3c2c1.Cl',
  'CC1CCN(C(=O)CC#N)CC1N(C)c1ncnc2[nH]ccc12'],
 '0/Drug_emb_graph': tensor([[-0.0104, -0.0124, -0.0965,  ..., -0.0641, -0

In [6]:
print("Molecule drugs")
for key in ["0/Drug_fp", '0/Drug_emb_graph', '0/Drug_emb_image', '0/Drug_emb_text']:
    print("- ", key[2:], batch[key].shape)

print("\nProtein targets")
for key in ['1/Target_fp', '1/Target_emb_ESM', '1/Target_emb_T5', '1/Target_emb_DNA']:
    print("- ", key[2:], batch[key].shape)

print("\nInteraction value: ", batch["central"].shape)

Molecule drugs
-  Drug_fp torch.Size([12, 2048])
-  Drug_emb_graph torch.Size([12, 512])
-  Drug_emb_image torch.Size([12, 512])
-  Drug_emb_text torch.Size([12, 768])

Protein targets
-  Target_fp torch.Size([12, 4170])
-  Target_emb_ESM torch.Size([12, 1280])
-  Target_emb_T5 torch.Size([12, 1024])
-  Target_emb_DNA torch.Size([12, 1024])

Interaction value:  torch.Size([12])


In [7]:
from torch.utils.data import DataLoader, Dataset
import h5torch

def get_dataset(split_type, split_name):
    return h5torch.Dataset(
        "./data/dataset/DAVIS.h5t",
        sampling="coo",
        subset=(f"unstructured/{split_type}", split_name),
        in_memory=True,
    )

class CustomH5Dataset(Dataset):
    def __init__(self, h5_dataset, inputs_0, inputs_1):
        self.dataset = h5_dataset
        self.inputs_0 = inputs_0
        self.inputs_1 = inputs_1

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        batch = self.dataset[idx]
        # Inputs for branch 0 (drug)
        x0 = [batch[key] for key in self.inputs_0]
        # Inputs for branch 1 (target)
        x1 = [batch[key] for key in self.inputs_1]
        # Interaction value
        y = batch['central']
        return x0, x1, y

config = {
        'inputs_0': ['0/Drug_fp'],
        'inputs_1': ['1/Target_fp'],
        'model_type': 'plain',  # 'plain' or 'variational'
}
batch_size = 16

train_dataset = get_dataset("split_rand", 'train')
valid_dataset = get_dataset("split_rand", 'valid')
test_dataset = get_dataset("split_rand", 'test')

# Create DataLoaders
train_loader = DataLoader(
    CustomH5Dataset(train_dataset, config['inputs_0'], config['inputs_1']),
    batch_size=batch_size,
    shuffle=True
)
valid_loader = DataLoader(
    CustomH5Dataset(valid_dataset, config['inputs_0'], config['inputs_1']),
    batch_size=batch_size,
    shuffle=False
)
test_loader = DataLoader(
    CustomH5Dataset(test_dataset, config['inputs_0'], config['inputs_1']),
    batch_size=batch_size,
    shuffle=False
)

In [8]:
import os
import torch
os.chdir('/home/robsyc/Desktop/thesis/MB-VAE-DTI/')
from utils.modelTraining import train_and_evaluate

CONFIGS = {
    'single_view_fp': {
        'inputs_0': ['0/Drug_fp'],
        'inputs_1': ['1/Target_fp'],
        'model_type': 'plain',  # 'plain' or 'variational'
    },
    # 'var_single_view_emb': {
    #     'inputs_0': ['0/Drug_emb_graph'],
    #     'inputs_1': ['1/Target_emb_T5'],
    #     'model_type': 'variational',
    # },
}

split_type = 'split_rand'
for exp_name, config in CONFIGS.items():
    print(f"\nRunning experiment: {exp_name}")
    best_valid_loss, test_loss = train_and_evaluate(
        config=config,
        split_type=split_type,
        num_epochs=10,
        batch_size=16
    )
    print(f"Experiment: {exp_name}, Best Valid Loss: {best_valid_loss:.4f}, Test Loss: {test_loss:.4f}")


Running experiment: single_view_fp
Epoch [1/10], Train Loss: 6.2590, Valid Loss: 0.4720
Epoch [2/10], Train Loss: 0.5535, Valid Loss: 0.3568
Epoch [3/10], Train Loss: 0.4491, Valid Loss: 0.3484
Epoch [4/10], Train Loss: 0.4225, Valid Loss: 0.4932
Epoch [5/10], Train Loss: 0.4021, Valid Loss: 0.5732
Epoch [6/10], Train Loss: 0.4050, Valid Loss: 0.6627
Epoch [7/10], Train Loss: 0.4071, Valid Loss: 0.3678
Epoch [8/10], Train Loss: 0.3850, Valid Loss: 0.4727
Epoch [9/10], Train Loss: 0.3837, Valid Loss: 0.4347
Epoch [10/10], Train Loss: 0.3859, Valid Loss: 0.3137
Test Loss: 0.3159
Experiment: single_view_fp, Best Valid Loss: 0.3137, Test Loss: 0.3159
