## Import Libraries

In [2]:
from PIL import Image
from tqdm import tqdm
import glob

from docarray import DocumentArray
from docarray import Document



from umap import UMAP
import pandas as pd
import requests
from bs4 import BeautifulSoup
from urllib.request import urlopen

## Download the Image Data

In [8]:
def download_image(url):
    image_name = url.split("/")[-1]
    response = requests.get(url)
    if response.status_code == 200:
        with open(f"../data/images/{image_name}.jpg", 'wb') as f:
            f.write(response.content)
            
def get_images(url = "https://collections.ushmm.org/search/?f%5Bf_images%5D%5B%5D=all_images&f%5Bf_images%5D%5B%5D=indiv_photographs&page=1&per_page=50"):
    r = requests.get(url)
    soup = BeautifulSoup(r.content)
    all_data = []
    total_page_nums = int(soup.find("ul", {"class": "pagination"}).find_all("li")[-1].text.strip())
    for i in range(1, 2):
        url = f"https://collections.ushmm.org/search/?f%5Bf_images%5D%5B%5D=all_images&f%5Bf_images%5D%5B%5D=indiv_photographs&page={i}&per_page=50"
        print(url)
        r = requests.get(url)
        soup = BeautifulSoup(r.content)
        items = soup.find_all("div", {"class": "document"})
        for item in items:
            try:
                title = item.find("a").text.strip()
                page_url = "https://collections.ushmm.org"+item.find("a")["href"]
                image_url = item.find("img")["src"]
                image_name = url.split("/")[-1]
                # image_file = f"../data/images/{image_name}.jpg"
                all_data.append(
                    (
                        title, page_url, image_url
                    )
                )
                download_image(image_url)
            except:
                TypeError

    return all_data
image_data = get_images()

https://collections.ushmm.org/search/?f%5Bf_images%5D%5B%5D=all_images&f%5Bf_images%5D%5B%5D=indiv_photographs&page=1&per_page=50


In [36]:
len(image_data)

48

In [9]:
da = DocumentArray.from_files('../data/images/*.jpg')

In [10]:
da[0]

In [11]:
da[0].display()

<IPython.core.display.Image object>

In [12]:
import timm
from PIL import Image
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

In [13]:
model = timm.create_model('mobilenetv3_large_100', pretrained=True, num_classes=0)
config = resolve_data_config({}, model=model)
config

Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth" to C:\Users\wma22/.cache\torch\hub\checkpoints\mobilenetv3_large_100_ra-f55367f5.pth


{'input_size': (3, 224, 224),
 'interpolation': 'bicubic',
 'mean': (0.485, 0.456, 0.406),
 'std': (0.229, 0.224, 0.225),
 'crop_pct': 0.875}

In [14]:
config['input_size'][1:]

(224, 224)

In [15]:
da[0].load_uri_to_image_tensor().set_image_tensor_shape(shape=(224, 224)).set_image_tensor_normalization()

In [16]:
da[0].tensor.shape

(224, 224, 3)

In [17]:
def preproc(d: Document):
    return (
        d.load_uri_to_image_tensor()  # load
        .set_image_tensor_shape(shape=(224, 224))
        .set_image_tensor_normalization()  # normalize color
        .set_image_tensor_channel_axis(-1, 0)
    )

In [18]:
da.apply(preproc)

In [19]:
da[0].tensor.shape

(3, 224, 224)

In [20]:
da.embed(model)

In [21]:
da.embeddings

tensor([[ 0.2896,  0.4548,  0.4600,  ..., -0.2828, -0.3103, -0.0508],
        [-0.2009,  0.3138, -0.3027,  ...,  0.5575,  1.0264, -0.3652],
        [ 0.4669, -0.2685, -0.3247,  ..., -0.3737,  1.7443,  0.2583],
        ...,
        [ 0.6113,  0.5803,  0.2144,  ...,  0.3690, -0.0207,  0.0458],
        [ 0.0876, -0.2452, -0.1533,  ..., -0.2297,  0.6221, -0.1041],
        [-0.2123, -0.3714, -0.2599,  ..., -0.3669,  0.5702, -0.3255]])

In [22]:
import umap

In [48]:
umap_proj = umap.UMAP(n_neighbors=15, min_dist=0.01, metric='correlation').fit_transform(da.embeddings)

In [49]:
umap_proj.shape

(48, 2)

In [50]:
umap_proj[0]

array([-0.03310011,  2.1810837 ], dtype=float32)

In [51]:
for d, coord in zip(da, umap_proj):
    d.tags['umap_proj_x'] = coord[0]
    d.tags['umap_proj_y'] = coord[1]

In [53]:
df = pd.DataFrame(image_data, columns=["title", "page_url", "image_url"])
df.head(1)

Unnamed: 0,title,page_url,image_url
0,German police and auxiliaries in civilian clot...,https://collections.ushmm.org/search/catalog/p...,https://collections.ushmm.org/iiif-b/assets/th...


In [63]:
# Apply coordinates
df['x'] = umap_proj[:, 0]
df['y'] = umap_proj[:, 1]

In [64]:
df.head(2)

Unnamed: 0,title,page_url,image_url,x,y
0,German police and auxiliaries in civilian clot...,https://collections.ushmm.org/search/catalog/p...,https://collections.ushmm.org/iiif-b/assets/th...,-0.0331,2.181084
1,A sign marking a mass grave in Bergen-Belsen.,https://collections.ushmm.org/search/catalog/p...,https://collections.ushmm.org/iiif-b/assets/th...,-0.039753,4.018694


In [65]:
output_file = "../data/ushmm_test_coords.csv"
df.to_csv(output_file, index=False)

## Visualizing our Clusters

In [28]:
from bokeh.io import curdoc
from bokeh.layouts import column, row
from bokeh.models import (Button, ColumnDataSource, DataTable, TableColumn, TextInput)
from bokeh.plotting import figure, show
from bokeh.models import DataTable, TableColumn, ColorBar, HTMLTemplateFormatter, Spinner, RangeSlider
from bokeh.io import output_notebook
from bokeh.application import Application
from bokeh.application.handlers import FunctionHandler
import numpy as np

In [29]:
output_notebook()

In [66]:
def bulk_text(path, keywords=None):
    df = pd.read_csv(path)
    df['alpha'] = 0.5
    if keywords:
        df['color'] = [determine_keyword(str(t), keywords) for t in df['text']]
        df['alpha'] = [0.4 if c == 'none' else 1 for c in df['color']]

    highlighted_idx = []

    # mapper, df = get_color_mapping(df)
    columns = [
        TableColumn(field="title", title="title", width=500),
        TableColumn(field="image_url", title="image", formatter=HTMLTemplateFormatter(template='<img src="<%= image_url %>" width=60>')),
        TableColumn(field="image_url", title="download", formatter=HTMLTemplateFormatter(template=r'<a href="<%= image_url %>", target="_blank">Download Image</a>')),
    ]

    def update(attr, old, new):
        """Callback used for plot update when lasso selecting"""
        global highlighted_idx
        subset = df.iloc[new]
        highlighted_idx = new
        subset = subset.iloc[np.random.permutation(len(subset))]
        source.data = subset

    def save():
        """Callback used to save highlighted data points"""
        global highlighted_idx
        df.iloc[highlighted_idx][['text']].to_csv(text_filename.value, index=False)

    source = ColumnDataSource(data=dict())
    source_orig = ColumnDataSource(data=df)

    data_table = DataTable(source=source, columns=columns, width=700, height=700)
    source.data = df

    p = figure(title="", sizing_mode="scale_both", tools=["lasso_select", "box_select", "pan", "box_zoom", "wheel_zoom", "reset"])
    p.toolbar.active_drag = None
    p.toolbar.active_inspect = None

    circle_kwargs = {"x": "x", "y": "y", "size": 1, "source": source_orig, "alpha": "alpha"}

    scatter = p.circle(**circle_kwargs)
    p.plot_width = 500
    if "color" in df.columns:
        p.plot_width=350
    p.plot_height = 700
    ## Spinner for Node Size
    spinner = Spinner(title="Circle Size", low = 1, high=60, step=1, value=scatter.glyph.size, width=200)
    spinner.js_link("value", scatter.glyph, "size")
    
    ## Adjust Row Height
    row_spinner = Spinner(title="Row Height", low = 50, high=1000, step=10, value=data_table.row_height, width=200)
    row_spinner.js_link("value", data_table, "row_height")
        
    scatter.data_source.selected.on_change('indices', update)

    text_filename = TextInput(value="out.csv", title="Filename:")
    save_btn = Button(label="SAVE")
    save_btn.on_click(save)

    plot = column(p)
    controls_main = column(spinner, text_filename)
    controls = column(row_spinner, save_btn)
    
    def make_doc(doc):
        doc.add_root(row(spinner, row_spinner))
        doc.add_root(row(plot, data_table))
        doc.add_root(row(text_filename))
        doc.add_root(row(save_btn))
    handler = FunctionHandler(make_doc)
    app=Application(handler)
    return app


app = bulk_text(output_file)
show(app)