In [None]:
!conda install gdcm -c conda-forge -y

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import models, transforms

from sklearn.model_selection import train_test_split
from PIL import Image
from tqdm import tqdm
import copy
import glob
import time
import numpy as np
import pandas as pd
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut

In [None]:
study_csv = pd.read_csv('../input/siim-covid19-detection/train_study_level.csv')
study = study_csv.to_numpy()[:, 1:]
X = [_id.split('_')[0] for _id in study_csv['id']]
y = np.where(study == 1)[1]
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42, test_size=0.2, stratify=y)

In [None]:
def read_xray(path, voi_lut = True, fix_monochrome = True):
    # Original from: https://www.kaggle.com/raddar/convert-dicom-to-np-array-the-correct-way
    dicom = pydicom.read_file(path)
    
    # VOI LUT (if available by DICOM device) is used to transform raw DICOM data to 
    # "human-friendly" view
    if voi_lut:
        data = apply_voi_lut(dicom.pixel_array, dicom)
    else:
        data = dicom.pixel_array
        
    # depending on this value, X-ray may look inverted - fix that:
    if fix_monochrome and dicom.PhotometricInterpretation == "MONOCHROME1":
        data = np.amax(data) - data
        
    data = data - np.min(data)
    data = data / np.max(data)
    data = (data * 255).astype(np.uint8)
    data = np.repeat(data[..., np.newaxis], 3, -1)
         
    return data

In [None]:
class Covid19Dataset(Dataset):
    def __init__(self, root_dir, studies, labels, transform):
        self.root_dir = root_dir
        self.image_paths = []
        self.labels = []
        self.transform = transform
        
        for study, label in tqdm(zip(studies, labels)):
            p = f'{root_dir}/{study}/*/*'
            images = glob.glob(p)
            self.image_paths += images
            self.labels += [label] * len(images)
            
        assert len(self.image_paths) == len(self.labels)
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        path = self.image_paths[idx]
        label = self.labels[idx]
        
        pixel_array = read_xray(path)
        image = Image.fromarray(pixel_array)
        image = self.transform(image)
        
        return image, label

In [None]:
def train_model(model, dataloaders, criterion, optimizer, num_epochs):
    since = time.time()

    val_acc_history = []

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in tqdm(dataloaders[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    _, preds = torch.max(outputs, 1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'val':
                val_acc_history.append(epoch_acc)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, val_acc_history

def initialize_model(num_classes):
    model = models.resnet50(pretrained=True)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)
    return model

In [None]:
transform = transforms.Compose([
    transforms.Resize((800, 800)),
    transforms.ToTensor(),
])

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
train_set = Covid19Dataset('../input/siim-covid19-detection/train', X_train, y_train, transform)
val_set = Covid19Dataset('../input/siim-covid19-detection/train', X_test, y_test, transform)
dataloaders = {
    'train': DataLoader(train_set, batch_size=4, shuffle=True),
    'val': DataLoader(val_set, batch_size=4, shuffle=True),
}

In [None]:
# Number of classes in the dataset
num_classes = 4

# Number of epochs to train for
num_epochs = 5

model = initialize_model(num_classes)
model = model.to(device)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()
model, hist = train_model(model, dataloaders, criterion, optimizer, num_epochs=num_epochs)
torch.save(model.state_dict(), 'resnet50.pt')