In [8]:
import torch 
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.utils.data.dataset import Dataset  # For custom datasets
from torch.utils.data.sampler import SubsetRandomSampler
import torch.backends.cudnn as cudnn
import torchvision.models as models
import pandas as pd
import numpy as np
from PIL import Image
import time
import shutil


# Device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Hyper parameters
validation_split = .2
shuffle_dataset = False
random_seed= 42
num_epochs = 5
num_classes = 2
batch_size = 20
learning_rate = 0.001
weight_decay = 0.001
momentum = 0.9
print_freq = 1
best_prec1 = 0
workers = 8
pretrained = False
fine_tune = True
arch = 'alexnet'
classes = [1,2]

dir = '/home/suraj/asd-abide-prediction-pytorch/'

#dataset

data = pd.read_csv(dir + 'data.csv')

class CustomDatasetFromImages(Dataset):
    def __init__(self, csv_path, transformation):
        """
        Args:
            csv_path (string): path to csv file
            img_path (string): path to the folder where images are
            transform: pytorch transforms for transforms and tensor conversion
        """
        # Transforms
        self.to_tensor = transformation
        # Read the csv file
        self.data_info = pd.read_csv(csv_path, header=None)
        # First column contains the image paths
        self.image_arr = np.asarray(dir + 'png-i/' + self.data_info.iloc[:, 6] + '_alff.nii.jpg')
        # Second column is the labels
        self.label_arr = np.asarray(self.data_info.iloc[:, 7])
        # Third column is for an operation indicator
        # self.operation_arr = np.asarray(self.data_info.iloc[:, 2])
        # Calculate len
        self.data_len = len(self.data_info.index)

    def __getitem__(self, index):
        # Get image name from the pandas df
        single_image_name = self.image_arr[index]
        # Open image
        img_as_img = Image.open(single_image_name)

        # Check if there is an operation
        #some_operation = self.operation_arr[index]
        # If there is an operation
        #if some_operation:
            # Do some operation on image
            # ...
            # ...
        #    pass
        # Transform image to tensor
        img_as_tensor = self.to_tensor(img_as_img)

        # Get label(class) of the image based on the cropped pandas column
        single_image_label = self.label_arr[index]

        return (img_as_tensor, single_image_label)

    def __len__(self):
        return self.data_len

# if __name__ == "__main__":

#     normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
#                                  std=[0.229, 0.224, 0.225])
#     transformations = transforms.Compose([
#             transforms.RandomSizedCrop(224),
#             transforms.RandomHorizontalFlip(),
#             transforms.ToTensor(),
#             normalize,
#         ])

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

dataset_train = CustomDatasetFromImages(dir + 'train.csv', transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

dataset_test = CustomDatasetFromImages(dir + 'test.csv', transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))
    
# Data loader
#train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
#                                          batch_size=batch_size, 
#                                         shuffle=True)

#test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
#                                          batch_size=batch_size, 
#                                          shuffle=False)



# Creating data indices for training and validation splits:
# dataset_size = len(dataset)
# indices = list(range(dataset_size))
# split = int(np.floor(validation_split * dataset_size))
# if shuffle_dataset :
#     np.random.seed(random_seed)
#     np.random.shuffle(indices)
# train_indices, val_indices = indices[split:], indices[:split]

# Creating PT data samplers and loaders:
# train_sampler = SubsetRandomSampler(train_indices)
# validation_sampler = SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True, 
                                           num_workers=workers, pin_memory=True)
validation_loader = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=True,
                                                num_workers=workers, pin_memory=True)


def train(train_loader, model, criterion, optimizer, epoch):
    """Train the model on Training Set"""
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
    
        input_var = torch.autograd.Variable(input)
        target_var = torch.autograd.Variable(target)

        # compute output
        output = model(input_var)
        #topk = (1,5) if labels >= 100 else (1,) # TO FIX
        # For nets that have multiple outputs such as Inception
        if isinstance(output, tuple):
            loss = sum((criterion(o,target_var) for o in output))
            # print (output)
            for o in output:
                prec1 = accuracy(o.data, target, topk=(1,))
                top1.update(prec1[0], input.size(0))
            losses.update(loss.data[0], input.size(0)*len(output))
        else:
            loss = criterion(output, target_var)
            prec1 = accuracy(output.data, target, topk=(1,))
            top1.update(prec1[0], input.size(0))
            losses.update(loss.data[0], input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # Info log every args.print_freq
        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1_val} ({top1_avg})'.format(
                   epoch, i, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses,
                   top1_val=np.asscalar(top1.val.cpu().numpy())+0XA,
                   top1_avg=np.asscalar(top1.avg.cpu().numpy())+0XA))


def validate(val_loader, model, criterion):
    """Validate the model on Validation Set"""
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    # Evaluate all the validation set
    for i, (input, target) in enumerate(val_loader):
    
        input_var = torch.autograd.Variable(input, volatile=True)
        target_var = torch.autograd.Variable(target, volatile=True)

        # compute output
        output = model(input_var)
        # print ("Output: ", output)
        #topk = (1,5) if labels >= 100 else (1,) # TODO: add more topk evaluation
        # For nets that have multiple outputs such as Inception
        if isinstance(output, tuple):
            loss = sum((criterion(o,target_var) for o in output))
            # print (output)
            for o in output:
                prec1 = accuracy(o.data, target, topk=(1,))
                top1.update(prec1[0], input.size(0))
            losses.update(loss.data[0], input.size(0)*len(output))
        else:
            loss = criterion(output, target_var)
            prec1 = accuracy(output.data, target, topk=(1,))
            top1.update(prec1[0], input.size(0))
            losses.update(loss.data[0], input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # Info log every args.print_freq
        if i % print_freq == 0:
            print('Test: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1_val} ({top1_avg})'.format(
                   i, len(val_loader), batch_time=batch_time,
                   loss=losses,
                   top1_val=np.asscalar(top1.val.cpu().numpy())+0XA,
                   top1_avg=np.asscalar(top1.avg.cpu().numpy())+0XA))

    print(' * Prec@1 {top1}'
          .format(top1=np.asscalar(top1.avg.cpu().numpy())+0XA))
    return top1.avg


def test(test_loader, model, classes):
    """Test the model on the Evaluation Folder
    Args:
        - classes: is a list with the class name
        - names: is a generator to retrieve the filename that is classified
    """
    # switch to evaluate mode
    model.eval()
    # Evaluate all the validation set
    for i, (input, _) in enumerate(test_loader):
    
        input_var = torch.autograd.Variable(input, volatile=True)

        # compute output
        output = model(input_var)
        # Take last layer output
        if isinstance(output, tuple):
            output = output[len(output)-1]

        # print (output.data.max(1, keepdim=True)[1])
        lab = classes[np.asscalar(output.data.max(1, keepdim=True)[1].cpu().numpy())]
        print ("Image classified as: " + lab)


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = learning_rate * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

if pretrained:
        print("=> using pre-trained model '{}'".format(arch))
        model = models.__dict__[arch](pretrained=True)
        print(model)
        # quit()
else:
        print("=> creating model '{}'".format(arch))
        model = models.__dict__[arch](num_classes=num_classes)
        # print(model)

    # Freeze model, train only the last FC layer for the transfered task
        if fine_tune:
            print("=> transfer-learning mode + fine-tuning (train only the last FC layer)")
            # Freeze Previous Layers(now we are using them as features extractor)
            for param in model.parameters():
                param.requires_grad = False

            # Fine Tuning the last Layer For the new task
            model.classifier._modules['6'] = nn.Linear(4096, num_classes+1)
            parameters = model.classifier._modules['6'].parameters()
            print(model)
            # quit()

        else:
            parameters = model.parameters()

    
# Define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss()

# Set SGD + Momentum
optimizer = torch.optim.SGD(model.parameters(), learning_rate,
                            momentum=momentum,
                            weight_decay=weight_decay)

# optionally resume from a checkpoint

# Load model on CPU
model.cpu()

############ TRAIN/EVAL/TEST ############
cudnn.benchmark = True

# Training

print("=> training...")
for epoch in range(num_epochs):

    adjust_learning_rate(optimizer, epoch)

    # Train for one epoch
    train(train_loader, model, criterion, optimizer, epoch)

    # Evaluate on validation set
    prec1 = validate(validation_loader, model, criterion)
    # print (prec1)

    #Test data
    #names = get_images_name(os.path.join(testdir, 'images'))
    #test(validation_loader, model, train_dataset.classes)

    # Remember best prec@1 and save checkpoint
    prec1 = prec1.cpu() # Load on CPU if CUDA
    # Get bool not ByteTensor
    is_best = prec1 > best_prec1
    # Get greater Tensor
    best_prec1 = max(prec1, best_prec1)

    save_checkpoint({
        'epoch': epoch + 1,
        'arch': arch,
        'state_dict': model.state_dict(),
        'best_prec1': best_prec1,
        'optimizer' : optimizer.state_dict(),
    }, is_best)


=> creating model 'alexnet'
=> transfer-learning mode + fine-tuning (train only the last FC layer)
AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Dropout(p=0.5)
    (1): Linear(in_features=9216, out_features=4096, bias



Epoch: [0][0/42]	Time 0.967 (0.967)	Data 0.380 (0.380)	Loss 1.0962 (1.0962)	Prec@1 55.0 (55.0)
Epoch: [0][1/42]	Time 0.665 (0.816)	Data 0.000 (0.190)	Loss 1.0938 (1.0950)	Prec@1 70.0 (62.5)
Epoch: [0][2/42]	Time 0.619 (0.751)	Data 0.000 (0.127)	Loss 1.0952 (1.0951)	Prec@1 60.0 (61.66666793823242)
Epoch: [0][3/42]	Time 0.577 (0.707)	Data 0.000 (0.095)	Loss 1.0952 (1.0951)	Prec@1 50.0 (58.75)
Epoch: [0][4/42]	Time 0.580 (0.682)	Data 0.000 (0.076)	Loss 1.0932 (1.0947)	Prec@1 50.0 (57.0)
Epoch: [0][5/42]	Time 0.570 (0.663)	Data 0.000 (0.063)	Loss 1.0932 (1.0945)	Prec@1 50.0 (55.83333206176758)
Epoch: [0][6/42]	Time 0.571 (0.650)	Data 0.000 (0.054)	Loss 1.0886 (1.0936)	Prec@1 80.0 (59.28571319580078)
Epoch: [0][7/42]	Time 0.571 (0.640)	Data 0.000 (0.048)	Loss 1.0903 (1.0932)	Prec@1 50.0 (58.125)
Epoch: [0][8/42]	Time 0.586 (0.634)	Data 0.006 (0.043)	Loss 1.0886 (1.0927)	Prec@1 50.0 (57.22222137451172)
Epoch: [0][9/42]	Time 0.575 (0.628)	Data 0.001 (0.039)	Loss 1.0852 (1.0919)	Prec@1 85.0 (6



Test: [0/11]	Time 1.107 (1.107)	Loss 1.0296 (1.0296)	Prec@1 55.0 (55.0)
Test: [1/11]	Time 0.538 (0.823)	Loss 1.0298 (1.0297)	Prec@1 55.0 (55.0)
Test: [2/11]	Time 0.536 (0.727)	Loss 1.0285 (1.0293)	Prec@1 65.0 (58.33333206176758)
Test: [3/11]	Time 0.528 (0.677)	Loss 1.0293 (1.0293)	Prec@1 60.0 (58.75)
Test: [4/11]	Time 0.533 (0.649)	Loss 1.0291 (1.0293)	Prec@1 60.0 (59.0)
Test: [5/11]	Time 0.525 (0.628)	Loss 1.0284 (1.0291)	Prec@1 65.0 (60.0)
Test: [6/11]	Time 0.536 (0.615)	Loss 1.0293 (1.0291)	Prec@1 60.0 (60.0)
Test: [7/11]	Time 0.537 (0.605)	Loss 1.0291 (1.0291)	Prec@1 60.0 (60.0)
Test: [8/11]	Time 0.533 (0.597)	Loss 1.0294 (1.0292)	Prec@1 60.0 (60.0)
Test: [9/11]	Time 0.530 (0.590)	Loss 1.0280 (1.0290)	Prec@1 70.0 (61.0)
Test: [10/11]	Time 0.175 (0.553)	Loss 1.0314 (1.0291)	Prec@1 43.33333206176758 (60.485435485839844)
 * Prec@1 60.485435485839844
Epoch: [1][0/42]	Time 1.167 (1.167)	Data 0.450 (0.450)	Loss 1.0230 (1.0230)	Prec@1 80.0 (80.0)
Epoch: [1][1/42]	Time 0.530 (0.848)	Data 0

Epoch: [2][21/42]	Time 0.576 (0.594)	Data 0.001 (0.012)	Loss 0.9377 (0.9488)	Prec@1 60.0 (63.40909194946289)
Epoch: [2][22/42]	Time 0.574 (0.593)	Data 0.002 (0.012)	Loss 0.9349 (0.9482)	Prec@1 65.0 (63.4782600402832)
Epoch: [2][23/42]	Time 0.568 (0.592)	Data 0.001 (0.011)	Loss 0.9380 (0.9478)	Prec@1 55.0 (63.125)
Epoch: [2][24/42]	Time 0.582 (0.591)	Data 0.001 (0.011)	Loss 0.9348 (0.9473)	Prec@1 60.0 (63.0)
Epoch: [2][25/42]	Time 0.562 (0.590)	Data 0.002 (0.011)	Loss 0.9314 (0.9467)	Prec@1 60.0 (62.88461685180664)
Epoch: [2][26/42]	Time 0.554 (0.589)	Data 0.001 (0.010)	Loss 0.9380 (0.9463)	Prec@1 45.0 (62.22222137451172)
Epoch: [2][27/42]	Time 0.549 (0.587)	Data 0.002 (0.010)	Loss 0.9301 (0.9458)	Prec@1 65.0 (62.32143020629883)
Epoch: [2][28/42]	Time 0.554 (0.586)	Data 0.002 (0.010)	Loss 0.9323 (0.9453)	Prec@1 60.0 (62.24137878417969)
Epoch: [2][29/42]	Time 0.557 (0.585)	Data 0.001 (0.009)	Loss 0.9281 (0.9447)	Prec@1 65.0 (62.33333206176758)
Epoch: [2][30/42]	Time 0.547 (0.584)	Data 0.

Test: [9/11]	Time 0.530 (0.585)	Loss 0.8959 (0.8920)	Prec@1 50.0 (61.0)
Test: [10/11]	Time 0.175 (0.547)	Loss 0.8984 (0.8922)	Prec@1 43.33333206176758 (60.485435485839844)
 * Prec@1 60.485435485839844
Epoch: [4][0/42]	Time 0.966 (0.966)	Data 0.275 (0.275)	Loss 0.8851 (0.8851)	Prec@1 60.0 (60.0)
Epoch: [4][1/42]	Time 0.546 (0.756)	Data 0.001 (0.138)	Loss 0.8866 (0.8858)	Prec@1 55.0 (57.5)
Epoch: [4][2/42]	Time 0.572 (0.695)	Data 0.004 (0.093)	Loss 0.8874 (0.8864)	Prec@1 50.0 (55.0)
Epoch: [4][3/42]	Time 0.576 (0.665)	Data 0.000 (0.070)	Loss 0.8771 (0.8841)	Prec@1 75.0 (60.0)
Epoch: [4][4/42]	Time 0.566 (0.645)	Data 0.001 (0.056)	Loss 0.8820 (0.8836)	Prec@1 60.0 (60.0)
Epoch: [4][5/42]	Time 0.578 (0.634)	Data 0.002 (0.047)	Loss 0.8809 (0.8832)	Prec@1 60.0 (60.0)
Epoch: [4][6/42]	Time 0.636 (0.634)	Data 0.000 (0.040)	Loss 0.8745 (0.8820)	Prec@1 75.0 (62.14285659790039)
Epoch: [4][7/42]	Time 0.589 (0.629)	Data 0.002 (0.036)	Loss 0.8806 (0.8818)	Prec@1 60.0 (61.875)
Epoch: [4][8/42]	Time 0.