In [2]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [3]:
import numpy as np
import matplotlib.pyplot as plt
import os, sys

In [4]:
import torch
import torch.nn as nn

In [5]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms

## Siamese networks

In [6]:
from dataloader import SiameseTestData_Omniglot, SiameseTrainData_Omniglot
from networks import SiameseNet
from losses import ContrastiveLoss
import torchvision.models as models
import torch.optim.lr_scheduler as lr_scheduler
from utils import AverageMeter
import time

In [7]:
input_size = 105
learning_rate = 1e3
epochs = 500
batch_size = 256
num_workers = 4

In [8]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    cuda = True
else:
    device = torch.device("cpu")
device

device(type='cuda')

In [9]:
data_transforms = {
        'train': transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]),
        'val': transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    }

In [10]:
from torch.utils.data.sampler import SubsetRandomSampler

In [11]:
omniglot_trainset = datasets.Omniglot(root='./omniglot_data/', download=True, background=True, transform=data_transforms['train'])
omniglot_valset = datasets.Omniglot(root='./omniglot_data/', download=True, background=True, transform=data_transforms['val'])
omniglot_evalset = datasets.Omniglot(root='./omniglot_data/', download=True, background=False, transform=data_transforms['val'])

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [12]:
indices = list(range(len(omniglot_trainset)))
split = int(0.15 * len(omniglot_trainset))
train_indices = indices[:split]
val_indices = indices[split:]
train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)

In [13]:
trainset = SiameseTrainData_Omniglot(omniglot_trainset)
valset = SiameseTestData_Omniglot(omniglot_valset)
testset = SiameseTestData_Omniglot(omniglot_evalset)

In [14]:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, sampler=train_sampler, num_workers=4, pin_memory=True)
valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, sampler=val_sampler, num_workers=4, pin_memory=True)

In [15]:
class KochNet(nn.Module):

    def __init__(self):
        super(KochNet, self).__init__()
        self.features = nn.Sequential(
            # 1x105x105
            nn.Conv2d(1, 64, kernel_size=10),
            # 64x96x96
            nn.ReLU(inplace=True), 
            nn.MaxPool2d(kernel_size=2),
            # 64x48x48
            nn.Conv2d(64, 128, kernel_size=7),
            # 128x42x42
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            # 128x21x21
            nn.Conv2d(128, 128, kernel_size=4),
            # 128x18x18
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            # 128x9x9
            nn.Conv2d(128, 256, kernel_size=4),
            # 256x6x6
            nn.ReLU(inplace=True),
        )
        self.fc = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.Sigmoid(),
        )
        self.output = nn.Linear(4096, 1)
        

    def forward(self, x1, x2):
        x1 = self.features(x1)
        x1 = x1.view(x1.size(0), 256 * 6 * 6)
        x1 = self.fc(x1)
        x2 = self.features(x2)
        x2 = x2.view(x2.size(0), 256 * 6 * 6)
        x2 = self.fc(x2)
        
        dist = torch.abs(x1 - x2)
        out = self.output(dist)
        return out

In [16]:
model = KochNet()
if cuda:
    model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [17]:
T_max = epochs
eta_min = 0.01
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max, eta_min=eta_min)
criterion = nn.BCEWithLogitsLoss()


In [18]:
def train(train_loader, model, criterion, optimizer, epoch, device, debug=False, print_freq=200):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for batch_idx, (imgs1, imgs2, targets) in enumerate(train_loader):
        data_time.update(time.time() - end)

        imgs1 = imgs1.to(device).float()
        imgs2 = imgs2.to(device).float()
        targets = targets.to(device).float()

        output = model(imgs1, imgs2)

        loss = criterion(output, targets)

        losses.update(loss.item(), imgs1.size(0))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if batch_idx % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(
                      epoch, batch_idx, len(train_loader), batch_time=batch_time,
                      data_time=data_time, loss=losses))
        if debug:
            break

    return losses.avg

In [19]:
def validate(val_loader, model, epoch, device, print_freq=200):
    batch_time = AverageMeter()
    losses = AverageMeter()

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for batch_idx, (imgs1, img2) in enumerate(val_loader):
            imgs1 = imgs1.to(device).float()
            imgs2 = imgs2.to(device).float()
            
            right, error = 0, 0
            output = model(imgs1, imgs2)
            pred = np.argmax(output.cpu().numpy())
            if pred == 0:
                right += 1
            else: error += 1
            print('*'*70)
            print('[%d]\tright:\t%d\terror:\t%d\tprecision:\t%f'%(batch_id, right, error, right*1.0/(right+error)))
            print('*'*70)
            losses.update(loss.item(), imgs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(
                          i, len(val_loader), batch_time=batch_time, loss=losses))
            if args.debug:
                break

    return losses.avg

In [21]:
len(trainset)

19280

In [19]:
for e in range(epochs):
    scheduler.step()
    train(trainloader, model, criterion, optimizer, e, device, False, 10)

Epoch: [0][0/12]	Time 1.548 (1.548)	Data 0.650 (0.650)	Loss 0.6930 (0.6930)
Epoch: [0][10/12]	Time 0.890 (0.948)	Data 0.000 (0.059)	Loss 13483.6641 (43407.0135)


Process Process-8:
Process Process-6:
Process Process-5:
Process Process-7:
Traceback (most recent call last):
  File "/home/var/anaconda3/envs/diss/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/var/anaconda3/envs/diss/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/var/anaconda3/envs/diss/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/var/anaconda3/envs/diss/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 106, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/var/anaconda3/envs/diss/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/var/anaconda3/envs/diss/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
 

KeyboardInterrupt: 

In [1]:
len(trainloader)

NameError: name 'trainloader' is not defined