<a href="https://colab.research.google.com/github/yingzibu/drug_design_JAK/blob/main/COMPLEX_FINAL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# USE GPU, need to connect to gpu
from google.colab import drive
drive.mount('/content/gdrive/')

Mounted at /content/gdrive/


In [None]:
!pip install gensim --quiet
!pip install  dgl -f https://data.dgl.ai/wheels/cu118/repo.html --quiet
!pip install  dglgo -f https://data.dgl.ai/wheels-test/repo.html --quiet
!pip install omegaconf --quiet
!pip install rdkit --quiet

# !pip install rdkit==2022.3.3 --quiet

!pip install pubchempy --quiet

!pip install torchmetrics --quiet
!pip install wandb -qU
!pip install transformers --quiet

!pip install pytorch_lightning --quiet
# https://tdcommons.ai/start/
!pip install PyTDC --quiet

In [None]:
Restart = True #@param {type:"boolean"}
import os
if Restart:
    os.kill(os.getpid(), 9)

In [1]:
cd /content/gdrive/MyDrive/A_JAK_design

/content/gdrive/MyDrive/A_JAK_design


In [2]:
from gensim.models import word2vec
from help_function.package_version_check import *
from help_function.function import *
# main()
!python --version

torch version:  2.0.1+cu118
cuda available:  True
rdkit version:  2023.03.2
matplotlib version:  3.7.1
dgl version:  1.1.1+cu118
did not install dgllife, required by MTATFP
gensim version:  4.3.1
Python 3.10.6


In [None]:
main()

In [3]:
import torchmetrics
print(torchmetrics.__version__)
import wandb
wandb.login()

1.0.1


Play with wandb

In [4]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from __future__ import annotations

import typing as T

from functools import lru_cache
from pathlib import Path

import h5py
import numpy as np
import torch
from tqdm import tqdm

from torch.utils.data import DataLoader
# from src.featurizers.molecule import MorganFeaturizer

# import MorganFeaturizer, ProtBertFeaturizer
# from model.architectures import SimpleCoembeddingNoSigmoid

class Featurizer:
    def __init__(
        self, name: str, shape: int, save_dir: Path = Path().absolute()
    ):
        self._name = name
        self._shape = shape
        self._save_path = save_dir / Path(f"{self._name}_features.h5")

        self._preloaded = False
        self._device = torch.device("cpu")
        self._cuda_registry = {}
        self._on_cuda = False
        self._features = {}

    def __call__(self, seq: str) -> torch.Tensor:
        if seq not in self.features:
            self._features[seq] = self.transform(seq)

        return self._features[seq]

    def _register_cuda(self, k: str, v, f=None):
        """
        Register an object as capable of being moved to a CUDA device
        """
        self._cuda_registry[k] = (v, f)

    def _transform(self, seq: str) -> torch.Tensor:
        raise NotImplementedError

    def _update_device(self, device: torch.device):
        self._device = device
        for k, (v, f) in self._cuda_registry.items():
            if f is None:
                try:
                    self._cuda_registry[k] = (v.to(self._device), None)
                except RuntimeError as e:
                    logg.error(e)
                    logg.debug(device)
                    logg.debug(type(self._device))
                    logg.debug(self._device)
            else:
                self._cuda_registry[k] = (f(v, self._device), f)
        for k, v in self._features.items():
            self._features[k] = v.to(device)

    @lru_cache(maxsize=5000)
    def transform(self, seq: str) -> torch.Tensor:
        with torch.set_grad_enabled(False):
            feats = self._transform(seq)
            if self._on_cuda:
                feats = feats.to(self.device)
            return feats

    @property
    def name(self) -> str:
        return self._name

    @property
    def shape(self) -> int:
        return self._shape

    @property
    def path(self) -> Path:
        return self._save_path

    @property
    def features(self) -> dict:
        return self._features

    @property
    def on_cuda(self) -> bool:
        return self._on_cuda

    @property
    def device(self) -> torch.device:
        return self._device

    def to(self, device: torch.device) -> Featurizer:
        self._update_device(device)
        self._on_cuda = device.type == "cuda"
        return self

    def cuda(self, device: torch.device) -> Featurizer:
        """
        Perform model computations on CUDA, move saved embeddings to CUDA device
        """
        self._update_device(device)
        self._on_cuda = True
        return self

    def cpu(self) -> Featurizer:
        """
        Perform model computations on CPU, move saved embeddings to CPU
        """
        self._update_device(torch.device("cpu"))
        self._on_cuda = False
        return self

    def write_to_disk(
        self, seq_list: T.List[str], verbose: bool = True
    ) -> None:
        logg.info(f"Writing {self.name} features to {self.path}")
        with h5py.File(self._save_path, "a") as h5fi:
            for seq in tqdm(seq_list, disable=not verbose, desc=self.name):
                seq_h5 = sanitize_string(seq)
                if seq_h5 in h5fi:
                    logg.warning(f"{seq} already in h5file")
                feats = self.transform(seq)
                dset = h5fi.require_dataset(seq_h5, feats.shape, np.float32)
                dset[:] = feats.cpu().numpy()

    def preload(
        self,
        seq_list: T.List[str],
        verbose: bool = True,
        write_first: bool = True,
    ) -> None:
        logg.info(f"Preloading {self.name} features from {self.path}")

        if write_first and not self._save_path.exists():
            self.write_to_disk(seq_list, verbose=verbose)

        if self._save_path.exists():
            with h5py.File(self._save_path, "r") as h5fi:
                for seq in tqdm(seq_list, disable=not verbose, desc=self.name):
                    if seq in h5fi:
                        seq_h5 = sanitize_string(seq)
                        feats = torch.from_numpy(h5fi[seq_h5][:])
                    else:
                        feats = self.transform(seq)

                    if self._on_cuda:
                        feats = feats.to(self.device)

                    self._features[seq] = feats

        else:
            for seq in tqdm(seq_list, disable=not verbose, desc=self.name):
                feats = self.transform(seq)

                if self._on_cuda:
                    feats = feats.to(self.device)

                self._features[seq] = feats

        # seqs_sanitized = [sanitize_string(s) for s in seq_list]
        # feat_dict = load_hdf5_parallel(self._save_path, seqs_sanitized,n_jobs=32)
        # self._features.update(feat_dict)

        self._update_device(self.device)
        self._preloaded = True

class MorganFeaturizer(Featurizer):

    def __init__(
        self,
        shape: int = 2048,
        radius: int = 2,
        save_dir: Path = Path().absolute(),
    ):
        super().__init__("Morgan", shape, save_dir)

        self._radius = radius

    def smiles_to_morgan(self, smile: str):
        """
        Convert smiles into Morgan Fingerprint.
        :param smile: SMILES string
        :type smile: str
        :return: Morgan fingerprint
        :rtype: np.ndarray
        """
        try:
            smile = canonicalize(smile)
            mol = Chem.MolFromSmiles(smile)
            features_vec = AllChem.GetMorganFingerprintAsBitVect(
                mol, self._radius, nBits=self.shape
            )
            features = np.zeros((1,))
            DataStructs.ConvertToNumpyArray(features_vec, features)
        except Exception as e:
            logg.error(
                f"rdkit not found this smiles for morgan: {smile} convert to all 0 features"
            )
            logg.error(e)
            features = np.zeros((self.shape,))
        return features

    def _transform(self, smile: str) -> torch.Tensor:
        # feats = torch.from_numpy(self._featurizer(smile)).squeeze().float()
        feats = (
            torch.from_numpy(self.smiles_to_morgan(smile)).squeeze().float()
        )
        if feats.shape[0] != self.shape:
            logg.warning("Failed to featurize: appending zero vector")
            feats = torch.zeros(self.shape)
        return feats


class SimpleCoembedding(nn.Module):
    def __init__(
        self,
        drug_shape=2048,
        target_shape=1024,
        latent_dimension=1024,
        latent_activation=nn.ReLU,
        latent_distance="Cosine",
        classify=True,
    ):
        super().__init__()
        self.drug_shape = drug_shape
        self.target_shape = target_shape
        self.latent_dimension = latent_dimension
        self.do_classify = classify

        self.drug_projector = nn.Sequential(
            nn.Linear(self.drug_shape, latent_dimension), latent_activation()
        )
        nn.init.xavier_normal_(self.drug_projector[0].weight)

        self.target_projector = nn.Sequential(
            nn.Linear(self.target_shape, latent_dimension), latent_activation()
        )
        nn.init.xavier_normal_(self.target_projector[0].weight)

        if self.do_classify:
            self.distance_metric = latent_distance
            self.activator = DISTANCE_METRICS[self.distance_metric]()

    def forward(self, drug, target):
        if self.do_classify:
            return self.classify(drug, target)
        else:
            return self.regress(drug, target)

    def regress(self, drug, target):
        drug_projection = self.drug_projector(drug)
        target_projection = self.target_projector(target)

        inner_prod = torch.bmm(
            drug_projection.view(-1, 1, self.latent_dimension),
            target_projection.view(-1, self.latent_dimension, 1),
        ).squeeze()
        relu_f = torch.nn.ReLU()
        return relu_f(inner_prod).squeeze()

    def classify(self, drug, target):
        drug_projection = self.drug_projector(drug)
        target_projection = self.target_projector(target)

        distance = self.activator(drug_projection, target_projection)
        return distance.squeeze()

class SimpleCoembeddingSigmoid(nn.Module):
    def __init__(
        self,
        drug_shape=2048,
        target_shape=1024,
        latent_dimension=1024,
        latent_activation=nn.ReLU,
        latent_distance="Cosine",
        classify=True,
    ):
        super().__init__()
        self.drug_shape = drug_shape
        self.target_shape = target_shape
        self.latent_dimension = latent_dimension
        self.do_classify = classify

        self.drug_projector = nn.Sequential(
            nn.Linear(self.drug_shape, latent_dimension), latent_activation()
        )
        nn.init.xavier_normal_(self.drug_projector[0].weight)

        self.target_projector = nn.Sequential(
            nn.Linear(self.target_shape, latent_dimension), latent_activation()
        )
        nn.init.xavier_normal_(self.target_projector[0].weight)

        if self.do_classify:
            self.distance_metric = latent_distance
            self.activator = DISTANCE_METRICS[self.distance_metric]()

    def forward(self, drug, target):
        if self.do_classify:
            return self.classify(drug, target)
        else:
            return self.regress(drug, target)

    def regress(self, drug, target):
        drug_projection = self.drug_projector(drug)
        target_projection = self.target_projector(target)

        inner_prod = torch.bmm(
            drug_projection.view(-1, 1, self.latent_dimension),
            target_projection.view(-1, self.latent_dimension, 1),
        ).squeeze()
        relu_f = torch.nn.ReLU()
        return relu_f(inner_prod).squeeze()

    def classify(self, drug, target):
        drug_projection = self.drug_projector(drug)
        target_projection = self.target_projector(target)

        distance = self.activator(drug_projection, target_projection)
        sigmoid_f = torch.nn.Sigmoid()
        return sigmoid_f(distance).squeeze()


In [5]:

import typing as T

import logging as lg
import multiprocessing as mp
import sys
from functools import partial
from pathlib import Path

import h5py
import numpy as np
import torch
from omegaconf import OmegaConf
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
from tqdm import tqdm

logLevels = {0: lg.ERROR, 1: lg.WARNING, 2: lg.INFO, 3: lg.DEBUG}
LOGGER_NAME = "DTI"


def get_logger(logger_name: str = None) -> lg.Logger:
    if logger_name is None:
        logger_name = LOGGER_NAME
    return lg.getLogger(logger_name)


logg = get_logger()


def config_logger(
    file: T.Union[Path, None],
    fmt: str,
    level: bool = 2,
    use_stdout: bool = True,):
    """
    Create and configure the logger

    :param file: Can be a Path or None -- if a Path, log messages will be written to the file at Path
    :type file: T.Union[Path, None]
    :param fmt: Formatting string for the log messages
    :type fmt: str
    :param level: Level of verbosity
    :type level: int
    :param use_stdout: Whether to also log messages to stdout
    :type use_stdout: bool
    :return:
    """

    module_logger = lg.getLogger(LOGGER_NAME)
    module_logger.setLevel(logLevels[level])
    formatter = lg.Formatter(fmt)

    if file is not None:
        fh = lg.FileHandler(file)
        fh.setFormatter(formatter)
        module_logger.addHandler(fh)

    if use_stdout:
        sh = lg.StreamHandler(sys.stdout)
        sh.setFormatter(formatter)
        module_logger.addHandler(sh)

    lg.propagate = False

    return module_logger


def set_random_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)


def canonicalize(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is not None:
        return Chem.MolToSmiles(mol, isomericSmiles=True)
    else:
        return None


def smiles2morgan(s, radius=2, nBits=2048):
    """
    Convert smiles into Morgan Fingerprint.
    :param smile: SMILES string
    :type smile: str
    :return: Morgan fingerprint
    :rtype: np.ndarray
    """
    try:
        s = canonicalize(s)
        mol = Chem.MolFromSmiles(s)
        features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits)
        features = np.zeros((1,))
        DataStructs.ConvertToNumpyArray(features_vec, features)
    except Exception as e:
        logg.error(e)
        logg.error(
            f"Failed to convert SMILES to Morgan Fingerprint: {s} convert to all 0 features"
        )
        features = np.zeros((nBits,))
    return features


def get_config(experiment_id, mol_feat, prot_feat):
    data_cfg = {
        "batch_size": 32,
        "num_workers": 0,
        "precompute": True,
        "mol_feat": mol_feat,
        "prot_feat": prot_feat,
    }
    model_cfg = {
        "latent_size": 1024,
    }
    training_cfg = {
        "n_epochs": 50,
        "every_n_val": 1,
    }
    cfg = {
        "data": data_cfg,
        "model": model_cfg,
        "training": training_cfg,
        "experiment_id": experiment_id,
    }

    return OmegaConf.structured(cfg)


def _hdf5_load_partial_func(k, file_path):
    """
    Helper function for load_hdf5_parallel
    """

    with h5py.File(file_path, "r") as fi:
        emb = torch.from_numpy(fi[k][:])
    return emb


def load_hdf5_parallel(file_path, keys, n_jobs=-1):
    """
    Load keys from hdf5 file into memory
    :param file_path: Path to hdf5 file
    :type file_path: str
    :param keys: List of keys to get
    :type keys: list[str]
    :return: Dictionary with keys and records in memory
    :rtype: dict
    """
    torch.multiprocessing.set_sharing_strategy("file_system")

    if (n_jobs == -1) or (n_jobs > mp.cpu_count()):
        n_jobs = mp.cpu_count()

    with mp.Pool(processes=n_jobs) as pool:
        all_embs = list(
            tqdm(
                pool.imap(partial(_hdf5_load_partial_func, file_path=file_path), keys),
                total=len(keys),
            )
        )

    embeddings = {k: v for k, v in zip(keys, all_embs)}
    return embeddings


class Cosine(nn.Module):
    def forward(self, x1, x2):
        return nn.CosineSimilarity()(x1, x2)

class SquaredCosine(nn.Module):
    def forward(self, x1, x2):
        return nn.CosineSimilarity()(x1, x2) ** 2

class Euclidean(nn.Module):
    def forward(self, x1, x2):
        return torch.cdist(x1, x2, p=2.0)


class SquaredEuclidean(nn.Module):
    def forward(self, x1, x2):
        return torch.cdist(x1, x2, p=2.0) ** 2


#######################
# Model Architectures #
#######################


DISTANCE_METRICS = {
    "Cosine": Cosine,
    "SquaredCosine": SquaredCosine,
    "Euclidean": Euclidean,
    "SquaredEuclidean": SquaredEuclidean,
}

ACTIVATIONS = {"ReLU": nn.ReLU, "GELU": nn.GELU, "ELU": nn.ELU, "Sigmoid": nn.Sigmoid}

In [6]:
# import typing as T
# from types import SimpleNamespace

# import os
# import pickle as pk
# import sys
# from functools import lru_cache
# from pathlib import Path

# import pandas as pd
import pytorch_lightning as pl
import torch
from numpy.random import choice
from sklearn.model_selection import KFold, train_test_split
from tdc.benchmark_group import dti_dg_group
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from gensim.models import word2vec

# from ..featurizer import Featurizer
# from ..featurizer.protein import FOLDSEEK_MISSING_IDX
# from ..utils import get_logger

logg = get_logger()


def get_task_dir(task_name: str, database_root: Path):
    """
    Get the path to data for each benchmark data set

    :param task_name: Name of benchmark
    :type task_name: str
    """

    database_root = Path(database_root).resolve()

    task_paths = {
        "biosnap": database_root / "BIOSNAP/full_data",
        "biosnap_prot": database_root / "BIOSNAP/unseen_protein",
        "biosnap_mol": database_root / "BIOSNAP/unseen_drug",
        "bindingdb": database_root / "BindingDB",
        "davis": database_root / "DAVIS",
        "dti_dg": database_root / "TDC",
        "dude": database_root / "DUDe",
        "halogenase": database_root / "EnzPred/halogenase_NaCl_binary",
        "bkace": database_root / "EnzPred/duf_binary",
        "gt": database_root / "EnzPred/gt_acceptors_achiral_binary",
        "esterase": database_root / "EnzPred/esterase_binary",
        "kinase": database_root / "EnzPred/davis_filtered",
        "phosphatase": database_root / "EnzPred/phosphatase_chiral_binary",
    }

    return Path(task_paths[task_name.lower()]).resolve()


def drug_target_collate_fn(args: T.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
    """
    Collate function for PyTorch data loader -- turn a batch of triplets into a triplet of batches

    If target embeddings are not all the same length, it will zero pad them
    This is to account for differences in length from FoldSeek embeddings

    :param args: Batch of training samples with molecule, protein, and affinity
    :type args: Iterable[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
    :return: Create a batch of examples
    :rtype: T.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
    """
    d_emb = [a[0] for a in args]
    t_emb = [a[1] for a in args]
    labs = [a[2] for a in args]

    drugs = torch.stack(d_emb, 0)
    targets = pad_sequence(t_emb, batch_first=True, padding_value=FOLDSEEK_MISSING_IDX)
    labels = torch.stack(labs, 0)

    return drugs, targets, labels


def contrastive_collate_fn(args: T.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
    """
    Collate function for PyTorch data loader -- turn a batch of triplets into a triplet of batches

    Specific collate function for contrastive dataloader

    :param args: Batch of training samples with anchor, positive, negative
    :type args: Iterable[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
    :return: Create a batch of examples
    :rtype: T.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
    """
    anchor_emb = [a[0] for a in args]
    pos_emb = [a[1] for a in args]
    neg_emb = [a[2] for a in args]

    anchors = pad_sequence(
        anchor_emb, batch_first=True, padding_value=FOLDSEEK_MISSING_IDX
    )
    positives = torch.stack(pos_emb, 0)
    negatives = torch.stack(neg_emb, 0)

    return anchors, positives, negatives


def make_contrastive(
    df: pd.DataFrame,
    posneg_column: str,
    anchor_column: str,
    label_column: str,
    n_neg_per: int = 50,):
    pos_df = df[df[label_column] == 1]
    neg_df = df[df[label_column] == 0]

    contrastive = []

    for _, r in pos_df.iterrows():
        for _ in range(n_neg_per):
            contrastive.append(
                (
                    r[anchor_column],
                    r[posneg_column],
                    choice(neg_df[posneg_column]),
                )
            )

    contrastive = pd.DataFrame(contrastive, columns=["Anchor", "Positive", "Negative"])
    return contrastive


class BinaryDataset(Dataset):
    def __init__(
        self,
        drugs,
        targets,
        labels,
        drug_featurizer: Featurizer,
        target_featurizer: Featurizer,
    ):
        self.drugs = drugs
        self.targets = targets
        self.labels = labels

        self.drug_featurizer = drug_featurizer
        self.target_featurizer = target_featurizer

    def __len__(self):
        return len(self.drugs)

    def __getitem__(self, i: int):
        drug = self.drug_featurizer(self.drugs.iloc[i])
        target = self.target_featurizer(self.targets.iloc[i])
        label = torch.tensor(self.labels.iloc[i])

        return drug, target, label


class ContrastiveDataset(Dataset):
    def __init__(
        self,
        anchors,
        positives,
        negatives,
        posneg_featurizer: Featurizer,
        anchor_featurizer: Featurizer,
    ):
        self.anchors = anchors
        self.positives = positives
        self.negatives = negatives

        self.posneg_featurizer = posneg_featurizer
        self.anchor_featurizer = anchor_featurizer

    def __len__(self):
        return len(self.anchors)

    def __getitem__(self, i):
        anchorEmb = self.anchor_featurizer(self.anchors[i])
        positiveEmb = self.posneg_featurizer(self.positives[i])
        negativeEmb = self.posneg_featurizer(self.negatives[i])

        return anchorEmb, positiveEmb, negativeEmb


class DTIDataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_dir: str,
        drug_featurizer: Featurizer,
        target_featurizer: Featurizer,
        device: torch.device = torch.device("cpu"),
        batch_size: int = 32,
        shuffle: bool = True,
        num_workers: int = 0,
        header=0,
        index_col=0,
        sep=",",
    ):
        self._loader_kwargs = {
            "batch_size": batch_size,
            "shuffle": shuffle,
            "num_workers": num_workers,
            "collate_fn": drug_target_collate_fn,
        }

        self._csv_kwargs = {
            "header": header,
            "index_col": index_col,
            "sep": sep,
        }

        self._device = device

        self._data_dir = Path(data_dir)
        self._train_path = Path("train.csv")
        self._val_path = Path("val.csv")
        self._test_path = Path("test.csv")

        self._drug_column = "SMILES"
        self._target_column = "Target Sequence"
        self._label_column = "Label"

        self.drug_featurizer = drug_featurizer
        self.target_featurizer = target_featurizer

    def prepare_data(self):
        if self.drug_featurizer.path.exists() and self.target_featurizer.path.exists():
            logg.warning("Drug and target featurizers already exist")
            return

        df_train = pd.read_csv(self._data_dir / self._train_path, **self._csv_kwargs)

        df_val = pd.read_csv(self._data_dir / self._val_path, **self._csv_kwargs)

        df_test = pd.read_csv(self._data_dir / self._test_path, **self._csv_kwargs)

        dataframes = [df_train, df_val, df_test]
        all_drugs = pd.concat([i[self._drug_column] for i in dataframes]).unique()
        all_targets = pd.concat([i[self._target_column] for i in dataframes]).unique()

        if self._device.type == "cuda":
            self.drug_featurizer.cuda(self._device)
            self.target_featurizer.cuda(self._device)

        if not self.drug_featurizer.path.exists():
            self.drug_featurizer.write_to_disk(all_drugs)

        if not self.target_featurizer.path.exists():
            self.target_featurizer.write_to_disk(all_targets)

        self.drug_featurizer.cpu()
        self.target_featurizer.cpu()

    def setup(self, stage: T.Optional[str] = None):
        self.df_train = pd.read_csv(
            self._data_dir / self._train_path, **self._csv_kwargs
        )

        self.df_val = pd.read_csv(self._data_dir / self._val_path, **self._csv_kwargs)

        self.df_test = pd.read_csv(self._data_dir / self._test_path, **self._csv_kwargs)

        self._dataframes = [self.df_train, self.df_val, self.df_test]

        all_drugs = pd.concat([i[self._drug_column] for i in self._dataframes]).unique()
        all_targets = pd.concat(
            [i[self._target_column] for i in self._dataframes]
        ).unique()

        if self._device.type == "cuda":
            self.drug_featurizer.cuda(self._device)
            self.target_featurizer.cuda(self._device)

        self.drug_featurizer.preload(all_drugs)
        self.drug_featurizer.cpu()

        self.target_featurizer.preload(all_targets)
        self.target_featurizer.cpu()

        if stage == "fit" or stage is None:
            self.data_train = BinaryDataset(
                self.df_train[self._drug_column],
                self.df_train[self._target_column],
                self.df_train[self._label_column],
                self.drug_featurizer,
                self.target_featurizer,
            )

            self.data_val = BinaryDataset(
                self.df_val[self._drug_column],
                self.df_val[self._target_column],
                self.df_val[self._label_column],
                self.drug_featurizer,
                self.target_featurizer,
            )

        if stage == "test" or stage is None:
            self.data_test = BinaryDataset(
                self.df_test[self._drug_column],
                self.df_test[self._target_column],
                self.df_test[self._label_column],
                self.drug_featurizer,
                self.target_featurizer,
            )

    def train_dataloader(self):
        return DataLoader(self.data_train, **self._loader_kwargs)

    def val_dataloader(self):
        return DataLoader(self.data_val, **self._loader_kwargs)

    def test_dataloader(self):
        return DataLoader(self.data_test, **self._loader_kwargs)


class TDCDataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_dir: str,
        drug_featurizer: Featurizer,
        target_featurizer: Featurizer,
        device: torch.device = torch.device("cpu"),
        seed: int = 0,
        batch_size: int = 32,
        shuffle: bool = True,
        num_workers: int = 0,
        header=0,
        index_col=0,
        sep=",",
    ):
        self._loader_kwargs = {
            "batch_size": batch_size,
            "shuffle": shuffle,
            "num_workers": num_workers,
            "collate_fn": drug_target_collate_fn,
        }

        self._csv_kwargs = {
            "header": header,
            "index_col": index_col,
            "sep": sep,
        }

        self._device = device

        self._data_dir = Path(data_dir)
        self._seed = seed

        self._drug_column = "Drug"
        self._target_column = "Target"
        self._label_column = "Y"

        self.drug_featurizer = drug_featurizer
        self.target_featurizer = target_featurizer

    def prepare_data(self):
        dg_group = dti_dg_group(path=self._data_dir)
        dg_benchmark = dg_group.get("bindingdb_patent")

        train_val, test = (
            dg_benchmark["train_val"],
            dg_benchmark["test"],
        )

        all_drugs = pd.concat([train_val, test])[self._drug_column].unique()
        all_targets = pd.concat([train_val, test])[self._target_column].unique()

        if self.drug_featurizer.path.exists() and self.target_featurizer.path.exists():
            logg.warning("Drug and target featurizers already exist")
            return

        if self._device.type == "cuda":
            self.drug_featurizer.cuda(self._device)
            self.target_featurizer.cuda(self._device)

        if not self.drug_featurizer.path.exists():
            self.drug_featurizer.write_to_disk(all_drugs)

        if not self.target_featurizer.path.exists():
            self.target_featurizer.write_to_disk(all_targets)

        self.drug_featurizer.cpu()
        self.target_featurizer.cpu()

    def setup(self, stage: T.Optional[str] = None):
        dg_group = dti_dg_group(path=self._data_dir)
        dg_benchmark = dg_group.get("bindingdb_patent")
        dg_name = dg_benchmark["name"]

        self.df_train, self.df_val = dg_group.get_train_valid_split(
            benchmark=dg_name, split_type="default", seed=self._seed
        )
        self.df_test = dg_benchmark["test"]

        self._dataframes = [self.df_train, self.df_val]

        all_drugs = pd.concat([i[self._drug_column] for i in self._dataframes]).unique()
        all_targets = pd.concat(
            [i[self._target_column] for i in self._dataframes]
        ).unique()

        if self._device.type == "cuda":
            self.drug_featurizer.cuda(self._device)
            self.target_featurizer.cuda(self._device)

        self.drug_featurizer.preload(all_drugs)
        self.drug_featurizer.cpu()

        self.target_featurizer.preload(all_targets)
        self.target_featurizer.cpu()

        if stage == "fit" or stage is None:
            self.data_train = BinaryDataset(
                self.df_train[self._drug_column],
                self.df_train[self._target_column],
                self.df_train[self._label_column],
                self.drug_featurizer,
                self.target_featurizer,
            )

            self.data_val = BinaryDataset(
                self.df_val[self._drug_column],
                self.df_val[self._target_column],
                self.df_val[self._label_column],
                self.drug_featurizer,
                self.target_featurizer,
            )

        if stage == "test" or stage is None:
            self.data_test = BinaryDataset(
                self.df_test[self._drug_column],
                self.df_test[self._target_column],
                self.df_test[self._label_column],
                self.drug_featurizer,
                self.target_featurizer,
            )

    def train_dataloader(self):
        return DataLoader(self.data_train, **self._loader_kwargs)

    def val_dataloader(self):
        return DataLoader(self.data_val, **self._loader_kwargs)

    def test_dataloader(self):
        return DataLoader(self.data_test, **self._loader_kwargs)


class EnzPredDataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_dir: str,
        drug_featurizer: Featurizer,
        target_featurizer: Featurizer,
        device: torch.device = torch.device("cpu"),
        seed: int = 0,
        batch_size: int = 32,
        shuffle: bool = True,
        num_workers: int = 0,
        header=0,
        index_col=0,
        sep=",",
    ):
        self._loader_kwargs = {
            "batch_size": batch_size,
            "shuffle": shuffle,
            "num_workers": num_workers,
            "collate_fn": drug_target_collate_fn,
        }

        self._csv_kwargs = {
            "header": header,
            "index_col": index_col,
            "sep": sep,
        }

        self._device = device

        self._data_file = Path(data_dir).with_suffix(".csv")
        self._data_stem = Path(self._data_file.stem)
        self._data_dir = self._data_file.parent / self._data_file.stem
        self._seed = 0
        self._replicate = seed

        df = pd.read_csv(self._data_file, index_col=0)
        self._drug_column = df.columns[1]
        self._target_column = df.columns[0]
        self._label_column = df.columns[2]

        self.drug_featurizer = drug_featurizer
        self.target_featurizer = target_featurizer

    @classmethod
    def dataset_list(cls):
        return [
            "halogenase",
            "bkace",
            "gt",
            "esterase",
            "kinase",
            "phosphatase",
        ]

    def prepare_data(self):
        os.makedirs(self._data_dir, exist_ok=True)

        kfsplitter = KFold(n_splits=10, shuffle=True, random_state=self._seed)
        full_data = pd.read_csv(self._data_file, index_col=0)

        all_drugs = full_data[self._drug_column].unique()
        all_targets = full_data[self._target_column].unique()

        if self.drug_featurizer.path.exists() and self.target_featurizer.path.exists():
            logg.warning("Drug and target featurizers already exist")

        if self._device.type == "cuda":
            self.drug_featurizer.cuda(self._device)
            self.target_featurizer.cuda(self._device)

        if not self.drug_featurizer.path.exists():
            self.drug_featurizer.write_to_disk(all_drugs)

        if not self.target_featurizer.path.exists():
            self.target_featurizer.write_to_disk(all_targets)

        self.drug_featurizer.cpu()
        self.target_featurizer.cpu()

        for i, split in enumerate(kfsplitter.split(full_data)):
            fold_train = full_data.iloc[split[0]].reset_index(drop=True)
            fold_test = full_data.iloc[split[1]].reset_index(drop=True)
            logg.debug(self._data_dir / self._data_stem.with_suffix(f".{i}.train.csv"))
            fold_train.to_csv(
                self._data_dir / self._data_stem.with_suffix(f".{i}.train.csv"),
                index=True,
                header=True,
            )
            fold_test.to_csv(
                self._data_dir / self._data_stem.with_suffix(f".{i}.test.csv"),
                index=True,
                header=True,
            )

    def setup(self, stage: T.Optional[str] = None):
        df_train = pd.read_csv(
            self._data_dir
            / self._data_stem.with_suffix(f".{self._replicate}.train.csv"),
            index_col=0,
        )
        self.df_train, self.df_val = train_test_split(df_train, test_size=0.1)
        self.df_test = pd.read_csv(
            self._data_dir
            / self._data_stem.with_suffix(f".{self._replicate}.test.csv"),
            index_col=0,
        )

        self._dataframes = [self.df_train, self.df_val, self.df_test]

        all_drugs = pd.concat([i[self._drug_column] for i in self._dataframes]).unique()
        all_targets = pd.concat(
            [i[self._target_column] for i in self._dataframes]
        ).unique()

        if self._device.type == "cuda":
            self.drug_featurizer.cuda(self._device)
            self.target_featurizer.cuda(self._device)

        self.drug_featurizer.preload(all_drugs)
        self.drug_featurizer.cpu()

        self.target_featurizer.preload(all_targets)
        self.target_featurizer.cpu()

        if stage == "fit" or stage is None:
            self.data_train = BinaryDataset(
                self.df_train[self._drug_column],
                self.df_train[self._target_column],
                self.df_train[self._label_column],
                self.drug_featurizer,
                self.target_featurizer,
            )

            self.data_val = BinaryDataset(
                self.df_val[self._drug_column],
                self.df_val[self._target_column],
                self.df_val[self._label_column],
                self.drug_featurizer,
                self.target_featurizer,
            )

        if stage == "test" or stage is None:
            self.data_test = BinaryDataset(
                self.df_test[self._drug_column],
                self.df_test[self._target_column],
                self.df_test[self._label_column],
                self.drug_featurizer,
                self.target_featurizer,
            )

    def train_dataloader(self):
        return DataLoader(self.data_train, **self._loader_kwargs)

    def val_dataloader(self):
        return DataLoader(self.data_val, **self._loader_kwargs)

    def test_dataloader(self):
        return DataLoader(self.data_test, **self._loader_kwargs)


class DUDEDataModule(pl.LightningDataModule):
    def __init__(
        self,
        task_dir: str,
        contrastive_split: str,
        drug_featurizer: Featurizer,
        target_featurizer: Featurizer,
        device: torch.device = torch.device("cpu"),
        n_neg_per: int = 50,
        batch_size: int = 32,
        shuffle: bool = True,
        num_workers: int = 0,
        header=0,
        index_col=None,
        sep="\t",
    ):
        self._loader_kwargs = {
            "batch_size": batch_size,
            "shuffle": shuffle,
            "num_workers": num_workers,
            "collate_fn": contrastive_collate_fn,
        }

        self._csv_kwargs = {
            "header": header,
            "index_col": index_col,
            "sep": sep,
        }

        self._device = device
        self._n_neg_per = n_neg_per

        self._data_dir = task_dir
        self._split = contrastive_split
        self._split_path = self._data_dir / Path(
            f"dude_{self._split}_type_train_test_split.csv"
        )

        self._drug_id_column = "Molecule_ID"
        self._drug_column = "Molecule_SMILES"
        self._target_id_column = "Target_ID"
        self._target_column = "Target_Seq"
        self._label_column = "Label"

        self.drug_featurizer = drug_featurizer
        self.target_featurizer = target_featurizer

    def prepare_data(self):
        pass

    def setup(self, stage: T.Optional[str] = None):
        self.df_full = pd.read_csv(
            self._data_dir / Path("full.tsv"), **self._csv_kwargs
        )

        self.df_splits = pd.read_csv(self._split_path, header=None)
        self._train_list = self.df_splits[self.df_splits[1] == "train"][0].values
        self._test_list = self.df_splits[self.df_splits[1] == "test"][0].values

        self.df_train = self.df_full[
            self.df_full[self._target_id_column].isin(self._train_list)
        ]
        self.df_test = self.df_full[
            self.df_full[self._target_id_column].isin(self._test_list)
        ]

        self.train_contrastive = make_contrastive(
            self.df_train,
            self._drug_column,
            self._target_column,
            self._label_column,
            self._n_neg_per,
        )

        self._dataframes = [self.df_train]  # , self.df_test]

        all_drugs = pd.concat([i[self._drug_column] for i in self._dataframes]).unique()
        all_targets = pd.concat(
            [i[self._target_column] for i in self._dataframes]
        ).unique()

        if self._device.type == "cuda":
            self.drug_featurizer.cuda(self._device)
            self.target_featurizer.cuda(self._device)

        self.drug_featurizer.preload(all_drugs, write_first=True)
        self.drug_featurizer.cpu()

        self.target_featurizer.preload(all_targets, write_first=True)
        self.target_featurizer.cpu()

        if stage == "fit" or stage is None:
            self.data_train = ContrastiveDataset(
                self.train_contrastive["Anchor"],
                self.train_contrastive["Positive"],
                self.train_contrastive["Negative"],
                self.drug_featurizer,
                self.target_featurizer,
            )


    def train_dataloader(self):
        return DataLoader(self.data_train, **self._loader_kwargs)




In [7]:

class LogisticActivation(nn.Module):
    """
    Implementation of Generalized Sigmoid
    Applies the element-wise function:
    :math:`\\sigma(x) = \\frac{1}{1 + \\exp(-k(x-x_0))}`
    :param x0: The value of the sigmoid midpoint
    :type x0: float
    :param k: The slope of the sigmoid - trainable -  :math:`k \\geq 0`
    :type k: float
    :param train: Whether :math:`k` is a trainable parameter
    :type train: bool
    """

    def __init__(self, x0=0, k=1, train=False):
        super(LogisticActivation, self).__init__()
        self.x0 = x0
        self.k = nn.Parameter(torch.FloatTensor([float(k)]), requires_grad=False)
        self.k.requiresGrad = train

    def forward(self, x):
        """
        Applies the function to the input elementwise
        :param x: :math:`(N \\times *)` where :math:`*` means, any number of additional dimensions
        :type x: torch.Tensor
        :return: :math:`(N \\times *)`, same shape as the input
        :rtype: torch.Tensor
        """
        o = torch.clamp(
            1 / (1 + torch.exp(-self.k * (x - self.x0))), min=0, max=1
        ).squeeze()
        return o

    def clip(self):
        """
        Restricts sigmoid slope :math:`k` to be greater than or equal to 0, if :math:`k` is trained.
        :meta private:
        """
        self.k.data.clamp_(min=0)


#######################
# Model Architectures #
#######################


class SimpleCoembedding(nn.Module):
    def __init__(
        self,
        drug_shape=2048,
        target_shape=1024,
        latent_dimension=1024,
        latent_activation="ReLU",
        latent_distance="Cosine",
        classify=True,
    ):
        super().__init__()
        self.drug_shape = drug_shape
        self.target_shape = target_shape
        self.latent_dimension = latent_dimension
        self.do_classify = classify
        self.latent_activation = ACTIVATIONS[latent_activation]

        self.drug_projector = nn.Sequential(
            nn.Linear(self.drug_shape, latent_dimension), self.latent_activation()
        )
        nn.init.xavier_normal_(self.drug_projector[0].weight)

        self.target_projector = nn.Sequential(
            nn.Linear(self.target_shape, latent_dimension), self.latent_activation()
        )
        nn.init.xavier_normal_(self.target_projector[0].weight)

        if self.do_classify:
            self.distance_metric = latent_distance
            self.activator = DISTANCE_METRICS[self.distance_metric]()

    def forward(self, drug, target):
        if self.do_classify:
            return self.classify(drug, target)
        else:
            return self.regress(drug, target)

    def regress(self, drug, target):
        drug_projection = self.drug_projector(drug)
        target_projection = self.target_projector(target)

        inner_prod = torch.bmm(
            drug_projection.view(-1, 1, self.latent_dimension),
            target_projection.view(-1, self.latent_dimension, 1),
        ).squeeze()
        return inner_prod.squeeze()

    def classify(self, drug, target):
        drug_projection = self.drug_projector(drug)
        target_projection = self.target_projector(target)

        distance = self.activator(drug_projection, target_projection)
        return distance.squeeze()


SimpleCoembeddingNoSigmoid = SimpleCoembedding


class SimpleCoembeddingSigmoid(nn.Module):
    def __init__(
        self,
        drug_shape=2048,
        target_shape=1024,
        latent_dimension=1024,
        latent_activation=nn.ReLU,
        latent_distance="Cosine",
        classify=True,
    ):
        super().__init__()
        self.drug_shape = drug_shape
        self.target_shape = target_shape
        self.latent_dimension = latent_dimension
        self.do_classify = classify

        self.drug_projector = nn.Sequential(
            nn.Linear(self.drug_shape, latent_dimension), latent_activation()
        )
        nn.init.xavier_normal_(self.drug_projector[0].weight)

        self.target_projector = nn.Sequential(
            nn.Linear(self.target_shape, latent_dimension), latent_activation()
        )
        nn.init.xavier_normal_(self.target_projector[0].weight)

        if self.do_classify:
            self.distance_metric = latent_distance
            self.activator = DISTANCE_METRICS[self.distance_metric]()

    def forward(self, drug, target):
        if self.do_classify:
            return self.classify(drug, target)
        else:
            return self.regress(drug, target)

    def regress(self, drug, target):
        drug_projection = self.drug_projector(drug)
        target_projection = self.target_projector(target)

        inner_prod = torch.bmm(
            drug_projection.view(-1, 1, self.latent_dimension),
            target_projection.view(-1, self.latent_dimension, 1),
        ).squeeze()
        relu_f = torch.nn.ReLU()
        return relu_f(inner_prod).squeeze()

    def classify(self, drug, target):
        drug_projection = self.drug_projector(drug)
        target_projection = self.target_projector(target)

        distance = self.activator(drug_projection, target_projection)
        sigmoid_f = torch.nn.Sigmoid()
        return sigmoid_f(distance).squeeze()


class SimpleCoembedding_FoldSeek(nn.Module):
    def __init__(
        self,
        drug_shape=2048,
        target_shape=1024,
        latent_dimension=1024,
        latent_activation=nn.ReLU,
        latent_distance="Cosine",
        classify=True,
        foldseek_embedding_dimension=1024,
    ):
        super().__init__()
        self.drug_shape = drug_shape
        self.target_shape = target_shape
        self.latent_dimension = latent_dimension
        self.foldseek_embedding_dimension = foldseek_embedding_dimension
        self.do_classify = classify

        self.foldseek_index_embedding = nn.Embedding(
            22,
            self.foldseek_embedding_dimension,
            padding_idx=FOLDSEEK_MISSING_IDX,
        )

        self.drug_projector = nn.Sequential(
            nn.Linear(self.drug_shape, latent_dimension), latent_activation()
        )
        nn.init.xavier_normal_(self.drug_projector[0].weight)

        self._target_projector = nn.Sequential(
            nn.Linear(
                (self.target_shape + self.foldseek_embedding_dimension),
                latent_dimension,
            ),
            latent_activation(),
        )
        nn.init.xavier_normal_(self._target_projector[0].weight)

        if self.do_classify:
            self.distance_metric = latent_distance
            self.activator = DISTANCE_METRICS[self.distance_metric]()

    def forward(self, drug, target):
        if self.do_classify:
            return self.classify(drug, target)
        else:
            return self.regress(drug, target)

    def _split_foldseek_target_embedding(self, target_embedding):
        """
        Expect that first dimension of target_embedding is batch dimension, second dimension is [target_shape | protein_length]

        FS indexes from 1-21, 0 is padding
        target is D + N_pool
            first D is PLM embedding
            next N_pool is FS index + pool
            nn.Embedding ignores elements with padding_idx = 0

            N --embedding--> N x D_fs --mean pool--> D_fs
            target is (D | D_fs) --linear--> latent
        """
        if target_embedding.shape[1] == self.target_shape:
            return target_embedding

        plm_embedding = target_embedding[:, : self.target_shape]
        foldseek_indices = target_embedding[:, self.target_shape :].long()
        foldseek_embedding = self.foldseek_index_embedding(foldseek_indices).mean(dim=1)

        full_target_embedding = torch.cat([plm_embedding, foldseek_embedding], dim=1)
        return full_target_embedding

    def target_projector(self, target):
        target_fs_emb = self._split_foldseek_target_embedding(target)
        target_projection = self._target_projector(target_fs_emb)
        return target_projection

    def regress(self, drug, target):
        drug_projection = self.drug_projector(drug)
        target_fs_emb = self._split_foldseek_target_embedding(target)
        target_projection = self._target_projector(target_fs_emb)

        inner_prod = torch.bmm(
            drug_projection.view(-1, 1, self.latent_dimension),
            target_projection.view(-1, self.latent_dimension, 1),
        ).squeeze()
        relu_f = torch.nn.ReLU()
        return relu_f(inner_prod).squeeze()

    def classify(self, drug, target):
        drug_projection = self.drug_projector(drug)
        target_projection = self.target_projector(target)

        distance = self.activator(drug_projection, target_projection)
        return distance.squeeze()


class SimpleCoembedding_FoldSeekX(nn.Module):
    def __init__(
        self,
        drug_shape=2048,
        target_shape=1024,
        latent_dimension=1024,
        latent_activation=nn.ReLU,
        latent_distance="Cosine",
        classify=True,
        foldseek_embedding_dimension=512,
    ):
        super().__init__()
        self.drug_shape = drug_shape
        self.target_shape = target_shape
        self.latent_dimension = latent_dimension
        self.foldseek_embedding_dimension = foldseek_embedding_dimension
        self.do_classify = classify

        self.foldseek_index_embedding = nn.Embedding(
            22,
            self.foldseek_embedding_dimension,
            padding_idx=FOLDSEEK_MISSING_IDX,
        )

        self.drug_projector = nn.Sequential(
            nn.Linear(self.drug_shape, latent_dimension), latent_activation()
        )
        nn.init.xavier_normal_(self.drug_projector[0].weight)

        self._target_projector = nn.Sequential(
            nn.Linear(
                (self.target_shape + self.foldseek_embedding_dimension),
                latent_dimension,
            ),
            latent_activation(),
        )
        nn.init.xavier_normal_(self._target_projector[0].weight)

        # self.projector_dropout = nn.Dropout(p=0.2)

        if self.do_classify:
            self.distance_metric = latent_distance
            self.activator = DISTANCE_METRICS[self.distance_metric]()

    def forward(self, drug, target):
        if self.do_classify:
            return self.classify(drug, target)
        else:
            return self.regress(drug, target)

    def _split_foldseek_target_embedding(self, target_embedding):
        """
        Expect that first dimension of target_embedding is batch dimension, second dimension is [target_shape | protein_length]

        FS indexes from 1-21, 0 is padding
        target is D + N_pool
            first D is PLM embedding
            next N_pool is FS index + pool
            nn.Embedding ignores elements with padding_idx = 0

            N --embedding--> N x D_fs --mean pool--> D_fs
            target is (D | D_fs) --linear--> latent
        """
        if target_embedding.shape[1] == self.target_shape:
            return target_embedding

        plm_embedding = target_embedding[:, : self.target_shape]
        foldseek_indices = target_embedding[:, self.target_shape :].long()
        foldseek_embedding = self.foldseek_index_embedding(foldseek_indices).mean(dim=1)

        full_target_embedding = torch.cat([plm_embedding, foldseek_embedding], dim=1)
        return full_target_embedding

    def target_projector(self, target):
        target_fs_emb = self._split_foldseek_target_embedding(target)
        target_projection = self._target_projector(target_fs_emb)
        return target_projection

    def regress(self, drug, target):
        drug_projection = self.drug_projector(drug)
        target_fs_emb = self._split_foldseek_target_embedding(target)
        target_projection = self._target_projector(target_fs_emb)

        inner_prod = torch.bmm(
            drug_projection.view(-1, 1, self.latent_dimension),
            target_projection.view(-1, self.latent_dimension, 1),
        ).squeeze()
        relu_f = torch.nn.ReLU()
        return relu_f(inner_prod).squeeze()

    def classify(self, drug, target):
        drug_projection = self.drug_projector(drug)
        target_projection = self.target_projector(target)

        distance = self.activator(drug_projection, target_projection)
        return distance.squeeze()


class GoldmanCPI(nn.Module):
    def __init__(
        self,
        drug_shape=2048,
        target_shape=1024,
        latent_dimension=100,
        latent_activation=nn.ReLU,
        latent_distance="Cosine",
        model_dropout=0.2,
        classify=True,
    ):
        super().__init__()
        self.drug_shape = drug_shape
        self.target_shape = target_shape
        self.latent_dimension = latent_dimension
        self.do_classify = classify

        self.drug_projector = nn.Sequential(
            nn.Linear(self.drug_shape, latent_dimension), latent_activation()
        )
        nn.init.xavier_normal_(self.drug_projector[0].weight)

        self.target_projector = nn.Sequential(
            nn.Linear(self.target_shape, latent_dimension), latent_activation()
        )
        nn.init.xavier_normal_(self.target_projector[0].weight)

        self.last_layers = nn.Sequential(
            nn.ReLU(),
            nn.Linear(latent_dimension, latent_dimension, bias=True),
            nn.Dropout(p=model_dropout),
            nn.ReLU(),
            nn.Linear(latent_dimension, latent_dimension, bias=True),
            nn.Dropout(p=model_dropout),
            nn.ReLU(),
            nn.Linear(latent_dimension, 1, bias=True),
        )

        if self.do_classify:
            self.distance_metric = latent_distance
            self.activator = DISTANCE_METRICS[self.distance_metric]()

    def forward(self, drug, target):
        if self.do_classify:
            return self.classify(drug, target)
        else:
            return self.regress(drug, target)

    def regress(self, drug, target):
        drug_projection = self.drug_projector(drug)
        target_projection = self.target_projector(target)
        output = torch.einsum("bd,bd->bd", drug_projection, target_projection)
        distance = self.last_layers(output)
        return distance

    def classify(self, drug, target):
        distance = self.regress(drug, target)
        sigmoid_f = torch.nn.Sigmoid()
        return sigmoid_f(distance).squeeze()


class SimpleCosine(nn.Module):
    def __init__(
        self,
        mol_emb_size=2048,
        prot_emb_size=100,
        latent_size=1024,
        latent_activation=nn.ReLU,
        distance_metric="Cosine",
    ):
        super().__init__()
        self.mol_emb_size = mol_emb_size
        self.prot_emb_size = prot_emb_size

        self.mol_projector = nn.Sequential(
            nn.Linear(self.mol_emb_size, latent_size), latent_activation()
        )

        self.prot_projector = nn.Sequential(
            nn.Linear(self.prot_emb_size, latent_size), latent_activation()
        )

        self.dist_metric = distance_metric
        self.activator = DISTANCE_METRICS[self.dist_metric]()

    def forward(self, mol_emb, prot_emb):
        mol_proj = self.mol_projector(mol_emb)
        prot_proj = self.prot_projector(prot_emb)

        return self.activator(mol_proj, prot_proj)


class AffinityCoembedInner(nn.Module):
    def __init__(
        self, mol_emb_size, prot_emb_size, latent_size=1024, activation=nn.ReLU
    ):
        super().__init__()
        self.mol_emb_size = mol_emb_size
        self.prot_emb_size = prot_emb_size
        self.latent_size = latent_size

        self.mol_projector = nn.Sequential(
            nn.Linear(self.mol_emb_size, latent_size), activation()
        )
        nn.init.xavier_uniform(self.mol_projector[0].weight)

        print(self.mol_projector[0].weight)

        self.prot_projector = nn.Sequential(
            nn.Linear(self.prot_emb_size, latent_size), activation()
        )
        nn.init.xavier_uniform(self.prot_projector[0].weight)

    def forward(self, mol_emb, prot_emb):
        mol_proj = self.mol_projector(mol_emb)
        prot_proj = self.prot_projector(prot_emb)
        print(mol_proj)
        print(prot_proj)
        y = torch.bmm(
            mol_proj.view(-1, 1, self.latent_size),
            prot_proj.view(-1, self.latent_size, 1),
        ).squeeze()
        return y


class CosineBatchNorm(nn.Module):
    def __init__(
        self,
        mol_emb_size=2048,
        prot_emb_size=100,
        latent_size=1024,
        latent_activation=nn.ReLU,
        distance_metric="Cosine",
    ):
        super().__init__()
        self.mol_emb_size = mol_emb_size
        self.prot_emb_size = prot_emb_size
        self.latent_size = latent_size

        self.mol_projector = nn.Sequential(
            nn.Linear(self.mol_emb_size, self.latent_size), latent_activation()
        )

        self.prot_projector = nn.Sequential(
            nn.Linear(self.prot_emb_size, self.latent_size),
            latent_activation(),
        )

        self.mol_norm = nn.BatchNorm1d(self.latent_size)
        self.prot_norm = nn.BatchNorm1d(self.latent_size)

        self.dist_metric = distance_metric
        self.activator = DISTANCE_METRICS[self.dist_metric]()

    def forward(self, mol_emb, prot_emb):
        mol_proj = self.mol_norm(self.mol_projector(mol_emb))
        prot_proj = self.prot_norm(self.prot_projector(prot_emb))

        return self.activator(mol_proj, prot_proj)


class LSTMCosine(nn.Module):
    def __init__(
        self,
        mol_emb_size=2048,
        prot_emb_size=100,
        lstm_layers=3,
        lstm_dim=256,
        latent_size=256,
        latent_activation=nn.ReLU,
    ):
        super().__init__()
        self.mol_emb_size = mol_emb_size
        self.prot_emb_size = prot_emb_size

        self.mol_projector = nn.Sequential(
            nn.Linear(self.mol_emb_size, latent_size), latent_activation()
        )

        self.rnn = nn.LSTM(
            self.prot_emb_size,
            lstm_dim,
            num_layers=lstm_layers,
            batch_first=True,
            bidirectional=True,
        )

        self.prot_projector = nn.Sequential(
            nn.Linear(2 * lstm_layers * lstm_dim, latent_size), nn.ReLU()
        )

        self.activator = nn.CosineSimilarity()

    def forward(self, mol_emb, prot_emb):
        mol_proj = self.mol_projector(mol_emb)

        outp, (h_out, _) = self.rnn(prot_emb)
        prot_hidden = h_out.permute(1, 0, 2).reshape(outp.shape[0], -1)
        prot_proj = self.prot_projector(prot_hidden)

        return self.activator(mol_proj, prot_proj)


class DeepCosine(nn.Module):
    def __init__(
        self,
        mol_emb_size=2048,
        prot_emb_size=100,
        latent_size=1024,
        hidden_size=4096,
        latent_activation=nn.ReLU,
    ):
        super().__init__()
        self.mol_emb_size = mol_emb_size
        self.prot_emb_size = prot_emb_size

        self.mol_projector = nn.Sequential(
            nn.Linear(self.mol_emb_size, latent_size), latent_activation()
        )

        self.prot_projector = nn.Sequential(
            nn.Linear(self.prot_emb_size, hidden_size),
            torch.nn.Dropout(p=0.5, inplace=False),
            latent_activation(),
            nn.Linear(hidden_size, latent_size),
            torch.nn.Dropout(p=0.5, inplace=False),
            latent_activation(),
        )

        self.activator = nn.CosineSimilarity()

    def forward(self, mol_emb, prot_emb):
        mol_proj = self.mol_projector(mol_emb)
        prot_proj = self.prot_projector(prot_emb)

        return self.activator(mol_proj, prot_proj)


class SimpleConcat(nn.Module):
    def __init__(
        self,
        mol_emb_size=2048,
        prot_emb_size=100,
        hidden_dim_1=512,
        hidden_dim_2=256,
        activation=nn.ReLU,
    ):
        super().__init__()
        self.mol_emb_size = mol_emb_size
        self.prot_emb_size = prot_emb_size

        self.fc1 = nn.Sequential(
            nn.Linear(mol_emb_size + prot_emb_size, hidden_dim_1), activation()
        )
        self.fc2 = nn.Sequential(nn.Linear(hidden_dim_1, hidden_dim_2), activation())
        self.fc3 = nn.Sequential(nn.Linear(hidden_dim_2, 1), nn.Sigmoid())

    def forward(self, mol_emb, prot_emb):
        cat_emb = torch.cat([mol_emb, prot_emb], axis=1)
        return self.fc3(self.fc2(self.fc1(cat_emb))).squeeze()


class SeparateConcat(nn.Module):
    def __init__(
        self,
        mol_emb_size=2048,
        prot_emb_size=100,
        latent_size=1024,
        latent_activation=nn.ReLU,
        distance_metric=None,
    ):
        super().__init__()
        self.mol_emb_size = mol_emb_size
        self.prot_emb_size = prot_emb_size

        self.mol_projector = nn.Sequential(
            nn.Linear(self.mol_emb_size, latent_size), latent_activation()
        )

        self.prot_projector = nn.Sequential(
            nn.Linear(self.prot_emb_size, latent_size), latent_activation()
        )

        self.fc = nn.Sequential(nn.Linear(2 * latent_size, 1), nn.Sigmoid())

    def forward(self, mol_emb, prot_emb):
        mol_proj = self.mol_projector(mol_emb)
        prot_proj = self.prot_projector(prot_emb)
        cat_emb = torch.cat([mol_proj, prot_proj], axis=1)
        return self.fc(cat_emb).squeeze()


class AffinityEmbedConcat(nn.Module):
    def __init__(
        self, mol_emb_size, prot_emb_size, latent_size=1024, activation=nn.ReLU
    ):
        super().__init__()
        self.mol_emb_size = mol_emb_size
        self.prot_emb_size = prot_emb_size
        self.latent_size = latent_size

        self.mol_projector = nn.Sequential(
            nn.Linear(self.mol_emb_size, latent_size), activation()
        )

        self.prot_projector = nn.Sequential(
            nn.Linear(self.prot_emb_size, latent_size), activation()
        )

        self.fc = nn.Linear(2 * latent_size, 1)

    def forward(self, mol_emb, prot_emb):
        mol_proj = self.mol_projector(mol_emb)
        prot_proj = self.prot_projector(prot_emb)
        cat_emb = torch.cat([mol_proj, prot_proj], axis=1)
        return self.fc(cat_emb).squeeze()


SimplePLMModel = AffinityEmbedConcat


class AffinityConcatLinear(nn.Module):
    def __init__(
        self,
        mol_emb_size,
        prot_emb_size,
    ):
        super().__init__()
        self.mol_emb_size = mol_emb_size
        self.prot_emb_size = prot_emb_size
        self.fc = nn.Linear(mol_emb_size + prot_emb_size, 1)

    def forward(self, mol_emb, prot_emb):
        cat_emb = torch.cat([mol_emb, prot_emb], axis=1)
        return self.fc(cat_emb).squeeze()



In [9]:
test_config = '/content/gdrive/MyDrive/test_config.yaml'

config = OmegaConf.load(test_config)
use_cuda = torch.cuda.is_available()
device = torch.device(f'cuda:{config.device}' if use_cuda else "cpu")
print('use_cuda: ', use_cuda)
print(device)

config.classify = False
config.latent_activation = "GELU"
config.watch_metric = "val/pcc"


use_cuda:  True
cuda:0


In [None]:
tdc_data = '/content/gdrive/MyDrive/A_JAK_design/ConPLex_dev/dataset/TDC/dti_dg_group.zip'

In [None]:
!unzip /content/gdrive/MyDrive/A_JAK_design/ConPLex_dev/dataset/TDC/dti_dg_group.zip

In [11]:
dti_train = '/content/gdrive/MyDrive/A_JAK_design/dti_dg_group/bindingdb_patent/train_val.csv'
dti_train = pd.read_csv(dti_train)
print(dti_train.shape) #  (183430, 6)
dti_train.head()

(183430, 6)


Unnamed: 0,Drug_ID,Drug,Target_ID,Target,Y,Year
0,91808352.0,COCc1ccccc1C1C(C(=O)C(C)C)C(=O)C(=O)N1c1ccc(cc...,P56373,MNCISDFFTYETTKSVVVKSWTIGIINRVVQLLIISYFVGWVFLHE...,2.564949,2013
1,67223437.0,Cc1c2[C@@H]3CCCN([C@@H]3Cc2ccc1C#N)C(=O)c1ccc2...,P28845,MAFMKKYLLPILGLFMAYYYYSANEEFRPEMLQGKKVIVTGASKGI...,4.60517,2013
2,46222354.0,CC(C)(C#N)c1cccc(C(=O)Nc2cccc(Oc3ccc4nc(NC(=O)...,Q02750,MPKKKPTPIQLNPAPDGSAVNGTSSAETNLEALQKKLEELELDEQQ...,5.703782,2013
3,59454472.0,CC(=O)Nc1nc2ccc(Oc3cccc(NC(=O)Nc4ccc(cc4)C(C)(...,Q02750,MPKKKPTPIQLNPAPDGSAVNGTSSAETNLEALQKKLEELELDEQQ...,5.703782,2013
4,46222580.0,Fc1ccc(Oc2ccc3nc(NC(=O)C4CC4)sc3c2C#N)cc1NC(=O...,Q02750,MPKKKPTPIQLNPAPDGSAVNGTSSAETNLEALQKKLEELELDEQQ...,5.703782,2013


In [None]:

# data_file = '/content/gdrive/MyDrive/A_JAK_design/ConPLex_dev/fedratinib_jak2_2lines.tsv'
data_file = '/content/gdrive/MyDrive/A_JAK_design/test_exist.tsv'
try:
    query_df = pd.read_csv(
        data_file,
        sep="\t",
        names=["proteinID", "moleculeID", "proteinSequence", "moleculeSmiles"],
        index_col=False
    )
except FileNotFoundError:
    print(f"Could not find data file: {data_file}")
print(query_df)


Select first several columns to evaluate our model

In [163]:
ind = 10 #@param {type:"integer"}
new_test_df = pd.DataFrame()
new_test_df['moleculeID'] = pd.DataFrame(dti_train['Drug_ID'][:ind])
new_test_df['proteinID'] = pd.DataFrame(dti_train['Target_ID'][:ind])
new_test_df['proteinSequence'] = pd.DataFrame(dti_train['Target'][:ind])
new_test_df['moleculeSmiles'] = pd.DataFrame(dti_train['Drug'][:ind])
new_test_df['label'] = pd.DataFrame(dti_train['Y'][:ind])

query_df = new_test_df
query_df

Unnamed: 0,moleculeID,proteinID,proteinSequence,moleculeSmiles,label
0,91808352.0,P56373,MNCISDFFTYETTKSVVVKSWTIGIINRVVQLLIISYFVGWVFLHE...,COCc1ccccc1C1C(C(=O)C(C)C)C(=O)C(=O)N1c1ccc(cc...,2.564949
1,67223437.0,P28845,MAFMKKYLLPILGLFMAYYYYSANEEFRPEMLQGKKVIVTGASKGI...,Cc1c2[C@@H]3CCCN([C@@H]3Cc2ccc1C#N)C(=O)c1ccc2...,4.60517
2,46222354.0,Q02750,MPKKKPTPIQLNPAPDGSAVNGTSSAETNLEALQKKLEELELDEQQ...,CC(C)(C#N)c1cccc(C(=O)Nc2cccc(Oc3ccc4nc(NC(=O)...,5.703782
3,59454472.0,Q02750,MPKKKPTPIQLNPAPDGSAVNGTSSAETNLEALQKKLEELELDEQQ...,CC(=O)Nc1nc2ccc(Oc3cccc(NC(=O)Nc4ccc(cc4)C(C)(...,5.703782
4,46222580.0,Q02750,MPKKKPTPIQLNPAPDGSAVNGTSSAETNLEALQKKLEELELDEQQ...,Fc1ccc(Oc2ccc3nc(NC(=O)C4CC4)sc3c2C#N)cc1NC(=O...,5.703782
5,59454502.0,Q02750,MPKKKPTPIQLNPAPDGSAVNGTSSAETNLEALQKKLEELELDEQQ...,CC(=O)Nc1nc2ccc(Oc3cccc(NC(=O)Cc4cccc(c4)C(F)(...,5.703782
6,46209401.0,Q02750,MPKKKPTPIQLNPAPDGSAVNGTSSAETNLEALQKKLEELELDEQQ...,Fc1ccc(Oc2ccc3nc(NC(=O)C4CC4)sc3c2C#N)cc1NC(=O...,5.703782
7,86644710.0,Q02750,MPKKKPTPIQLNPAPDGSAVNGTSSAETNLEALQKKLEELELDEQQ...,COC(=O)c1c(Oc2ccc(F)c(NC(=O)Cc3cccc(c3)C(F)(F)...,5.703782
8,71009084.0,Q05469,MEPGSKSVSRSDWQPEPHQRPITPLEPGPEKTPIAQPESKTLQGSN...,CC(C)(C)OC(=O)N1C[C@H]2CCN(C(=O)[C@H]2C1)c1ccc...,5.135798
9,71008797.0,Q05469,MEPGSKSVSRSDWQPEPHQRPITPLEPGPEKTPIAQPESKTLQGSN...,FC(F)(F)Oc1ccc(cc1)N1CC[C@H]2CN(C[C@H]2C1=O)C(...,3.901973


In [11]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from __future__ import annotations

import typing as T

from functools import lru_cache
from pathlib import Path

import h5py
import numpy as np
import torch
from tqdm import tqdm

from torch.utils.data import DataLoader
# from src.featurizers.molecule import MorganFeaturizer

# import MorganFeaturizer, ProtBertFeaturizer
# from model.architectures import SimpleCoembeddingNoSigmoid

class Featurizer:
    def __init__(
        self, name: str, shape: int, save_dir: Path = Path().absolute()
    ):
        self._name = name
        self._shape = shape
        self._save_path = save_dir / Path(f"{self._name}_features.h5")

        self._preloaded = False
        self._device = torch.device("cpu")
        self._cuda_registry = {}
        self._on_cuda = False
        self._features = {}

    def __call__(self, seq: str) -> torch.Tensor:
        if seq not in self.features:
            self._features[seq] = self.transform(seq)

        return self._features[seq]

    def _register_cuda(self, k: str, v, f=None):
        """
        Register an object as capable of being moved to a CUDA device
        """
        self._cuda_registry[k] = (v, f)

    def _transform(self, seq: str) -> torch.Tensor:
        raise NotImplementedError

    def _update_device(self, device: torch.device):
        self._device = device
        for k, (v, f) in self._cuda_registry.items():
            if f is None:
                try:
                    self._cuda_registry[k] = (v.to(self._device), None)
                except RuntimeError as e:
                    logg.error(e)
                    logg.debug(device)
                    logg.debug(type(self._device))
                    logg.debug(self._device)
            else:
                self._cuda_registry[k] = (f(v, self._device), f)
        for k, v in self._features.items():
            self._features[k] = v.to(device)

    @lru_cache(maxsize=5000)
    def transform(self, seq: str) -> torch.Tensor:
        with torch.set_grad_enabled(False):
            feats = self._transform(seq)
            if self._on_cuda:
                feats = feats.to(self.device)
            return feats

    @property
    def name(self) -> str:
        return self._name

    @property
    def shape(self) -> int:
        return self._shape

    @property
    def path(self) -> Path:
        return self._save_path

    @property
    def features(self) -> dict:
        return self._features

    @property
    def on_cuda(self) -> bool:
        return self._on_cuda

    @property
    def device(self) -> torch.device:
        return self._device

    def to(self, device: torch.device) -> Featurizer:
        self._update_device(device)
        self._on_cuda = device.type == "cuda"
        return self

    def cuda(self, device: torch.device) -> Featurizer:
        """
        Perform model computations on CUDA, move saved embeddings to CUDA device
        """
        self._update_device(device)
        self._on_cuda = True
        return self

    def cpu(self) -> Featurizer:
        """
        Perform model computations on CPU, move saved embeddings to CPU
        """
        self._update_device(torch.device("cpu"))
        self._on_cuda = False
        return self

    def write_to_disk(
        self, seq_list: T.List[str], verbose: bool = True
    ) -> None:
        logg.info(f"Writing {self.name} features to {self.path}")
        with h5py.File(self._save_path, "a") as h5fi:
            for seq in tqdm(seq_list, disable=not verbose, desc=self.name):
                seq_h5 = sanitize_string(seq)
                if seq_h5 in h5fi:
                    logg.warning(f"{seq} already in h5file")
                feats = self.transform(seq)
                dset = h5fi.require_dataset(seq_h5, feats.shape, np.float32)
                dset[:] = feats.cpu().numpy()

    def preload(
        self,
        seq_list: T.List[str],
        verbose: bool = True,
        write_first: bool = True,
    ) -> None:
        logg.info(f"Preloading {self.name} features from {self.path}")

        if write_first and not self._save_path.exists():
            self.write_to_disk(seq_list, verbose=verbose)

        if self._save_path.exists():
            with h5py.File(self._save_path, "r") as h5fi:
                for seq in tqdm(seq_list, disable=not verbose, desc=self.name):
                    if seq in h5fi:
                        seq_h5 = sanitize_string(seq)
                        feats = torch.from_numpy(h5fi[seq_h5][:])
                    else:
                        feats = self.transform(seq)

                    if self._on_cuda:
                        feats = feats.to(self.device)

                    self._features[seq] = feats

        else:
            for seq in tqdm(seq_list, disable=not verbose, desc=self.name):
                feats = self.transform(seq)

                if self._on_cuda:
                    feats = feats.to(self.device)

                self._features[seq] = feats

        # seqs_sanitized = [sanitize_string(s) for s in seq_list]
        # feat_dict = load_hdf5_parallel(self._save_path, seqs_sanitized,n_jobs=32)
        # self._features.update(feat_dict)

        self._update_device(self.device)
        self._preloaded = True

class MorganFeaturizer(Featurizer):

    def __init__(
        self,
        shape: int = 2048,
        radius: int = 2,
        save_dir: Path = Path().absolute(),
    ):
        super().__init__("Morgan", shape, save_dir)

        self._radius = radius

    def smiles_to_morgan(self, smile: str):
        """
        Convert smiles into Morgan Fingerprint.
        :param smile: SMILES string
        :type smile: str
        :return: Morgan fingerprint
        :rtype: np.ndarray
        """
        try:
            smile = canonicalize(smile)
            mol = Chem.MolFromSmiles(smile)
            features_vec = AllChem.GetMorganFingerprintAsBitVect(
                mol, self._radius, nBits=self.shape
            )
            features = np.zeros((1,))
            DataStructs.ConvertToNumpyArray(features_vec, features)
        except Exception as e:
            logg.error(
                f"rdkit not found this smiles for morgan: {smile} convert to all 0 features"
            )
            logg.error(e)
            features = np.zeros((self.shape,))
        return features

    def _transform(self, smile: str) -> torch.Tensor:
        # feats = torch.from_numpy(self._featurizer(smile)).squeeze().float()
        feats = (
            torch.from_numpy(self.smiles_to_morgan(smile)).squeeze().float()
        )
        if feats.shape[0] != self.shape:
            logg.warning("Failed to featurize: appending zero vector")
            feats = torch.zeros(self.shape)
        return feats


class SimpleCoembedding(nn.Module):
    def __init__(
        self,
        drug_shape=2048,
        target_shape=1024,
        latent_dimension=1024,
        latent_activation=nn.ReLU,
        latent_distance="Cosine",
        classify=True,
    ):
        super().__init__()
        self.drug_shape = drug_shape
        self.target_shape = target_shape
        self.latent_dimension = latent_dimension
        self.do_classify = classify
        if latent_activation =='GELU':
            # print(latent_activation)
            latent_activation = nn.GELU

        self.drug_projector = nn.Sequential(
            nn.Linear(self.drug_shape, latent_dimension), latent_activation()
        )
        nn.init.xavier_normal_(self.drug_projector[0].weight)

        self.target_projector = nn.Sequential(
            nn.Linear(self.target_shape, latent_dimension), latent_activation()
        )
        nn.init.xavier_normal_(self.target_projector[0].weight)

        if self.do_classify:
            self.distance_metric = latent_distance
            self.activator = DISTANCE_METRICS[self.distance_metric]()

    def forward(self, drug, target):
        if self.do_classify:
            return self.classify(drug, target)
        else:
            return self.regress(drug, target)

    def regress(self, drug, target):
        drug_projection = self.drug_projector(drug)
        target_projection = self.target_projector(target)

        inner_prod = torch.bmm(
            drug_projection.view(-1, 1, self.latent_dimension),
            target_projection.view(-1, self.latent_dimension, 1),
        ).squeeze()
        relu_f = torch.nn.ReLU()
        return relu_f(inner_prod).squeeze()

    def classify(self, drug, target):
        drug_projection = self.drug_projector(drug)
        target_projection = self.target_projector(target)

        distance = self.activator(drug_projection, target_projection)
        return distance.squeeze()

class SimpleCoembeddingSigmoid(nn.Module):
    def __init__(
        self,
        drug_shape=2048,
        target_shape=1024,
        latent_dimension=1024,
        latent_activation=nn.ReLU,
        latent_distance="Cosine",
        classify=True,
    ):
        super().__init__()
        self.drug_shape = drug_shape
        self.target_shape = target_shape
        self.latent_dimension = latent_dimension
        self.do_classify = classify

        self.drug_projector = nn.Sequential(
            nn.Linear(self.drug_shape, latent_dimension), latent_activation()
        )
        nn.init.xavier_normal_(self.drug_projector[0].weight)

        self.target_projector = nn.Sequential(
            nn.Linear(self.target_shape, latent_dimension), latent_activation()
        )
        nn.init.xavier_normal_(self.target_projector[0].weight)

        if self.do_classify:
            self.distance_metric = latent_distance
            self.activator = DISTANCE_METRICS[self.distance_metric]()

    def forward(self, drug, target):
        if self.do_classify:
            return self.classify(drug, target)
        else:
            return self.regress(drug, target)

    def regress(self, drug, target):
        drug_projection = self.drug_projector(drug)
        target_projection = self.target_projector(target)

        inner_prod = torch.bmm(
            drug_projection.view(-1, 1, self.latent_dimension),
            target_projection.view(-1, self.latent_dimension, 1),
        ).squeeze()
        relu_f = torch.nn.ReLU()
        return relu_f(inner_prod).squeeze()

    def classify(self, drug, target):
        drug_projection = self.drug_projector(drug)
        target_projection = self.target_projector(target)

        distance = self.activator(drug_projection, target_projection)
        sigmoid_f = torch.nn.Sigmoid()
        return sigmoid_f(distance).squeeze()
from pathlib import Path
MODEL_CACHE_DIR = Path(
    '/content/gdrive/MyDrive/A_JAK_design/ConPLex_dev/', "..", "..", "models")

class ProtBertFeaturizer(Featurizer):
    def __init__(self, save_dir: Path = Path().absolute(), per_tok=False):
        super().__init__("ProtBert", 1024, save_dir)

        from transformers import AutoTokenizer, AutoModel, pipeline

        self._max_len = 1024
        self.per_tok = per_tok

        self._protbert_tokenizer = AutoTokenizer.from_pretrained(
            "Rostlab/prot_bert",
            do_lower_case=False,
            cache_dir=f"{MODEL_CACHE_DIR}/huggingface/transformers",
        )
        self._protbert_model = AutoModel.from_pretrained(
            "Rostlab/prot_bert",
            cache_dir=f"{MODEL_CACHE_DIR}/huggingface/transformers",
        )
        self._protbert_feat = pipeline(
            "feature-extraction",
            model=self._protbert_model,
            tokenizer=self._protbert_tokenizer,
        )

        self._register_cuda("model", self._protbert_model)
        self._register_cuda(
            "featurizer", self._protbert_feat, self._feat_to_device
        )

    def _feat_to_device(self, pipe, device):
        from transformers import pipeline

        if device.type == "cpu":
            d = -1
        else:
            d = device.index

        pipe = pipeline(
            "feature-extraction",
            model=self._protbert_model,
            tokenizer=self._protbert_tokenizer,
            device=d,
        )
        self._protbert_feat = pipe
        return pipe

    def _space_sequence(self, x):
        return " ".join(list(x))

    def _transform(self, seq: str):
        if len(seq) > self._max_len - 2:
            seq = seq[: self._max_len - 2]

        embedding = torch.tensor(
            self._cuda_registry["featurizer"][0](self._space_sequence(seq))
        )
        seq_len = len(seq)
        start_Idx = 1
        end_Idx = seq_len + 1
        feats = embedding.squeeze()[start_Idx:end_Idx]

        if self.per_tok:
            return feats
        return feats.mean(0)



In [12]:
task_dir = 'test_conplex_train/'
# create_path(task_dir)

In [14]:
logg.debug(f"Setting random state {config.replicate}")
set_random_seed(config.replicate)
logg.info("Preparing DataModule")
# task_dir = get_task_dir(config.task, database_root=data_cache_dir)

target_featurizer = ProtBertFeaturizer(
    save_dir=task_dir, per_tok=False)
drug_featurizer = MorganFeaturizer(save_dir=task_dir)

In [15]:
datamodule = TDCDataModule(
    task_dir,
    drug_featurizer,
    target_featurizer,
    device=device,
    seed=config.replicate,
    batch_size=config.batch_size,
    shuffle=config.shuffle,
    num_workers=config.num_workers,
)

In [None]:
def sanitize_string(s):
    return s.replace("/", "|")
datamodule.prepare_data()
datamodule.setup()

# Load DataLoaders
logg.info("Getting DataLoaders")
training_generator = datamodule.train_dataloader()
validation_generator = datamodule.val_dataloader()
testing_generator = datamodule.test_dataloader()

In [131]:
config.drug_shape = drug_featurizer.shape
config.target_shape = target_featurizer.shape

In [141]:
logg.info("Initializing model")

model = SimpleCoembedding(
    int(config.drug_shape),
    int(config.target_shape),
    latent_dimension=int(config.latent_dimension),
    latent_distance=config.latent_distance,
    latent_activation=config.latent_activation,
    classify=config.classify,
)
if "checkpoint" in config:
    state_dict = torch.load(config.checkpoint)
    model.load_state_dict(state_dict)

model = model.to(device)
logg.info(model)

# Optimizers
logg.info("Initializing optimizers")
opt = torch.optim.AdamW(model.parameters(), lr=config.lr)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    opt, T_0=config.lr_t0
)

GELU


In [None]:
for i in config.items():
    print(i)
    print(i[1], type(i[1]))

In [143]:
config.latent_activation

'GELU'

In [147]:
import copy
logg.info("Initializing metrics")
max_metric = 0
model_max = copy.deepcopy(model)
loss_fct = torch.nn.MSELoss()
val_metrics = {
    "val/mse": torchmetrics.MeanSquaredError,
    "val/pcc": torchmetrics.PearsonCorrCoef,
}

from time import time
config.contrastive = False
FOLDSEEK_MISSING_IDX = 20

In [146]:
# Initialize wandb

import json
def wandb_log(m, do_wandb=True):
    if do_wandb:
        wandb.log(m)


do_wandb = config.wandb_save and ("wandb_proj" in config)
if do_wandb:
    logg.info(f"Initializing wandb project {config.wandb_proj}")
    wandb.init(
        project=config.wandb_proj,
        name='test_run',
        config=dict(config),
    )
    wandb.watch(model, log_freq=100)
logg.info("Config:")
logg.info(json.dumps(dict(config), indent=4))

logg.info("Beginning Training")

torch.backends.cudnn.benchmark = True


VBox(children=(Label(value='0.002 MB of 0.002 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Charts/epoch_time,█▁
epoch,▁▁██
train/loss,▃▅▄▄▅▃▂▁▇▃▃▅█▆▃▇▆▄▃▄▅▄█▇▇▄▄▆▄▅▇▆▅▄▃▅▄█▅▇
train/lr,█▁
train/step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
val/mse,▁█
val/pcc,▁▁

0,1
Charts/epoch_time,20.0083
epoch,1.0
train/loss,21.32545
train/lr,9e-05
train/step,386912.0
val/mse,27.31711
val/pcc,0.81092


In [148]:


def test(model, data_generator, metrics, device=None, classify=True):
    if device is None:
        device = torch.device("cpu")

    metric_dict = {}

    for k, met_class in metrics.items():
        if classify:
            met_instance = met_class(task="binary")
        else:
            met_instance = met_class()
        met_instance.to(device)
        met_instance.reset()
        metric_dict[k] = met_instance

    model.eval()

    for _, batch in tqdm(enumerate(data_generator), total=len(data_generator)):
        pred, label = step(model, batch, device)
        if classify:
            label = label.int()
        else:
            label = label.float()

        for _, met_instance in metric_dict.items():
            met_instance(pred, label)

    results = {}
    for k, met_instance in metric_dict.items():
        res = met_instance.compute()
        results[k] = res

    for met_instance in metric_dict.values():
        met_instance.to("cpu")

    return results


In [None]:
for i, batch in tqdm(enumerate(training_generator),
                     total=len(training_generator)):
    print(batch)


In [154]:
batch[0].shape, batch[1].shape, batch[2].shape

# torch.Size([32, 2048]) torch.Size([32, 1024]) torch.Size([32])

(torch.Size([32, 2048]), torch.Size([32, 1024]), torch.Size([32]))

In [155]:
batch[2]

tensor([ 3.8286e+00,  1.0000e-10,  2.7788e+00,  1.6118e+01,  4.8828e+00,
         3.3673e+00,  2.6603e+00,  2.3026e+00,  3.8501e+00,  2.7147e+00,
         4.6052e+00,  4.8675e+00, -9.4135e-01,  5.3327e+00,  2.3026e+00,
         7.4955e+00,  2.5649e+00,  6.4615e+00,  8.4913e+00,  5.2257e+00,
         1.9315e+00,  9.2103e+00,  5.1358e+00,  3.4657e+00,  6.2538e+00,
         8.1605e+00,  6.5509e+00,  2.0477e+00,  5.8579e+00,  5.6937e+00,
         8.3291e-01,  9.3538e+00], dtype=torch.float64)

In [156]:
save_dir = 'save_conplex_new/'
create_path(save_dir)

loss_fct = torch.nn.MSELoss()
val_metrics = {
    "val/mse": torchmetrics.MeanSquaredError,
    "val/pcc": torchmetrics.PearsonCorrCoef,
}

test_metrics = {
    "test/mse": torchmetrics.MeanSquaredError,
    "test/pcc": torchmetrics.PearsonCorrCoef,
}

save_conplex_new/  folder is in directory:  False
save_conplex_new/  is created!


In [13]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)

def conplex_predict_batch(query_df, model_or_path, device=device):

    target_featurizer = ProtBertFeaturizer(
    save_dir='/', per_tok=False)
    drug_featurizer = MorganFeaturizer(save_dir='/')
    drug_featurizer.preload(query_df["moleculeSmiles"].unique())
    target_featurizer.preload(query_df["proteinSequence"].unique())

    SimpleCoembeddingNoSigmoid = SimpleCoembedding

    DISTANCE_METRICS = {
        "Cosine": Cosine,
        "SquaredCosine": SquaredCosine,
        "Euclidean": Euclidean,
        "SquaredEuclidean": SquaredEuclidean,
    }

    ACTIVATIONS = {"ReLU": nn.ReLU, "GELU": nn.GELU, "ELU": nn.ELU,
                   "Sigmoid": nn.Sigmoid}

    model = SimpleCoembeddingNoSigmoid(drug_featurizer.shape,
                                       target_featurizer.shape, 1024)
    if isinstance(model_or_path, str):
        try:
            model.load_state_dict(torch.load(model_path, map_location=device))
        except:
            model_path = model_save_path
    else: model = model_or_path
    model = model.eval()
    model = model.to(device)
    drug_featurizer.to(device)
    target_featurizer.to(device)

    dt_feature_pairs = [
    (drug_featurizer(r["moleculeSmiles"]),
     target_featurizer(r["proteinSequence"]))
    for _, r in query_df.iterrows()]

    dloader = DataLoader(dt_feature_pairs, batch_size=128, shuffle=False)

    print(f"Generating predictions...")
    preds = []
    with torch.set_grad_enabled(False):
        for b in dloader:
            preds.append(model(b[0], b[1]).detach().cpu().numpy())

    if len(preds[0]) != 1:
        preds = np.concatenate(preds)

    result_df = pd.DataFrame(query_df[["moleculeID", "proteinID"]])
    result_df["Prediction"] = preds
    return result_df


In [14]:
config.latent_activation

'GELU'

In [16]:
def calculate_feature_pair(query_df, device=device):
    target_featurizer = ProtBertFeaturizer(
    save_dir='/', per_tok=False)
    drug_featurizer = MorganFeaturizer(save_dir='/')
    drug_featurizer.preload(query_df["moleculeSmiles"].unique())

    target_featurizer.preload(query_df["proteinSequence"].unique())

    SimpleCoembeddingNoSigmoid = SimpleCoembedding

    DISTANCE_METRICS = {
        "Cosine": Cosine,
        "SquaredCosine": SquaredCosine,
        "Euclidean": Euclidean,
        "SquaredEuclidean": SquaredEuclidean,
    }

    ACTIVATIONS = {"ReLU": nn.ReLU, "GELU": nn.GELU, "ELU": nn.ELU,
                   "Sigmoid": nn.Sigmoid}

    drug_featurizer.to(device)
    target_featurizer.to(device)

    dt_feature_pairs = [
    (drug_featurizer(r["moleculeSmiles"]),
     target_featurizer(r["proteinSequence"]))
    for _, r in query_df.iterrows()]
    return dt_feature_pairs


def low_memory_predict(dt_feature_pairs, model_or_path, query_df):
    model = SimpleCoembedding(feature_pairs[0][0].shape[0],
                              feature_pairs[0][1].shape[0], 1024,
    latent_distance='Cosine',
    latent_activation='GELU',
    classify=False)

    if isinstance(model_or_path, str):
        # model = SimpleCoembeddingNoSigmoid(feature_pairs[0][0].shape[0],
        #                                feature_pairs[0][1].shape[0], 1024)
        model.load_state_dict(torch.load(model_or_path,
                                             map_location=device))
    else:
            model = model_or_path

    model = model.eval()
    model = model.to(device)

    dloader = DataLoader(dt_feature_pairs, batch_size=128, shuffle=False)

    print(f"Generating predictions...")
    preds = []
    with torch.set_grad_enabled(False):
        for b in dloader:
            preds.append(model(b[0], b[1]).detach().cpu().numpy())

    if len(preds[0]) != 1:
        preds = np.concatenate(preds)

    if "label" in query_df.columns:
        pass

    else:
        query_df['label'] = pd.DataFrame(len(query_df) * ['NA'])
    result_df = pd.DataFrame(query_df[["moleculeID", "proteinID", "label"]])

    result_df["Prediction"] = preds

    return result_df


d = '/content/gdrive/MyDrive/A_JAK_design/save_conplex/test_run_best_model.pt'


# data_file = '/content/gdrive/MyDrive/A_JAK_design/ConPLex_dev/fedratinib_jak2_2lines.tsv'
data_file = 'test_new.tsv'
try:
    query_df = pd.read_csv(
        data_file,
        sep="\t",
        names=["proteinID", "moleculeID", "proteinSequence", "moleculeSmiles"],
        index_col=False
    )
except FileNotFoundError:
    print(f"Could not find data file: {data_file}")
# query_df['moleculeSmiles'] = 'Cc1cnc(Nc2ccc(OCCN3CCCC3)cc2)nc1Nc1cccc(S(=O)(=O)NC(C)(C)C)c1'
# query_df.to_csv('test_new.tsv', sep="\t", index=False, header=False)

query_df
# bar = 'CCS(=O)(=O)N1CC(C1)(CC#N)N2C=C(C=N2)C3=C4C=CNC4=NC=N3'
# query_df['moleculeSmiles'].loc[1:]= bar
# query_df
# query_df.to_csv('test_new.tsv', sep="\t", index=False, header=False)


feature_pairs = calculate_feature_pair(query_df)
low_memory_predict(feature_pairs, d, query_df)

Morgan: 100%|██████████| 2/2 [00:00<00:00, 454.69it/s]
ProtBert: 100%|██████████| 1/1 [00:03<00:00,  3.81s/it]


Generating predictions...


Unnamed: 0,moleculeID,proteinID,label,Prediction
0,CHEMBL1287853,ENSP00000371067,,3.580391
1,CHEMBL1287853,ENSP00000371067,,3.047762


In [159]:
drug_featurizer.shape

2048

In [160]:
feature_pairs[0][0].shape

torch.Size([2048])

In [161]:
'proteinID' in query_df.columns

True

In [162]:
query_df['label'] = pd.DataFrame(len(query_df) * ['NA'])
query_df

Unnamed: 0,moleculeID,proteinID,proteinSequence,moleculeSmiles,label
0,91808352.0,P56373,MNCISDFFTYETTKSVVVKSWTIGIINRVVQLLIISYFVGWVFLHE...,COCc1ccccc1C1C(C(=O)C(C)C)C(=O)C(=O)N1c1ccc(cc...,
1,67223437.0,P28845,MAFMKKYLLPILGLFMAYYYYSANEEFRPEMLQGKKVIVTGASKGI...,Cc1c2[C@@H]3CCCN([C@@H]3Cc2ccc1C#N)C(=O)c1ccc2...,
2,46222354.0,Q02750,MPKKKPTPIQLNPAPDGSAVNGTSSAETNLEALQKKLEELELDEQQ...,CC(C)(C#N)c1cccc(C(=O)Nc2cccc(Oc3ccc4nc(NC(=O)...,
3,59454472.0,Q02750,MPKKKPTPIQLNPAPDGSAVNGTSSAETNLEALQKKLEELELDEQQ...,CC(=O)Nc1nc2ccc(Oc3cccc(NC(=O)Nc4ccc(cc4)C(C)(...,
4,46222580.0,Q02750,MPKKKPTPIQLNPAPDGSAVNGTSSAETNLEALQKKLEELELDEQQ...,Fc1ccc(Oc2ccc3nc(NC(=O)C4CC4)sc3c2C#N)cc1NC(=O...,
5,59454502.0,Q02750,MPKKKPTPIQLNPAPDGSAVNGTSSAETNLEALQKKLEELELDEQQ...,CC(=O)Nc1nc2ccc(Oc3cccc(NC(=O)Cc4cccc(c4)C(F)(...,
6,46209401.0,Q02750,MPKKKPTPIQLNPAPDGSAVNGTSSAETNLEALQKKLEELELDEQQ...,Fc1ccc(Oc2ccc3nc(NC(=O)C4CC4)sc3c2C#N)cc1NC(=O...,
7,86644710.0,Q02750,MPKKKPTPIQLNPAPDGSAVNGTSSAETNLEALQKKLEELELDEQQ...,COC(=O)c1c(Oc2ccc(F)c(NC(=O)Cc3cccc(c3)C(F)(F)...,
8,71009084.0,Q05469,MEPGSKSVSRSDWQPEPHQRPITPLEPGPEKTPIAQPESKTLQGSN...,CC(C)(C)OC(=O)N1C[C@H]2CCN(C(=O)[C@H]2C1)c1ccc...,
9,71008797.0,Q05469,MEPGSKSVSRSDWQPEPHQRPITPLEPGPEKTPIAQPESKTLQGSN...,FC(F)(F)Oc1ccc(cc1)N1CC[C@H]2CN(C[C@H]2C1=O)C(...,


In [107]:
feature_pairs = calculate_feature_pair(new_test_df)
new_test_df
low_memory_predict(feature_pairs, model_save_path, new_test_df)


Morgan: 100%|██████████| 10/10 [00:00<00:00, 743.49it/s]
ProtBert: 100%|██████████| 4/4 [00:06<00:00,  1.56s/it]


AttributeError: ignored

In [75]:
dloader = DataLoader(feature_pairs, batch_size=32, shuffle=False)

print(f"Generating predictions...")
preds = []
with torch.set_grad_enabled(False):
    for b in tqdm(dloader):
        print(b[0].shape, b[1].shape)
        print(model(b[0], b[1]))
        preds.append(model(b[0], b[1]).detach().cpu().numpy())

if len(preds[0]) != 1:
    preds = np.concatenate(preds)

print(preds)

if "label" in .columns:
    pass
else:
    query_df['label'] = pd.DataFrame(len(query_df) * ['NA'])

print(query_df)
result_df = pd.DataFrame(query_df[["moleculeID", "proteinID", "label"]])
print(result_df)

result_df["Prediction"] = preds

# return result_df
#


Generating predictions...


100%|██████████| 1/1 [00:00<00:00, 136.24it/s]

torch.Size([10, 2048]) torch.Size([10, 1024])
tensor([0.2225, 0.2616, 0.2611, 0.2787, 0.2741, 0.2788, 0.2652, 0.2536, 0.3579,
        0.3293], device='cuda:0')
[0.22254471 0.2615533  0.26109576 0.2787041  0.27413142 0.2788158
 0.2652303  0.2535711  0.35794708 0.3293364 ]
    proteinID moleculeID                                    proteinSequence  \
0  91808352.0     P56373  MNCISDFFTYETTKSVVVKSWTIGIINRVVQLLIISYFVGWVFLHE...   
1  67223437.0     P28845  MAFMKKYLLPILGLFMAYYYYSANEEFRPEMLQGKKVIVTGASKGI...   

                                      moleculeSmiles label  
0  COCc1ccccc1C1C(C(=O)C(C)C)C(=O)C(=O)N1c1ccc(cc...    NA  
1  Cc1c2[C@@H]3CCCN([C@@H]3Cc2ccc1C#N)C(=O)c1ccc2...    NA  
  moleculeID   proteinID label
0     P56373  91808352.0    NA
1     P28845  67223437.0    NA





ValueError: ignored

In [164]:
save_dir

'save_conplex_new/'

In [166]:
from torch.autograd import Variable

def step(model, batch, device=None):
    if device is None:
        device = torch.device("cpu")

    drug, target, label = batch  # target is (D + N_pool)
    pred = model(drug.to(device), target.to(device))
    label = Variable(torch.from_numpy(np.array(label)).float()).to(device)
    return pred, label


start_time = time()
for epo in range(config.epochs):
    model.train()
    epoch_time_start = time()

    # Main Step
    for i, batch in tqdm(enumerate(training_generator),
                         total=len(training_generator)):
        # print(batch[0].shape, batch[1].shape)
        pred, label = step(model, batch, device)  # batch is (2048, 1024, 1)
        loss = loss_fct(pred, label)

        wandb_log({
            "train/step": (epo * len(training_generator) * config.batch_size)
            + (i * config.batch_size),
            "train/loss": loss,},
            do_wandb,)

        opt.zero_grad()
        loss.backward()
        opt.step()

    lr_scheduler.step()

    wandb_log({"epoch": epo, "train/lr": lr_scheduler.get_lr()[0],},
              do_wandb,)

    logg.info(
    f"Training at Epoch {epo + 1} with loss {loss.cpu().detach().numpy():8f}")
    logg.info(f"Updating learning rate to {lr_scheduler.get_lr()[0]:8f}")

    # Contrastive Step
    if config.contrastive:
        logg.info(f"Training contrastive at Epoch {epo + 1}")
        for i, batch in tqdm(
            enumerate(contrastive_generator),
            total=len(contrastive_generator),
        ):
            anchor, positive, negative = contrastive_step(model, batch, device)

            contrastive_loss = contrastive_loss_fct(anchor, positive, negative)

            wandb_log(
                {
                    "train/c_step": (
                        epo
                        * len(training_generator)
                        * config.contrastive_batch_size
                    )
                    + (i * config.contrastive_batch_size),
                    "train/c_loss": contrastive_loss,
                },
                do_wandb,
            )

            opt_contrastive.zero_grad()
            contrastive_loss.backward()
            opt_contrastive.step()

        contrastive_loss_fct.step()
        lr_scheduler_contrastive.step()

        wandb_log(
            {
                "epoch": epo,
                "train/triplet_margin": contrastive_loss_fct.margin,
                "train/contrastive_lr": lr_scheduler_contrastive.get_lr(),
            },
            do_wandb,
        )

        logg.info(
            f"Training at Contrastive Epoch {epo + 1} with loss {contrastive_loss.cpu().detach().numpy():8f}"
        )
        logg.info(
            f"Updating contrastive learning rate to {lr_scheduler_contrastive.get_lr()[0]:8f}"
        )
        logg.info(f"Updating contrastive margin to {contrastive_loss_fct.margin}")

    epoch_time_end = time()

    # Validation
    if epo % config.every_n_val == 0:
        with torch.set_grad_enabled(False):
            val_results = test(
                model,
                validation_generator,
                val_metrics,
                device,
                config.classify,
            )

            val_results["epoch"] = epo
            val_results["Charts/epoch_time"] = (
                epoch_time_end - epoch_time_start
            ) / config.every_n_val

            wandb_log(val_results, do_wandb)
################################################################################
            print(low_memory_predict(feature_pairs, model, query_df))
################################################################################
            if val_results[config.watch_metric] > max_metric:
                print(f'{val_results[config.watch_metric]} > {max_metric}')

                logg.debug(
f"Validation AUPR {val_results[config.watch_metric]:8f} > previous max {max_metric:8f}")
                model_max = copy.deepcopy(model)
                max_metric = val_results[config.watch_metric]
                model_path_str = f"{save_dir}test_epoch{epo:02}.pt"
                model_save_path = Path(model_path_str)
                print(f"model save: {model_path_str}")
                logg.info(f"Saving checkpoint model to {model_save_path}")

                torch.save(model_max.state_dict(), model_save_path)
                print('try recompute')
                print(low_memory_predict(feature_pairs, model_path_str, query_df))

                if do_wandb:
                    art = wandb.Artifact(f"dti-{config.run_id}", type="model")
                    art.add_file(model_save_path, model_save_path.name)
                    wandb.log_artifact(art, aliases=["best"])

            logg.info(f"Validation at Epoch {epo + 1}")
            for k, v in val_results.items():
                if not k.startswith("_"):
                    logg.info(f"{k}: {v}")

end_time = time()


100%|██████████| 4591/4591 [00:19<00:00, 229.56it/s]
100%|██████████| 1142/1142 [00:04<00:00, 240.57it/s]


Generating predictions...
   moleculeID proteinID     label  Prediction
0  91808352.0    P56373  2.564949    4.206481
1  67223437.0    P28845  4.605170    4.559961
2  46222354.0    Q02750  5.703782    4.801049
3  59454472.0    Q02750  5.703782    6.321260
4  46222580.0    Q02750  5.703782    4.552299
5  59454502.0    Q02750  5.703782    4.686613
6  46209401.0    Q02750  5.703782    4.019047
7  86644710.0    Q02750  5.703782    2.868285
8  71009084.0    Q05469  5.135798    6.308607
9  71008797.0    Q05469  3.901973    5.711331
0.8589463829994202 > 0.8132712841033936
model save: save_conplex_new/test_epoch00.pt
try recompute 
Generating predictions...
   moleculeID proteinID     label  Prediction
0  91808352.0    P56373  2.564949    0.230441
1  67223437.0    P28845  4.605170    0.271976
2  46222354.0    Q02750  5.703782    0.210956
3  59454472.0    Q02750  5.703782    0.263724
4  46222580.0    Q02750  5.703782    0.214229
5  59454502.0    Q02750  5.703782    0.217289
6  46209401.0    Q02

 20%|██        | 924/4591 [00:03<00:14, 246.12it/s]


KeyboardInterrupt: ignored

In [111]:
query_df = new_test_df
query_df

Unnamed: 0,moleculeID,proteinID,proteinSequence,moleculeSmiles,label
0,91808352.0,P56373,MNCISDFFTYETTKSVVVKSWTIGIINRVVQLLIISYFVGWVFLHE...,COCc1ccccc1C1C(C(=O)C(C)C)C(=O)C(=O)N1c1ccc(cc...,2.564949
1,67223437.0,P28845,MAFMKKYLLPILGLFMAYYYYSANEEFRPEMLQGKKVIVTGASKGI...,Cc1c2[C@@H]3CCCN([C@@H]3Cc2ccc1C#N)C(=O)c1ccc2...,4.60517
2,46222354.0,Q02750,MPKKKPTPIQLNPAPDGSAVNGTSSAETNLEALQKKLEELELDEQQ...,CC(C)(C#N)c1cccc(C(=O)Nc2cccc(Oc3ccc4nc(NC(=O)...,5.703782
3,59454472.0,Q02750,MPKKKPTPIQLNPAPDGSAVNGTSSAETNLEALQKKLEELELDEQQ...,CC(=O)Nc1nc2ccc(Oc3cccc(NC(=O)Nc4ccc(cc4)C(C)(...,5.703782
4,46222580.0,Q02750,MPKKKPTPIQLNPAPDGSAVNGTSSAETNLEALQKKLEELELDEQQ...,Fc1ccc(Oc2ccc3nc(NC(=O)C4CC4)sc3c2C#N)cc1NC(=O...,5.703782
5,59454502.0,Q02750,MPKKKPTPIQLNPAPDGSAVNGTSSAETNLEALQKKLEELELDEQQ...,CC(=O)Nc1nc2ccc(Oc3cccc(NC(=O)Cc4cccc(c4)C(F)(...,5.703782
6,46209401.0,Q02750,MPKKKPTPIQLNPAPDGSAVNGTSSAETNLEALQKKLEELELDEQQ...,Fc1ccc(Oc2ccc3nc(NC(=O)C4CC4)sc3c2C#N)cc1NC(=O...,5.703782
7,86644710.0,Q02750,MPKKKPTPIQLNPAPDGSAVNGTSSAETNLEALQKKLEELELDEQQ...,COC(=O)c1c(Oc2ccc(F)c(NC(=O)Cc3cccc(c3)C(F)(F)...,5.703782
8,71009084.0,Q05469,MEPGSKSVSRSDWQPEPHQRPITPLEPGPEKTPIAQPESKTLQGSN...,CC(C)(C)OC(=O)N1C[C@H]2CCN(C(=O)[C@H]2C1)c1ccc...,5.135798
9,71008797.0,Q05469,MEPGSKSVSRSDWQPEPHQRPITPLEPGPEKTPIAQPESKTLQGSN...,FC(F)(F)Oc1ccc(cc1)N1CC[C@H]2CN(C[C@H]2C1=O)C(...,3.901973


In [82]:
import wandb
run = wandb.init()
artifact = run.use_artifact('lanlang/NoSigmoidTest/dti-test_run:v41', type='model')
artifact_dir = artifact.download()

VBox(children=(Label(value='324.265 MB of 324.265 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0,…

0,1
Charts/epoch_time,▅▂▃▃▃▂▂▆▂▁▃▅█▆▃▇▂▃▃▇▅▇▅▆▆▆▅▃▂▄▂▅▄▄▃▄▄▂▂▄
Charts/wall_clock_time,▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
test/eval_time,▁
test/mse,▁
test/pcc,▁
train/loss,▆█▅▄▆▃▄▃▇▄▄▅▂▂▂▂▃▄▅▁▄▂▁▂▂▂▁▃▂▁▁▁▃▄▃▁▂▂▁▂
train/lr,█▇▇▆▃▂▂▁█▇▇▆▃▂▂▁█▇▇▆▃▂▂▁█▇▇▆▃▂▂▁█▇▇▆▃▂▂█
train/step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val/mse,█▅▄▄▃▂▂▂▂▂▂▂▂▁▁▁▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Charts/epoch_time,18.93929
Charts/wall_clock_time,1183.82552
epoch,50.0
test/eval_time,168.32756
test/mse,6.60214
test/pcc,0.57691
train/loss,27.60896
train/lr,0.0001
train/step,18208.0
val/mse,1.71335


[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [83]:
artifact_dir

'./artifacts/dti-test_run:v41'

In [96]:
feature_pairs = calculate_feature_pair(new_test_df)


Morgan: 100%|██████████| 10/10 [00:00<00:00, 671.27it/s]
ProtBert: 100%|██████████| 4/4 [00:05<00:00,  1.48s/it]


TypeError: ignored

In [100]:
feature_pairs[0][0].shape[0]

2048

In [None]:
# Testing
logg.info("Beginning testing")
try:
    with torch.set_grad_enabled(False):
        model_max = model_max.eval()

        test_start_time = time()
        test_results = test(
            model_max,
            testing_generator,
            test_metrics,
            device,
            config.classify,
        )
        test_end_time = time()

        test_results["epoch"] = epo + 1
        test_results["test/eval_time"] = test_end_time - test_start_time
        test_results["Charts/wall_clock_time"] = end_time - start_time
        wandb_log(test_results, do_wandb)

        logg.info("Final Testing")
        for k, v in test_results.items():
            if not k.startswith("_"):
                logg.info(f"{k}: {v}")
        create_path(save_dir)
        model_save_path = Path(f"{save_dir}/{config.run_id}_best_model.pt")
        torch.save(
            model_max.state_dict(),
            model_save_path,
        )
        logg.info(f"Saving final model to {model_save_path}")

        if do_wandb:
            art = wandb.Artifact(f"dti-{config.run_id}", type="model")
            art.add_file(model_save_path, model_save_path.name)
            wandb.log_artifact(art, aliases=["best"])

except Exception as e:
    logg.error(f"Testing failed with exception {e}")

In [None]:
model_save_path = Path(f"{save_dir}/{config.run_id}_best_model.pt")
torch.save(
    model_max.state_dict(),
    model_save_path,
)
logg.info(f"Saving final model to {model_save_path}")

In [None]:
model_save_path

In [None]:
cd save_conplex

In [None]:
ls

Unnamed: 0,proteinID,moleculeID,proteinSequence,moleculeSmiles
0,91808352.0,P56373,MNCISDFFTYETTKSVVVKSWTIGIINRVVQLLIISYFVGWVFLHE...,COCc1ccccc1C1C(C(=O)C(C)C)C(=O)C(=O)N1c1ccc(cc...
1,67223437.0,P28845,MAFMKKYLLPILGLFMAYYYYSANEEFRPEMLQGKKVIVTGASKGI...,Cc1c2[C@@H]3CCCN([C@@H]3Cc2ccc1C#N)C(=O)c1ccc2...


In [None]:
query_df.to_csv('test_exist.tsv', sep="\t", index=False, header=False)
query_df

In [None]:
conplex_predict(new_test_df, model)

In [None]:
query_df = new_test_df
query_df

In [50]:
def sanitize_string(s):
    return s.replace("/", "|")


model_path = '/content/gdrive/MyDrive/A_JAK_design/save_conplex/test_run_best_model.pt'

# data_file = '/content/gdrive/MyDrive/A_JAK_design/ConPLex_dev/fedratinib_jak2_2lines.tsv'
data_file = '/content/gdrive/MyDrive/A_JAK_design/test_exist.tsv'
try:
    query_df = pd.read_csv(
        data_file,
        sep="\t",
        names=["proteinID", "moleculeID", "proteinSequence", "moleculeSmiles"],
        index_col=False
    )
except FileNotFoundError:
    print(f"Could not find data file: {data_file}")
print(query_df)


target_featurizer = ProtBertFeaturizer(
    save_dir='/', per_tok=False)
drug_featurizer = MorganFeaturizer(save_dir='/')
drug_featurizer.preload(query_df["moleculeSmiles"].unique())
target_featurizer.preload(query_df["proteinSequence"].unique())

SimpleCoembeddingNoSigmoid = SimpleCoembedding

DISTANCE_METRICS = {
    "Cosine": Cosine,
    "SquaredCosine": SquaredCosine,
    "Euclidean": Euclidean,
    "SquaredEuclidean": SquaredEuclidean,
}

ACTIVATIONS = {"ReLU": nn.ReLU, "GELU": nn.GELU, "ELU": nn.ELU, "Sigmoid": nn.Sigmoid}

model = SimpleCoembeddingNoSigmoid(
    drug_featurizer.shape, target_featurizer.shape, 1024
)
try:
    model.load_state_dict(torch.load(model_path, map_location=device))
except:
    model_path = model_save_path
    model.load_state_dict(torch.load(model_path, map_location=device))

model = model.eval()
model = model.to(device)
drug_featurizer.to(device)
target_featurizer.to(device)

    proteinID moleculeID                                    proteinSequence  \
0  91808352.0     P56373  MNCISDFFTYETTKSVVVKSWTIGIINRVVQLLIISYFVGWVFLHE...   
1  67223437.0     P28845  MAFMKKYLLPILGLFMAYYYYSANEEFRPEMLQGKKVIVTGASKGI...   

                                      moleculeSmiles  
0  COCc1ccccc1C1C(C(=O)C(C)C)C(=O)C(=O)N1c1ccc(cc...  
1  Cc1c2[C@@H]3CCCN([C@@H]3Cc2ccc1C#N)C(=O)c1ccc2...  


Morgan: 100%|██████████| 2/2 [00:00<00:00, 947.76it/s]
ProtBert: 100%|██████████| 2/2 [00:00<00:00, 1038.07it/s]


<__main__.ProtBertFeaturizer at 0x7cbe7c258220>

In [51]:
dt_feature_pairs = [
    (drug_featurizer(r["moleculeSmiles"]), target_featurizer(r["proteinSequence"]))
    for _, r in query_df.iterrows()]

dloader = DataLoader(dt_feature_pairs, batch_size=128, shuffle=False)

print(f"Generating predictions...")
preds = []
with torch.set_grad_enabled(False):
    for b in dloader:
        preds.append(model(b[0], b[1]).detach().cpu().numpy())
        # print(model(b[0], b[1]).detach().cpu().numpy())

# print(preds[0])
# print(len(preds[0]))

if len(preds[0]) != 1:
    preds = np.concatenate(preds)

result_df = pd.DataFrame(query_df[["moleculeID", "proteinID"]])
result_df["Prediction"] = preds

# print(f"Printing ConPLex results to {outfile}")
# result_df.to_csv(outfile, sep="\t", index=False, header=False)
result_df

Generating predictions...


Unnamed: 0,moleculeID,proteinID,Prediction
0,P56373,91808352.0,0.222545
1,P28845,67223437.0,0.261553


In [None]:
# dti_train = '/content/gdrive/MyDrive/A_JAK_design/dti_dg_group/bindingdb_patent/train_val.csv'
# dti_train = pd.read_csv(dti_train)
# print(dti_train.shape) #  (183430, 6)
dti_train.head()

In [None]:
dti_train.iloc[0]['Drug_ID']

In [None]:
dti_train['Drug_ID'][:2]

In [None]:
new_test_df = pd.DataFrame()
new_test_df['moleculeID'] = pd.DataFrame(dti_train['Drug_ID'][:2])
new_test_df['proteinID'] = pd.DataFrame(dti_train['Target_ID'][:2])
new_test_df['proteinSequence'] = pd.DataFrame(dti_train['Target'][:2])
new_test_df['moleculeSmiles'] = pd.DataFrame(dti_train['Drug'][:2])

In [None]:
new_test_df