In [None]:
from tqdm import tqdm
import os
from io import BytesIO
import ast
import numpy as np
import pickle

import torch
import boto3
from scipy.spatial.distance import cdist
import networkx as nx
import matplotlib.pyplot as plt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from itertools import combinations
import umap.umap_ as umap

### 1. Import the feature vectors from S3

In [None]:
# https://alexwlchan.net/2017/07/listing-s3-keys/
def get_all_s3_keys(bucket):
    """Get a list of all keys in an S3 bucket."""
    keys = []

    kwargs = {"Bucket": bucket}
    while True:
        resp = s3.list_objects_v2(**kwargs)
        for obj in resp["Contents"]:
            keys.append(obj["Key"])

        try:
            kwargs["ContinuationToken"] = resp["NextContinuationToken"]
        except KeyError:
            break

    return keys

In [None]:
bucket_name = "miro-images-feature-vectors"
s3 = boto3.client("s3")
keys = get_all_s3_keys(bucket_name)

In [None]:
len(keys)

In [None]:
folder_name = "reduced_feature_vectors_20_dims"
keys = [k for k in keys if k.split("/")[0] == folder_name]

In [None]:
len(keys)

In [None]:
feature_vectors = {}
for key in tqdm(keys):
    obj = s3.get_object(Bucket=bucket_name, Key=key)
    read_obj = obj["Body"].read()

    feature_vectors[key] = np.frombuffer(read_obj, dtype=np.float)

### 2. Get the distances between feature vectors

In [None]:
feature_vectors_list = list(feature_vectors.values())

In [None]:
[v for f in feature_vectors_list for v in f if "nan" in str(v)]

In [None]:
dist_mat = cdist(feature_vectors_list, feature_vectors_list, metric="cosine")

In [None]:
dist_mat

In [None]:
dist_mat_top = np.zeros_like(dist_mat)
dist_mat_top[:] = None

In [None]:
n = 3

# Find the top n neighbours for each image

for i, _ in tqdm(enumerate(keys)):
    arr = dist_mat[i].argsort()
    top_args = arr[arr != i]
    dist_mat_top[i][top_args[0:n]] = dist_mat[i][top_args[0:n]]
    for j in top_args[0:n]:
        dist_mat_top[j][i] = dist_mat[j][i]

### 3a. Load images for plotting

In [None]:
with open("../data/processed_images_sample.pkl", "rb") as handle:
    images_original = pickle.load(handle)

# Put in the same order as the feature vectors
images = []
for key in feature_vectors.keys():
    image_key = os.path.basename(key)
    images.append(images_original[image_key])

In [None]:
images[0]

### 3. Plot the network of images connected to their closest neighbours

In [None]:
def inv_rel_norm(value, min_val, max_val):
    value = (value - min_val) / (max_val - min_val)
    value = 1 / (value + 1e-8)
    return value

In [None]:
def create_graph(dist_mat_top):

    min_val = np.nanmin(dist_mat_top)
    max_val = np.nanmax(dist_mat_top)

    nodes = list(range(0, len(dist_mat_top[0])))

    G = nx.Graph()
    G.add_nodes_from(nodes)

    # Put the weights in as the distances
    # only inc nodes if they are in the closest related neighbours
    for start, end in list(combinations(nodes, 2)):
        if ~np.isnan(dist_mat_top[start, end]):
            # Since in the plot a higher weight makes the nodes closer,
            # but a higher value in the distance matrix means the images are further away,
            # we need to inverse the weight (so higher = closer)
            G.add_edge(
                start,
                end,
                weight=inv_rel_norm(dist_mat_top[start, end], min_val, max_val),
            )
    return G

In [None]:
def plot_graph(G, image_names=None):

    pos = nx.spring_layout(G)

    plt.figure(3, figsize=(10, 10))
    nx.draw(G, pos, node_size=10)
    for p in pos:  # raise text positions
        pos[p][1] += 0.06
    if image_names:
        image_names_dict = {k: str(k) + " " + v for k, v in enumerate(image_names)}
        nx.draw_networkx_labels(G, pos, labels=image_names_dict)
    plt.show()

In [None]:
G = create_graph(dist_mat_top)

In [None]:
plot_graph(G)

### 4. Visualise the clusters by reducing dimensions

In [None]:
reducer = umap.UMAP()
embedding_fv = reducer.fit_transform(feature_vectors_list)
embedding_fv.shape

In [None]:
# from https://www.kaggle.com/gaborvecsei/plants-t-sne
def visualize_scatter_with_images(X_2d_data, images, figsize=(45, 45), image_zoom=1):
    fig, ax = plt.subplots(figsize=figsize)
    artists = []
    for xy, i in zip(X_2d_data, images):
        x0, y0 = xy
        img = OffsetImage(i, zoom=image_zoom)
        ab = AnnotationBbox(img, (x0, y0), xycoords="data", frameon=False)
        artists.append(ax.add_artist(ab))
    ax.update_datalim(X_2d_data)
    ax.autoscale()
    plt.axis("off")
    plt.show()

In [None]:
x_data = [[a, b] for (a, b) in zip(embedding_fv[:, 0], embedding_fv[:, 1])]

In [None]:
visualize_scatter_with_images(x_data, images=images, image_zoom=0.3)

### 5a. Pick 2 images and look at the route between them

In [None]:
image_names_dict = {k: v for k, v in enumerate(images)}

In [None]:
node1 = np.random.choice(list(image_names_dict))
node2 = np.random.choice(list(image_names_dict))

node_path = nx.dijkstra_path(G, node1, node2, weight=None)
print(node_path)

show_images = [images[i] for i in node_path]

fig = plt.figure(figsize=(20, 10))
columns = len(show_images)
for i, image in enumerate(show_images):
    ax = plt.subplot(len(show_images) / columns + 1, columns, i + 1)
    ax.set_axis_off()
    plt.imshow(image)

### 5b. User sets number of images in pathway

In [None]:
path_size = 10

### 5c. Pick 3 images and look at the paths between them

In [None]:
node1 = np.random.choice(list(image_names_dict))
node2 = np.random.choice(list(image_names_dict))
node3 = np.random.choice(list(image_names_dict))

node_path_a = nx.dijkstra_path(G, node1, node2, weight=None)
node_path_b = nx.dijkstra_path(G, node2, node3, weight=None)
node_path_c = nx.dijkstra_path(G, node3, node1, weight=None)
node_path_3 = node_path_a[:-1] + node_path_b[:-1] + node_path_c
print(node_path_3)

show_images = [images[i] for i in node_path_3]

fig = plt.figure(figsize=(20, 10))
columns = len(show_images)
for i, image in enumerate(show_images):
    ax = plt.subplot(len(show_images) / columns + 1, columns, i + 1)

    if node_path_3[i] in [node1, node2, node3]:
        ax.set(title="NODE")
    ax.set_axis_off()
    plt.imshow(image)

### 6. Plot route on graph

In [None]:
# from https://www.kaggle.com/gaborvecsei/plants-t-sne
def visualize_scatter_pathway_with_images(
    X_2d_data, pathway, images, figsize=(45, 45), image_zoom=1
):
    fig, ax = plt.subplots(figsize=figsize)

    x_path = [x_data[c][0] for c in node_path]
    y_path = [x_data[c][1] for c in node_path]

    artists = []
    for num, (xy, i) in enumerate(zip(X_2d_data, images)):
        x0, y0 = xy
        if num in pathway:
            img = OffsetImage(i, zoom=image_zoom * 2, alpha=0.8)
        else:
            img = OffsetImage(i, zoom=image_zoom, alpha=0.2)
        ab = AnnotationBbox(img, (x0, y0), xycoords="data", frameon=False)
        artists.append(ax.add_artist(ab))
    ax.update_datalim(X_2d_data)
    ax.autoscale()
    plt.plot(x_path, y_path, "ro-", linewidth=5)
    plt.axis("off")

    plt.show()

In [None]:
visualize_scatter_pathway_with_images(
    x_data, node_path, images=images, figsize=(30, 30), image_zoom=0.3
)