In [2]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import h5py
from Augmentation import RandomCrop, CenterCrop, RandomFlip

  from ._conv import register_converters as _register_converters


In [2]:
def normalize_MRIs(image):
    mean = np.mean(image)
    std = np.std(image)
    image -= mean
    #image -= 95.09
    image /= std
    #image /= 86.38
    return image

In [36]:
class OAI_Dataloader(Dataset):
    """OAI dataset. Sequences of images, each sequence labeled by 0 or 1 (TKR or not)"""

    def __init__(self, root_dir, csv_root_dir, csv_file, dim=(384,384,32), 
                 normalize = True, randomCrop = True, 
                 randomFlip = True, flipProbability = -1, cropDim = (384,384,32)):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.root_dir = root_dir
        self.data = pd.read_csv(csv_root_dir + csv_file)[0:10]
        self.folders = ['00m']
        self.dim = dim
        self.normalize = normalize
        self.randomCrop = randomCrop
        self.randomFlip = randomFlip
        self.flipProbability = flipProbability
        self.cropDim = cropDim
        
        
    def __len__(self):
        return self.data.shape[0]
    

    def __getitem__(self, idx):
        """ returns img, label with : img = [batch_size, 1, X, Y, Z] 
                                      label in {0,1}
        """
        file = self.data['FileName'].iloc[idx]
        
        img_seq = []
        k = 0
        for folder in self.folders:
            #pre_image = h5py.File(self.root_dir + folder + '/' + file, "r")['data/'].value.astype('float64')
            pre_image = h5py.File(self.root_dir + file, "r")['data/'].value.astype('float64')
            if self.normalize:
                pre_image = normalize_MRIs(pre_image)
            # Augmentation
            if self.randomFlip:
                pre_image = RandomFlip(image=pre_image,p=0.5).horizontal_flip(p=self.flipProbability)
            if self.randomCrop:
                pre_image = RandomCrop(pre_image).crop_along_hieght_width_depth(self.cropDim)
            else:
                pre_image = CenterCrop(image=pre_image).crop(size = self.cropDim)
                
            img_seq.append(pre_image)
            k += 1

        img_seq = torch.tensor(torch.from_numpy(np.array(img_seq)), dtype=torch.float)

        label = int(self.data['NumberOfDaysFromScanToTKR'].iloc[idx] != 0)
         
        return (img_seq, label)

In [37]:
# #test1
# root_dir = '/gpfs/data/denizlab/Datasets/OAI/SAG_3D_DESS/'
# csv_file = 'HDF5_00_SAG_3D_DESScohort_2_prime.csv'

#test2
csv_root_dir = '/home/yangj14/oai/Tianyu/data/'
root_dir = '/gpfs/data/denizlab/Datasets/OAI/SAG_IW_TSE/'
train_csv_file = 'Fold1_train.csv'
val_csv_file = 'Fold1_val.csv'

In [58]:
train_params = {'dim': (384,384,32),
          'normalize' : False,
          'randomCrop' : False,
          'randomFlip' : False,
          'flipProbability' : -1,
          'cropDim' : (384,384,32)}

train_dataset = OAI_Dataloader(root_dir, csv_root_dir, train_csv_file, **train_params)
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=False)

val_dataset = OAI_Dataloader(root_dir, csv_root_dir, val_csv_file, **train_params)
val_dataloader = DataLoader(val_dataset, batch_size=2, shuffle=False)

In [47]:
for im, l in train_dataloader:
    print(im.shape)
    print(l)
    #print((torch.tensor(im, dtype=torch.double)))

torch.Size([2, 1, 384, 384, 32])
tensor([ 0,  0])
torch.Size([2, 1, 384, 384, 32])
tensor([ 0,  0])
torch.Size([2, 1, 384, 384, 32])
tensor([ 0,  1])
torch.Size([2, 1, 384, 384, 32])
tensor([ 0,  1])
torch.Size([2, 1, 384, 384, 32])
tensor([ 1,  1])


In [64]:
'''VGG11/13/16/19 in Pytorch'''
# Only tried model with one data point. Need to load more data
import torch
import torch.nn as nn
from torch.autograd import Variable


cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


class VGG(nn.Module):

    def __init__(self, vgg_name):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(4608, 2)


    def forward(self, x):
        out = self.features(x)
        #print(out.shape)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 1
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))]
            elif x == 64:
                layers += [nn.Conv3d(in_channels, x, kernel_size=(3, 3, 3), stride = (2, 2,1), padding=1),
                           nn.BatchNorm3d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
            else:
                layers += [nn.Conv3d(in_channels, x, kernel_size=(3, 3, 3), stride = (1, 1,1), padding=1),
                           nn.BatchNorm3d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
        layers += [nn.AvgPool3d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)

# net = VGG('VGG11')
# x = torch.randn(2,3,32,32)
# print(net(Variable(x)).size())


# net = VGG('VGG16')
# for im, l in train_dataloader:
#     print(net(im).size())

In [53]:
no_cuda = False
log_interval = 1

cuda = not no_cuda and torch.cuda.is_available()
# cuda = False

seed = 1
torch.manual_seed(seed)

# device = torch.device("cuda" if args.cuda else "cpu")
device = torch.device("cuda" if cuda else "cpu")

# kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}

In [65]:
# Train the model 
def test_model(loader, model):
    """
    Help function that tests the model's performance on a dataset
    @param: loader - data loader for the dataset to test against
    """
    correct = 0
    total = 0
    model.eval()
    for im, labels in loader:
        im, labels = im.to(device), labels.to(device)
        outputs = F.softmax(model(im), dim=1)
        predicted = outputs.max(1, keepdim=True)[1]

        total += labels.size(0)
        correct += predicted.eq(labels.view_as(predicted)).sum().item()
    return (100 * correct / total)

import timeit
start = timeit.default_timer()

model = VGG('VGG16').to(device)

learning_rate = 0.00005
num_epochs = 20 # number epoch to train

# Criterion and Optimizer
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
total_step = len(train_dataloader)

val_accuracy_100_vgg = []
val_accuracy_epoch_vgg = []


for epoch in range(num_epochs):
    for i, (im, labels) in enumerate(train_dataloader):
        im, labels = im.to(device), labels.to(device)
        model.train()
        optimizer.zero_grad()
        # Forward pass
        outputs = model(im)
        loss = criterion(outputs, labels)

        # Backward and optimize
        loss.backward()
        optimizer.step()
        # validate every 100 iterations
        if i > 0 and i % 100 == 0:
            # validate
            val_acc = test_model(val_dataloader, model)
            print('Epoch: [{}/{}], Step: [{}/{}], Validation Acc: {}'.format(
                       epoch+1, num_epochs, i+1, len(train_loader), val_acc))
            val_accuracy_100_vgg.append(val_acc)
    val_acc = test_model(val_dataloader, model)
    val_accuracy_epoch_vgg.append(val_acc)
    
    
stop = timeit.default_timer()
print(stop - start)

print(val_accuracy_epoch_vgg)

79.72023869212717
[60.0, 60.0, 60.0, 60.0, 60.0, 60.0, 60.0, 60.0, 60.0, 60.0, 60.0, 60.0, 70.0, 70.0, 70.0, 70.0, 70.0, 70.0, 70.0, 70.0]


In [62]:
total_step

5