In [None]:
import os
from io import BytesIO
import json
import pickle

import requests
from PIL import Image
from tqdm import tqdm
import boto3
import numpy as np
import torch
from torchvision.models import vgg
import torchvision.transforms as transforms
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
from matplotlib import gridspec

In [None]:
vgg16 = vgg.vgg16(pretrained=True)
vgg16 = vgg16.eval()  # for no dropout behaviour

In [None]:
LABELS_URL = "https://s3.amazonaws.com/outcome-blog/imagenet/labels.json"

# Let's get our class labels for this model.
response = requests.get(LABELS_URL)  # Make an HTTP GET request and store the response.
labels = {int(key): value for key, value in response.json().items()}

## Higher res images straight from S3

In [None]:
bucket_name = "wellcomecollection-miro-images-public"

In [None]:
sts = boto3.client("sts")
assumed_role_object = sts.assume_role(
    RoleArn="arn:aws:iam::760097843905:role/calm-assumable_read_role",
    RoleSessionName="AssumeRoleSession1",
)
credentials = assumed_role_object["Credentials"]

In [None]:
s3_fetch = boto3.resource(
    "s3",
    aws_access_key_id=credentials["AccessKeyId"],
    aws_secret_access_key=credentials["SecretAccessKey"],
    aws_session_token=credentials["SessionToken"],
)

In [None]:
bucket = s3_fetch.Bucket(bucket_name)
bucket_info = bucket.meta.client.list_objects(Bucket=bucket.name, Delimiter="/")

In [None]:
# Get all folder names.
folder_names = [f["Prefix"] for f in bucket_info.get("CommonPrefixes")]
print("{} image folders".format(len(folder_names)))  # 219

# Get all file dirs from all folders. Takes a minute or so
print("Getting all file dir names for all images...")
file_dirs = []
for folder_name in tqdm(folder_names):
    file_dirs.extend([s.key for s in bucket.objects.filter(Prefix=folder_name)])
print("{} image files".format(len(file_dirs)))  # 120589

### Get one image, or an high res image path

In [None]:
image_name = "V0001893"  # ['V0001893', 'V0047369EL']
file_dir = [f for f in file_dirs if image_name in f][0]

obj = s3_fetch.Object(bucket_name, file_dir)
im = Image.open(BytesIO(obj.get()["Body"].read()))
file_name = os.path.splitext(os.path.basename(file_dir))[0]
im
im.save("../medium_blog_images/{}.png".format(file_name))

In [None]:
image_names = [
    os.path.splitext(os.path.basename(f))[0] for f in np.random.choice(file_dirs, 11)
] + ["A0000001"]

In [None]:
image_names = [
    "L0061160",
    "L0038847",
    "B0006893",
    "V0010192EL",
    "V0025035",
    "L0052856",
    "V0050358",
    "L0008713",
    "V0007884EL",
    "M0012095",
    "V0010104",
    "A0000001",
]

In [None]:
images = {}
plot_images_sizes = []
for image_name in tqdm(image_names):
    file = [f for f in file_dirs if image_name in f][0]
    obj = s3_fetch.Object(bucket_name, file)
    im = Image.open(BytesIO(obj.get()["Body"].read()))
    if im.mode != "RGB":
        im = im.convert("RGB")
    im.thumbnail((224, 224), resample=Image.BICUBIC)
    plot_images_sizes.append(im.size)
    images[image_name] = im

In [None]:
# Multi-row image

fig = plt.figure(figsize=(20, 10))
columns = 6
for i, (image_name, im) in enumerate(images.items()):
    ax = plt.subplot(np.ceil(len(image_names) / columns), columns, i + 1)
    # plt.title(image_name)
    ax.set_axis_off()
    plt.imshow(im)

In [None]:
def print_path(image_names, columns=len(image_names)):

    images = {}
    plot_images_sizes = []
    for image_name in tqdm(image_names):
        file = [f for f in file_dirs if image_name in f][0]
        obj = s3_fetch.Object(bucket_name, file)
        im = Image.open(BytesIO(obj.get()["Body"].read()))
        if im.mode != "RGB":
            im = im.convert("RGB")
        im.thumbnail((224, 224), resample=Image.BICUBIC)
        plot_images_sizes.append(im.size)
        images[image_name] = im

    max_y = max([c[1] for c in plot_images_sizes])
    rescale_x = [c[0] * max_y / c[1] for c in plot_images_sizes]
    columns = len(image_names)
    fig = plt.figure(figsize=(20, 30))
    gs = gridspec.GridSpec(1, columns, width_ratios=rescale_x)

    for i, (image_name, im) in enumerate(images.items()):
        ax = plt.subplot(gs[i])
        ax.set_axis_off()
        plt.imshow(im)

In [None]:
image_names = [
    "B0008895",
    "N0021591",
    "B0007199",
    "A0001358",
    "V0007108",
    "V0036001",
    "V0037737",
    "V0026902EL",
    "M0010374",
]
print_path(image_names)

In [None]:
image_names = ["V0001893", "V0047369EL"]
print_path(image_names, 6)

In [None]:
image_names = ["B0000663", "V0014173"]
print_path(image_names, 6)

In [None]:
image_names = ["L0078444", "L0078481"]
print_path(image_names, 6)

In [None]:
image_names = ["V0003760", "V0006594"]
print_path(image_names, 6)

In [None]:
# One row image
image_names = [
    "A0000001",
    "A0000002",
    "A0000003",
    "A0001260",
    "B0007248",
    "B0004589",
    "B0004848",
    "B0006893",
]
print_path(image_names)

In [None]:
image_names = [
    "A0000785",
    "B0001152",
    "A0000318",
    "V0007884EL",
    "L0027241",
    "V0013859",
    "V0013040",
    "V0040933",
]
print_path(image_names)

In [None]:
image_names = [
    "V0044783",
    "V0022904ER",
    "V0021741",
    "V0021867",
    "V0021857",
    "V0023111",
    "V0023117",
]
print_path(image_names)

In [None]:
image_names = [
    "V0044783",
    "V0022904ER",
    "V0021741",
    "A0000113",
    "B0004207",
    "V0043888",
    "V0023376",
    "V0046793",
]
print_path(image_names)

In [None]:
image_names = [
    "V0001893",
    #  'V0003665',
    "V0026311",
    "V0031656",
    "V0007101ER",
    "L0027175",
    "V0042795EL",
    "V0044410",
    "V0042799EL",
    "V0047369EL",
]
print_path(image_names)

In [None]:
image_names = [
    "L0032287",
    "M0012716",
    "V0049671",
    "M0006130",
    "L0040595",
    "L0056834",
    "V0030245",
    "V0029003",
    "L0034782",
    "A0000632",
]
print_path(image_names)

In [None]:
image_names = [
    "L0045857",
    "L0045858",
    "L0045856",
    "L0045886",
    "V0005269",
    "V0032946ER",
    "V0033137EL",
    "V0035635ER",
    "V0035703",
    "V0035629ER",
]
print_path(image_names)

## Output of vgg

In [None]:
min_img_size = (
    224  # The min size, as noted in the PyTorch pretrained models doc, is 224 px.
)
transform_pipeline = transforms.Compose(
    [
        transforms.Resize(min_img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

In [None]:
vgg16 = vgg.vgg16(pretrained=True)
vgg16 = vgg16.eval()  # for no dropout behaviour

In [None]:
def predict_image(transform_pipeline, im, model, labels):

    img = transform_pipeline(im)
    img = img.unsqueeze(0)

    # Now let's get a prediction!
    prediction = model(img)  # Returns a Tensor of shape (batch, num class labels)
    return labels[prediction.data.numpy().argmax()]

In [None]:
image_name = "B0008895"
file_dir = [f for f in file_dirs if image_name in f][0]

obj = s3_fetch.Object(bucket_name, file_dir)
im = Image.open(BytesIO(obj.get()["Body"].read()))
if im.mode != "RGB":
    im = im.convert("RGB")

In [None]:
print(predict_image(transform_pipeline, im, vgg16, labels))
im.resize((200, 200), resample=Image.BILINEAR)

In [None]:
img = transform_pipeline(im)
img = img.unsqueeze(0)
img

### output the FV

In [None]:
# Create all the images transforms
min_img_size_fv = (
    224,
    224,
)  # The min size, as noted in the PyTorch pretrained models doc, is 224 px.
transform_pipeline_fv = transforms.Compose(
    [
        transforms.Resize(min_img_size_fv),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
# Remove the last layer from the model, so that the output will be a feature vector
vgg16_short = vgg16
vgg16_short.classifier = vgg16.classifier[:4]

In [None]:
image_name = "M0010374"
file_dir = [f for f in file_dirs if image_name in f][0]

obj = s3_fetch.Object(bucket_name, file_dir)
im = Image.open(BytesIO(obj.get()["Body"].read()))
if im.mode != "RGB":
    im = im.convert("RGB")

img = transform_pipeline(im)
img = img.unsqueeze(0)
vgg16_short(img)