In [48]:
import torchio
from torchio.transforms import (
    RescaleIntensity,
    Compose,
)
import pandas as pd
import os
from pathlib import Path
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torchvision
import torch.nn.functional as F
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt

In [42]:
class R2Plus1dStem4MRI(nn.Sequential):
    """R(2+1)D stem is different than the default one as it uses separated 3D convolution
    """

    def __init__(self):
        super(R2Plus1dStem4MRI, self).__init__(
            nn.Conv3d(1, 155, kernel_size=(1, 7, 7),
                      stride=(1, 2, 2), padding=(0, 3, 3),
                      bias=False),
            nn.BatchNorm3d(155),
            nn.ReLU(inplace=True),

            nn.Conv3d(155, 64, kernel_size=(3, 1, 1),
                      stride=(1, 1, 1), padding=(1, 0, 0),
                      bias=False),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True))

In [25]:
class MRIDatasets:
    def __init__(self, dataset_path, metadata_path):
        self.dataset_path = dataset_path
        self.metadata = metadata_path

    def tcga(self):
        metadata_df = pd.read_csv(self.metadata)

        imgs = []
        for path, currentDir, files in os.walk(self.dataset_path):
            for file in files:
                if file.endswith('t1ce.nii.gz'):
                    subject_id = os.path.basename(path)
                    img_path = Path(path + os.sep + 't1ce.nii.gz')
                    row_df = metadata_df[metadata_df['subject_id'] == subject_id]
                    # there might be more than one row found in csv file
                    label = int(row_df['IDH1_mut'].unique()[0] + row_df['loh1p/19q_cnv'].unique()[0])
                    imgs.append(torchio.Subject(t1=torchio.ScalarImage(img_path), label=label,))

        return imgs

In [29]:
total_samples = MRIDatasets(dataset_path='../data_multimodal_tcga/Radiology', metadata_path='../data_multimodal_tcga/patient-info-tcga.csv').tcga()

In [46]:
total_samples = total_samples[:5]

In [31]:
# for dataset being unbalanced for classes [0, 1, 2]
class_weights = torch.FloatTensor([1, 2.2, 4.1])

# Transforms
rescale = RescaleIntensity((0.05, 99.5))
randaffine = torchio.RandomAffine(scales=(0.9,1.2),degrees=10, isotropic=True, image_interpolation='nearest')
flip = torchio.RandomFlip(axes=('LR'), p=0.5)
transforms = [rescale, flip, randaffine]

transform = Compose(transforms)

subjects_dataset = torchio.SubjectsDataset(total_samples, transform=transform)

# train/test split
train_set_samples = (int(len(total_samples) - 0.2 * len(total_samples)))
test_set_samples = (int(len(total_samples)) - (train_set_samples))

trainset, testset = torch.utils.data.random_split(subjects_dataset, [train_set_samples, test_set_samples],
                                                  generator=torch.Generator().manual_seed(55))

trainloader = DataLoader(dataset=trainset, batch_size=1, shuffle=True, num_workers=1)
testloader = DataLoader(dataset=testset, batch_size=1, shuffle=True, num_workers=1)

In [44]:
model = torchvision.models.video.r2plus1d_18(pretrained=False)
model.stem = R2Plus1dStem4MRI()

# regularization
model.fc = nn.Sequential(
    nn.Dropout(0.5),
    nn.Linear(model.fc.in_features, 3)
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

print(model)

VideoResNet(
  (stem): R2Plus1dStem4MRI(
    (0): Conv3d(1, 155, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3), bias=False)
    (1): BatchNorm3d(155, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv3d(155, 64, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
    (4): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Sequential(
        (0): Conv2Plus1D(
          (0): Conv3d(64, 144, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False)
          (1): BatchNorm3d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv3d(144, 64, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
        )
        (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_st

In [38]:
criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-3)

# Initialize the prediction and label lists(tensors) for confusion matrix
predlist = torch.zeros(0, dtype=torch.long).to(device)
lbllist = torch.zeros(0, dtype=torch.long).to(device)

In [None]:
from tqdm import tqdm


epochs = 2
for epoch in tqdm(range(epochs)):

    logs = {}
    total_correct = 0
    total_loss = 0
    total_images = 0
    total_val_loss = 0

    for i, traindata in enumerate(trainloader):
        images = F.interpolate(traindata['t1'][torchio.DATA], scale_factor=(0.7, 0.7, 0.7)).to(device)
        labels = traindata['label'].to(device)
        optimizer.zero_grad()

        # Forward propagation
        outputs = model(images)

        loss = criterion(outputs, labels)

        # Backward prop
        loss.backward()

        # Updating gradients
        optimizer.step()
        # scheduler.step()

        # Total number of labels
        total_images += labels.size(0)

        # Obtaining predictions from max value
        _, predicted = torch.max(outputs.data, 1)

        # Calculate the number of correct answers
        correct = (predicted == labels).sum().item()

        total_correct += correct
        total_loss += loss.item()

        running_trainacc = ((total_correct / total_images) * 100)

        logs['log loss'] = total_loss / total_images
        logs['Accuracy'] = ((total_correct / total_images) * 100)

        print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Accuracy: {:.2f}%'
              .format(epoch + 1, epochs, i + 1, len(trainloader), (total_loss / total_images),
                      (total_correct / total_images) * 100))

        # Testing the model

        with torch.no_grad():
            correct = 0
            total = 0

            for testdata in testloader:
                images = F.interpolate(testdata['t1'][torchio.DATA], scale_factor=(0.7, 0.7, 0.7)).to(device)

                labels = testdata['label'].to(device)
                outputs = model(images)

                _, predicted = torch.max(outputs.data, 1)

                predlist = torch.cat([predlist, predicted.view(-1)])  # Append batch prediction results

                lbllist = torch.cat([lbllist, labels.view(-1)])

                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                total_losss = loss.item()

                accuracy = correct / total

            print('Test Accuracy of the model: {} %'.format(100 * correct / total))

            logs['val_' + 'log loss'] = total_loss / total
            validationloss = total_loss / total

            validationacc = ((correct / total) * 100)
            logs['val_' + 'Accuracy'] = ((correct / total) * 100)

  0%|          | 0/2 [00:00<?, ?it/s]

Epoch [1/2], Step [1/149], Loss: 0.5056, Accuracy: 100.00%
Test Accuracy of the model: 23.68421052631579 %
Epoch [1/2], Step [2/149], Loss: 1.5871, Accuracy: 50.00%
Test Accuracy of the model: 31.57894736842105 %
Epoch [1/2], Step [3/149], Loss: 1.7863, Accuracy: 33.33%
Test Accuracy of the model: 26.31578947368421 %
Epoch [1/2], Step [4/149], Loss: 2.0876, Accuracy: 25.00%
Test Accuracy of the model: 26.31578947368421 %
Epoch [1/2], Step [5/149], Loss: 1.8404, Accuracy: 40.00%
Test Accuracy of the model: 26.31578947368421 %
Epoch [1/2], Step [6/149], Loss: 1.9646, Accuracy: 33.33%


In [11]:
conf_mat = confusion_matrix(lbllist.cpu().numpy(), predlist.cpu().numpy())

print(conf_mat)
cls = ["0", "1", "2"]
# Per-class accuracy
class_accuracy = 100 * conf_mat.diagonal() / conf_mat.sum(1)
print(class_accuracy)
plt.figure(figsize=(10, 10))
# plot_confusion_matrix(conf_mat, cls)
plt.show()