
**Install requirements**

In [0]:
!pip3 install 'torch==1.3.1'
!pip3 install 'torchvision==0.4.2'
!pip3 install 'Pillow-SIMD'
!pip3 install 'tqdm'



**Import libraries**

In [0]:
import os
import logging
import copy 
import sys

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Subset, DataLoader
from torch.backends import cudnn
from torch.autograd import Function

import torchvision
from torchvision import transforms
from torchvision.models import alexnet
from torchvision.datasets import VisionDataset

from PIL import Image
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import time
import random

# **Set Arguments**

In [0]:
DATA_DIR = 'AIML_project/dataset1'
NUM_CLASSES = 4

####**Define Data Preprocessing**

In [0]:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
 
train_transform = transforms.Compose([transforms.Resize(256),      # Resizes short size of the PIL image to 256
                                      transforms.CenterCrop(224),  # Crops a central square patch of the image
                                                                  # 224 because torchvision's AlexNet needs a 224x224 input!
                                                                  # Remember this when applying different transformations, otherwise you get an error
                                      transforms.ToTensor(), # Turn PIL Image to torch.Tensor
                                      transforms.Normalize(mean, std)# Normalizes tensor with mean and standard deviation
                                        
])

# Define transforms for the evaluation phase
eval_transform = transforms.Compose([transforms.Resize(256),
                                    transforms.CenterCrop(224),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean, std)                                    
])

####**class PACS**

In [0]:
IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]

def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')

class PACS(VisionDataset):
  def __init__(self, root, split='train', transform=None, target_transform=None, loader=pil_loader):
        super(PACS, self).__init__(root, transform=transform, target_transform=target_transform)
        
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
        self.classes, self.class_to_idx = self._find_classes(self.root)
        self.images = self.make_dataset(DATA_DIR,self.class_to_idx)

  def _find_classes(self, dir):
        
        if sys.version_info >= (3, 5):
            # Faster and available in Python 3.5 and above
            classes = [d.name for d in os.scandir(dir) if d.is_dir()]
        else:
            classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
        classes.sort()
        class_to_idx = {classes[i]: i for i in range(len(classes))}
        return classes, class_to_idx

  def make_dataset(self, dir, class_to_idx):
    images = []
    dir = os.path.expanduser(dir)
    
    for target in sorted(class_to_idx.keys()):
        d = os.path.join(dir, target)
        if not os.path.isdir(d):
            continue
        for root, dirs, _ in sorted(os.walk(d)):
            for i in sorted(dirs):
                path = os.path.join(root, i)
                item = (path, class_to_idx[target])
                images.append(item)

    return images  # contiene i path delle cartelle contenenti le quaterne. Ogni path è associato ad un'etichetta che indica la posizione dell'odd

  def __getitem__(self, index):

        quad = []
        sample_dir, label = self.images[index]
        for img in sorted(os.listdir(sample_dir)):
          image_path = os.path.join(sample_dir, img) 
          quad.append( self.loader(image_path))

        # Applies preprocessing when accessing the image
        if self.transform is not None:
          for i in range(4):
              quad[i] = self.transform(quad[i])
        if self.target_transform is not None:
            label = self.target_transform(label)

        return quad, label  

  def __len__(self):
        length = len(self.images)
        return length  

####**OOONet**

In [0]:
class ConcatLayer(Function):
    @staticmethod
    def forward(ctx, fc6_1 ,fc6_2, fc6_3 ,fc6_4):
        print(fc6_1.shape,fc6_2.shape,fc6_3.shape,fc6_4.shape )
        concatenation = torch.cat([fc6_1, fc6_2,fc6_3,fc6_4], dim=1) # esempio (3,4) (3,4)  ---> ( 2, 3, 4 )  oppure ( 4096 )( 4096 ) --> (2,4096)
        #return concatenation.view_as(concatenation)
        print(concatenation.shape)
        return concatenation
        
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output


class OOONet(nn.Module):

    def __init__(self, num_classes=1000):
        super(OOONet, self).__init__()
        
        self.branch1_1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.AdaptiveAvgPool2d((6, 6)),           
        )
        self.branch1_2 = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),                   #FC 6 
            nn.ReLU(inplace=True),
        )
        self.branch2_1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.AdaptiveAvgPool2d((6, 6)),
        )
        self.branch2_2 = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),                   #FC 6 
            nn.ReLU(inplace=True),
        )
        self.branch3_1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.AdaptiveAvgPool2d((6, 6)),   
        )
        self.branch3_2 = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),                   #FC 6 
            nn.ReLU(inplace=True),
        )
        self.branch4_1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),  #CONV 5 
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.AdaptiveAvgPool2d((6, 6)),
               
        )
        self.branch4_2 = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),                   #FC 6 
            nn.ReLU(inplace=True),
        )        
        #self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
     
        #Livelli di fusione!!!!!!!
        self.concatLayer = ConcatLayer()

        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(4096,16384),            # 16384 da 4096
            nn.ReLU(inplace=True),
            nn.Linear(4096, NUM_CLASSES),
        )
        
    
    def forward(self, x):
        #bisogna definire X perchè dovrebbe essere un vettore di 4 immagini.
        x0 = self.branch1_1(x[0])
        #x0 = self.avgpool(x0)
        x0 = torch.flatten(x0, 1)
        output_branch1 = self.branch1_2(x0)

        x1 = self.branch2_1(x[1])
        #x1 = self.avgpool(x1)
        x1 = torch.flatten(x1, 1)
        output_branch2 = self.branch2_2(x1)

        x2 = self.branch3_1(x[2])
        #x2 = self.avgpool(x2)
        x2 = torch.flatten(x2, 1)
        output_branch3 = self.branch3_2(x2)

        x3 = self.branch4_1(x[3])
        #x3 = self.avgpool(x3)
        x3 = torch.flatten(x3, 1)
        output_branch4 = self.branch4_2(x3)

        out = self.concatLayer.apply(output_branch1,output_branch2,output_branch3,output_branch4) 
        #out = torch.flatten(out,1)
        print("out")
        print(out.shape)
        out = self.classifier(out)
      
        
        return out

def buildO3Net ():

    model = alexnet(pretrained=True)

    net =  OOONet()
    #0,3,6,8,10,14

    #DEEP COPY FEATURES OF BRANCH 1

    net.branch1_1[0].weight.data = copy.deepcopy(model.features[0].weight.data)
    net.branch1_1[0].bias.data = copy.deepcopy(model.features[0].bias.data)   
    net.branch1_1[3].weight.data = copy.deepcopy(model.features[3].weight.data)
    net.branch1_1[3].bias.data = copy.deepcopy(model.features[3].bias.data)
    net.branch1_1[6].weight.data = copy.deepcopy(model.features[6].weight.data)
    net.branch1_1[6].bias.data = copy.deepcopy(model.features[6].bias.data)
    net.branch1_1[8].weight.data = copy.deepcopy(model.features[8].weight.data)
    net.branch1_1[8].bias.data = copy.deepcopy(model.features[8].bias.data)
    net.branch1_1[10].weight.data = copy.deepcopy(model.features[10].weight.data)
    net.branch1_1[10].bias.data = copy.deepcopy(model.features[10].bias.data)
    net.branch1_2[1].weight.data = copy.deepcopy(model.classifier[1].weight.data)
    net.branch1_2[1].bias.data = copy.deepcopy(model.classifier[1].bias.data)


    #DEEP COPY FEATURES OF BRANCH 2

    net.branch2_1[0].weight.data = copy.deepcopy(model.features[0].weight.data)
    net.branch2_1[0].bias.data = copy.deepcopy(model.features[0].bias.data)   
    net.branch2_1[3].weight.data = copy.deepcopy(model.features[3].weight.data)
    net.branch2_1[3].bias.data = copy.deepcopy(model.features[3].bias.data)
    net.branch2_1[6].weight.data = copy.deepcopy(model.features[6].weight.data)
    net.branch2_1[6].bias.data = copy.deepcopy(model.features[6].bias.data)
    net.branch2_1[8].weight.data = copy.deepcopy(model.features[8].weight.data)
    net.branch2_1[8].bias.data = copy.deepcopy(model.features[8].bias.data)
    net.branch2_1[10].weight.data = copy.deepcopy(model.features[10].weight.data)
    net.branch2_1[10].bias.data = copy.deepcopy(model.features[10].bias.data)
    net.branch2_2[1].weight.data = copy.deepcopy(model.classifier[1].weight.data)
    net.branch2_2[1].bias.data = copy.deepcopy(model.classifier[1].bias.data)

    #DEEP COPY FEATURES OF BRANCH 3

    net.branch3_1[0].weight.data = copy.deepcopy(model.features[0].weight.data)
    net.branch3_1[0].bias.data = copy.deepcopy(model.features[0].bias.data)   
    net.branch3_1[3].weight.data = copy.deepcopy(model.features[3].weight.data)
    net.branch3_1[3].bias.data = copy.deepcopy(model.features[3].bias.data)
    net.branch3_1[6].weight.data = copy.deepcopy(model.features[6].weight.data)
    net.branch3_1[6].bias.data = copy.deepcopy(model.features[6].bias.data)
    net.branch3_1[8].weight.data = copy.deepcopy(model.features[8].weight.data)
    net.branch3_1[8].bias.data = copy.deepcopy(model.features[8].bias.data)
    net.branch3_1[10].weight.data = copy.deepcopy(model.features[10].weight.data)
    net.branch3_1[10].bias.data = copy.deepcopy(model.features[10].bias.data)
    net.branch3_2[1].weight.data = copy.deepcopy(model.classifier[1].weight.data)
    net.branch3_2[1].bias.data = copy.deepcopy(model.classifier[1].bias.data)

    #DEEP COPY FEATURES OF BRANCH 4

    net.branch4_1[0].weight.data = copy.deepcopy(model.features[0].weight.data)
    net.branch4_1[0].bias.data = copy.deepcopy(model.features[0].bias.data)   
    net.branch4_1[3].weight.data = copy.deepcopy(model.features[3].weight.data)
    net.branch4_1[3].bias.data = copy.deepcopy(model.features[3].bias.data)
    net.branch4_1[6].weight.data = copy.deepcopy(model.features[6].weight.data)
    net.branch4_1[6].bias.data = copy.deepcopy(model.features[6].bias.data)
    net.branch4_1[8].weight.data = copy.deepcopy(model.features[8].weight.data)
    net.branch4_1[8].bias.data = copy.deepcopy(model.features[8].bias.data)
    net.branch4_1[10].weight.data = copy.deepcopy(model.features[10].weight.data)
    net.branch4_1[10].bias.data = copy.deepcopy(model.features[10].bias.data)
    net.branch4_2[1].weight.data = copy.deepcopy(model.classifier[1].weight.data)
    net.branch4_2[1].bias.data = copy.deepcopy(model.classifier[1].bias.data)

    #DEEP COPY OF LAST TWO FC LAYERS

    net.classifier[1].weight.data = copy.deepcopy(model.classifier[4].weight.data)
    net.classifier[1].bias.data = copy.deepcopy(model.classifier[4].bias.data)
   
    return net



####**Prepare Dataset**

In [0]:
# Clone github repository with data
if not os.path.isdir('./AIML_project'):
  !git clone https://github.com/rebeccapelaca/AIML_project.git

dataset = PACS(DATA_DIR, transform=train_transform)
print('Dataset: {}'.format(len(dataset)))

train_indexes = [idx for idx in range(len(dataset)) if (idx % 3) == 1]
val_indexes = [idx for idx in range(len(dataset)) if (idx % 3) == 2]
test_indexes = [idx for idx in range(len(dataset)) if not (idx % 3)]

train_dataset = Subset(dataset, train_indexes)
val_dataset = Subset(dataset, val_indexes)
test_dataset = Subset(dataset, test_indexes)

print('Training Dataset: {}'.format(len(train_dataset)))
print('Validation Dataset: {}'.format(len(val_dataset)))
print('Test Dataset: {}'.format(len(test_dataset)))


Dataset: 14064
Training Dataset: 4688
Validation Dataset: 4688
Test Dataset: 4688


####**Dataloaders**

In [0]:
train_dataloader = DataLoader(dataset=train_dataset, batch_size=256, shuffle=True, num_workers=4, drop_last=True)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=256, shuffle=True, num_workers=4, drop_last=True)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=256, shuffle=True, num_workers=4, drop_last=True)

####**Set Arguments**

In [0]:
DEVICE = 'cuda' # 'cuda' or 'cpu'
 

#BATCH_SIZE = 256     # Batch size will be chosen through a grid search
LR = 1e-3            # Learning rate will be chosen through a grid search

MOMENTUM = 0.9       # Hyperparameter for SGD, keep this at 0.9 when using SGD
WEIGHT_DECAY = 5e-5  # Regularization, you can keep this at the default

#NUM_EPOCHS = 30      # Total number of training epochs (iterations over dataset)
STEP_SIZE = 20       # How many epochs before decreasing learning rate (if using a step-down policy)
GAMMA = 0.1          # Multiplicative factor for learning rate step-down

LOG_FREQUENCY = 10

#Grid search parameters
lrates=[1e-4,1e-3]       #tried 1e-6 (too low),1e-5(too low),1e-4,3e-4,7e-4,1e-3(too high with Adam),1e-2(too high)
batch_sizes=[192,256]   #128(too small),192,256,320,384(too big)
NUM_EPOCHS=[30] #30 is the best

####**Train**

In [0]:
net = buildO3Net()
net.to(DEVICE)
criterion = nn.CrossEntropyLoss() # for classification, we use Cross Entropy
parameters_to_optimize = net.parameters()
optimizer = optim.SGD(parameters_to_optimize, lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)

num_epochs = 30


for epoch in range(num_epochs):
    print('Starting epoch {}/{}, LR = {}'.format(epoch+1, num_epochs, scheduler.get_lr()))

    for images, labels in train_dataloader:
      # Bring data over the device of choice
      for i in range(4):
        images[i] = images[i].to(DEVICE)
      
      labels = labels.to(DEVICE)

      net.train() # Sets module in training mode

      # PyTorch, by default, accumulates gradients after each backward pass
      # We need to manually set the gradients to zero before starting a new iteration
      optimizer.zero_grad() # Zero-ing the gradients

      # Forward pass to the network
      outputs = net(images)

      # Compute loss based on output and ground truth
      loss = criterion(outputs, labels)
      print("loss:{}".format(loss.item()))

      # Compute gradients for each layer and update weights
      loss.backward()  # backward pass: computes gradients
      optimizer.step() # update weights based on accumulated gradients

      current_step += 1
    
   
    

Starting epoch 1/30, LR = [0.001]
torch.Size([256, 4096]) torch.Size([256, 4096]) torch.Size([256, 4096]) torch.Size([256, 4096])
torch.Size([256, 16384])
out
torch.Size([256, 16384])


RuntimeError: ignored

###**Test**

In [0]:
net = net.to(DEVICE) 
net.train(False) # Set Network to evaluation mode , equivalent to net.eval()
#test_dataloader = DataLoader(test_dataset, batch_size=best_batch, shuffle=True, num_workers=4)

running_corrects = 0
for images, labels in tqdm(test_dataloader): #evaluate performance on validation set
  images = images.to(DEVICE)
  labels = labels.to(DEVICE)

  # Forward Pass
  outputs = net(images)
  # Get predictions
  _, preds = torch.max(outputs.data, 1)

  # Update Corrects
  running_corrects += torch.sum(preds == labels.data).data.item()

# Calculate Accuracy on validation set
accuracy = running_corrects / float(len(test_dataset))


print('Test Accuracy: {}'.format(accuracy))

