In [None]:
import torch
from lighter_zoo import SegResEncoder
from monai.transforms import (
    Compose,
    LoadImage,
    EnsureType,
    Orientation,
    ScaleIntensityRange,
    CropForeground,
    Resize
)
from monai.inferers import SlidingWindowInferer

from einops import rearrange
from dotenv import load_dotenv
import os
import numpy as np
from glob import glob
from tqdm import tqdm

In [None]:
dataset_name = "CT-RATE_train_eval"
dataset_split = "train"

In [None]:
load_dotenv()

project_path = os.getenv("PROJECTPATH")
run_name = "CT-FM"
checkpoint_name = "embeddings"

output_path = os.path.join(
    project_path,
    "evaluation/cache",
    dataset_name,
    run_name,
    checkpoint_name,
)
os.makedirs(output_path, exist_ok=True)

In [None]:
device = torch.device("cuda")

model = SegResEncoder.from_pretrained("project-lighter/ct_fm_feature_extractor")
model.to(device)
model.eval()

print("Loaded model")

In [None]:
preprocess = Compose(
    [
        LoadImage(ensure_channel_first=True),
        EnsureType(),
        Orientation(axcodes="SPL"),
        ScaleIntensityRange(
            a_min=-1024,
            a_max=2048,
            b_min=0,
            b_max=1,
            clip=True,
        ),
        CropForeground(allow_smaller=True),
        Resize(spatial_size=(240, 512, 512)),
    ]
)
print()

In [None]:
data_path = os.getenv("DATAPATH")
dataset_path = os.path.join(data_path, "niftis/CT-RATE", dataset_split)

class Dataset:
    def __init__(self, dataset_path):
        nifti_template = os.path.join(dataset_path, "**/*.nii")
        self.nifti_paths = glob(nifti_template, recursive=True)
    def __getitem__(self, index):
        nifti_path = self.nifti_paths[index]
        image_name = nifti_path.split("/")[-1].split(".")[0]
        input_tensor = preprocess(nifti_path)
        return input_tensor.unsqueeze(0), nifti_path
    def __len__(self):
        return len(self.nifti_paths)

def collate_fn(batch):
    tensors, filenames = zip(*batch)

    return tensors, filenames

dataset = Dataset(dataset_path)
len(dataset)

In [None]:
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=2,
    shuffle=False,
    collate_fn=collate_fn
)

In [None]:
for input_tensors, image_names in tqdm(dataloader):

    for i, image_name in enumerate(image_names):

        input_tensor = input_tensors[i].to(device)
    
        with torch.no_grad():
            output = model(input_tensor)[-1].squeeze()
            output = rearrange(output, "e a w h -> a (w h) e").cpu().numpy()
        
        np.save(os.path.join(output_path, f"{image_name}.npy"), output)

        print(output.shape)
        
    break