In [1]:
import torch 
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.utils.data.dataset import Dataset  # For custom datasets
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision.models as models
import pandas as pd
import numpy as np
from PIL import Image
import time
import shutil


# Device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Hyper parameters
validation_split = .2
shuffle_dataset = False
random_seed= 4
num_epochs = 5
num_classes = 2
batch_size = 20
learning_rate = 0.001
weight_decay = 0.0001
momentum = 0.9
print_freq = 1
best_prec1 = 0

dir = '/home/plant99/Documents/code/pro_per/a-t-b/data/'

#dataset

data = pd.read_csv(dir + 'data.csv')

class CustomDatasetFromImages(Dataset):
    def __init__(self, csv_path):
        """
        Args:
            csv_path (string): path to csv file
            img_path (string): path to the folder where images are
            transform: pytorch transforms for transforms and tensor conversion
        """
        # Transforms
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        self.to_tensor = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])
        # Read the csv file
        self.data_info = pd.read_csv(csv_path, header=None)
        # First column contains the image paths
        self.image_arr = np.asarray(dir+'/nii-data/Outputs/cpac/filt_global/png-i/' + self.data_info.iloc[:, 6] + '_alff.nii.jpg')
        # Second column is the labels
        self.label_arr = np.asarray(self.data_info.iloc[:, 7])
        # Third column is for an operation indicator
        # self.operation_arr = np.asarray(self.data_info.iloc[:, 2])
        # Calculate len
        self.data_len = len(self.data_info.index)

    def __getitem__(self, index):
        # Get image name from the pandas df
        single_image_name = self.image_arr[index]
        # Open image
        img_as_img = Image.open(single_image_name)

        # Check if there is an operation
        #some_operation = self.operation_arr[index]
        # If there is an operation
        #if some_operation:
            # Do some operation on image
            # ...
            # ...
        #    pass
        # Transform image to tensor
        img_as_tensor = self.to_tensor(img_as_img)

        # Get label(class) of the image based on the cropped pandas column
        single_image_label = self.label_arr[index]

        return (img_as_tensor, single_image_label)

    def __len__(self):
        return self.data_len

# if __name__ == "__main__":

#     normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
#                                  std=[0.229, 0.224, 0.225])
#     transformations = transforms.Compose([
#             transforms.RandomSizedCrop(224),
#             transforms.RandomHorizontalFlip(),
#             transforms.ToTensor(),
#             normalize,
#         ])

dataset = CustomDatasetFromImages(dir + 'data.csv')
    
# Data loader
#train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
#                                          batch_size=batch_size, 
#                                         shuffle=True)

#test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
#                                          batch_size=batch_size, 
#                                          shuffle=False)



# Creating data indices for training and validation splits:
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
validation_sampler = SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 
                                           sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                                sampler=validation_sampler)


# Convolutional neural network (two convolutional layers)
# class ConvNet(nn.Module):
#     def __init__(self, num_classes=10):
#         super(ConvNet, self).__init__()
#         self.layer1 = nn.Sequential(
#             nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=2),
#             nn.BatchNorm2d(16),
#             nn.ReLU(),
#             nn.MaxPool2d(kernel_size=2, stride=2))
#         self.layer2 = nn.Sequential(
#             nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
#             nn.BatchNorm2d(32),
#             nn.ReLU(),
#             nn.MaxPool2d(kernel_size=2, stride=2))
#         self.fc = nn.Linear(51168, num_classes+1)
        
#     def forward(self, x):
#         out = self.layer1(x)
#         out = self.layer2(out)
#         out = out.reshape(out.size(0), -1)
#         out = self.fc(out)
#         return out

model = models.alexnet()


# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), learning_rate,
                                momentum,
                                weight_decay)

def train(train_loader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        #target = target.cuda(async=True)
        input_var = torch.autograd.Variable(input)
        target_var = torch.autograd.Variable(target)

        # compute output
        output = model(input_var)
        loss = criterion(output, target_var)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.data[0], input.size(0))
        top1.update(prec1[0], input.size(0))
        top5.update(prec5[0], input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                   epoch, i, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses, top1=top1, top5=top5))


def validate(val_loader, model, criterion):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    for i, (input, target) in enumerate(val_loader):
        #target = target.cuda(async=True)
        input_var = torch.autograd.Variable(input, volatile=True)
        target_var = torch.autograd.Variable(target, volatile=True)

        # compute output
        output = model(input_var)
        loss = criterion(output, target_var)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.data[0], input.size(0))
        top1.update(prec1[0], input.size(0))
        top5.update(prec5[0], input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % print_freq == 0:
            print('Test: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                   i, len(val_loader), batch_time=batch_time, loss=losses,
                   top1=top1, top5=top5))

    print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
          .format(top1=top1, top5=top5))

    return top1.avg


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = learning_rate * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


for epoch in range(num_epochs):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(validation_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint({
            'epoch': epoch + 1,
            #'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
        }, is_best)


    
     
# # Train the model
# total_step = len(train_loader)
# for epoch in range(num_epochs):
#     for i, (images, labels) in enumerate(train_loader):
#         labels_ = []
#         for label in labels:
#             labels_.append(int(label))
#         labels = torch.LongTensor(labels_)
#         images = images.to(device)
#         labels = labels.to(device)
        
#         # Forward pass
#         outputs = model(images)
#         loss = criterion(outputs, labels)
        
#         # Backward and optimize
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
        
#         if (i+1) % 100 == 0:
#             print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
#                    .format(epoch+1, num_epochs, i+1, total_step, loss.item()))
            

# # Test the model
# model.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
# with torch.no_grad():
#     correct = 0
#     total = 0
#     for images, labels in validation_loader:
#         images = images.to(device)
#         labels_ = []
#         for label in labels:
#             labels_.append(int(label))
#         labels = torch.LongTensor(labels_)
#         labels = labels.to(device)
#         outputs = model(images)
#         _, predicted = torch.max(outputs.data, 1)
#         total += labels.size(0)
#         correct += (predicted == labels).sum().item()


#     print('Test Accuracy of the model on the 1036 test images: {} %'.format(100 * correct / total))

# # Save the model checkpoint
# #torch.save(model.state_dict(), 'model.ckpt')



Epoch: [0][0/42]	Time 3.605 (3.605)	Data 0.441 (0.441)	Loss 6.9082 (6.9082)	Prec@1 0.000 (0.000)	Prec@5 0.000 (0.000)
Epoch: [0][1/42]	Time 2.762 (3.183)	Data 0.275 (0.358)	Loss 6.9050 (6.9066)	Prec@1 0.000 (0.000)	Prec@5 0.000 (0.000)
Epoch: [0][2/42]	Time 2.986 (3.117)	Data 0.512 (0.410)	Loss 6.9035 (6.9056)	Prec@1 0.000 (0.000)	Prec@5 0.000 (0.000)
Epoch: [0][3/42]	Time 3.127 (3.120)	Data 0.281 (0.377)	Loss 6.8988 (6.9039)	Prec@1 0.000 (0.000)	Prec@5 0.000 (0.000)
Epoch: [0][4/42]	Time 2.559 (3.008)	Data 0.243 (0.350)	Loss 6.9019 (6.9035)	Prec@1 0.000 (0.000)	Prec@5 0.000 (0.000)
Epoch: [0][5/42]	Time 2.610 (2.942)	Data 0.298 (0.342)	Loss 6.8876 (6.9008)	Prec@1 0.000 (0.000)	Prec@5 0.000 (0.000)
Epoch: [0][6/42]	Time 2.537 (2.884)	Data 0.217 (0.324)	Loss 6.8918 (6.8995)	Prec@1 0.000 (0.000)	Prec@5 0.000 (0.000)
Epoch: [0][7/42]	Time 3.130 (2.915)	Data 0.382 (0.331)	Loss 6.8785 (6.8969)	Prec@1 25.000 (3.125)	Prec@5 35.000 (4.375)
Epoch: [0][8/42]	Time 3.250 (2.952)	Data 0.592 (0.360)



Test: [0/11]	Time 1.190 (1.190)	Loss 17.2486 (17.2486)	Prec@1 40.000 (40.000)	Prec@5 100.000 (100.000)
Test: [1/11]	Time 1.106 (1.148)	Loss 14.3266 (15.7876)	Prec@1 50.000 (45.000)	Prec@5 100.000 (100.000)
Test: [2/11]	Time 1.136 (1.144)	Loss 19.4674 (17.0142)	Prec@1 35.000 (41.667)	Prec@5 100.000 (100.000)
Test: [3/11]	Time 1.102 (1.133)	Loss 14.8829 (16.4814)	Prec@1 50.000 (43.750)	Prec@5 100.000 (100.000)
Test: [4/11]	Time 1.111 (1.129)	Loss 22.4436 (17.6738)	Prec@1 25.000 (40.000)	Prec@5 100.000 (100.000)
Test: [5/11]	Time 1.085 (1.121)	Loss 16.0320 (17.4002)	Prec@1 45.000 (40.833)	Prec@5 100.000 (100.000)
Test: [6/11]	Time 1.032 (1.109)	Loss 16.1337 (17.2193)	Prec@1 45.000 (41.429)	Prec@5 100.000 (100.000)
Test: [7/11]	Time 1.163 (1.116)	Loss 17.8488 (17.2980)	Prec@1 40.000 (41.250)	Prec@5 100.000 (100.000)
Test: [8/11]	Time 1.152 (1.120)	Loss 13.6057 (16.8877)	Prec@1 55.000 (42.778)	Prec@5 100.000 (100.000)
Test: [9/11]	Time 1.161 (1.124)	Loss 11.7418 (16.3731)	Prec@1 60.000 (44.

Epoch: [2][6/42]	Time 2.385 (2.415)	Data 0.088 (0.116)	Loss 0.7362 (0.7887)	Prec@1 55.000 (47.143)	Prec@5 100.000 (100.000)
Epoch: [2][7/42]	Time 2.439 (2.418)	Data 0.088 (0.113)	Loss 0.7140 (0.7794)	Prec@1 50.000 (47.500)	Prec@5 100.000 (100.000)
Epoch: [2][8/42]	Time 2.522 (2.429)	Data 0.154 (0.117)	Loss 0.7225 (0.7731)	Prec@1 50.000 (47.778)	Prec@5 100.000 (100.000)
Epoch: [2][9/42]	Time 2.403 (2.427)	Data 0.123 (0.118)	Loss 0.7245 (0.7682)	Prec@1 45.000 (47.500)	Prec@5 100.000 (100.000)
Epoch: [2][10/42]	Time 2.463 (2.430)	Data 0.142 (0.120)	Loss 0.7167 (0.7635)	Prec@1 50.000 (47.727)	Prec@5 100.000 (100.000)
Epoch: [2][11/42]	Time 2.497 (2.435)	Data 0.133 (0.121)	Loss 0.7995 (0.7665)	Prec@1 30.000 (46.250)	Prec@5 100.000 (100.000)
Epoch: [2][12/42]	Time 2.444 (2.436)	Data 0.132 (0.122)	Loss 0.7464 (0.7650)	Prec@1 40.000 (45.769)	Prec@5 100.000 (100.000)
Epoch: [2][13/42]	Time 2.602 (2.448)	Data 0.318 (0.136)	Loss 0.7036 (0.7606)	Prec@1 55.000 (46.429)	Prec@5 100.000 (100.000)
Epoc

Epoch: [3][21/42]	Time 2.452 (2.504)	Data 0.108 (0.184)	Loss 0.7377 (0.7211)	Prec@1 45.000 (50.227)	Prec@5 100.000 (100.000)
Epoch: [3][22/42]	Time 2.482 (2.503)	Data 0.173 (0.184)	Loss 0.7034 (0.7203)	Prec@1 55.000 (50.435)	Prec@5 100.000 (100.000)
Epoch: [3][23/42]	Time 2.399 (2.499)	Data 0.114 (0.181)	Loss 0.6633 (0.7179)	Prec@1 60.000 (50.833)	Prec@5 100.000 (100.000)
Epoch: [3][24/42]	Time 2.426 (2.496)	Data 0.138 (0.179)	Loss 0.6806 (0.7164)	Prec@1 55.000 (51.000)	Prec@5 100.000 (100.000)
Epoch: [3][25/42]	Time 2.543 (2.498)	Data 0.167 (0.178)	Loss 0.6647 (0.7145)	Prec@1 60.000 (51.346)	Prec@5 100.000 (100.000)
Epoch: [3][26/42]	Time 2.490 (2.498)	Data 0.180 (0.178)	Loss 0.6969 (0.7138)	Prec@1 50.000 (51.296)	Prec@5 100.000 (100.000)
Epoch: [3][27/42]	Time 2.450 (2.496)	Data 0.147 (0.177)	Loss 0.6805 (0.7126)	Prec@1 50.000 (51.250)	Prec@5 100.000 (100.000)
Epoch: [3][28/42]	Time 2.471 (2.495)	Data 0.124 (0.176)	Loss 0.7088 (0.7125)	Prec@1 55.000 (51.379)	Prec@5 100.000 (100.000)


Epoch: [4][36/42]	Time 3.072 (3.387)	Data 0.248 (0.438)	Loss 0.6483 (0.7091)	Prec@1 75.000 (50.270)	Prec@5 100.000 (100.000)
Epoch: [4][37/42]	Time 3.616 (3.393)	Data 0.418 (0.438)	Loss 0.6912 (0.7086)	Prec@1 60.000 (50.526)	Prec@5 100.000 (100.000)
Epoch: [4][38/42]	Time 4.298 (3.416)	Data 0.296 (0.434)	Loss 0.7102 (0.7087)	Prec@1 50.000 (50.513)	Prec@5 100.000 (100.000)
Epoch: [4][39/42]	Time 3.777 (3.425)	Data 0.361 (0.432)	Loss 0.6848 (0.7081)	Prec@1 55.000 (50.625)	Prec@5 100.000 (100.000)
Epoch: [4][40/42]	Time 4.126 (3.442)	Data 0.591 (0.436)	Loss 0.7731 (0.7096)	Prec@1 45.000 (50.488)	Prec@5 100.000 (100.000)
Epoch: [4][41/42]	Time 1.805 (3.403)	Data 0.113 (0.429)	Loss 0.7609 (0.7101)	Prec@1 37.500 (50.362)	Prec@5 100.000 (100.000)
Test: [0/11]	Time 1.522 (1.522)	Loss 0.7447 (0.7447)	Prec@1 35.000 (35.000)	Prec@5 100.000 (100.000)
Test: [1/11]	Time 1.389 (1.456)	Loss 0.7099 (0.7273)	Prec@1 50.000 (42.500)	Prec@5 100.000 (100.000)
Test: [2/11]	Time 2.763 (1.891)	Loss 0.7214 (0.7