In [16]:
import numpy as np
import torch.nn.functional as F
import torch
import torch.nn as nn
import pandas as pd
import torch.optim as optim
from sklearn.model_selection import train_test_split
import torchvision.models as models
from visdom import Visdom
import cv2

In [2]:
csv_file = "../dataset/train.csv"
train_dir = "../dataset/train/"
test_dir = "../dataset/test/"

In [23]:
class data():
    def __init__(self):
        self.index = 0
        self.batch = 100
        self.next_round = True
        self.end = False
        self.build_data()
        
    def read_image(self, data):
        image_x = []
        label_y = []
        for i in range(data.shape[0]):
            img = cv2.imread(train_dir+data[i][0])
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = cv2.resize(img, (256, 256))
            image_x.append(img)
#             label_y.append(data[i][1])
#         return np.array(image_x), np.expand_dims(np.array(label_y),-1)
        return np.array(image_x)
    
    def build_data(self):
        data = pd.read_csv(csv_file)['Image']
        train_data, test_data = train_test_split(data, test_size=0.1, shuffle=True, random_state=1337)
        train_list = np.array(train_data)
        print(train_list[0:10])
        test_list = np.array(test_data)
        train_list = np.array(np.meshgrid(train_list,train_list,train_list)).T.reshape(-1,3)
        print(train_list[0:10])
        self.train = np.array(train_data)
        self.test = np.array(test_data)
    
    def next_batch(self):
        start = self.index
        if (self.next_round == False):
            end = self.train.shape[0]
            self.end = True
        else:
            end = self.index+self.batch
        x, y = self.read_image(self.train[start:end])
        self.index += self.batch
        return x, y
    
    def has_next_batch(self):
        if (self.end):
            return False
        if (self.index + self.batch >= self.train.shape[0] - 1 ):
            self.next_round = False
        return True
    def reset_batch(self):
        self.index = 0
        self.end = False
    
    def generate(self):
        pass
    

In [24]:
a = data()

['416f311e.jpg' '43099c40.jpg' '9583591b.jpg' 'a65bb97f.jpg'
 '02cff75e.jpg' '040d8913.jpg' 'bed60aa6.jpg' '5114596e.jpg'
 '8d0aa6b7.jpg' '598f3afd.jpg']


MemoryError: 

In [13]:
class Tripletnet(nn.Module):
    def __init__(self, embeddingnet):
        super(Tripletnet, self).__init__()
        self.embeddingnet = embeddingnet
        pass
    def forward(self, x, y, z):
        embedded_x = self.embeddingnet(x)
        embedded_y = self.embeddingnet(y)
        embedded_z = self.embeddingnet(z)
        dist_a = F.pairwise_distance(embedded_x, embedded_y, 2)
        dist_b = F.pairwise_distance(embedded_x, embedded_z, 2)
        return dist_a, dist_b, embedded_x, embedded_y, embedded_z

In [5]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [6]:
def accuracy(dista, distb):
    margin = 0
    pred = (dista - distb - margin).cpu().data
    return (pred > 0).sum()*1.0/dista.size()[0]

In [7]:
class VisdomLinePlotter(object):
    """Plots to Visdom"""
    def __init__(self, env_name='main'):
        self.viz = Visdom()
        self.env = env_name
        self.plots = {}
    def plot(self, var_name, split_name, x, y):
        if var_name not in self.plots:
            self.plots[var_name] = self.viz.line(X=np.array([x,x]), Y=np.array([y,y]), env=self.env, opts=dict(
                legend=[split_name],
                title=var_name,
                xlabel='Epochs',
                ylabel=var_name
            ))
        else:
            self.viz.updateTrace(X=np.array([x]), Y=np.array([y]), env=self.env, win=self.plots[var_name], name=split_name)

In [8]:
def train(train_loader, tnet, criterion, optimizer, epoch):
    losses = AverageMeter()
    accs = AverageMeter()
    emb_norms = AverageMeter()

    # switch to train mode
    tnet.train()
    for batch_idx, (data1, data2, data3) in enumerate(train_loader):

        data1, data2, data3 = data1.cuda(), data2.cuda(), data3.cuda()
#         data1, data2, data3 = Variable(data1), Variable(data2), Variable(data3)

        # compute output
        dista, distb, embedded_x, embedded_y, embedded_z = tnet(data1, data2, data3)
        # 1 means, dista should be larger than distb
        target = torch.FloatTensor(dista.size()).fill_(1).cuda()
#         if args.cuda:
#         target = Variable(target)
        
        loss_triplet = criterion(dista, distb, target)
        loss_embedd = embedded_x.norm(2) + embedded_y.norm(2) + embedded_z.norm(2)
        loss = loss_triplet + 0.001 * loss_embedd

        # measure accuracy and record loss
        acc = accuracy(dista, distb)
        losses.update(loss_triplet.data[0], data1.size(0))
        accs.update(acc, data1.size(0))
        emb_norms.update(loss_embedd.data[0]/3, data1.size(0))

        # compute gradient and do optimizer step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{}]\t'
                  'Loss: {:.4f} ({:.4f}) \t'
                  'Acc: {:.2f}% ({:.2f}%) \t'
                  'Emb_Norm: {:.2f} ({:.2f})'.format(
                epoch, batch_idx * len(data1), len(train_loader.dataset),
                losses.val, losses.avg, 
                100. * accs.val, 100. * accs.avg, emb_norms.val, emb_norms.avg))
    # log avg values to somewhere
    plotter.plot('acc', 'train', epoch, accs.avg)
    plotter.plot('loss', 'train', epoch, losses.avg)
    plotter.plot('emb_norms', 'train', epoch, emb_norms.avg)

In [9]:
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    """Saves checkpoint to disk"""
    directory = "runs/%s/"%("model")
    if not os.path.exists(directory):
        os.makedirs(directory)
    filename = directory + filename
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'runs/%s/'%("model") + 'model_best.pth.tar')

In [10]:
def test(test_loader, tnet, criterion, epoch):
    losses = AverageMeter()
    accs = AverageMeter()

    # switch to evaluation mode
    tnet.eval()
    for batch_idx, (data1, data2, data3) in enumerate(test_loader):
        if args.cuda:
            data1, data2, data3 = data1.cuda(), data2.cuda(), data3.cuda()
        data1, data2, data3 = Variable(data1), Variable(data2), Variable(data3)

        # compute output
        dista, distb, _, _, _ = tnet(data1, data2, data3)
        target = torch.FloatTensor(dista.size()).fill_(1)
        if args.cuda:
            target = target.cuda()
        target = Variable(target)
        test_loss =  criterion(dista, distb, target).data[0]

        # measure accuracy and record loss
        acc = accuracy(dista, distb)
        accs.update(acc, data1.size(0))
        losses.update(test_loss, data1.size(0))      

    print('\nTest set: Average loss: {:.4f}, Accuracy: {:.2f}%\n'.format(
        losses.avg, 100. * accs.avg))
    plotter.plot('acc', 'test', epoch, accs.avg)
    plotter.plot('loss', 'test', epoch, losses.avg)
    return accs.avg

In [14]:
def main():
    model = models.resnet101(pretrained=True)
    tnet = Tripletnet(model)
    criterion = torch.nn.MarginRankingLoss(margin = 0.2)
    optimizer = optim.SGD(tnet.parameters(), lr=0.01, momentum=0.5)
    epochs = 100
    best_acc = 0
    for epoch in range(1, epochs + 1):
        # train for one epoch
        train(train_loader, tnet, criterion, optimizer, epoch)
        # evaluate on validation set
        acc = test(test_loader, tnet, criterion, epoch)

        # remember best acc and save checkpoint
        is_best = acc > best_acc
        best_acc = max(acc, best_acc)
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': tnet.state_dict(),
            'best_prec1': best_acc,
        }, is_best)

In [17]:
main()

NameError: name 'train_loader' is not defined