In [None]:
import os
from pathlib import Path

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import pandas as pd

from histolung.models.feature_extractor import BaseFeatureExtractor
from histolung.evaluation.datasets import TileDataset

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # Replace "0" with the GPU index you want to use

In [6]:
data_dir = Path(
    "/home/valentin/workspaces/histolung/data/processed/LungHist700/"
)
tiles_dir = data_dir / "tiles"
metadata = pd.read_csv(data_dir / "metadata.csv").set_index("tile_id")

In [None]:
model = BaseFeatureExtractor.get_feature_extractor(
    "UNI",
    weights_filepath=
    "models/uni/assets/ckpts/vit_large_patch16_224.dinov2.uni_mass100k/pytorch_model.bin",
).to("cuda")
preprocess = model.get_preprocessing()


In [None]:
tile_paths = [
    p for p in tiles_dir.glob("*.png")
    if metadata.loc[p.stem]["resolution"] == "20x"
]

tile_dataset = TileDataset(tile_paths, preprocess=preprocess)
dataloader = DataLoader(tile_dataset, batch_size=128, num_workers=12)

In [None]:
embeddings = []
tile_ids = []
for batch in tqdm(dataloader):
    images, batch_tile_ids = batch
    embeddings.append(model(images.to("cuda")).detach().cpu())
    tile_ids.extend(batch_tile_ids)

In [13]:
labels = metadata.set_index("tile_id").loc[tile_ids]["label"].tolist()

In [11]:
train_patient_ids = [
    2, 3, 4, 5, 7, 8, 12, 14, 15, 16, 17, 18, 20, 21, 23, 24, 25, 26, 28, 29,
    30, 33, 36, 37, 38, 39, 41, 42, 45
]
val_patient_ids = [1, 6, 27, 32, 44]
test_patient_ids = [9, 13, 31, 40]
train_patient_ids = val_patient_ids + train_patient_ids

In [17]:
superclass_mapping = {"nor": 0, "aca": 1, "scc": 2}
labels = list(
    map(lambda x: superclass_mapping[x],
        metadata.loc[tile_ids]["superclass"].tolist()))
patient_ids = metadata.loc[tile_ids]["patient_id"].tolist()

# Split based on patient IDs
train_idx = [
    i for i, image_id in enumerate(patient_ids)
    if image_id in train_patient_ids
]
test_idx = [
    i for i, image_id in enumerate(patient_ids)
    if image_id in test_patient_ids
]

In [None]:
set(train_idx).intersection(set(test_idx))

In [21]:
X = np.concatenate(embeddings, axis=0)

In [None]:
X.shape

In [None]:
X_train, labels_train = X[train_idx], labels[train_idx]