In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import os, sys

In [3]:
import torch
import torch.nn as nn

In [4]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import random

## Siamese networks

In [5]:
from dataloader import SiameseTestData_ImageFolder, SiameseTrainData_ImageFolder
from networks import SiameseNet
from losses import ContrastiveLoss
import torchvision.models as models
import torch.optim.lr_scheduler as lr_scheduler
from utils import AverageMeter
import time

In [6]:
input_size = 105
learning_rate = 1e3
epochs = 200
batch_size = 256
num_workers = 4

In [14]:
cuda = False
if torch.cuda.is_available():
    device = torch.device("cuda")
    cuda = True
else:
    device = torch.device("cpu")
device

device(type='cpu')

In [15]:
data_transforms = {
        'train': transforms.Compose([
            transforms.Grayscale(num_output_channels=1),
            transforms.Resize((input_size, input_size)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]),
        'val': transforms.Compose([
            transforms.Grayscale(num_output_channels=1),
            transforms.Resize((input_size, input_size)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    }

In [16]:
trainset = datasets.ImageFolder('./omniglot_data/changed/train', transform=data_transforms['train'])
valset = datasets.ImageFolder('./omniglot_data/changed/valid', transform=data_transforms['val'])
testset =  datasets.ImageFolder('./omniglot_data/changed/test', transform=data_transforms['val'])

In [17]:
train_siamese = SiameseTrainData_ImageFolder(trainset)
val_siamese = SiameseTestData_ImageFolder(valset, times=161)

In [18]:
trainloader = torch.utils.data.DataLoader(train_siamese, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
valloader = torch.utils.data.DataLoader(val_siamese, batch_size=20, shuffle=False, num_workers=4, pin_memory=True)

In [19]:
class KochNet(nn.Module):

    def __init__(self):
        super(KochNet, self).__init__()
        self.features = nn.Sequential(
            # 1x105x105
            nn.Conv2d(1, 64, kernel_size=10),
            # 64x96x96
            nn.ReLU(inplace=True), 
            nn.MaxPool2d(kernel_size=2),
            # 64x48x48
            nn.Conv2d(64, 128, kernel_size=7),
            # 128x42x42
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            # 128x21x21
            nn.Conv2d(128, 128, kernel_size=4),
            # 128x18x18
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            # 128x9x9
            nn.Conv2d(128, 256, kernel_size=4),
            # 256x6x6
            nn.ReLU(inplace=True),
        )
        self.fc = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.Sigmoid(),
        )
        self.output = nn.Linear(4096, 1)
        

    def forward(self, x1, x2):
        x1 = self.features(x1)
        x1 = x1.view(x1.size(0), 256 * 6 * 6)
        x1 = self.fc(x1)
        x2 = self.features(x2)
        x2 = x2.view(x2.size(0), 256 * 6 * 6)
        x2 = self.fc(x2)
        dist = torch.abs(x1 - x2)
        out = self.output(dist)
        return out

In [20]:
model = KochNet()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
if cuda:
    model.to(device)

In [21]:
T_max = epochs
eta_min = 0.01
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max, eta_min=eta_min)
criterion = nn.BCEWithLogitsLoss()

In [22]:
def train(train_loader, model, criterion, optimizer, epoch, device, debug=False, print_freq=200):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for batch_idx, (imgs1, imgs2, targets) in enumerate(train_loader):
        data_time.update(time.time() - end)

        imgs1 = imgs1.to(device).float()
        imgs2 = imgs2.to(device).float()
        targets = targets.to(device).float()

        output = model(imgs1, imgs2)

        loss = criterion(output, targets)

        losses.update(loss.item(), imgs1.size(0))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if batch_idx % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(
                      epoch, batch_idx, len(train_loader), batch_time=batch_time,
                      data_time=data_time, loss=losses))
        if debug:
            break

    return losses.avg

In [23]:
def validate(val_loader, model, epoch, device, print_freq=100):
    batch_time = AverageMeter()

    # switch to evaluate mode
    model.eval()
    right, error = 0, 0
    with torch.no_grad():
        end = time.time()
        for batch_idx, (imgs1, imgs2) in enumerate(val_loader):
            imgs1 = imgs1.to(device).float()
            imgs2 = imgs2.to(device).float()
            
           
            output = model(imgs1, imgs2)
#             print(output)
            pred = np.argmax(output.cpu().numpy())
#             print(pred)
            if pred == 0:
                right += 1
            else: error += 1
           

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
             
            if batch_idx % print_freq == 0:
                print('*'*70)
                print('[%d]\tcorrect:\t%d\twrong:\t%d\tprecision:\t%f'%(batch_idx, right, error, right*1.0/(right+error)))
                print('*'*70)
                print('Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'.format(batch_time=batch_time))

In [17]:
for e in range(epochs):
    scheduler.step()
    print('*'*70)
    train(trainloader, model, criterion, optimizer, e, device, False, 10)
    if e % 10 == 0:
        validate(valloader, model, e, device) 

**********************************************************************
Epoch: [0][0/63]	Time 1.599 (1.599)	Data 0.677 (0.677)	Loss 0.6932 (0.6932)
Epoch: [0][10/63]	Time 0.888 (0.953)	Data 0.000 (0.062)	Loss 13010.0781 (61147.0602)
Epoch: [0][20/63]	Time 0.892 (0.922)	Data 0.000 (0.032)	Loss 28952.9883 (45831.8793)
Epoch: [0][30/63]	Time 0.888 (0.910)	Data 0.000 (0.022)	Loss 22080.5781 (39842.8959)
Epoch: [0][40/63]	Time 0.885 (0.904)	Data 0.000 (0.017)	Loss 16020.5391 (36838.8013)
Epoch: [0][50/63]	Time 0.890 (0.901)	Data 0.000 (0.013)	Loss 21635.7910 (34108.7984)
Epoch: [0][60/63]	Time 0.888 (0.899)	Data 0.000 (0.011)	Loss 14106.1367 (31935.5298)
**********************************************************************
[0]	correct:	0	wrong:	1	precision:	0.000000
**********************************************************************
Time 0.229 (0.229)	
**********************************************************************
[100]	correct:	2	wrong:	99	precision:	0.019802
*******************

Process Process-41:
Process Process-43:
Process Process-44:
Process Process-42:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/var/anaconda3/envs/diss/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/var/anaconda3/envs/diss/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/var/anaconda3/envs/diss/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/var/anaconda3/envs/diss/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/var/anaconda3/envs/diss/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/var/anaconda3/envs/diss/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 96, in _worker_loop
    r = index_queue.get(timeout=M

KeyboardInterrupt: 

In [18]:
len(valset)/20

161.0

In [20]:
len(trainset)/256

62.734375