In [None]:
import os

from dotenv import load_dotenv
from huggingface_hub import login
import torch
from torchvision import transforms
from torchvision.models import resnet50
import numpy as np
import torch.nn as nn

from histolung.evaluation.evaluators import LungHist700Evaluator
from histolung.models.models_darya import MoCoV2Encoder

In [None]:
moco_model = MoCoV2Encoder()

In [None]:
checkpoint = torch.load(
    "/mnt/nas7/data/Personal/Darya/saved_models/superpixels_moco_org/superpixel_moco_org_58.pth",
    # "/mnt/nas7/data/Personal/Darya/saved_models/superpixels_resnet50__alpha_0.5__ablation/superpixel_org_22.pth",
    map_location="cpu",
)


In [None]:
moco_model.state_dict().keys()

In [None]:

moco_model.load_state_dict(checkpoint["model_state_dict"], strict=False)

In [None]:
checkpoint["model_state_dict"].keys()

In [None]:
model = moco_model.encoder_q
model.fc = nn.Identity()

In [None]:
def get_device(gpu_id=None):
    """Select the appropriate device for computation."""
    if torch.cuda.is_available():
        if gpu_id is not None and gpu_id < torch.cuda.device_count():
            device = torch.device(f"cuda:{gpu_id}")
            print(f"Using GPU: {torch.cuda.get_device_name(device)}")
        else:
            device = torch.device("cuda:0")  # Default to first GPU
            print(f"Using GPU: {torch.cuda.get_device_name(device)}")
    else:
        device = torch.device("cpu")
        print("Using CPU.")
    return device

In [None]:
device = get_device(gpu_id=1)

In [None]:
model.to(device)
model.eval()

In [None]:
evaluator = LungHist700Evaluator(
    n_splits=5,
    batch_size=256,
    num_workers=4,
    gpu_id=1,
    data_dir="/home/valentin/workspaces/histolung/data/processed/LungHist700_10x",
)

In [None]:
embeddings, tile_ids = evaluator.compute_embeddings(model)

In [None]:
embeddings.shape

In [None]:
results = evaluator.evaluate(embeddings, tile_ids, verbose=True, magnification="all")

In [None]:
print(f"\nk-NN Concatenated Accuracy: {results['concatenated_accuracy']:.4f}")
print(
    f"\nk-NN Mean Accuracy \u00B1 STD : {results['mean_accuracy']:.4f} \u00B1 {results['std_accuracy']:.4f}"
)
