In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
import torchvision
import torchvision.transforms as transforms

import numpy as np
from sklearn.cluster import KMeans
from sklearn.model_selection import StratifiedKFold

In [3]:
class TripletResNet(nn.Module):
    def __init__(self, metric_dim, n_classes):
        super(TripletResNet, self).__init__()
        resnet = torchvision.models.__dict__['resnet18'](pretrained=True)
        for params in resnet.parameters():
            params.requires_grad = False

        self.model = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2,
            resnet.layer3,
            resnet.layer4,
            resnet.avgpool,
        )
        self.fc1 = nn.Linear(resnet.fc.in_features, metric_dim)
        self.fc2 = nn.Linear(metric_dim, n_classes)

    def forward(self, x):
        x = self.model(x)
        x = x.view(x.size(0), -1)
        metric = F.normalize(self.fc1(x))
        classes = F.softmax(self.fc2(metric))
        return metric, classes

In [4]:
def load_filter_dataset(dataset):
    images, labels = [], []
    cat_counter = 0
    for train in dataset:
        X, y = train[0], train[1]
        if y == 5: # dog
            images.append(X)
            labels.append(0)
        elif y == 3: # cat
            if cat_counter == 500:
                continue
            images.append(X)
            labels.append(1)
            cat_counter += 1
    return torch.stack(images), torch.Tensor(labels)

In [5]:
def load_dataset():
    transform = transforms.Compose(
        [transforms.Resize(224),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, 
                                           download=True, transform=transform)
    X_train, y_train = load_filter_dataset(trainset)
    X_test, y_test = load_filter_dataset(testset)

    return X_train, X_test, y_train, y_test

In [6]:
X_train, X_test, y_train, y_test = load_dataset()

Files already downloaded and verified
Files already downloaded and verified


In [25]:
device = 'cpu'#torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TripletResNet(100, 2)
model.to(device)

model.load_state_dict(torch.load('./weights/triplet-angular_1.16779.pth'))
model.eval()
print()




In [None]:
train_metric, train_classes = model(X_train)
test_metric, test_classes = model(X_test)

train_metric = train_metric.detach().numpy()
test_metric = test_metric.detach().numpy()
test_classes = test_classes.detach().numpy()
test_labels = test_classes.argmax(1)