In [16]:
import pandas as pd
import seaborn as sns
import os
import json
import pydicom
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import torch
import wandb

In [17]:
df = pd.read_csv("data/v2/framerejectionfeedback.csv")

In [18]:
def normalize_dicom(dcm) -> np.ndarray:
    data = dcm.pixel_array.astype("float32")
    window_center = dcm.WindowCenter
    window_width = dcm.WindowWidth
    if isinstance(window_center, pydicom.multival.MultiValue):
        window_center = window_center[0]
    if isinstance(window_width, pydicom.multival.MultiValue):
        window_width = window_width[0]
    lower = window_center - window_width / 2
    upper = window_center + window_width / 2
    normalized = np.clip((data - lower) / (upper - lower), 0, 1)
    return normalized

In [19]:
# loading about 10GB of data

dcm_files = os.listdir("data/frame_selection/dicoms/")
dcm_files = [f for f in dcm_files if f.endswith(".dcm")]

parent = "data/frame_selection/dicoms/"
images = {}
for filename in tqdm(dcm_files):
    dcm = pydicom.dcmread(parent + filename)
    img = normalize_dicom(dcm)
    images[filename.replace(".dcm", "")] = img

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 541/541 [02:34<00:00,  3.50it/s]


In [14]:
test_img_ids = pd.read_csv("testimgs.csv")
test_img_ids = set(test_img_ids.ID)

{'131aedfhs6pnf1fvtvp49mfi9clxqv4y22',
 '131aedfhs6pnf1fvtvp49mk95mce4ksv22',
 '131aedfhs6pnf1fvtvp49mhebo5bwm8j22',
 '131aedfhs6pnf1fvtvp49jww2fesosyw22',
 '131aedfhs6pnf1fvtvp49mfe16xskqh522',
 '131aedfhs6pnf1fvtvp49mn3xaiu7kol22',
 '131aedfhs6pnf1fvtvp49mfe0fzulhdp8kn',
 '131aedfhs6pnf1fvtvp49mg993ho7uch22',
 '131aedfhs6pnf1fvtvp49mi8hnble4a222',
 '131aedfhs6pnf1fvtvp49mm9f7mqmw8r22',
 '131aedfhs6pnf1fvtvp49jxuhobsd64322',
 '131aedfhs6pnf1fvtvp49mfbw01cxrrh22',
 '131aedfhs6pnf1fvtvp49mk5xv833v3h22',
 '131aedfhs6pnf1fvtvp49e8212gyaybo22',
 '131aedfht3wfgufvmycqzh8wt0tmyyq',
 '131aedfhs6pnf1fvtvp49jss9812hjmj22',
 '131aedfhs6pnf1fvtvp49mk95ml0ge0522',
 '131aedfht3wfgufvmycqznyflf1wn70',
 '131aedfht3wfgufvmycqzlxhm34ldkw',
 '131aedfh7815iqf0ke14rkkx6aex6d',
 '131aedfhs6pnf1fvtvp49mhbhb5zsd8u22',
 '131aedfhs6pnf1fvtvp49jxq6vd6n01m22',
 '131aedfht3wfgufvmycqznydcvyn2qs',
 '131aedfhs6pnf1fvtvp49jzpgy296p2022',
 '131aedfhs6pnf1fvtvp49jxuhrbjv1c222',
 '131aedfht3wfgufvmycqz5i2cdnwhyk',
 '13

In [22]:
class SimpleFrameSelectionDataset:  
    def __init__(self, images: dict, df: pd.DataFrame, split=None):
        self.images = images
        keys = 2 * list(images.keys())
        labels = [0] * len(images) + [1] * len(images)
        n_frames = list(df["REJECTEDFRAME"]) + list(df["SELECTEDFRAME"])

        self.index = [*zip(keys, n_frames, labels)]
        self.index = [(k, n, l) for (k, n, l) in self.index if len(self.images[k])>n]

        if split == "train":
            self.index = [(k, n, l) for (k, n, l) in self.index if k not in test_img_ids]
        elif split == "test":
            self.index = [(k, n, l) for (k, n, l) in self.index if k in test_img_ids]
                    
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((224, 224), antialias=True),
            transforms.Normalize([0.5], [0.5])   
        ])
                
    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx):
        key, n_frame, label = self.index[idx]
        img = self.images[key][n_frame]
        img = self.transform(img)
        return img, label
        
train_dataset = SimpleFrameSelectionDataset(images, df, split="train")
test_dataset = SimpleFrameSelectionDataset(images, df, split="test")

print(f"Train: {len(train_dataset)}")
print(f"Test: {len(test_dataset)}")

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16)

Train: 832
Test: 228


In [26]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = torchvision.models.resnet18(pretrained=True)
model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
model.fc = torch.nn.Linear(model.fc.in_features, 1)  
model = model.to(device)

criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

best_test_loss = np.inf
pbar = tqdm(range(100))
for epoch in pbar:
    model.train()
    running_loss = 0.0
    correct = 0
    for imgs, labels in train_loader:
        imgs = imgs.to(device)
        labels = labels.to(device).float().unsqueeze(1)

        outputs = model(imgs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        correct += ((outputs > 0.5).float() == labels).sum().item()

    avg_train_loss = running_loss / len(train_loader)
    train_accuracy = correct / len(train_dataset) * 100

    model.eval()
    running_loss = 0.0
    correct = 0
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs = imgs.to(device)
            labels = labels.to(device).float().unsqueeze(1)

            outputs = model(imgs)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            correct += ((outputs > 0.5).float() == labels).sum().item()

    avg_test_loss = running_loss / len(test_loader)
    test_accuracy = correct / len(test_dataset) * 100

    pbar.set_description(
        f'Loss: {avg_train_loss:.4f}/{avg_test_loss:.4f} | ' 
        f'Acc {train_accuracy:.2f}/{test_accuracy:.2f}'
    )

    if avg_test_loss < best_test_loss:
        best_test_loss = avg_test_loss
        torch.save(model.state_dict(), 'frame_selection.pth')


Loss: 0.3962/0.8313 | Acc 72.00/58.33: 100%|███████████████████████████████████████████████████████████████████████████████████| 100/100 [03:45<00:00,  2.26s/it]
