In [2]:
from IMPA.dataset.data_loader import CellDataLoader
from IMPA.solver import IMPAmodule
from omegaconf import OmegaConf
import matplotlib.pyplot as plt
from tqdm import tqdm
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
import os
import torch
from tqdm import tqdm
import pandas as pd
import seaborn as sns
from skimage import io, color, filters, measure
import pickle as pkl
import yaml
import scanpy as sc

import sys
sys.path.insert(0, "/home/icb/alessandro.palma/environment/IMPA/IMPA/experiments/measure_metrics")
from compute_metrics import *

from rdkit import Chem
from rdkit.Chem import AllChem, DataStructs
from itertools import combinations
from pathlib import Path
from adjustText import adjust_text

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [3]:
def initialize_model(yaml_config, dest_dir):
    args_rdkit = OmegaConf.create(yaml_config)
    dataloader_rdkit = CellDataLoader(args_rdkit)
    solver_rdkit = IMPAmodule(args_rdkit, dest_dir, dataloader_rdkit)
    return solver_rdkit

class Args:
    def __init__(self, dictionary):
        self.__dict__ = dictionary

    def __getattr__(self, key):
        if key in self.__dict__:
            return self.__dict__[key]
        else:
            raise AttributeError(f"'DictToObject' object has no attribute '{key}'")

    def __call__(self, key):
        return self.__getattr__(key)


def t2np(t, batch_dim=False):
    return ((t.permute(0, 2, 3, 1) + 1) / 2).clamp(0, 1).cpu().numpy()


def plot_n_images(images, n_to_plot, channel=None, size=((1.5, 1.5))):
    for i, img in enumerate(images):
        plt.figure(figsize=size)
        if channel == None:
            plt.imshow(img)
        else:
            plt.imshow(img[:,:,channel],"grey")
        plt.axis("off")
        plt.grid("off")
        plt.show()
        if i == n_to_plot:
            break

def transform_by_emb(solver, dataloader, y, n_average, args):
    """
    Transform images in a dataloader using a solver for a specific drug ID.

    Parameters:
        solver: The solver object used for transformation.
        dataloader: The dataloader containing images to be transformed.
        n_average (int): Number of times to average random noise vectors.
        drug_id (str): The ID of the drug for transformation.
        args: Arguments object containing additional parameters.

    Returns:
        tuple: A tuple containing two NumPy arrays representing controls and transformed images.
    """
    controls = []
    transformed = []
    y = y.unsqueeze(0)
    with torch.no_grad():
        for batch in tqdm(dataloader.train_dataloader()):
            X_ctr = batch["X"][0]
            z = torch.ones(X_ctr.shape[0], n_average, args.z_dimension).cuda().mean(1)
            # z = torch.randn(X_ctr.shape[0], n_average, args.z_dimension).cuda().quantile(0.75,1)
            
            # Perturbation ID 
            y_emb = y.repeat((z.shape[0], 1)).cuda()
            y_emb = torch.cat([y_emb, z], dim=1)
            y_emb = solver.nets.mapping_network(y_emb) 
            
            _, X_generated = solver.nets.generator(X_ctr, y_emb)
            transformed.append(t2np(X_generated.detach().cpu()))
            controls.append(t2np(X_ctr.detach().cpu()))
            break
    return np.concatenate(controls, axis=0), np.concatenate(transformed, axis=0)

In [4]:
bbbbc021_embeddings = pd.read_csv("/home/icb/alessandro.palma/environment/IMPA/IMPA/embeddings/csv/emb_fp.csv", index_col=0)

bbbc021_index = pd.read_csv("/home/icb/alessandro.palma/environment/IMPA/IMPA/project_folder/datasets/bbbc021_all/metadata/bbbc021_df_all.csv",
                           index_col=0)

In [5]:
bbbbc021_embeddings.shape

(35, 1024)

In [6]:
path_to_configs = Path("/home/icb/alessandro.palma/environment/IMPA/IMPA/config_hydra/config")

with open(path_to_configs / 'bbbc021_all_retrain_fp.yaml', 'r') as IMPA_bbbc021:
    # Load YAML data using safe_load() from the file
    yaml_IMPA_bbbc021 = yaml.safe_load(IMPA_bbbc021)

yaml_IMPA_bbbc021["style_dim"] = 128
yaml_IMPA_bbbc021["z_dimension"] = 64
yaml_IMPA_bbbc021["latent_dim"] = 1024
dest_dir = "/home/icb/alessandro.palma/environment/IMPA/IMPA/project_folder/experiments/20240227_0e1c3a90-fa20-438c-9d8d-f081c86c68e7_bbbc021_all_retrain_fp"

IMPA_bbbc021 = initialize_model(yaml_IMPA_bbbc021, dest_dir)

IMPA_bbbc021._load_checkpoint(150)

Number of parameters in generator: 24752771
Number of parameters in style_encoder: 14362304
Number of parameters in discriminator: 14309978
Number of parameters in mapping_network: 139392
Initializing embedding_matrix...
Initializing generator...
Initializing style_encoder...
Initializing discriminator...
Initializing mapping_network...
IMPAmodule(
  (embedding_matrix): Embedding(26, 1024)
  (generator): DataParallel(
    (module): Generator(
      (from_rgb): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (encode): ModuleList(
        (0): ResBlk(
          (actv): LeakyReLU(negative_slope=0.2)
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (conv2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=Fal

In [12]:
import time 
for batch_size in [8, 16, 32, 64, 128]:
    with torch.no_grad():
        yaml_IMPA_bbbc021["batch_size"] = batch_size
        args = OmegaConf.create(yaml_IMPA_bbbc021)
        dataloader = CellDataLoader(args)
        
        batch = next(iter(dataloader.train_dataloader()))
        start_time = time.time()
        X_ctr = batch["X"][0]
        y = batch["mol_one_hot"].argmax(1).cuda()
        y = IMPA_bbbc021.embedding_matrix(y).cuda()
        z = torch.randn(X_ctr.shape[0], args.z_dimension).cuda()
        
        # Perturbation ID 
        y_emb = torch.cat([y, z], dim=1)
        y_emb = IMPA_bbbc021.nets.mapping_network(y_emb) 
        
        _, X_generated = IMPA_bbbc021.nets.generator(X_ctr, y_emb)
        end_time = time.time()
        print(f"Timing for batch size {batch['X'][0].shape}", end_time - start_time)

Timing for batch size torch.Size([8, 3, 96, 96]) 0.0820150375366211
Timing for batch size torch.Size([16, 3, 96, 96]) 0.028976917266845703
Timing for batch size torch.Size([32, 3, 96, 96]) 0.008441448211669922
Timing for batch size torch.Size([64, 3, 96, 96]) 0.008421182632446289
Timing for batch size torch.Size([128, 3, 96, 96]) 0.01615166664123535


In [8]:
y.shape

torch.Size([128, 1024])

In [9]:
z

tensor([[ 2.5841e-02, -4.2467e-01,  3.4941e-01,  ...,  1.8211e+00,
         -1.9366e+00, -2.8147e-01],
        [ 1.1299e+00, -1.6947e+00, -2.2214e-01,  ...,  9.3070e-01,
          5.2010e-01,  1.2309e+00],
        [ 1.0620e+00, -2.3394e-01,  8.5563e-01,  ..., -6.6059e-01,
         -2.3643e-01, -4.8504e-01],
        ...,
        [-1.1658e+00,  1.9877e-01,  1.0136e+00,  ...,  4.8101e-01,
         -1.8381e+00,  7.6252e-01],
        [ 2.3418e-01, -5.8967e-01, -1.9285e+00,  ...,  5.2363e-01,
          1.0054e+00, -1.2943e-03],
        [ 1.6860e+00,  5.1936e-02,  1.9389e+00,  ..., -2.3013e-01,
         -1.5218e+00, -5.1531e-01]], device='cuda:0')