In [1]:
import sys
import os

src_path = os.path.split(os.getcwd())[0]
sys.path.insert(0, src_path)

import json
import logging
import numpy as np
import pandas as pd
from pathlib import Path, PurePath
from collections import OrderedDict
from itertools import chain

import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import clip.clip as clip
from training.datasets import CellPainting
from clip.clip import _transform
from clip.model import convert_weights, CLIPGeneral
from tqdm import tqdm

from rdkit.Chem import AllChem
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import DataStructs

from sklearn.metrics import accuracy_score, top_k_accuracy_score

from huggingface_hub import hf_hub_download

In [2]:
FILENAME = "cloome-retrieval-zero-shot.pt"
REPO_ID = "anasanchezf/cloome"
checkpoint_path = hf_hub_download(REPO_ID, FILENAME)
checkpoint_path = "/system/user/studentwork/seibezed/cloome/src/logs/retrieval_imgres=[520]_hidden_dim=1024_molecule_layers=4_normalize=dataset_init_inv_tau=14.3_learn_inv_tau=True_lr=0.001_wd=0.1_agg=True_model=RN50_world_size=4batchsize=32_workers=8_date=2025-08-29-18-03-47_/checkpoints/epoch_69.pt"

In [3]:
# CLOOB
model = "RN50-8192"
image_resolution = 520
img_path = "/system/user/publicdata/cellpainting/npzs/chembl24"
mol_path = "/system/user/studentwork/seibezed/cloome/src/data/morgan_chiral_fps_8192_test_scaffold.hdf5"
val = "/publicwork/sanchez_copied/data/murcko_test_imgpermol_final_1398mols.csv"


In [4]:
def convert_models_to_fp32(model):
    for p in model.parameters():
        p.data = p.data.float()
        if p.grad:
            p.grad.data = p.grad.data.float()

In [5]:
def load(model_path, device, model, image_resolution):
    checkpoint = torch.load(model_path)
    state_dict = checkpoint["state_dict"]

    model_config_file = os.path.join(src_path, f"training/model_configs/{model.replace('/', '-')}.json")
    print('Loading model from', model_config_file)
    assert os.path.exists(model_config_file)
    with open(model_config_file, 'r') as f:
        model_info = json.load(f)
    model = CLIPGeneral(**model_info)

    if str(device) == "cpu":
        model.float()
    print(device)
    #print({k for k,v in state_dict.items()})
    new_state_dict = {k[len('module.'):]: v for k,v in state_dict.items()}

    model.load_state_dict(new_state_dict)
    model.to(device)
    model.eval()

    return model, _transform(image_resolution, image_resolution,  is_train=False)

In [6]:
def get_features(dataset, model, device):
    all_image_features = []
    all_text_features = []
    all_ids = []

    print(f"get_features {device}")
    print(len(dataset))

    with torch.no_grad():
        for batch in tqdm(DataLoader(dataset, num_workers=20, batch_size=64)):
            #print(mols)
            imgs, mols = batch

            images, mols = imgs["input"], mols["input"]
            ids = imgs["ID"]
            
            img_features = model.encode_image(images.to(device))
            text_features = model.encode_text(mols.to(device))

            img_features = img_features / img_features.norm(dim=-1, keepdim=True)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

            all_image_features.append(img_features)
            all_text_features.append(text_features)
            all_ids.append(ids)

        all_ids = list(chain.from_iterable(all_ids))
    return torch.cat(all_image_features), torch.cat(all_text_features), all_ids

In [7]:
def main(df, model_path, model, img_path, mol_path, image_resolution):
    # Load the model
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(torch.cuda.device_count())

    model, preprocess = load(model_path, device, model, image_resolution)

    preprocess_train = _transform(image_resolution, image_resolution, is_train=True)
    preprocess_val = _transform(image_resolution, image_resolution, is_train=False, normalize="dataset", preprocess="crop")

    # Load the dataset
    val = CellPainting(df,
                       img_path,
                       mol_path,
                       transforms = preprocess_val)

    # Calculate the image features
    print("getting_features")
    val_img_features, val_text_features, val_ids = get_features(val, model, device)
    
    return val_img_features, val_text_features, val_ids

In [8]:
def get_metrics(image_features, text_features):
    metrics = {}
    logits_per_image = image_features @ text_features.t()
    logits_per_text = logits_per_image.t()

    logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text}
    ground_truth = (
        torch.arange(len(text_features)).view(-1, 1).to(logits_per_image.device)
    )

    rankings = {}
    all_top_samples = {}
    all_preds = {}

    for name, logit in logits.items():
        ranking = torch.argsort(logit, descending=True)
        rankings[name] = ranking
        preds = torch.where(ranking == ground_truth)[1]
        preds = preds.detach().cpu().numpy()
        all_preds[name] = preds
        top_samples = np.where(preds < 10)[0]
        all_top_samples[name] = top_samples
        metrics[f"{name}_mean_rank"] = preds.mean() + 1
        metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1
        for k in [1, 5, 10]:
            metrics[f"{name}_R@{k}"] = np.mean(preds < k)

    return rankings, all_top_samples, all_preds, metrics, logits

In [9]:
val_img_features, val_text_features, val_ids = main(val, checkpoint_path, model, img_path, mol_path, image_resolution)

4
Loading model from /system/user/studentwork/seibezed/cloome/src/training/model_configs/RN50-8192.json
cuda




1398
getting_features
get_features cuda
1398


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 22/22 [00:31<00:00,  1.43s/it]


In [10]:
rankings, all_top_samples, all_preds, metrics, logits = get_metrics(val_img_features, val_text_features)

In [11]:
metrics

{'image_to_text_mean_rank': 481.9856938483548,
 'image_to_text_median_rank': 408.0,
 'image_to_text_R@1': 0.031473533619456366,
 'image_to_text_R@5': 0.06151645207439199,
 'image_to_text_R@10': 0.07439198855507868,
 'text_to_image_mean_rank': 490.28111587982835,
 'text_to_image_median_rank': 402.0,
 'text_to_image_R@1': 0.02861230329041488,
 'text_to_image_R@5': 0.060085836909871244,
 'text_to_image_R@10': 0.07081545064377683}

In [12]:
ground_truth = (
    torch.arange(len(val_text_features)).view(-1, 1).to("cpu")
)
ground_truth

tensor([[   0],
        [   1],
        [   2],
        ...,
        [1395],
        [1396],
        [1397]])

In [13]:
all_preds = []

for i, logs in enumerate(logits["image_to_text"]):
    choices = np.arange(len(val_text_features))
    choices = np.delete(choices, i)
        
    logs = logs.cpu().numpy()
    
    positive = logs[i]
    negatives_ind = np.random.choice(choices, 99, replace=False)
    negatives = logs[negatives_ind]
    
    sampled_logs = np.hstack([positive, negatives])
    
    ground_truth = np.zeros(len(sampled_logs))
    ground_truth[0] = 1
    
    ranking = np.argsort(sampled_logs)
    ranking = np.flip(ranking)
    pred = np.where(ranking == 0)[0]
    all_preds.append(pred)


all_preds = np.vstack(all_preds)
print(all_preds)

r1 = np.mean(all_preds < 1) * 100
r5 = np.mean(all_preds < 5) * 100
r10 = np.mean(all_preds < 10) * 100
print(r1, r5, r10)

n1 = len(np.where(all_preds < 1)[0])
n5 = len(np.where(all_preds < 5)[0])
n10 = len(np.where(all_preds < 10)[0])
print(n1, n5, n10)

[[36]
 [ 0]
 [14]
 ...
 [56]
 [33]
 [60]]
8.583690987124463 18.812589413447782 26.46638054363376
120 263 370


In [14]:
all_preds_t = []

for i, logs in enumerate(logits["text_to_image"]):
    choices = np.arange(len(val_text_features))
    choices = np.delete(choices, i)
    
    logs = logs.cpu().numpy()
    
    positive = logs[i]
    negatives_ind = np.random.choice(choices, 99, replace=False)
    negatives = logs[negatives_ind]
    
    sampled_logs = np.hstack([positive, negatives])
    
    ground_truth = np.zeros(len(sampled_logs))
    ground_truth[0] = 1
    
    ranking = np.argsort(sampled_logs)
    ranking = np.flip(ranking)
    pred = np.where(ranking == 0)[0]
    all_preds_t.append(pred)

all_preds_t = np.vstack(all_preds_t)
print(all_preds_t)

r1_t = np.mean(all_preds_t < 1) * 100
r5_t = np.mean(all_preds_t < 5) * 100
r10_t = np.mean(all_preds_t < 10) * 100
print(r1_t, r5_t, r10_t)

n1_t = len(np.where(all_preds_t < 1)[0])
n5_t = len(np.where(all_preds_t < 5)[0])
n10_t = len(np.where(all_preds_t < 10)[0])
print(n1_t, n5_t, n10_t)

[[43]
 [ 1]
 [16]
 ...
 [55]
 [18]
 [58]]
7.582260371959943 16.595135908440632 24.82117310443491
106 232 347


In [21]:
from scipy.stats import binomtest

n_samples = 1398
value = 0.0708*100

successes = int(round(value * n_samples / 100))
print(successes)

btest = binomtest(k=successes, n=n_samples)
result = btest.proportion_estimate * 100
ci = btest.proportion_ci(confidence_level=0.95)
    
print(result)
print(ci)

99
7.081545064377683
ConfidenceInterval(low=0.05792469613034303, high=0.08554312389972292)
