In [2]:
import numpy as np 
import torch
import sys
import pandas as pd
from tqdm import tqdm
import pickle as pkl 
from IMPA.solver import IMPAmodule
from IMPA.dataset.data_loader import CellDataLoader
from omegaconf import OmegaConf
import yaml
from pathlib import Path
import sklearn

sys.path.append('/home/icb/alessandro.palma/environment/IMPA/IMPA/IMPA/eval/gan_metrics')
sys.path.insert(0, '/home/icb/alessandro.palma/environment/IMPA/IMPA/IMPA')

from fid import *
from density_and_coverage import compute_d_c


sys.path.insert(0, '/lustre/groups/ml01/workspace/alessandro.palma/imCPA_official/experiments/general_experiments/1.benchmark_scores')
sys.path.insert(0, '/lustre/groups/ml01/workspace/alessandro.palma/imCPA_official/experiments/general_experiments/5.interpretability')
from compute_scores import *
from util_functions import CustomTransform

In [3]:
def initialize_model(yaml_config, dest_dir):
    args_rdkit = OmegaConf.create(yaml_config)
    dataloader_rdkit = CellDataLoader(args_rdkit)
    solver_rdkit = IMPAmodule(args_rdkit, dest_dir, dataloader_rdkit)
    return solver_rdkit

class Args:
    def __init__(self, dictionary):
        self.__dict__ = dictionary

    def __getattr__(self, key):
        if key in self.__dict__:
            return self.__dict__[key]
        else:
            raise AttributeError(f"'DictToObject' object has no attribute '{key}'")

    def __call__(self, key):
        return self.__getattr__(key)

In [97]:
bbbbc021_embeddings = pd.read_csv("/home/icb/alessandro.palma/environment/IMPA/IMPA/embeddings/csv/emb_fp_all.csv", index_col=0)

bbbc021_index = pd.read_csv("/lustre/groups/ml01/workspace/alessandro.palma/imCPA_official/data/bbbc021_unannotated/processed/bbbc021_unannotated_large/metadata/bbbc021_unannotated_large_subset.csv",
                           index_col=0)

In [101]:
path_to_configs = Path("/home/icb/alessandro.palma/environment/IMPA/IMPA/config_hydra/config")

with open(path_to_configs / 'REBUTTAL_bbbc021_large_all.yaml', 'r') as IMPA_bbbc021:
    # Load YAML data using safe_load() from the file
    yaml_IMPA_bbbc021 = yaml.safe_load(IMPA_bbbc021)

yaml_IMPA_bbbc021['data_index_path'] = '/lustre/groups/ml01/workspace/alessandro.palma/imCPA_official/data/bbbc021_unannotated/processed/bbbc021_unannotated_large/metadata/bbbc021_unannotated_large_subset.csv'
yaml_IMPA_bbbc021['ood_set'] = None
dest_dir = "/home/icb/alessandro.palma/environment/IMPA/IMPA/project_folder/experiments/20240825_783bda05-f711-4c90-8a4f-dc04e89cbd92_bbbc021_unannotated_large"

args = OmegaConf.create(yaml_IMPA_bbbc021)
dataloader = CellDataLoader(args)

In [102]:
# INITIALIZE CLASSIFIER (OF MOA)
classifier = Discriminator(img_size=96, 
                    num_domains=13, 
                    max_conv_dim=512, 
                    in_channels=3, 
                    dim_in=64, 
                    multi_task=False).to("cuda")  
classifier.load_state_dict(torch.load("/lustre/groups/ml01/workspace/alessandro.palma/imCPA_official/experiments/general_experiments/7.train_classifier/checkpoints/larger_fov/checkpoints_all_drugs.ckpt"))
classifier.eval()

Discriminator(
  (conv): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ResBlk(
      (actv): LeakyReLU(negative_slope=0.2)
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv1x1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    )
    (2): ResBlk(
      (actv): LeakyReLU(negative_slope=0.2)
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv1x1): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    )
    (3): ResBlk(
      (actv): LeakyReLU(negative_slope=0.2)
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv1x1): Conv2d(256, 512, kern

In [123]:
y_real = []
y_pred = []

for batch in dataloader.val_dataloader():
    with torch.no_grad():
        y_moa = batch["y_id"]
        y_hat = classifier(batch["X"][1].cuda(), None).argmax(1)
        y_real += y_moa.tolist()
        y_pred += y_hat.tolist()

In [127]:
np.sum(np.array(y_real)==np.array(y_pred))/len(y_pred)

0.8068849706129303

In [135]:
dataloader.y2id

{'Actin disruptors': 0,
 'Aurora kinase inhibitors': 1,
 'Cholesterol-lowering': 2,
 'DMSO': 3,
 'DNA damage': 4,
 'DNA replication': 5,
 'Eg5 inhibitors': 6,
 'Epithelial': 7,
 'Kinase inhibitors': 8,
 'Microtubule destabilizers': 9,
 'Microtubule stabilizers': 10,
 'Protein degradation': 11,
 'Protein synthesis': 12}

In [152]:
pd.DataFrame(sklearn.metrics.classification_report(y_real, y_pred, output_dict=True)).transpose()

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Unnamed: 0,precision,recall,f1-score,support
0,0.948276,0.808824,0.873016,136.0
1,0.964286,0.84375,0.9,128.0
2,0.925926,0.531915,0.675676,94.0
3,0.0,0.0,0.0,0.0
4,0.931507,0.68,0.786127,200.0
5,0.834146,0.830097,0.832117,206.0
6,0.742857,0.604651,0.666667,43.0
7,0.666667,0.25,0.363636,24.0
8,0.942029,0.643564,0.764706,101.0
9,0.839806,0.812207,0.825776,213.0


In [159]:
{key:val for key,val in dict(zip(bbbc021_index.CPD_NAME, bbbc021_index.ANNOT)).items()}

{'ALLN': 'Protein degradation',
 'alsterpaullone': 'Kinase inhibitors',
 'anisomycin': 'Protein synthesis',
 'bryostatin': 'Kinase inhibitors',
 'camptothecin': 'DNA replication',
 'chlorambucil': 'DNA damage',
 'cisplatin': 'DNA damage',
 'colchicine': 'Microtubule destabilizers',
 'cyclohexamide': 'Protein synthesis',
 'cytochalasin B': 'Actin disruptors',
 'cytochalasin D': 'Actin disruptors',
 'demecolcine': 'Microtubule destabilizers',
 'docetaxel': 'Microtubule stabilizers',
 'emetine': 'Protein synthesis',
 'epothilone B': 'Microtubule stabilizers',
 'etoposide': 'DNA damage',
 'floxuridine': 'DNA replication',
 'lactacystin': 'Protein degradation',
 'latrunculin B': 'Actin disruptors',
 'methotrexate': 'DNA replication',
 'mevinolin/lovastatin': 'Cholesterol-lowering',
 'MG-132': 'Protein degradation',
 'mitomycin C': 'DNA damage',
 'mitoxantrone': 'DNA replication',
 'nocodazole': 'Microtubule destabilizers',
 'PD-169316': 'Kinase inhibitors',
 'PP-2': 'Epithelial',
 'proteaso

* ALLN --> Protein degradation 0.74
* Bryostatin --> Kinase inhibitors 0.76
* MG-132 --> Protein degradation 0.74
* methotrexate --> DNA replication 0.83
* colchicine --> Microtubule Destabilizer 0.83
* cytochalasin B --> Actin disruption 0.87
* AZ258 --> Aurora Kinase inhibitor 0.90
* cisplatin --> DNA damage 0.78