In [1]:
# conda env: datacat (python=3.8.2)
# for `utils.py`
# ==== all.py ====
from pathlib import Path
import os
import numpy as np

# ==== metrics.py ====
from sklearn import metrics

# ==== dataloader.py ====
import torch
from torch.utils.data import Dataset
from typing import Any, Iterable, List, Optional, Tuple, Union
import pandas as pd
from scipy import sparse
try: # only if Graph-Model is used
    import dgl
except: pass

# ==== utils.py ====
import json
from loguru import logger
import mlflow
import mlflow.entities
from datacat4ml.Scripts.model_dev.model_def import DotProduct
import datacat4ml.Scripts.model_dev.model_def as model_def
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiplicativeLR
from torch.utils.data import Subset, RandomSampler, SequentialSampler, BatchSampler
from scipy.special import expit as sigmoid
import scipy 

# ==== train.py ====
import argparse
#import mlflow
import random
import wandb
from time import time

from datacat4ml.const import DATA_DIR, SCRIPTS_DIR

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import wandb
wandb.login()
print(wandb.api.default_entity)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mxixichennn[0m ([33mxixichennn-freie-universit-t-berlin[0m). Use [1m`wandb login --relogin`[0m to force relogin


xixichennn-freie-universit-t-berlin


# models.py

`DotProduct` - The difference between `forward` and `forward_dense`
- forward: `preactivations = (compound_embeddings * assay_embeddings).sum(axis=1)`
- forward_dense: `preactivations = (compound_embeddings @ assay_embeddings.T)`

In [None]:
# Suppose Embedding size(D) = 2, number of compounds(N) = 3, number of assays(M) = 2
compound_embeddings = torch.tensor([[1, 2], [3, 4], [5, 6]])
print("compound_embeddings:\n", compound_embeddings)
print("compound_embeddings.shape:", compound_embeddings.shape) 
print("compound_embeddings.shape[0]:", compound_embeddings.shape[0])  
print("compound_embeddings.shape[1]:", compound_embeddings.shape[1], "\n")

assay_embeddings = torch.tensor([[10, 20], [30, 40], [50, 60]])
print("assay_embeddings:\n", assay_embeddings)
print("assay_embeddings.shape:", assay_embeddings.shape)  
print("assay_embeddings.shape[0]:", assay_embeddings.shape[0])
print("assay_embeddings.shape[1]:", assay_embeddings.shape[1], "\n")

forward = (compound_embeddings * assay_embeddings).sum(axis=1) 
#  [1*10 + 2*20, 3*30 + 4*40, 5*50 + 6*60] = （C1*A1, C2*A2, C3*A3）
print("compound_embeddings * assay_embeddings:\n", compound_embeddings * assay_embeddings)
print("(compound_embeddings * assay_embeddings).sum(axis=0):\n", (compound_embeddings * assay_embeddings).sum(axis=0))
print("forward:", forward, "\n")

forward_dense = (compound_embeddings @ assay_embeddings.T) 
# [[1*10 +2*20, 1*30 + 2*40, 1*50 + 2*60],
#  [3*10 + 4*20, 3*30 + 4*40, 3*50 + 4*60],
#  [5*10 + 6*20, 5*30 + 6*40, 5*50 + 6*60]]
# = 
# [[C1A1, C1A2, C1A3],
#  [C2A1, C2A2, C2A3],
#  [C3A1, C3A2, C3A3]]
print("compound_embeddings @ assay_embeddings.T:\n", forward_dense)
print('assay_embeddings.T:\n', assay_embeddings.T)
print("forward_dense:\n", forward_dense)
print("forward_dense.diag():\n", forward_dense.diagonal())

compound_embeddings:
 tensor([[1, 2],
        [3, 4],
        [5, 6]])
compound_embeddings.shape: torch.Size([3, 2])
compound_embeddings.shape[0]: 3
compound_embeddings.shape[1]: 2 

assay_embeddings:
 tensor([[10, 20],
        [30, 40],
        [50, 60]])
assay_embeddings.shape: torch.Size([3, 2])
assay_embeddings.shape[0]: 3
assay_embeddings.shape[1]: 2 

compound_embeddings * assay_embeddings:
 tensor([[ 10,  40],
        [ 90, 160],
        [250, 360]])
(compound_embeddings * assay_embeddings).sum(axis=0):
 tensor([350, 560])
Forward: tensor([ 50, 250, 610]) 

compound_embeddings @ assay_embeddings.T:
 tensor([[ 50, 110, 170],
        [110, 250, 390],
        [170, 390, 610]])
assay_embeddings.T:
 tensor([[10, 30, 50],
        [20, 40, 60]])
forward_dense:
 tensor([[ 50, 110, 170],
        [110, 250, 390],
        [170, 390, 610]])
forward_dense.diag():
 tensor([ 50, 250, 610])


In [6]:
# given assay_embedding.size = (3, 512), compound_embedding.size = (3, 512); i.e. 3 compounds and 3 assays, each with 512-dimensional embeddings
# initialize 'c_embeds' and 'a_embeds' with random integers 
#c_embeds = torch.randi((3, 512))

c_embeds = torch.rand((3, 512))
a_embeds = torch.rand((3, 512))
print("c_embeds.shape:", c_embeds.shape)
print("a_embeds.shape:", a_embeds.shape, "\n")

fwd = (c_embeds * a_embeds).sum(axis=1)
print("fwd:", fwd)
print("fwd.shape:", fwd.shape, "\n")

fwd_dense = (c_embeds @ a_embeds.T)
print("fwd_dense:\n", fwd_dense)
print("fwd_dense.shape:", fwd_dense.shape)
print("torch.diag(fwd_dense):", torch.diag(fwd_dense))

print(f'The length of c_embeds is len {len(c_embeds)}')
print(f'c_embeds.shape[0] is {c_embeds.shape[0]}')

c_embeds.shape: torch.Size([3, 512])
a_embeds.shape: torch.Size([3, 512]) 

fwd: tensor([129.9870, 128.8514, 130.4847])
fwd.shape: torch.Size([3]) 

fwd_dense:
 tensor([[129.9870, 130.7846, 136.4513],
        [125.9640, 128.8514, 129.9631],
        [124.3015, 125.3976, 130.4847]])
fwd_dense.shape: torch.Size([3, 3])
torch.diag(fwd_dense): tensor([129.9870, 128.8514, 130.4847])
The length of c_embeds is len 3
c_embeds.shape[0] is 3


In [7]:
# ===================== CustomCE ======================
input = fwd_dense
print(f'input before scaling:\n {input}\n')
# class customCE(nn.CrossEntropyLoss)
# def forward
#beta
beta = 1/(input.shape[0]**(1/2)) # = 1/√3 = 0.5773502691896257
print(f'beta is {beta}\n')

target = torch.tensor([1, 1, 0])
print(f'target before scaling:\n {target}\n')
print(f'target*2: {target*2}')
print(f'target*2 - 1: {target*2 - 1}\n')
# input
input = input*(target*2 - 1)*beta
print(f'input after scaling:\n {input} \n')

#target
target = torch.arange(0,len(input)).to(input.device)
print(f'target after scaling: {target}')

class CustomCE(nn.CrossEntropyLoss):
    """ Cross entropy loss """
    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        beta = 1/(input.shape[0]**(1/2))
        input = input*(target*2-1)*beta # target from [0,1] to [-1,1]
        target = torch.arange(0,len(input)).to(input.device)
        
        return F.cross_entropy(input, target, weight=self.weight,
                        ignore_index=self.ignore_index, reduction=self.reduction)

criterion = CustomCE()
loss = criterion(input, target)

print(f'loss calculated by CustomCE is {loss}')

input before scaling:
 tensor([[129.9870, 130.7846, 136.4513],
        [125.9640, 128.8514, 129.9631],
        [124.3015, 125.3976, 130.4847]])

beta is 0.5773502691896258

target before scaling:
 tensor([1, 1, 0])

target*2: tensor([2, 2, 0])
target*2 - 1: tensor([ 1,  1, -1])

input after scaling:
 tensor([[ 75.0480,  75.5085, -78.7802],
        [ 72.7253,  74.3924, -75.0342],
        [ 71.7655,  72.3983, -75.3354]]) 

target after scaling: tensor([0, 1, 2])
loss calculated by CustomCE is 86.402587890625


In [18]:
# ===================== Contrastive Loss ======================
# params
input = fwd_dense
print(f'input before scaling:\n {input}\n')
target = torch.tensor([0, 1, 0])
print(f'target before scaling:\n {target}\n')

# def forward
sigma = 1
bs = target.shape[0] # batch size.
print(f'bs is {bs}')

# modif
modif = (1-torch.eye(bs)).to(target.device) + (torch.eye(bs).to(target.device)*(target*2-1))
print(f'torch.eye(bs) is \n{torch.eye(bs)}')
print(f'(1-torch.eye(bs)) is \n{(1-torch.eye(bs))}')
print(f'(target*2-1) is {(target*2-1)}')
print(f'torch.eye(bs)*(target*2-1) is \n{torch.eye(bs)*(target*2-1)}')
print(f'modif:\n{modif}\n')

#input
input = input*modif/sigma
print(f'input after scaling:\n {input}\n')

diag_idx = torch.arange(0,len(input)).to(input.device)
print(f'diag_idx: {diag_idx}\n')

class ConLoss(nn.CrossEntropyLoss):
    """"Contrastive Loss"""
    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        sigma = 1
        bs = target.shape[0]
        #only modify diag that is a negative
        # eg makes this from a target of [0, 1, 0]
        #tensor([[ -1.,  1.,  1.],
        #        [  1.,  1.,  1.],
        #        [  1.,  1., -1.]])
        modif = (1-torch.eye(bs)).to(target.device) + (torch.eye(bs).to(target.device)*(target*2-1)) 
        input = input*modif/sigma

        diag_idx = torch.arange(0,len(input)).to(input.device)

        #label_smoothing = hparams.get('label_smoothing', 0.0)
        #if label_smoothing is None: #if it's in label_smoothing but still None
        label_smoothing = 0.0

        mol2txt = F.cross_entropy(input,   diag_idx, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction, label_smoothing=label_smoothing)
        txt2mol = F.cross_entropy(input.T, diag_idx, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction, label_smoothing=label_smoothing)
        return mol2txt+txt2mol

criterion = ConLoss()
loss = criterion(input, target)
print(f'loss calculated by ConLoss is {loss}')

input before scaling:
 tensor([[125.0471, 127.5372, 128.1412],
        [122.0491, 129.2777, 130.8487],
        [125.1797, 129.8019, 126.7569]])

target before scaling:
 tensor([0, 1, 0])

bs is 3
torch.eye(bs) is 
tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])
(1-torch.eye(bs)) is 
tensor([[0., 1., 1.],
        [1., 0., 1.],
        [1., 1., 0.]])
(target*2-1) is tensor([-1,  1, -1])
torch.eye(bs)*(target*2-1) is 
tensor([[-1.,  0., -0.],
        [-0.,  1., -0.],
        [-0.,  0., -1.]])
modif:
tensor([[-1.,  1.,  1.],
        [ 1.,  1.,  1.],
        [ 1.,  1., -1.]])

input after scaling:
 tensor([[-125.0471,  127.5372,  128.1412],
        [ 122.0491,  129.2777,  130.8487],
        [ 125.1797,  129.8019, -126.7569]])

diag_idx: tensor([0, 1, 2])

loss calculated by ConLoss is 4.809605598449707


In [None]:
# =============== top_k_accuracy: pos, neg ==============
from 
activity = torch.randint(0, 2, (10,))
print(f'activity is {activity}')
ks = [1, 5, 10, 50] 

# Ground-truth matches = diagonal indices
y_true = torch.arange(0, 10)
print(f'y_true is {y_true}')

# --- positives only ---
pos_mask = activity == 1
if pos_mask.any(): # Check if there are any True values, i.e., this batch of activities contains at least one '1'.
    tkaccs_pos, arocc_pos = top_k_accuracy(y_true[pos_mask],preactivations[pos_mask], k=ks, ret_arocc=True)

tensor([0, 1, 1, 0, 0, 0, 1, 0, 0, 1])


# metrics.py

In [7]:
def get_sparse_data(m, i):
    """
    Get the non-zero data from a sparse matrix by an index.

    Params
    ------
    m : scipy.sparse.csr_matrix
    i: index

    Returns
    -------
    m.indptr[i] to m.indptr[i + 1] data: List[float]
        A list of non-zero values in the sparse matrix row at index `i`.
    """
    return [m.data[index] for index in range(m.indptr[i], m.indptr[i + 1])]

In [24]:
# A dense matrix example
A= [
    [0, 0, 3],
    [4, 0, 5],
    [0, 0, 0],
    [6, 7, 0]
]

# convert to CSR: compressed sparse row.
m = sparse.csr_matrix(A)

print(f'm is\n{m},\nm.data is {m.data},\nm.indices is {m.indices},\nm.indptr is {m.indptr}\n') # indptr is the index pointer (the index of data), which indicates where each row starts in the data array.

print(f'range(m.indptr[i], m.indptr[i + 1]) is {range(m.indptr[0], m.indptr[0 + 1])}\n')
print(f'm.indptr[0] is {m.indptr[0]}, m.indptr[1] is {m.indptr[1]}, m.indptr[2] is {m.indptr[2]}, m.indptr[3] is {m.indptr[3]}, m.indptr[4] is {m.indptr[4]}\n')

print(f'get_sparse_data(m, 0) is {get_sparse_data(m, 0)}')
print(f'get_sparse_data(m, 1) is {get_sparse_data(m, 1)}')
print(f'get_sparse_data(m, 2) is {get_sparse_data(m, 2)}')
print(f'get_sparse_data(m, 3) is {get_sparse_data(m, 3)}')

m is
  (0, 2)	3
  (1, 0)	4
  (1, 2)	5
  (3, 0)	6
  (3, 1)	7,
m.data is [3 4 5 6 7],
m.indices is [2 0 2 0 1],
m.indptr is [0 1 3 3 5]

range(m.indptr[i], m.indptr[i + 1]) is range(0, 1)

m.indptr[0] is 0, m.indptr[1] is 1, m.indptr[2] is 3, m.indptr[3] is 3, m.indptr[4] is 5

get_sparse_data(m, 0) is [3]
get_sparse_data(m, 1) is [4, 5]
get_sparse_data(m, 2) is []
get_sparse_data(m, 3) is [6, 7]


## BEDROC

In [2]:
y_pred_proba = np.array([[0.2, 0.8], [0.1, 0.9], [0.4, 0.6], [0.3, 0.7]])
y_true = np.array([1, 0, 1, 0])

print(f'y_pred_proba[:, 1] is\n{y_pred_proba[:, 1]}\n')

score = list(zip(y_pred_proba[:, 1], y_true))
print(f'score is\n{score}\n')

score.sort(key=lambda x: x[0], reverse=True)
print (f'score sorted by the first element is\n{score}\n')

y_pred_proba[:, 1] is
[0.8 0.9 0.6 0.7]

score is
[(0.8, 1), (0.9, 0), (0.6, 1), (0.7, 0)]

score sorted by the first element is
[(0.9, 0), (0.8, 1), (0.7, 0), (0.6, 1)]



In [10]:
#ef calc_bedroc(y_true, y_pred, decreasing=True, alpha=20.0):
#   """
#   Computes the BEDROC (Boltzmann-Enhanced Discrimination of Receiver Operating Characteristic) score.
#
#   Params:
#       y_true (array-like):
#           Binary class labels. 1 for positive class, 0 otherwise.
#       y_pred (array-like):
#           Prediction values.
#       decreasing (bool):
#           True if high values of `y_pred` correlates to positive class.
#       alpha (float):
#           Early recognition parameter.
#   Returns:
#       float:
#           Value in interval [0, 1] indicating degree to which the predictive technique employed detects (early) the positive class.
#
#   """
#
#   assert len(y_true) == len(y_pred), 'The number of scores must be equal to the number of labels.'
#
#   big_n = len(y_true)
#   n = sum(y_true == 1)
#
#   if decreasing:
#       order = np.argsort(-y_pred)
#   else:
#       order = np.argsort(y_pred)
#
#   m_rank = (y_true[order] == 1).nonzero()[0] + 1
#
#   s = np.sum(np.exp(-alpha * m_rank / big_n))
#
#   r_a = n / big_n
#
#   rand_sum = r_a * (1 - np.exp(-alpha)) / (np.exp(alpha/big_n) -1)
#
#   fac = r_a * np.sinh(alpha / 2) / (np.cosh(alpha /2) - np.cosh(alpha/2 -alpha * r_a))
#
#   cte = 1 / (1 - np.exp(alpha * (1 - r_a)))
#
#   return s * fac / rand_sum + cte

from rdkit.ML.Scoring.Scoring import CalcBEDROC

def calc_bedroc_on_clip(y_true, y_score, alpha: float = 20.0):
    """ Calculates the bedroc score unsing rdkit.ML.Scoring.CalcBEDROC.
    The source code is available at https://github.com/rdkit/rdkit/blob/master/rdkit/ML/Scoring/Scoring.py#L103
    This function is defined as `def CalcBEDROC(score, col, alpha)`, 
        where `score` is ordered list with tuples of (pred_proba, true value), with pred_proba being descendingly sorted,
        'col' is the column index for true values, i.e. 1 for the positive class (1), 
        and `alpha` is the early recognition parameter.

    
    Params
    ------
    y_true: (lst/array) a list of true values for all compounds.
    y_p: (lst/array) a list of predicted probabilities for all compounds, i.e. the value of model.predict_proba(x_test). 
                   y_pred_proba[:, 1] is the probability of the positive class (1).
    alpha: (float)  early recognition parameter. 
            alpha = 80.5, 2% of the top-ranked compounds of the all compounds were calculated; 2% represents the proportion of active compounds in the DUD-E database;
            alpha = 321.5, 0.5% of the top-ranked compounds of the all compounds  were calculated; 4 times smaller than 2% --> early recognition.
            alpha = 20.0(default), 8% of the top-ranked compounds of the all compounds were calculated; 4 times larger than 2% --> is interesting for the cases where relatively high-throughput experiments are available.

    returns
    -------
    (float) BEDROC score
    """

    pair = list(zip(y_score, y_true)) # pair the predicted scores with the true values
    pair.sort(key=lambda x: x[0], reverse=True)
    bedroc_score= CalcBEDROC(pair, 1, alpha) # 1 is the column index for the ground-truth values (y_true)

    return bedroc_score

In [11]:
def swipe_threshold_sparse(targets, scores, bedroc_alpha = 20, verbose=True, ret_dict=False):
    """
    This function computes metrics per assay (i.e., column-wise):

    Compute ArgMaxJ, AUROC, AVGP, AUPRC and BEDROC (and more if ret_dict=True) metrics for the true binary values
    `targets` given the predictions `scores`.

    Params
    ---------
    targets: :class:`scipy.sparse.csc_matrix`, shape(N, M) # N refers to the number of compounds, M refers to the number of assays.
        True target values.
    scores: :class:`scipy.sparse.csc_matrix`, shape(N, M)
        Predicted values
    bedroc_alpha: float
        Early recognition parameter for BEDROC. Default is 20.0, which is interesting for the cases where relatively high-throughput experiments are available.
    verbose: bool
        Be verbose if True.
    

    Returns
    ---------
    tuple of dict
        - ArgMaxJ of each valid column keyed by the column index (assay index), # get the optimal threshold that maximizes the difference between true positive rate (TPR) and false positive rate (FPR).
        - AUROC of each valid column keyed by the column index (assay index) # AUROC
        - AVGP of each valid column keyed by the column index (assay index) # average precision score
        - NegAVGP of each valid column keyed by the column index (assay index) # average precision score for the negative class (1 - y_true)
        - dAVGP of each valid column keyed by the column index (assay index) # difference between average precision and the mean of y_true
        - dNegAVGP of each valid column keyed by the column index (assay index) # difference between average precision for the negative class and the mean of 1 - y_true
        - AUPRC of each valid column keyed by the column index (assay index) # area under the precision-recall curve
        - BEDROC of each valid column keyed by the column index (assay index) # early recognition.
    """

    assert targets.shape == scores.shape, '"targets" and "scores" must have the same shape.' # assert <condition>, <error message>
    
    # find non-empty columns
    # (https://mike.place/2015/sparse/ for CSR, but works for CSC, too)
    non_empty_idx = np.where(np.diff(targets.indptr) != 0)[0] # Return the compounds that have at least one assay with a non-zero value?

    counter_invalid = 0
    argmax_j, auroc, avgp, neg_avgp, davgp, dneg_avgp, auprc, bedroc = {}, {}, {}, {}, {}, {}, {}, {}

    for col_idx in non_empty_idx: # This function computes metrics per assay (i.e., column-wise):
        y_true = np.array(list(get_sparse_data(targets, col_idx)))
        if len(pd.unique(y_true)) == 1: # `pd.unique` is faster than `np.unique` and `set`.
            counter_invalid += 1
            continue
        y_score = np.array(list(get_sparse_data(scores, col_idx)))
        assert len(y_true) == len(y_score)

        fpr, tpr, thresholds = metrics.roc_curve(y_true, y_score)
        assert len(fpr) == len(tpr) == len(thresholds), 'Length mismatch between "fpr", "tpr", and "thresholds".'
        argmax_j[col_idx] = thresholds[np.argmax(tpr - fpr)] 

        auroc[col_idx] = metrics.roc_auc_score(y_true, y_score)
        avgp[col_idx] = metrics.average_precision_score(y_true, y_score)
        neg_avgp[col_idx] = metrics.average_precision_score(1 - y_true, 1 - y_score)
        davgp[col_idx] = avgp[col_idx] - y_true.mean()
        dneg_avgp[col_idx] = neg_avgp[col_idx] - (1 - y_true.mean())

        # check if the auprc is same as avgp.
        precision, recall, thresholds = metrics.precision_recall_curve(y_true, y_score)
        auprc[col_idx] = metrics.auc(recall, precision)
        
        bedroc[col_idx] = calc_bedroc_on_clip(y_true, y_score, alpha=bedroc_alpha)

    if verbose:
        logger.info(f'Found {len(auroc)} columns with both positive and negative samples.')
        logger.info(f'Found and skipped {counter_invalid} columns with only positive or negative samples.')

    if ret_dict:
        return {'argmax_j':argmax_j, 'auroc':auroc, 'avgp':avgp, 'neg_avgp':neg_avgp,
                'davgp':davgp, 'dneg_avgp':dneg_avgp, 'auprc':auprc, 'bedroc':bedroc}

    return argmax_j, auroc, avgp, neg_avgp, davgp, dneg_avgp, auprc, bedroc

# dataloader.py

In [3]:
def get_sparse_indices_and_data(m, i):
    """Get the indices and data of a sparse matrix.
    
    Params:
    -------
    m: a sparse matrix in CSR format
    i: the index of the row for which to extract the non-zero elements. #? non-zero?

    Returns:
    -------
    tuple: (indices, data)
        col_indices: the column indices of the non-zero data in row i. (in csr format, only the non-zero elements are stored)
        data: the values of the non-zero elements in row i.
    """
    # `m.data`: the non-zero values of the sparse matrix
    # `m.indices`: the column indices of the non-zero values
    # `m.indptr`: which maps the elements of `data` and `indices` to the rows of the sparse matrix. Explanation: https://stackoverflow.com/questions/52299420/scipy-csr-matrix-understand-indptr
    col_indices = m.indices[m.indptr[i]:m.indptr[i+1]] 
    data = m.data[m.indptr[i]:m.indptr[i+1]]
    return col_indices, data

In [4]:
class InMemoryClamp(Dataset):
    """
    Subclass of :class:`torch.utils.data.Dataset` holding BioBert activity data, 
    that is, activity triplets, and compound and assay feature vectors.

    :class:`InMemoryClamp` supports two different indexing (and iteration) styles. 
    The default style is to itreate over `(compound, assay, activity)` COO triplets, however they are sorted.
    The "meta-assays" style consists in interating over unique compounds using a CSR sparse structure,
    and averaging the feature vectors of the positive and negative assays of each compound. 

    By inheriting from :class:`torch.utils.data.Dataset`, this class must implement at least two methods:
    - :meth:`__len__` to return the size of the dataset.
    - :meth:`__getitem__` to retrieve a single data point from the dataset.
    """

    def __init__(
            self,
            root: Union[str, Path],
            assay_mode: str,
            compound_mode: str = None, 
            train_size: float = 0.6, 
            aid_max: int = None, #? Yu: could be removed
            cid_max: int = None, #? Yu: could be removed
            verbose: bool = True
    ) -> None:
        """
        Instantiate the dataset class.

        - The data is loaded in memory with the :meth:`_load_dataset` method.
        - Splits are created separately along compounds and along assays with the :meth:`_find_splits` method. Compound and assay splits can be interwoven with the :meth:`subset` method.

        Params:
        root: str or :class:`pathlib.Path`
            Path to a directory of ready BioBert files.
        assay_mode: str
            Type of assay features ("biobert-last", "biobert-two-last", or "lsa").
        train_size: float (between 0 and 1)
            Fraction of compounds and assays assigned to training data.
        verbose: bool
            Be verbose if True.
        """
        self.root = Path(root)
        self.assay_mode = assay_mode
        self.compound_mode = compound_mode
        self.train_size = train_size
        self.verbose = verbose

        self._load_dataset()
        self.find_splits()

        self.meta_assays = False
        self.assay_onehot = None

    def _load_dataset(self) -> None:
        """
        Load prepared dataset from the `root` directory:

        - `activity`: Parquet file containing `(compound, assay, activity)` triplets. Compounds and assays are represented by indices, 
        and thus the file is directly loaded into a :class:`scipy.sparse.coo_matrix` with rows corresponding to compounds and columns corresponding to assays.

        - `compound_names`: Parquet file containing the mapping between the compound index used in `activity` and the corresponding compound name.
        It is loaded into a :class:`pandas.DataFrame`.

        - `assay_names`: Parquet file containing the mapping between the assay index used in `activity` and the corresponding assay name. 
        It is loaded into a :class:`pandas.DataFrame`.

        - `compound_features`: npz file containing the compound features array, where the feature vector for the compound indexed by `idx` is stored in the `idx`-th row. 
        It is loaded into a :class:`scipy.sparse.csr_matrix`.

        - `assay_features`: npy file containing the assay features array, where the feature vector for the assay indexed by `idx` is stored in the `idx`-th row.
        It is loaded into a :class:`numpy.ndarray`.

        Compute the additional basic dataset attributes `num_compounds`, `num_assays`, `compound_feature_size`, `assay_feature_size`.
        """

        if self.verbose:
            logger.info(f'Load dataset from "{self.root} with {self.assay_mode}" assay features.')

        #======= Load compound data =======
        with open(self.root / 'compound_names.parquet', 'rb') as f:
            self.compound_names = pd.read_parquet(f)
        self.num_compounds = len(self.compound_names)

        compound_modes = self.compound_mode.split('||') if self.compound_mode is not None else 1 #? Yu: replace `||` with `+` ?        if len(compound_modes) >1:
        if len(compound_modes) > 1:
            logger.info('Multiple compound modes are concatenated ')
            self.compound_features = np.concatenate([self._load_compound(cm) for cm in compound_modes], axis=1)
        else:
            self.compound_features = self._load_compound(self.compound_mode)
        # compound_feature_size
        if 'graph' in self.compound_mode and (not 'graphormer' in self.compound_mode):
            self.compound_features_size = self.compound_features[0].ndata['h'].shape[1] # in_edge_feats. #? Yu
        elif isinstance(self.compound_features, pd.DataFrame):
            self.compound_features_size = 40000 #? Yu
        else:
            if len(self.compound_features.shape)>1:
                self.compound_features_size = self.compound_features.shape[1]
            else:
                self.compound_features_size = 1

        #======== Load assay data ========
        with open(self.root / 'assay_info.parquet', 'rb') as f:
            self.assay_names = pd.read_parquet(f)
        self.num_assays = len(self.assay_names)

        assay_modes = self.assay_mode.split('||')
        if len(assay_modes)>1:
            logger.info('Multiple assay modes are concatenated')
            self.assay_features = np.concatenate([self._load_assay(am) for am in assay_modes], axis=1)
        else:
            self.assay_features = self._load_assay(self.assay_mode)

        # assay_feature_size
        if (self.assay_features is None):
            self.assay_features_size = 512 #wild guess also 512#? Yu
        elif len(self.assay_features.shape)==1:
            # its only a list, so probably text
            self.assay_features_size = 768 #? Yu
        else:
            self.assay_features_size = self.assay_features.shape[1]
        
        #======= Load activity data =======
        with open(self.root / 'activity.parquet', 'rb') as f:
            activity_df = pd.read_parquet(f)
            self.activity_df = activity_df
        
        # ? Yu: will the :meth:`sparse.coo_matrix` only keep the non-zero values? If so, only the active compounds (where activity  is not 0) will be kept?
        self.activity = sparse.coo_matrix(
            (
                activity_df['activity'],# activity is the value
                (activity_df['compound_idx'], activity_df['assay_idx']) # compound in row, assay in column.
            ),
            shape=(self.num_compounds, self.num_assays),
        )
    
    def _load_compound(self, compound_mode=None):
        cmpfn = f'compound_features{"_"+compound_mode if compound_mode else ""}'
        #?Yu: if 'graph' is not used, remove the below code
        if 'graph' in compound_mode and (not 'graphormer' in compound_mode):
            logger.info(f'graph in compound mode: loading '+cmpfn)
            import dgl
            from dgl.data.utils import load_graphs
            compound_features = load_graphs(str(self.root/(cmpfn+".bin")))[0]
            compound_features = np.array(compound_features)
        elif compound_mode == 'smiles':
            compound_features = pd.read_parquet(self.root/('compound_smiles.parquet'))['CanonicalSMILES'].values
        else:
            try: #tries to open npz files else npy
                with open(self.root/(cmpfn+".npz"), 'rb') as f:
                    compound_features = sparse.load_npz(f)
            except:
                logger.info(f'loading '+cmpfn+'.npz failed, using .npy instead')
                try:
                    compound_features = np.load(self.root/(cmpfn+".npy"))
                except:
                    logger.info(f'loading '+cmpfn+'.npy failed, trying to compute it on the fly')
                    compound_features = pd.read_parquet(self.root/('compound_smiles.parquet'))
        return compound_features
    
    def _load_assay(self, assay_mode='lsa') -> None: #? Yu: 'lsa'
        """ loads assay """
        if assay_mode =='':
            print('no assay features')
            return None
        
        #? Yu: if the below assay modes are not used, remove them.
        if assay_mode == 'biobert-last':
            with open(self.root/('assay_features_dmis-lab_biobert-large-cased-v1.1_last_layer.npy'), 'rb') as f:
                return np.load(f, allow_pickle=True)
        elif assay_mode == 'biobert-two-last':
            with open(self.root/('assay_features_dmis-lab_biobert-large-cased-v1.1_penultimate_and_last_layer.npy'), 'rb') as f:
                return  np.load(f, allow_pickle=True)
        
        # load the prepared assay features
        try: # tries to open npz file else npy
            with open(self.root/(f'assay_features_{assay_mode}.npz'), 'rb') as f:
                return sparse.load_npz(f)
        except:
            with open(self.root/(f'assay_features_{assay_mode}.npy'), 'rb') as f:
                return np.load(f, allow_pickle=True)
        
        return None

    def _find_splits(self) -> None:
        """
        We assume that during the preparation of the PubChem data, compounds(assays) have been indexed 
        so that a larger compound(assay) index corresponds to a compound(assay) incorporated to PubChem later in time.
        This function finds the compound(assay) index cut-points to create three chronological disjoint splits.

        The oldest `train_size` fraction of compounds(assays) are assigned to training. 
        From the remaining compounds(assays), the oldest half are assigned to vailidation, and the newest half are assigned to test.
        Only the index cut points are stored.
        """
        if self.verbose:
            logger.info(f'Find split cut-points for compound and assay indices (train_size={self.train_size}).')

        first_cut, second_cut = self._chunk(self.num_compounds, self.train_size)
        self.compound_cut = {'train': first_cut, 'valid': second_cut}

        first_cut, second_cut = self._chunk(self.num_assays, self.train_size)
        self.assay_cut = {'train': first_cut, 'valid': second_cut}

    def _chunk(n:int, first_cut_ratio:float) -> Tuple[int, int]:
        """
        Find the two cut points required to chunk a sequence of `n` items into three parts, 
        the first having `first_cut_ratio` of the items, 
        the second and the third having approximately the half of the remaining items.

        Params
        -------
        n: int
            Length of the sequence to chunk.
        first_cut_ratio: float
            Portion of items in the first chunk. This is the `train_size`
        
        Returns
        -------
        int, int
            Positions where the first and second cut occurs.
        """
        first_cut = int(round(first_cut_ratio * n))
        second_cut = first_cut + int(round((n - first_cut) / 2))

        return first_cut, second_cut

    def subset(
            self, 
            c_low: Optional[int] = None,
            c_high: Optional[int] = None,
            a_low: Optional[int] = None,
            a_high: Optional[int] = None,
    ) -> np.ndarray:
        if c_low is None: # sef the compound low index to 0
            c_low = 0
        if c_high is None: # set the compound high index to the number of compounds
            c_high = self.num_compounds
        if a_low is None: # set the assay low index to 0
            a_low = 0
        if a_high is None: # set the assay high index to the number of assays
            a_high = self.num_assays

        if self.verbose:
            logger.info(f'Find activity triplets where {c_low} <= compound_idx <= {c_high} and {a_low} <= assay_idx <= {a_high}.')
        
        activity_bool = np.logical_and.reduce( # take multiple Boolean conditions and combines them using logical AND across all conditions.
            (
                self.activity.row >= c_low,
                self.activity.row < c_high,
                self.activity.col >= a_low,
                self.activity.col < a_high
            )
        )

        return np.flatnonzero(activity_bool) # applies the logical condition to the COO matrix and returns the indices that satisfy the condition.

    def get_unique_names(
            self, 
            activity_idx: Union[int, Iterable[int], slice]
    ) -> Tuple[pd.DataFrame, pd.DataFrame]:
        """
        Get the unique compound and assay names within the `activity` triplets  indexed by `activity_idx` in default, COO style.

        Params:
        -------
        activity_idx: int, iterable of int, slice
            Index to one or multiple `activity` triplets.
        
        Returns:
        -------
        compound_names: :class:`pandas.DataFrame`
        assay_names: :class:`pandas.DataFrame`
        """

        compound_idx = self.activity.row[activity_idx]
        assay_idx = self.activity.col[activity_idx]

        if isinstance(compound_idx, np.ndarray) and isinstance(assay_idx, np.ndarray):
            compound_idx = pd.unique(compound_idx)
            assay_idx = pd.unique(assay_idx)
        
        elif isinstance(compound_idx, (int, np.integer)) and isinstance(assay_idx, (int, np.integer)):
            pass # a single index means a single compound and assay, so no need to do anything.

        else:
            raise ValueError('activity_idx must be an int, iterable of int, or slice.')

        compound_names = self.compound_names.iloc[compound_idx]
        assay_names = self.assay_names.iloc[assay_idx]

        return compound_names.sort_index(), assay_names.sort_index() # sort the names alphabetically

    def getitem(
            self,
            activity_idx: Union[int, Iterable[int], slice],
            ret_np=False
    ) -> Tuple[Any, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        
        Params
        -------
        activity_idx: int, iterable of int, slice
            Specifies the indices of the activity triplets to retrieve.
        ret_np: bool
            Determines the format of the returned data. If True, returns numpy arrays; If False, returns PyTorch tensors.
        
        Returns
        -------
        tuple of :class:`torch.Tensor`
        - `activity_idx`: the original indices provided as input. This will enable to reconstruct the order in which the dataset has been visited.
        - `compound_features`: shape(len(activity_idx), compound_feature_size)
        - `assay_features`: shape(len(activity_idx), assay_feature_size)
        - `activity`: shape(len(activity_idx), ).
        """
        compound_idx = self.activity.row[activity_idx]
        assay_idx = self.activity.col[activity_idx]
        activity = self.activity.data[activity_idx]

        # ===== get compound_features =====
        if isinstance(self.compound_features, pd.DataFrame):
            compound_smiles = self.compound_features.iloc[compound_idx]['CanonicalSMILES'].values
            from datacat4ml.Scripts.data_prep.data_featurize.compound_featurize.encode_compound  import convert_smiles_to_fp
            if self.compound_mode == 'MxFP':
                fptype = 'maccs+morganc+topologicaltorsion+erg+atompair+pattern+rdkc+mhfp+rdkd'
            else:
                fptype = self.compound_mode
            # Todo: fp_size as input parameter
            fp_size = 40000 #? Yu
            compound_features = convert_smiles_to_fp(compound_smiles, fp_size=fp_size, which=fptype, radius=2, njobs=1).astype(np.float32)
        else:
            compound_features = self.compound_features[compound_idx]
            if isinstance(compound_features, sparse.csr_matrix):
                compound_features = compound_features.toarray()
        

        # ===== get assay_features =====
        assay_features = self.assay_features[assay_idx]
        if isinstance(assay_features, sparse.csr_matrix):
            assay_features = assay_features.toarray()
        
        #? Yu: if not used, remove the below code
        try:
            assay_onehot = self.assay_onehot[assay_idx].toarray()
        except (TypeError, ValueError):
            assay_onehot = np.zeros_like(assay_features)
        
        # ===== Handle single indices =====
        # If `activity_idx`is a single integer or a list with only one element, the retrieved feature vectors are reshaped into 1D arrays to maintain the consistency of the output format.
        if isinstance(activity_idx, (int, np.integer)):
            compound_features = compound_features.reshape(-1) 
            assay_features = assay_features.reshape(-1) 
            assay_onehot = assay_onehot.reshape(-1) 
            activity = [activity]
        elif isinstance(activity_idx, list):
            if len(activity_idx) == 1:
                compound_features = compound_features.reshape(-1)
                assay_features = assay_features.reshape(-1)
                assay_onehot = assay_onehot.reshape(-1)
        activity = np.array(activity)

        # ===== Return =====
        # return the data as Numpy arrays.
        if ret_np:
            return(
                activity_idx,
                compound_features, #already float32
                assay_features if not isinstance(assay_features[0], str) else assay_features, # already float32
                assay_onehot if not isinstance(assay_onehot[0], str) else assay_onehot, # already float32
                (float(activity)) # torch.nn.BCEWithLogitsLoss needs this to be float too...
            )

        # return the data as PyTorch tensors.
        if self.compound_mode == 'smiles':
            comp_feat = compound_features
        elif isinstance(compound_features, np.ndarray):
            comp_feat = torch.from_numpy(compound_features)
        elif not isinstance(compound_features[0], dgl.DGLGraph):
            comp_feat = dgl.batch(compound_features)
        else:
            comp_feat = compound_features

        return  (
            activity_idx, 
            comp_feat, # alread float32
            torch.from_numpy(assay_features) if not isinstance(assay_features[0], str) else assay_features, # already float32
            torch.from_numpy(assay_onehot.astype(int)) if not isinstance(assay_onehot[0], str) else assay_onehot, # already float32
            torch.from_numpy(activity).float() # torch.nn.BCEWithLogitsLoss needs this to be float too...
        )

    def getitem_meta_assay(
            self,
            compound_idx: Union[int, List[int], slice]
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        For a given compound (or list), retrieve the data in the `meta-assay` style, 
        which involves summarizing assay feature vectors (positive and negative) for each compound.
        
        Params
        -------
        compound_idx: int, iterable of int, slice
            Index to one or multiple compounds.
        
        Returns
        -------
        tuple of :class:`torch.Tensor`
        - `compound_features`: shape(N, compound_feature_size)
        - `assay_features`: shape(N, assay_feature_size)
        - `activity`: shape(N, )
        """
        
        # extract the data for the specified compounds
        activity_slice = self.activity.tocsr()[compound_idx] 

        # find non-empty rows
        # `activity_slic.indptr`: pointer to the start of each row in the sparse matrix.
        # `np.diff(activity_slice.indptr)`: measures the number of elements in each row.
        # `np.where(...!=0)`: finds rows that contain non-zero elements. (i.e., rows with at least one assay-related to the compound)`
        non_empty_row_idx = np.where(np.diff(activity_slice.indptr)!=0)[0] #?

        # initialize containers for results
        compound_features_l = [] # list of compound features
        assay_positive_features_l, assay_negative_features_l = [], [] # averaged features of positive assays, and negative assays
        activity_l = [] # activity lables

        # process each non-empty row
        for row_idx in non_empty_row_idx:
            positive_l, negative_l  = [], []
            for col_idx, activity in get_sparse_indices_and_data(activity_slice, row_idx):
                if activity == 0:
                    negative_l.append(self.assay_features[col_idx]) 
                else:
                    positive_l.append(self.assay_features[col_idx])
            
            if len(negative_l) > 0:
                compound_features_l.append(self.compound_features[row_idx])
                negative = np.vstack(negative_l).mean(axis=0)
                assay_negative_features_l.append(negative)
                activity_l.append(0)
            
            if len(positive_l) > 0:
                compound_features_l.append(self.compound_features[row_idx])
                positive = np.vstack(positive_l).mean(axis=0)
                assay_positive_features_l.append(positive)
                activity_l.append(1)

        compound_features = sparse.vstack(compound_features_l).toarray()
        assay_features_l = np.vstack(
            assay_negative_features_l + assay_positive_features_l # '+' is used to concatenate the two lists
        )

        activity = np.array(activity_l)

        return (
            torch.from_numpy(compound_features), # already float32
            torch.from_numpy(assay_features_l), # already float32
            torch.from_numpy(activity).float() # torch.nn.BCEWithLogitsLoss needs this to be float too...
        )

    @staticmethod
    def collate(batch_as_list:list) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Necessary for :meth:`getitem_meta_assay` if using a :class:`torch.utils.data.DataLoader`.
        Not necessaryif using :class:`torch.utils.data.BatchSampler`, as I typically do.

        Params
        -------
        batch_as_list: list
            Result of :meth:`getitem_meta_assay` for a mini-batch.

        Returns
        -------
        tuple of :class:`torch.Tensor`
            Data for a mini-batch.
        """
        compound_features_t, assay_features_t, activty_t = zip(*batch_as_list)
        return(
            torch.cat(compound_features_t, dim=0),
            torch.cat(assay_features_t, dim=0),
            torch.cat(activty_t, dim=0)
        )
        

    def __getitem__(self, idx: Union[int, Iterable[int], slice]) -> Tuple:
        """
        Index or slice `activity` by `idx`. The indexing mode depends on the value of `meta_assays`. 
        If False(default), the indexing is over COO triplets.
        If True, the indexing is over unique compounds.
        """

        if self.meta_assays:
            return self.getitem_meta_assay(idx)
        else:
            return self.getitem(idx)
        
    def __len__(self) -> int:
        """
        Return the length of the dataset.

        - If `meta_assays` is False (default), length is defined as the number of `(compound, assay, activity)` COO triplets.
        - If `meta_assays` is True, length is defined as the number of the unique compounds.
        """

        if self.meta_assays:
            return self.num_compounds
        else:
            return self.activity.nnz # the number of non-zero elements in the sparse matrix
        
    def __repr__(self):
        return f'InMemoryClamp\n' \
               f'\troot="{self.root}"\n' \
               f'\tassay_mode="{self.assay_mode}"\n' \
               f'\ttrain_size={self.train_size}\n' \
               f'\tactivity.shape={self.activity.shape}\n' \
               f'\tactivity.nnz={self.activity.nnz}\n' \
               f'\tmeta_assays={self.meta_assays}'

In [3]:
# About funciton `subset`
c_low = 1
c_high = 4
a_low = 2
a_high = 5

activity_df = pd.DataFrame({
    "compound_idx": [0, 1, 2, 3, 4],
    "assay_idx": [0, 0, 0, 0, 0],
    "activity": [1, 0, 1, 0, 1]
})

self_activity = sparse.coo_matrix(
    (
        activity_df['activity'], 
        (activity_df['compound_idx'], activity_df['assay_idx'])
    ),
    shape=(5, 5),)

activity_bool = np.logical_and.reduce( # take multiple Boolean conditions and combines them using logical AND across all conditions.
            (
                self_activity.row >= c_low,
                self_activity.row < c_high,
                self_activity.col >= a_low,
                self_activity.col < a_high
            )
        )

print(f'activity_bool: \n{activity_bool}\n')
print(f'np.flatnonzero(activity_bool): \n{np.flatnonzero(activity_bool)}\n')

print(f'self_activity.data: \n{self_activity.data}\n')

activity_bool: 
[False False False False False]

np.flatnonzero(activity_bool): 
[]

self_activity.data: 
[1 0 1 0 1]



# utils.py

In [5]:
def parse_hidden_layers(s:str):
    """
    #? Yu: why this function necessary?
    Parse a string in the form of [32, 32] into a list of integers."""
    try:
        res = [int(ls) for ls in s.strip('[]').split(',')]
    except:
        raise argparse.ArgumentTypeError(
            f"Invalid hidden layers format: {s}. Expected format is [32, 32]."
        )
    return res

In [6]:
NAME2FORMATTER = {
    'assay_mode': str,
    'model': str,
    'multitask_temperature': float, #?Yu: if not used later, remove
    'optimizer': str,
    'hidden_layers': parse_hidden_layers, #?Yu: number of hidden layers?
    'compound_layer_sizes': parse_hidden_layers, #?Yu:?
    'assay_layer_sizes': parse_hidden_layers, #?Yu:?
    'embedding_size': int, 
    'lr_ini': float,
    'epoch_max': int, 
    'batch_size': int,
    'dropout_input': float,
    'dropout_hidden': float,
    'l2': float, #?Yu: L2 regularization?
    'nonlinearity': str,
    'pooling_mode': str,
    'lr_factor': float,
    'patience': int,
    'attempts': int, # not used in public version #?Yu: don't understand, remove?
    'loss_fun': str, 
    'tokenizer': str,
    'transformer': str,
    'warmup_epochs': int,
    'train_balanced': int,
    'beta': float,
    'norm': bool,
    'label_smoothing': float,
    'gpu': int,
    'checkpoint': str,
    'verbose': bool,
    'hyperparams': str,
    'format': str,
    'f': str, #?Yu: file path?
    'support_set_size': int,
    'train_only_actives': bool,
    'random': int, 
    'seed': int, 
    'dataset': str,
    'experiment': float,
    'split': str,
    'wandb': str,
    'compound_mode': str,
    'train_subsample': float, #?Yu:?
}

EVERY = 50000

In [6]:
def get_hparams(path, mode='logs', verbose=False):
    """
    Get hyperparameters from a path. If logs uses path /params/* files form mlflow.
    If mode is json: loads in the file provided in path. 

    Params
    ------
    path: str
        Path to the hyperparameters file.
    mode: str
        Mode of the hyperparameters file. Default is 'logs'.
    verbose: bool
        Be verbose if True.
    """
    if isinstance(path, str):
        path = Path(path)
    hparams = {}
    if mode =='logs':
        for fn in os.listdir(path/'params'):
            try:
                with open(path/f'params/{fn}') as f:
                    lines = f.readlines()
                    try:
                        hparams[fn] = NAME2FORMATTER.get(fn, str)(lines[0])
                    except:
                        hparams[fn] = None if len(lines)==0 else lines[0]
            except:
                pass
    elif mode == 'json':
        with open(path) as f:
            hparams = json.load(f)
    if verbose:
        logger.info("loaded hparams:\n", hparams)
    
    return hparams

In [8]:
def seed_everything(seed=70135):
    """ does what it says ;) - from https://gist.github.com/KirillVladimirov/005ec7f762293d2321385580d3dbe335"""
    import numpy as np
    import random
    import os
    import torch

    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

In [None]:
def set_device(gpu=0, verbose=False): #?Yu: verbose is not used
    "Set device to gpu or cpu."
    if gpu == 'any':
        gpu = 0
    if torch.cuda.is_available():
        device = torch.device(f'cuda:{gpu}')
    else:
        device = torch.device('cpu')
    return device

**def train_and_test()**
- Function signature and parameters 
- Model initialization 
- Optimizer and loss function initialization 
- Learning rate Scheduler 
- Batch sampler 
- Training loop 
- Validation loop 
- Testing loop 

In [10]:
def init_checkpoint(path, device, verbose=False):
    """
    load from path if path is not None, otherwise return empty dict.
    """
    if path is not None:
        if verbose:
            logger.info('Load checkpoint.')
        return torch.load(path, map_location=device)
    return {}

In [11]:
def get_log_paths(run_info: mlflow.entities.RunInfo):
    """
    Return paths to the artifacts directory and the model weights.
    """
    artifacts_dir = Path('mlruns', run_info.experiment_id, run_info.run_id, 'artifacts')
    checkpoint_file_path = artifacts_dir / 'checkpoint.pt'
    metrics_file_path = artifacts_dir / 'metrics.parquet'
    return artifacts_dir, checkpoint_file_path, metrics_file_path

In [12]:
class EarlyStopper:
    # adapted from https://stackoverflow.com/questions/71998978/early-stopping-in-pytorch
    # Early stopping is a regularization technique to prevent overfitting. 
    # During training, it monitors the validation loss and stops training when the validation loss does not improve for a specified number of epochs (patience).
    # This helps the model to generalize better rather than just memorizing the training data.
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience # number of epochs with no improvement after which training will be stopped
        self.min_delta = min_delta # minimum change to consider it an improvement
        self.counter = 0 # counter for the number of epochs with no improvement
        self.min_validation_loss = np.inf # the best (lowest) validation loss seen so far
        self.improved = False # flag to indicate if the last validation loss has improved

    def __call__(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
            self.improved = True
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            self.improved = False
            if self.counter >= self.patience:
                return True
        return False       

In [3]:
def init_model( #?Yu: understand and finalize this function.
        compound_feature_size: int,
        assay_feature_size: int,
        hp: dict,
        verbose: bool = False
) -> DotProduct:
    """
    Initialize PyTorch model.

    Params
    -------
    compound_feature_size: int
        Input size of the compound encoder.
    assay_feature_size: int
        Input size of the assay encoder.
    hp: dict
        Hyperparameters.
    verbose: bool
        Be verbose if True.

    Returns
    -------
    :class:`DotProduct``
        Model instance.
    """
    #Yu: `model_str` is not used later, so it could be removed.
    #model_str = hp['model'] # TAn example of `model_str` is 'MLPLayerNorm'. 
    #modes = ['Multitask', 'Scaled', 'GNN', 'Pretrained'] #?Yu: is it necessary to have these modes?
    #selected_mode = ''
    #for mode in modes:
    #    if mode in model_str:
    #        model_str = model_str.replace(mode, '')
    #        selected_mode = mode
#
    #logger.info(selected_mode+' has been selected.')

    #?Yu: if not used, remove the below code
    # if not (hasattr(models, model_str) or hasattr(gnn, 'GNN'+model_str)):
    #    raise NotImplementedError(f'Model "{hp["model"]}" is not known.')

    # ['Linear', 'MLPLayerNorm', 'ScaledMLPLayerNorm', 'MultitaskMLPLayerNorm']

    if verbose:
        logger.info(f'Initialize "{hp["model"]}" model.')

    init_dict = hp.copy()
    init_dict.pop('embedding_size') #?Yu remove the embedding size, since it has to be provided as positional argument.

    # For getattr(clamp.models, hp['model']) to work, the class must be exposed at the package level in /clamp/models/__init__.py. 
    # Typically, __init__.py will import selected classes from those submodules, making them accessible like clamp.models.MyModelClass.
    model = getattr(model_def, hp['model'])( #?Yu model_def: class, hp['model']: the specified attribute. 
        compound_features_size=compound_feature_size,
        assay_features_size=assay_feature_size,
        embedding_size=hp['embedding_size'],
        **init_dict
    )

    if wandb.run:
        wandb.watch(model, log_freq=100, log_graph=(True))  # Log model weights and gradients for visualization in wandb, generate and log the computational graph automatically.

    return model


In [5]:
#Yu: Comprehend the below code.
path = os.path.join(DATA_DIR, 'model_dev', 'hparams', 'default.json')
hp = get_hparams(path, mode='json', verbose=False)
init_dict = hp.copy()
init_dict.pop('embedding_size')

print(f'hp: {hp}')
print(f'init_dict: {init_dict}')

hp: {'model': 'MLPLayerNorm', 'hidden_layers': [2048, 1024], 'embedding_size': 512, 'lr_ini': 1e-05, 'epoch_max': 50, 'batch_size': 256, 'dropout_input': 0.1, 'dropout_hidden': 0.3, 'l2': 0.0005, 'lr_factor': 1, 'patience': 3, 'attempts': 1, 'optimizer': 'AdamW', 'loss_fun': 'BCE', 'warmup_epochs': 2, 'train_balanced': 0, 'train_subsample': 0, 'beta': 1, 'bedroc_alpha': 20.0}
init_dict: {'model': 'MLPLayerNorm', 'hidden_layers': [2048, 1024], 'lr_ini': 1e-05, 'epoch_max': 50, 'batch_size': 256, 'dropout_input': 0.1, 'dropout_hidden': 0.3, 'l2': 0.0005, 'lr_factor': 1, 'patience': 3, 'attempts': 1, 'optimizer': 'AdamW', 'loss_fun': 'BCE', 'warmup_epochs': 2, 'train_balanced': 0, 'train_subsample': 0, 'beta': 1, 'bedroc_alpha': 20.0}


In [None]:
def filter_dict(dict_to_filter, thing_with_kwargs):
    """
    Yu: integrate this function into other functions, so that it can be removed.
    filter dict_to_filter by the arguments of the object or function of thing_with_kwargs.
    so that you can stressfree do this: thing_with_kwargs(**filter_dict_return)
    returns: filtered_dict
    modified from https://stackoverflow.com/questions/26515595/how-does-one-ignore-unexpected-keyword-arguments-passed-to-a-function
    """
    import inspect # a python standard library
    sig = inspect.signature(thing_with_kwargs) # get the list of valid argument names for the function or class `thing_with_kwargs`.
    filter_keys =[param.name for param in sig.parameters.values() if param.kind == param.POSITIONAL_OR_KEYWORD] # only `POSITIONAL_OR_KEYWORD` parameters are considered.
    inters = set(dict_to_filter.keys()).intersection(filter_keys) # do filter

    return {k:dict_to_filter[k] for k in inters}


def init_optimizer(model, hp, verbose=False):
    """
    Initialize optimizer for the model.
    """
    if verbose:
        logger.info(f"Trying to initialize '{hp['optimizer']}' optimizer from torch.optim.")
    hp['lr'] = hp.pop('lr_ini') # remove 'lr_ini' and rename it to 'lr'
    hp['weight_decay'] = hp.pop('l2') # remove 'l2' and rename it to 'weight_decay'
    optimizer = getattr(torch.optim, hp['optimizer']) # fetch the optimizer class (e.g., `Adam`)
    filtered_dict = filter_dict(hp, optimizer) # remove any keys in `hp` that are not valid arguments for the selected optimizer class.
    
    return optimizer(params=model.parameters(), **filtered_dict) # optimize all the model's parameters by using the filtered hyperparameters of the optimizer.

In [16]:
# Yu: comprehend the code below.
path = os.path.join(DATA_DIR, 'model_dev', 'hparams', 'default.json')
hp = get_hparams(path, mode='json', verbose=False)

# arguments for `filter_dict`
dict_to_filter = hp.copy()
dict_to_filter['lr'] = dict_to_filter.pop('lr_ini')
dict_to_filter['weight_decay'] = dict_to_filter.pop('l2')
thing_with_kwargs = getattr(torch.optim, hp['optimizer'])

# code in `filter_dict``
import inspect
sig = inspect.signature(thing_with_kwargs)
print(f'sig: {sig}')

filter_keys = [p.name for p in sig.parameters.values() if p.kind == p.POSITIONAL_OR_KEYWORD]
print(f'filter_keys: {filter_keys}')

inters = set(dict_to_filter.keys())
print(f'inters: {inters}')

inters_intersection = inters.intersection(filter_keys)
print(f'inters_intersection: {inters_intersection}')

return_by_filter_dict = {k:dict_to_filter[k] for k in inters_intersection}
print(f'return_by_filter_dict: {return_by_filter_dict}')

sig: (params: Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]], lr: Union[float, torch.Tensor] = 0.001, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0.01, amsgrad: bool = False, *, maximize: bool = False, foreach: Union[bool, NoneType] = None, capturable: bool = False, differentiable: bool = False, fused: Union[bool, NoneType] = None)
filter_keys: ['params', 'lr', 'betas', 'eps', 'weight_decay', 'amsgrad']
inters: {'lr_factor', 'patience', 'bedroc_alpha', 'epoch_max', 'lr', 'optimizer', 'attempts', 'loss_fun', 'weight_decay', 'warmup_epochs', 'train_balanced', 'batch_size', 'train_subsample', 'embedding_size', 'beta', 'model', 'hidden_layers', 'dropout_hidden', 'dropout_input'}
inters_intersection: {'weight_decay', 'lr'}
return_by_filter_dict: {'weight_decay': 0.0005, 'lr': 1e-05}


In [15]:
def top_k_accuracy(y_true, y_pred, k=5, ret_arocc=False, ret_mrocc=False, verbose=False, count_equal_as_correct=False, eps_noise=0):
    """
    partly from http://stephantul.github.io/python/pytorch/2020/09/18/fast_topk/
    count_equal counts equal values as being a correct choice. e.g. all preds = 0 --> T1acc=1
    ret_mrocc ... also return median rank of correct choice
    eps_noise ... if > 0, and noise*eps to y_pred .. recommended e.g. 1e-10 #?Yu
    """
    if eps_noise > 0:
        if torch.is_tensor(y_pred):#?Yu
            y_pred = y_pred + torch.rand(y_pred.shape)*eps_noise
        else:
            y_pred = y_pred + np.random.rand(*y_pred.shape)*eps_noise
    if count_equal_as_correct:
        greater = (y_pred > y_pred[range(len(y_pred)), y_true][:,None]).sum(1) # how many are bigger
    else:
        greater = (y_pred >= y_pred[range(len(y_pred)), y_true][:,None]).sum(1) # how many are bigger or equal
    if torch.is_tensor(y_pred):
        greater = greater.long()
    if isinstance(k, int): k = [k] # pack it into a list
    tkaccs = []
    for ki in k:
        if count_equal_as_correct:
            tkacc = (greater<=(ki-1))
        else:
            tkacc = (greater<=(ki))

        if torch.is_tensor(y_pred):
            tkacc = tkacc.float().mean().detach().cpu().numpy()
        else:
            tkacc = tkacc.mean()
        tkaccs.append(tkacc)
        if verbose:
            print('Top', ki, 'acc:\t', str(tkacc)[:6])
    
    if ret_arocc:
        arocc = greater.float().mean()+1
        if torch.is_tensor(arocc):
            arocc = arocc.detach().cpu().numpy()
        return (tkaccs[0], arocc) if len(tkaccs) == 1 else (tkaccs, arocc)
    if ret_mrocc:
        mrocc = greater.median()+1
        if torch.is_tensor(mrocc):
            mrocc = mrocc.float().detach().cpu().numpy()
        return (tkaccs[0], mrocc) if len(tkaccs) == 1 else (tkaccs, mrocc)
    
    return tkaccs[0] if len(tkaccs) == 1 else tkaccs

In [None]:
def train_and_test(
        InMemory: InMemoryClamp,
        train_idx: np.ndarray,
        valid_idx: np.ndarray,
        test_idx: np.ndarray,
        hparams: dict,
        run_info: mlflow.entities.RunInfo,
        checkpoint_file: Optional[Path] = None,
        keep: bool = True,
        device: str = 'cpu',
        bf16: bool = False, #?Yu: when to set bf16=True?
        verbose: bool = True
) -> None:
    """
    Train a model on `InMemory[train_idx]` while validating on `InMemory[valid_idx]`.
    Once the training is finished, evaluate the model on `InMemory[test_idx]`.

    A moodel-optimizer PyTorch checkpoint can be passed to resume training.

    Params:
    -------
    InMemory: :class:`dataset.InMemoryClamp`
         Dataset instance.
    train_idx: :class:`numpy.ndarray`
        Activity indices of the training split.
    valid_idx: :class:`numpy.ndarray`
        Activity indices of the validation split.
    test_idx: :class:`numpy.ndarray`
        Activity indices of the test split.
    hparams: dict
        Model characteristics and training strategy.
    run_info: :class:`mlflow.entities.RunInfo`
        MLflow's run details (for logging purposes).
    checkpoint_file: str or :class:`pathlib.Path`
        Path to a model-optimizer checkpoint from which to resume training.
    keep: bool
        Keep the persisted model weights if True, remove them otherwise.
    device: str
        Device to use for training (e.g., "cpu" or "cuda").
    verbose: bool
        Print verbose messages if True.
    """

    if verbose:
        if checkpoint_file is None:
            message = 'Strat training.'
        else:
            message = f'Resume training from {checkpoint_file}.'
        logger.info(message)


    # ================================= Function signature and parameters =================================
    # initialize checkpoint, if any; if no checkpoint is given, an empty dict is returned.
    checkpoint = init_checkpoint(checkpoint_file, device)
    # get paths to the artifacts directory and the model weights.
    artifacts_dir, checkpoint_file_path, metrics_file_path = get_log_paths(run_info)
    early_stopping = EarlyStopper(patience=hparams['patience'], min_delta=0.0001)

    # ================================= Model initialization =================================
    print(hparams)
    
    #?Yu: Regard different assays or targets as different tasks. Keep the below `Multitask` related code if used later, otherwise remove it.
    #?Yu: why `setup_assay_onehot` is used here?`
    if 'Multitask' in hparams.get('model'):

        _, train_assays = InMemory.get_unique_names(train_idx)
        InMemory.setup_assay_onehot(size=train_assays.index.max() + 1)
        train_assay_features = InMemory.assay_features[:train_assays.index.max() + 1] #?Yu: no `assay_features` defined neither in the primary code or 'InMemoryClamp` before.
        train_assay_features_norm = F.normalize(torch.from_numpy(train_assay_features), #?Yu: why set this here but use it quite later?
            p=2, dim=1 #Yu: p=2: the exponent value in the norm formulation; dim=1: the dimension to reduce.
        ).to(device)

        model = init_model(
            compound_features_size=InMemory.compound_features_size,
            assay_feature_size=InMemory.assay_onehot.size, #?Yu: `assay_onehot` has not been defined in the `InMemoryClamp` class before.
            hp=hparams,
            verbose=verbose
        )
    
    else:
        model = init_model(
            compound_features_size=InMemory.compound_features_size,
            assay_features_size=InMemory.assay_features_size,
            hp=hparams,
            verbose=verbose
        )
    
    #Yu: 'model_state_dict' is only available in class `Pretrained(DotProduct)`
    #Yu: if not used, remove the below code
    if 'model_state_dict' in checkpoint:
        if verbose:
            logger.info('Load model_state_dict from checkpoint into model.')
        model.load_state_dict(checkpoint['model_state_dict']) # `load_state_dict` is the attribute for `nn.Module``
        model.train() # `train` is a method of `nn.Module` that sets the module in training mode.
    
    model = model.to(device)
    # ================================= Optimizer and loss function initialization =================================
    # initialize optimizer
    # Moving a model to the GPU should be done before the creation of its optimizer.
    optimizer = init_optimizer(model, hparams, verbose)

    if 'optimizer_state_dict' in checkpoint:
        if verbose:
            logger.infp('Load optimizer_state_dict from checkpoint into optimizer.')
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    # initialize loss function #Yu: the core of clamp.
    criterion = nn.BCEWithLogitsLoss() # default, allowing `loss_fun` to be optional.
    if 'loss_fun' in hparams:
        class CustomCE(nn.CrossEntropyLoss):
            """Cross entropy loss #?Yu"""
            def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
                """
                param
                -------
                input: predicted unnormalized logits. This is the raw output (logits, i.e. preactivations, no softmax/sigmoid) from the model, typically of shape [batch_size, batch_size] in contrastive/self-supervised settings.
                target: ground truth class indices or class probabilities.

                return
                -------
                for `F.cross_entropy`:
                weight: a manual rescaling weight given to each class.
                """
                beta = 1/(input.shape[0]**(1/2)) # scaling factor, normalizes the logits so that their magnitude is independent of batch size, which can help stabilize training.
                input = input * (target*2-1) * beta # 
                target = torch.arange(0, len(input)).to(input.device)

                return F.cross_entropy(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction)
            
        class ConLoss(nn.CrossEntropyLoss):
            """Contrastive Loss"""
            def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
                sigma = 1 # scaling factor. set to 1 here, so it does not affect the result.
                bs = target.shape[0] #?Yu
                #Yu: remove the below code if not used
                #only modify diag that is a negative
                # eg makes this from a target of [0, 1, 0]
                # tensor([[-1., 1., 1.],
                #         [1., 1., 1.],
                #         [1., 1., -1.]])
                modif = (1-torch.eye(bs)).to(target.device) + (torch.eye(bs).to(target.device)*(target*2-1)) # `torch.eye`: returns a 2-D tensor with ones on the diagonal and zeros elsewhere.`bs` is the number of rows.
                input = input*modif/sigma
                diag_idx = torch.arange(0, len(input)).to(input.device)

                label_smoothing = hparams.get('label_smoothing', 0.0)
                if label_smoothing is None:
                    label_smoothing = 0.0
                
                mol2txt = F.cross_entropy(input, diag_idx, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction, label_smoothing=label_smoothing)
                text2mol = F.cross_entropy(input.T, diag_idx, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction, label_smoothing=label_smoothing)

                return mol2txt + text2mol
            
        str2loss_fun = {
            'BCE': nn.BCEWithLogitsLoss(),
            'CE': CustomCE(),
            'Con': ConLoss(),
        }
        assert hparams['loss_fun'] in str2loss_fun, "loss_fun not implemented"
        criterion = str2loss_fun[hparams['loss_fun']]

    criterion = criterion.to(device)
        
    # ================================= Learning rate Scheduler =================================
    # lambda function below returns `lr_factor` whatever the input to lambda is.
    if 'lr_factor' in hparams:
        lr_factor = hparams['lr_factor']
    else:
        lr_factor = 1 #?Yu: why
    scheduler = MultiplicativeLR(optimizer, lr_lambda=lambda _: lr_factor)

    #lot_lr_scheduler(torch.optim.lr_scheduler.CosineAnnealingLR, T_max=1000, eta_min=0)
    num_steps_per_epoch = len(train_idx)/hparams['batch_size']
    class Linwarmup():
        def __init__(self, steps=10000):
            self.step = 0
            self.max_step = steps
            self.step_size = 1/steps
        def get_lr(self, lr):
            if self.step>self.max_step:
                return 1
            new_lr = self.step * self.step_size
            self.step += 1
            return new_lr
        
    #Todo Bug when set to 0
    if hparams.get('warmup_step'): #?Yu: why not `if hparams['warmup_step']`?
        scheduler2 = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                       lr_lambda=Linwarmup(steps=num_steps_per_epoch + hparams.get('warmup_epochs', 0)).get_lr)
    else:
        scheduler2 = None
    
    if lr_factor !=1:
        logger.info(f'Scheduler enabled with lr_factor={hparams["lr_factor"]}. Note that this makes different runs difficult to compare.')
    else:
        logger.info('Scheduler enabled with lr_factor=1. This keeps the interface but results in no reduction.')
    # ================================= Batch sampler =================================
    # The acutual dataset slicing is actually done manually.
    train_sampler = RandomSampler(data_source=train_idx) #?Yu: why not implement train_batcher like 'valid_batcher' and 'test_batcher'?

    valid_sampler = SequentialSampler(data_source=valid_idx)
    valid_batcher = BatchSampler(
        sampler=valid_sampler,
        batch_size=hparams['batch_size'],
        drop_last=False
    )

    test_sampler = SequentialSampler(data_source=test_idx)
    test_batcher = BatchSampler(
        sampler=test_sampler,
        batch_size=hparams['batch_size'],
        drop_last=False
    )

    epoch = checkpoint.get('epoch', 0) #?Yu: what if not checkpoint?
    new_train_idx = None #?Yu: why
    while epoch < checkpoint.get('epoch', 0) + hparams['epoch_max']:
        if hparams.get('train_balanced', False):
            logger.info('sampling balanced')
            num_pos = InMemory.activity.data[train_idx].sum()
            #?Yu_ remove the below comment later
            # too large with WeightedRandomSampler
            # num_neg = (len(train_idx))-num_pos
            remove_those = train_idx[((InMemory.activity.data[train_idx]) == 0)]
            remove_those = np.random.choice(remove_those, size=int(len(remove_those)-num_pos)) #?Yu
            idx = np.in1d(train_idx, remove_those) #?Yu
            new_train_idx = train_idx[~idx] #?Yu
            if isinstance(hparams['train_balanced'], int):
                max_samples_per_epoch = hparams['train_balanced']
                if max_samples_per_epoch > 1:
                    logger.info(f'using only {max_samples_per_epoch} for one epoch')
                    new_train_idx = np.random.choice(new_train_idx, size=max_samples_per_epoch)
            train_sampler = RandomSampler(data_source=new_train_idx)
        if hparams.get('train_subsample', 0) > 0: #?Yu: why not `elif`
            if hparams['train_subsample']<1:
                logger.info(f'subsample training set to {hparams["train_subsample"]*100}%')
                hparams['train_subsample'] = int(hparams['train_subsample']*len(train_idx))
            logger.info(f'subsample training set to {hparams["train_subsample"]}')
            sub_train_idx = np.random.choice(train_idx if new_train_idx is None else new_train_idx, size=int(hparams['train_subsample']))
            train_sampler = RandomSampler(data_source=sub_train_idx)
        
        train_batcher = BatchSampler(
            sampler=train_sampler, 
            batch_size=hparams['batch_size'],
            drop_last=False
        )
 
        # ================================= Training loop =================================
        loss_sum = 0.
        preactivations_l = []
        topk_l, arocc_l = [], []
        activity_idx_l = []
        for nb, batch_indices in enumerate(train_batcher): #tqdm(, mininterval=2) #?Yu: what is nb?

            # get and unpack batch data
            batch_data = Subset(InMemory, indices=train_idx)[batch_indices]
            activity_idx, compound_features, assay_features, assay_onehot, activity = batch_data #?Yu: what is no assay_onehot?

            # move data to device
            #?Yu: remove the below comments if not used
            # assignment is not necessary for modules but it is for tensors.
            # https://discuss.pytorch.org/t/what-is-the-difference-between-doing-net-cuda-vs-net-to-device/69278/8
            if isinstance(compound_features, torch.Tensor):
                compound_features = compound_features.to(device)
            assay_features = assay_features.to(device) if not isinstance(assay_features[0], str) else assay_features #?Yu: why not using same method that is used for `compound_features`?
            assay_onehot = assay_onehot.to(device).float() if not isinstance(assay_onehot[0], str) else assay_onehot #?Yu
            activity = activity.to(device)

            # forward
            #with torch.autocast("cuda", dtype=torch.bfloat16 if bf16 else torch.float32): #?Yu: shall I keep this comment?
            if hparams.get('loss_fun') in ('CE', 'Con'): # why in the two cases, `forward_dense` is used?
                preactivations = model.forward_dense(compound_features, #?Yu: go to check the difference between 'forward' and 'forward_dense'
                                                     assay_onehot if 'Multitask' in hparams['model'] else assay_features) #?Yu: consider whether to remove 'assay_onehot'
            else:
                preactivations = model(compound_features, 
                                       assay_onehot if 'Multitask' in hparams['model'] else assay_features)
            
            # loss
            beta = hparams.get('beta', 1)
            if beta is None: beta = 1
            preactivations = preactivations*1/beta #?Yu
            loss = criterion(preactivations, activity)

            # zero gradients, backpropagation, update
            optimizer.zero_grad()
            loss.backward()
            if hparams.get('optimizer') == 'SAM': #?Yu
                def closure():
                    preactivations = model(compound_features, assay_onehot if 'Multitask' in hparams['model'] else assay_features) # why compute preactivation again?
                    loss = criterion(preactivations, activity)
                    loss.backward()
                    return loss 
                optimizer.step(closure)
            else:
                optimizer.step()
                scheduler.step()
                if scheduler2: scheduler2.step() #?Yu

            # accumulate loss 
            loss_sum += loss.item()

            if hparams.get('loss_fun')=='Con':
                ks = [1, 5, 10, 50] #?Yu
                tkaccs, arocc = top_k_accuracy(torch.arange(0, len(preactivations)), preactivations, k=[1, 5, 10, 50], ret_arocc=True)
                topk_l.append(tkaccs)
                arocc_l.append(arocc)
            if hparams.get('loss_fun') in ('CE', 'Con'):
                #preactivations = preactivations.sum(axis=1) #?Yu: why keep it here?
                preactivations =torch.diag(preactivations) # get only diag elements

            # accumulate preactivations
            # - need to detach; preactivations.requires_grad = True
            # - move it to cpu #?Yu
            preactivations_l.append(preactivations.detach().cpu())

            # accumulate_indices to track order in which the dataset is visited
            # - activity_idx is a np.array, not a torch.tensor
            activity_idx_l.append(activity_idx)

            if nb % EVERY == 0 and verbose: #?Yu: EVERY = 50000, why set it to 50000?
                logger.info(f'Epoch{epoch}: Training batch {nb} out of {len(train_batcher) - 1}.')
        
        # log mean loss over all minibatches
        mlflow.log_metric('train_loss', loss_sum / len(train_batcher), step=epoch)
        if wandb.run:
            wandb.log({
                'train/loss': loss_sum / len(train_batcher),
                'lr': scheduler2.get_last_lr()[0] if scheduler2 else scheduler.get_last_lr()[0]
            }, step=epoch)

        # compute metrics for each assay (on the cpu)
        preactivations = torch.cat(preactivations_l, dim=0)
        probabilities = torch.sigmoid(preactivations).numpy()

        activity_idx = np.concatenate(activity_idx_l, axis=0)
        # assert np.array_equal(np.sort(activity_idx), train_idx)
        # assert not np.array_equal(activity_idx, train_idx)

        targets = sparse.csc_matrix(
            (
                InMemory.activity.data[activity_idx],
                (
                    InMemory.activity.row[activity_idx],
                    InMemory.activity.col[activity_idx]
                )
            ), shape=(InMemory.num_compounds, InMemory.num_assays), dtype=np.bool
        )

        scores = sparse.csc_matrix(
            (
                probabilities, #?Yu: why is probabilities used here?
                (
                    InMemory.activity.row[activity_idx],
                    InMemory.activity.col[activity_idx]
                )
            ), shape=(InMemory.num_compounds, InMemory.num_assays), dtype=np.float32
        )

        #?Yu: `metrics` should be changed according to my implementation.
        #md = metrics.swipe_threshold_sparse(targets=targets, scores=scores,verbose=verbose>=2, ret_dict=True) # returns dict for with metric per assay in the form of {metric: {assay_nr: value}} #?Yu: isn't `verbose=verbose>=2` syntax error?
        md = swipe_threshold_sparse(targets=targets, scores=scores,verbose=verbose>=2, ret_dict=True)

        if hparams.get('loss_fun') == 'Con':
            for ii, k in enumerate(ks):#?Yu: `ks` is not defined in this loop.
                md[f'top_{k}_acc'] = {0:np.vstack(topk_l)[:-1, ii].mean()} # drop last (might be not full) #?Yu why
            md['arocc'] = {0:np.hstack(arocc_l)[:-1].mean()} # drop last (might be not full) #?Yu why

        logdic = {f'train_mean_{k}': np.nanmean(list(v.values())) for k,v in md.items() if v}
        mlflow.log_metrics(logdic, step=epoch) #?Yu: sort the use of mlflow along the code.
        if wandb.run: wandb.log({k.replace('_', '/'):v for k, v in logdic.items()}, step=epoch)
        # if verbose: logger.info(logdic) #?Yu: why not print the logdic？

        # ================================= Validation loop =================================
        with torch.no_grad():
            
            model.eval()

            loss_sum = 0.
            preactivations_l = [] #?Yu: will this overwrite the ones in the training loop?
            activity_idx_l = []
            for nb, batch_indices in enumerate(valid_batcher):

                # get and unpack batch data
                batch_data = Subset(InMemory, indices=valid_idx)[batch_indices]
                activity_idx, compound_features, assay_features, _, activity = batch_data #?Yu why isn't here `assay_onehot`?

                # move data to device
                # assignment is not necessary for modules but it is for tensors.
                # https://discuss.pytorch.org/t/what-is-the-difference-between-doing-net-cuda-vs-net-to-device/69278/8
                if isinstance(compound_features, torch.Tensor):
                    compound_features = compound_features.to(device)
                assay_features = assay_features.to(device) if not isinstance(assay_features[0], str) else assay_features # why is the conditions different between 'assay_features' and 'compound_features'.
                activity = activity.to(device)

                # forward #?Yu: why the `Multitask` related code is here but not in the training loop?`
                if 'Multitask' in hparams['model']:
                    assay_features_norm = F.normalize(
                        assay_features, p=2, dim=1
                    )
                    sim_to_train = assay_features_norm @ train_assay_features_norm.T #?Yu
                    sim_to_train_weights = F.softmax(sim_to_train * hparams['multitask_temperature'], dim = 1) #?Yu: what does `F.softmax` do?
                    preactivations = model(compound_features, sim_to_train_weights)
                
                elif hparams.get('loss_fun') in ('CE', 'Con'):
                    preactivations = model.forward_dense(compound_features, assay_onehot if 'Multitask' in hparams['model'] else assay_features) #?Yu: 'assay_onehot' is not defined in the validation loop.
                else:
                    preactivations = model(compound_features, assay_features)

                # loss
                preactivations = preactivations * 1 / hparams.get('beta', 1) #?Yu
                #?Yu Why is the below code block commented out in the primary code.
                if hparams.get('loss_fun') in ('CE', 'Con'):
                    loss = F.binary_cross_entropy_with_logits(preactivations, activity)
                else:
                    loss = criterion(preactivations, activity)
                
                # accumulate loss
                loss_sum += loss.item()

                if hparams.get('loss_fun') in ('CE', 'Con'): # how to calc the below metrics if `loss_fun` is neither 'CE' nor 'Con'?
                    ks = [1, 5, 10, 50]
                    tkaccs, arocc = top_k_accuracy(torch.arange(0, len(preactivations)), preactivations, k=[1, 5, 10, 50], ret_arocc=True)
                    topk_l.append(tkaccs) # already detached numpy #?Yu
                    arocc_l.append(arocc)
                
                # accumulate preactivations
                #Yu: remove the below comments if not used
                # - need to detach; preactivations.requires_grad is True
                # - move it to cpu
                if hparams.get('loss_fun') in ('CE', 'Con'): #?Yu: combine this condition with the one above? and similarily, how to calc preactivations if `loss_fun` is neither 'CE' nor 'Con'?
                    # preactivations = preactivations.sum(axis=1)
                    preactivations = torch.diag(preactivations) #?Yu: `torch.diag`

                preactivations_l.append(preactivations.detach().cpu())

                # accumulate indices just to double check.
                # - activity_idx is a np.array, not a torch.tensor
                activity_idx_l.append(activity_idx)

                if nb % EVERY == 0 and verbose:
                    logger.info(f'Epoch{epoch}: Validation batch {nb} out of {len(valid_batcher) -1}.')
            
            # log mean loss over all minibatches
            valid_loss = loss_sum / len(valid_batcher)
            mlflow.log_metrics('valid_loss', valid_loss, step=epoch) #?Yu: sort the use of mlflow along the code.
            if wandb.run: 
                wandb.log({'valid/loss': valid_loss}, step=epoch)
            
            # compute test auroc and avgp for each assay (on the cpu)
            preactivations = torch.cat(preactivations_l, dim=0)
            probabilities = torch.sigmoid(preactivations).numpy()

            activity_idx = np.concatenate(activity_idx_l, axis=0)
            # assert np.array_equal(activity_idx, valid_idx) #?Yu: why is this line commented out in the primary code?

            targets = sparse.csc_matrix(
                (
                    InMemory.activity.data[valid_idx],
                    (
                        InMemory.activity.row[valid_idx],
                        InMemory.activity.col[valid_idx]
                    )
                ), shape=(InMemory.num_compounds, InMemory.num_assays), dtype=np.bool
            )

            scores = sparse.csc_matrix(
                (
                    probabilities,
                    (
                        InMemory.activity.row[valid_idx],
                        InMemory.activity.col[valid_idx]
                    )
                ), shape=(InMemory.num_compounds, InMemory.num_assays), dtype=np.float32
            )

            #md = metrics.swipe_threshold_sparse(targets=targets, scores=scores, verbose=verbose>=2, ret_dict=True)
            md = swipe_threshold_sparse(targets=targets, scores=scores, verbose=verbose>=2, ret_dict=True)

            if hparams.get('loss_fun') == 'Con': #?Yu: what if 'loss_fun' is 'Con'
                #?Yu: how about other metrics?
                for ii, k in enumerate(ks): #?Yu: `ks` is not defined in this loop.
                    md[f'top_{k}_acc'] = {0:np.vstack(topk_l)[:-1, ii].mean()} # drop last (might be not full)
                md['arocc'] = {0:np.hstack(arocc_l)[:-1].mean()} # drop last (might be not full)
            
            # log metrics mean over assays #?Yu: modify here to calc metrics on OR datasets.

            logdic = {f'valid_mean_{k}': np.nanmean(list(v.values())) for k,v in md.items() if v} #?Yu: what is logdic?
            logdic['valid_loss'] =valid_loss

            mlflow.log_metrics(logdic, step=epoch)

            if wandb.run: wandb.log({k.replace('_', '/'):v for k, v in logdic.items()}, step=epoch)
            # if verbose: logger.info(logdic)

            # monitor metric
            evaluation_metric = 'valid_mean_davgp'

            if evaluation_metric not in logdic:
                logger.info('Using -valid_loss because valid_mean_avgp not in logdic.')
            log_value = logdic.get(evaluation_metric, -valid_loss) #?Yu: why -valid_loss?
            # metric_monitor(logdic['valid_mean_davgp'], epoch)
            do_early_stop = early_stopping(-log_value) # smaller is better #?Yu: why -log_value?

            # log model checkpoint dir
            if wandb.run:
                wandb.run.config.update({'model_save_dir':checkpoint_file_path})
            
            if early_stopping.improved:
                logger.info(f'Epoch {epoch}: Save model and optimizer checkpoint with val-davgp: {log_value}.')
                torch.save({
                    'value': log_value,
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }, checkpoint_file_path)
            

            if do_early_stop: #?Yu is this set by me or during the training loop?
                logger.info(f'Epoch {epoch}: Out of patience. Early stop!')
                break

            model.train() #?Yu: how this code line connect with the code above.
        
        epoch +=1
    
    # ================================= Testing loop =================================
    # test with best model
    with torch.no_grad():
        
        epoch -= 1 #?Yu how can this be the best model
        logger.info(f'Epoch {epoch}: Restore model from checkpoint.')
        # check if checkpoint exists
        if not os.path.exists(checkpoint_file_path):
            logger.warning(f'Checkpoint file {checkpoint_file_path} does not exist. Test with init model.')
        else:
            checkpoint = torch.load(checkpoint_file_path)
            model.load_state_dict(checkpoint['model_state_dict'])

        model.eval()

        loss_sum = 0.
        preactivations_l = []
        activity_idx_l = []
        for nb, batch_indices in enumerate(test_batcher):

            # get and unpack batch data
            batch_data = Subset(InMemory, indices=test_idx)[batch_indices]
            activity_idx, compound_features, assay_features, _, activity = batch_data #?Yu: why isn't here `assay_onehot`?

            # move data to device
            # assignment is not necessary for modules but it is for tensors.
            # https://discuss.pytorch.org/t/what-is-the-difference-between-doing-net-cuda-vs-net-to-device/69278/8
            if isinstance(compound_features, torch.Tensor):
                compound_features = compound_features.to(device)
            assay_features = assay_features.to(device) if not isinstance(assay_features[0], str) else assay_features
            activity = activity.to(device)

            # forward
            if 'Multitask' in hparams['model']:
                assay_features_norm = F.normalize(assay_features, p=2, dim=1) #
                sim_to_train = assay_features_norm @ train_assay_features_norm.T
                sim_to_train_weights = F.softmax(sim_to_train * hparams['multitask_temperature '], dim=1) #?Yu: what is `sim_to_train_weights`?
                preactivations = model(compound_features, sim_to_train_weights)
            else:
                preactivations = model(compound_features, assay_features) #?Yu: why not `assay_onehot`?
            
            # loss
            #?Yu: why the below code block is commented out in the primary code.
            #if hparams.get('loss_fun') in ('CE', 'Con'):
            #    loss = F.binary_cross_entropy_with_logits(preactivations, activity)
            #else:
            loss = criterion(preactivations, activity)

            # accumulate loss
            loss_sum += loss.item()

            # accumulate preactivations
            # - need to detach; preactivations.requires_grad is True
            # - move it to cpu
            preactivations_l.append(preactivations)

            # accumulate indices just to double check.
            # - activity_idx is a np.array, not a torch.tensor
            activity_idx_l.append(activity_idx) #?Yu: why not this code line closely after the definition of `activity_idx`.

            if nb % EVERY == 0 and verbose:
                logger.info(f'Epoch{epoch}: Test batch {nb} out of {len(test_batcher) -1}.')
        
        # log mean loss over all minibatches
        mlflow.log_metric('test_loss', loss_sum / len(test_batcher), step=epoch)
        if wandb.run: wandb.log({'test/loss': loss_sum / len(test_batcher)})

        # compute test auroc and avgp for each assay (on the cpu) 'WHY??? #?Yu: 'WHY' is in the primary code.
        preactivations = torch.cat(preactivations_l, dim=0)
        probabilities = torch.sigmoid(preactivations) #?Yu: figure out `sigmoid` function, why use it here.

        activity_idx = np.concatenate(activity_idx_l, axis=0)
        # assert np.array_equal(activity_idx, test_idx) #Todo WHY??? #?Yu: 'WHY' is in the primary code.

        probabilities = probabilities.detach().cpu().numpy().astype(np.float32)

        targets = sparse.csc_matrix(
            (
                InMemory.activity.data[test_idx],
                (
                    InMemory.activity.row[test_idx],
                    InMemory.activity.col[test_idx]
                )
            ), shape=(InMemory.num_compounds, InMemory.num_assays), dtype=np.bool
        )

        scores = sparse.csc_matrix(
            (
                probabilities,
                (
                    InMemory.activity.row[test_idx],
                    InMemory.activity.col[test_idx]
                )
            ), shape=(InMemory.num_compounds, InMemory.num_assays), dtype=np.float32
        )

        #md  = metrics.swipe_threshold_sparse(targets=targets, scores=scores, verbose=verbose>=2, ret_dict=True)
        md  = swipe_threshold_sparse(targets=targets, scores=scores, verbose=verbose>=2, ret_dict=True)

        if hparams.get('loss_fun') == 'Con':
            for ii, k in enumerate(ks): # what is `ii`
                md[f'top_{k}_acc'] = {0:np.vstack(topk_l)[:-1, ii].mean()} # drop last (might be not full)
            md['arocc'] = {0:np.hstack(arocc_l)[:-1].mean()} # drop last (might be not full)

        # log metrics mean over assays

        logdic = {f'test_mean_{k}': np.nanmean(list(v.values())) for k,v in md.items() if v} #?Yu: why `:` is not in f''
        mlflow.log_metrics(logdic, step=epoch)

        if wandb.run: wandb.log({k.replace('_','/'):v for k, v in logdic.items()}, step=epoch)
        if verbose:
            logger.info(pd.DataFrame.from_dict([logdic]).T) #?Yu: print a dataframe?
        
        # compute test activity counts and positives
        counts, positives = {}, {}
        for idx, col in enumerate(targets.T):
            if col.nnz == 0:
                continue
            counts[idx] = col.nnz
            positives[idx] = col.sum()

        # 'test_mean_bedroc': 0.6988015835969245, 'test_mean_davgp': 0.16930837444561778, 'test_mean_dneg_avgp': 0.17522445272085613, 
        # 'test/mean/auroc': 0.6709850363704437, 'test/mean/avgp': 0.6411171492554743, 'test/mean/neg/avgp': 0.7034156779109996, 
        # 'test/mean/argmax/j': 0.4308185
        # store test metrics and counts in a parquet file
        metrics_df = pd.DataFrame(md)
        metrics_df['argmax_j'] = metrics_df['argmax_j'].apply(sigmoid)
        #?Yu: why is the below code commented out in the primary code.
        # metrics_df['counts'] = counts # for PC_large: ValueError: Length of values (3933) does not match length of index (615)
        # metrics_df['positives'] = positves

        metrics_df.index.rename('assay_idx', inplace=True)

        metrics_df = InMemory.assay_names.merge(metrics_df, left_index=True, right_index=True)
        logger.info(f'Writing test metrics to {metrics_file_path}')
        metrics_df.to_parquet(metrics_file_path, compression=None, index=True)

        with pd.option_context('float_format', "{:.2f}".format):
            print(metrics_df)
            print(metrics_df.mean(0, numeric_only=True))

        model.train()
    
    if not keep: #?Yu: what is keep
        logger.info('Delete model checkpoint.')
        checkpoint_file_path.unlink() #unlink_ remove file or link.

In [3]:
# Yu: comprehend ``for nb, batch_indices in enumerate(train_batcher)`

train_batcher = [[0, 1, 2], 
                 [3, 4, 5], 
                 [6, 7, 8], 
                 [9]]
for nb, batch_indices in enumerate(train_batcher):
    print(f'Batch {nb}: {batch_indices}')

print(f'The length of train_batcher: {len(train_batcher)}')

Batch 0: [0, 1, 2]
Batch 1: [3, 4, 5]
Batch 2: [6, 7, 8]
Batch 3: [9]
The length of train_batcher: 4


In [None]:
targets = [[True, False, False, False, False],
 [False,  True, False, False, False],
 [False,  True, False, False, False],
 [False, False, False, False, False],
 [False, False, False, False, False]]

# convert targets to a sparse matrix
from scipy import sparse
import numpy as np

targets = sparse.csc_matrix(targets, dtype=np.bool_)

In [3]:
scores = [[1.,        0.,        0.,        0.,        0.       ],
 [0.,        1.,        0.,        0.,        0.       ],
 [0.,        0.9999994, 0.,        0.,        0.       ],
 [0.,        1.,        0.,        0.,        0.       ],
 [0.,        1.,        0.,        0.,        0.       ]]

scores = sparse.csc_matrix(scores, dtype=np.float32)

In [12]:
bedroc_alpha = 20
md = swipe_threshold_sparse(targets=targets, scores=scores, bedroc_alpha=bedroc_alpha, ret_dict=True)
print(md)

[32m2025-08-03 11:23:51.366[0m | [1mINFO    [0m | [36m__main__[0m:[36mswipe_threshold_sparse[0m:[36m67[0m - [1mFound 0 columns with both positive and negative samples.[0m
[32m2025-08-03 11:23:51.367[0m | [1mINFO    [0m | [36m__main__[0m:[36mswipe_threshold_sparse[0m:[36m68[0m - [1mFound and skipped 2 columns with only positive or negative samples.[0m


{'argmax_j': {}, 'auroc': {}, 'avgp': {}, 'neg_avgp': {}, 'davgp': {}, 'dneg_avgp': {}, 'auprc': {}, 'bedroc': {}}


In [None]:
def test(
        InMemory: InMemoryClamp,
        train_idx: np.ndarray,
        test_idx: np.ndarray,
        hparams: dict, 
        run_info: mlflow.entities.RunInfo,
        device: str = 'cpu',
        verbose: bool = False, 
        model = None
) -> None: #?Yu: isn't `metric.df` returned?
    """
    Test a model on `InMemory[test_idx]`if test metrics are not yet to be found under the `actifacts` directory. 
    If so, interrupt the program.

    Params
    ----------
    InMemory: InMemoryClamp
        Dataset instance
    train_idx: :class:`numpy.ndarray``
        Activity indices of the training split. Only for multitask models. #?Yu: why only for multitask models?
    test_idx: :class:`numpy.ndarray``
        Activity indices of the test split.
    run_info: class:`mlflow.entities.RunInfo`
        MLflow's run details (for logging purposes).
    device: str
        Computing device.
    verbose: bool
        Be verbose if True.
    """

    if verbose:
        logger.info('Start evaluation.')

    artifacts_dir = Path('mlruns', run_info.experiment_id, run_info.run_id, 'artifacts')

    # for logging new checkpoints
    checkpoint_file_path = artifacts_dir / 'checkpoint.pt'
    metrics_file_path = artifacts_dir / 'metrics.parquet'

    # initialize checkpoint
    if model != None:
        checkpoint = init_checkpoint(checkpoint_file_path, device)
        assert checkpoint, 'No checkpoint found'
        assert 'model_state_dict' in checkpoint, 'No model found in checkpoint' #?Yu 'model_state_dict' in checkpoint? how this attribute be gotten?

    artifacts_dir, checkpoint_file_path, metrics_file_path = get_log_paths(run_info)

    # initialize model
    if 'Multitask' in hparams['model']:
        _, train_assays = InMemory.get_unique_names(train_idx)
        InMemory.setp_assay_onehot(size=train_assays.index.max() + 1)
        train_assay_features = InMemory.assay_features[:train_assays.index.max() + 1]
        train_assay_features_norm = F.normalize(
            torch.from_numpy(train_assay_features), p=2, dim=1
        ).to(device)

        if model != None:
            model = init_model(
                compound_feature_size= InMemory.compound_feature_size,
                assay_feature_size= InMemory.assay_onehot.size,
                hp=hparams, #?Yu: how can the hparams that get the best model be used here?
                verbose=verbose
            )
    else:
        if model != None:
            model = init_model(
                compound_feature_size=InMemory.compound_feature_size,
                assay_feature_size=InMemory.assay_features.size,
                hp=hparams, verbose=verbose
            )
        
    if verbose:
        logger.info('Load model from checkpoint.')
    if model != None:
        model.load_state_dict(checkpoint['model_state_dict'])

    # assignment is not necessary when moving modules, but it is for tensors.
    # https://discuss.pytorch.org/t/what-is-the-difference-between-doing-net-cuda-vs-net-to-device/69278/8
    # here I only assign for consistency
    model = model.to(device)

    # initialize loss function
    criterion = nn.BCEWithLogitsLoss # why is it enough to use this function instead of `CustomCE` or `ConLoss` during testing?
    criterion = criterion.to(device)

    test_sampler = SequentialSampler(data_source=test_idx)
    test_batcher = BatchSampler(
        sampler=test_sampler,
        batch_size=hparams['batch_size'],
        drop_last=False #?Yu:
    )

    epoch = checkpoint.get('epoch', 0)
    with torch.no_grad():

        model.eval()

        loss_sum = 0.
        preactivations_l = []
        activity_idx_l = []
        for nb, batch_indices in enumerate(test_batcher):

            # get and unpack batch data
            batch_data = Subset(InMemory, indices=test_idx)[batch_indices]
            activity_idx, compound_features, assay_features, assay_onehot, activity = batch_data #?Yu: why `assay_onehot` is added here?

            # move data to device
            if isinstance(compound_features, torch.Tensor):
                compound_features = compound_features.to(device)
            assay_features = assay_features.to(device) if not isinstance(assay_features[0], str) else assay_features
            activity = activity.to(device)

            # forward
            if 'Multitask' in hparams['model']:
                assay_features_norm = F.normalize(assay_features, p=2, dim=1)
                sim_to_train = assay_features_norm @ train_assay_features_norm.T
                sim_to_train_weights = F.softmax(sim_to_train * hparams['multitask_temperature'], dim=1)
                preactivations = model(compound_features, sim_to_train_weights)
            else:
                preactivations = model(compound_features, assay_features)   
            
            # loss
            loss = criterion(preactivations, activity)

            # accumulate loss
            loss_sum += loss.item()

            # accumulate preactivations
            # - need to detach; preactivations.requires_grad is True
            # - move it to cpu
            preactivations_l.append(preactivations.detach().cpu()) #?Yu: why is `detach` used here but not in the `def train_and_test``

            # accumulate indices just to double check
            # - activity_idx is a np.array, not a torch.tensor
            activity_idx_l.append(activity_idx)

            if nb % EVERY == 0 and verbose:
                logger.info(f'Epoch {epoch}: Test batch {nb} out of {len(test_batcher) - 1}.')

        # log mean loss over all minibatches
        mlflow.log_metric('test_loss', loss_sum / len(test_batcher), step=epoch)
        if wandb.run: wandb.log({'test/loss': loss_sum/len(test_batcher)}, step=epoch)

        # compute test auroc and avgp for each assay (on the cpu)
        preactivations = torch.cat(preactivations_l, dim=0)
        probabilities = torch.sigmoid(preactivations).numpy()

        activity_idx = np.concatenate(activity_idx_l, axis=0)
        assert np.array_equal(activity_idx, test_idx) #?Yu: this code line is commented out in the `def train_and_test`

        targets = sparse.csc_matrix(
            (
                InMemory.activity.data[test_idx],
                (
                    InMemory.activity.row[test_idx],
                    InMemory.activity.col[test_idx]
                )
            ), shape=(InMemory.num_compounds, InMemory.num_assays), dtype=np.bool
        )

        scores = sparse.csc_matrix(
            (
                probabilities,
                (
                    InMemory.activity.row[test_idx],
                    InMemory.activity.col[test_idx]
                )
            ), shape=(InMemory.num_compounds, InMemory.num_assays), dtype=np.float32
        )

        #md = metrics.swipe_threshold_sparse(targets=targets, scores=scores, verbose=verbose>=2, ret_dict=True)
        md = swipe_threshold_sparse(targets=targets, scores=scores, verbose=verbose>=2, ret_dict=True)

        # log metrics mean over assays
        logdic = {f'test_mean_{mdk}': np.mean(list(md[f'{mdk}'].values())) for mdk in md.keys()} #?Yu: why is `mdk` used here? different from the one in the `def train_and_test`
        mlflow.log_metrics(logdic, step=epoch)

        if wandb.run: wandb.log({k.replace('_', '/'):v for k, v in logdic.items()}, step=epoch)
        if verbose: logger.info(logdic)

        # compute test activity counts and positives
        counts, positives = {}, {}
        for idx, col in enumerate(targets.T):#?Yu: why `targets.T`?
            if col.nnz == 0:
                continue
            counts[idx] = col.nnz
            positives[idx] = col.sum()
        
        # store test metrics and counts in a parquet file
        metrics_df = pd.DataFrame(md)
        metrics_df['argmax_j'] = metrics_df['argmax_j'].apply(sigmoid)
        metrics_df['counts'] = counts
        metrics_df['positives'] = positives

        metrics_df.index.rename('assay_idx', inplace=True)

        metrics_df = InMemory.assay_names.merge(metrics_df, left_index=True, right_index=True)
        logger.info(f'Writing test metrics to {metrics_file_path}')
        metrics_df.to_parquet(metrics_file_path, compression=None, index=True)

        if wandb.run:
            wandb.log({"metrics_per_assay": wandb.Table(data=metrics_df)})
        
        logger.info(f'Saved best test-metrics to {metrics_file_path}')
        logger.info(f'Saved best checkpoint to {checkpoint_file_path}')

        model.train() #?Yu: why is this line here?

        with pd.option_context('float_format', "{:.2f}".format):
            print(metrics_df)
        
        return metrics_df

# train.py

In [18]:
"""example call:
python clamp/train.py \
    --dataset=./data/fsmol \
    --split=FSMOL_split \
    --assay_mode=clip \
    --compound_mode=morganc+rdkc 
"""

""" training pubchem23 without downstream datasets
python clamp/train.py \
    --dataset=./data/pubchem23/ \
    --split=time_a \
    --assay_mode=clip \
    --batch_size=8192 \
    --dropout_hidde=0.3 \ #?Yu: this parameter is not implemented in the primary code
    --drop_cidx_path=./data/pubchem23/cidx_overlap_moleculenet.npy \
    --train_subsample=10e6 \
    --wandb --experiment=pretrain
"""

' training pubchem23 without downstream datasets\npython clamp/train.py     --dataset=./data/pubchem23/     --split=time_a     --assay_mode=clip     --batch_size=8192     --dropout_hidde=0.3 \\ #?Yu: this parameter is not implemented in the primary code\n    --drop_cidx_path=./data/pubchem23/cidx_overlap_moleculenet.npy     --train_subsample=10e6     --wandb --experiment=pretrain\n'

In [19]:
def parse_args_override(override_hpjson=True): #?Yu: why set this to True?
    parser = argparse.ArgumentParser('Train and test a single run of clamp-Activity model. Overrides arguments from hyperparam-file')
    parser.add_argument('-f', type=str) #?Yu 
    parser.add_argument('--dataset', type=str, default='./data/fsmol', help='Path to a prepared dataset directory') #?Yu: parquet file or npy file or others?
    parser.add_argument('--assay_mode', type=str, default='lsa', help='Type of assay features("clip", "biobert", or "lsa")') #?Yu: why lsa is default? where is lsa implemented?#
    parser.add_argument('--compound_mode', type=str, default='morganc+rdkc', help='Type of compound features (default: morganc+rdkc)') 
    parser.add_argument('--hyperparams', type=str, default='./hparams/default.json', help='Path to hyperparameters to use in training (json, Hyperparams, or logs).')

    parser.add_argument('--checkpoint', help='Path to a model-optimizer PyTorch checkpoint from which to resume training.', metavar='')
    parser.add_argument('--experiment', type=str, default='debug', help='Name of MLflow experiment where to assign this run.', metavar='')
    parser.add_argument('--random', action='store_true', help='Forget about the specified model and run a random baseline.') #?Yu: delete it later if not used?

    #?Yu: why the arguments below are commented out in the primary code?
    parser.add_argument('--optimizer', type=str, default='AdamW', help='Optimizer to use for training (default: AdamW).')
    parser.add_argument('--l2', type=float, default=0.01, help='Weight decay to use for training (default: 0.01).')
    parser.add_argument('--loss_fun', type=str, default='BCE', help='Loss function to use for training (default: BCE).')
    parser.add_argument('--epoch_max', type=int, default=50, help='Maximum number of epochs to train for (default: 100).')
    parser.add_argument('--lr_ini', type=float, default=1e-5, help='Initial learning rate (default: 1e-5).' )

    parser.add_argument('--gpu', type=str, default='0', help='GPU number to use (default: 0).', metavar='')
    parser.add_argument('--seed', type=int, default=None, help='seed everything with provided seed, default no seed')
    
    parser.add_argument('--split', type=str, default='time_a_c', help='split-type. Default:time_a_c for time based assay and compound split. Options: time_a, time_c, random:{seed}, or column of activity.parquet triplet.') #?Yu: shall I modify these split options?
    parser.add_argument('--support_set_size', type=int, default=0, help='per task how many to add from test- as well as valid- to the train-set (Default:0, i.e. zero-shot).') #?Yu: '0' -> 0
    parser.add_argument('--train_only_actives', action='store_true', help='train only with active molecules.')
    parser.add_argument('--drop_cidx_path', type=str, default=None, help='Path to a file containing a np of cidx (NOT CIDs) to drop from the dataset.') #?Yu: a np of cidx?

    parser.add_argument('--verbose','-v', type=int, default=0, help='verbosity level (default:0)') #?Yu: what does verbosity level mean?
    parser.add_argument('--wandb', '-w', action='store_true', help='wandb logging on')
    parser.add_argument('--bf16', action='store_true', help='use bfloat16 for training') #?Yu: bfloat16?

    args, unknown = parser.parse_known_args() #?Yu: ?
    keypairs = dict([unknown[i:i+2] for i in range(0, len(unknown), 1) if unknown[i].startswith('--') and not (unknown[i+1:i+2]+["--"])[0].startswith('--')]) #?Yu: don't understand. delete it?

    hparams = get_hparams(path=args.hyperparams, mode='json', verbose=args.verbose)

    if override_hpjson:
        for k, v in NAME2FORMATTER.items():
            if (k not in args):
                default = hparams.get(k, None)
                parser.add_argument('--'+k, type=v, default=default)
                if (k in keypairs):
                    logger.info(f'{k} from hparams file will be overwritten')
        args = parser.parse_args()
    
    #?Yu: the below args are not in the 'add_argument' function, why?
    if args.nonlinearity is None:
        args.nonlinearity = 'ReLU'
    if args.compound_layer_sizes is None:
        logger.info('no compound_layer_sizes provided, setting to hidden_layers')
        args.compound_layer_sizes = args.hidden_layers
    if args.assay_layer_sizes is None:
        logger.info('no assay_layer_sizes provided, setting to hidden_layers')
        args.assay_layer_sizes =  args.hidden_layers

    return args

In [None]:
def setup_dataset(dataset='./data/fsmol', assay_mode='lsa', compound_mode='morganc+rdkc', split='split', 
                  verbose=False, support_set_size=0, drop_cidx_path=None, train_only_actives=False, **kwargs):
    """
    Setup the dataset by given a dataset-path.
    Loads an InMemoryClamp object containing:
    - split: 'split' is the column name in the activity.parquet, 'time_a_c' is in the primary code
    - support_set_size: 0, adding {support_set_size} samples from test and from valid to train (per assay/task);
    - train_only_actives: False, only uses the active compounds;
    - drop_cidx_path: None, path to a npy file containing cidx (NOT CIDs) to drop from the dataset.
    """
    dataset = Path(dataset)
    clamp_dl = InMemoryClamp(
        root=dataset,
        assay_mode=assay_mode,
        compound_mode=compound_mode,
        verbose=verbose,
    )

    # ===== split =====
    logger.info(f'loading split info from activity.parquet triplet-list under the column split={split}')
    try:
        splits = pd.read_parquet(dataset/'activity.parquet')[split]
    except KeyError:
        raise ValueError(f'no split column {split} in activity.parquet', pd.read_parquet(dataset/'activity.parquet').columns, 'columns available')
    train_idx, valid_idx, test_idx =[splits[splits==sp].index.values for sp in ['train', 'valid', 'test']]

    # ===== support_set_size =====

    # ===== train_only_actives =====

    # ===== drop_cidx_path =====

    # ===== verbose =====

    return clamp_dl, train_idx, valid_idx, test_idx

In [None]:
def main(args):
    # Hyperparameter Preparation
    hparams = args.__dict__

    # MLflow Experiment Setup
    mlflow.set_experiment(args.experiment)

    # Seeding (Optional)
    if args.seed:
        seed_everything(args.seed)
        logger.info(f'seeded everything with seed {args.seed}') #?Yu: if not needed, delete it?
    
    # Dataset Preparation
    clamp_dl, train_idx, valid_idx, test_idx = setup_dataset(**args.__dict__)
    assert set(train_idx).intersection(set(valid_idx)) == set() # assert no overlap between the splits.
    assert set(train_idx).intersection(set(test_idx)) == set()

    # Weights & Biases (wandb) Logging
    if args.wandb:
        runname = args.experiment+args.split[-1]+args.assay_mode[-1]
        if args.random:
            runname += 'random'
        else:
            runname = str(runname)+''
            runname += str(args.model) #?Yu: what could `args.model` be?
        runname += ''.join([chr(random.randrange(97, 97 +26)) for _ in range(3)]) # to add some randomness to the run name
        wandb.init(project='clipGPCR', entity='yu', name=runname, config=args.__dict__)

    # Device Setup
    device = set_device(gpu=args.gpu, verbose=args.verbose)

    # Metrics DataFrame Initialization
    metrics_df = pd.DataFrame()

    # Training/Testing Run (with MLflow Logging)
    try:
        with mlflow.start_run(): # begins a new experiment run.
            mlflowi = mlflow.active_run().info # provides metadata (like run id, experiment id) for this run.
        
        # Checkpoint Resume Logging
        if args.checkpoint is not None:
            mlflow.set_tag(
                'mlflow.note.content',
                f'Resumed training from {args.checkpoint}.'
            )
        
        # Assay Mode Consistency and Logging
        if 'assay_mode' in hparams:
            if hparams['assay_mode'] != args.assay_mode:
                # Warn if there's a mismatch.
                logger.warning(f'Assay features are "{args.assay_mode}" in command line but \"{hparams["assay_mode"]}\" in hyperparameter file.')
                logger.warning(f'Command line "{args.assay_mode}" is the prevailing option.')
                hparams['assay_mode'] = args.assay_mode
        else:
            mlflow.log_param('assay_mode', args.assay_mode)
        mlflow.log_params(hparams) # Logs all hyperparamters to MLflow for easy reference and reproducibility.

        # Comment out the below code block because the random baseline is seemed unnecessary for my current plan.
        #if args.random:
        #    mlflow.set_tag(
        #        'mlflow.note.content',
        #        'Ignore the displayed parameters. Metrics correspond to predictions randomly drawn from U(0, 1).'
        #    )
        #    utils.random(
        #        clamp_dl,
        #        test_idx=test_idx,
        #        run_info=mlflowi,
        #        verbose=args.verbose)
        #else:
        #metrics_df = utils.train_and_test(
        metrics_df = train_and_test(
            clamp_dl, 
            train_idx=train_idx,
            valid_idx=valid_idx,
            test_idx=test_idx,
            hparams=hparams,
            run_info=mlflowi,
            checkpoint_file=args.checkpoint,
            device=device,
            bf16=args.bf16,
            verbose=args.verbose)
    
    except KeyboardInterrupt:
        logger.error('Training manually interrupted. Trying to test with last checkpoint.')
        # If the training is manually interrupted, it still tries to evaluate (test) the model using the last checkpoint, and logs results to the same MLflow run.
        #?Yu: delete the below code if not used.
        #metrics_df = utils.test(
        metrics_df = test(
            clamp_dl,
            train_idx=train_idx,
            test_idx=test_idx,
            hparams=hparams,
            run_info=mlflowi,
            device=device,
            verbose=args.verbose,
        )

In [None]:
if __name__ == '__main__':
    args = parse_args_override()

    run_id = str(time()).split('.')[0]
    fn_postfix = f'{args.experiment}_{run_id}'

    if args.verbose>=1:
        logger.info('Run args:', os.getcwd()+__file__, args.__dict__)

    main(args)