In [None]:
from PIL import Image
from io import BytesIO
import boto3
import numpy as np
import torch
from torchvision.models import vgg
import torchvision.transforms as transforms
import requests
import os
import networkx as nx
from sklearn.metrics.pairwise import euclidean_distances
import matplotlib.pyplot as plt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from itertools import combinations
from scipy.spatial.distance import cdist

import time

from tqdm import tqdm

import umap.umap_ as umap

In [None]:
vgg16 = vgg.vgg16(pretrained=True)
vgg16 = vgg16.eval()  # for no dropout behaviour

In [None]:
LABELS_URL = "https://s3.amazonaws.com/outcome-blog/imagenet/labels.json"

# Let's get our class labels for this model.
response = requests.get(LABELS_URL)  # Make an HTTP GET request and store the response.
labels = {int(key): value for key, value in response.json().items()}

## 1. Get images from S3

In [None]:
bucket_name = "wellcomecollection-miro-images-public"

In [None]:
sts = boto3.client("sts")
assumed_role_object = sts.assume_role(
    RoleArn="arn:aws:iam::760097843905:role/calm-assumable_read_role",
    RoleSessionName="AssumeRoleSession1",
)
credentials = assumed_role_object["Credentials"]

In [None]:
s3 = boto3.resource(
    "s3",
    aws_access_key_id=credentials["AccessKeyId"],
    aws_secret_access_key=credentials["SecretAccessKey"],
    aws_session_token=credentials["SessionToken"],
)

In [None]:
bucket = s3.Bucket(bucket_name)
bucket_info = bucket.meta.client.list_objects(Bucket=bucket.name, Delimiter="/")

In [None]:
# Get all folder names.
folder_names = [f["Prefix"] for f in bucket_info.get("CommonPrefixes")]
print("{} image folders".format(len(folder_names)))  # 219

# Get all file dirs from all folders. Takes a minute or so
print("Getting all file dir names for all images...")
file_dir = []
for folder_name in tqdm(folder_names):
    file_dir.extend([s.key for s in bucket.objects.filter(Prefix=folder_name)])
print("{} image files".format(len(file_dir)))  # 120589

In [None]:
# Pick n random image directories and store them
n = 1000
np.random.seed(seed=0)  # Just for dev
random_file_dir = np.random.choice(file_dir, n, replace=False)

print("Storing {} random images...".format(n))
images = []
for file in tqdm(random_file_dir):
    obj = s3.Object(bucket_name, file)
    im = Image.open(BytesIO(obj.get()["Body"].read()))
    im.thumbnail((750, 750))
    if im.mode != "RGB":
        im = im.convert("RGB")
    images.append(im)

## 2. Predict image (optional)

In [None]:
min_img_size = (
    224  # The min size, as noted in the PyTorch pretrained models doc, is 224 px.
)
transform_pipeline = transforms.Compose(
    [
        transforms.Resize(min_img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

In [None]:
# Might need to re run if you overwrite it with vgg16_short
vgg16 = vgg.vgg16(pretrained=True)
vgg16 = vgg16.eval()  # for no dropout behaviour

In [None]:
def predict_image(transform_pipeline, im, model, labels):

    img = transform_pipeline(im)
    img = img.unsqueeze(0)

    # Now let's get a prediciton!
    prediction = model(img)  # Returns a Tensor of shape (batch, num class labels)
    prediction = (
        prediction.data.numpy().argmax()
    )  # Our prediction will be the index of the class label with the largest value.
    print(prediction)
    return labels[prediction]

In [None]:
im = images[5]
print(predict_image(transform_pipeline, im, vgg16, labels))
im.resize((200, 200), resample=Image.BILINEAR)

## 3. Extract feature vectors from images

In [None]:
min_img_size = (
    224  # The min size, as noted in the PyTorch pretrained models doc, is 224 px.
)
transform_pipeline = transforms.Compose(
    [
        transforms.Resize(min_img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

In [None]:
# Remove the last layer from the model, so that the output will be a feature vector
vgg16_short = vgg16
vgg16_short.classifier = vgg16.classifier[:4]

In [None]:
print("Getting feature vectors for {} images...".format(len(images)))
feature_vectors = []
for i, image in tqdm(enumerate(images)):
    img = transform_pipeline(image)
    img = img.unsqueeze(0)
    feature_vectors.append(vgg16_short(img).squeeze().tolist())

## 4. Get the pairwise distance matrix for the images, and the closest neighbours

In [None]:
dist_mat = cdist(feature_vectors, feature_vectors, metric="cosine")

In [None]:
dist_mat_top = np.zeros_like(dist_mat)
dist_mat_top[:] = None

In [None]:
n = 3

# Find the top n neighbours for each image

for i, _ in tqdm(enumerate(images)):
    arr = dist_mat[i].argsort()
    top_args = arr[arr != i]
    dist_mat_top[i][top_args[0:n]] = dist_mat[i][top_args[0:n]]
    for j in top_args[0:n]:
        dist_mat_top[j][i] = dist_mat[j][i]

## 5. Plot the network of images connected to their closest neighbours

In [None]:
def inv_rel_norm(value, min_val, max_val):
    value = (value - min_val) / (max_val - min_val)
    value = 1 / (value + 1e-8)
    return value

In [None]:
def create_graph(dist_mat_top):

    min_val = np.nanmin(dist_mat_top)
    max_val = np.nanmax(dist_mat_top)

    nodes = list(range(0, len(dist_mat_top[0])))

    G = nx.Graph()
    G.add_nodes_from(nodes)

    # Put the weights in as the distances
    # only inc nodes if they are in the closest related neighbours
    for start, end in list(combinations(nodes, 2)):
        if ~np.isnan(dist_mat_top[start, end]):
            # Since in the plot a higher weight makes the nodes closer,
            # but a higher value in the distance matrix means the images are further away,
            # we need to inverse the weight (so higher = closer)
            G.add_edge(
                start,
                end,
                weight=inv_rel_norm(dist_mat_top[start, end], min_val, max_val),
            )
    return G

In [None]:
def plot_graph(G, image_names=None):

    pos = nx.spring_layout(G)

    plt.figure(3, figsize=(10, 10))
    nx.draw(G, pos, node_size=10)
    for p in pos:  # raise text positions
        pos[p][1] += 0.06
    if image_names:
        image_names_dict = {k: str(k) + " " + v for k, v in enumerate(image_names)}
        nx.draw_networkx_labels(G, pos, labels=image_names_dict)
    plt.show()

In [None]:
G = create_graph(dist_mat_top)

In [None]:
plot_graph(G)

## 6. Visualise the clusters by reducing dimensions

In [None]:
reducer = umap.UMAP()
embedding_fv = reducer.fit_transform(feature_vectors)
embedding_fv.shape

In [None]:
# from https://www.kaggle.com/gaborvecsei/plants-t-sne
def visualize_scatter_with_images(X_2d_data, images, figsize=(45, 45), image_zoom=1):
    fig, ax = plt.subplots(figsize=figsize)
    artists = []
    for xy, i in zip(X_2d_data, images):
        x0, y0 = xy
        img = OffsetImage(i, zoom=image_zoom)
        ab = AnnotationBbox(img, (x0, y0), xycoords="data", frameon=False)
        artists.append(ax.add_artist(ab))
    ax.update_datalim(X_2d_data)
    ax.autoscale()
    plt.axis("off")
    plt.show()

In [None]:
x_data = [[a, b] for (a, b) in zip(embedding_fv[:, 0], embedding_fv[:, 1])]

In [None]:
visualize_scatter_with_images(x_data, images=images, image_zoom=0.1)

## Get a list of the biggest differences between 2 images

In [None]:
dist_mat_top[262]

In [None]:
np.nanargmax(dist_mat_top, axis=0)

## 7. Pick 2 images and look at the route between them

In [None]:
image_names_dict = {k: v for k, v in enumerate(random_file_dir)}

In [None]:
node1 = np.random.choice(list(image_names_dict))
node2 = np.random.choice(list(image_names_dict))

# nice path:
# node1 = 6
# node2 = 146

node_path = nx.dijkstra_path(G, node1, node2, weight=None)
print(node_path)

show_images = [images[i] for i in node_path]

fig = plt.figure(figsize=(20, 10))
columns = len(show_images)
for i, image in enumerate(show_images):
    ax = plt.subplot(len(show_images) / columns + 1, columns, i + 1)
    ax.set_axis_off()
    plt.imshow(image)