In [88]:
import json

import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
import torchvision
from torchvision import transforms, datasets


# 0: 猫, 1: 犬
def load_fname(json_path):
    with open(json_path) as f:
        d = f.read()
    data_dict = json.loads(d)

    image_fname_list = []
    for animal in data_dict:
        for seed in data_dict[animal]:
            image_fname_list += data_dict[animal][seed]['images']

    return image_fname_list


def load_images(image_fname_list, transform):
    ret_images = []
    ret_labels = []
    for fname in image_fname_list:
        path = './data/' + fname + '.jpg'
        img = Image.open(path).convert('RGB')
        img = transform(img)
        ret_images.append(img)

        label = 1 if fname[0].islower() else 0
        ret_labels.append(label)
    
    ret_images = torch.stack(ret_images)
    ret_labels = torch.Tensor(ret_labels)

    return ret_images, ret_labels


def load_datasets(train_json, test_json, transform):
    fnames = load_fname(train_json)
    test_fnames = load_fname(test_json)

    size = int(len(fnames) * 0.3)
    valid_fnames = np.random.choice(fnames, size)
    train_fnames = []
    for fname in fnames:
        if fname not in valid_fnames:
            train_fnames.append(fname)

    X_train, y_train = load_images(train_fnames, transform)
    X_valid, y_valid = load_images(valid_fnames, transform)
    X_test, y_test = load_images(test_fnames, transform)
    # X_train, X_valid, y_train, y_valid = train_test_split(X, y, shuffle=True, random_state=27)

    train_dataset = TripletSampler(X_train, y_train)
    valid_dataset = TripletSampler(X_valid, y_valid)
    test_dataset = TripletSampler(X_test, y_test)

    return train_dataset, valid_dataset, test_dataset



class TripletSampler(Dataset):
    def __init__(self, inputs, targets):
        self.inputs = inputs
        self.targets = targets

    def __len__(self):
        return len(self.targets)
 
    def __getitem__(self, idx):
        x = self.inputs[idx]
        t = self.targets[idx]
        xp_idx = np.random.choice(np.where(self.targets == t)[0])
        xn_idx = np.random.choice(np.where(self.targets != t)[0])
        xp = self.inputs[xp_idx]
        xn = self.inputs[xn_idx]
        return x, xp, xn

In [83]:
class TripletResNet(nn.Module):
    def __init__(self, metric_dim):
        super(TripletResNet, self).__init__()
        resnet = torchvision.models.__dict__['resnet18'](pretrained=True)
        for params in resnet.parameters():
            params.requires_grad = False

        self.model = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2,
            resnet.layer3,
            resnet.layer4,
            resnet.avgpool,
        )
        self.fc = nn.Linear(resnet.fc.in_features, metric_dim)

    def forward(self, x):
        x = self.model(x)
        x = x.view(x.size(0), -1)
        metric = self.fc(x)
        return metric

In [84]:
class TripletLoss(nn.Module):
    def __init__(self, margin=0.2):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        dist = torch.sum(
            torch.pow((anchor - positive), 2) - torch.pow((anchor - negative), 2),
            dim=1) + self.margin
        return F.relu(dist).mean()


class TripletAngularLoss(nn.Module):
    def __init__(self, alpha=45, in_degree=True):
        # y=dnn(x), must be L2 Normalized.
        super(TripletAngularLoss, self).__init__()
        if in_degree:
            alpha = np.deg2rad(alpha)
        self.tan_alpha = np.tan(alpha) ** 2

    def forward(self, a, p, n):
        c = (a + p) / 2
        loss = F.relu(F.normalize(a - p).pow(2) - 4 * self.tan_alpha * F.normalize(n - c).pow(2))
        return loss.sum()

In [85]:
transform = transforms.Compose([transforms.Resize(224),
                                transforms.CenterCrop(224),
                                transforms.ToTensor()])

train_dataset, valid_dataset, test_dataset = \
    load_datasets('./configs/train2.json', './configs/test2.json', transform)

train_loader = DataLoader(train_dataset, batch_size=4)
valid_loader = DataLoader(valid_dataset, batch_size=4)
test_loader = DataLoader(test_dataset, batch_size=4)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = TripletResNet(100)
model = model.to(device)
# criterion = TripletAngularLoss()
criterion = TripletLoss()
optimizer = torch.optim.SGD(model.parameters(),
                            lr=1e-3,
                            momentum=0.9,
                            weight_decay=1e-4)

In [89]:
model.train()
optimizer.zero_grad()

epoch_loss = 0
for i, (anchors, positives, negatives) in enumerate(train_loader):
    anchors = anchors.to(device)
    positives = positives.to(device)
    negatives = negatives.to(device)

    anc_metric = model(anchors)
    pos_metric = model(positives)
    neg_metric = model(negatives)

    loss = criterion(anc_metric, pos_metric, neg_metric)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    epoch_loss += loss.item()

KeyboardInterrupt: 