In [None]:
# constant

In [None]:
LETTER_TO_NUM = {'C': 4, 'D': 3, 'S': 15, 'Q': 5, 'K': 11, 'I': 9,
                       'P': 14, 'T': 16, 'F': 13, 'A': 0, 'G': 7, 'H': 8,
                       'E': 6, 'L': 10, 'R': 1, 'W': 17, 'V': 19,
                       'N': 2, 'Y': 18, 'M': 12, 'X':20}

NUM_TO_LETTER = {v:k for k, v in LETTER_TO_NUM.items()}

ATOM_VOCAB = [
    'C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na','Ca',
    'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb','Sb', 'Sn', 'Ag',
    'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H','Li', 'Ge', 'Cu', 'Au', 'Ni',
    'Cd', 'In', 'Mn', 'Zr','Cr', 'Pt', 'Hg', 'Pb', 'unk']


In [None]:
# dta 

In [None]:
"""
Drug-target binding affinity datasets
"""
import math
import yaml
import json
from functools import partial
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.utils.data as data


class DTA(data.Dataset):
    """
    Base class for loading drug-target binding affinity datasets.
    """
    def __init__(self, df=None, data_list=None, onthefly=False,
                prot_featurize_fn=None, drug_featurize_fn=None):
        """
        Parameters
        ----------
            df : pd.DataFrame with columns [`drug`, `protein`, `y`],
                where `drug`: drug key, `protein`: protein key, `y`: binding affinity.
            data_list : list of dict (same order as df)
                if `onthefly` is True, data_list has the PDB coordinates and SMILES strings
                    {`drug`: SDF file path, `protein`: coordinates dict (`pdb_data` in `DTATask`), `y`: float}
                if `onthefly` is False, data_list has the cached torch_geometric graphs
                    {`drug`: `torch_geometric.data.Data`, `protein`: `torch_geometric.data.Data`, `y`: float}
                `protein` has attributes:
                    -x          alpha carbon coordinates, shape [n_nodes, 3]
                    -edge_index edge indices, shape [2, n_edges]
                    -seq        sequence converted to int tensor according to `self.letter_to_num`, shape [n_nodes]
                    -name       name of the protein structure, string
                    -node_s     node scalar features, shape [n_nodes, 6]
                    -node_v     node vector features, shape [n_nodes, 3, 3]
                    -edge_s     edge scalar features, shape [n_edges, 39]
                    -edge_v     edge scalar features, shape [n_edges, 1, 3]
                    -mask       node mask, `False` for nodes with missing data that are excluded from message passing
                    -seq_emb    sequence embedding (ESM1b), shape [n_nodes, 1280]
                `drug` has attributes:
                    -x          atom coordinates, shape [n_nodes, 3]
                    -edge_index edge indices, shape [2, n_edges]
                    -node_s     node scalar features, shape [n_nodes, 66]
                    -node_v     node vector features, shape [n_nodes, 1, 3]
                    -edge_s     edge scalar features, shape [n_edges, 16]
                    -edge_v     edge scalar features, shape [n_edges, 1, 3]
                    -name       name of the drug, string
            onthefly : bool
                whether to featurize data on the fly or pre-compute
            prot_featurize_fn : function
                function to featurize a protein.
            drug_featurize_fn : function
                function to featurize a drug.
        """
        super(DTA, self).__init__()
        self.data_df = df
        self.data_list = data_list
        self.onthefly = onthefly
        if onthefly:
            assert prot_featurize_fn is not None, 'prot_featurize_fn must be provided'
            assert drug_featurize_fn is not None, 'drug_featurize_fn must be provided'
        self.prot_featurize_fn = prot_featurize_fn
        self.drug_featurize_fn = drug_featurize_fn

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        if self.onthefly:
            drug = self.drug_featurize_fn(
                self.data_list[idx]['drug'],
                name=self.data_list[idx]['drug_name']
            )
            prot = self.prot_featurize_fn(
                self.data_list[idx]['protein'],
                name=self.data_list[idx]['protein_name']
            )
        else:
            drug = self.data_list[idx]['drug']
            prot = self.data_list[idx]['protein']
        y = self.data_list[idx]['y']
        item = {'drug': drug, 'protein': prot, 'y': y}
        return item


def create_fold(df, fold_seed, frac):
    """
    Create train/valid/test folds by random splitting.
    """
    train_frac, val_frac, test_frac = frac
    test = df.sample(frac = test_frac, replace = False, random_state = fold_seed)
    train_val = df[~df.index.isin(test.index)]
    val = train_val.sample(frac = val_frac/(1-test_frac), replace = False, random_state = 1)
    train = train_val[~train_val.index.isin(val.index)]

    return {'train': train.reset_index(drop = True),
            'valid': val.reset_index(drop = True),
            'test': test.reset_index(drop = True)}


def create_fold_setting_cold(df, fold_seed, frac, entity):
    """
    Create train/valid/test folds by drug/protein-wise splitting.
    """
    train_frac, val_frac, test_frac = frac
    gene_drop = df[entity].drop_duplicates().sample(frac = test_frac, replace = False, random_state = fold_seed).values

    test = df[df[entity].isin(gene_drop)]

    train_val = df[~df[entity].isin(gene_drop)]

    gene_drop_val = train_val[entity].drop_duplicates().sample(frac = val_frac/(1-test_frac), replace = False, random_state = fold_seed).values
    val = train_val[train_val[entity].isin(gene_drop_val)]
    train = train_val[~train_val[entity].isin(gene_drop_val)]

    return {'train': train.reset_index(drop = True),
            'valid': val.reset_index(drop = True),
            'test': test.reset_index(drop = True)}


def create_full_ood_set(df, fold_seed, frac):
    """
    Create train/valid/test folds such that drugs and proteins are
    not overlapped in train and test sets. Train and valid may share
    drugs and proteins (random split).
    """
    train_frac, val_frac, test_frac = frac
    test_drugs = df['drug'].drop_duplicates().sample(frac=test_frac, replace=False, random_state=fold_seed).values
    test_prots = df['protein'].drop_duplicates().sample(frac=test_frac, replace=False, random_state=fold_seed).values

    test = df[(df['drug'].isin(test_drugs)) & (df['protein'].isin(test_prots))]
    train_val = df[(~df['drug'].isin(test_drugs)) & (~df['protein'].isin(test_prots))]

    val = train_val.sample(frac=val_frac/(1-test_frac), replace=False, random_state=fold_seed)
    train = train_val[~train_val.index.isin(val.index)]

    return {'train': train.reset_index(drop=True),
            'valid': val.reset_index(drop=True),
            'test': test.reset_index(drop=True)}


def create_seq_identity_fold(df, mmseqs_seq_clus_df, fold_seed, frac, min_clus_in_split=5):
    """
    Adapted from: https://github.com/drorlab/atom3d/blob/master/atom3d/splits/sequence.py
    Clusters are selected randomly into validation and test sets,
    but to ensure that there is some diversity in each set
    (i.e. a split does not consist of a single sequence cluster), a minimum number of clusters in each split is enforced.
    Some data examples may be removed in order to satisfy this constraint.
    """
    _rng = np.random.RandomState(fold_seed)

    def _parse_mmseqs_cluster_res(mmseqs_seq_clus_df):
        clus2seq, seq2clus = {}, {}
        for rep, sdf in mmseqs_seq_clus_df.groupby('rep'):
            for seq in sdf['seq']:
                if rep not in clus2seq:
                    clus2seq[rep] = []
                clus2seq[rep].append(seq)
                seq2clus[seq] = rep
        return seq2clus, clus2seq

    def _create_cluster_split(df, seq2clus, clus2seq, to_use, split_size, min_clus_in_split):
        data = df.copy()
        all_prot = set(seq2clus.keys())
        used = all_prot.difference(to_use)
        split = None
        while True:
            p = _rng.choice(sorted(to_use))
            c = seq2clus[p]
            members = set(clus2seq[c])
            members = members.difference(used)
            if len(members) == 0:
                continue
            # ensure that at least min_fam_in_split families in each split
            max_clust_size = int(np.ceil(split_size / min_clus_in_split))
            sel_prot = list(members)[:max_clust_size]
            sel_df = data[data['protein'].isin(sel_prot)]
            split = sel_df if split is None else pd.concat([split, sel_df])
            to_use = to_use.difference(members)
            used = used.union(members)
            if len(split) >= split_size:
                break
        split = split.reset_index(drop=True)
        return split, to_use

    seq2clus, clus2seq = _parse_mmseqs_cluster_res(mmseqs_seq_clus_df)
    train_frac, val_frac, test_frac = frac
    test_size, val_size = len(df) * test_frac, len(df) * val_frac
    to_use = set(seq2clus.keys())

    val_df, to_use = _create_cluster_split(df, seq2clus, clus2seq, to_use, val_size, min_clus_in_split)
    test_df, to_use = _create_cluster_split(df, seq2clus, clus2seq, to_use, test_size, min_clus_in_split)
    train_df = df[df['protein'].isin(to_use)].reset_index(drop=True)
    train_df['split'] = 'train'
    val_df['split'] = 'valid'
    test_df['split'] = 'test'

    assert len(set(train_df['protein']) & set(val_df['protein'])) == 0
    assert len(set(test_df['protein']) & set(val_df['protein'])) == 0
    assert len(set(train_df['protein']) & set(test_df['protein'])) == 0

    return {'train': train_df.reset_index(drop=True),
            'valid': val_df.reset_index(drop=True),
            'test': test_df.reset_index(drop=True)}


class DTATask(object):
    """
    Drug-target binding task (e.g., KIBA or Davis).
    Three splits: train/valid/test, each split is a DTA() class
    """
    def __init__(self,
            task_name=None,
            df=None,
            prot_pdb_id=None, pdb_data=None,
            emb_dir=None,
            drug_sdf_dir=None,
            num_pos_emb=16, num_rbf=16,
            contact_cutoff=8.,
            split_method='random', split_frac=[0.7, 0.1, 0.2],
            mmseqs_seq_clus_df=None,
            seed=42, onthefly=False
        ):
        """
        Parameters
        ----------
        task_name: str
            Name of the task (e.g., KIBA, Davis, etc.)
        df: pd.DataFrame
            Dataframe containing the data
        prot_pdb_id: dict
            Dictionary mapping protein name to PDB ID
        pdb_data: dict
            A json format of pocket structure data, where key is the PDB ID
            and value is the corresponding PDB structure data in a dictionary:
                -'name': kinase name
                -'UniProt_id': UniProt ID
                -'PDB_id': PDB ID,
                -'chain': chain ID,
                -'seq': pocket sequence,                
                -'coords': coordinates of the 'N', 'CA', 'C', 'O' atoms of the pocket residues,
                    - "N": [[x, y, z], ...]
                    - "CA": [[], ...],
                    - "C": [[], ...],
                    - "O": [[], ...]               
            (there are some other keys but only for internal use)
        emb_dir: str
            Directory containing the protein embeddings
        drug_sdf_dir: str
            Directory containing the drug SDF files
        num_pos_emb: int
            Dimension of positional embeddings
        num_rbf: int
            Number of radial basis functions
        contact_cutoff: float
            Cutoff distance for defining residue-residue contacts
        split_method: str
            how to split train/test sets, 
            -`random`: random split
            -`protein`: split by protein
            -`drug`: split by drug
            -`both`: unseen drugs and proteins in test set
            -`seqid`: split by protein sequence identity 
                (need to priovide the MMseqs2 sequence cluster result,
                see `mmseqs_seq_clus_df`)
        split_frac: list
            Fraction of data in train/valid/test sets
        mmseqs_seq_clus_df: pd.DataFrame
            Dataframe containing the MMseqs2 sequence cluster result
            using a desired sequence identity cutoff
        seed: int
            Random seed
        onthefly: bool
            whether to featurize data on the fly or pre-compute
        """
        self.task_name = task_name        
        self.prot_pdb_id = prot_pdb_id
        self.pdb_data = pdb_data        
        self.emb_dir = emb_dir
        self.df = df
        self.prot_featurize_params = dict(
            num_pos_emb=num_pos_emb, num_rbf=num_rbf,
            contact_cutoff=contact_cutoff)        
        self.drug_sdf_dir = drug_sdf_dir        
        self._prot2pdb = None
        self._pdb_graph_db = None        
        self._drug2sdf_file = None
        self._drug_sdf_db = None
        self.split_method = split_method
        self.split_frac = split_frac
        self.mmseqs_seq_clus_df = mmseqs_seq_clus_df
        self.seed = seed
        self.onthefly = onthefly

    def _format_pdb_entry(self, _data):
        _coords = _data["coords"]
        entry = {
            "name": _data["name"],
            "seq": _data["seq"],
            "coords": list(zip(_coords["N"], _coords["CA"], _coords["C"], _coords["O"])),
        }        
        if self.emb_dir is not None:
            embed_file = f"{_data['PDB_id']}.{_data['chain']}.pt"
            entry["embed"] = f"{self.emb_dir}/{embed_file}"
        return entry

    @property
    def prot2pdb(self):
        if self._prot2pdb is None:
            self._prot2pdb = {}
            for prot, pdb in self.prot_pdb_id.items():
                _pdb_entry = self.pdb_data[pdb]
                self._prot2pdb[prot] = self._format_pdb_entry(_pdb_entry)
        return self._prot2pdb

    @property
    def pdb_graph_db(self):
        if self._pdb_graph_db is None:
            self._pdb_graph_db = pdb_graph.pdb_to_graphs(self.prot2pdb,
                self.prot_featurize_params)
        return self._pdb_graph_db

    @property
    def drug2sdf_file(self):
        if self._drug2sdf_file is None:            
            drug2sdf_file = {f.stem : str(f) for f in Path(self.drug_sdf_dir).glob('*.sdf')}
            # Convert str keys to int for Davis
            if self.task_name == 'DAVIS' and all([k.isdigit() for k in drug2sdf_file.keys()]):
                drug2sdf_file = {int(k) : v for k, v in drug2sdf_file.items()}
            self._drug2sdf_file = drug2sdf_file
        return self._drug2sdf_file

    @property
    def drug_sdf_db(self):
        if self._drug_sdf_db is None:
            self._drug_sdf_db = mol_graph.sdf_to_graphs(self.drug2sdf_file)
        return self._drug_sdf_db


    def build_data(self, df, onthefly=False):
        records = df.to_dict('records')
        data_list = []
        for entry in records:
            drug = entry['drug']
            prot = entry['protein']
            if onthefly:
                pf = self.prot2pdb[prot]
                df = self.drug2sdf_file[drug]
            else:                
                pf = self.pdb_graph_db[prot]                
                df = self.drug_sdf_db[drug]
            data_list.append({'drug': df, 'protein': pf, 'y': entry['y'],
                'drug_name': drug, 'protein_name': prot})
        if onthefly:
            prot_featurize_fn = partial(
                pdb_graph.featurize_protein_graph,
                **self.prot_featurize_params)            
            drug_featurize_fn = mol_graph.featurize_drug
        else:
            prot_featurize_fn, drug_featurize_fn = None, None
        data = DTA(df=df, data_list=data_list, onthefly=onthefly,
            prot_featurize_fn=prot_featurize_fn, drug_featurize_fn=drug_featurize_fn)
        return data


    def get_split(self, df=None, split_method=None,
            split_frac=None, seed=None, onthefly=None,
            return_df=False):
        df = df or self.df
        split_method = split_method or self.split_method
        split_frac = split_frac or self.split_frac
        seed = seed or self.seed
        onthefly = onthefly or self.onthefly
        if split_method == 'random':
            split_df = create_fold(self.df, seed, split_frac)
        elif split_method == 'drug':
            split_df = create_fold_setting_cold(self.df, seed, split_frac, 'drug')
        elif split_method == 'protein':
            split_df = create_fold_setting_cold(self.df, seed, split_frac, 'protein')
        elif split_method == 'both':
            split_df = create_full_ood_set(self.df, seed, split_frac)
        elif split_method == 'seqid':
            split_df = create_seq_identity_fold(
                self.df, self.mmseqs_seq_clus_df, seed, split_frac)
        else:
            raise ValueError("Unknown split method: {}".format(split_method))
        split_data = {}
        for split, df in split_df.items():
            split_data[split] = self.build_data(df, onthefly=onthefly)
        if return_df:
            return split_data, split_df
        else:
            return split_data


class KIBA(DTATask):
    """
    KIBA drug-target interaction dataset
    """
    def __init__(self,
            data_path='../data/KIBA/kiba_data.tsv',            
            pdb_map='../data/KIBA/kiba_uniprot2pdb.yaml',
            pdb_json='../data/structure/pockets_structure.json',                        
            emb_dir='../data/esm1b',           
            num_pos_emb=16, num_rbf=16,
            contact_cutoff=8.,            
            drug_sdf_dir='../data/structure/kiba_mol3d_sdf',
            split_method='random', split_frac=[0.7, 0.1, 0.2],
            mmseqs_seq_cluster_file='../data/KIBA/kiba_cluster_id50_cluster.tsv',
            seed=42, onthefly=False
        ):
        df = pd.read_table(data_path)        
        prot_pdb_id = yaml.safe_load(open(pdb_map, 'r'))
        pdb_data = json.load(open(pdb_json, 'r'))                
        mmseqs_seq_clus_df = pd.read_table(mmseqs_seq_cluster_file, names=['rep', 'seq'])
        super(KIBA, self).__init__(
            task_name='KIBA',
            df=df, 
            prot_pdb_id=prot_pdb_id, pdb_data=pdb_data,
            emb_dir=emb_dir,            
            num_pos_emb=num_pos_emb, num_rbf=num_rbf,
            contact_cutoff=contact_cutoff,
            drug_sdf_dir=drug_sdf_dir,
            split_method=split_method, split_frac=split_frac,
            mmseqs_seq_clus_df=mmseqs_seq_clus_df,
            seed=seed, onthefly=onthefly
            )


class DAVIS(DTATask):
    """
    DAVIS drug-target interaction dataset
    """
    def __init__(self,
            data_path='../data/DAVIS/davis_data.tsv',            
            pdb_map='../data/DAVIS/davis_protein2pdb.yaml',
            pdb_json='../data/structure/pockets_structure.json',                        
            emb_dir='../data/esm1b',           
            num_pos_emb=16, num_rbf=16,
            contact_cutoff=8.,            
            drug_sdf_dir='../data/structure/davis_mol3d_sdf',
            split_method='random', split_frac=[0.7, 0.1, 0.2],
            mmseqs_seq_cluster_file='../data/DAVIS/davis_cluster_id50_cluster.tsv',
            seed=42, onthefly=False
        ):
        df = pd.read_table(data_path)        
        prot_pdb_id = yaml.safe_load(open(pdb_map, 'r'))
        pdb_data = json.load(open(pdb_json, 'r'))        
        mmseqs_seq_clus_df = pd.read_table(mmseqs_seq_cluster_file, names=['rep', 'seq'])
        super(DAVIS, self).__init__(
            task_name='DAVIS',
            df=df, 
            prot_pdb_id=prot_pdb_id, pdb_data=pdb_data,
            emb_dir=emb_dir,            
            num_pos_emb=num_pos_emb, num_rbf=num_rbf,
            contact_cutoff=contact_cutoff,
            drug_sdf_dir=drug_sdf_dir,
            split_method=split_method, split_frac=split_frac,
            mmseqs_seq_clus_df=mmseqs_seq_clus_df,
            seed=seed, onthefly=onthefly
            )


In [None]:
# experiment

In [None]:
import copy
from pathlib import Path
from tqdm import tqdm
import pandas as pd
import numpy as np
from joblib import Parallel, delayed
import uncertainty_toolbox as uct

import torch
import torch.optim as optim
import torch.nn.functional as F
import torch_geometric
torch.set_num_threads(1)



def _parallel_train_per_epoch(
    kwargs=None, test_loader=None,
    n_epochs=None, eval_freq=None, test_freq=None,
    monitoring_score='pearson',
    loss_fn=None, logger=None,
    test_after_train=True,
):
    midx = kwargs['midx']
    model = kwargs['model']
    optimizer = kwargs['optimizer']
    train_loader = kwargs['train_loader']
    valid_loader = kwargs['valid_loader']
    device = kwargs['device']
    stopper = kwargs['stopper']
    best_model_state_dict = kwargs['best_model_state_dict']
    if stopper.early_stop:
        return kwargs

    model.train()
    for epoch in range(1, n_epochs + 1):
        total_loss = 0
        for step, batch in enumerate(train_loader, start=1):
            xd = batch['drug'].to(device)
            xp = batch['protein'].to(device)
            y = batch['y'].to(device)
            optimizer.zero_grad()
            yh = model(xd, xp)
            loss = loss_fn(yh, y.view(-1, 1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        train_loss = total_loss / step
        if epoch % eval_freq == 0:
            val_results = _parallel_test(
                {'model': model, 'midx': midx, 'test_loader': valid_loader, 'device': device},
                loss_fn=loss_fn, logger=logger
            )
            is_best = stopper.update(val_results['metrics'][monitoring_score])
            if is_best:
                best_model_state_dict = copy.deepcopy(model.state_dict())
            logger.info(f"M-{midx} E-{epoch} | Train Loss: {train_loss:.4f} | Valid Loss: {val_results['loss']:.4f} | "\
                + ' | '.join([f'{k}: {v:.4f}' for k, v in val_results['metrics'].items()])
                + f" | best {monitoring_score}: {stopper.best_score:.4f}"
                )
        if test_freq is not None and epoch % test_freq == 0:
            test_results = _parallel_test(
                {'midx': midx, 'model': model, 'test_loader': test_loader, 'device': device},
                loss_fn=loss_fn, logger=logger
            )
            logger.info(f"M-{midx} E-{epoch} | Test Loss: {test_results['loss']:.4f} | "\
                + ' | '.join([f'{k}: {v:.4f}' for k, v in test_results['metrics'].items()])
                )

        if stopper.early_stop:
            logger.info('Eearly stop at epoch {}'.format(epoch))

    if best_model_state_dict is not None:
        model.load_state_dict(best_model_state_dict)
    if test_after_train:
        test_results = _parallel_test(
            {'midx': midx, 'model': model, 'test_loader': test_loader, 'device': device},
            loss_fn=loss_fn,
            test_tag=f"Model {midx}", print_log=True, logger=logger
        )
    rets = dict(midx = midx, model = model)
    return rets


def _parallel_test(
    kwargs=None, loss_fn=None, 
    test_tag=None, print_log=False, logger=None,
):
    midx = kwargs['midx']
    model = kwargs['model']
    test_loader = kwargs['test_loader']
    device = kwargs['device']
    model.eval()
    yt, yp, total_loss = torch.Tensor(), torch.Tensor(), 0
    with torch.no_grad():
        for step, batch in enumerate(test_loader, start=1):
            xd = batch['drug'].to(device)
            xp = batch['protein'].to(device)
            y = batch['y'].to(device)
            yh = model(xd, xp)
            loss = loss_fn(yh, y.view(-1, 1))
            total_loss += loss.item()
            yp = torch.cat([yp, yh.detach().cpu()], dim=0)
            yt = torch.cat([yt, y.detach().cpu()], dim=0)
    yt = yt.numpy()
    yp = yp.view(-1).numpy()
    results = {
        'midx': midx,
        'y_true': yt,
        'y_pred': yp,
        'loss': total_loss / step,
    }
    eval_metrics = evaluation_metrics(
        yt, yp,
        eval_metrics=['mse', 'spearman', 'pearson']
    )
    results['metrics'] = eval_metrics
    if print_log:
        logger.info(f"{test_tag} | Test Loss: {results['loss']:.4f} | "\
            + ' | '.join([f'{k}: {v:.4f}' for k, v in results['metrics'].items()]))
    return results


def _unpack_evidential_output(output):
    mu, v, alpha, beta = torch.split(output, output.shape[1]//4, dim=1)
    inverse_evidence = 1. / ((alpha - 1) * v)
    var = beta * inverse_evidence
    return mu, var, inverse_evidence


class DTAExperiment(object):
    def __init__(self,
        task=None,
        split_method='protein',
        split_frac=[0.7, 0.1, 0.2],
        prot_gcn_dims=[128, 128, 128], prot_gcn_bn=False,
        prot_fc_dims=[1024, 128],
        drug_in_dim=66, drug_fc_dims=[1024, 128], drug_gcn_dims=[128, 64],
        mlp_dims=[1024, 512], mlp_dropout=0.25,
        num_pos_emb=16, num_rbf=16,
        contact_cutoff=8.,
        n_ensembles=1, n_epochs=500, batch_size=256,
        lr=0.001,        
        seed=42, onthefly=False,
        uncertainty=False, parallel=False,
        output_dir='../output', save_log=False
    ):
        self.saver = Saver(output_dir)
        self.logger = Logger(logfile=self.saver.save_dir/'exp.log' if save_log else None)

        self.uncertainty = uncertainty
        self.parallel = parallel
        self.n_ensembles = n_ensembles
        if self.uncertainty and self.n_ensembles < 2:
            raise ValueError('n_ensembles must be greater than 1 when uncertainty is True')            
        self.n_epochs = n_epochs
        self.batch_size = batch_size
        self.lr = lr
        dataset_klass = {
            'kiba': KIBA,
            'davis': DAVIS,
        }[task]

        self.dataset = dataset_klass(
            split_method=split_method,
            split_frac=split_frac,
            seed=seed,
            onthefly=onthefly,
            num_pos_emb=num_pos_emb,
            num_rbf=num_rbf,
            contact_cutoff=contact_cutoff,
        )
        self._task_data_df_split = None
        self._task_loader = None

        n_gpus = torch.cuda.device_count()
        if self.parallel and n_gpus < self.n_ensembles:
            self.logger.warning(f"Visible GPUs ({n_gpus}) is fewer than "
            f"number of models ({self.n_ensembles}). Some models will be run on the same GPU"
            )
        self.devices = [torch.device(f'cuda:{i % n_gpus}')
            for i in range(self.n_ensembles)]
        self.model_config = dict(
            prot_emb_dim=1280,
            prot_gcn_dims=prot_gcn_dims,            
            prot_fc_dims=prot_fc_dims,
            drug_node_in_dim=[66, 1], 
            drug_node_h_dims=drug_gcn_dims,
            drug_fc_dims=drug_fc_dims,            
            mlp_dims=mlp_dims, mlp_dropout=mlp_dropout)
        self.build_model()
        self.criterion = F.mse_loss

        self.split_method = split_method
        self.split_frac = split_frac

        self.logger.info(self.models[0])
        self.logger.info(self.optimizers[0])

    def build_model(self):
        self.models = [DTAModel(**self.model_config).to(self.devices[i])
                        for i in range(self.n_ensembles)]
        self.optimizers = [optim.Adam(model.parameters(), lr=self.lr) for model in self.models]

    def _get_data_loader(self, dataset, shuffle=False):
        return torch_geometric.loader.DataLoader(
                    dataset=dataset,
                    batch_size=self.batch_size,
                    shuffle=shuffle,
                    pin_memory=False,
                    num_workers=0,
                )

    @property
    def task_data_df_split(self):
        if self._task_data_df_split is None:
            (data, df) = self.dataset.get_split(return_df=True)
            self._task_data_df_split = (data, df)
        return self._task_data_df_split

    @property
    def task_data(self):
        return self.task_data_df_split[0]

    @property
    def task_df(self):
        return self.task_data_df_split[1]

    @property
    def task_loader(self):
        if self._task_loader is None:
            _loader = {
                s: self._get_data_loader(
                    self.task_data[s], shuffle=(s == 'train'))
                for s in self.task_data
            }
            self._task_loader = _loader
        return self._task_loader

    def recalibrate_std(self, df, recalib_df):
        y_mean = recalib_df['y_pred'].values
        y_std = recalib_df['y_std'].values
        y_true = recalib_df['y_true'].values
        std_ratio = uct.recalibration.optimize_recalibration_ratio(
            y_mean, y_std, y_true, criterion="miscal")
        df['y_std_recalib'] = df['y_std'] * std_ratio
        return df

    def _format_predict_df(self, results,
            test_df=None, esb_yp=None, recalib_df=None):
        """
        results: dict with keys y_pred, y_true, y_var
        """
        df = self.task_df['test'].copy() if test_df is None else test_df.copy()
        assert np.allclose(results['y_true'], df['y'].values)
        df = df.rename(columns={'y': 'y_true'})
        df['y_pred'] = results['y_pred']
        if esb_yp is not None:
            if self.uncertainty:
                df['y_std'] = np.std(esb_yp, axis=0)
                if recalib_df is not None:
                    df = self.recalibrate_std(df, recalib_df)
            for i in range(self.n_ensembles):
                df[f'y_pred_{i + 1}'] = esb_yp[i]
        return df

    def train(self, n_epochs=None, patience=None,
                eval_freq=1, test_freq=None,
                monitoring_score='pearson',
                train_data=None, valid_data=None,                
                rebuild_model=False,
                test_after_train=False):
        n_epochs = n_epochs or self.n_epochs
        if rebuild_model:
            self.build_model()
        tl, vl = self.task_loader['train'], self.task_loader['valid']
        rets_list = []
        for i in range(self.n_ensembles):
            stp = EarlyStopping(eval_freq=eval_freq, patience=patience,
                                    higher_better=(monitoring_score != 'mse'))
            rets = dict(
                midx = i + 1,
                model = self.models[i],
                optimizer = self.optimizers[i],
                device = self.devices[i],
                train_loader = tl,
                valid_loader = vl,
                stopper = stp,
                best_model_state_dict = None,
            )
            rets_list.append(rets)

        rets_list = Parallel(n_jobs=(self.n_ensembles if self.parallel else 1), prefer="threads")(
            delayed(_parallel_train_per_epoch)(
                kwargs=rets_list[i],
                test_loader=self.task_loader['test'],
                n_epochs=n_epochs, eval_freq=eval_freq, test_freq=test_freq,
                monitoring_score=monitoring_score,
                loss_fn=self.criterion, logger=self.logger,
                test_after_train=test_after_train,
            ) for i in range(self.n_ensembles))

        for i, rets in enumerate(rets_list):
            self.models[rets['midx'] - 1] = rets['model']


    def test(self, test_model=None, test_loader=None,
                test_data=None, test_df=None,
                recalib_df=None,
                save_prediction=False, save_df_name='prediction.tsv',
                test_tag=None, print_log=False):
        test_models = self.models if test_model is None else [test_model]
        if test_data is not None:
            assert test_df is not None, 'test_df must be provided if test_data used'
            test_loader = self._get_data_loader(test_data)
        elif test_loader is not None:
            assert test_df is not None, 'test_df must be provided if test_loader used'
        else:
            test_loader = self.task_loader['test']
        rets_list = []
        for i, model in enumerate(test_models):
            rets = _parallel_test(
                kwargs={
                    'midx': i + 1,
                    'model': model,
                    'test_loader': test_loader,
                    'device': self.devices[i],
                },
                loss_fn=self.criterion,
                test_tag=f"Model {i+1}", print_log=True, logger=self.logger
            )
            rets_list.append(rets)


        esb_yp, esb_loss = None, 0
        for rets in rets_list:
            esb_yp = rets['y_pred'].reshape(1, -1) if esb_yp is None else\
                np.vstack((esb_yp, rets['y_pred'].reshape(1, -1)))
            esb_loss += rets['loss']

        y_true = rets['y_true']
        y_pred = np.mean(esb_yp, axis=0)
        esb_loss /= len(test_models)
        results = {
            'y_true': y_true,
            'y_pred': y_pred,
            'loss': esb_loss,
        }

        eval_metrics = evaluation_metrics(
            y_true, y_pred,
            eval_metrics=['mse', 'spearman', 'pearson']
        )
        results['metrics'] = eval_metrics
        results['df'] = self._format_predict_df(results,
            test_df=test_df, esb_yp=esb_yp, recalib_df=recalib_df)
        if save_prediction:
            self.saver.save_df(results['df'], save_df_name, float_format='%g')
        if print_log:
            self.logger.info(f"{test_tag} | Test Loss: {results['loss']:.4f} | "\
                + ' | '.join([f'{k}: {v:.4f}' for k, v in results['metrics'].items()]))
        return results



In [None]:
# gvp

In [None]:
"""
Geometric Vector Perceptrons
From: https://github.com/drorlab/gvp-pytorch/blob/82af6b22eaf8311c15733117b0071408d24ed877/gvp/__init__.py
"""
import torch, functools
from torch import nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_scatter import scatter_add

def tuple_sum(*args):
    '''
    Sums any number of tuples (s, V) elementwise.
    '''
    return tuple(map(sum, zip(*args)))

def tuple_cat(*args, dim=-1):
    '''
    Concatenates any number of tuples (s, V) elementwise.
    
    :param dim: dimension along which to concatenate when viewed
                as the `dim` index for the scalar-channel tensors.
                This means that `dim=-1` will be applied as
                `dim=-2` for the vector-channel tensors.
    '''
    dim %= len(args[0][0].shape)
    s_args, v_args = list(zip(*args))
    return torch.cat(s_args, dim=dim), torch.cat(v_args, dim=dim)

def tuple_index(x, idx):
    '''
    Indexes into a tuple (s, V) along the first dimension.
    
    :param idx: any object which can be used to index into a `torch.Tensor`
    '''
    return x[0][idx], x[1][idx]

def randn(n, dims, device="cpu"):
    '''
    Returns random tuples (s, V) drawn elementwise from a normal distribution.
    
    :param n: number of data points
    :param dims: tuple of dimensions (n_scalar, n_vector)
    
    :return: (s, V) with s.shape = (n, n_scalar) and
             V.shape = (n, n_vector, 3)
    '''
    return torch.randn(n, dims[0], device=device), \
            torch.randn(n, dims[1], 3, device=device)

def _norm_no_nan(x, axis=-1, keepdims=False, eps=1e-8, sqrt=True):
    '''
    L2 norm of tensor clamped above a minimum value `eps`.
    
    :param sqrt: if `False`, returns the square of the L2 norm
    '''
    out = torch.clamp(torch.sum(torch.square(x), axis, keepdims), min=eps)
    return torch.sqrt(out) if sqrt else out

def _split(x, nv):
    '''
    Splits a merged representation of (s, V) back into a tuple. 
    Should be used only with `_merge(s, V)` and only if the tuple 
    representation cannot be used.
    
    :param x: the `torch.Tensor` returned from `_merge`
    :param nv: the number of vector channels in the input to `_merge`
    '''
    v = torch.reshape(x[..., -3*nv:], x.shape[:-1] + (nv, 3))
    s = x[..., :-3*nv]
    return s, v

def _merge(s, v):
    '''
    Merges a tuple (s, V) into a single `torch.Tensor`, where the
    vector channels are flattened and appended to the scalar channels.
    Should be used only if the tuple representation cannot be used.
    Use `_split(x, nv)` to reverse.
    '''
    v = torch.reshape(v, v.shape[:-2] + (3*v.shape[-2],))
    return torch.cat([s, v], -1)

class GVP(nn.Module):
    '''
    Geometric Vector Perceptron. See manuscript and README.md
    for more details.
    
    :param in_dims: tuple (n_scalar, n_vector)
    :param out_dims: tuple (n_scalar, n_vector)
    :param h_dim: intermediate number of vector channels, optional
    :param activations: tuple of functions (scalar_act, vector_act)
    :param vector_gate: whether to use vector gating.
                        (vector_act will be used as sigma^+ in vector gating if `True`)
    '''
    def __init__(self, in_dims, out_dims, h_dim=None,
                 activations=(F.relu, torch.sigmoid), vector_gate=False):
        super(GVP, self).__init__()
        self.si, self.vi = in_dims
        self.so, self.vo = out_dims
        self.vector_gate = vector_gate
        if self.vi: 
            self.h_dim = h_dim or max(self.vi, self.vo) 
            self.wh = nn.Linear(self.vi, self.h_dim, bias=False)
            self.ws = nn.Linear(self.h_dim + self.si, self.so)
            if self.vo:
                self.wv = nn.Linear(self.h_dim, self.vo, bias=False)
                if self.vector_gate: self.wsv = nn.Linear(self.so, self.vo)
        else:
            self.ws = nn.Linear(self.si, self.so)
        
        self.scalar_act, self.vector_act = activations
        self.dummy_param = nn.Parameter(torch.empty(0))
        
    def forward(self, x):
        '''
        :param x: tuple (s, V) of `torch.Tensor`, 
                  or (if vectors_in is 0), a single `torch.Tensor`
        :return: tuple (s, V) of `torch.Tensor`,
                 or (if vectors_out is 0), a single `torch.Tensor`
        '''
        if self.vi:
            s, v = x
            v = torch.transpose(v, -1, -2)
            vh = self.wh(v)    
            vn = _norm_no_nan(vh, axis=-2)
            s = self.ws(torch.cat([s, vn], -1))
            if self.vo: 
                v = self.wv(vh) 
                v = torch.transpose(v, -1, -2)
                if self.vector_gate: 
                    if self.vector_act:
                        gate = self.wsv(self.vector_act(s))
                    else:
                        gate = self.wsv(s)
                    v = v * torch.sigmoid(gate).unsqueeze(-1)
                elif self.vector_act:
                    v = v * self.vector_act(
                        _norm_no_nan(v, axis=-1, keepdims=True))
        else:
            s = self.ws(x)
            if self.vo:
                v = torch.zeros(s.shape[0], self.vo, 3,
                                device=self.dummy_param.device)
        if self.scalar_act:
            s = self.scalar_act(s)
        
        return (s, v) if self.vo else s

class _VDropout(nn.Module):
    '''
    Vector channel dropout where the elements of each
    vector channel are dropped together.
    '''
    def __init__(self, drop_rate):
        super(_VDropout, self).__init__()
        self.drop_rate = drop_rate
        self.dummy_param = nn.Parameter(torch.empty(0))

    def forward(self, x):
        '''
        :param x: `torch.Tensor` corresponding to vector channels
        '''
        device = self.dummy_param.device
        if not self.training:
            return x
        mask = torch.bernoulli(
            (1 - self.drop_rate) * torch.ones(x.shape[:-1], device=device)
        ).unsqueeze(-1)
        x = mask * x / (1 - self.drop_rate)
        return x

class Dropout(nn.Module):
    '''
    Combined dropout for tuples (s, V).
    Takes tuples (s, V) as input and as output.
    '''
    def __init__(self, drop_rate):
        super(Dropout, self).__init__()
        self.sdropout = nn.Dropout(drop_rate)
        self.vdropout = _VDropout(drop_rate)

    def forward(self, x):
        '''
        :param x: tuple (s, V) of `torch.Tensor`,
                  or single `torch.Tensor` 
                  (will be assumed to be scalar channels)
        '''
        if type(x) is torch.Tensor:
            return self.sdropout(x)
        s, v = x
        return self.sdropout(s), self.vdropout(v)

class LayerNorm(nn.Module):
    '''
    Combined LayerNorm for tuples (s, V).
    Takes tuples (s, V) as input and as output.
    '''
    def __init__(self, dims):
        super(LayerNorm, self).__init__()
        self.s, self.v = dims
        self.scalar_norm = nn.LayerNorm(self.s)
        
    def forward(self, x):
        '''
        :param x: tuple (s, V) of `torch.Tensor`,
                  or single `torch.Tensor` 
                  (will be assumed to be scalar channels)
        '''
        if not self.v:
            return self.scalar_norm(x)
        s, v = x
        vn = _norm_no_nan(v, axis=-1, keepdims=True, sqrt=False)
        vn = torch.sqrt(torch.mean(vn, dim=-2, keepdim=True))
        return self.scalar_norm(s), v / vn

class GVPConv(MessagePassing):
    '''
    Graph convolution / message passing with Geometric Vector Perceptrons.
    Takes in a graph with node and edge embeddings,
    and returns new node embeddings.
    
    This does NOT do residual updates and pointwise feedforward layers
    ---see `GVPConvLayer`.
    
    :param in_dims: input node embedding dimensions (n_scalar, n_vector)
    :param out_dims: output node embedding dimensions (n_scalar, n_vector)
    :param edge_dims: input edge embedding dimensions (n_scalar, n_vector)
    :param n_layers: number of GVPs in the message function
    :param module_list: preconstructed message function, overrides n_layers
    :param aggr: should be "add" if some incoming edges are masked, as in
                 a masked autoregressive decoder architecture, otherwise "mean"
    :param activations: tuple of functions (scalar_act, vector_act) to use in GVPs
    :param vector_gate: whether to use vector gating.
                        (vector_act will be used as sigma^+ in vector gating if `True`)
    '''
    def __init__(self, in_dims, out_dims, edge_dims,
                 n_layers=3, module_list=None, aggr="mean", 
                 activations=(F.relu, torch.sigmoid), vector_gate=False):
        super(GVPConv, self).__init__(aggr=aggr)
        self.si, self.vi = in_dims
        self.so, self.vo = out_dims
        self.se, self.ve = edge_dims
        
        GVP_ = functools.partial(GVP, 
                activations=activations, vector_gate=vector_gate)
        
        module_list = module_list or []
        if not module_list:
            if n_layers == 1:
                module_list.append(
                    GVP_((2*self.si + self.se, 2*self.vi + self.ve), 
                        (self.so, self.vo), activations=(None, None)))
            else:
                module_list.append(
                    GVP_((2*self.si + self.se, 2*self.vi + self.ve), out_dims)
                )
                for i in range(n_layers - 2):
                    module_list.append(GVP_(out_dims, out_dims))
                module_list.append(GVP_(out_dims, out_dims,
                                       activations=(None, None)))
        self.message_func = nn.Sequential(*module_list)

    def forward(self, x, edge_index, edge_attr):
        '''
        :param x: tuple (s, V) of `torch.Tensor`
        :param edge_index: array of shape [2, n_edges]
        :param edge_attr: tuple (s, V) of `torch.Tensor`
        '''
        x_s, x_v = x
        message = self.propagate(edge_index, 
                    s=x_s, v=x_v.reshape(x_v.shape[0], 3*x_v.shape[1]),
                    edge_attr=edge_attr)
        return _split(message, self.vo) 

    def message(self, s_i, v_i, s_j, v_j, edge_attr):
        v_j = v_j.view(v_j.shape[0], v_j.shape[1]//3, 3)
        v_i = v_i.view(v_i.shape[0], v_i.shape[1]//3, 3)
        message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i))
        message = self.message_func(message)
        return _merge(*message)


class GVPConvLayer(nn.Module):
    '''
    Full graph convolution / message passing layer with 
    Geometric Vector Perceptrons. Residually updates node embeddings with
    aggregated incoming messages, applies a pointwise feedforward 
    network to node embeddings, and returns updated node embeddings.
    
    To only compute the aggregated messages, see `GVPConv`.
    
    :param node_dims: node embedding dimensions (n_scalar, n_vector)
    :param edge_dims: input edge embedding dimensions (n_scalar, n_vector)
    :param n_message: number of GVPs to use in message function
    :param n_feedforward: number of GVPs to use in feedforward function
    :param drop_rate: drop probability in all dropout layers
    :param autoregressive: if `True`, this `GVPConvLayer` will be used
           with a different set of input node embeddings for messages
           where src >= dst
    :param activations: tuple of functions (scalar_act, vector_act) to use in GVPs
    :param vector_gate: whether to use vector gating.
                        (vector_act will be used as sigma^+ in vector gating if `True`)
    '''
    def __init__(self, node_dims, edge_dims,
                 n_message=3, n_feedforward=2, drop_rate=.1,
                 autoregressive=False, 
                 activations=(F.relu, torch.sigmoid), vector_gate=False):
        
        super(GVPConvLayer, self).__init__()
        self.conv = GVPConv(node_dims, node_dims, edge_dims, n_message,
                           aggr="add" if autoregressive else "mean",
                           activations=activations, vector_gate=vector_gate)
        GVP_ = functools.partial(GVP, 
                activations=activations, vector_gate=vector_gate)
        self.norm = nn.ModuleList([LayerNorm(node_dims) for _ in range(2)])
        self.dropout = nn.ModuleList([Dropout(drop_rate) for _ in range(2)])

        ff_func = []
        if n_feedforward == 1:
            ff_func.append(GVP_(node_dims, node_dims, activations=(None, None)))
        else:
            hid_dims = 4*node_dims[0], 2*node_dims[1]
            ff_func.append(GVP_(node_dims, hid_dims))
            for i in range(n_feedforward-2):
                ff_func.append(GVP_(hid_dims, hid_dims))
            ff_func.append(GVP_(hid_dims, node_dims, activations=(None, None)))
        self.ff_func = nn.Sequential(*ff_func)

    def forward(self, x, edge_index, edge_attr,
                autoregressive_x=None, node_mask=None):
        '''
        :param x: tuple (s, V) of `torch.Tensor`
        :param edge_index: array of shape [2, n_edges]
        :param edge_attr: tuple (s, V) of `torch.Tensor`
        :param autoregressive_x: tuple (s, V) of `torch.Tensor`. 
                If not `None`, will be used as src node embeddings
                for forming messages where src >= dst. The corrent node 
                embeddings `x` will still be the base of the update and the 
                pointwise feedforward.
        :param node_mask: array of type `bool` to index into the first
                dim of node embeddings (s, V). If not `None`, only
                these nodes will be updated.
        '''
        
        if autoregressive_x is not None:
            src, dst = edge_index
            mask = src < dst
            edge_index_forward = edge_index[:, mask]
            edge_index_backward = edge_index[:, ~mask]
            edge_attr_forward = tuple_index(edge_attr, mask)
            edge_attr_backward = tuple_index(edge_attr, ~mask)
            
            dh = tuple_sum(
                self.conv(x, edge_index_forward, edge_attr_forward),
                self.conv(autoregressive_x, edge_index_backward, edge_attr_backward)
            )
            
            count = scatter_add(torch.ones_like(dst), dst,
                        dim_size=dh[0].size(0)).clamp(min=1).unsqueeze(-1)
            
            dh = dh[0] / count, dh[1] / count.unsqueeze(-1)

        else:
            dh = self.conv(x, edge_index, edge_attr)
        
        if node_mask is not None:
            x_ = x
            x, dh = tuple_index(x, node_mask), tuple_index(dh, node_mask)
            
        x = self.norm[0](tuple_sum(x, self.dropout[0](dh)))
        
        dh = self.ff_func(x)
        x = self.norm[1](tuple_sum(x, self.dropout[1](dh)))
        
        if node_mask is not None:
            x_[0][node_mask], x_[1][node_mask] = x[0], x[1]
            x = x_
        return x

In [None]:
# metrics

In [None]:
from sklearn import metrics
from scipy import stats
import numpy as np

def eval_mse(y_true, y_pred, squared=True):
    """Evaluate mse/rmse and return the results.
    squared: bool, default=True
        If True returns MSE value, if False returns RMSE value.
    """
    return metrics.mean_squared_error(y_true, y_pred, squared=squared)

def eval_pearson(y_true, y_pred):
    """Evaluate Pearson correlation and return the results."""
    return stats.pearsonr(y_true, y_pred)[0]

def eval_spearman(y_true, y_pred):
    """Evaluate Spearman correlation and return the results."""
    return stats.spearmanr(y_true, y_pred)[0]

def eval_r2(y_true, y_pred):
    """Evaluate R2 and return the results."""
    return metrics.r2_score(y_true, y_pred)

def eval_auroc(y_true, y_pred):
    """Evaluate AUROC and return the results."""
    fpr, tpr, _ = metrics.roc_curve(y_true, y_pred)
    return metrics.auc(fpr, tpr)

def eval_auprc(y_true, y_pred):
    """Evaluate AUPRC and return the results."""
    pre, rec, _ = metrics.precision_recall_curve(y_true, y_pred)
    return metrics.auc(rec, pre)


def evaluation_metrics(y_true=None, y_pred=None,
		eval_metrics=[]):
    """Evaluate eval_metrics and return the results.
    Parameters
    ----------
    y_true: true labels
    y_pred: predicted labels
    eval_metrics: a list of evaluation metrics
    """
    results = {}
    for m in eval_metrics:
        if m == 'mse':
            s = eval_mse(y_true, y_pred, squared=True)
        elif m == 'rmse':
            s = eval_mse(y_true, y_pred, squared=False)
        elif m == 'pearson':
            s = eval_pearson(y_true, y_pred)
        elif m == 'spearman':
            s = eval_spearman(y_true, y_pred)
        elif m == 'r2':
            s = eval_r2(y_true, y_pred)
        elif m == 'auroc':
            s = eval_auroc(y_true, y_pred)
        elif m == 'auprc':
            s = eval_auprc(y_true, y_pred)
        else:
            raise ValueError('Unknown evaluation metric: {}'.format(m))
        results[m] = s        
    return results

In [None]:
# model

In [None]:
import torch
import torch.nn as nn
import torch_geometric

class Prot3DGraphModel(nn.Module):
    def __init__(self,
        d_vocab=21, d_embed=20,
        d_dihedrals=6, d_pretrained_emb=1280, d_edge=39,
        d_gcn=[128, 256, 256],
    ):
        super(Prot3DGraphModel, self).__init__()
        d_gcn_in = d_gcn[0]
        self.embed = nn.Embedding(d_vocab, d_embed)
        self.proj_node = nn.Linear(d_embed + d_dihedrals + d_pretrained_emb, d_gcn_in)
        self.proj_edge = nn.Linear(d_edge, d_gcn_in)
        gcn_layer_sizes = [d_gcn_in] + d_gcn
        layers = []
        for i in range(len(gcn_layer_sizes) - 1):            
            layers.append((
                torch_geometric.nn.TransformerConv(
                    gcn_layer_sizes[i], gcn_layer_sizes[i + 1], edge_dim=d_gcn_in),
                'x, edge_index, edge_attr -> x'
            ))            
            layers.append(nn.LeakyReLU())            
        
        self.gcn = torch_geometric.nn.Sequential(
            'x, edge_index, edge_attr', layers)        
        self.pool = torch_geometric.nn.global_mean_pool
        

    def forward(self, data):
        x, edge_index = data.seq, data.edge_index
        batch = data.batch

        x = self.embed(x)
        s = data.node_s
        emb = data.seq_emb
        x = torch.cat([x, s, emb], dim=-1)

        edge_attr = data.edge_s

        x = self.proj_node(x)
        edge_attr = self.proj_edge(edge_attr)

        x = self.gcn(x, edge_index, edge_attr)
        x = torch_geometric.nn.global_mean_pool(x, batch)
        return x



class DrugGVPModel(nn.Module):
    def __init__(self, 
        node_in_dim=[66, 1], node_h_dim=[128, 64],
        edge_in_dim=[16, 1], edge_h_dim=[32, 1],
        num_layers=3, drop_rate=0.1
    ):
        """
        Parameters
        ----------
        node_in_dim : list of int
            Input dimension of drug node features (si, vi).
            Scalar node feartures have shape (N, si).
            Vector node features have shape (N, vi, 3).
        node_h_dims : list of int
            Hidden dimension of drug node features (so, vo).
            Scalar node feartures have shape (N, so).
            Vector node features have shape (N, vo, 3).
        """
        super(DrugGVPModel, self).__init__()
        self.W_v = nn.Sequential(
            LayerNorm(node_in_dim),
            GVP(node_in_dim, node_h_dim, activations=(None, None))
        )
        self.W_e = nn.Sequential(
            LayerNorm(edge_in_dim),
            GVP(edge_in_dim, edge_h_dim, activations=(None, None))
        )

        self.layers = nn.ModuleList(
                GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate)
            for _ in range(num_layers))

        ns, _ = node_h_dim
        self.W_out = nn.Sequential(
            LayerNorm(node_h_dim),
            GVP(node_h_dim, (ns, 0)))

    def forward(self, xd):
        # Unpack input data
        h_V = (xd.node_s, xd.node_v)
        h_E = (xd.edge_s, xd.edge_v)
        edge_index = xd.edge_index
        batch = xd.batch

        h_V = self.W_v(h_V)
        h_E = self.W_e(h_E)
        for layer in self.layers:
            h_V = layer(h_V, edge_index, h_E)
        out = self.W_out(h_V)

        # per-graph mean
        out = torch_geometric.nn.global_add_pool(out, batch)

        return out


class DTAModel(nn.Module):
    def __init__(self,
            prot_emb_dim=1280,
            prot_gcn_dims=[128, 256, 256],
            prot_fc_dims=[1024, 128],
            drug_node_in_dim=[66, 1], drug_node_h_dims=[128, 64],
            drug_edge_in_dim=[16, 1], drug_edge_h_dims=[32, 1],            
            drug_fc_dims=[1024, 128],
            mlp_dims=[1024, 512], mlp_dropout=0.25):
        super(DTAModel, self).__init__()

        self.drug_model = DrugGVPModel(
            node_in_dim=drug_node_in_dim, node_h_dim=drug_node_h_dims,
            edge_in_dim=drug_edge_in_dim, edge_h_dim=drug_edge_h_dims,
        )
        drug_emb_dim = drug_node_h_dims[0]

        self.prot_model = Prot3DGraphModel(
            d_pretrained_emb=prot_emb_dim, d_gcn=prot_gcn_dims
        )
        prot_emb_dim = prot_gcn_dims[-1]

        self.drug_fc = self.get_fc_layers(
            [drug_emb_dim] + drug_fc_dims,
            dropout=mlp_dropout, batchnorm=False,
            no_last_dropout=True, no_last_activation=True)
       
        self.prot_fc = self.get_fc_layers(
            [prot_emb_dim] + prot_fc_dims,
            dropout=mlp_dropout, batchnorm=False,
            no_last_dropout=True, no_last_activation=True)

        self.top_fc = self.get_fc_layers(
            [drug_fc_dims[-1] + prot_fc_dims[-1]] + mlp_dims + [1],
            dropout=mlp_dropout, batchnorm=False,
            no_last_dropout=True, no_last_activation=True)

    def get_fc_layers(self, hidden_sizes,
            dropout=0, batchnorm=False,
            no_last_dropout=True, no_last_activation=True):
        act_fn = torch.nn.LeakyReLU()
        layers = []
        for i, (in_dim, out_dim) in enumerate(zip(hidden_sizes[:-1], hidden_sizes[1:])):
            layers.append(nn.Linear(in_dim, out_dim))
            if not no_last_activation or i != len(hidden_sizes) - 2:
                layers.append(act_fn)
            if dropout > 0:
                if not no_last_dropout or i != len(hidden_sizes) - 2:
                    layers.append(nn.Dropout(dropout))
            if batchnorm and i != len(hidden_sizes) - 2:
                layers.append(nn.BatchNorm1d(out_dim))
        return nn.Sequential(*layers)

    def forward(self, xd, xp):
        xd = self.drug_model(xd)
        xp = self.prot_model(xp)

        xd = self.drug_fc(xd)
        xp = self.prot_fc(xp)

        x = torch.cat([xd, xp], dim=1)
        x = self.top_fc(x)
        return x


In [None]:
# mol_graph

In [None]:
import rdkit 
from rdkit import Chem
import torch
import numpy as np
import pandas as pd
import torch_geometric
import torch_cluster



def onehot_encoder(a=None, alphabet=None, default=None, drop_first=False):
    '''
    Parameters
    ----------
    a: array of numerical value of categorical feature classes.
    alphabet: valid values of feature classes.
    default: default class if out of alphabet.
    Returns
    -------
    A 2-D one-hot array with size |x| * |alphabet|
    '''
    # replace out-of-vocabulary classes
    alphabet_set = set(alphabet)
    a = [x if x in alphabet_set else default for x in a]

    # cast to category to force class not present
    a = pd.Categorical(a, categories=alphabet)

    onehot = pd.get_dummies(pd.Series(a), columns=alphabet, drop_first=drop_first)
    return onehot.values


def _build_atom_feature(mol):
    # dim: 44 + 7 + 7 + 7 + 1
    feature_alphabet = {
        # (alphabet, default value)
        'GetSymbol': (ATOM_VOCAB, 'unk'),
        'GetDegree': ([0, 1, 2, 3, 4, 5, 6], 6),
        'GetTotalNumHs': ([0, 1, 2, 3, 4, 5, 6], 6),
        'GetImplicitValence': ([0, 1, 2, 3, 4, 5, 6], 6),
        'GetIsAromatic': ([0, 1], 1)
    }

    atom_feature = None
    for attr in ['GetSymbol', 'GetDegree', 'GetTotalNumHs',
                'GetImplicitValence', 'GetIsAromatic']:
        feature = [getattr(atom, attr)() for atom in mol.GetAtoms()]
        feature = onehot_encoder(feature,
                    alphabet=feature_alphabet[attr][0],
                    default=feature_alphabet[attr][1],
                    drop_first=(attr in ['GetIsAromatic']) # binary-class feature
                )
        atom_feature = feature if atom_feature is None else np.concatenate((atom_feature, feature), axis=1)
    atom_feature = atom_feature.astype(np.float32)
    return atom_feature




def _build_edge_feature(coords, edge_index, D_max=4.5, num_rbf=16):
    E_vectors = coords[edge_index[0]] - coords[edge_index[1]]
    rbf = _rbf(E_vectors.norm(dim=-1), D_max=D_max, D_count=num_rbf)

    edge_s = rbf
    edge_v = _normalize(E_vectors).unsqueeze(-2)

    edge_s, edge_v = map(torch.nan_to_num, (edge_s, edge_v))

    return edge_s, edge_v


def sdf_to_graphs(data_list):
    """
    Parameters
    ----------
    data_list: dict, drug key -> sdf file path
    Returns
    -------
    graphs : dict
        A list of torch_geometric graphs. drug key -> graph
    """
    graphs = {}
    for key, sdf_path in tqdm(data_list.items(), desc='sdf'):
        graphs[key] = featurize_drug(sdf_path, name=key)
    return graphs


def featurize_drug(sdf_path, name=None, edge_cutoff=4.5, num_rbf=16):
    """
    Parameters
    ----------
    sdf_path: str
        Path to sdf file
    name: str
        Name of drug
    Returns
    -------
    graph: torch_geometric.data.Data
        A torch_geometric graph
    """
    mol = rdkit.Chem.MolFromMolFile(sdf_path)
    conf = mol.GetConformer()
    with torch.no_grad():
        coords = conf.GetPositions()
        coords = torch.as_tensor(coords, dtype=torch.float32)
        atom_feature = _build_atom_feature(mol)
        atom_feature = torch.as_tensor(atom_feature, dtype=torch.float32)
        edge_index = torch_cluster.radius_graph(coords, r=edge_cutoff)

    node_s = atom_feature
    node_v = coords.unsqueeze(1)
    # edge_v, edge_index = _build_edge_feature(mol)
    edge_s, edge_v = _build_edge_feature(
        coords, edge_index, D_max=edge_cutoff, num_rbf=num_rbf)

    data = torch_geometric.data.Data(
        x=coords, edge_index=edge_index, name=name,
        node_v=node_v, node_s=node_s, edge_v=edge_v, edge_s=edge_s)
    return data



In [None]:
# parsing

In [None]:
def add_train_args(parser):
    # Dataset parameters
    parser.add_argument('--task', help='Task name')
    parser.add_argument('--split_method', default='random',
        choices=['random', 'protein', 'drug', 'both', 'seqid'],
        help='Split method: random, protein, drug, or both')
    parser.add_argument('--seed', type=int, default=42,
        help='Random Seed')

    # Data representation parameters
    parser.add_argument('--contact_cutoff', type=float, default=8.,
        help='cutoff of C-alpha distance to define protein contact graph')
    parser.add_argument('--num_pos_emb', type=int, default=16,
        help='number of positional embeddings')
    parser.add_argument('--num_rbf', type=int, default=16,
        help='number of RBF kernels')

    # Protein model parameters
    parser.add_argument('--prot_gcn_dims', type=int, nargs='+', default=[128, 256, 256],
        help='protein GCN layers dimensions')
    parser.add_argument('--prot_fc_dims', type=int, nargs='+', default=[1024, 128],
        help='protein FC layers dimensions')

    # Drug model parameters
    parser.add_argument('--drug_gcn_dims', type=int, nargs='+', default=[128, 64],
        help='drug GVP hidden layers dimensions')
    parser.add_argument('--drug_fc_dims', type=int, nargs='+', default=[1024, 128],
        help='drug FC layers dimensions')

    # Top model parameters
    parser.add_argument('--mlp_dims', type=int, nargs='+', default=[1024, 512],
        help='top MLP layers dimensions')
    parser.add_argument('--mlp_dropout', type=float, default=0.25,
        help='dropout rate in top MLP')

    # uncertainty parameters
    parser.add_argument('--uncertainty', action='store_true',
        help='estimate uncertainty')
    parser.add_argument('--recalibrate', action='store_true',
        help='recalibrate uncertainty')

    # Training parameters
    parser.add_argument('--n_ensembles', type=int, default=1,
        help='number of ensembles')
    parser.add_argument('--batch_size', type=int, default=128,
        help='batch size')
    parser.add_argument('--n_epochs', type=int, default=500,
        help='number of epochs')
    parser.add_argument('--patience', action='store', type=int,
        help='patience for early stopping')
    parser.add_argument('--eval_freq', type=int, default=1,
        help='evaluation frequency')
    parser.add_argument('--test_freq', type=int,
        help='test frequency')
    parser.add_argument('--lr', type=float, default=0.0005,
        help='learning rate')
    parser.add_argument('--monitor_metric', default='pearson',
        help='validation metric to monitor for deciding best checkpoint')
    parser.add_argument('--parallel', action='store_true',
        help='run ensembles in parallel on multiple GPUs')

    # Save parameters
    parser.add_argument('--output_dir', action='store', default='../output', help='output folder')
    parser.add_argument('--save_log', action='store_true', default=False, help='save log file')
    parser.add_argument('--save_checkpoint', action='store_true', default=False, help='save checkpoint')
    parser.add_argument('--save_prediction', action='store_true', default=False, help='save prediction')


In [None]:
# pdb_graph

In [None]:
"""
Adapted from
https://github.com/jingraham/neurips19-graph-protein-design
https://github.com/drorlab/gvp-pytorch
"""
import math
import numpy as np
import scipy as sp
from tqdm import tqdm
import torch
import torch.nn.functional as F
import torch_geometric
import torch_cluster


def pdb_to_graphs(prot_data, params):
    """
    Converts a list of protein dict to a list of torch_geometric graphs.
    Parameters
    ----------
    prot_data : dict
        A list of protein data dict. see format in `featurize_protein_graph()`.
    params : dict
        A dictionary of parameters defined in `featurize_protein_graph()`.
    Returns
    -------
    graphs : dict
        A list of torch_geometric graphs. protein key -> graph
    """
    graphs = {}
    for key, struct in tqdm(prot_data.items(), desc='pdb'):
        graphs[key] = featurize_protein_graph(
            struct, name=key, **params)
    return graphs

def featurize_protein_graph(
        protein, name=None,
        num_pos_emb=16, num_rbf=16,        
        contact_cutoff=8.,
    ):
    """
    Parameters: see comments of DTATask() in dta.py
    """
    with torch.no_grad():
        coords = torch.as_tensor(protein['coords'], dtype=torch.float32)
        seq = torch.as_tensor([LETTER_TO_NUM[a] for a in protein['seq']], dtype=torch.long)        
        seq_emb = torch.load(protein['embed'])

        mask = torch.isfinite(coords.sum(dim=(1,2)))
        coords[~mask] = np.inf

        X_ca = coords[:, 1]        
        ca_mask = torch.isfinite(X_ca.sum(dim=(1)))
        ca_mask = ca_mask.float()
        ca_mask_2D = torch.unsqueeze(ca_mask, 0) * torch.unsqueeze(ca_mask, 1)
        dX_ca = torch.unsqueeze(X_ca, 0) - torch.unsqueeze(X_ca, 1)
        D_ca = ca_mask_2D * torch.sqrt(torch.sum(dX_ca**2, 2) + 1e-6)
        edge_index = torch.nonzero((D_ca < contact_cutoff) & (ca_mask_2D == 1))
        edge_index = edge_index.t().contiguous()
        

        O_feature = _local_frame(X_ca, edge_index)
        pos_embeddings = _positional_embeddings(edge_index, num_embeddings=num_pos_emb)
        E_vectors = X_ca[edge_index[0]] - X_ca[edge_index[1]]
        rbf = _rbf(E_vectors.norm(dim=-1), D_count=num_rbf)

        dihedrals = _dihedrals(coords)
        orientations = _orientations(X_ca)
        sidechains = _sidechains(coords)

        node_s = dihedrals
        node_v = torch.cat([orientations, sidechains.unsqueeze(-2)], dim=-2)
        edge_s = torch.cat([rbf, O_feature, pos_embeddings], dim=-1)
        edge_v = _normalize(E_vectors).unsqueeze(-2)

        node_s, node_v, edge_s, edge_v = map(torch.nan_to_num,
                (node_s, node_v, edge_s, edge_v))

    data = torch_geometric.data.Data(x=X_ca, seq=seq, name=name,
                                        node_s=node_s, node_v=node_v,
                                        edge_s=edge_s, edge_v=edge_v,
                                        edge_index=edge_index, mask=mask,                                        
                                        seq_emb=seq_emb)
    return data


def _dihedrals(X, eps=1e-7):
    X = torch.reshape(X[:, :3], [3 * X.shape[0], 3])
    dX = X[1:] - X[:-1]
    U = _normalize(dX, dim=-1)
    u_2 = U[:-2]
    u_1 = U[1:-1]
    u_0 = U[2:]

    # Backbone normals
    n_2 = _normalize(torch.cross(u_2, u_1), dim=-1)
    n_1 = _normalize(torch.cross(u_1, u_0), dim=-1)

    # Angle between normals
    cosD = torch.sum(n_2 * n_1, -1)
    cosD = torch.clamp(cosD, -1 + eps, 1 - eps)
    D = torch.sign(torch.sum(u_2 * n_1, -1)) * torch.acos(cosD)

    # This scheme will remove phi[0], psi[-1], omega[-1]
    D = F.pad(D, [1, 2])
    D = torch.reshape(D, [-1, 3])
    # Lift angle representations to the circle
    D_features = torch.cat([torch.cos(D), torch.sin(D)], 1)
    return D_features


def _positional_embeddings(edge_index,
                            num_embeddings=None,
                            period_range=[2, 1000]):
    d = edge_index[0] - edge_index[1]

    frequency = torch.exp(
        torch.arange(0, num_embeddings, 2, dtype=torch.float32)
        * -(np.log(10000.0) / num_embeddings)
    )
    angles = d.unsqueeze(-1) * frequency
    E = torch.cat((torch.cos(angles), torch.sin(angles)), -1)
    return E


def _orientations(X):
    forward = _normalize(X[1:] - X[:-1])
    backward = _normalize(X[:-1] - X[1:])
    forward = F.pad(forward, [0, 0, 0, 1])
    backward = F.pad(backward, [0, 0, 1, 0])
    return torch.cat([forward.unsqueeze(-2), backward.unsqueeze(-2)], -2)


def _sidechains(X):
    n, origin, c = X[:, 0], X[:, 1], X[:, 2]
    c, n = _normalize(c - origin), _normalize(n - origin)
    bisector = _normalize(c + n)
    perp = _normalize(torch.cross(c, n))
    vec = -bisector * math.sqrt(1 / 3) - perp * math.sqrt(2 / 3)
    return vec


def _normalize(tensor, dim=-1):
    '''
    Normalizes a `torch.Tensor` along dimension `dim` without `nan`s.
    '''
    return torch.nan_to_num(
        torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True)))


def _rbf(D, D_min=0., D_max=20., D_count=16, device='cpu'):
    '''
    Returns an RBF embedding of `torch.Tensor` `D` along a new axis=-1.
    That is, if `D` has shape [...dims], then the returned tensor will have
    shape [...dims, D_count].
    '''
    D_mu = torch.linspace(D_min, D_max, D_count, device=device)
    D_mu = D_mu.view([1, -1])
    D_sigma = (D_max - D_min) / D_count
    D_expand = torch.unsqueeze(D, -1)

    RBF = torch.exp(-((D_expand - D_mu) / D_sigma) ** 2)
    return RBF


def _local_frame(X, edge_index, eps=1e-6):
    dX = X[1:] - X[:-1]
    U = _normalize(dX, dim=-1)
    u_2 = U[:-2]
    u_1 = U[1:-1]
    u_0 = U[2:]

    # Backbone normals
    n_2 = _normalize(torch.cross(u_2, u_1), dim=-1)
    n_1 = _normalize(torch.cross(u_1, u_0), dim=-1)

    o_1 = _normalize(u_2 - u_1, dim=-1)
    O = torch.stack((o_1, n_2, torch.cross(o_1, n_2)), 1)
    O = F.pad(O, (0, 0, 0, 0, 1, 2), 'constant', 0)

    # dX = X[edge_index[0]] - X[edge_index[1]]
    dX = X[edge_index[1]] - X[edge_index[0]]
    dX = _normalize(dX, dim=-1)
    # dU = torch.bmm(O[edge_index[1]], dX.unsqueeze(2)).squeeze(2)
    dU = torch.bmm(O[edge_index[0]], dX.unsqueeze(2)).squeeze(2)
    R = torch.bmm(O[edge_index[0]].transpose(-1,-2), O[edge_index[1]])
    Q = _quaternions(R)
    O_features = torch.cat((dU,Q), dim=-1)

    return O_features


def _quaternions(R):
    # Simple Wikipedia version
    # en.wikipedia.org/wiki/Rotation_matrix#Quaternion
    # For other options see math.stackexchange.com/questions/2074316/calculating-rotation-axis-from-rotation-matrix
    diag = torch.diagonal(R, dim1=-2, dim2=-1)
    Rxx, Ryy, Rzz = diag.unbind(-1)
    magnitudes = 0.5 * torch.sqrt(torch.abs(1 + torch.stack([
            Rxx - Ryy - Rzz,
        - Rxx + Ryy - Rzz,
        - Rxx - Ryy + Rzz
    ], -1)))
    _R = lambda i,j: R[:, i, j]
    signs = torch.sign(torch.stack([
        _R(2,1) - _R(1,2),
        _R(0,2) - _R(2,0),
        _R(1,0) - _R(0,1)
    ], -1))
    xyz = signs * magnitudes
    # The relu enforces a non-negative trace
    w = torch.sqrt(F.relu(1 + diag.sum(-1, keepdim=True))) / 2.
    Q = torch.cat((xyz, w), -1)
    Q = F.normalize(Q, dim=-1)
    return Q



In [None]:
# utils

In [None]:
import sys
import yaml
import logging
import torch
import pathlib
import numpy as np

class Logger(object):
    def __init__(self, logfile=None, level=logging.INFO):
        '''
        logfile: pathlib object
        '''
        self.logger = logging.getLogger()
        self.logger.setLevel(level)
        formatter = logging.Formatter("%(asctime)s\t%(message)s", "%Y-%m-%d %H:%M:%S")

        for hd in self.logger.handlers[:]:
            self.logger.removeHandler(hd)

        sh = logging.StreamHandler(sys.stdout)
        sh.setFormatter(formatter)
        self.logger.addHandler(sh)

        if logfile is not None:
            logfile.parent.mkdir(exist_ok=True, parents=True)
            fh = logging.FileHandler(logfile, 'w')
            fh.setFormatter(formatter)
            self.logger.addHandler(fh)

    def debug(self, msg):
        self.logger.debug(msg)

    def info(self, msg):
        self.logger.info(msg)

    def warning(self, msg):
        self.logger.warning(msg)

    def error(self, msg):
        self.logger.error(msg)


class Saver(object):
    def __init__(self, output_dir):        
        self.save_dir = pathlib.Path(output_dir)
    
    def mkdir(self):
        self.save_dir.mkdir(exist_ok=True, parents=True)

    def save_ckp(self, pt, filename='checkpoint.pt'):
        self.mkdir()
        torch.save(pt, str(self.save_dir/filename))

    def save_df(self, df, filename, float_format='%.6f'):
        self.mkdir()
        df.to_csv(self.save_dir/filename, float_format=float_format, index=False, sep='\t')
    
    def save_config(self, config, filename, overwrite=True):
        self.mkdir()
        with open(self.save_dir/filename, 'w') as f:
            yaml.dump(config, f, indent=2)



class EarlyStopping(object):
    def __init__(self, 
            patience=100, eval_freq=1, best_score=None, 
            delta=1e-9, higher_better=True):
        self.patience = patience
        self.eval_freq = eval_freq
        self.best_score = best_score
        self.delta = delta
        self.higher_better = higher_better
        self.counter = 0
        self.early_stop = False
    
    def not_improved(self, val_score):
        if np.isnan(val_score):
            return True
        if self.higher_better:
            return val_score < self.best_score + self.delta
        else:
            return val_score > self.best_score - self.delta
    
    def update(self, val_score):
        if self.best_score is None:
            self.best_score = val_score
            is_best = True
        elif self.not_improved(val_score):
            self.counter += self.eval_freq
            if (self.patience is not None) and (self.counter > self.patience):
                self.early_stop = True
            is_best = False
        else:
            self.best_score = val_score
            self.counter = 0
            is_best = True
        return is_best


In [None]:
# train run

In [None]:

def get_esm_embedding(seq, esm_model):
    inputs = tokenizer(seq, return_tensors="pt", add_special_tokens=True)
    inputs = {k: v.cuda() for k, v in inputs.items()}
    with torch.no_grad():
        outputs = esm_model(**inputs)
    token_representations = outputs.last_hidden_state
    # remove [CLS] and [EOS] if needed
    emb = token_representations[0, 1:-1]
    return emb.cpu()



# Dataset builder for DTA class
def build_dataset(df_fold, pdb_structures, exp_cols = "pKi", is_pred = False):
    data_list = []
    for i, row in df_fold.iterrows():
        pdb_id = os.path.basename(row["standardized_ligand_sdf"]).split(".")[0]
        protein_json = pdb_structures.get(pdb_id)
        protein = featurize_protein_graph(protein_json)
        drug = featurize_drug(row["standardized_ligand_sdf"])
        if is_pred == True:
            data_list.append({
                "protein": protein,
                "drug": drug,
                "y": 0
            })

        else:
            data_list.append({
                "protein": protein,
                "drug": drug,
                "y": float(row[exp_cols]),
            })
    return DTA(df=df_fold, data_list=data_list)


def extract_backbone_coords(structure, pdb_id, pdb_path):
    coords = {"N": [], "CA": [], "C": [], "O": []}
    seq = ""

    model = structure[0]

    valid_chain = None
    for chain in model:
        if any(is_aa(res, standard=True) for res in chain):
            valid_chain = chain
            break

    if valid_chain is None:
        print("No valid chains: ", pdb_id, pdb_path)
        return None, None, None

    chain_id = valid_chain.id

    for res in valid_chain:
        if not is_aa(res, standard=True):
            continue
        seq += res.resname[0]  # fallback, not exact 1-letter code

        for atom_name in ["N", "CA", "C", "O"]:
            if atom_name in res:
                coords[atom_name].append(res[atom_name].coord.tolist())
            else:
                coords[atom_name].append([float("nan")] * 3)

    return seq, coords, chain_id

# MTL

In [None]:
# ============= MODIFICATIONS FOR MULTI-TASK LEARNING =============

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

# 1. MASKED MSE LOSS WITH TASK WEIGHTING
class MaskedMSELoss(nn.Module):
    """
    Masked MSE loss that handles NaN values and applies task-specific weighting
    based on the inverse of the range of each task.
    """
    def __init__(self, task_ranges=None):
        super(MaskedMSELoss, self).__init__()
        self.task_ranges = task_ranges
        if task_ranges is not None:
            # Calculate task weights based on inverse range
            weights = []
            for range_val in task_ranges.values():
                weights.append(1.0 / range_val if range_val > 0 else 1.0)
            total_weight = sum(weights)
            self.task_weights = torch.tensor([w / total_weight for w in weights])
        else:
            self.task_weights = None
    
    def forward(self, pred, target):
        """
        pred: [batch_size, n_tasks]
        target: [batch_size, n_tasks]
        """
        # Create mask for non-NaN values
        mask = ~torch.isnan(target)
        
        # Calculate MSE only for non-NaN values
        if mask.sum() == 0:
            return torch.tensor(0.0, requires_grad=True)
        
        # Apply mask
        pred_masked = pred[mask]
        target_masked = target[mask]
        
        # Calculate squared errors
        se = (pred_masked - target_masked) ** 2
        
        # If we have task weights, apply them
        if self.task_weights is not None:
            # Expand mask to get task indices
            task_indices = torch.where(mask)[1]
            weights = self.task_weights.to(pred.device)[task_indices]
            weighted_se = se * weights
            loss = weighted_se.mean()
        else:
            loss = se.mean()
        
        return loss

    
# 2. MODIFIED DTA MODEL FOR MULTI-TASK LEARNING
class MTL_DTAModel(nn.Module):
    def __init__(self,
            task_names=['pKi', 'pEC50', 'pKd', 'pIC50'],  # List of tasks
            prot_emb_dim=1280,
            prot_gcn_dims=[128, 256, 256],
            prot_fc_dims=[1024, 128],
            drug_node_in_dim=[66, 1], drug_node_h_dims=[128, 64],
            drug_edge_in_dim=[16, 1], drug_edge_h_dims=[32, 1],            
            drug_fc_dims=[1024, 128],
            mlp_dims=[1024, 512], mlp_dropout=0.25):
        super(MTL_DTAModel, self).__init__()
        
        self.task_names = task_names
        self.n_tasks = len(task_names)
        
        # Same encoders as before
        self.drug_model = DrugGVPModel(
            node_in_dim=drug_node_in_dim, node_h_dim=drug_node_h_dims,
            edge_in_dim=drug_edge_in_dim, edge_h_dim=drug_edge_h_dims,
        )
        drug_emb_dim = drug_node_h_dims[0]
        
        self.prot_model = Prot3DGraphModel(
            d_pretrained_emb=prot_emb_dim, d_gcn=prot_gcn_dims
        )
        prot_emb_dim = prot_gcn_dims[-1]
        
        self.drug_fc = self.get_fc_layers(
            [drug_emb_dim] + drug_fc_dims,
            dropout=mlp_dropout, batchnorm=False,
            no_last_dropout=True, no_last_activation=True)
       
        self.prot_fc = self.get_fc_layers(
            [prot_emb_dim] + prot_fc_dims,
            dropout=mlp_dropout, batchnorm=False,
            no_last_dropout=True, no_last_activation=True)
        
        # Shared representation layers
        self.shared_fc = self.get_fc_layers(
            [drug_fc_dims[-1] + prot_fc_dims[-1]] + mlp_dims,
            dropout=mlp_dropout, batchnorm=False,
            no_last_dropout=True, no_last_activation=True)
        
        # Task-specific heads (one for each task)
        self.task_heads = nn.ModuleDict({
            task: nn.Linear(mlp_dims[-1], 1) for task in task_names
        })
    
    def get_fc_layers(self, hidden_sizes,
            dropout=0, batchnorm=False,
            no_last_dropout=True, no_last_activation=True):
        act_fn = torch.nn.LeakyReLU()
        layers = []
        for i, (in_dim, out_dim) in enumerate(zip(hidden_sizes[:-1], hidden_sizes[1:])):
            layers.append(nn.Linear(in_dim, out_dim))
            if not no_last_activation or i != len(hidden_sizes) - 2:
                layers.append(act_fn)
            if dropout > 0:
                if not no_last_dropout or i != len(hidden_sizes) - 2:
                    layers.append(nn.Dropout(dropout))
            if batchnorm and i != len(hidden_sizes) - 2:
                layers.append(nn.BatchNorm1d(out_dim))
        return nn.Sequential(*layers)
    
    def forward(self, xd, xp):
        # Encode drug and protein
        xd = self.drug_model(xd)
        xp = self.prot_model(xp)
        
        # Process through FC layers
        xd = self.drug_fc(xd)
        xp = self.prot_fc(xp)
        
        # Concatenate and process through shared layers
        x = torch.cat([xd, xp], dim=1)
        shared_repr = self.shared_fc(x)
        
        # Generate predictions for each task
        outputs = []
        for task in self.task_names:
            task_pred = self.task_heads[task](shared_repr)
            outputs.append(task_pred)
        
        # Stack outputs: [batch_size, n_tasks]
        return torch.cat(outputs, dim=1)

# 3. MODIFIED DTA DATASET CLASS
class MTL_DTA(data.Dataset):
    def __init__(self, df=None, data_list=None, task_cols=None, onthefly=False,
                prot_featurize_fn=None, drug_featurize_fn=None):
        super(MTL_DTA, self).__init__()
        self.data_df = df
        self.data_list = data_list
        self.task_cols = task_cols or ['pKi', 'pEC50', 'pKd', 'pIC50']
        self.onthefly = onthefly
        self.prot_featurize_fn = prot_featurize_fn
        self.drug_featurize_fn = drug_featurize_fn
    
    def __len__(self):
        return len(self.data_list)
    
    def __getitem__(self, idx):
        if self.onthefly:
            drug = self.drug_featurize_fn(
                self.data_list[idx]['drug'],
                name=self.data_list[idx]['drug_name']
            )
            prot = self.prot_featurize_fn(
                self.data_list[idx]['protein'],
                name=self.data_list[idx]['protein_name']
            )
        else:
            drug = self.data_list[idx]['drug']
            prot = self.data_list[idx]['protein']
        
        # Get multi-task targets
        y_multi = []
        for task in self.task_cols:
            val = self.data_list[idx].get(task, np.nan)
            y_multi.append(val if not pd.isna(val) else np.nan)
        
        y = torch.tensor(y_multi, dtype=torch.float32)
        
        item = {'drug': drug, 'protein': prot, 'y': y}
        return item

    



# Dataset builder for DTA class
def build_dataset(df_fold, pdb_structures, exp_cols = "pKi", is_pred = False):
    data_list = []
    for i, row in df_fold.iterrows():
        pdb_id = os.path.basename(row["standardized_ligand_sdf"]).split(".")[0]
        protein_json = pdb_structures.get(pdb_id)
        protein = featurize_protein_graph(protein_json)
        drug = featurize_drug(row["standardized_ligand_sdf"])
        if is_pred == True:
            data_list.append({
                "protein": protein,
                "drug": drug,
                "y": 0
            })

        else:
            data_list.append({
                "protein": protein,
                "drug": drug,
                "y": float(row[exp_cols]),
            })
    return DTA(df=df_fold, data_list=data_list)



# 4. MODIFIED BUILD DATASET FUNCTION
def build_mtl_dataset(df_fold, pdb_structures, task_cols=['pKi', 'pEC50', 'pKd', 'pIC50']):
    data_list = []
    for i, row in df_fold.iterrows():
        pdb_id = os.path.basename(row["standardized_ligand_sdf"]).split(".")[0]
        protein_json = pdb_structures.get(pdb_id)
        protein = featurize_protein_graph(protein_json)
        drug = featurize_drug(row["standardized_ligand_sdf"])
        
        # Collect all task values
        task_values = {}
        for task in task_cols:
            if task in row and not pd.isna(row[task]):
                task_values[task] = float(row[task])
            else:
                task_values[task] = np.nan
        
        data_entry = {
            "protein": protein,
            "drug": drug,
        }
        data_entry.update(task_values)
        data_list.append(data_entry)
    
    return MTL_DTA(df=df_fold, data_list=data_list, task_cols=task_cols)

# 5. MODIFIED TRAINING LOOP
def train_mtl_model(model, train_loader, valid_loader, task_cols, task_ranges, 
                    n_epochs=100, lr=0.0005, device='cuda', patience=20):
    """
    Training loop for multi-task learning model
    
    Args:
        task_cols: List of task column names
        task_ranges: Dict mapping task names to their value ranges
    """
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = MaskedMSELoss(task_ranges=task_ranges)
    stopper = EarlyStopping(patience=patience, higher_better=False)
    best_model = None
    
    for epoch in range(n_epochs):
        # Training
        model.train()
        train_loss = 0
        n_batches = 0
        
        for batch in train_loader:
            xd = batch['drug'].to(device)
            xp = batch['protein'].to(device)
            y = batch['y'].to(device)  # [batch_size, n_tasks]
            
            optimizer.zero_grad()
            pred = model(xd, xp)  # [batch_size, n_tasks]
            loss = criterion(pred, y)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            n_batches += 1
        
        # Validation
        model.eval()
        val_loss = 0
        val_n_batches = 0
        task_metrics = {task: {'mse': 0, 'n': 0} for task in task_cols}
        
        with torch.no_grad():
            for batch in valid_loader:
                xd = batch['drug'].to(device)
                xp = batch['protein'].to(device)
                y = batch['y'].to(device)
                
                pred = model(xd, xp)
                loss = criterion(pred, y)
                val_loss += loss.item()
                val_n_batches += 1
                
                # Calculate per-task metrics
                for i, task in enumerate(task_cols):
                    mask = ~torch.isnan(y[:, i])
                    if mask.sum() > 0:
                        task_mse = F.mse_loss(pred[mask, i], y[mask, i])
                        task_metrics[task]['mse'] += task_mse.item()
                        task_metrics[task]['n'] += 1
        
        avg_train_loss = train_loss / n_batches
        avg_val_loss = val_loss / val_n_batches if val_n_batches > 0 else float('inf')
        
        # Print metrics
        print(f"Epoch {epoch+1}/{n_epochs}")
        print(f"  Train Loss: {avg_train_loss:.4f}")
        print(f"  Valid Loss: {avg_val_loss:.4f}")
        
        for task in task_cols:
            if task_metrics[task]['n'] > 0:
                avg_task_mse = task_metrics[task]['mse'] / task_metrics[task]['n']
                print(f"  {task} MSE: {avg_task_mse:.4f}")
        
        # Early stopping
        if stopper.update(avg_val_loss):
            best_model = model.state_dict()
        if stopper.early_stop:
            print("Early stopping triggered")
            break
    
    if best_model is not None:
        model.load_state_dict(best_model)
    
    return model

# 6. EXAMPLE USAGE
def prepare_mtl_experiment(df, task_cols=['pKi', 'pEC50', 'pKd', 'pIC50']):
    """
    Prepare data for multi-task learning
    """
    # Calculate task ranges for weighting
    task_ranges = {}
    for task in task_cols:
        if task in df.columns:
            valid_values = df[task].dropna()
            if len(valid_values) > 0:
                task_ranges[task] = valid_values.max() - valid_values.min()
            else:
                task_ranges[task] = 1.0
        else:
            task_ranges[task] = 1.0
    
    print("Task ranges for weighting:")
    for task, range_val in task_ranges.items():
        weight = 1.0 / range_val if range_val > 0 else 1.0
        normalized_weight = weight / sum(1.0/r if r > 0 else 1.0 for r in task_ranges.values())
        print(f"  {task}: range={range_val:.2f}, weight={normalized_weight:.4f}")
    
    return task_ranges



import json

def structureJSON(df, esm_model):
    structure_dict = {}

    for i, row in tqdm(df.iterrows(), total=len(df)):
        pdb_path = row["standardized_protein_pdb"]
        try:

            pdb_id = os.path.basename(pdb_path).split('.')[0]

            structure = parser.get_structure(pdb_id, pdb_path)
            seq, coords, chain_id = extract_backbone_coords(structure, pdb_id, pdb_path)
            if seq is None:
                available = [c.id for c in structure[0]]
                print(f"[SKIP] {pdb_id}: no usable chain found (available: {available})")
                continue


            # Stack in order: N, CA, C, O --> [L, 4, 3]
            coords_stacked = []
            for i in range(len(coords["N"])):
                coord_group = []
                for atom in ["N", "CA", "C", "O"]:
                    coord_group.append(coords[atom][i])
                coords_stacked.append(coord_group)

            if coords_stacked is None:
                print(f"[SKIP] {pdb_id}: no usable coords found (available: {pdb_path})")
                continue

                
            embedding = get_esm_embedding(seq, esm_model)
            torch.save(embedding, f"esm_embeddings/{pdb_id}.pt")

            if coords_stacked != None and embedding != None:
                structure_dict[pdb_id] = {
                    "name": pdb_id,
                    "UniProt_id": "UNKNOWN",
                    "PDB_id": pdb_id,
                    "chain": chain_id,
                    "seq": seq,
                    "coords": coords_stacked,
                    "embed": f"esm_embeddings/{pdb_id}.pt"

                }

        except Exception as e:
            print(f"[ERROR] Failed to process {pdb_id}: {e}")
            continue



    # Save to JSON
    with open("../data/pockets_structure.json", "w") as f:
        json.dump(structure_dict, f, indent=2)


    print(f"\n✅ Done. Saved {len(structure_dict)} protein structures to pockets_structure.json")

    return(structure_dict)



In [None]:
import os
import json
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm.auto import tqdm
import torch
from Bio.PDB import PDBParser  # make sure Biopython is installed

# Assumes you already have:
# - extract_backbone_coords(structure, pdb_id, pdb_path)
# - get_esm_embedding(seq, esm_model)

ATOMS = ("N", "CA", "C", "O")
EMBED_DIR = Path("esm_embeddings")
EMBED_DIR.mkdir(parents=True, exist_ok=True)

def _stack_backbone(coords):
    # coords: dict with keys "N","CA","C","O", each a list of [x,y,z]
    L = len(coords["N"])
    return [[coords[a][i] for a in ATOMS] for i in range(L)]

def _process_pdb_path(pdb_path):
    """
    Worker: parse PDB, extract seq/coords/chain, return tuple or a skip marker.
    Runs in a separate process; initializes its own parser.
    """
    try:
        pdb_id = os.path.basename(pdb_path).split('.')[0]
        parser = PDBParser(QUIET=True)
        structure = parser.get_structure(pdb_id, pdb_path)

        seq, coords, chain_id = extract_backbone_coords(structure, pdb_id, pdb_path)
        if seq is None:
            available = [c.id for c in structure[0]]
            return ("skip", pdb_id, f"no usable chain (available: {available})")

        if not coords or any(k not in coords for k in ATOMS) or len(coords["N"]) == 0:
            return ("skip", pdb_id, "no usable coords")

        coords_stacked = _stack_backbone(coords)
        if not coords_stacked:
            return ("skip", pdb_id, "empty coords after stacking")

        return ("ok", pdb_id, seq, coords_stacked, chain_id)

    except Exception as e:
        return ("error", os.path.basename(pdb_path).split('.')[0], str(e))

def structureJSON(df, esm_model, max_workers=None, embed_batch_size=8, out_json="../data/pockets_structure.json"):
    structure_dict = {}

    pdb_paths = df["standardized_protein_pdb"].tolist()
    results = []

    # Phase 1: parallel PDB parsing + coordinate extraction
    with ProcessPoolExecutor(max_workers=max_workers) as ex:
        futures = {ex.submit(_process_pdb_path, p): p for p in pdb_paths}
        for fut in tqdm(as_completed(futures), total=len(futures), desc="PDB -> seq/coords"):
            status_tuple = fut.result()
            results.append(status_tuple)

    # Log skips/errors (fast)
    for r in results:
        tag = r[0]
        if tag == "skip":
            _, pdb_id, msg = r
            print(f"[SKIP] {pdb_id}: {msg}")
        elif tag == "error":
            _, pdb_id, err = r
            print(f"[ERROR] Failed to process {pdb_id}: {err}")

    # Keep only successful items
    ok_items = [(pdb_id, seq, coords_stacked, chain_id)
                for tag, pdb_id, *rest in results if tag == "ok"
                for (seq, coords_stacked, chain_id) in [tuple(rest)]]

    # Phase 2: embeddings on a single device (GPU/CPU) to avoid per-process model copies
    # Optionally batch if your get_esm_embedding supports lists; otherwise do per-sequence.
    # Here we do per-sequence by default; simple and safe.
    for pdb_id, seq, coords_stacked, chain_id in tqdm(ok_items, desc="ESM embeddings"):
        try:
            embedding = get_esm_embedding(seq, esm_model)  # ensure this returns a tensor
            torch.save(embedding, EMBED_DIR / f"{pdb_id}.pt")

            structure_dict[pdb_id] = {
                "name": pdb_id,
                "UniProt_id": "UNKNOWN",
                "PDB_id": pdb_id,
                "chain": chain_id,
                "seq": seq,
                "coords": coords_stacked,         # [[N,CA,C,O], ...], each as [x,y,z]
                "embed": str(EMBED_DIR / f"{pdb_id}.pt")
            }
        except Exception as e:
            print(f"[ERROR] ESM embedding failed for {pdb_id}: {e}")

    # Save to JSON
    os.makedirs(os.path.dirname(out_json), exist_ok=True)
    with open(out_json, "w") as f:
        json.dump(structure_dict, f, indent=2)

    print(f"\n✅ Done. Saved {len(structure_dict)} protein structures to {os.path.basename(out_json)}")
    return structure_dict





import os
import json
import pandas as pd
import torch
from pathlib import Path
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
import numpy as np

def structureJSON_chunked(df, esm_model, max_workers=None, embed_batch_size=8, 
                          chunk_size=100000, out_dir="../data/structure_chunks/"):
    """
    Process structures in chunks to avoid memory issues.
    
    Args:
        df: DataFrame with protein PDB paths
        esm_model: ESM model for embeddings
        max_workers: Number of parallel workers
        chunk_size: Maximum entries per chunk (default 100000)
        out_dir: Directory to save chunked JSON files
    
    Returns:
        dict: Metadata about created chunks
    """
    os.makedirs(out_dir, exist_ok=True)
    os.makedirs(EMBED_DIR, exist_ok=True)
    
    pdb_paths = df["standardized_protein_pdb"].tolist()
    total_pdbs = len(pdb_paths)
    num_chunks = (total_pdbs + chunk_size - 1) // chunk_size
    
    print(f"Processing {total_pdbs} PDBs in {num_chunks} chunks of max {chunk_size} each")
    
    chunk_metadata = {
        "num_chunks": num_chunks,
        "chunk_size": chunk_size,
        "chunks": []
    }
    
    # Process in chunks
    for chunk_idx in range(num_chunks):
        start_idx = chunk_idx * chunk_size
        end_idx = min((chunk_idx + 1) * chunk_size, total_pdbs)
        chunk_paths = pdb_paths[start_idx:end_idx]
        
        print(f"\n=== Processing chunk {chunk_idx + 1}/{num_chunks} ({len(chunk_paths)} PDBs) ===")
        
        structure_dict = {}
        results = []
        
        # Phase 1: Parallel PDB parsing for this chunk
        with ProcessPoolExecutor(max_workers=max_workers) as ex:
            futures = {ex.submit(_process_pdb_path, p): p for p in chunk_paths}
            for fut in tqdm(as_completed(futures), total=len(futures), 
                          desc=f"Chunk {chunk_idx + 1} - PDB parsing"):
                status_tuple = fut.result()
                results.append(status_tuple)
        
        # Log errors for this chunk
        for r in results:
            tag = r[0]
            if tag == "skip":
                _, pdb_id, msg = r
                print(f"[SKIP] {pdb_id}: {msg}")
            elif tag == "error":
                _, pdb_id, err = r
                print(f"[ERROR] Failed to process {pdb_id}: {err}")
        
        # Keep only successful items
        ok_items = [(pdb_id, seq, coords_stacked, chain_id)
                    for tag, pdb_id, *rest in results if tag == "ok"
                    for (seq, coords_stacked, chain_id) in [tuple(rest)]]
        
        # Phase 2: ESM embeddings for this chunk
        for pdb_id, seq, coords_stacked, chain_id in tqdm(ok_items, 
                                                          desc=f"Chunk {chunk_idx + 1} - ESM embeddings"):
            try:
                embedding = get_esm_embedding(seq, esm_model)
                embed_path = EMBED_DIR / f"{pdb_id}.pt"
                torch.save(embedding, embed_path)
                
                structure_dict[pdb_id] = {
                    "name": pdb_id,
                    "UniProt_id": "UNKNOWN",
                    "PDB_id": pdb_id,
                    "chain": chain_id,
                    "seq": seq,
                    "coords": coords_stacked,
                    "embed": str(embed_path)
                }
            except Exception as e:
                print(f"[ERROR] ESM embedding failed for {pdb_id}: {e}")
        
        # Save this chunk
        chunk_filename = f"structures_chunk_{chunk_idx:04d}.json"
        chunk_path = os.path.join(out_dir, chunk_filename)
        with open(chunk_path, "w") as f:
            json.dump(structure_dict, f, indent=2)
        
        chunk_info = {
            "chunk_idx": chunk_idx,
            "filename": chunk_filename,
            "path": chunk_path,
            "num_structures": len(structure_dict),
            "start_idx": start_idx,
            "end_idx": end_idx
        }
        chunk_metadata["chunks"].append(chunk_info)
        
        print(f"✅ Chunk {chunk_idx + 1} saved: {len(structure_dict)} structures to {chunk_filename}")
    
    # Save metadata
    metadata_path = os.path.join(out_dir, "chunk_metadata.json")
    with open(metadata_path, "w") as f:
        json.dump(chunk_metadata, f, indent=2)
    
    print(f"\n✅ All chunks processed. Metadata saved to {metadata_path}")
    return chunk_metadata


class StructureChunkLoader:
    """
    Efficient loader for chunked structure dictionaries.
    Loads chunks on-demand and caches them.
    """
    def __init__(self, chunk_dir="../data/structure_chunks/", cache_size=2):
        self.chunk_dir = chunk_dir
        self.cache_size = cache_size
        self.cache = {}  # chunk_idx -> structure_dict
        self.cache_order = []  # LRU tracking
        
        # Load metadata
        metadata_path = os.path.join(chunk_dir, "chunk_metadata.json")
        with open(metadata_path, "r") as f:
            self.metadata = json.load(f)
        
        # Build lookup: pdb_id -> chunk_idx
        self.pdb_to_chunk = {}
        for chunk_info in self.metadata["chunks"]:
            chunk_path = os.path.join(chunk_dir, chunk_info["filename"])
            # Quick scan to build index (could be saved in metadata for efficiency)
            with open(chunk_path, "r") as f:
                chunk_data = json.load(f)
                for pdb_id in chunk_data.keys():
                    self.pdb_to_chunk[pdb_id] = chunk_info["chunk_idx"]
    
    def _load_chunk(self, chunk_idx):
        """Load a chunk into cache, managing cache size."""
        if chunk_idx in self.cache:
            # Move to end (most recently used)
            self.cache_order.remove(chunk_idx)
            self.cache_order.append(chunk_idx)
            return self.cache[chunk_idx]
        
        # Load chunk
        chunk_info = self.metadata["chunks"][chunk_idx]
        chunk_path = os.path.join(self.chunk_dir, chunk_info["filename"])
        with open(chunk_path, "r") as f:
            chunk_data = json.load(f)
        
        # Add to cache
        self.cache[chunk_idx] = chunk_data
        self.cache_order.append(chunk_idx)
        
        # Evict oldest if cache is full
        if len(self.cache) > self.cache_size:
            oldest = self.cache_order.pop(0)
            del self.cache[oldest]
        
        return chunk_data
    
    def get(self, pdb_id):
        """Get structure for a specific PDB ID."""
        if pdb_id not in self.pdb_to_chunk:
            return None
        
        chunk_idx = self.pdb_to_chunk[pdb_id]
        chunk_data = self._load_chunk(chunk_idx)
        return chunk_data.get(pdb_id)
    
    def get_batch(self, pdb_ids):
        """Get multiple structures efficiently by grouping by chunk."""
        # Group PDB IDs by chunk
        chunk_groups = {}
        for pdb_id in pdb_ids:
            if pdb_id in self.pdb_to_chunk:
                chunk_idx = self.pdb_to_chunk[pdb_id]
                if chunk_idx not in chunk_groups:
                    chunk_groups[chunk_idx] = []
                chunk_groups[chunk_idx].append(pdb_id)
        
        # Load each chunk and extract structures
        results = {}
        for chunk_idx, chunk_pdb_ids in chunk_groups.items():
            chunk_data = self._load_chunk(chunk_idx)
            for pdb_id in chunk_pdb_ids:
                if pdb_id in chunk_data:
                    results[pdb_id] = chunk_data[pdb_id]
        
        return results
    
    def get_available_pdb_ids(self):
        """Return set of all available PDB IDs."""
        return set(self.pdb_to_chunk.keys())


def build_mtl_dataset_optimized(df_fold, chunk_loader, task_cols=['pKi', 'pEC50', 'pKd', 'pIC50']):
    """
    Build MTL dataset efficiently using chunked structure loader.
    
    Args:
        df_fold: DataFrame with fold data
        chunk_loader: StructureChunkLoader instance
        task_cols: List of task columns
    
    Returns:
        MTL_DTA dataset
    """
    data_list = []
    
    # Get all protein IDs from the fold
    protein_ids = df_fold["protein_id"].tolist()
    
    # Batch load structures (efficient chunk-based loading)
    print(f"Loading structures for {len(protein_ids)} proteins...")
    structures_batch = chunk_loader.get_batch(protein_ids)
    
    # Process each row
    skipped = 0
    for i, row in tqdm(df_fold.iterrows(), total=len(df_fold), desc="Building dataset"):
        protein_id = row["protein_id"]
        
        if protein_id not in structures_batch:
            skipped += 1
            continue
        
        protein_json = structures_batch[protein_id]
        protein = featurize_protein_graph(protein_json)
        drug = featurize_drug(row["standardized_ligand_sdf"])
        
        # Collect all task values
        task_values = {}
        for task in task_cols:
            if task in row and not pd.isna(row[task]):
                task_values[task] = float(row[task])
            else:
                task_values[task] = np.nan
        
        data_entry = {
            "protein": protein,
            "drug": drug,
        }
        data_entry.update(task_values)
        data_list.append(data_entry)
    
    if skipped > 0:
        print(f"Warning: Skipped {skipped} entries due to missing structures")
    
    return MTL_DTA(df=df_fold, data_list=data_list, task_cols=task_cols)

# 4. MODIFIED BUILD DATASET FUNCTION
def build_mtl_dataset(df_fold, pdb_structures, task_cols=['pKi', 'pEC50', 'pKd', 'pIC50']):
    data_list = []
    for i, row in df_fold.iterrows():
        pdb_id = os.path.basename(row["standardized_ligand_sdf"]).split(".")[0]
        protein_json = pdb_structures.get(pdb_id)
        protein = featurize_protein_graph(protein_json)
        drug = featurize_drug(row["standardized_ligand_sdf"])
        
        # Collect all task values
        task_values = {}
        for task in task_cols:
            if task in row and not pd.isna(row[task]):
                task_values[task] = float(row[task])
            else:
                task_values[task] = np.nan
        
        data_entry = {
            "protein": protein,
            "drug": drug,
        }
        data_entry.update(task_values)
        data_list.append(data_entry)
    
    return MTL_DTA(df=df_fold, data_list=data_list, task_cols=task_cols)

# ============= USAGE EXAMPLE =============


In [None]:
import os
import json
import pandas as pd
import torch
from pathlib import Path
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
import numpy as np
from multiprocessing import Pool, cpu_count
import gc

def process_single_chunk(args):
    """
    Process a single chunk of PDB files independently.
    This function is designed to be run in parallel.
    
    Args:
        args: tuple of (chunk_idx, pdb_paths, out_dir, embed_dir, esm_model_name)
    
    Returns:
        dict with chunk processing results
    """
    chunk_idx, pdb_paths, out_dir, embed_dir, esm_model_name = args
    
    # Import inside function for multiprocessing
    from transformers import EsmModel, EsmTokenizer
    from concurrent.futures import ProcessPoolExecutor, as_completed
    from tqdm import tqdm
    import torch
    import json
    import os
    
    print(f"\n[Chunk {chunk_idx}] Starting processing of {len(pdb_paths)} PDBs")
    
    # Load ESM model for this process
    print(f"[Chunk {chunk_idx}] Loading ESM model...")
    tokenizer = EsmTokenizer.from_pretrained(esm_model_name)
    esm_model = EsmModel.from_pretrained(esm_model_name)
    esm_model.eval()
    
    # Move to GPU if available (each process gets its own GPU memory)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        # For multi-GPU, assign different chunks to different GPUs
        num_gpus = torch.cuda.device_count()
        gpu_id = chunk_idx % num_gpus
        device = torch.device(f'cuda:{gpu_id}')
        esm_model = esm_model.to(device)
        print(f"[Chunk {chunk_idx}] Using GPU {gpu_id}")
    else:
        print(f"[Chunk {chunk_idx}] Using CPU")
    
    structure_dict = {}
    results = []
    
    # Phase 1: Parallel PDB parsing within this chunk
    max_workers = min(8, cpu_count() // 4)  # Limit workers per chunk
    with ProcessPoolExecutor(max_workers=max_workers) as ex:
        futures = {ex.submit(_process_pdb_path, p): p for p in pdb_paths}
        for fut in tqdm(as_completed(futures), total=len(futures), 
                      desc=f"Chunk {chunk_idx} - PDB parsing", position=chunk_idx):
            try:
                status_tuple = fut.result(timeout=30)  # Add timeout
                results.append(status_tuple)
            except Exception as e:
                print(f"[Chunk {chunk_idx}] Error processing PDB: {e}")
    
    # Log errors
    error_count = 0
    skip_count = 0
    for r in results:
        tag = r[0]
        if tag == "skip":
            skip_count += 1
        elif tag == "error":
            error_count += 1
    
    if error_count > 0 or skip_count > 0:
        print(f"[Chunk {chunk_idx}] Skipped: {skip_count}, Errors: {error_count}")
    
    # Keep only successful items
    ok_items = [(pdb_id, seq, coords_stacked, chain_id)
                for tag, pdb_id, *rest in results if tag == "ok"
                for (seq, coords_stacked, chain_id) in [tuple(rest)]]
    
    # Phase 2: ESM embeddings (batch processing for efficiency)
    print(f"[Chunk {chunk_idx}] Computing ESM embeddings for {len(ok_items)} proteins...")
    
    os.makedirs(embed_dir, exist_ok=True)
    
    # Process in batches to optimize GPU usage
    batch_size = 8
    for i in tqdm(range(0, len(ok_items), batch_size), 
                  desc=f"Chunk {chunk_idx} - ESM embeddings", position=chunk_idx):
        batch = ok_items[i:i+batch_size]
        
        for pdb_id, seq, coords_stacked, chain_id in batch:
            try:
                # Compute embedding
                with torch.no_grad():
                    embedding = get_esm_embedding(seq, esm_model, tokenizer, device)
                
                # Save embedding
                embed_path = os.path.join(embed_dir, f"{pdb_id}.pt")
                torch.save(embedding.cpu(), embed_path)  # Save on CPU to free GPU memory
                
                structure_dict[pdb_id] = {
                    "name": pdb_id,
                    "UniProt_id": "UNKNOWN",
                    "PDB_id": pdb_id,
                    "chain": chain_id,
                    "seq": seq,
                    "coords": coords_stacked,
                    "embed": embed_path
                }
            except Exception as e:
                print(f"[Chunk {chunk_idx}] ESM embedding failed for {pdb_id}: {e}")
        
        # Periodic garbage collection
        if i % 100 == 0:
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    
    # Save chunk
    chunk_filename = f"structures_chunk_{chunk_idx:04d}.json"
    chunk_path = os.path.join(out_dir, chunk_filename)
    with open(chunk_path, "w") as f:
        json.dump(structure_dict, f, indent=2)
    
    print(f"[Chunk {chunk_idx}] ✅ Completed: {len(structure_dict)} structures saved")
    
    # Clean up GPU memory
    del esm_model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    return {
        "chunk_idx": chunk_idx,
        "filename": chunk_filename,
        "path": chunk_path,
        "num_structures": len(structure_dict),
        "num_errors": error_count,
        "num_skipped": skip_count
    }


def get_esm_embedding(seq, esm_model, tokenizer, device):
    """
    Get ESM embedding for a sequence.
    
    Args:
        seq: Protein sequence
        esm_model: ESM model
        tokenizer: ESM tokenizer
        device: torch device
    
    Returns:
        torch.Tensor: Embedding
    """
    inputs = tokenizer(seq, return_tensors="pt", truncation=True, max_length=1024)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = esm_model(**inputs)
        # Use mean pooling over sequence length
        embedding = outputs.last_hidden_state.mean(dim=1)
    
    return embedding


def structureJSON_chunked(df, esm_model_name="facebook/esm2_t33_650M_UR50D",
                         chunk_size=100000, max_chunks_parallel=4,
                         out_dir="../data/structure_chunks/",
                         embed_dir="../data/embeddings/"):
    """
    Process structures in parallel chunks to avoid memory issues and maximize speed.
    
    Args:
        df: DataFrame with protein PDB paths
        esm_model_name: Name of ESM model to use
        chunk_size: Maximum entries per chunk (default 100000)
        max_chunks_parallel: Maximum number of chunks to process in parallel
        out_dir: Directory to save chunked JSON files
        embed_dir: Directory to save embeddings
    
    Returns:
        dict: Metadata about created chunks
    """
    os.makedirs(out_dir, exist_ok=True)
    os.makedirs(embed_dir, exist_ok=True)
    
    # Get unique PDB paths (avoid duplicates)
    pdb_paths = df["standardized_protein_pdb"].unique().tolist()
    total_pdbs = len(pdb_paths)
    num_chunks = (total_pdbs + chunk_size - 1) // chunk_size
    
    print(f"=" * 80)
    print(f"Processing {total_pdbs} unique PDBs in {num_chunks} chunks")
    print(f"Chunk size: {chunk_size}, Parallel chunks: {max_chunks_parallel}")
    print(f"=" * 80)
    
    # Prepare chunk arguments
    chunk_args = []
    for chunk_idx in range(num_chunks):
        start_idx = chunk_idx * chunk_size
        end_idx = min((chunk_idx + 1) * chunk_size, total_pdbs)
        chunk_paths = pdb_paths[start_idx:end_idx]
        
        chunk_args.append((
            chunk_idx,
            chunk_paths,
            out_dir,
            embed_dir,
            esm_model_name
        ))
    
    # Process chunks in parallel
    chunk_results = []
    
    # Determine optimal number of parallel processes
    num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
    if num_gpus > 0:
        # If we have GPUs, process one chunk per GPU
        parallel_chunks = min(max_chunks_parallel, num_gpus, num_chunks)
        print(f"Using {num_gpus} GPUs, processing {parallel_chunks} chunks in parallel")
    else:
        # CPU only - limit parallelism to avoid memory issues
        parallel_chunks = min(max_chunks_parallel, cpu_count() // 4, num_chunks)
        print(f"Using CPU only, processing {parallel_chunks} chunks in parallel")
    
    # Process in batches of parallel chunks
    for batch_start in range(0, num_chunks, parallel_chunks):
        batch_end = min(batch_start + parallel_chunks, num_chunks)
        batch_args = chunk_args[batch_start:batch_end]
        
        print(f"\nProcessing chunk batch {batch_start+1}-{batch_end} of {num_chunks}")
        
        if len(batch_args) == 1:
            # Single chunk - process directly
            result = process_single_chunk(batch_args[0])
            chunk_results.append(result)
        else:
            # Multiple chunks - use multiprocessing
            with Pool(processes=len(batch_args)) as pool:
                results = pool.map(process_single_chunk, batch_args)
                chunk_results.extend(results)
        
        # Garbage collection between batches
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    # Create metadata
    chunk_metadata = {
        "num_chunks": num_chunks,
        "chunk_size": chunk_size,
        "total_structures": sum(r["num_structures"] for r in chunk_results),
        "total_errors": sum(r["num_errors"] for r in chunk_results),
        "total_skipped": sum(r["num_skipped"] for r in chunk_results),
        "chunks": []
    }
    
    # Add chunk info with proper indices
    start_idx = 0
    for result in sorted(chunk_results, key=lambda x: x["chunk_idx"]):
        end_idx = start_idx + result["num_structures"]
        chunk_info = {
            "chunk_idx": result["chunk_idx"],
            "filename": result["filename"],
            "path": result["path"],
            "num_structures": result["num_structures"],
            "num_errors": result["num_errors"],
            "num_skipped": result["num_skipped"],
            "start_idx": start_idx,
            "end_idx": end_idx
        }
        chunk_metadata["chunks"].append(chunk_info)
        start_idx = end_idx
    
    # Save metadata
    metadata_path = os.path.join(out_dir, "chunk_metadata.json")
    with open(metadata_path, "w") as f:
        json.dump(chunk_metadata, f, indent=2)
    
    print(f"\n{'=' * 80}")
    print(f"✅ Processing complete!")
    print(f"  - Total structures: {chunk_metadata['total_structures']}")
    print(f"  - Total errors: {chunk_metadata['total_errors']}")
    print(f"  - Total skipped: {chunk_metadata['total_skipped']}")
    print(f"  - Metadata saved: {metadata_path}")
    print(f"{'=' * 80}")
    
    return chunk_metadata


# ============= Helper function for PDB processing =============
def _process_pdb_path(pdb_path):
    """
    Process a single PDB file to extract sequence and coordinates.
    This function runs in a separate process.
    
    Returns:
        tuple: (status, pdb_id, data...) where status is "ok", "skip", or "error"
    """
    from Bio.PDB import PDBParser, is_aa
    
    parser = PDBParser(QUIET=True)
    pdb_id = os.path.splitext(os.path.basename(pdb_path))[0]
    
    try:
        structure = parser.get_structure(pdb_id, pdb_path)
        
        # Get first model
        model = structure[0]
        
        # Process each chain
        for chain in model:
            residues = [r for r in chain if is_aa(r)]
            if len(residues) == 0:
                continue
            
            # Extract sequence
            seq = ''.join([seq1(r.resname) for r in residues])
            
            # Extract coordinates [N, CA, C, O] for each residue
            coords = []
            for residue in residues:
                try:
                    n_coord = residue['N'].coord.tolist()
                    ca_coord = residue['CA'].coord.tolist()
                    c_coord = residue['C'].coord.tolist()
                    o_coord = residue['O'].coord.tolist()
                    coords.append([n_coord, ca_coord, c_coord, o_coord])
                except:
                    # Missing atoms - use zeros
                    coords.append([[0,0,0], [0,0,0], [0,0,0], [0,0,0]])
            
            return ("ok", pdb_id, seq, coords, chain.id)
        
        return ("skip", pdb_id, "No valid chains found")
        
    except Exception as e:
        return ("error", pdb_id, str(e))


# ============= Optimized Chunk Loader (same as before) =============
class StructureChunkLoader:
    """
    Efficient loader for chunked structure dictionaries.
    Loads chunks on-demand and caches them.
    """
    def __init__(self, chunk_dir="../data/structure_chunks/", cache_size=2):
        self.chunk_dir = chunk_dir
        self.cache_size = cache_size
        self.cache = {}  # chunk_idx -> structure_dict
        self.cache_order = []  # LRU tracking
        
        # Load metadata
        metadata_path = os.path.join(chunk_dir, "chunk_metadata.json")
        with open(metadata_path, "r") as f:
            self.metadata = json.load(f)
        
        print(f"Loaded metadata: {self.metadata['total_structures']} structures in {self.metadata['num_chunks']} chunks")
        
        # Build lookup: pdb_id -> chunk_idx
        self.pdb_to_chunk = {}
        for chunk_info in self.metadata["chunks"]:
            chunk_path = os.path.join(chunk_dir, chunk_info["filename"])
            if os.path.exists(chunk_path):
                with open(chunk_path, "r") as f:
                    chunk_data = json.load(f)
                    for pdb_id in chunk_data.keys():
                        self.pdb_to_chunk[pdb_id] = chunk_info["chunk_idx"]
            else:
                print(f"Warning: Chunk file not found: {chunk_path}")
    
    def _load_chunk(self, chunk_idx):
        """Load a chunk into cache, managing cache size."""
        if chunk_idx in self.cache:
            # Move to end (most recently used)
            self.cache_order.remove(chunk_idx)
            self.cache_order.append(chunk_idx)
            return self.cache[chunk_idx]
        
        # Load chunk
        chunk_info = self.metadata["chunks"][chunk_idx]
        chunk_path = os.path.join(self.chunk_dir, chunk_info["filename"])
        with open(chunk_path, "r") as f:
            chunk_data = json.load(f)
        
        # Add to cache
        self.cache[chunk_idx] = chunk_data
        self.cache_order.append(chunk_idx)
        
        # Evict oldest if cache is full
        if len(self.cache) > self.cache_size:
            oldest = self.cache_order.pop(0)
            del self.cache[oldest]
            gc.collect()  # Force garbage collection
        
        return chunk_data
    
    def get(self, pdb_id):
        """Get structure for a specific PDB ID."""
        if pdb_id not in self.pdb_to_chunk:
            return None
        
        chunk_idx = self.pdb_to_chunk[pdb_id]
        chunk_data = self._load_chunk(chunk_idx)
        return chunk_data.get(pdb_id)
    
    def get_batch(self, pdb_ids):
        """Get multiple structures efficiently by grouping by chunk."""
        # Group PDB IDs by chunk
        chunk_groups = {}
        for pdb_id in pdb_ids:
            if pdb_id in self.pdb_to_chunk:
                chunk_idx = self.pdb_to_chunk[pdb_id]
                if chunk_idx not in chunk_groups:
                    chunk_groups[chunk_idx] = []
                chunk_groups[chunk_idx].append(pdb_id)
        
        # Load each chunk and extract structures
        results = {}
        for chunk_idx, chunk_pdb_ids in chunk_groups.items():
            chunk_data = self._load_chunk(chunk_idx)
            for pdb_id in chunk_pdb_ids:
                if pdb_id in chunk_data:
                    results[pdb_id] = chunk_data[pdb_id]
        
        return results
    
    def get_available_pdb_ids(self):
        """Return set of all available PDB IDs."""
        return set(self.pdb_to_chunk.keys())


def build_mtl_dataset_optimized(df_fold, chunk_loader, task_cols=['pKi', 'pEC50', 'pKd', 'pIC50']):
    """
    Build MTL dataset efficiently using chunked structure loader.
    
    Args:
        df_fold: DataFrame with fold data
        chunk_loader: StructureChunkLoader instance
        task_cols: List of task columns
    
    Returns:
        MTL_DTA dataset
    """
    data_list = []
    
    # Get all protein IDs from the fold
    protein_ids = df_fold["protein_id"].tolist()
    
    # Batch load structures (efficient chunk-based loading)
    print(f"Loading structures for {len(protein_ids)} proteins...")
    structures_batch = chunk_loader.get_batch(protein_ids)
    
    # Process each row
    skipped = 0
    for i, row in tqdm(df_fold.iterrows(), total=len(df_fold), desc="Building dataset"):
        protein_id = row["protein_id"]
        
        if protein_id not in structures_batch:
            skipped += 1
            continue
        
        protein_json = structures_batch[protein_id]
        protein = featurize_protein_graph(protein_json)
        drug = featurize_drug(row["standardized_ligand_sdf"])
        
        # Collect all task values
        task_values = {}
        for task in task_cols:
            if task in row and not pd.isna(row[task]):
                task_values[task] = float(row[task])
            else:
                task_values[task] = np.nan
        
        data_entry = {
            "protein": protein,
            "drug": drug,
        }
        data_entry.update(task_values)
        data_list.append(data_entry)
    
    if skipped > 0:
        print(f"Warning: Skipped {skipped} entries due to missing structures")
    
    return MTL_DTA(df=df_fold, data_list=data_list, task_cols=task_cols)


# ============= USAGE EXAMPLE =============


# Input

In [None]:
import pandas as pd
import os
import json
from Bio.PDB import PDBParser, is_aa
from tqdm import tqdm
from transformers import EsmModel, EsmTokenizer
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
parser = PDBParser(QUIET=True)
import os
import pandas as pd
import torch
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error, r2_score
import matplotlib.pyplot as plt
import torch.nn.functional as F
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Defne tasks to train on
task_cols = ['pKi', 'pEC50', 'pKd', 'pKd (Wang, FEP)', 'pIC50', 'potency']


# Load your dataframe
df = pd.read_parquet("../data/standardized/standardized_input.parquet", engine="fastparquet")
df = df.sample(frac=1, random_state=42).reset_index(drop=True)

col_nan = ["standardized_protein_pdb", "standardized_ligand_sdf"] + task_cols
# df = df[df['is_experimental'] == True]
df = df.dropna(how = "all", subset=col_nan)
df = df.reset_index(drop=True)
df = df[df["standardized_protein_pdb"].isna()==False]
df = df[df["standardized_ligand_sdf"].isna()==False]
df = df[:50000]

# Calculate task ranges from your dataframe
task_ranges = prepare_mtl_experiment(df, task_cols)


# Load ESM2

In [None]:
# load ESM-2 model
model_name = "facebook/esm2_t33_650M_UR50D"
tokenizer = EsmTokenizer.from_pretrained(model_name)
esm_model = EsmModel.from_pretrained(model_name)
esm_model.eval().cuda() 


# Generate structure dict

In [None]:
# so chunk, then iterate in 4 chunk the esm embedding, then iterate each chuck in parralization and output each time the dict of strutucre 

In [None]:
import multiprocessing as mp
mp.set_start_method('spawn', force=True)  # Must be at the very beginning

from parallel_structure_processing_optimized import structureJSON_chunked_optimized


import time
start_time = time.time()

structure_metadata = structureJSON_chunked_optimized(
    df,
    num_gpus=1,
    cpu_workers=90
)

print("--- %s seconds ---" % (time.time() - start_time))


In [None]:


(0.23*64)

In [None]:
1

In [None]:
(0.23-5.63)/5.63

In [None]:
print(f"\nProcessing complete!")
print(f"Total structures processed: {structure_metadata['total_structures']}")
print(f"Total chunks created: {structure_metadata['num_chunks']}")

# Optional: Load and use the chunks later
from parallel_structure_processing_optimized import StructureChunkLoader

# Create a chunk loader (caches 2 chunks in memory at a time)
chunk_loader = StructureChunkLoader(
    chunk_dir="../data/structure_chunks/",
    cache_size=2
)


# Check validity

In [None]:
import json

with open("../data/structure_chunks/structures_chunk_0000.json", "r") as f:
    pdb_structures = json.load(f)

In [None]:
1