# Model Inference

In [3]:
# imports
import torch, os
import matplotlib.pyplot as plt
import numpy as np
from source.net import *
from source.utils import *
from source.dataset import *
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from sklearn.metrics import accuracy_score
from pathlib import Path

device = "cuda:0"


In [4]:
backbone = load_net('simclr')
fc_head = load_net('fc_head', {'num_classes': 1139})
checkpoint_dir = Path("/work/ai4bio2024/rxrx1/checkpoints/simclr_head")
checkpoints, nets = [], []

for file_name in checkpoint_dir.iterdir():
    nets.append(CellClassifier(backbone, fc_head))
    checkpoints.append(file_name)

for i in range(len(nets)):
    nets[i] = nets[i].to(device)
    load_weights(checkpoints[i], nets[i], device)

for checkpoint, net in zip(checkpoints, nets):
    print(f"Loaded checkpoint-{checkpoint}")

Loaded checkpoint-/work/ai4bio2024/rxrx1/checkpoints/simclr_head/checkpoint50
Loaded checkpoint-/work/ai4bio2024/rxrx1/checkpoints/simclr_head/checkpoint70
Loaded checkpoint-/work/ai4bio2024/rxrx1/checkpoints/simclr_head/checkpoint20
Loaded checkpoint-/work/ai4bio2024/rxrx1/checkpoints/simclr_head/checkpoint80
Loaded checkpoint-/work/ai4bio2024/rxrx1/checkpoints/simclr_head/checkpoint100
Loaded checkpoint-/work/ai4bio2024/rxrx1/checkpoints/simclr_head/checkpoint10
Loaded checkpoint-/work/ai4bio2024/rxrx1/checkpoints/simclr_head/checkpoint30
Loaded checkpoint-/work/ai4bio2024/rxrx1/checkpoints/simclr_head/checkpoint60
Loaded checkpoint-/work/ai4bio2024/rxrx1/checkpoints/simclr_head/checkpoint90
Loaded checkpoint-/work/ai4bio2024/rxrx1/checkpoints/simclr_head/checkpoint40


In [5]:
split_sizes = [0.7, 0.15, 0.15]
dataset = Rxrx1('/work/ai4bio2024/rxrx1')
train_size = int(split_sizes[0] * len(dataset))
val_size = int(split_sizes[1] * len(dataset))
test_size = len(dataset) - train_size - val_size
generator = torch.Generator().manual_seed(42)
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size], generator=generator)
test_dataloader = DataLoader(test_dataset, batch_size=128, pin_memory_device=device,
                             pin_memory=True, num_workers=12, drop_last=True, prefetch_factor=2)

In [6]:
pred_labels, true_labels = [], []

for i in range(len(nets)):
    nets[i].eval()
    pred_labels.append([])
    true_labels.append([])

for x_batch, _, y_batch in tqdm(test_dataloader):
    x_batch, y_batch = x_batch.to(torch.float).to(device), y_batch.to(device)
    for i in range(len(nets)):
        y_pred = nets[i](x_batch).argmax(dim=1)
        true_labels[i].append(y_batch)
        pred_labels[i].append(y_pred)

100%|██████████| 147/147 [02:42<00:00,  1.11s/it]


In [7]:
for i in range(len(nets)):
    print(f"Accuracy-{str(checkpoints[i]).split('/')[-1]}:{accuracy_score(torch.cat(true_labels[i]).cpu(), torch.cat(pred_labels[i]).cpu())}")

Accuracy-checkpoint50:0.03752125850340136
Accuracy-checkpoint70:0.03752125850340136
Accuracy-checkpoint20:0.03752125850340136
Accuracy-checkpoint80:0.03752125850340136
Accuracy-checkpoint100:0.03752125850340136
Accuracy-checkpoint10:0.03752125850340136
Accuracy-checkpoint30:0.03752125850340136
Accuracy-checkpoint60:0.03752125850340136
Accuracy-checkpoint90:0.03752125850340136
Accuracy-checkpoint40:0.03752125850340136
