In [None]:
import numpy as np

from tqdm import tqdm

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# load data

In [None]:
palettes = torch.Tensor(np.load("/storage/data/palette/lab_palettes.npy"))
palette_permutations = torch.Tensor(
    np.load("/storage/data/palette/palette_permutations.npy")
)
distances = torch.Tensor(np.load("/storage/data/palette/distance_matrix.npy"))

In [None]:
all_indexes = np.arange(len(palette_permutations))
train_size = int(len(palette_permutations) * 0.9)

train_indexes = np.random.choice(all_indexes, size=train_size, replace=False)
test_indexes = np.setdiff1d(all_indexes, train_indexes)

In [None]:
train_palettes = palette_permutations[train_indexes]
test_palettes = palette_permutations[test_indexes]

In [None]:
train_distances = distances[train_indexes][:, train_indexes]
test_distances = distances[test_indexes][:, test_indexes]

# dataset and dataloader

In [None]:
class PaletteDistanceDataset(Dataset):
    def __init__(self, palette_permutations, distances, length):
        self.palette_permutations = palette_permutations
        self.dim_1 = palette_permutations.shape[0]
        self.dim_2 = palette_permutations.shape[1]
        self.distances = distances
        self.length = length

    def __getitem__(self, ix):
        ix_1, ix_2 = np.random.randint(self.dim_1, size=2)
        sub_ix_1, sub_ix_2 = np.random.randint(self.dim_2, size=2)

        palette_1 = self.palette_permutations[ix_1, sub_ix_1]
        palette_2 = self.palette_permutations[ix_2, sub_ix_2]
        target_distance = self.distances[ix_1, ix_2]
        return palette_1, palette_2, target_distance

    def __len__(self):
        return self.length

In [None]:
train_dataset = PaletteDistanceDataset(
    train_palettes, train_distances, length=100_000_000
)
test_dataset = PaletteDistanceDataset(test_palettes, test_distances, length=1_000_000)

In [None]:
palette_1, palette_2, target_distance = train_dataset.__getitem__(0)
palette_1, palette_2, target_distance

In [None]:
train_loader = DataLoader(
    dataset=train_dataset, batch_size=4096, num_workers=5, shuffle=True
)

test_loader = DataLoader(dataset=test_dataset, batch_size=4096, shuffle=True)

# write network

In [None]:
class PaletteEmbedder(nn.Module):
    def __init__(self):
        super().__init__()
        self.first_transform = nn.Sequential(
            nn.Linear(3, 6), nn.ReLU(), nn.Linear(6, 12)
        )
        self.second_transform = nn.Sequential(
            nn.Linear(60, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 30),
        )

    def forward(self, input_palettes):
        batch_size = input_palettes.shape[0]
        intermediate = self.first_transform(input_palettes)
        flattened = intermediate.reshape(batch_size, -1)
        embedded = self.second_transform(flattened)
        return embedded


class SiameseNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.palette_embedder = PaletteEmbedder()

    def forward(self, palettes_1, palettes_2):
        embeddings_1 = self.palette_embedder(palettes_1)
        embeddings_2 = self.palette_embedder(palettes_2)
        return embeddings_1, embeddings_2

In [None]:
model = SiameseNetwork()
model

# train

In [None]:
train_losses, test_losses = [], []
torch.backends.cudnn.benchmark = True
trainable_parameters = filter(lambda p: p.requires_grad, model.parameters())
optimiser = optim.Adam(trainable_parameters, lr=1e-3)
distance_metric = nn.PairwiseDistance()
loss_function = nn.MSELoss()

In [None]:
if device.type == "cuda":
    model.cuda()

In [None]:
def train(model, train_loader, distance_metric, loss_function, optimiser, n_epochs):
    for epoch in range(n_epochs):
        model.train()
        train_loop = tqdm(train_loader)
        for palettes_1, palettes_2, target_distances in train_loop:
            if device.type == "cuda":
                palettes_1 = palettes_1.cuda(non_blocking=True)
                palettes_2 = palettes_2.cuda(non_blocking=True)
                target_distances = target_distances.cuda(non_blocking=True)

            optimiser.zero_grad()
            embeddings_1, embeddings_2 = model(palettes_1, palettes_2)

            pred_distances = distance_metric(embeddings_1, embeddings_2)
            loss = loss_function(target_distances, pred_distances)
            loss.backward()
            optimiser.step()

            train_losses.append(np.sqrt(loss.cpu().item()))
            train_loop.set_description("Epoch {}/{}".format(epoch + 1, n_epochs))
            train_loop.set_postfix({"loss": np.mean(train_losses[-100:])})

        model.eval()
        test_loop = tqdm(test_loader)
        for palettes_1, palettes_2, target_distances in test_loop:
            if device.type == "cuda":
                palettes_1 = palettes_1.cuda(non_blocking=True)
                palettes_2 = palettes_2.cuda(non_blocking=True)
                target_distances = target_distances.cuda(non_blocking=True)

            embeddings_1, embeddings_2 = model(palettes_1, palettes_2)

            pred_distances = distance_metric(embeddings_1, embeddings_2)
            loss = loss_function(target_distances, pred_distances)

            test_losses.append(np.sqrt(loss.cpu().item()))
            test_loop.set_description("Test")
            test_loop.set_postfix({"loss": np.mean(test_losses[-100:])})

In [None]:
train(model, train_loader, distance_metric, loss_function, optimiser, 3)

In [None]:
torch.save(model.state_dict(), "/storage/code/palette/model_state_dict.pt")

# plot losses

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("whitegrid")
plt.rcParams["figure.figsize"] = (20, 20)

import pandas as pd

In [None]:
ax = pd.Series(train_losses).rolling(100).mean().plot()
ax.set_ylim(0, 60);