In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
import torchvision
import torchvision.transforms as transforms

import numpy as np
from sklearn.model_selection import StratifiedKFold

In [2]:
transform = transforms.Compose(
    [transforms.Resize(224),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [4]:
def load_filter_dataset(dataset):
    images, labels = [], []
    cat_counter = 0
    for train in dataset:
        X, y = train[0], train[1]
        if y == 5: # dog
            images.append(X)
            labels.append(y)
        elif y == 3: # cat
            if cat_counter == 500:
                continue
            images.append(X)
            labels.append(y)
            cat_counter += 1
    return torch.stack(images), torch.Tensor(labels)


images, labels = load_filter_dataset(trainset)
test_images, test_labels = load_filter_dataset(testset)

In [5]:
class NPairSamplerCifar(Dataset):
    def __init__(self, images, labels, n):
        self.images = images
        self.labels = labels
        self.n = n

    def __len__(self):
        return len(self.labels)
 
    def __getitem__(self, idx):
        x = self.images[idx]
        t = self.labels[idx]
        xp_idx = np.random.choice(np.where(self.labels == t)[0])
        xp = self.images[xp_idx]
        xn = []
        for i in range(self.n):
            xn_idx = np.random.choice(np.where(self.labels != t)[0])
            xn_tmp = self.images[xn_idx]
            xn.append(xn_tmp)
        return x, t, xp, torch.stack(xn)

In [6]:
sampler = NPairSamplerCifar(images, labels, 5)

In [7]:
for i in sampler[0]:
    print(i.size())

torch.Size([3, 224, 224])
torch.Size([])
torch.Size([3, 224, 224])
torch.Size([5, 3, 224, 224])


In [8]:
class CifarDataset(Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels

    def __len__(self):
        return len(self.labels)
 
    def __getitem__(self, idx):
        x = self.images[idx]
        t = self.labels[idx]
        xp_idx = np.random.choice(np.where(self.labels == t)[0])
        xn_idx = np.random.choice(np.where(self.labels != t)[0])
        xp = self.images[xp_idx]
        xn = self.images[xn_idx]
        return x, t, xp, xn

In [9]:
test_dataset = CifarDataset(test_images, test_labels)

In [10]:
class TripletResNet(nn.Module):
    def __init__(self, out_dim, n_classes):
        super(TripletResNet, self).__init__()
        resnet = torchvision.models.__dict__['resnet18'](pretrained=True)
        for params in resnet.parameters():
            params.requires_grad = False

        self.model = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2,
            resnet.layer3,
            resnet.layer4,
            resnet.avgpool,
        )
        self.fc1 = nn.Linear(resnet.fc.in_features, out_dim)
        self.fc2 = nn.Linear(out_dim, n_classes)

    def forward(self, x):
        x = self.model(x)
        x = x.view(x.size(0), -1)
        metric = F.normalize(self.fc1(x))
        classes = self.fc2(metric)
        return metric, classes

In [11]:
class TripletLoss(nn.Module):
    def __init__(self, margin=0.2):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, t, positive, negative):
        dist = torch.sum(
            torch.pow((anchor - positive), 2) - torch.pow((anchor - negative), 2),
            dim=1) + self.margin
        dist = F.relu(dist)
        loss = torch.mean(dist)
        y_0 = anchor[t==0]
        if len(y_0) > 0:
            loss += torch.mean(y_0**2)
        return loss

In [12]:
def flow_data(model, data_loader, optimizer=None):
    if optimizer is None:
        model.eval()
        training = False
    else:
        model.train()
        optimizer.zero_grad()
        training = True

    epoch_loss = 0
    for i, (anchors, labels, positives, negatives) in enumerate(data_loader):
        anchors = anchors.to(device)
        labels = labels.to(device)
        positives = positives.to(device)
        negatives = negatives.to(device)

        out_anc, _ = model(anchors)
        out_pos, _ = model(positives)
        out_neg, _ = model(negatives)

        loss = criterion(out_anc, labels, out_pos, out_neg)

        if training:
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        epoch_loss += loss.item()
    return epoch_loss / len(data_loader)

In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

cv = StratifiedKFold(n_splits=5, random_state=27, shuffle=True)

for cv_i, (train_idx, valid_idx) in enumerate(cv.split(images, labels)):
    train_dataset = CifarDataset(images[train_idx], labels[train_idx])
    valid_dataset = CifarDataset(images[valid_idx], labels[valid_idx])
    train_loader = DataLoader(train_dataset, batch_size=32)
    valid_loader = DataLoader(valid_dataset, batch_size=32)

    model = TripletResNet(100, 2)
    model = model.to(device)
    criterion = TripletLoss()
    optimizer = torch.optim.SGD(model.parameters(), 
                                lr=0.001, momentum=0.9, weight_decay=1e-4)

    for epoch in range(2):
        train_epoch_loss = 0
        valid_epoch_loss = 0
        train_loss = flow_data(model, train_loader, optimizer=optimizer)
        valid_loss = flow_data(model, valid_loader)

        train_epoch_loss += train_loss.item()
        valid_epoch_loss += valid_loss.item()

        print('EPOCH: [{epoch}/{2}], train_loss: {train_epoch_loss:.3f}, valid_loss: {valid_epochs_loss:.3f}')

KeyboardInterrupt: 

In [13]:
class TripletSamplerCifar(Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels

    def __len__(self):
        return len(self.labels)
 
    def __getitem__(self, idx):
        x = self.images[idx]
        t = self.labels[idx]
        xp_idx = np.random.choice(np.where(self.labels == t)[0])
        xn_idx = np.random.choice(np.where(self.labels != t)[0])
        xp = self.images[xp_idx]
        xn = self.images[xn_idx]
        return x, t, xp, xn