In [None]:
from PIL import Image
from torchvision import transforms

from histopatseg.models.mil_complete import MILModel
from histopatseg.models.models import load_model
from histopatseg.utils import get_device


In [None]:
device = get_device(0)
feature_extractor, transform_fe, embedding_dim, autocast_dtype = load_model("UNI2", device)

In [None]:
transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.707223, 0.578729, 0.703617),
                                 std=(0.211883, 0.230117, 0.177517)),
        ])

In [None]:
model = MILModel(embedding_dim, feature_extractor=feature_extractor)

In [None]:
image = Image.open("/home/valentin/workspaces/histopatseg/data/processed/LungHist700/LungHist700_20x/nor_20x_24.png").convert("RGB")
image = transform(image)

In [None]:
image.shape

In [None]:
tiles = model.tile_image(image)

In [None]:
tiles.shape

In [None]:
import torch
import matplotlib.pyplot as plt

# Define a simple denormalization function
def denormalize(tensor, mean, std):
    """
    Denormalize a tensor image using mean and std.
    Assumes tensor is (C, H, W).
    """
    # Clone the tensor to avoid modifying the original
    tensor = tensor.clone()
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)
    return tensor

# Set normalization parameters (as in your transform)
mean = torch.tensor([0.707223, 0.578729, 0.703617])
std = torch.tensor([0.211883, 0.230117, 0.177517])

to_pil = transforms.ToPILImage()

n_tiles = tiles.shape[0]
cols = min(5, n_tiles)
rows = (n_tiles + cols - 1) // cols
plt.figure(figsize=(15, 3 * rows))
for i, tile in enumerate(tiles):
    tile_denorm = denormalize(tile, mean, std)
    pil_tile = to_pil(tile_denorm.cpu())
    plt.subplot(rows, cols, i + 1)
    plt.imshow(pil_tile)
    plt.axis("off")
plt.show()
