In [2]:
import random
from typing import List, Tuple, Dict, Set, Callable
import re
import pickle
from pathlib import Path
import hashlib
import time

import pandas as pd
import numpy as np
import wandb
import torch
import torch_geometric
from torch import nn
from torch.nn import functional as F
from torch.functional import Tensor
from torch_geometric.nn import GINConv, global_mean_pool
from torch_geometric.data import InMemoryDataset
from torch_geometric.loader import DataLoader
from torch_geometric.typing import Adj
from pytorch_lightning import seed_everything, Trainer, LightningDataModule, LightningModule
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, RichModelSummary, RichProgressBar
from rdkit import Chem
from rdkit.Chem import AllChem
from transformers import T5Tokenizer, T5EncoderModel
from torchmetrics import MetricCollection, Accuracy, AUROC, MatthewsCorrCoef, MeanSquaredError, MeanAbsoluteError, ExplainedVariance

print(torch.version.cuda)

12.1


In [3]:
mols = {
    "L01": "C1CCCCC1",
    "L02": "O1CCOCC1",
    "L03": "C1CCCC2C1CCCC2",
    "L04": "C1CCCCC1C2CCCCC2",
    "L05": "c1ccccc1-c2ccccc2",
    "L06": "CC(C)(C)OC(=O)N1CCC[C@H]1C(=O)O",
    "L07": "O=Cc1ccc(O)c(OC)c1",
    "L08": "CC(=O)NCCC1=CNc2c1cc(OC)cc2",
    "L09": "CCc(c1)ccc2[n+]1ccc3c2[nH]c4c3cccc4",
    "L10": "CCC[C@@H](O)CC\C=C\C=C\C#CC#C\C=C\CO",
}
aa_seqs = {f"G{i + 1:02d}": ''.join(random.choices("ACDEFGHIKLMNPQRSTVWY", k=40)) for i in range(10)}

p_values_range = (5, 10)
classes = [0, 1]
labels = []
records = {
    "LIG_ID": list(mols.keys()),
    "GENE_ID": list(aa_seqs.keys()),
    "acname": [],
    "acvalue_uM": [],
    "class": [],
}
for _ in range(10):
    records["acvalue_uM"].append(round(random.uniform(*p_values_range), 2))
    records["class"].append(int(records["acvalue_uM"][-1] < 6))
    records["acname"].append("pIC50" if random.choice([True, False]) else "pKi")

data = pd.DataFrame(records)
with open(dummy_path := Path("data") / "dummy.pkl", "wb") as f:
    pickle.dump((data, mols, aa_seqs), f)

In [4]:
class MolEmbed(LightningModule):
    def __init__(self, input_dim: int, hidden_dim: int = 128, output_dim: int = 128, num_layers: int = 3):
        super().__init__()
        self.inp = GINConv(
            nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.PReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
            )
        )
        mid_layers = [
            GINConv(
                nn.Sequential(
                    nn.Linear(hidden_dim, hidden_dim),
                    nn.PReLU(),
                    nn.Linear(hidden_dim, hidden_dim),
                    nn.BatchNorm1d(hidden_dim),
                )
            )
            for _ in range(num_layers - 2)
        ]
        self.mid_layers = nn.ModuleList(mid_layers)
        self.out = GINConv(
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.PReLU(),
                nn.Linear(hidden_dim, output_dim),
                nn.BatchNorm1d(output_dim),
            )
        )

    def forward(self, data, **kwargs) -> Tensor:
        """Forward the data through the GNN module"""
        x = self.inp(data.x, data.edge_index)
        for module in self.mid_layers:
            x = module(x, data.edge_index)
        x = self.out(x, data.edge_index)
        pool = global_mean_pool(x, data["batch"])
        return F.normalize(pool, dim=1)

In [12]:
class CACHE_DTI(LightningModule):
    def __init__(self, batch_size, **kwargs):
        super().__init__()
        self.drug_embedder = MolEmbed(input_dim=9, hidden_dim=128, output_dim=128, num_layers=3)
        self.prot_embedder = nn.Sequential(nn.Linear(768, 128), nn.PReLU())
        self.head = nn.Sequential(
            nn.Linear(256, 64), 
            nn.PReLU(), 
            nn.Dropout(0.2), 
            nn.Linear(64, 3), 
            nn.PReLU(), 
            nn.Dropout(0.2)
        )
        self.reg_criterion = nn.MSELoss()
        self.class_criterion = nn.BCEWithLogitsLoss()
        self.metrics = self._set_metrics()
        self.automatic_optimization = False
        self.batch_size = batch_size

    def to(self, device):
        super().to(device)
        for splits in self.metrics.values():
            for mc in splits.values():
                mc.to(device)

    def _set_metrics(self):
        class_metrics = MetricCollection([
            Accuracy(task="binary"),
            AUROC(task="binary"),
            MatthewsCorrCoef(task="binary"),
        ])
        reg_metrics = MetricCollection([
            MeanAbsoluteError(),
            MeanSquaredError(),
            ExplainedVariance(),
        ])
        return {
            "class": {
                "train": class_metrics.clone("class/train/"),
                "val": class_metrics.clone("class/val/"),
                "test": class_metrics.clone("class/test/"),
            }, 
            "reg": {
                "train": reg_metrics.clone("reg/train/"),
                "val": reg_metrics.clone("reg/val/"),
                "test": reg_metrics.clone("reg/test/"),
            }
        }
     
    def forward(self, data):
        drug_embed = self.drug_embedder(data)
        prot_embed = self.prot_embedder(data["t5"])
        comb_embed = torch.cat((drug_embed, prot_embed), dim=1)
        pred = self.head(comb_embed)
        mask = data["y"][:, :2] == -1
        pred[:, :2][mask] = data["y"][:, :2][mask]
        return {
            "reg_pred": pred[:, :-1],
            "reg_labels": data["y"][:, :-1],
            "class_pred": pred[:, -1],
            "class_labels": data["y"][:, -1].float(),
        }

    def shared_step(self, data):
        fwd_dict = self.forward(data)
        fwd_dict["reg_loss"] = self.reg_criterion(fwd_dict["reg_pred"], fwd_dict["reg_labels"])
        fwd_dict["class_loss"] = self.class_criterion(fwd_dict["class_pred"], fwd_dict["class_labels"])
        return fwd_dict

    def update(self, fwd, stage):
        self.metrics["reg"][stage].update(fwd["reg_pred"].contiguous(), fwd["reg_labels"].contiguous())
        # self.metrics["class"][stage].update(fwd["class_pred"], fwd["class_labels"])
        self.log(f"{stage}/reg/loss", fwd["reg_loss"], batch_size=self.batch_size)
        # self.log(f"{stage}/class/loss", fwd["class_loss"], batch_size=self.batch_size)

    def log_histograms(self):
        """Logs the histograms of all the available parameters."""
        if self.logger:
            for name, param in self.named_parameters():
                self.logger.experiment.add_histogram(name, param, self.current_epoch)

    def training_step(self, data, data_idx):
        fwd = self.shared_step(data)
        self.update(fwd, "train")

        # Do backpropagation
        self.optimizers().zero_grad()
        self.manual_backward(fwd["reg_loss"])  # , retain_graph=True)
        # self.manual_backward(fwd["class_loss"])
        self.optimizers().step()
        
        return fwd

    def validation_step(self, data, data_idx):
        print("validation step")
        fwd = self.shared_step(data)
        self.update(fwd, "val")
        print(self.metrics["reg"]["val"])
        return fwd

    def test_step(self, data, data_idx):
        fwd = self.shared_step(data)
        seld.update(fwd, "test")
        return fwd
        
    def log_all(self, metrics: dict, hparams: bool = False):
        """Log all metrics."""
        print(metrics)
        if self.logger:
            for k, v in metrics.items():
                print("Logging", k, ":", v)
                self.log(k, v, self.current_epoch)

    def shared_end(self, stage):
        for task in ["reg"]:  # , "class"]:
            print("Logging for", task)
            metrics = self.metrics[task][stage].compute()
            self.metrics[task][stage].reset()
            self.log_all(metrics)

    def on_training_epoch_end(self):
        self.log_histograms()
        print("Logging training end")
        self.shared_end("train")

    def on_validation_epoch_end(self):
        print("Logging validation end")
        self.shared_end("val")

    def on_test_epoch_end(self):
        print("Logging testing end")
        self.shared_end("test")

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())

In [13]:
def ProstT5(aa_seqs):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    tokenizer = T5Tokenizer.from_pretrained("Rostlab/ProstT5", do_lower_case=False)
    prostt5 = T5EncoderModel.from_pretrained("Rostlab/ProstT5").to(device)
    prostt5.full() if device == "cpu" else prostt5.half()
    
    seqs = [" ".join(["<fold2AA>"] + list(re.sub("[OUZ]", "X", seq))) for seq in aa_seqs]
    max_len = max(len(s) for s in aa_seqs)
    encoding = tokenizer.batch_encode_plus(seqs, add_special_tokens=True, padding="longest", return_tensors="pt").to(device)
    with torch.no_grad():
        aaseq_embed = prostt5(encoding.input_ids, attention_mask=encoding.attention_mask)
    return dict(zip(aa_seqs, aaseq_embed.last_hidden_state[:, 1 : max_len + 1].mean(dim=1).cpu()))


def ProtT5(aa_seqs):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_base_mt_uniref50", do_lower_case=False)
    prott5 = T5EncoderModel.from_pretrained("Rostlab/prot_t5_base_mt_uniref50").to(device)
    seqs = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in aa_seqs]
    
    ids = tokenizer(aa_seqs, add_special_tokens=True, padding="longest")

    input_ids = torch.tensor(ids['input_ids']).to(device)
    attention_mask = torch.tensor(ids['attention_mask']).to(device)
    max_len = max(len(s) for s in aa_seqs)
    
    with torch.no_grad():
        embedding_repr = prott5(input_ids=input_ids, attention_mask=attention_mask)
    return dict(zip(aa_seqs, embedding_repr.last_hidden_state[:, 1 : max_len + 1].mean(dim=1).cpu().numpy()))

In [14]:
def prepare(root, filename):
    with open(filename, "rb") as f:
        inter, lig_map, tar_map = pickle.load(f)
    if (t5_path := Path(root) / f"prostt5_{filename.stem}.pkl").exists():
        with open(t5_path, "rb") as f:
            t5 = pickle.load(f)
    else:
        t5 = ProtT5(list(tar_map.values()))
        with open(t5_path, "wb") as f:
            pickle.dump(t5, f)
    lig_data = {k: torch_geometric.utils.from_smiles(v) for k, v in lig_map.items()}
    return inter, lig_data, tar_map, t5


def row2data(row, lig_data, tar_map, t5):
    d = lig_data[row["LIG_ID"]].clone()
    d["t5"] = torch.tensor(t5[tar_map[row["GENE_ID"]]]).reshape(1, -1)
    if row["acname"] == "pKi":
        d["y"] = [row["acvalue_uM"], -1, row["class"]]
    elif row["acname"] == "pIC50":
        d["y"] = [-1, row["acvalue_uM"], row["class"]]
    d["y"] = torch.tensor(d["y"]).reshape(1, -1)
    return d


class BaseDataset(InMemoryDataset):
    def __init__(self, filename, path_idx: int = 0):
        self.filename = Path(filename)
        super().__init__(Path("data"))
        self.data, self.slices = torch.load(self.processed_paths[path_idx])

    @property
    def processed_paths(self):
        return [self.root / f for f in self.processed_file_names]

    def process_(self, data, path_idx: int = 0):
        data, slices = self.collate(data)
        torch.save((data, slices), self.processed_paths[path_idx])


class PretrainDataset(BaseDataset):
    def __init__(self, filename):
        super().__init__(filename)
    
    @property
    def processed_file_names(self):
        return [self.filename.stem + ".pt"]

    def process(self):
        inter, lig_data, tar_map, t5 = prepare(self.root, self.filename)
        data = []
        for _, row in inter.iterrows():
            try:
                data.append(row2data(row, lig_data, tar_map, t5))
            except Exception as e:
                print(e)
        self.process_(data)


class MCHR1Dataset(BaseDataset):
    splits = {"train": 0, "val": 1, "ensemble": 2, "test": 3}
    
    def __init__(self, filename, split):
        super().__init__(filename, self.splits[split])

    @property
    def processed_file_names(self):
        return [k + ".pt" for k in ["train", "val", "ensemble", "test"]]
    
    def process(self):
        inter, lig_data, tar_map, t5 = prepare(self.root, self.filename)
        data = {k: [] for k in self.splits.keys()}
        for _, row in inter.iterrows():
            try:
                data[row["split"]].append(row2data(row, lig_data, tar_map, t5))
            except Exception as e:
                print(e)
        for split in self.splits.keys():
            self.process_(data[split], self.splits[split])

In [15]:
class BaseDataModule(LightningDataModule):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size
    
    def train_dataloader(self):
        return DataLoader(self.train, batch_size=min(self.batch_size, len(self.train)), shuffle=True)

    def val_dataloader(self):
        dl = DataLoader(self.val, batch_size=min(self.batch_size, len(self.val)), shuffle=False)
        print("Validation dataloader created:", dl)
        return dl

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=min(self.batch_size, len(self.test)), shuffle=False)

    def predict_dataloader(self):
        return DataLoader(self.ensemble, batch_size=min(self.batch_size, len(self.ensemble)), shuffle=False)


class PretrainDataModule(BaseDataModule):
    def __init__(self, filename, batch_size):
        super().__init__(batch_size)
        ds = PretrainDataset(filename)
        self.train, self.val, self.test = torch.utils.data.dataset.random_split(ds, [0.4, 0.3, 0.3])


class MCHR1DataModule(BaseDataModule):
    def __init__(self, filename, batch_size):
        super().__init__(batch_size)

        self.train = MCHR1Dataset(filename, "train")
        self.val = MCHR1Dataset(filename, "val")
        self.ensemble = MCHR1Dataset(filename, "ensemble")
        self.test = MCHR1Dataset(filename, "test")

In [16]:
def fold2split(fold_name, fold_id):
    fold_nr = int(fold_name[-1])
    if fold_nr == 8:
        return "ensemble"
    if fold_nr == 9:
        return "test"
    if fold_nr == (fold_id + 6) % 7:
        return "val"
    return "train"


def prep_mchr1(output_path, fold_id):
    df = pd.read_csv(Path("..") / "Reference Data" / "20240430_MCHR1_splitted_RJ.csv")
    
    with open(Path("..") / "Protein Structures" / "Q99705.fasta") as f:
        tar_map = {"P1": "".join(l.strip() for l in f.readlines()[1:])}
    lig_map = dict(df[["ID", "smiles"]].values)
    
    df["GENE_ID"] = "P1"
    df.rename(columns={"ID": "LIG_ID"}, inplace=True)
    df["split"] = df["DataSAIL_10f"].apply(lambda x: fold2split(x, fold_id))
    inter = df[["LIG_ID", "GENE_ID", "acvalue_uM", "acname", "class", "split"]]
    
    with open(output_path, "wb") as f:
        pickle.dump((inter, lig_map, tar_map), f)

In [17]:
def time2hash(npos):
    # Get the current timestamp
    current_timestamp = str(time.time())
    
    # Create a SHA-256 hash of the current timestamp
    hash_object = hashlib.sha256(current_timestamp.encode())
    hash_hex = hash_object.hexdigest()
    
    # Extract the first 4 digits of the hash
    return hash_hex[:npos]

In [18]:
def pretrain(batch_size: int = 2):
    base_path = Path("..") / "Pretrain_Data"
    # bdb_path = base_path / "bdb_muts.pkl"
    # glass_path = base_path / "glass.pkl"
    bdb_path = Path("data") / "dummy.pkl"
    glass_path = Path("data") / "dummy.pkl"

    bdb = PretrainDataModule(bdb_path, batch_size)
    glass = PretrainDataModule(glass_path, batch_size)
    run_name = None

    for pretrain_module, name in [
        (bdb, "BindingDB"), 
        # (glass, "GLASS"), 
    ]:

        if run_name is None:
            model = CACHE_DTI(batch_size)
        else:
            # load model
            run = wandb.init(project="CACHE5")
            artifact = run.use_artifact(f"rindti/CACHE5/model-{run_name}:best", type="model")
            artifact_dir = artifact.download()
            model = CACHE_DTI.load_from_checkpoint(Path(artifact_dir) / "model.ckpt")
            wandb.finish()
        
        logger = WandbLogger(
            log_model="all",
            project="CACHE5",
            # entity="old-shatterhand",
            name=name.lower() + "_" + time2hash(4),
        )
        run_name = logger.experiment.path.split("/")[-1]

        trainer = Trainer(
            callbacks=[
                ModelCheckpoint(save_last=True, mode="min", monitor="val/reg/loss", save_top_k=1),
                # RichModelSummary(),
                # RichProgressBar(),
            ],
            logger=logger,
            log_every_n_steps=1,
            enable_model_summary=False,
            # gpus=1,
            max_epochs=10,
            # batch_size=batch_size,
        )
            
        trainer.fit(model, pretrain_module)
        print("Pretraining on", name, "finished")
        run_name = logger.experiment.path
    return run_name


def train(ckpt_path, batch_size):
    mchr1_path = Path("data") / "mchr1.pkl"
    prep_mchr1(mchr1_path, fold_id=0)
    mchr1 = MCHR1DataModule(mchr1_path, batch_size)


bs = 1
ckpt = pretrain(bs)
# train(ckpt, bs)

/home/rjo21/miniconda3/envs/cache_train/lib/python3.10/site-packages/pytorch_lightning/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: |                                                                                            …

Validation dataloader created: <torch_geometric.loader.dataloader.DataLoader object at 0x7f107a1d8670>
validation step
MetricCollection(
  (MeanAbsoluteError): MeanAbsoluteError()
  (MeanSquaredError): MeanSquaredError()
  (ExplainedVariance): ExplainedVariance(),
  prefix=reg/val/
)
validation step
MetricCollection(
  (MeanAbsoluteError): MeanAbsoluteError()
  (MeanSquaredError): MeanSquaredError()
  (ExplainedVariance): ExplainedVariance(),
  prefix=reg/val/
)
Logging validation end
Logging for reg
{'reg/val/MeanAbsoluteError': tensor(3.5759, device='cuda:0'), 'reg/val/MeanSquaredError': tensor(27.0711, device='cuda:0'), 'reg/val/ExplainedVariance': tensor(0.5000, device='cuda:0')}
Logging reg/val/MeanAbsoluteError : tensor(3.5759, device='cuda:0')
Logging reg/val/MeanSquaredError : tensor(27.0711, device='cuda:0')
Logging reg/val/ExplainedVariance : tensor(0.5000, device='cuda:0')


/home/rjo21/miniconda3/envs/cache_train/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Training: |                                                                                                   …

Validation: |                                                                                                 …

validation step
MetricCollection(
  (MeanAbsoluteError): MeanAbsoluteError()
  (MeanSquaredError): MeanSquaredError()
  (ExplainedVariance): ExplainedVariance(),
  prefix=reg/val/
)
validation step
MetricCollection(
  (MeanAbsoluteError): MeanAbsoluteError()
  (MeanSquaredError): MeanSquaredError()
  (ExplainedVariance): ExplainedVariance(),
  prefix=reg/val/
)
validation step
MetricCollection(
  (MeanAbsoluteError): MeanAbsoluteError()
  (MeanSquaredError): MeanSquaredError()
  (ExplainedVariance): ExplainedVariance(),
  prefix=reg/val/
)
Logging validation end
Logging for reg
{'reg/val/MeanAbsoluteError': tensor(3.4954, device='cuda:0'), 'reg/val/MeanSquaredError': tensor(25.4450, device='cuda:0'), 'reg/val/ExplainedVariance': tensor(0.5004, device='cuda:0')}
Logging reg/val/MeanAbsoluteError : tensor(3.4954, device='cuda:0')
Logging reg/val/MeanSquaredError : tensor(25.4450, device='cuda:0')
Logging reg/val/ExplainedVariance : tensor(0.5004, device='cuda:0')


Validation: |                                                                                                 …

validation step
MetricCollection(
  (MeanAbsoluteError): MeanAbsoluteError()
  (MeanSquaredError): MeanSquaredError()
  (ExplainedVariance): ExplainedVariance(),
  prefix=reg/val/
)
validation step
MetricCollection(
  (MeanAbsoluteError): MeanAbsoluteError()
  (MeanSquaredError): MeanSquaredError()
  (ExplainedVariance): ExplainedVariance(),
  prefix=reg/val/
)
validation step
MetricCollection(
  (MeanAbsoluteError): MeanAbsoluteError()
  (MeanSquaredError): MeanSquaredError()
  (ExplainedVariance): ExplainedVariance(),
  prefix=reg/val/
)
Logging validation end
Logging for reg
{'reg/val/MeanAbsoluteError': tensor(3.3941, device='cuda:0'), 'reg/val/MeanSquaredError': tensor(24.0475, device='cuda:0'), 'reg/val/ExplainedVariance': tensor(0.5012, device='cuda:0')}
Logging reg/val/MeanAbsoluteError : tensor(3.3941, device='cuda:0')


MisconfigurationException: You called `self.log(reg/val/MeanAbsoluteError, ...)` twice in `on_validation_epoch_end` with different arguments. This is not allowed