In [1]:
import os
import json
import csv
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import random
random.seed(100)
from tqdm.auto import tqdm
import pickle
from triplet_utils import *
import torch.nn as nn
import torch.nn.functional as F

## Reading training data

In [2]:
with open(os.path.join(data_path, "2DInstances.txt")) as f:
    reader = csv.reader(f, delimiter='\t')
    next(reader)
    instances = [i for i in reader]

In [3]:
nb_instances = len(instances)
print(f"Number of training instances: {nb_instances}")

Number of training instances: 1970


## Triplets Dataset and Dataloader

In [4]:
triplets_possibilities = {}

for idx in tqdm(range(len(instances))):
    
    anchor = instances[idx]
    
    all_possible_positives, all_possible_negatives, triplet_found = find_triplet(anchor, instances)

    triplets_possibilities[str(anchor)] = {
        "pos": all_possible_positives,
        "neg": all_possible_negatives
    }

with open("saved_dict/triplets_possibilities.pkl", "wb") as f:
    pickle.dump(triplets_possibilities, f)

  0%|          | 0/1970 [00:00<?, ?it/s]

In [5]:
with open("saved_dict/triplets_possibilities.pkl", "rb") as f:
    triplets_possibilities = pickle.load(f)

In [6]:
print(f"Number of triplets possibilities: {len(triplets_possibilities)}")

Number of triplets possibilities: 1970


In [7]:
anchors = []
positives = []
negatives = []
removed_idx = []
triplet_dict = {}

for idx in tqdm(range(len(instances))):
    anchor = instances[idx]
    pos = triplets_possibilities[str(anchor)]["pos"]
    neg = triplets_possibilities[str(anchor)]["neg"]
    valid, positive, negative = sample_triplets(pos, neg)

    if valid:
        # these lists are for the idx-based iteration of torch.util.dataset
        anchors.append(anchor)
        positives.append(positive)
        negatives.append(negative)

        # this allows an O(1) access to a triplet by its key
        triplet_dict[str(anchor)] = {
            "anchor": anchor,
            "pos": positive,
            "neg": negative
        }
    else:
        removed_idx.append(idx)

# remove all instances for which we could not find a valid triplet
instances = [i for idx, i in enumerate(instances) if idx not in removed_idx]

  0%|          | 0/1970 [00:00<?, ?it/s]

In [8]:
print(f"Number of valid triplets: {len(triplet_dict)}")

Number of valid triplets: 469


In [9]:
import torch.utils.data

class TripletDataset(torch.utils.data.Dataset):
    def __init__(self, anchors, positives, negatives, transform=None):
        self.transform = transform
        self.anchors = anchors
        self.positives = positives
        self.negatives = negatives
        self.size = len(self.anchors)

    def __getitem__(self, idx):
        anchor = load_instance(self.anchors[idx], self.transform)
        p = load_instance(self.positives[idx], self.transform)
        n = [load_instance(neg, transform) for neg in self.negatives[idx]]

        return {
            "anchor": anchor,
            "pos": p,
            "neg": n
        }

    def __len__(self):
        return self.size

In [10]:
dataset = TripletDataset(anchors, positives, negatives, transform)

In [11]:
print(f"Dataset size: {dataset.size}")

Dataset size: 469


In [12]:
from torch.utils.data.sampler import SubsetRandomSampler, SequentialSampler

In [13]:
batch_size = 16
validation_percentage = 0.2
shuffle_dataset = True
random_seed = 42

dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_percentage * dataset_size))
np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(dataset,
                                           batch_size=batch_size,
                                           sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=batch_size,
                                         sampler=val_sampler)

In [14]:
next(iter(train_loader))

{'anchor': {'image': tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            ...,
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]],
  
           [[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            ...,
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]],
  
           [[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            ...,
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]]],
  
  
          [[[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0.,

## Models

In [15]:
import torch
import torchvision.models as models

In [16]:
class MyResNet(torch.nn.Module):
    def __init__(self, requires_grad=False):
        super(MyResNet, self).__init__()
        resnet = models.resnet50(pretrained=True)

        self.start = torch.nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool
        )

        self.slice1 = resnet.layer1
        self.slice2 = resnet.layer2
        self.slice3 = resnet.layer3
        self.slice4 = resnet.layer4
        
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h = self.start(X)
        h = self.slice1(h)
        h_relu1_2 = h
        h = self.slice2(h)
        h_relu2_2 = h
        h = self.slice3(h)
        h_relu3_3 = h
        h = self.slice4(h)
        h_relu4_3 = h
        return [h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3]

In [17]:
class MyModel(torch.nn.Module):
    def __init__(self, requires_grad=False, add_conv=True):
        super(MyModel, self).__init__()
        self.add_conv = add_conv
        self.model = MyResNet(requires_grad=requires_grad)
        if self.add_conv:
            self.conv1 = torch.nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=7)
            self.max1 = torch.nn.MaxPool2d(kernel_size=2)
            self.conv2 = torch.nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=4)
            self.max2 = torch.nn.MaxPool2d(kernel_size=2)
            self.conv3 = torch.nn.Conv2d(in_channels=2048, out_channels=2048, kernel_size=2)
            self.max3 = torch.nn.MaxPool2d(kernel_size=2)
            self.flat = torch.nn.Flatten()

    def forward(self, batch):
        x = batch["image"]
        x = self.model(x)[1]
        if self.add_conv:
            x = self.conv1(x)
            x = self.max1(x)
            x = self.conv2(x)
            x = self.max2(x)
            x = self.conv3(x)
            x = self.max3(x)
            x = self.flat(x)
        
        return x # return list of size 1 to have same API as the other encoders that return list of mulitple encodings

In [18]:
model = MyModel()

In [19]:
if torch.cuda.is_available():
    model = model.to("cuda:0")

In [20]:
model.eval

<bound method Module.eval of MyModel(
  (model): MyResNet(
    (start): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    )
    (slice1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inpla

In [21]:
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)

conv1.weight
conv1.bias
conv2.weight
conv2.bias
conv3.weight
conv3.bias


## Training

In [22]:
def l2_dist_sum_weighted_per_batch(input_features, target_features):
    """
    Calculates l2 distance between input and target features and weights distance by number of feature maps.
    :param input_features: list of feature-maps
    :param target_features: list of feature-maps
    :param reduction: see nn.MSELoss --> this controls if we reduce all batches at the end by summing/averaging or not.
    :return:
    """
    batch_dists = []
    for i in range(input_features[0].shape[0]):
        dists = []
        for input, target in zip(input_features, target_features):
            dists.append(nn.MSELoss()(input[i], target[i]))
        batch_dists.append(torch.mean(torch.stack(dists)))
    return torch.stack(batch_dists)

In [23]:
class TripletLoss(nn.Module):

    def __init__(self, margin=0.5, reduction_mode="mean"):
        super(TripletLoss, self).__init__()
        self.margin = margin
        self.reduction_mode = reduction_mode

    def forward(self, encodings):

        # get anchor, pos, negs from encodings
        anchor = [encodings[i]["anchor"] for i in range(len(encodings))]
        pos = [encodings[i]["pos"] for i in range(len(encodings))]
        negs = [encodings[i]["neg"] for i in range(len(encodings))]
        if len(negs[0]) != 1:
            raise ValueError(f"Only one negative is allowed for TripletLoss, but got: {len(negs[0])}")
        negs = [negs[k][0] for k in range(len(negs))]

        # calculate anchor->pos dist
        pos_dist = l2_dist_sum_weighted_per_batch(anchor, pos)

        # calculate anchor->neg dist for all negs
        neg_dist = l2_dist_sum_weighted_per_batch(anchor, negs)

        losses = F.relu(pos_dist - neg_dist + self.margin)

        # calculate triplet loss
        if self.reduction_mode == "mean":
            return F.relu(pos_dist - neg_dist + self.margin).mean(), pos_dist.mean().data.cpu().numpy(), neg_dist.mean().data.cpu().numpy()
        elif self.reduction_mode == "sum":
            return F.relu(pos_dist - neg_dist + self.margin).sum(), pos_dist.sum().data.cpu().numpy(), neg_dist.sum().data.cpu().numpy()
        else:
            raise ValueError(f"Unsupported reduction_mode:{self.reduction_mode}")

In [24]:
loss_func=TripletLoss(margin=0.1, reduction_mode="mean")

In [25]:
optim = torch.optim.Adam
scheduler=torch.optim.lr_scheduler.StepLR
optim_args={
    "lr": 1e-5,
    "betas": (0.9, 0.999),
    "eps": 1e-8,
    "weight_decay": 0.0
}
scheduler_args={
    "step_size": 15,
    "gamma": 0.1,
}
top_k_accs = [1, 5]

In [26]:
optim = optim(filter(lambda p: p.requires_grad, model.parameters()), **optim_args)
scheduler = scheduler(optim, **scheduler_args)

In [27]:
def forward_pass(model, sample, device):
    sample = triplets_as_batches(sample, 1)
    if torch.cuda.is_available():
        sample["image"] = sample["image"].to(device)
    encodings = model(sample)
    encodings = outputs_as_triplets(encodings, 1)
    loss, pos_dist, neg_dist = loss_func(encodings)

    return loss, pos_dist, neg_dist

In [28]:
def train_one_epoch(model, train_loader, optim, epoch, iter_per_epoch, device, 
                    log_nth_iter, log_nth_epoch, num_epochs):
    model.train()  # TRAINING mode (for dropout, batchnorm, etc.)
    train_losses = []
    train_pos_dists = []
    train_neg_dists = []

    train_minibatches = train_loader
    train_minibatches = tqdm(train_minibatches)

    for i, sample in enumerate(train_minibatches):  # for every minibatch in training set

        # FORWARD PASS --> Loss + acc calculation
        train_loss, train_pos_dist, train_neg_dist = forward_pass(model, sample, device)
        
        # BACKWARD PASS --> Gradient-Descent update
        train_loss.backward()
        optim.step()
        optim.zero_grad()

        # LOGGING of loss and accuracy
        train_loss = train_loss.data.cpu().numpy()
        train_losses.append(train_loss)
        train_pos_dists.append(train_pos_dist)
        train_neg_dists.append(train_neg_dist)

        # Print loss every log_nth iteration
        if log_nth_iter != 0 and (i + 1) % log_nth_iter == 0:
            print("[Iteration {cur}/{max}] TRAIN loss: {loss}".format(cur=i + 1,
                                                                          max=iter_per_epoch,
                                                                          loss=train_loss))

    # ONE EPOCH PASSED --> calculate + log mean train accuracy/loss for this epoch
    mean_train_loss = np.mean(train_losses)
    mean_train_pos_dist = np.mean(train_pos_dists)
    mean_train_neg_dist = np.mean(train_neg_dists)

    if log_nth_epoch != 0 and (epoch + 1) % log_nth_epoch == 0:
        print("[EPOCH {cur}/{max}] TRAIN mean loss / pos_dist / neg_dist: {loss}, {pos_dist}, {neg_dist}".format(
            cur=epoch + 1,
            max=num_epochs,
            loss=mean_train_loss,
            pos_dist=mean_train_pos_dist,
            neg_dist=mean_train_neg_dist))

    return mean_train_loss, mean_train_pos_dist, mean_train_neg_dist

In [29]:
def val_one_epoch(model, val_loader, epoch, device, log_nth_iter, log_nth_epoch, num_epochs):
    # ONE EPOCH PASSED --> calculate + log validation accuracy/loss for this epoch
    model.eval()  # EVAL mode (for dropout, batchnorm, etc.)
    with torch.no_grad():
        val_losses = []
        val_pos_dists = []
        val_neg_dists = []

        val_minibatches = val_loader
        val_minibatches = tqdm(val_minibatches)

        for i, sample in enumerate(val_minibatches):
            # FORWARD PASS --> Loss + acc calculation
            val_loss, val_pos_dist, val_neg_dist = forward_pass(model, sample, device)

            # LOGGING of loss and accuracy
            val_loss = val_loss.data.cpu().numpy()
            val_losses.append(val_loss)
            val_pos_dists.append(val_pos_dist)
            val_neg_dists.append(val_neg_dist)

            # Print loss every log_nth iteration
            if log_nth_iter != 0 and (i + 1) % log_nth_iter == 0:
                    print("[Iteration {cur}/{max}] Val loss: {loss}".format(cur=i + 1,
                                                                            max=len(val_loader),
                                                                            loss=val_loss))

        mean_val_loss = np.mean(val_losses)
        mean_val_pos_dist = np.mean(val_pos_dists)
        mean_val_neg_dist = np.mean(val_neg_dists)

        if log_nth_epoch != 0 and (epoch + 1) % log_nth_epoch == 0:
            print("[EPOCH {cur}/{max}] VAL mean loss / pos_dist / neg_dist: {loss}, {pos_dist}, {neg_dist}".format(
                    cur=epoch + 1,
                    max=num_epochs,
                    loss=mean_val_loss,
                    pos_dist=mean_val_pos_dist,
                    neg_dist=mean_val_neg_dist))

        return mean_val_loss, mean_val_pos_dist, mean_val_neg_dist

In [30]:
def train(model, train_loader, val_loader, start_epoch=0, num_epochs=10, 
          log_nth_iter=1, log_nth_epoch=1):
        """
        Train a given model with the provided data.
        Inputs:
        - model: model object initialized from a torch.nn.Module
        - train_loader: train data in torch.utils.data.DataLoader
        - val_loader: val data in torch.utils.data.DataLoader
        - num_epochs: total number of training epochs
        - log_nth: log training accuracy and loss every nth iteration
        """

        # model to cuda before optim creation: https://pytorch.org/docs/stable/optim.html
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        model.to(device)

        # init loss / acc history and train length
        iter_per_epoch = len(train_loader)

        # start training
        print('START TRAIN on device: {}'.format(device))

        max_epoch = start_epoch + num_epochs
        epochs = range(start_epoch, max_epoch)
        epochs = tqdm(range(start_epoch, max_epoch))

        # epoch loop
        for epoch in epochs:

            # train iterations for one epoch
            mean_train_loss, mean_train_pos_dist, mean_train_neg_dist = train_one_epoch(model,
                                                                                        train_loader,
                                                                                        optim,
                                                                                        epoch,
                                                                                        iter_per_epoch,
                                                                                        device,
                                                                                        log_nth_iter,
                                                                                        log_nth_epoch,
                                                                                        num_epochs)

            # val iterations for one epoch
            mean_val_loss, mean_val_pos_dist, mean_val_neg_dist = val_one_epoch(model,
                                                                                val_loader,
                                                                                epoch,
                                                                                device,
                                                                                log_nth_iter,
                                                                                log_nth_epoch,
                                                                                num_epochs)

            # Decay Learning Rate
            # Currently, only the ReduceLROnPlateau scheduler needs an argument (last val loss).
            # ALl others are a non-argument call to step() method.
            if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                scheduler.step(self.val_loss_history[-1])
            else:
                scheduler.step()

In [None]:
train(model, train_loader, val_loader, num_epochs=20, log_nth_iter=0, log_nth_epoch=1)

START TRAIN on device: cpu


  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

[EPOCH 1/20] TRAIN mean loss / pos_dist / neg_dist: 0.09514176100492477, 0.0039747622795403, 0.010540000163018703


  0%|          | 0/6 [00:00<?, ?it/s]

[EPOCH 1/20] VAL mean loss / pos_dist / neg_dist: 0.08688118308782578, 0.016223536804318428, 0.03737553954124451


  0%|          | 0/24 [00:00<?, ?it/s]

[EPOCH 2/20] TRAIN mean loss / pos_dist / neg_dist: 0.06816696375608444, 0.04619154334068298, 0.1385985165834427


  0%|          | 0/6 [00:00<?, ?it/s]

[EPOCH 2/20] VAL mean loss / pos_dist / neg_dist: 0.08459172397851944, 0.05373213067650795, 0.14346247911453247


  0%|          | 0/24 [00:00<?, ?it/s]

[EPOCH 3/20] TRAIN mean loss / pos_dist / neg_dist: 0.05093863233923912, 0.050353217869997025, 0.20076493918895721


  0%|          | 0/6 [00:00<?, ?it/s]

[EPOCH 3/20] VAL mean loss / pos_dist / neg_dist: 0.08247614651918411, 0.06875953078269958, 0.1957845538854599


  0%|          | 0/24 [00:00<?, ?it/s]

[EPOCH 4/20] TRAIN mean loss / pos_dist / neg_dist: 0.03523515164852142, 0.06831726431846619, 0.31920644640922546


  0%|          | 0/6 [00:00<?, ?it/s]

## Saving the trained model

In [None]:
def save_model(modelname, model):
    from pathlib import Path
    path = "./saved_models"
    Path(path).mkdir(parents=True, exist_ok=True)
    # Might need to make sure, that the correct saved_results directory is chosen here.
    filepath = path + "/" + modelname + ".pt"
    torch.save(model.state_dict(), filepath)

In [None]:
# Create unique ID for this training process for saving to disk.
from datetime import datetime
import uuid
now = datetime.now() # current date and time
id_suffix = now.strftime("%Y-%b-%d_%H-%M-%S")

In [None]:
modelname = "model_" + id_suffix
save_model(modelname, model)
print(modelname)