In [None]:
from src.data_management.dataset import AbstractCxrDataset
import logging
import re
from enum import Enum
from pathlib import Path
from typing import List, Optional

import cv2
import pandas as pd
import torch
import torchvision
from torch.utils.data import Dataset
from transformers import AutoTokenizer

from src.utils.time import measure_time

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [None]:
class CxrEvaluationCLIP(AbstractCxrDataset):
    def __init__(
        self,
        label_path: str,
        img_dir: str,
        testing: bool = False,
    ):
        import time

        st_time = time.time()
        self._img_dir = Path(img_dir)
        self._labels = pd.read_csv(label_path, dtype="str")
        self._label_names = list(self._labels.columns[6:])
        en_time = time.time()
        print(f"Time to load labels: {en_time - st_time}")

        st_time = time.time()
        self._img_paths = self._create_img_paths()
        en_time = time.time()
        print(f"Time to load images: {en_time - st_time}")

        st_time = time.time()
        self._transform = torchvision.transforms.Compose(
            [
                torchvision.transforms.ToPILImage(),
                torchvision.transforms.Resize(224),
                torchvision.transforms.Grayscale(num_output_channels=1),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(mean=(0.485,), std=(0.229,)),
            ]
        )

        self._tokenizer = self._create_tokenizer()
        en_time = time.time()
        print(f"Time to load tokenizer: {en_time - st_time}")

    def _create_tokenizer(self):
        tokenizer = AutoTokenizer.from_pretrained(
            "openai/clip-vit-base-patch32",
            truncation_side="left",
            padding_side="right",
            model_max_length=77,
        )

        return tokenizer

    def _create_img_paths(self) -> List[Path]:
        img_paths = []

        def _create_path(row: pd.Series) -> Path:
            part_id = self.extract_subject_part(row["subject_id"])
            subject_id = row["subject_id"]
            study_id = row["study_id"]
            img_id = row["dicom_id"]
            return (
                self._img_dir
                / f"p{part_id}"
                / f"p{subject_id}"
                / f"s{study_id}"
                / f"{img_id}.jpg"
            )

        self._labels.apply(
            lambda x: img_paths.append(_create_path(x)),
            axis=1,
        )

        return img_paths

    def _load_and_process_text(self, index: int) -> torch.Tensor:
        # 26 pathologies -> tensor [26, (...)] with (...) is for each pathology
        text = (
            "This patient has the following pathologies: "
            f"{self._labels['pathology'][index]}"
        )

        tokens = self._tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=77,
            return_tensors="pt",
        )

        return tokens

    def _load_and_process_img(self, img_path: Path) -> torch.Tensor:
        img = super()._load_and_process_img(img_path)
        img = self._transform(img)

        return img

    def __getitem__(self, index):
        img = self._load_and_process_img(self._img_paths[index])
        txt = self._load_and_process_text(index)

        return img, txt, int(self._labels["label"][index])

In [None]:
from collections import defaultdict

import torch
from tqdm import tqdm
from transformers.tokenization_utils_base import BatchEncoding

from myconfig import CFG
from src.models.experiments import CxrVQA
from src.utils import AvgMeter

IS_TESTING = False


def eval_collate_fn(batch):
    imgs = [item[0] for item in batch]
    print(imgs)
    collated_data = dict()
    for key in batch[0][1].keys():
        data = [item[1][key] for item in batch]
        collated_data[key] = torch.stack(data).squeeze()

    stacked_imgs = torch.stack(imgs)
    stacked_texts = BatchEncoding(data=collated_data)
    stacked_labels = torch.tensor([item[2] for item in batch])
    return stacked_imgs, stacked_texts, stacked_labels


val_ds = CxrEvaluationCLIP(
    label_path="/mnt/ssd1/CXR/data/classification-labels/test_train_long.csv",
    img_dir="/mnt/ssd1/CXR/data/cxr_dataset",
    testing=IS_TESTING,
)

val_dataloader = torch.utils.data.DataLoader(
    val_ds,
    batch_size=CFG.batch_size,
    shuffle=True,
    num_workers=CFG.num_workers,
    collate_fn=eval_collate_fn,
)

state_dict = torch.load("/mnt/ssd1/CXR/ckpts/best.pt")
model = CxrVQA()
model.load_state_dict(state_dict)
model.to(CFG.device)

n_classes = 26

avg_auc = 0


def valid_epoch(model, valid_loader):
    loss_meter = AvgMeter()

    tqdm_object = tqdm(valid_loader, total=len(valid_loader))
    for batch in tqdm_object:
        imgs = batch[0].to(CFG.device)
        texts = batch[1].to(CFG.device)
        loss = model(imgs, texts)

        count = len(batch[0])
        loss_meter.update(loss.item(), count)

        tqdm_object.set_postfix(valid_loss=loss_meter.avg)
    return loss_meter


model.eval()

In [None]:

from src.models.losses.cross_entropy import CELoss


tqdm_object = tqdm(val_dataloader, total=len(val_dataloader))

criterion = CELoss()

for batch in tqdm_object:
    imgs = batch[0].to(CFG.device)
    texts = batch[1].to(CFG.device)
    sim = model.calcluate_similarity(imgs, texts)
    loss = criterion(sim, batch[2].to(CFG.device))
    count = len(batch[0])
    break
    

In [None]:
sim

In [None]:
val_ds._tokenizer.batch_decode(sequences=val_ds.__getitem__(0)[1]["input_ids"].detach())

In [None]:
val_ds.__getitem__(0)[1]["input_ids"].detach()

In [None]:
# for epoch in range(CFG.epochs):
#     print(f"Epoch: {epoch + 1}")

#     tqdm_object = tqdm(val_dataloader, total=len(val_dataloader))
#     for batch in tqdm_object:
#         imgs = batch[0].to(CFG.device)
#         texts = batch[1].to(CFG.device)
#         sim = model.calcluate_similarity(imgs, texts)

#         count = len(batch[0])

#     model.eval()
#     with torch.no_grad():
#         valid_loss = valid_epoch(model, val_dataloader)
#         print(f"Valid Loss: {valid_loss}")