### Todo

I already had the COVID test for the BioVil-T ResNet model. Dig into it. Find that model, and use it as the feature extractor. During this process, I can also dig into other notebooks that I have created.

In [1]:
import torch
import torch.nn as nn
from typing import Callable, Optional
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import cv2
import pydicom
import os
from torchvision.models import resnet50
from glob import glob
import random

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
os.getcwd()

'/scratch/users/oince22/hpc_run/fromage_scratch'

In [3]:
torch.__version__

'1.13.1+cu117'

In [4]:
torch.cuda.is_available()

True

In [5]:
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor

class BioViL(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = resnet50()
        self._initialize_resnet()
        self.feature_extractor = self._get_feature_extractor()
        
    def _initialize_resnet(self):
        base_keys = list(self.model.state_dict().keys())
        model_state_dict = torch.load("biovil_backbone_2048.pt")
        self.model.load_state_dict(model_state_dict)
        # torch.save(self.model.state_dict(), "biovil_backbone_2048.pt")

    
    def _get_feature_extractor(self):
        self._return_nodes = {'avgpool': 'avgpool'}
        return create_feature_extractor(self.model, return_nodes=self._return_nodes)

    def forward(self, x):
        return self.feature_extractor(x)

In [8]:
model = BioViL().cuda()

In [7]:
sum(p.numel() for p in model.parameters())

25557032

In [7]:
sum(p.numel() for p in model.parameters())

25557032

In [6]:
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.biovil_model = BioViL()
        self.fc_dim = 2048
        self.fc1 = torch.nn.Linear(self.fc_dim, 128)
        self.fc2 = torch.nn.Linear(128, 2)
        self.relu = torch.nn.ReLU()
        self.biovil_model.requires_grad = False

    def train(self, mode=True):
        super(Model, self).train(mode=mode)
        self.biovil_model.eval()
        
    def forward(self, x):
        x = self.biovil_model(x)['avgpool']
        x = x.reshape(-1, self.fc_dim)
        return self.fc2(self.relu(self.fc1(x)))

In [7]:
class ExpandChannels:
    def __call__(self, data: torch.Tensor) -> torch.Tensor:
        if data.shape[0] != 1:
            raise ValueError(f"Expected input of shape [1, H, W], found {data.shape}")
        return torch.repeat_interleave(data, 3, dim=0)

In [8]:
from torchvision.transforms import Compose, Resize, ToTensor, CenterCrop, RandomHorizontalFlip
def create_chest_xray_transform_for_inference(resize: int, center_crop_size: int, train=False) -> Compose:
    if not train:
        transforms = [Resize(resize), CenterCrop(center_crop_size), ToTensor(), ExpandChannels()]
    if train:
        transforms = [Resize(resize), RandomHorizontalFlip(), CenterCrop(center_crop_size), ToTensor(), ExpandChannels()]
    return Compose(transforms)

In [9]:
root_dir = "chest_xray/chest_xray"
source_dirs = ["NORMAL", "PNEUMONIA"]

class ChestXRayDataset(torch.utils.data.Dataset):
    def __init__(self, transform, split="train"):
        assert split in ["train", "test"]
        
        get_img_path = lambda img_fname, split, class_name: root_dir + "/" + split + "/" + class_name + "/" + img_fname
        
        def get_images(class_name):
            images = [get_img_path(x, split, class_name) for x in os.listdir(root_dir + "/" + split + "/" + class_name) if x.lower().endswith('jpeg')]
            print(f'Found {len(images)} {class_name} examples')
            return images
        
        self.img_names = []
        self.labels = []

        self.class_names = source_dirs
        
        for label, class_name in enumerate(self.class_names):
            cur_img_paths = get_images(class_name)
            self.img_names.extend(cur_img_paths)
            self.labels.extend([label for _ in range(len(cur_img_paths))])
            
        self.transform = transform
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        image_path = self.img_names[idx]
        image = Image.open(image_path)
        return self.transform(image), self.labels[idx]

In [10]:
transform_center_crop_size = 512
TRANSFORM_RESIZE = 480

train_transform = create_chest_xray_transform_for_inference(resize=TRANSFORM_RESIZE, center_crop_size=transform_center_crop_size, train=True)
test_transform = create_chest_xray_transform_for_inference(resize=TRANSFORM_RESIZE, center_crop_size=transform_center_crop_size)

In [11]:
train_dataset = ChestXRayDataset(train_transform, split="train")
test_dataset = ChestXRayDataset(test_transform, split="test")

print("Length of train set   :  ", len(train_dataset))
print("Length of test set    :  ", len(test_dataset))

Found 1341 NORMAL examples
Found 3875 PNEUMONIA examples
Found 234 NORMAL examples
Found 390 PNEUMONIA examples
Length of train set   :   5216
Length of test set    :   624


In [12]:
batch_size = 18

data_train_len = torch.utils.data.DataLoader(train_dataset, batch_size= batch_size, shuffle=True)
data_test_len = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)

print("Length of training batches", len(data_train_len))
print("Lentgth of test batches", len(data_test_len))

Length of training batches 290
Lentgth of test batches 35


In [13]:
def convert_dicom_to_jpg(input_path, output_path, resize=False, new_width=512):
    # Load DICOM file
    ds = pydicom.dcmread(input_path)

    # Extract pixel data and normalize to range [0, 255]
    pixel_data = ds.pixel_array
    pixel_data = pixel_data.astype(np.float32)
    pixel_data -= np.min(pixel_data)
    pixel_data /= np.max(pixel_data)
    pixel_data *= 255.0
    pixel_data = np.uint8(pixel_data)

    # Check PhotometricInterpretation for inversion
    if ds.PhotometricInterpretation == "MONOCHROME1":
        # Invert pixel values
        pixel_data = 255 - pixel_data

    # Histogram equalization
    pixel_data = cv2.equalizeHist(pixel_data)

    if resize:
        height, width = pixel_data.shape

        scale = new_width / width
        new_height = int(height * scale)

        pixel_data = cv2.resize(pixel_data, (new_width, new_height))

    # Convert to JPEG with quality factor 95
    encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 95]
    _, jpeg_data = cv2.imencode('.jpg', pixel_data, encode_param)

    with open(output_path, 'wb') as f:
        f.write(jpeg_data)

    return pixel_data

In [14]:
import gc
loss_fn = torch.nn.CrossEntropyLoss()
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

In [15]:
model = Model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5, weight_decay=1e-2)

In [16]:
def train(epochs):
    print('Starting training..')
    for e in range(0, epochs):
        print('='*20)
        print(f'Starting epoch {e + 1}/{epochs}')
        print('='*20)
        accuracy = 0

        train_loss = 0.
        val_loss = 0.

        model.train() # set model to training phase
        model.biovil_model.eval()
        

        for train_step, (images, labels) in enumerate(data_test_len):
            optimizer.zero_grad()
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            
            loss.backward()
            
            optimizer.step()
            train_loss += loss.item()

        print('Evaluating at step', train_step)

        model.eval() # set model to eval phase

        with torch.no_grad():
            for val_step, (images, labels) in enumerate(data_test_len):
                images = images.to(device)
                labels = labels.to(device)

                outputs = model(images)
                loss = loss_fn(outputs, labels)
                val_loss += loss.item()

                outputs = outputs.detach().cpu()
                labels = labels.detach().cpu()

                _, preds = torch.max(outputs, 1)
                accuracy += sum((preds == labels).numpy())
                print(f"Val step: {val_step}", end="\r")

        val_loss /= (val_step + 1)
        accuracy = accuracy/len(test_dataset)        
        
        model.train()
        model.biovil_model.eval()

        train_loss /= (train_step + 1)

        print(f'Training Loss: {train_loss:.4f}')
        print(f'Validation Loss: {val_loss:.4f}, Accuracy: {accuracy:.4f}')


    print('Training complete..')

In [17]:
def test_predicts():
    model.eval()
    images, labels = next(iter(data_test_len))
    images = images.to(device)
    labels = labels.to(device)
    outputs = model(images)
    outputs = outputs.detach().cpu()
    labels = labels.detach().cpu()
    _, preds = torch.max(outputs, 1)
    images = images.detach().cpu()
    show_images(images, labels, preds)

In [18]:
train(epochs = 40)

Starting training..
Starting epoch 1/40
Evaluating at step 34
Training Loss: 0.7358
Validation Loss: 0.6971, Accuracy: 0.3750
Starting epoch 2/40



KeyboardInterrupt



In [None]:
test_predicts()

transform_center_crop_size = 480
TRANSFORM_RESIZE = 512

train_transform = create_chest_xray_transform_for_inference(resize=TRANSFORM_RESIZE, center_crop_size=transform_center_crop_size)
test_transform = create_chest_xray_transform_for_inference(resize=TRANSFORM_RESIZE, center_crop_size=transform_center_crop_size)

dicom_path = "/datasets/mimic/physionet.org/files/mimic-cxr/2.0.0/files/p10/p10000032/s50414267/02aa804e-bde0afdd-112c0b34-7bc16630-4e384014.dcm"
jpeg_path = "test2_resized.jpg"
img = convert_dicom_to_jpg(dicom_path, jpeg_path, resize=True)
print(img.shape)
img = Image.fromarray(img)
print(img)
plt.imshow(img, cmap="gray")
plt.figure()
img = test_transform(img)
img = img.permute(1, 2, 0).numpy()
img = img * 255
img = img.astype(np.uint8)
img = Image.fromarray(img)
plt.imshow(img)