In [1]:
import os
import random
import numpy as np
from PIL import Image
from tqdm.notebook import tqdm

import pandas as pd
from pandas import DataFrame
from sklearn.model_selection import train_test_split

import torch
import torchvision
from torch import nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

In [21]:
# MEAN + STD for centering

"""
directory = 'food'

mean = [0, 0, 0]
std = [0, 0, 0]

total_images = len(os.listdir(directory))
for filename in os.listdir(directory):
    if filename.endswith(".jpg"):
        path = (os.path.join(directory, filename))
        image = np.array(Image.open(path).resize((224,224), Image.ANTIALIAS), dtype = np.float32) / 255.0
        image = image.transpose(2,0,1)
        mean[0] += (np.mean(image[0]) / total_images)
        mean[1] += (np.mean(image[1]) / total_images)
        mean[2] += (np.mean(image[2]) / total_images)
        std[0] += (np.std(image[0]) / total_images)
        std[1] += (np.std(image[1]) / total_images)
        std[2] += (np.std(image[2]) / total_images)
"""

mean = [0.6080600980066939, 0.5160828748897249, 0.4123187661759489]
std = [0.22611850065632327, 0.24211910331997763, 0.2603793200133369]

print(mean, std)

mean, std = np.array(mean), np.array(std)
mean.resize((3,1,1))
std.resize((3,1,1))

[0.6080600980066939, 0.5160828748897249, 0.4123187661759489] [0.22611850065632327, 0.24211910331997763, 0.2603793200133369]


In [22]:
class ProjectDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, dataframe, preload=False, prediction=False, random_permutations=True):
        self.triplets = dataframe
        self.root_dir = "food"
        self.img_len = 224
        self.images = dict()
        self.preload = preload
        self.prediction = prediction
        self.random_permutations = random_permutations
        if preload:
            for idx in range(len(self.triplets)):
                for col in range(3):
                    name = self.triplets.iloc[idx, col]
                    if name not in self.images:
                        path = os.path.join(self.root_dir, name)
                        image = np.array(Image.open(path).resize((self.img_len,self.img_len), Image.ANTIALIAS)) / 255.0
                        image = image.transpose(2,0,1)
                        image = (image - mean) / std
                        image = np.array(image, dtype=np.float32)
                        self.images[name] = image

    def __len__(self):
        return len(self.triplets)

    def __getitem__(self, idx):
        if not self.preload:
            img_name1 = os.path.join(self.root_dir, self.triplets.iloc[idx, 0])
            img_name2 = os.path.join(self.root_dir, self.triplets.iloc[idx, 1])
            img_name3 = os.path.join(self.root_dir, self.triplets.iloc[idx, 2])

            image1 = np.array(Image.open(img_name1).resize((self.img_len,self.img_len), Image.ANTIALIAS), dtype = np.float32) / 255.0
            image2 = np.array(Image.open(img_name2).resize((self.img_len,self.img_len), Image.ANTIALIAS), dtype = np.float32) / 255.0
            image3 = np.array(Image.open(img_name3).resize((self.img_len,self.img_len), Image.ANTIALIAS), dtype = np.float32) / 255.0
        
            image1 = image1.transpose(2,0,1)
            image2 = image2.transpose(2,0,1)
            image3 = image3.transpose(2,0,1)
            
            image1 = (image1 - mean) / std
            image2 = (image2 - mean) / std
            image3 = (image3 - mean) / std
            
            image1 = np.array(image1, dtype=np.float32)
            image2 = np.array(image2, dtype=np.float32)
            image3 = np.array(image3, dtype=np.float32)
        else:
            image1 = self.images[self.triplets.iloc[idx, 0]]
            image2 = self.images[self.triplets.iloc[idx, 1]]
            image3 = self.images[self.triplets.iloc[idx, 2]]
        
        if self.random_permutations:
            # Random permutations (flip horizontally or vertically)
            if random.random() < 0.5:
                np.flip(image1, 1)
                
            if random.random() < 0.5:
                np.flip(image1, 2)
                
            if random.random() < 0.5:
                np.flip(image2, 1)
                
            if random.random() < 0.5:
                np.flip(image2, 2)
                
            if random.random() < 0.5:
                np.flip(image3, 1)
                
            if random.random() < 0.5:
                np.flip(image3, 2)
        
        if not self.prediction:
            order = random.random()
        else:
            order = 1.0
            
        sample = {'image1': image1,
                  'image2': image2 if order < 0.5 else image3,
                  'image3': image3 if order < 0.5 else image2,
                  'label': 0 if order < 0.5 else 1}

        return sample

In [23]:
class CNN1(nn.Module):

    def __init__(self, input_shape = (3, 224, 224), pretrained = False):

        super(CNN1, self).__init__()
        
        if not pretrained:
            self.CNN_layers = nn.Sequential(nn.Conv2d(input_shape[0], 16, kernel_size=4),
                                            nn.BatchNorm2d(16),
                                            nn.ReLU(),
                                            nn.MaxPool2d(kernel_size=2, stride=2),

                                            nn.Conv2d(16, 32, kernel_size=4),
                                            nn.BatchNorm2d(32),
                                            nn.ReLU(),
                                            nn.MaxPool2d(kernel_size=2, stride=2),

                                            nn.Conv2d(32, 32, kernel_size=4),
                                            nn.BatchNorm2d(32),
                                            nn.ReLU(),
                                            nn.MaxPool2d(kernel_size=2, stride=2),

                                            nn.Conv2d(32, 64, kernel_size=4),
                                            nn.BatchNorm2d(64),
                                            nn.ReLU(),
                                            nn.MaxPool2d(kernel_size=2, stride=2))
        else:
            self.CNN_layers = torchvision.models.alexnet(pretrained=True).features[:16]
        
            # Freeze
            for name, param in self.CNN_layers.named_parameters():
                param.requires_grad = False
                        
                        
        cnn_output_size = self._get_conv_output(input_shape)
        
        self.output_layer= nn.Sequential(nn.Linear(cnn_output_size * 3, 256),
                                         nn.ReLU(),
                                         nn.Dropout(0.2),
                                         nn.Linear(256, 128),
                                         nn.ReLU(),
                                         nn.Dropout(0.2),
                                         nn.Linear(128, 32),
                                         nn.ReLU(),
                                         nn.Linear(32, 2))
        
        
    
    # Forward method
    def forward(self, x):
        
        in1 = x['image1']
        in2 = x['image2']
        in3 = x['image3']
        
        if torch.cuda.is_available():
            in1 = in1.cuda()
            in2 = in2.cuda()
            in3 = in3.cuda()
        
        # CNN layers
        out1 = self.CNN_layers(in1)
        out2 = self.CNN_layers(in2)
        out3 = self.CNN_layers(in3)

        # Flat
        out1 = out1.view(out1.size(0), -1)
        out2 = out2.view(out2.size(0), -1)
        out3 = out3.view(out3.size(0), -1)
        
        # Concat
        out = torch.cat((out1, out2, out3), 1)

        # Per-predicted-value predictions
        out = self.output_layer(out)
        return out

    # Generates input sample and forward to get shape
    def _get_conv_output(self, shape = (3, 224, 224)):
        my_input = Variable(torch.rand(1, *shape))
        output_feat = self.CNN_layers(my_input)
        n_size = output_feat.data.view(1, -1).size(1)
        return n_size

In [24]:
# Data - Loading triplets
all_triplets = pd.read_csv('train_triplets.txt', header=None, delim_whitespace=True, names=['A', 'B', 'C'], dtype=str)
all_triplets = all_triplets.apply(lambda name : name + '.jpg')
train_triplets, val_triplets = train_test_split(all_triplets, train_size=0.95, random_state=42)

In [25]:
# Dataset and Dataloader
dataset = ProjectDataset(train_triplets, preload=True)
dataloader = DataLoader(dataset, shuffle=True, batch_size=128)

# Network
net = CNN1(pretrained=True)

def init_normal(m):
    if type(m) == nn.Linear:
        nn.init.xavier_uniform_(m.weight)

# use the modules apply function to recursively apply the initialization
net.apply(init_normal)

if torch.cuda.is_available():
    net = net.cuda()

# gradient clipping
clip_value = 3
for p in net.parameters():
    if p.requires_grad:
        p.register_hook(lambda grad: torch.clamp(grad, -clip_value, clip_value))

# Criterion and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

In [26]:
print(net)

CNN1(
  (CNN_layers): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (output_layer): Sequential(
    (0): Linear(in_features=27648, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=256, ou

In [27]:
epochs = 16
for epoch in range(epochs):
    running_loss = 0.0
    for i, data in tqdm(enumerate(dataloader, 0), total = len(dataloader)):
        # get the inputs; data is a list of [inputs, labels]
        labels = data['label']
        if torch.cuda.is_available():
            labels = labels.cuda()

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(data)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 8 == 7:    # print every 8 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 8))
            running_loss = 0.0

HBox(children=(FloatProgress(value=0.0, max=442.0), HTML(value='')))

[1,     8] loss: 1.298
[1,    16] loss: 0.946
[1,    24] loss: 0.869
[1,    32] loss: 0.792
[1,    40] loss: 0.774
[1,    48] loss: 0.759
[1,    56] loss: 0.743
[1,    64] loss: 0.725
[1,    72] loss: 0.724
[1,    80] loss: 0.729
[1,    88] loss: 0.718
[1,    96] loss: 0.710
[1,   104] loss: 0.701
[1,   112] loss: 0.716
[1,   120] loss: 0.708
[1,   128] loss: 0.705
[1,   136] loss: 0.704
[1,   144] loss: 0.705
[1,   152] loss: 0.692
[1,   160] loss: 0.703
[1,   168] loss: 0.701
[1,   176] loss: 0.699
[1,   184] loss: 0.702
[1,   192] loss: 0.703
[1,   200] loss: 0.702
[1,   208] loss: 0.706
[1,   216] loss: 0.698
[1,   224] loss: 0.702
[1,   232] loss: 0.688
[1,   240] loss: 0.693
[1,   248] loss: 0.688
[1,   256] loss: 0.699
[1,   264] loss: 0.697
[1,   272] loss: 0.685
[1,   280] loss: 0.683
[1,   288] loss: 0.699
[1,   296] loss: 0.696
[1,   304] loss: 0.683
[1,   312] loss: 0.689
[1,   320] loss: 0.691
[1,   328] loss: 0.680
[1,   336] loss: 0.672
[1,   344] loss: 0.691
[1,   352] 

HBox(children=(FloatProgress(value=0.0, max=442.0), HTML(value='')))

[2,     8] loss: 0.673
[2,    16] loss: 0.674
[2,    24] loss: 0.669
[2,    32] loss: 0.677
[2,    40] loss: 0.669
[2,    48] loss: 0.664
[2,    56] loss: 0.667
[2,    64] loss: 0.663
[2,    72] loss: 0.663
[2,    80] loss: 0.666
[2,    88] loss: 0.660
[2,    96] loss: 0.659
[2,   104] loss: 0.665
[2,   112] loss: 0.670
[2,   120] loss: 0.657
[2,   128] loss: 0.658
[2,   136] loss: 0.655
[2,   144] loss: 0.666
[2,   152] loss: 0.667
[2,   160] loss: 0.655
[2,   168] loss: 0.655
[2,   176] loss: 0.654
[2,   184] loss: 0.648
[2,   192] loss: 0.657
[2,   200] loss: 0.649
[2,   208] loss: 0.654
[2,   216] loss: 0.634
[2,   224] loss: 0.653
[2,   232] loss: 0.652
[2,   240] loss: 0.650
[2,   248] loss: 0.676
[2,   256] loss: 0.655
[2,   264] loss: 0.647
[2,   272] loss: 0.655
[2,   280] loss: 0.641
[2,   288] loss: 0.647
[2,   296] loss: 0.645
[2,   304] loss: 0.637
[2,   312] loss: 0.639
[2,   320] loss: 0.625
[2,   328] loss: 0.650
[2,   336] loss: 0.637
[2,   344] loss: 0.625
[2,   352] 

HBox(children=(FloatProgress(value=0.0, max=442.0), HTML(value='')))

[3,     8] loss: 0.614
[3,    16] loss: 0.620
[3,    24] loss: 0.627
[3,    32] loss: 0.607
[3,    40] loss: 0.628
[3,    48] loss: 0.619
[3,    56] loss: 0.598
[3,    64] loss: 0.633
[3,    72] loss: 0.605
[3,    80] loss: 0.625
[3,    88] loss: 0.618
[3,    96] loss: 0.619
[3,   104] loss: 0.609
[3,   112] loss: 0.614
[3,   120] loss: 0.611
[3,   128] loss: 0.609
[3,   136] loss: 0.602
[3,   144] loss: 0.592
[3,   152] loss: 0.601
[3,   160] loss: 0.635
[3,   168] loss: 0.618
[3,   176] loss: 0.620
[3,   184] loss: 0.604
[3,   192] loss: 0.598
[3,   200] loss: 0.605
[3,   208] loss: 0.633
[3,   216] loss: 0.597
[3,   224] loss: 0.612
[3,   232] loss: 0.609
[3,   240] loss: 0.613
[3,   248] loss: 0.605
[3,   256] loss: 0.600
[3,   264] loss: 0.645
[3,   272] loss: 0.610
[3,   280] loss: 0.614
[3,   288] loss: 0.614
[3,   296] loss: 0.608
[3,   304] loss: 0.572
[3,   312] loss: 0.592
[3,   320] loss: 0.592
[3,   328] loss: 0.587
[3,   336] loss: 0.591
[3,   344] loss: 0.602
[3,   352] 

HBox(children=(FloatProgress(value=0.0, max=442.0), HTML(value='')))

[4,     8] loss: 0.599
[4,    16] loss: 0.618
[4,    24] loss: 0.593
[4,    32] loss: 0.595
[4,    40] loss: 0.594
[4,    48] loss: 0.578
[4,    56] loss: 0.582
[4,    64] loss: 0.583
[4,    72] loss: 0.588
[4,    80] loss: 0.595
[4,    88] loss: 0.593
[4,    96] loss: 0.589
[4,   104] loss: 0.589
[4,   112] loss: 0.578
[4,   120] loss: 0.580
[4,   128] loss: 0.588
[4,   136] loss: 0.565
[4,   144] loss: 0.592
[4,   152] loss: 0.585
[4,   160] loss: 0.595
[4,   168] loss: 0.590
[4,   176] loss: 0.590
[4,   184] loss: 0.568
[4,   192] loss: 0.572
[4,   200] loss: 0.587
[4,   208] loss: 0.589
[4,   216] loss: 0.554
[4,   224] loss: 0.581
[4,   232] loss: 0.572
[4,   240] loss: 0.605
[4,   248] loss: 0.556
[4,   256] loss: 0.597
[4,   264] loss: 0.573
[4,   272] loss: 0.587
[4,   280] loss: 0.573
[4,   288] loss: 0.608
[4,   296] loss: 0.573
[4,   304] loss: 0.575
[4,   312] loss: 0.574
[4,   320] loss: 0.553
[4,   328] loss: 0.579
[4,   336] loss: 0.562
[4,   344] loss: 0.560
[4,   352] 

HBox(children=(FloatProgress(value=0.0, max=442.0), HTML(value='')))

[5,     8] loss: 0.569
[5,    16] loss: 0.552
[5,    24] loss: 0.540
[5,    32] loss: 0.573
[5,    40] loss: 0.566
[5,    48] loss: 0.574
[5,    56] loss: 0.556
[5,    64] loss: 0.565
[5,    72] loss: 0.544
[5,    80] loss: 0.552
[5,    88] loss: 0.561
[5,    96] loss: 0.568
[5,   104] loss: 0.548
[5,   112] loss: 0.554
[5,   120] loss: 0.541
[5,   128] loss: 0.562
[5,   136] loss: 0.554
[5,   144] loss: 0.539
[5,   152] loss: 0.573
[5,   160] loss: 0.573
[5,   168] loss: 0.560
[5,   176] loss: 0.535
[5,   184] loss: 0.556
[5,   192] loss: 0.523
[5,   200] loss: 0.537
[5,   208] loss: 0.551
[5,   216] loss: 0.556
[5,   224] loss: 0.547
[5,   232] loss: 0.567
[5,   240] loss: 0.543
[5,   248] loss: 0.543
[5,   256] loss: 0.576
[5,   264] loss: 0.556
[5,   272] loss: 0.555
[5,   280] loss: 0.562
[5,   288] loss: 0.552
[5,   296] loss: 0.557
[5,   304] loss: 0.546
[5,   312] loss: 0.548
[5,   320] loss: 0.544
[5,   328] loss: 0.545
[5,   336] loss: 0.548
[5,   344] loss: 0.551
[5,   352] 

HBox(children=(FloatProgress(value=0.0, max=442.0), HTML(value='')))

[6,     8] loss: 0.526
[6,    16] loss: 0.518
[6,    24] loss: 0.515
[6,    32] loss: 0.532
[6,    40] loss: 0.541
[6,    48] loss: 0.536
[6,    56] loss: 0.517
[6,    64] loss: 0.559
[6,    72] loss: 0.528
[6,    80] loss: 0.529
[6,    88] loss: 0.499
[6,    96] loss: 0.530
[6,   104] loss: 0.535
[6,   112] loss: 0.522
[6,   120] loss: 0.523
[6,   128] loss: 0.524
[6,   136] loss: 0.526
[6,   144] loss: 0.537
[6,   152] loss: 0.502
[6,   160] loss: 0.524
[6,   168] loss: 0.544
[6,   176] loss: 0.530
[6,   184] loss: 0.537
[6,   192] loss: 0.536
[6,   200] loss: 0.515
[6,   208] loss: 0.521
[6,   216] loss: 0.508
[6,   224] loss: 0.532
[6,   232] loss: 0.536
[6,   240] loss: 0.528
[6,   248] loss: 0.519
[6,   256] loss: 0.545
[6,   264] loss: 0.510
[6,   272] loss: 0.536
[6,   280] loss: 0.514
[6,   288] loss: 0.521
[6,   296] loss: 0.536
[6,   304] loss: 0.524
[6,   312] loss: 0.526
[6,   320] loss: 0.533
[6,   328] loss: 0.514
[6,   336] loss: 0.531
[6,   344] loss: 0.512
[6,   352] 

HBox(children=(FloatProgress(value=0.0, max=442.0), HTML(value='')))

[7,     8] loss: 0.498
[7,    16] loss: 0.494
[7,    24] loss: 0.495
[7,    32] loss: 0.495
[7,    40] loss: 0.487
[7,    48] loss: 0.496
[7,    56] loss: 0.485
[7,    64] loss: 0.494
[7,    72] loss: 0.492
[7,    80] loss: 0.506
[7,    88] loss: 0.517
[7,    96] loss: 0.513
[7,   104] loss: 0.531
[7,   112] loss: 0.504
[7,   120] loss: 0.508
[7,   128] loss: 0.492
[7,   136] loss: 0.472
[7,   144] loss: 0.516
[7,   152] loss: 0.512
[7,   160] loss: 0.494
[7,   168] loss: 0.462
[7,   176] loss: 0.503
[7,   184] loss: 0.504
[7,   192] loss: 0.490
[7,   200] loss: 0.489
[7,   208] loss: 0.469
[7,   216] loss: 0.493
[7,   224] loss: 0.479
[7,   232] loss: 0.487
[7,   240] loss: 0.515
[7,   248] loss: 0.504
[7,   256] loss: 0.489
[7,   264] loss: 0.518
[7,   272] loss: 0.470
[7,   280] loss: 0.480
[7,   288] loss: 0.495
[7,   296] loss: 0.517
[7,   304] loss: 0.468
[7,   312] loss: 0.505
[7,   320] loss: 0.475
[7,   328] loss: 0.471
[7,   336] loss: 0.480
[7,   344] loss: 0.465
[7,   352] 

HBox(children=(FloatProgress(value=0.0, max=442.0), HTML(value='')))

[8,     8] loss: 0.474
[8,    16] loss: 0.467
[8,    24] loss: 0.475
[8,    32] loss: 0.466
[8,    40] loss: 0.433
[8,    48] loss: 0.421
[8,    56] loss: 0.468
[8,    64] loss: 0.483
[8,    72] loss: 0.486
[8,    80] loss: 0.461
[8,    88] loss: 0.468
[8,    96] loss: 0.459
[8,   104] loss: 0.455
[8,   112] loss: 0.473
[8,   120] loss: 0.458
[8,   128] loss: 0.460
[8,   136] loss: 0.449
[8,   144] loss: 0.465
[8,   152] loss: 0.461
[8,   160] loss: 0.486
[8,   168] loss: 0.445
[8,   176] loss: 0.456
[8,   184] loss: 0.451
[8,   192] loss: 0.436
[8,   200] loss: 0.474
[8,   208] loss: 0.464
[8,   216] loss: 0.471
[8,   224] loss: 0.482
[8,   232] loss: 0.512
[8,   240] loss: 0.459
[8,   248] loss: 0.461
[8,   256] loss: 0.463
[8,   264] loss: 0.444
[8,   272] loss: 0.461
[8,   280] loss: 0.465
[8,   288] loss: 0.471
[8,   296] loss: 0.440
[8,   304] loss: 0.448
[8,   312] loss: 0.469
[8,   320] loss: 0.427
[8,   328] loss: 0.482
[8,   336] loss: 0.469
[8,   344] loss: 0.466
[8,   352] 

HBox(children=(FloatProgress(value=0.0, max=442.0), HTML(value='')))

[9,     8] loss: 0.460
[9,    16] loss: 0.447
[9,    24] loss: 0.418
[9,    32] loss: 0.435
[9,    40] loss: 0.444
[9,    48] loss: 0.444
[9,    56] loss: 0.424
[9,    64] loss: 0.421
[9,    72] loss: 0.422
[9,    80] loss: 0.424
[9,    88] loss: 0.425
[9,    96] loss: 0.418
[9,   104] loss: 0.437
[9,   112] loss: 0.409
[9,   120] loss: 0.472
[9,   128] loss: 0.409
[9,   136] loss: 0.431
[9,   144] loss: 0.444
[9,   152] loss: 0.417
[9,   160] loss: 0.442
[9,   168] loss: 0.441
[9,   176] loss: 0.429
[9,   184] loss: 0.434
[9,   192] loss: 0.411
[9,   200] loss: 0.455
[9,   208] loss: 0.437
[9,   216] loss: 0.435
[9,   224] loss: 0.473
[9,   232] loss: 0.406
[9,   240] loss: 0.455
[9,   248] loss: 0.423
[9,   256] loss: 0.455
[9,   264] loss: 0.429
[9,   272] loss: 0.423
[9,   280] loss: 0.446
[9,   288] loss: 0.450
[9,   296] loss: 0.419
[9,   304] loss: 0.434
[9,   312] loss: 0.435
[9,   320] loss: 0.427
[9,   328] loss: 0.443
[9,   336] loss: 0.447
[9,   344] loss: 0.414
[9,   352] 

HBox(children=(FloatProgress(value=0.0, max=442.0), HTML(value='')))

[10,     8] loss: 0.426
[10,    16] loss: 0.406
[10,    24] loss: 0.415
[10,    32] loss: 0.382
[10,    40] loss: 0.406
[10,    48] loss: 0.394
[10,    56] loss: 0.396
[10,    64] loss: 0.400
[10,    72] loss: 0.454
[10,    80] loss: 0.411
[10,    88] loss: 0.411
[10,    96] loss: 0.404
[10,   104] loss: 0.401
[10,   112] loss: 0.398
[10,   120] loss: 0.392
[10,   128] loss: 0.370
[10,   136] loss: 0.390
[10,   144] loss: 0.388
[10,   152] loss: 0.390
[10,   160] loss: 0.381
[10,   168] loss: 0.403
[10,   176] loss: 0.365
[10,   184] loss: 0.447
[10,   192] loss: 0.397
[10,   200] loss: 0.404
[10,   208] loss: 0.427
[10,   216] loss: 0.432
[10,   224] loss: 0.414
[10,   232] loss: 0.396
[10,   240] loss: 0.402
[10,   248] loss: 0.426
[10,   256] loss: 0.398
[10,   264] loss: 0.371
[10,   272] loss: 0.383
[10,   280] loss: 0.386
[10,   288] loss: 0.396
[10,   296] loss: 0.406
[10,   304] loss: 0.415
[10,   312] loss: 0.403
[10,   320] loss: 0.373
[10,   328] loss: 0.415
[10,   336] loss

HBox(children=(FloatProgress(value=0.0, max=442.0), HTML(value='')))

[11,     8] loss: 0.407
[11,    16] loss: 0.379
[11,    24] loss: 0.389
[11,    32] loss: 0.398
[11,    40] loss: 0.384
[11,    48] loss: 0.354
[11,    56] loss: 0.370
[11,    64] loss: 0.389
[11,    72] loss: 0.368
[11,    80] loss: 0.369
[11,    88] loss: 0.353
[11,    96] loss: 0.343
[11,   104] loss: 0.387
[11,   112] loss: 0.368
[11,   120] loss: 0.368
[11,   128] loss: 0.383
[11,   136] loss: 0.365
[11,   144] loss: 0.373
[11,   152] loss: 0.377
[11,   160] loss: 0.358
[11,   168] loss: 0.365
[11,   176] loss: 0.367
[11,   184] loss: 0.360
[11,   192] loss: 0.384
[11,   200] loss: 0.367
[11,   208] loss: 0.382
[11,   216] loss: 0.399
[11,   224] loss: 0.390
[11,   232] loss: 0.385
[11,   240] loss: 0.357
[11,   248] loss: 0.366
[11,   256] loss: 0.402
[11,   264] loss: 0.404
[11,   272] loss: 0.389
[11,   280] loss: 0.360
[11,   288] loss: 0.362
[11,   296] loss: 0.373
[11,   304] loss: 0.366
[11,   312] loss: 0.375
[11,   320] loss: 0.372
[11,   328] loss: 0.379
[11,   336] loss

HBox(children=(FloatProgress(value=0.0, max=442.0), HTML(value='')))

[12,     8] loss: 0.344
[12,    16] loss: 0.358
[12,    24] loss: 0.342
[12,    32] loss: 0.359
[12,    40] loss: 0.352
[12,    48] loss: 0.344
[12,    56] loss: 0.359
[12,    64] loss: 0.339
[12,    72] loss: 0.319
[12,    80] loss: 0.352
[12,    88] loss: 0.349
[12,    96] loss: 0.354
[12,   104] loss: 0.342
[12,   112] loss: 0.375
[12,   120] loss: 0.350
[12,   128] loss: 0.358
[12,   136] loss: 0.364
[12,   144] loss: 0.341
[12,   152] loss: 0.321
[12,   160] loss: 0.357
[12,   168] loss: 0.324
[12,   176] loss: 0.357
[12,   184] loss: 0.345
[12,   192] loss: 0.338
[12,   200] loss: 0.337
[12,   208] loss: 0.349
[12,   216] loss: 0.327
[12,   224] loss: 0.367
[12,   232] loss: 0.352
[12,   240] loss: 0.338
[12,   248] loss: 0.351
[12,   256] loss: 0.323
[12,   264] loss: 0.358
[12,   272] loss: 0.334
[12,   280] loss: 0.328
[12,   288] loss: 0.337
[12,   296] loss: 0.362
[12,   304] loss: 0.374
[12,   312] loss: 0.365
[12,   320] loss: 0.352
[12,   328] loss: 0.352
[12,   336] loss

HBox(children=(FloatProgress(value=0.0, max=442.0), HTML(value='')))

[13,     8] loss: 0.304
[13,    16] loss: 0.306
[13,    24] loss: 0.327
[13,    32] loss: 0.307
[13,    40] loss: 0.288
[13,    48] loss: 0.350
[13,    56] loss: 0.299
[13,    64] loss: 0.321
[13,    72] loss: 0.324
[13,    80] loss: 0.344
[13,    88] loss: 0.358
[13,    96] loss: 0.347
[13,   104] loss: 0.319
[13,   112] loss: 0.319
[13,   120] loss: 0.361
[13,   128] loss: 0.351
[13,   136] loss: 0.328
[13,   144] loss: 0.371
[13,   152] loss: 0.335
[13,   160] loss: 0.324
[13,   168] loss: 0.312
[13,   176] loss: 0.356
[13,   184] loss: 0.320
[13,   192] loss: 0.314
[13,   200] loss: 0.353
[13,   208] loss: 0.320
[13,   216] loss: 0.305
[13,   224] loss: 0.347
[13,   232] loss: 0.324
[13,   240] loss: 0.324
[13,   248] loss: 0.310
[13,   256] loss: 0.299
[13,   264] loss: 0.322
[13,   272] loss: 0.356
[13,   280] loss: 0.310
[13,   288] loss: 0.324
[13,   296] loss: 0.322
[13,   304] loss: 0.340
[13,   312] loss: 0.342
[13,   320] loss: 0.341
[13,   328] loss: 0.309
[13,   336] loss

HBox(children=(FloatProgress(value=0.0, max=442.0), HTML(value='')))

[14,     8] loss: 0.282
[14,    16] loss: 0.301
[14,    24] loss: 0.323
[14,    32] loss: 0.303
[14,    40] loss: 0.260
[14,    48] loss: 0.297
[14,    56] loss: 0.276
[14,    64] loss: 0.278
[14,    72] loss: 0.300
[14,    80] loss: 0.295
[14,    88] loss: 0.301
[14,    96] loss: 0.264
[14,   104] loss: 0.284
[14,   112] loss: 0.304
[14,   120] loss: 0.292
[14,   128] loss: 0.306
[14,   136] loss: 0.294
[14,   144] loss: 0.299
[14,   152] loss: 0.302
[14,   160] loss: 0.321
[14,   168] loss: 0.312
[14,   176] loss: 0.300
[14,   184] loss: 0.286
[14,   192] loss: 0.279
[14,   200] loss: 0.289
[14,   208] loss: 0.278
[14,   216] loss: 0.304
[14,   224] loss: 0.312
[14,   232] loss: 0.333
[14,   240] loss: 0.319
[14,   248] loss: 0.326
[14,   256] loss: 0.304
[14,   264] loss: 0.295
[14,   272] loss: 0.292
[14,   280] loss: 0.316
[14,   288] loss: 0.311
[14,   296] loss: 0.327
[14,   304] loss: 0.305
[14,   312] loss: 0.281
[14,   320] loss: 0.314
[14,   328] loss: 0.296
[14,   336] loss

HBox(children=(FloatProgress(value=0.0, max=442.0), HTML(value='')))

[15,     8] loss: 0.284
[15,    16] loss: 0.250
[15,    24] loss: 0.257
[15,    32] loss: 0.260
[15,    40] loss: 0.276
[15,    48] loss: 0.287
[15,    56] loss: 0.324
[15,    64] loss: 0.296
[15,    72] loss: 0.285
[15,    80] loss: 0.290
[15,    88] loss: 0.262
[15,    96] loss: 0.254
[15,   104] loss: 0.267
[15,   112] loss: 0.251
[15,   120] loss: 0.299
[15,   128] loss: 0.261
[15,   136] loss: 0.293
[15,   144] loss: 0.267
[15,   152] loss: 0.267
[15,   160] loss: 0.267
[15,   168] loss: 0.265
[15,   176] loss: 0.278
[15,   184] loss: 0.284
[15,   192] loss: 0.264
[15,   200] loss: 0.245
[15,   208] loss: 0.294
[15,   216] loss: 0.275
[15,   224] loss: 0.285
[15,   232] loss: 0.286
[15,   240] loss: 0.279
[15,   248] loss: 0.267
[15,   256] loss: 0.290
[15,   264] loss: 0.272
[15,   272] loss: 0.262
[15,   280] loss: 0.294
[15,   288] loss: 0.264
[15,   296] loss: 0.264
[15,   304] loss: 0.273
[15,   312] loss: 0.266
[15,   320] loss: 0.261
[15,   328] loss: 0.292
[15,   336] loss

HBox(children=(FloatProgress(value=0.0, max=442.0), HTML(value='')))

[16,     8] loss: 0.263
[16,    16] loss: 0.275
[16,    24] loss: 0.234
[16,    32] loss: 0.201
[16,    40] loss: 0.226
[16,    48] loss: 0.254
[16,    56] loss: 0.249
[16,    64] loss: 0.232
[16,    72] loss: 0.276
[16,    80] loss: 0.267
[16,    88] loss: 0.264
[16,    96] loss: 0.256
[16,   104] loss: 0.252
[16,   112] loss: 0.266
[16,   120] loss: 0.256
[16,   128] loss: 0.269
[16,   136] loss: 0.248
[16,   144] loss: 0.213
[16,   152] loss: 0.267
[16,   160] loss: 0.225
[16,   168] loss: 0.233
[16,   176] loss: 0.228
[16,   184] loss: 0.227
[16,   192] loss: 0.276
[16,   200] loss: 0.254
[16,   208] loss: 0.241
[16,   216] loss: 0.254
[16,   224] loss: 0.257
[16,   232] loss: 0.241
[16,   240] loss: 0.252
[16,   248] loss: 0.255
[16,   256] loss: 0.239
[16,   264] loss: 0.260
[16,   272] loss: 0.264
[16,   280] loss: 0.270
[16,   288] loss: 0.259
[16,   296] loss: 0.241
[16,   304] loss: 0.261
[16,   312] loss: 0.259
[16,   320] loss: 0.275
[16,   328] loss: 0.236
[16,   336] loss

In [28]:
# Check Validation

In [29]:
# Dataset and Dataloader
valid_dataset = ProjectDataset(val_triplets, preload=True)
valid_dataloader = DataLoader(valid_dataset, batch_size = 64)

In [30]:
correct = 0
total = 0
for ind, i in tqdm(enumerate(valid_dataloader, 0), total = len(valid_dataloader)):
    outputs = net(i)
    correct_classification = (torch.argmax(outputs, axis=1).cpu() == i['label'])
    correct += (correct_classification).sum(dim=0)
    total += correct_classification.shape[0]
print("Test accuracy: {}.".format(correct/float(total)))

HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))


Test accuracy: 0.8057795763015747.


In [31]:
# Predictions

In [32]:
# Data - Loading triplets
test_triplets = pd.read_csv('test_triplets.txt', header=None, delim_whitespace=True, names=['A', 'B', 'C'], dtype=str)
test_triplets = test_triplets.apply(lambda name : name + '.jpg')

In [34]:
## Dataset and Dataloader
test_dataset = ProjectDataset(test_triplets, preload=True, prediction=True)
test_dataloader = DataLoader(test_dataset, batch_size = 128)

In [35]:
labels = np.zeros(len(test_dataset))
voting = 3

for _ in range(voting):
    index = 0
    for ind, i in tqdm(enumerate(test_dataloader, 0), total = len(test_dataloader)):
        outputs = net(i)
        for label in torch.argmax(outputs, axis=1):
            labels[index] += label
            index += 1
            
# Minimum half as much have to vote yes (if 50/50 -> no)
voted_labels = (labels >= ((voting+2) // 2))
        
with open("prediction.txt", 'w') as file:
    for label in voted_labels:
        file.write("{}\n".format(str(int(label))))

HBox(children=(FloatProgress(value=0.0, max=466.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=466.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=466.0), HTML(value='')))


