In [93]:
import sys
import os

src_path = os.path.split(os.getcwd())[0]
sys.path.insert(0, src_path)

import json
import logging
import numpy as np
import pandas as pd
from pathlib import Path, PurePath
from collections import OrderedDict
from itertools import chain

import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import clip.clip as clip
from training.datasets import CellPainting
from clip.clip import _transform
from clip.model import convert_weights, CLIPGeneral
from tqdm import tqdm

from rdkit.Chem import AllChem
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import DataStructs

from sklearn.metrics import accuracy_score, top_k_accuracy_score

from huggingface_hub import hf_hub_download

In [None]:
FILENAME = "cloome-retrieval-zero-shot.pt"
REPO_ID = "anasanchezf/cloome"
checkpoint_path = hf_hub_download(REPO_ID, FILENAME)

checkpoint_path = "/system/user/studentwork/seibezed/cloome/src/logs/imgres=[520]_hidden_dim=1024_molecule_layers=4_normalize=dataset_init_inv_tau=14.3_learn_inv_tau=False_lr=0.001_wd=0.1_agg=True_model=RN50_world_size=4batchsize=32_workers=8_date=2025-09-02-12-15-07_/checkpoints/epoch_63.pt"

In [95]:
# CLOOB
model = "RN50"
image_resolution = 520
img_path = "/system/user/publicdata/cellpainting/npzs/chembl24"
mol_path = "/system/user/studentwork/seibezed/cloome/src/data/morgan_chiral_fps_1024_test_zs_molecules_scaffold.hdf5"
val = "/publicwork/sanchez_copied/data/cellpainting-test-phenotype-imgpermol-scaffold.csv"
classes = "/publicwork/sanchez_copied/data/cellpainting-split-test-imgpermol-scaffold.csv"

In [96]:
def convert_models_to_fp32(model):
    for p in model.parameters():
        p.data = p.data.float()
        if p.grad:
            p.grad.data = p.grad.data.float()

In [97]:
def load(model_path, device, model, image_resolution):

    checkpoint = torch.load(model_path)
    state_dict = checkpoint["state_dict"]

    model_config_file = os.path.join(src_path, f"training/model_configs/{model.replace('/', '-')}.json")

    print('Loading model from', model_config_file)
    assert os.path.exists(model_config_file)
    with open(model_config_file, 'r') as f:
        model_info = json.load(f)
    model = CLIPGeneral(**model_info)

    if str(device) == "cpu":
        model.float()
    print(device)

    new_state_dict = {k[len('module.'):]: v for k,v in state_dict.items()}

    model.load_state_dict(new_state_dict)
    model.to(device)
    model.eval()

    return model, _transform(image_resolution, image_resolution,  is_train=False)

In [98]:
def get_features(dataset, model, device):
    all_image_features = []
    all_text_features = []
    all_ids = []

    print(f"get_features {device}")
    print(len(dataset))

    with torch.no_grad():
        for batch in tqdm(DataLoader(dataset, num_workers=20, batch_size=64)):
            #print(mols)
            imgs, mols = batch

            images, mols = imgs["input"], mols["input"]
            ids = imgs["ID"]
            
            img_features = model.encode_image(images.to(device))
            text_features = model.encode_text(mols.to(device))

            img_features = img_features / img_features.norm(dim=-1, keepdim=True)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

            all_image_features.append(img_features)
            all_text_features.append(text_features)
            all_ids.append(ids)

        all_ids = list(chain.from_iterable(all_ids))
    return torch.cat(all_image_features), torch.cat(all_text_features), all_ids

In [99]:
def main(df, model_path, model, img_path, mol_path, image_resolution):
    # Load the model
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(torch.cuda.device_count())

    model, preprocess = load(model_path, device, model, image_resolution)

    preprocess_train = _transform(image_resolution, image_resolution, is_train=True)
    preprocess_val = _transform(image_resolution, image_resolution, is_train=False, normalize="dataset", preprocess="downsize")

    # Load the dataset
    val = CellPainting(df,
                       img_path,
                       mol_path,
                       transforms = preprocess_val)

    # Calculate the image features
    print("getting_features")
    val_img_features, val_text_features, val_ids = get_features(val, model, device)
    
    return val_img_features, val_text_features, val_ids

In [100]:
val_img_features, val_text_features, val_ids = main(val, checkpoint_path, model, img_path, mol_path, image_resolution)
class_img_features, class_text_features, class_ids = main(classes, checkpoint_path, model, img_path, mol_path, image_resolution)
val_img_features = val_img_features.cpu()
class_img_features = class_img_features.cpu()

4
Loading model from /system/user/studentwork/seibezed/cloome/src/training/model_configs/RN50.json
cuda
28473
getting_features
get_features cuda
28473


100%|██████████| 445/445 [03:42<00:00,  2.00it/s]


4
Loading model from /system/user/studentwork/seibezed/cloome/src/training/model_configs/RN50.json
cuda
1398
getting_features
get_features cuda
1398


100%|██████████| 22/22 [00:19<00:00,  1.12it/s]


In [101]:
val_img_features.shape

torch.Size([28473, 512])

In [102]:
classes_df = pd.read_csv(classes)
classes_df.set_index("SAMPLE_KEY", inplace=True)
class_inchis = classes_df.loc[class_ids]["INCHIKEY"]

In [103]:
val_df = pd.read_csv(val)
val_df.set_index("SAMPLE_KEY", inplace=True)
val_inchis = val_df.loc[val_ids]["INCHIKEY"]

In [104]:
class_dict = {}

for i, inchi in enumerate(class_inchis): 
    class_dict[inchi] = i

In [105]:
ground_truth = np.zeros(len(val_inchis), dtype=int)

for i, inchi in enumerate(val_inchis): 
    label = class_dict[inchi]
    ground_truth[i] = int(label)

In [106]:
# Calculating accuracies in several ways

# WAY 1
logits = val_img_features @ class_img_features.T
acc = accuracy_score(ground_truth, logits.argmax(axis=1)) * 100.0
print(acc)

21.81716011660169


In [107]:
# WAY 2
N = ground_truth.shape[0]
(np.array(ground_truth) == logits.argmax(axis=1).numpy()).sum() / N * 100

21.81716011660169

In [108]:
# WAY 3
ranking = torch.argsort(logits, descending=True)
t = torch.tensor(ground_truth, dtype=torch.int16).view(-1,1)

preds = torch.where(ranking == t)[1]
preds = preds.detach().cpu().numpy()

metrics = {}
for k in [1, 5, 10]:
    metrics[f"R@{k}"] = np.mean(preds < k) * 100
    
print(metrics)

{'R@1': 21.81716011660169, 'R@5': 49.15534014680575, 'R@10': 65.65518210234256}


In [109]:
# WAY 4
probs = (val_img_features @ class_img_features.T).softmax(dim=-1)

metrics_skl = {}
for k in [1, 5, 10]:
    metrics_skl[f"R@{k}"] = top_k_accuracy_score(ground_truth, probs, k=k) * 100
    
print(metrics_skl)

{'R@1': 21.81716011660169, 'R@5': 49.15534014680575, 'R@10': 65.65518210234256}


In [110]:
from scipy.stats import binomtest

n_samples = val_img_features.shape[0]

mdict, cis = {}, {}

for metric, value in metrics.items():
    successes = int(value * n_samples / 100)
    btest = binomtest(k=successes, n=n_samples)
    mdict[metric] = btest.proportion_estimate * 100
    cis[metric] = btest.proportion_ci(confidence_level=0.95)
    
print(mdict)
print(cis)

{'R@1': 21.81716011660169, 'R@5': 49.15534014680575, 'R@10': 65.65167000316089}
{'R@1': ConfidenceInterval(low=0.21338579779160258, high=0.22301497567509646), 'R@5': ConfidenceInterval(low=0.48573011632899354, high=0.4973784106100583), 'R@10': ConfidenceInterval(low=0.6509676026653524, high=0.6620338446570153)}
