In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("whitegrid")
plt.rcParams["figure.figsize"] = (20, 20)

import pandas as pd
from IPython.core.display import display, HTML

import numpy as np

from tqdm import tqdm

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# load data

In [None]:
palettes = torch.Tensor(np.load("/storage/data/palette/all_palettes.npy"))
image_ids = np.load("/storage/data/palette/all_image_ids.npy")

# load model

In [None]:
class PaletteEmbedder(nn.Module):
    def __init__(self):
        super().__init__()
        self.initial_transform = nn.Sequential(
            nn.Linear(3, 6), nn.ReLU(), nn.Linear(6, 12)
        )
        self.embedder = nn.Sequential(
            nn.Linear(60, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 30),
        )

    def forward(self, input_palettes):
        batch_size = input_palettes.shape[0]
        intermediate = self.initial_transform(input_palettes)
        flattened = intermediate.reshape(batch_size, -1)
        embedded = self.embedder(flattened)
        return embedded


class SiameseNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.palette_embedder = PaletteEmbedder()

    def forward(self, palettes_1, palettes_2):
        embeddings_1 = self.palette_embedder(palettes_1)
        embeddings_2 = self.palette_embedder(palettes_2)
        return embeddings_1, embeddings_2

In [None]:
model = SiameseNetwork()
model.load_state_dict(
    torch.load("/storage/code/palette/model_state_dict.pt", map_location="cpu")
)
model.eval()
embedder = model.palette_embedder

# embed all palettes

In [None]:
embedded_palettes = embedder(palettes)

# show some samples

In [None]:
sample = embedded_palettes[:10000].detach().numpy()
id_sample = image_ids[:10000]

In [None]:
from scipy.spatial.distance import cdist

In [None]:
distances = cdist(sample, sample, metric="euclidean")

In [None]:
image_url = "https://iiif.wellcomecollection.org/image/{}.jpg/full/960,/0/default.jpg"

In [None]:
test_ix = np.random.randint(10000)
display(
    HTML(
        "<a href='{}' target='_blank'>query image</a>".format(
            image_url.format(image_ids[test_ix])
        )
    )
)

In [None]:
for ix in distances[test_ix].argsort()[1:6]:
    display(
        HTML(
            "<a href='{}' target='_blank'>image</a>".format(
                image_url.format(image_ids[ix])
            )
        )
    )

# save the embeddings

In [None]:
np.save("/storage/data/palette/embedded_palettes", embedded_palettes.detach().numpy())