# Declarations

## Imports

In [None]:
import re
import os
import math
import string
import random
import requests
import importlib
import itertools

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

In [None]:
import numpy as np
import pandas as pd

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

import plotly.graph_objects as go

from tqdm import tqdm

from IPython.display import display

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer

tfk = tf.keras
tfkl = tf.keras.layers
kb = tf.keras.backend

In [None]:
print(tf.__version__)
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

## Constants

In [None]:
# Randomness
seed = 42

random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
tf.compat.v1.set_random_seed(seed)

In [None]:
# Filepaths
kaggle = True

model_versions = ["v4.0"]

github_repo = "raul-singh/Rise-of-Transformers-Project"
github_branch = "main"
github_python_prefix = ["Code", "Notebooks", "py_files"]
github_clip_models_prefix = ["Code", "Models"] if kaggle else ["..", "Models"]
github_pyfiles_data = [
    {"name": "preprocessing", "imports": ["import_datasets"]}, 
    {"name": "clip", "imports": ["build_clip"]}
]
github_pyfiles = ["/".join(github_python_prefix) + "/" + pf["name"] + ".py" for pf in github_pyfiles_data]
github_clip_models = [f"{'/'.join(github_clip_models_prefix)}/{version}.yaml" for version in model_versions]

kaggle_dataset1 = "/kaggle/input/transformers-hackathon/"
kaggle_dataset2 = "/kaggle/input/transformers-hackathon-features/"
kaggle_weights = "/kaggle/input/clip-weights/"
kaggle_relevance = "/kaggle/input/clip-relevance/"

image_dir = "./resized_train"
relevance_dir = "./relevance"
caption_pred_file = "caption_prediction_train.csv"
concept_det_file = "concept_detection_train.csv"
concept_file = "concepts.csv"
clip_weights_files = [f"{version}.h5" for version in model_versions] if kaggle else [None for _ in model_versions]

if kaggle:
    image_dir = kaggle_dataset1 + image_dir
    relevance_dir = kaggle_relevance + relevance_dir
    caption_pred_file = kaggle_dataset2 + caption_pred_file
    concept_det_file = kaggle_dataset2 + concept_det_file
    concept_file = kaggle_dataset2 + concept_file
    clip_weights_files = [kaggle_weights + weight for weight in clip_weights_files]

In [None]:
# Train/Val/Test split and filter percentages
test_size = 0.2
val_size = 0
filter_percent_dataset = 1

# Batch size
batch_size = 32

# Import dataset types and shapes
in_feat_typ = {'caption': tf.string, 'concepts': tf.bool, 'image path': tf.string}
feature_shapes = {'image': (128, 128, 3), 'caption': (), 'concepts': (8374)}

# Output dataset structure
x_features_eval = ['image path', 'image']
y_features_eval = ['caption', 'concepts']

# Define parameters for dataset import
dataset_parameters = [{
    'x_features': x_features_eval, 'y_features': y_features_eval,
    'x_dict': True, 'y_dict': True,           
    'shuffle_buffer_size': 1,
    'batch_size': batch_size,
    'cached': True,
}]

## Meta-Imports

In [None]:
def clean_recursive_imports(source, import_list, prefix):
    import_prefix = re.sub(r"/", ".", prefix)
    for target_import in import_list:
        source = re.sub(r"from[ \t]+" + re.escape(target_import) + r"[ \t]+import", f"from {import_prefix + target_import} import", source)
    return source
    
def import_py_from_repo(repository, branch, filepath, prefix, recursive_imports_list=None):
    # Build path for retrieval and write name
    path_pre = "https://raw.githubusercontent.com/"
    path = path_pre + repository + "/" + branch + "/" + filepath 
    write_path = prefix + filepath.split("/")[-1]
    print("Downloading file from " + path)
    # Obtain raw text from file
    text = requests.get(path).text
    # Clean recursive imports
    text = clean_recursive_imports(text, recursive_imports_list, prefix) if recursive_imports_list else text
    # Create subdirectories if not exist
    os.makedirs(os.path.dirname(write_path), exist_ok=True)
    # Write file
    f = open(write_path, "w")
    f.write(text)
    f.close()

In [None]:
if kaggle:
    for pf_data, py_file in zip(github_pyfiles_data, github_pyfiles):
        import_py_from_repo(
            github_repo, github_branch, py_file, 
            "/".join(github_python_prefix) + "/", 
            recursive_imports_list=[pf["name"] for pf in github_pyfiles_data],
        )
        import_string = f'from {".".join(github_python_prefix) + "." + pf_data["name"]} import {", ".join(pf_data["imports"])}'
        exec(import_string)
    
    for model in github_clip_models:
        import_py_from_repo(github_repo, github_branch, model, "/".join(github_clip_models_prefix) + "/")
        
else:
    for pf_data in github_pyfiles_data:
        import_string = f'from py_files.{pf_data["name"]} import {", ".join(pf_data["imports"])}'
        exec(import_string)

# Preprocessing

In [None]:
concept_info, datasets, dataset_sizes = import_datasets(
    image_dir, caption_pred_file, concept_file, concept_det_file,
    in_feat_typ, feature_shapes,
    dataset_parameters,
    filter_percent_dataset,
    test_size, val_size,
    seed,
)

In [None]:
# Select loaded datasets and variables
concept_list, concepts_onehot = concept_info
_, _, test_dataset = datasets[0]
train_ds_size, val_ds_size, test_ds_size = dataset_sizes

del datasets

# Model Import

In [None]:
models = []
for structure, weights in zip(github_clip_models, clip_weights_files):
    print(f"Creating model {structure}")
    clip_image_encoder, clip_text_encoder, clip = build_clip(structure, weights_path=weights)
    models.append({
        "image_encoder": clip_image_encoder,
        "clip_text_encoder": clip_text_encoder,
        "clip": clip,
    })

In [None]:
# Select the loaded model to evaluate
clip_image_encoder, clip_text_encoder, clip = models[0].values()

del models

# Model Evaluation

## Evaluation Definitions

### Evaluation Variables

In [None]:
# Top-k number
k = 10
# Threshold for concept overlap metric
concept_overlap_threshold = 2
# Visualization decimal precision
decimal_precision = 4
# Index to choose model from the array of models
model_index = 0
# Dictionaries used to load/save total relevance files
relevance_fileinfo_cap = {"path": relevance_dir, "test_split": test_size, "val_split": val_size, "metric": "cap"}
relevance_fileinfo_con = {"path": relevance_dir, "test_split": test_size, "val_split": val_size, "metric": "con", "other": [("conthresh", concept_overlap_threshold)]}
# Function to preprocess data when we want to evaluate captions
reference_preprocess_cap = lambda x: x["caption"].numpy().decode('UTF-8')          
# Function to preprocess data when we want to evaluate concepts
reference_preprocess_con = lambda x: x["concepts"].numpy()
reference_preprocess_con_hash = lambda x: frozenset(sorted(np.where(x["concepts"].numpy())[0]))
# Function to compute if a match is relevant given concept arrays 
concept_relevance = lambda m, o: np.count_nonzero(np.logical_and(m, o)) >= min(concept_overlap_threshold, np.count_nonzero(m), np.count_nonzero(o))
concept_relevance_hash = lambda m, o: len(m.intersection(o)) >= min(concept_overlap_threshold, len(m), len(o))

In [None]:
# Metric IDs
class EvalMetrics:
    METRIC_ACCURACY = "Accuracy"
    METRIC_MAP = "MAP"
    METRIC_MAR = "MAR"
    METRIC_F1 = "F1"
# Import alias for consistency
evm = EvalMetrics

# Metric visualization parameters
metrics = [
    {"id": evm.METRIC_ACCURACY, "name": "Accuracy", "color": "green"},
    {"id": evm.METRIC_MAP, "name": "Mean Average Precision", "color": "blue"},
    {"id": evm.METRIC_MAR, "name": "Mean Average Recall", "color": "red"},
    {"id": evm.METRIC_F1, "name": "F1 Score", "color": "blueviolet"}
]

### Evaluation Functions

In [None]:
# Construct reference dataset for retrieving side data of elements
# Unusable due to TensorFlow funny stuff
def generate_dataset_reference(
    dataset_eval,                 # Dataset to generate embeddings
    dataset_ref_map=lambda *x: x, # Lambda mapping function for reference
):
    return [e for e in dataset_eval.map(dataset_ref_map).unbatch()]

# Generate the embeddings and the corresponding dataset reference for an image dataset
def generate_image_embeddings(
    image_encoder,                 # Image encoder of clip model
    dataset_eval,                  # Dataset to generate embeddings (WARNING: the dataset must not be shuffling or have a shuffle buffer size of 1)
    dataset_pred_map=lambda *x: x, # Lambda mapping function for prediction
    dataset_ref_map=lambda *x: x,  # Lambda mapping function for reference
):
    print("Generating image embeddings")
    # Generate image embedding
    image_embeddings = image_encoder.predict(
        dataset_eval.map(dataset_pred_map),
        verbose=1,
    )
    # Construct reference dataset for retrieving side data of elements
    dataset_reference = [e for e in dataset_eval.map(dataset_ref_map).unbatch()]
    return dataset_reference, image_embeddings

# Generate the embeddings and the corresponding dataset reference for a text dataset
def generate_text_embeddings(
    text_encoder,                  # Image encoder of clip model
    dataset_eval,                  # Dataset to generate embeddings (WARNING: the dataset must not be shuffling or have a shuffle buffer size of 1)
    dataset_pred_map=lambda *x: x, # Lambda mapping function for prediction
    dataset_ref_map=lambda *x: x,  # Lambda mapping function for reference
):
    print("Generating image embeddings")
    # Generate text embedding
    text_embeddings = text_encoder.predict(
        dataset_eval.map(dataset_pred_map),
        verbose=1,
    )
    # Construct reference dataset for retrieving side data of elements
    dataset_reference = [e for e in dataset_eval.map(dataset_ref_map).unbatch()]
    return dataset_reference, text_embeddings

# Return the results in the form of reference dataset indexes of a text to image retrieval for a series of queries
def find_t2i_matches(
    queries,                # Queries to search
    text_encoder,           # Text encoder of clip model
    image_embeddings,       # Generated image embeddings
    k=10,                   # Number of elements for top-k
    normalize=True,         # Embedding normalization
):
    print("Computing Text-to-Image matches")
    query_embedding = text_encoder.predict(queries)
    # Normalize the query and the image embeddings
    if normalize:
        image_embeddings = tf.math.l2_normalize(image_embeddings, axis=1)
        query_embedding = tf.math.l2_normalize(query_embedding, axis=1)
    # Compute the dot product between the query and the image embeddings
    dot_similarity = tf.matmul(query_embedding, image_embeddings, transpose_b=True)
    # Retrieve top k indices
    results = tf.math.top_k(dot_similarity, k).indices.numpy()
    return results

# Return the results in the form of reference dataset indexes of a image to text retrieval for a series of queries
def find_i2t_matches(
    queries,                # Queries to search
    image_encoder,          # Text encoder of clip model
    text_embeddings,        # Generated image embeddings
    k=10,                   # Number of elements for top-k
    normalize=True,         # Embedding normalization
):
    print("Computing Image-to-Text matches")
    query_embedding = image_encoder.predict(queries)
    # Normalize the query and the text embeddings
    if normalize:
        text_embeddings = tf.math.l2_normalize(text_embeddings, axis=1)
        query_embedding = tf.math.l2_normalize(query_embedding, axis=1)
    # Compute the dot product between the query and the text embeddings
    dot_similarity = tf.matmul(query_embedding, text_embeddings, transpose_b=True)
    # Retrieve top k indices
    results = tf.math.top_k(dot_similarity, k).indices.numpy()
    return results

# Extract the reference dataset objects given a list of indexes
def index_to_reference(results, dataset_reference):
    return [[dataset_reference[match] | {"index": match} for match in result] for result in results]

# Transform a one-hot encoded list of boolean concepts to the respective list of raw concepts labels
# If the flag string_form is set to true, the returned list will contain strings of concatenated concept text
def decode_concepts(concepts, encoder, concept_list, string_form=True):
    c = np.array(concepts)
    c = encoder.inverse_transform(c)
    if string_form:
        c = [" ".join([concept_list[concept] for concept in e]) for e in c]
    return c

In [None]:
# Retrieve relevant items given a list of queries (DO NOT RUN THIS ON A COMPLETE DATASET!!!)
def retrieve_relevant(queries, dataset_reference, reference_preprocess=lambda x: x, relevance=lambda m, o: m == o):
    return [
        [element for element in map(reference_preprocess, dataset_reference) if relevance(query, element)]
        for query in queries
    ]

# Compute the number of relevant items in the first k matches in a list of results
# If queries is None, it is assumed that the queries ran to obtain the list of results are parallel to the elements in dataset_reference
def compute_relevant_at_k(results, dataset_reference, queries=None, k=None, reference_preprocess=lambda x: x, relevance=lambda m, o: m == o):
    if not k:
        k = len(results[0])
    if queries:
        relevant_reference = retrieve_relevant(queries, dataset_reference, reference_preprocess=reference_preprocess, relevance=relevance)
    else:
        relevant_reference = map(reference_preprocess, dataset_reference)
    return [ 
        np.count_nonzero([relevance(match, reference) for match in list(map(reference_preprocess, matches))[0:k]])
        for matches, reference in zip(results, relevant_reference) 
    ]

# Computes the total number of relevant elements for a dataset or queries
# It is assumed that the element returned by reference_preprocess is hashable and can be used as a dictionary key
# If queries is None, it is assumed that the queries ran to obtain the list of results are parallel to the elements in dataset_reference
def compute_total_relevance(
    dataset_reference, queries=None,                                   # Dataset reference or queries to compute total relevance for
    reference_preprocess=lambda x: x, relevance=lambda m, o: m == o,   # Preprocessing function and relevance function
    load_from_file=True, save_to_file=True,                            # Load/Save flags
    fileinfo={}                                                        # Info for loading/saving data from/to a file in the form of a dictionary with the following keys:
                                                                       # path, filename, test_split, val_split, split, metric, other
                                                                       # if a filename is specified, only the base path is needed
):
    tot_relevant = True
    # Check if queries are passed, if so run general function without loading/saving to file
    if queries:
        relevant_reference = retrieve_relevant(queries, dataset_reference, reference_preprocess=reference_preprocess, relevance=relevance)
        return [len(e) for e in relevant_reference]
    # Check for existing file and load relevance data
    if load_from_file:
        tot_relevant = load_relevance_from_csv(fileinfo)
        if not tot_relevant:
            print("Proceeding with total relevance calculation...")
    if not tot_relevant:
        # Build preprocessed dataset
        relevant_reference = list(map(reference_preprocess, dataset_reference))
        total_n = {}
        # Iterate through dataset and count equal items
        for element in relevant_reference:
            if element in total_n:
                total_n[element] += 1
            else:
                total_n[element] = 1
        # Check bytecode of relevance function to determine if the relevance function is equality,
        # if so, return counts, otherwise apply relevance to the whole dataset
        if not relevance.__code__.co_code == (lambda m, o: m == o).__code__.co_code:
            total_n = {element: sum([total_n[x] for x in total_n if relevance(x, element) and element != x]) + 1 for element in tqdm(total_n)} 
        tot_relevant = [total_n[element] for element in relevant_reference]
    # Check for existing file and save relevance data
    if save_to_file:
        save_relevance_to_csv(tot_relevant, fileinfo)
    return tot_relevant

# Build the filename for a relevance file given some dataset and preprocessing attributes
def build_relevance_filename(
    path,                              # Base path for the file
    filename=None,                     # Name of the csv file to load, if None it will be inferred from dataset attributes
    test_split=0.2,                    # Test split percentage
    val_split=0,                       # Validation split percentage
    split="train",                     # Either "train", "test" or "val"
    metric="",                         # Metric used to compute relevance
    other=[],                          # Other attributes as an ordered list of (name, value) tuples
):
    if not filename:
        filename = "TotRelevant_" + str(test_split) + "_" + str(val_split) + "_" + split + "_" + metric
        for attr in other:
            filename += "_" + attr[0] + "-" + str(attr[1])
        filename += ".csv"
    filename = path + filename
    return filename

# Load a csv relevance file given a filename or some dataset and preprocessing attributes
def load_relevance_from_csv(fileinfo={}):
    filename = build_relevance_filename(**fileinfo)
    if not os.path.exists(filename):
        print(f"The relevance file \"{filename}\" does not exist!")
        return False
    else:
        try:
            return np.squeeze(pd.read_csv(filename, header=None).values.tolist())
        except OSError as error:
            print(f"Couldn't load file \"{filename}\": {error}")
    return False

# Save total relevant data to a csv relevance file given a filename or some dataset and preprocessing attributes
def save_relevance_to_csv(tot_relevant, fileinfo={}):
    filename = build_relevance_filename(**fileinfo)
    if os.path.exists(filename):
        print(f"Overwriting \"{filename}\" relevance file!")
    df = pd.DataFrame(tot_relevant)
    try:
        df.to_csv(filename, index=False, header=False)
        return True
    except OSError as error:
            print(f"Couldn't save file \"{filename}\": {error}")
    return False

In [None]:
def compute_top_k_accuracy(results, dataset_reference, relevant_at_k):
    hits = np.count_nonzero(relevant_at_k)
    return hits / len(dataset_reference)

def compute_map_k(results, dataset_reference, relevant_at_k, k=None):
    if not k:
        k = len(results[0])
    precision_at_k = [r/k for r in relevant_at_k]
    return np.sum(precision_at_k) / len(dataset_reference)

def compute_mar_k(results, dataset_reference, relevant_at_k, total_relevant):
    recall_at_k = [rk/tr for rk, tr in zip(relevant_at_k, total_relevant)]
    return np.sum(recall_at_k) / len(dataset_reference)

def compute_F1_k(precision=0, recall=0):
    if precision + recall == 0:
        f1_score = 0
    else:
        f1_score = 2 * (precision * recall) / (precision + recall)
        return f1_score

In [None]:
# Visualize results for text to image queries
def visualize_t2i_results(query, matches):
    print("Top matches for query: \"" + query + "\"")
    if "image path" in matches[0]:
        plt.figure(figsize=(18, 18))
    for i in range(len(matches)):
        if "image path" in matches[i]:
            path = matches[i]["image path"].numpy().decode('UTF-8')
            ax = plt.subplot(3, 3, i + 1)
            plt.imshow(mpimg.imread(path))
            plt.axis("off")
        if "caption" in matches[i]:
            caption = matches[i]["caption"].numpy().decode('UTF-8')
            print(f"{i}) {caption}")
        
# Standard isualization for a multi-purpose plotly graph
def visualize_multigraph(functions, titlexyf=(None, None, None), legend=True):
    fig = go.Figure()
    for function in functions:
        x = function['x']
        y = function['y']
        label = function['label'] if 'label' in function else ""
        color = function['color'] if 'color' in function else None
        linestyle = function['style'] if 'style' in function else "solid"
        marker = go.scatter.Marker(symbol=function['marker'], size=10) if 'marker' in function else None
        opacity = function['opacity'] if 'opacity' in function else 1
        k = len(x)
        fig.add_trace(go.Scatter(
            x=x, y=y,
            line=go.scatter.Line(color=color, dash=linestyle),
            opacity=opacity,
            marker=marker,
            mode="lines+markers+text" if marker else "lines+text",
            name=label,
        ))
    fig.update_xaxes(
        title=titlexyf[0],
        ticks="outside", ticklen=8, minor=dict(dtick=0.5, ticklen=6, tickcolor="black", showgrid=True), ticklabelstep=1, dtick=1, 
        range=(1,k), 
    )
    fig.update_yaxes(
        title=titlexyf[1],
        ticks="outside", ticklen=8, minor=dict(dtick=0.01, ticklen=6, tickcolor="black", showgrid=True), ticklabelstep=1, dtick=0.1,
    )
    fig.update_layout(
        title=titlexyf[2],
        width=900, height=600,
        margin=dict(l=50, r=50, b=20, t=40, pad=4),
        paper_bgcolor="LightSteelBlue",
    )
    fig.show()

In [None]:
# Compute baselines for retrieval
# Assumption of sampling with repetitions, results get more inaccurate as k/l -> inf
def retrieval_baselines(dataset_reference, total_relevant, k, metrics=[]):
    l = len(dataset_reference)
    metrics_out = {}
    for metric in metrics:
        if metric["id"] == evm.METRIC_ACCURACY:
            metrics_out[evm.METRIC_ACCURACY] = sum([ 1 - pow((l - n_el) / l, k) for n_el in total_relevant]) / l
        elif metric["id"] == evm.METRIC_MAP:
            metrics_out[evm.METRIC_MAP] = sum([ n_el / l for n_el in total_relevant]) / l
        elif metric["id"] == evm.METRIC_MAR:
            metrics_out[evm.METRIC_MAR] = sum([ k / l for n_el in total_relevant]) / l
        elif metric["id"] == evm.METRIC_F1:
            metrics_out[evm.METRIC_F1] = compute_F1_k(metrics_out[evm.METRIC_MAP], metrics_out[evm.METRIC_MAR])
    return metrics_out
    
# Computation of a retrieval report containing metrics
def retrieval_report(
    results, reference, relevant,   # Task results, dataset reference and relevant hits at k for task
    tot_relevant=None,              # Rotal number of relevant elements for each dataset element
    k=None,                         # k for metrics computation (should be less or equal than k of retrieval)
    baselines=True,                 # Calculate baselines alongside metrics
    metrics=[],                     # Metrics to take into consideration
    output=True,                    # Print outputs to stdout
    title="Retrieval Report",       # Title of the report
    decimal_precision=4,            # Decimal precision of values
):
    if not k:
        k = len(results[0])
    metrics_out = {}
    
    for metric in metrics:
        if metric["id"] == evm.METRIC_ACCURACY:
            metrics_out[evm.METRIC_ACCURACY] = compute_top_k_accuracy(results, reference, relevant)
        elif metric["id"] == evm.METRIC_MAP:
            metrics_out[evm.METRIC_MAP] = compute_map_k(results, reference, relevant, k=k)
        elif metric["id"] == evm.METRIC_MAR:
            metrics_out[evm.METRIC_MAR] = compute_mar_k(results, reference, relevant, tot_relevant)
        elif metric["id"] == evm.METRIC_F1:
            metrics_out[evm.METRIC_F1] = compute_F1_k(metrics_out[evm.METRIC_MAP], metrics_out[evm.METRIC_MAR])
            
    if baselines:
            baselines = retrieval_baselines(reference, tot_relevant, k, metrics=metrics)
            
    if output:
        print(f"\n ### {title} ###")
        for metric in metrics:
            string = f"{metric['name']:<30}: {round(metrics_out[metric['id']] * 100, decimal_precision):10}%"
            if baselines:
                string += f"{'   Baseline':<8}: {round(baselines[metric['id']] * 100, decimal_precision):10}%"
            print(string)
    
    if baselines:
        return metrics_out, baselines
    return metrics_out
        
# Computation of a retrieval report in graph form containing metrics 
def retrieval_graph_report(
    results, reference,                 # Task results and dataset reference
    tot_relevant=None,                  # Total number of relevant elements for each dataset element
    k_range=(1, 10),                    # k range for metrics computation (maximum value shoul not be greater than k of retrieval)
    baselines=True,                     # Calculate baselines alongside metrics
    metrics=[],                         # Metrics to take into consideration
    titlexyf=(None, None, None),        # Tuple containing: (title of x axis, title of y axis, figure title)
    reference_preprocess=lambda x: x,   # Function to preprocess data contained in the reference dataset
    relevance=lambda m, o: m == o,      # Function to compare elements
    functions=None,                     # Plot pre-existing function data
):
    if not functions:
        functions = {metric["id"]: {"x": [], "y": [], "label": metric["id"], "color": metric["color"], "marker": "0", "opacity": 0.8} for metric in metrics}
        if baselines:
            functions |= {metric["id"] + "_base": {"x": [], "y": [], "label": metric["id"] + " Baseline", "color": metric["color"], "style": "dash", "opacity": 0.5} for metric in metrics}
        for k in range(k_range[0], k_range[1] + 1):
            relevant = compute_relevant_at_k(results, reference, k=k, reference_preprocess=reference_preprocess, relevance=relevance)
            report = retrieval_report(results, reference, relevant, tot_relevant, k=k, baselines=baselines, metrics=metrics, output=False)
            metrics_out = report[0] if baselines else report
            baselines = report[1] if baselines else None
            for metric in metrics_out:
                functions[metric]["x"].append(k)
                functions[metric]["y"].append(metrics_out[metric])
                if baselines:
                    functions[metric + "_base"]["x"].append(k)
                    functions[metric + "_base"]["y"].append(baselines[metric])
    visualize_multigraph(functions.values(), titlexyf)
    return functions

# Computation of a retrieval report in graph form containing metrics, comparing multiple models on the same dataset
def retrieval_graph_compare(
    multi_results, reference,           # Per-model task results and dataset reference
    model_ids,                          # List of ordered model ids and labels in the form {"id": id, "label": label} 
    tot_relevant=None,                  # Total number of relevant elements for each dataset element
    k_range=(1, 10),                    # k range for metrics computation (maximum value shoul not be greater than k of retrieval)
    metrics=[],                         # Metrics to take into consideration
    titlexyf=(None, None, None),        # Tuple containing: (title of x axis, title of y axis, figure title)
    reference_preprocess=lambda x: x,   # Function to preprocess data contained in the reference dataset
    relevance=lambda m, o: m == o,      # Function to compare elements
    functions=None,                     # Plot pre-existing function data
):
    if not functions:
        # Add random markers
        markers = ["0", "1", "2", "3", "17", "26"]
        if len(markers) < len(model_ids):
            print("Too many models!")
            return None
        markers = random.sample(markers, len(model_ids))
        model_ids = [model | {"marker": marker} for model, marker in zip(model_ids, markers)]
        # Generate function models
        functions = {
            metric["id"] + model["id"]: 
            {"x": [], "y": [], "label": model["label"] + " " + metric["id"], "color": metric["color"], "marker": model["marker"], "opacity": 0.8} 
            for model in model_ids for metric in metrics
        }
        # Fill functions
        for model, results in zip(model_ids, multi_results):
            for k in range(k_range[0], k_range[1] + 1):
                relevant = compute_relevant_at_k(results, reference, k=k, reference_preprocess=reference_preprocess, relevance=relevance)
                metrics_out = retrieval_report(results, reference, relevant, tot_relevant, k=k, baselines=False, metrics=metrics, output=False)
                for metric in metrics_out:
                    functions[metric + model["id"]]["x"].append(k)
                    functions[metric + model["id"]]["y"].append(metrics_out[metric])
    visualize_multigraph(functions.values(), titlexyf)
    return functions
    
# Manually compute some text to image queries
def manual_t2i_queries(queries, text_encoder, image_embeddings, dataset_reference, k=10, normalize=True):
    results = find_t2i_matches(queries, clip_text_encoder, test_image_embeddings, k=k, normalize=normalize)
    results = index_to_reference(results, dataset_reference)
    for query, matches in zip(queries, results):
        visualize_t2i_results(query, matches)

## Dataset Metrics

In [None]:
# Generating embeddings for image-to-text and text-to-image tasks
test_dataset_reference, test_image_embeddings = generate_image_embeddings(
    clip_image_encoder,
    test_dataset,
    dataset_pred_map=lambda x, y: x['image'],
    dataset_ref_map=lambda x, y: y | {'image path': x['image path']}
)

_, test_text_embeddings = generate_text_embeddings(
    clip_text_encoder,
    test_dataset,
    dataset_pred_map=lambda x, y: y['caption'],
    dataset_ref_map=lambda x, y: y | {'image path': x['image path']}
)

# Compute relevance for all the test queries in the dataset 
test_tot_relevant_cap = compute_total_relevance(test_dataset_reference, reference_preprocess=reference_preprocess_cap, save_to_file=False, fileinfo=relevance_fileinfo_cap | {"split": "test"})
test_tot_relevant_con = compute_total_relevance(test_dataset_reference, reference_preprocess=reference_preprocess_con_hash, relevance=concept_relevance_hash, save_to_file=False, fileinfo=relevance_fileinfo_con | {"split": "test"})

## Text to Image Task

In [None]:
print("\n### Scoring test data ###")

test_queries = test_dataset.map(lambda x, y: y["caption"])

# Compute matching results and extrapolate relevant matches based on different criterions
test_raw_results = find_t2i_matches(test_queries, clip_text_encoder, test_image_embeddings, k=k, normalize=True)
test_results = index_to_reference(test_raw_results, test_dataset_reference)
test_relevant_cap = compute_relevant_at_k(test_results, test_dataset_reference, k=k, reference_preprocess=reference_preprocess_cap)
test_relevant_con = compute_relevant_at_k(test_results, test_dataset_reference, k=k, reference_preprocess=reference_preprocess_con, relevance=concept_relevance)

# Compute alternative caption results based on concept text concatenation
test_queries_fromconcepts = tf.data.Dataset.from_tensor_slices(decode_concepts([e["concepts"] for e in test_dataset_reference], concepts_onehot, concept_list, string_form=True)).batch(batch_size)
test_results_fromconcepts = index_to_reference(find_t2i_matches(test_queries_fromconcepts, clip_text_encoder, test_image_embeddings, k=k, normalize=True), test_dataset_reference)

### Caption equality relevance metric

In [None]:
_ = retrieval_report(
    test_results, test_dataset_reference, test_relevant_cap, test_tot_relevant_cap,
    k=k,
    metrics=metrics,
    title=f"Test Data - Caption equality metrics @ k={k}",
    decimal_precision=decimal_precision
)
_ = retrieval_graph_report(
    test_results, test_dataset_reference, test_tot_relevant_cap,
    k_range=(1, k),
    metrics=metrics, 
    titlexyf=("k", None, "Test Data - Caption equality metrics"),
    reference_preprocess=reference_preprocess_cap
)

In [None]:
compare_models = [{"id": "ds", "label": "Dataset Caption"}, {"id": "cpfs", "label": "Concept Fusion"}]
_ = retrieval_graph_compare(
    [test_results, test_results_fromconcepts], test_dataset_reference, compare_models, test_tot_relevant_cap,
    k_range=(1, k),
    metrics=metrics, 
    titlexyf=("k", None, "Test Data - Dataset Captions vs Concept-Fusion Captions"),
    reference_preprocess=reference_preprocess_cap
)

### Concept overlap relevance metric

In [None]:
_ = retrieval_report(
    test_results, test_dataset_reference, test_relevant_con, test_tot_relevant_con,
    k=k,
    metrics=metrics,
    title=f"Test Data - Concept overlap metrics @ k={k}",
    decimal_precision=decimal_precision,
)
_ = retrieval_graph_report(
    test_results, test_dataset_reference, test_tot_relevant_con,
    metrics=metrics,
    k_range=(1, k),
    titlexyf=("k", None, "Test Data - Concept overlap metrics"),
    reference_preprocess=reference_preprocess_con, relevance=concept_relevance,
)

## Image to Text

In [None]:
print("\n### Scoring test data ###")

test_queries = test_dataset.map(lambda x, y: x["image"])

# Compute matching results and extrapolate relevant matches based on different criterions
test_raw_results = find_i2t_matches(test_queries, clip_image_encoder, test_text_embeddings, k=k, normalize=True)
test_results = index_to_reference(test_raw_results, test_dataset_reference)
test_relevant_cap = compute_relevant_at_k(test_results, test_dataset_reference, k=k, reference_preprocess=reference_preprocess_cap)
test_relevant_con = compute_relevant_at_k(test_results, test_dataset_reference, k=k, reference_preprocess=reference_preprocess_con, relevance=concept_relevance)

### Caption equality relevance metric

In [None]:
_ = retrieval_report(
    test_results, test_dataset_reference, test_relevant_cap, test_tot_relevant_cap,
    k=k,
    metrics=metrics,
    title=f"Test Data - Caption equality metrics @ k={k}",
    decimal_precision=decimal_precision
)
_ = retrieval_graph_report(
    test_results, test_dataset_reference, test_tot_relevant_cap,
    k_range=(1, k),
    metrics=metrics, 
    titlexyf=("k", None, "Test Data - Caption equality metrics"),
    reference_preprocess=reference_preprocess_cap
)

### Concept overlap relevance metric

In [None]:
_ = retrieval_report(
    test_results, test_dataset_reference, test_relevant_con, test_tot_relevant_con,
    k=k,
    metrics=metrics,
    title=f"Test Data - Concept overlap metrics @ k={k}",
    decimal_precision=decimal_precision,
)
_ = retrieval_graph_report(
    test_results, test_dataset_reference, test_tot_relevant_con,
    metrics=metrics,
    k_range=(1, k),
    titlexyf=("k", None, "Test Data - Concept overlap metrics"),
    reference_preprocess=reference_preprocess_con, relevance=concept_relevance,
)