In [4]:
import torch 
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.utils.data.dataset import Dataset  # For custom datasets
from torch.utils.data.sampler import SubsetRandomSampler
import pandas as pd
import numpy as np
from PIL import Image


# Device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Hyper parameters
validation_split = .2
shuffle_dataset = False
random_seed= 42
num_epochs = 10
num_classes = 2
batch_size = 20
learning_rate = 0.0001
weight_decay = 0.0001
momentum = 0.9
print_freq = 1
best_prec1 = 0
workers = 8

dir = '/home/suraj/asd-abide-prediction-pytorch/'

#dataset

data = pd.read_csv(dir + 'data.csv')

class CustomDatasetFromImages(Dataset):
    def __init__(self, csv_path):
        """
        Args:
            csv_path (string): path to csv file
            img_path (string): path to the folder where images are
            transform: pytorch transforms for transforms and tensor conversion
        """
        # Transforms
        self.to_tensor = transforms.ToTensor()
        # Read the csv file
        self.data_info = pd.read_csv(csv_path, header=None)
        # First column contains the image paths
        self.image_arr = np.asarray(dir + 'png-i/' + self.data_info.iloc[:, 6] + '_alff.nii.jpg')
        # Second column is the labels
        self.label_arr = np.asarray(self.data_info.iloc[:, 7])
        # Third column is for an operation indicator
        # self.operation_arr = np.asarray(self.data_info.iloc[:, 2])
        # Calculate len
        self.data_len = len(self.data_info.index)

    def __getitem__(self, index):
        # Get image name from the pandas df
        single_image_name = self.image_arr[index]
        # Open image
        img_as_img = Image.open(single_image_name)

        # Check if there is an operation
        #some_operation = self.operation_arr[index]
        # If there is an operation
        #if some_operation:
            # Do some operation on image
            # ...
            # ...
        #    pass
        # Transform image to tensor
        img_as_tensor = self.to_tensor(img_as_img)

        # Get label(class) of the image based on the cropped pandas column
        single_image_label = self.label_arr[index]

        return (img_as_tensor, single_image_label)

    def __len__(self):
        return self.data_len

if __name__ == "__main__":

    dataset_train = CustomDatasetFromImages(dir + 'train.csv')
    dataset_test = CustomDatasetFromImages(dir + 'test.csv')
    
# Data loader
#train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
#                                          batch_size=batch_size, 
#                                         shuffle=True)

#test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
#                                          batch_size=batch_size, 
#                                          shuffle=False)



train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True, 
                                           num_workers=workers, pin_memory=True)
validation_loader = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=True,
                                                num_workers=workers, pin_memory=True)


# Convolutional neural network (two convolutional layers)
class ConvNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc1 = nn.Linear(51168, 50)
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc1(out)
        return out
    

    

model = ConvNet(num_classes).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        labels_ = []
        for label in labels:
            labels_.append(int(label))
        labels = torch.LongTensor(labels_)
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % print_freq == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))
            

# Test the model
model.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in validation_loader:
        images = images.to(device)
        labels_ = []
        for label in labels:
            labels_.append(int(label))
        labels = torch.LongTensor(labels_)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        print('Test Accuracy: {} %'.format(100 * correct / total))



    print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))

# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

Linear(in_features=51168, out_features=50, bias=True)
Epoch [1/10], Step [1/42], Loss: 1.1398
Epoch [1/10], Step [2/42], Loss: 1.0294
Epoch [1/10], Step [3/42], Loss: 0.6730
Epoch [1/10], Step [4/42], Loss: 1.8267
Epoch [1/10], Step [5/42], Loss: 0.9799
Epoch [1/10], Step [6/42], Loss: 1.6394
Epoch [1/10], Step [7/42], Loss: 0.7400
Epoch [1/10], Step [8/42], Loss: 0.7718
Epoch [1/10], Step [9/42], Loss: 0.8776
Epoch [1/10], Step [10/42], Loss: 0.9770
Epoch [1/10], Step [11/42], Loss: 0.8009
Epoch [1/10], Step [12/42], Loss: 0.6705
Epoch [1/10], Step [13/42], Loss: 0.5492
Epoch [1/10], Step [14/42], Loss: 0.6947
Epoch [1/10], Step [15/42], Loss: 0.6817
Epoch [1/10], Step [16/42], Loss: 0.7329
Epoch [1/10], Step [17/42], Loss: 0.7681
Epoch [1/10], Step [18/42], Loss: 0.6908
Epoch [1/10], Step [19/42], Loss: 0.7395
Epoch [1/10], Step [20/42], Loss: 0.6937
Epoch [1/10], Step [21/42], Loss: 0.8146
Epoch [1/10], Step [22/42], Loss: 0.6636
Epoch [1/10], Step [23/42], Loss: 1.1862
Epoch [1/10]

Process Process-32:
Process Process-31:
Process Process-30:
Process Process-26:
Process Process-25:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/suraj/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/home/suraj/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/suraj/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/suraj/anaconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
Process Process-29:
  File "/home/suraj/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/suraj/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/h

Epoch [1/10], Step [38/42], Loss: 0.9025


Process Process-27:
  File "/home/suraj/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/suraj/anaconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/suraj/anaconda3/lib/python3.6/multiprocessing/connection.py", line 257, in poll
    return self._poll(timeout)
  File "/home/suraj/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 96, in _worker_loop
    r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)


KeyboardInterrupt: 

Traceback (most recent call last):
  File "/home/suraj/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 96, in _worker_loop
    r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
  File "/home/suraj/anaconda3/lib/python3.6/multiprocessing/queues.py", line 104, in get
    if not self._poll(timeout):
  File "/home/suraj/anaconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/suraj/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/suraj/anaconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/suraj/anaconda3/lib/python3.6/multiprocessing/queues.py", line 104, in get
    if not self._poll(timeout):
  File "/home/suraj/anaconda3/lib/python3.6/multiprocessing/connection.py", line 257, in poll
    return self._poll(timeout)
  File "/home/suraj/anaconda3/lib/pyt