In [None]:
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from pydicom import dcmread
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm


import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils import data

## Preparing labels

In [None]:
label_data = pd.read_csv('../input/rsna-pneumonia-detection-challenge/stage_2_train_labels.csv')
columns = ['patientId', 'Target']

label_data = label_data.filter(columns)
label_data.head(5)

## Dividing labels for train and validation set

In [None]:
train_labels, val_labels = train_test_split(label_data.values, test_size=0.1)
print(train_labels.shape)
print(val_labels.shape)

In [None]:
print(f'patientId: {train_labels[0][0]}, Target: {train_labels[0][1]}')

## Preparing train and validation image paths

In [None]:
train_f = '../input/rsna-pneumonia-detection-challenge/stage_2_train_images'
test_f = '../input/rsna-pneumonia-detection-challenge/stage_2_test_images'

train_paths = [os.path.join(train_f, image[0]) for image in train_labels]
val_paths = [os.path.join(train_f, image[0]) for image in val_labels]

print(len(train_paths))
print(len(val_paths))

## Show some samples from data

In [None]:
def imshow(num_to_show=9):
    
    plt.figure(figsize=(10,10))
    
    for i in range(num_to_show):
        plt.subplot(3, 3, i+1)
        plt.grid(False)
        plt.xticks([])
        plt.yticks([])
        
        img_dcm = dcmread(f'{train_paths[i+20]}.dcm')
        img_np = img_dcm.pixel_array
        plt.imshow(img_np, cmap=plt.cm.binary)
        plt.xlabel(train_labels[i+20][1])

imshow()

## Composing transformations

In [None]:
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.Resize(224),
    transforms.ToTensor()])

## Write a custom dataset 

In [None]:
class Dataset(data.Dataset):
    
    def __init__(self, paths, labels, transform=None):
        self.paths = paths
        self.labels = labels
        self.transform = transform
    
    def __getitem__(self, index):
        image = dcmread(f'{self.paths[index]}.dcm')
        image = image.pixel_array
        image = image / 255.0

        image = (255*image).clip(0, 255).astype(np.uint8)
        image = Image.fromarray(image).convert('RGB')

        label = self.labels[index][1]
        
        if self.transform is not None:
            image = self.transform(image)
            
        return image, label
    
    def __len__(self):
        
        return len(self.paths)

## Check the custom dataset

In [None]:
train_dataset = Dataset(train_paths, train_labels, transform=transform)
image = iter(train_dataset)
img, label = next(image)
print(f'Tensor:{img}, Label:{label}')
img = np.transpose(img, (1, 2, 0))
plt.imshow(img)

## Train image shape

In [None]:
img.shape

## Prepare training and validation dataloader

In [None]:
train_dataset = Dataset(train_paths, train_labels, transform=transform)
val_dataset = Dataset(val_paths, val_labels, transform=transform)
train_loader = data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)
val_loader = data.DataLoader(dataset=val_dataset, batch_size=128, shuffle=False)

## Check dataloader

In [None]:
batch = iter(train_loader)
images, labels = next(batch)

image_grid = torchvision.utils.make_grid(images[:4])
image_np = image_grid.numpy()
img = np.transpose(image_np, (1, 2, 0))
plt.imshow(img)

## Specify device object

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Orthogonal CNN Helper (http://pwang.pw/ocnn.html)

In [None]:
""" helper function
original author baiyu
modified by Peter Wang (@samaoline)
"""

import sys

import numpy as np


def conv_orth_dist(kernel, stride = 1):
    [o_c, i_c, w, h] = kernel.shape
    assert (w == h),"Do not support rectangular kernel"
    #half = np.floor(w/2)
    assert stride<w,"Please use matrix orthgonality instead"
    new_s = stride*(w-1) + w#np.int(2*(half+np.floor(half/stride))+1)
    temp = torch.eye(new_s*new_s*i_c).reshape((new_s*new_s*i_c, i_c, new_s,new_s)).cuda()
    out = (F.conv2d(temp, kernel, stride=stride)).reshape((new_s*new_s*i_c, -1))
    Vmat = out[np.floor(new_s**2/2).astype(int)::new_s**2, :]
    temp= np.zeros((i_c, i_c*new_s**2))
    for i in range(temp.shape[0]):temp[i,np.floor(new_s**2/2).astype(int)+new_s**2*i]=1
    return torch.norm( Vmat@torch.t(out) - torch.from_numpy(temp).float().cuda() )
    
def deconv_orth_dist(kernel, stride = 2, padding = 1):
    [o_c, i_c, w, h] = kernel.shape
    output = torch.conv2d(kernel, kernel, stride=stride, padding=padding)
    target = torch.zeros((o_c, o_c, output.shape[-2], output.shape[-1])).cuda()
    ct = int(np.floor(output.shape[-1]/2))
    target[:,:,ct,ct] = torch.eye(o_c).cuda()
    return torch.norm( output - target )
    
def orth_dist(mat, stride=None):
    mat = mat.reshape( (mat.shape[0], -1) )
    if mat.shape[0] < mat.shape[1]:
        mat = mat.permute(1,0)
    return torch.norm( torch.t(mat)@mat - torch.eye(mat.shape[1]).cuda())

## Load pre-trained ResNet18 and fine-tune

In [None]:
model = torchvision.models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
# Here the size of each output sample is set to 2.
# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
model.fc = nn.Linear(num_ftrs, 2)

model.to(device)

criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

## Write a train code and RUN

In [None]:
num_epochs = 20
# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
    # Training step
    for i, (images, labels) in tqdm(enumerate(train_loader)):
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        diff = orth_dist(model.layer2[0].downsample[0].weight) + orth_dist(model.layer3[0].downsample[0].weight) + orth_dist(model.layer4[0].downsample[0].weight)
        diff += deconv_orth_dist(model.layer1[0].conv1.weight, stride=1) + deconv_orth_dist(model.layer1[1].conv1.weight, stride=1)
      #  diff += deconv_orth_dist(model.layer1[0].conv2.weight, stride=1) + deconv_orth_dist(model.layer1[1].conv2.weight, stride=1)
        diff += deconv_orth_dist(model.layer2[0].conv1.weight, stride=2) + deconv_orth_dist(model.layer2[1].conv1.weight, stride=1)
      #  diff += deconv_orth_dist(model.layer2[0].conv2.weight, stride=1) + deconv_orth_dist(model.layer2[1].conv2.weight, stride=1)
        diff += deconv_orth_dist(model.layer3[0].conv1.weight, stride=2) + deconv_orth_dist(model.layer3[1].conv1.weight, stride=1)
      #  diff += deconv_orth_dist(model.layer3[0].conv2.weight, stride=2) + deconv_orth_dist(model.layer3[1].conv2.weight, stride=1)
        diff += deconv_orth_dist(model.layer4[0].conv1.weight, stride=2) + deconv_orth_dist(model.layer4[1].conv1.weight, stride=1)
      #  diff += deconv_orth_dist(model.layer4[0].conv2.weight, stride=2) + deconv_orth_dist(model.layer4[1].conv2.weight, stride=1)
        loss = criterion(outputs, labels) + diff * 0.1
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 2000 == 0:
            
            print("Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}"
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))


    # Validation step
    correct = 0
    total = 0  
    for images, labels in tqdm(val_loader):
        images = images.to(device)
        labels = labels.to(device)
        predictions = model(images)
        _, predicted = torch.max(predictions, 1)
        total += labels.size(0)
        correct += (labels == predicted).sum()
    print(f'Epoch: {epoch+1}/{num_epochs}, Val_Acc: {100*correct/total}')

## Test model

In [None]:
model.eval()

correct = 0
total = 0  
for images, labels in tqdm(val_loader):
    images = images.to(device)
    labels = labels.to(device)
    predictions = model(images)
    _, predicted = torch.max(predictions, 1)
    total += labels.size(0)
    correct += (labels == predicted).sum()
print(f'Val_Acc: {100*correct/total}')

In [None]:
print(model)

In [None]:
pip install git+git://github.com/raghakot/keras-vis.git --upgrade --no-deps

In [None]:
len(images)

In [None]:
def print_saliency(i):
    image_test = images[i]
    # from: https://towardsdatascience.com/saliency-map-using-pytorch-68270fe45e80
    image_test = image_test.reshape(1, 3, 224, 224)
    image_test = image_test.to(device)
    image_test.requires_grad_()

    # Retrieve output from the image
    output = model(image_test)

    # Catch the output
    output_idx = output.argmax()
    output_max = output[0, output_idx]

    # Do backpropagation to get the derivative of the output based on the image
    output_max.backward()
    
    # Retireve the saliency map and also pick the maximum value from channels on each pixel.
    # In this case, we look at dim=1. Recall the shape (batch_size, channel, width, height)
    saliency, _ = torch.max(image_test.grad.data.abs(), dim=1) 
    saliency = saliency.reshape(224, 224)

    # Reshape the image
    image_test = image_test.reshape(-1, 224, 224)

    # Visualize the image and the saliency map
    fig, ax = plt.subplots(1, 2)
    ax[0].imshow(image_test.cpu().detach().numpy().transpose(1, 2, 0))
    ax[0].axis('off')
    ax[1].imshow(saliency.cpu(), cmap='hot')
    ax[1].axis('off')
    plt.tight_layout()
    fig.suptitle('The Image and Its Saliency Map')
    plt.show()

In [None]:
print_saliency(0)

In [None]:
print_saliency(50)

In [None]:
print_saliency(78)

In [None]:
torch.save(model.state_dict(), "orthogonal_cnn.pt")