In [None]:
from src.models.triplet_retriever import TripletRetriever
import torch
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from src.data.ucmerced_dataset import TripletDataModule
from src.settings import UC_MERCED_DATA_DIRECTORY, PATTERN_NET_DATA_DIRECTORY
from src.evaluation import evaluate_anmrr
import plotly.express as px

from IPython.display import Image
from IPython.display import display
from skimage import io

from src.visualisation import visualize_best_and_worst_queries


In [None]:
import numpy as np
from PIL import Image
def visualize_tsne_embeddings(embeddings: np.array, image_paths, name):

    def get_image(path):
        img = Image.open(path)
        # img.resize((10,10))
        a = np.asarray(img)
        return OffsetImage(a, zoom=0.15)


    fig, ax = plt.subplots(figsize=(15,15))
    ax.scatter(embeddings[:, 0], embeddings[:, 1]) 
    for image_path, (x, y) in zip(image_paths, embeddings):
        ab = AnnotationBbox(get_image(image_path), (x, y), frameon=False)
        ax.add_artist(ab)
    fig.savefig(name, dpi=300, bbox_inches='tight', pad_inches=0)
    plt.show()

In [None]:
def calculate_embeddings(model, dataloader):
    paths = []
    embeddings = []
    classes = []
    with torch.no_grad():
        for i_batch, sample_batched in enumerate(dataloader):
            anchors = sample_batched['a'].cuda()
            y = sample_batched['a_y']
            classes.append(y.cpu().numpy())
            anchor_paths = sample_batched['path']
            paths.extend(anchor_paths)
            a = model(anchors).cpu().numpy()
            embeddings.append(a)

        embeddings = np.concatenate(embeddings)
        classes = np.concatenate(classes)
    return paths, embeddings, classes

In [None]:
def analyze_embeddings(paths, embeddings, classes, name):
    tsne_embeddings = TSNE(n_components=2).fit_transform(embeddings)
    visualize_tsne_embeddings(tsne_embeddings, paths, name)

In [None]:
def load_model_dataloader(model_name, output_size, ckpt_path, data_path):
    checkpoint = torch.load(ckpt_path)
    model = TripletRetriever(model_name, 224, output_size)
    model.load_state_dict(checkpoint['state_dict'])
    model = model.cuda()
    dm = TripletDataModule(data_path, 224, 0.8, 100)
    dm.setup(None)
    val = dm.val_dataloader()
    return model, val, dm

In [None]:
def analyze_anmrr(model, dataloader, label_name_mapping):
    with torch.no_grad():
        anmrr, anmrr_per_class = evaluate_anmrr(model, val, euclidean_distances, True)
        x, y = zip(*anmrr_per_class)
        x = list(x)
        y = list(y)
        label_name_mapping = label_name_mapping
        names = [label_name_mapping[l] for l in x]
        fig = px.bar(x=names, y=y, labels={'y': 'ANMRR'})
        fig.update_xaxes(type='category')
        fig.show(renderer='browser')


In [None]:
def show_example_queries(paths, embeddings, classes, label_name_mapping):
    distances = euclidean_distances(embeddings)
    if(len(classes.shape) < 2):
        classes = classes[:, None]

    paths = np.array(paths).squeeze()
    rankings = np.argsort(distances, axis=1)
    selected_images = paths[rankings]

    cols = 3
    rows = 3

    for label, name in label_name_mapping.items():
        
        indices_with_class = np.argwhere(classes == label)[:, 0].squeeze()
        example_query_index = np.random.choice(indices_with_class)

        query_image_path = paths[example_query_index]
        example_query = selected_images[example_query_index, :].squeeze()

        query_image = io.imread(query_image_path)
        
        fig=plt.figure(figsize=(3, 3))
        plt.imshow(query_image)
        query_image_name = os.path.split(query_image_path)[1]
        plt.title(f"Query: {query_image_name}")
        plt.axis("off")
        plt.show()
        

        fig=plt.figure(figsize=(8, 8))
        for i in range (cols * rows):
            path = example_query[i]
            image = io.imread(path)
            fig.add_subplot(rows, cols, i+1)
            plt.title(os.path.split(path)[1])
            plt.axis("off")
            
            plt.imshow(image)
        fig.suptitle(f"Response to query: {query_image_name}")
        plt.show()


In [None]:
CKPT_PATH = "triplet_retrieval_uc_merced\\14fd0t9p\\checkpoints\\epoch=35-val_anmrr=0.19.ckpt"
model_name = 'resnet50'
output_size = 50

model, val, dm = load_model_dataloader(model_name, output_size, CKPT_PATH, UC_MERCED_DATA_DIRECTORY)

In [None]:
paths, embeddings, classes = calculate_embeddings(model, val)
analyze_embeddings(paths, embeddings, classes, "merced.png")

In [None]:
analyze_anmrr(model, val, dm.label_name_mapping)

In [None]:
show_example_queries(paths, embeddings, classes, dm.label_name_mapping)

In [None]:
CKPT_PATH = "triplet_retrieval_pattern_net\\115p0rcf\\checkpoints\\epoch=46-val_anmrr=0.11.ckpt"
model_name = 'resnet18'
output_size = 50

model, val, dm = load_model_dataloader(model_name, output_size, CKPT_PATH, PATTERN_NET_DATA_DIRECTORY)


In [None]:
paths, embeddings, classes = calculate_embeddings(model, val)
analyze_embeddings(paths, embeddings, classes, "pattern_net.png")

In [None]:
analyze_anmrr(model, val, dm.label_name_mapping)

In [None]:
show_example_queries(paths, embeddings, classes, dm.label_name_mapping)