In [None]:
# Import necessary libraries
import torch
import matplotlib.pyplot as plt
import sys
from torchvision.transforms import CenterCrop
import torch.nn.functional as F
import numpy as np


# Add the parent directory to path so we can import the modules
sys.path.append('..')

# Import the model classes and data modules from the project
from training.linear_semantic import LinearSemantic
from datasets.anorak import ANORAK
from models.histo_linear_decoder import LinearDecoder

In [None]:
# Load your trained model from checkpoint

checkpoint_path = "/home/valentin/workspaces/benchmark-vfm-ss/data/lightning_logs/6t30deru/checkpoints/epoch=1289-step=39990.ckpt"

# For Mask2Former model (if you trained with mask2former config)
model = LinearSemantic.load_from_checkpoint(
    checkpoint_path,
    # You need to provide the network again since it's not saved in the checkpoint
    network=LinearDecoder(
        img_size=(448, 448),
        encoder_name="hf-hub:MahmoodLab/UNI2-h",
        num_classes=7,
    ),
    strict=False  # In case of minor mismatches
)

# Set to evaluation mode
model.eval()
print(f"Model loaded from {checkpoint_path}")
print(f"Model type: {type(model).__name__}")

In [None]:
def move_to_device(x, device):
    if torch.is_tensor(x):
        return x.to(device, non_blocking=True)
    if isinstance(x, dict):
        return {k: move_to_device(v, device) for k, v in x.items()}
    if isinstance(x, (list, tuple)):
        return type(x)(move_to_device(v, device) for v in x)
    return x

In [None]:
# Set up data module to load validation data
data_module = ANORAK(
    root="../data/ANORAK",  # Adjust path to your data
    devices=1,
    num_workers=4,
    batch_size=1,
    img_size=(448, 448),
    num_classes=7,
    num_metrics=1
)

# Setup the data module
data_module.setup()


In [None]:
data_module.ignore_idx

In [None]:
data_module.compute_class_weights()

In [None]:

# Get a sample from validation set
val_dataloader = data_module.val_dataloader()
sample_batch = next(iter(val_dataloader))

# Extract image and target
sample_batch = move_to_device(sample_batch, model.device)
img, target = sample_batch


In [None]:
img[0].shape

In [None]:
target[0]["masks"].shape

In [None]:
target[0]

In [None]:
model.plot_semantic(img[0], target[0]["masks"], target[0]["labels"] )

In [None]:
with torch.inference_mode():
    logits = model.eval_step(sample_batch, is_notebook=True)

In [None]:
logits[1].shape

In [None]:
np_logits = np.transpose(logits[0].cpu().numpy(), (1, 2, 0))

In [None]:
np_logits.shape

In [None]:
plt.imshow(np.argmax(np_logits, axis=-1))