## Declarations

### Imports

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import itertools
import math
import string
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import plotly.graph_objects as go
import numpy as np
import random
from IPython.display import display
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer
from tqdm import tqdm

import tensorflow_hub as hub
import tensorflow_text as text

tfk = tf.keras
tfkl = tf.keras.layers
kb = tf.keras.backend
print(tf.__version__)
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

### Constants

In [None]:
# Random seed for reproducibility
seed = 42

random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
tf.compat.v1.set_random_seed(seed)

# Turn ON for kaggle filepaths
kaggle = True

kaggle1 = "/kaggle/input/transformers-hackathon/"
kaggle2 = "/kaggle/input/transformers-hackathon-features/"
kaggle3 = "/kaggle/input/clip-weights-v3/"

image_dir = "./resized_train"
caption_pred_file = "caption_prediction_train.csv"
concept_det_file = "concept_detection_train.csv"
concept_file = "concepts.csv"
clip_weights_file = "clip_weights.h5"

if kaggle:
    image_dir = kaggle1 + image_dir
    caption_pred_file = kaggle2 + caption_pred_file
    concept_det_file = kaggle2 + concept_det_file
    concept_file = kaggle2 + concept_file
    clip_weights_file = kaggle3 + clip_weights_file

image_size = (128, 128, 3)

batch_size = 32
epochs = 100

filter_percent_dataset = 1

## Preprocessing

In [None]:
def split(x, test_size=0.2, val_size=0.0, seed=0):
    if val_size + test_size >= 1:
        return None
    x_train, x_test = train_test_split(
        x, test_size=test_size + val_size, random_state=seed
    )
    x_val = None
    if val_size > 0:
        x_test, x_val = train_test_split(
            x_test,
            test_size=val_size / (test_size + val_size),
            random_state=seed,
        )
    return x_train, x_val, x_test

def load_image_from_path(path):
    image = tf.io.read_file(path)
    image = tf.io.decode_jpeg(image, channels=3, dct_method="INTEGER_ACCURATE")

    # may need resizing
    #image = tf.image.resize(image, image_shape[:2])
    image = tf.cast(image, dtype=tf.float16)
    image = image / 255.0
    return image

In [None]:
feature_types = {'image': tf.float16, 'caption': tf.string, 'concepts': tf.bool, 'raw caption': tf.string, 'image path': tf.string}
feature_shapes = {'image': (128, 128, 3), 'caption': (), 'concepts': (8374)}
base_features = ["image", "caption"]

concepts = pd.read_csv(concept_file, sep='\t')
concept_list = concepts.set_index('concept')['concept_name'].to_dict()
# Concept one-hot encoder
concepts_onehot = MultiLabelBinarizer(classes = list(concept_list.keys()))
_ = concepts_onehot.fit([list(concept_list.keys())])

In [None]:
def load_features(image_folder, captions_file, concepts_file, concept_encoder, filter_percent=1):
    features = []
    
    # Import CSVs
    csv_caption_dataset = tf.data.experimental.CsvDataset(
        captions_file,
        field_delim='\t',
        record_defaults=[tf.string, tf.string],
        header=True,
        select_cols=[0, 1]
    )
    csv_concept_dataset = tf.data.experimental.CsvDataset(
        concepts_file,
        field_delim='\t',
        record_defaults=[tf.string, tf.string],
        header=True,
        select_cols=[0, 1]
    )
    
    # We make the assumption that CSV files contain the same key values (image names)
    # following the same ordering

    # Extract features from dataset
    print("Extracting features from CSV file(s)")
    for caption_el, concept_el in tqdm(zip(csv_caption_dataset, csv_concept_dataset)):
        filename_cap, caption = caption_el
        filename_con , concepts = concept_el
        
        # Sanity check
        assert filename_cap == filename_con
        
        image_path = image_dir + "/" + filename_cap + ".jpg"
        
        features.append({
            'caption': caption,
            'image path': image_path,
            'concepts': concept_encoder.transform([concepts.numpy().decode("utf-8").split(";")]),
        })
        
    # Filter elements
    if filter_percent != 1:
        n_features = int(len(features) * filter_percent)
        features = random.sample(features, n_features)
        
    return features

def preprocess_features(features, concept_encoder, filter_percent=1):
    print("Preprocessing features")
    
    # Filter elements
    if filter_percent != 1:
        n_features = int(len(features) * filter_percent)
        features = random.sample(features, n_features)
        
    return {
        'image paths': tf.convert_to_tensor([x["image path"] for x in tqdm(features)], dtype=tf.string),
        'captions': tf.convert_to_tensor([x["caption"] for x in tqdm(features)], dtype=tf.string),
        'concepts': tf.convert_to_tensor(np.vstack([concept_encoder.transform(x["concepts"]).flatten() for x in tqdm(features)]), dtype=tf.bool),
        # 'images': tf.convert_to_tensor([load_image(x["image path"]) for x in tqdm(features)], dtype=tf.float16),
    }

def create_dataset(
        features, 
        input_features_types,
        feature_shapes,
        x_features, y_features=None, 
        x_dict=True, y_dict=True,
        load_images=True, 
        shuffle_buffer_size=1024, 
        batch_size=10, 
        cached=False
):
    # Generate dataset following initial input feature types
    dataset = tf.data.Dataset.from_generator(
        lambda: features, { x: input_features_types[x] for x in input_features_types }
    )
    
    # Preprocessing internal functions
    def setshape(e):
        for (k, v) in feature_shapes.items():
            if k in e:
                e[k].set_shape(v)
        return e
    def add_images(e):
        # Maybe parametrize
        img_from = "image path"
        img_to = "image"
        new_features = list(input_features_types.keys()) + [img_to]
        return {f:e[f] if f != img_to else load_image_from_path(e[img_from]) for f in new_features}
    def split_xy(e):
        e_x = {xf:tf.squeeze(e[xf]) for xf in x_features} if x_dict else tf.squeeze([e[xf] for xf in x_features])
        if y_features:
            e_y = {yf:tf.squeeze(e[yf]) for yf in y_features} if y_dict else tf.squeeze([e[yf] for yf in y_features])
            return (e_x, e_y)
        return e_x
    
    # Preprocess
    if load_images:
        dataset = dataset.map(add_images)
    dataset = dataset.map(setshape)
    dataset = dataset.map(split_xy)

    # Compile dataset
    if cached:
        dataset = dataset.cache()
    dataset = dataset.shuffle(shuffle_buffer_size).batch(batch_size).prefetch(tf.data.AUTOTUNE)

    return dataset

In [None]:
# Load dataset features from csv files, split them and preprocess them
features = load_features(image_dir, caption_pred_file, concept_det_file, concepts_onehot, filter_percent=filter_percent_dataset)
feat_train, feat_val, feat_test = split(features, test_size=0.2, val_size=0.0, seed=seed)

in_feat_typ = {'caption': tf.string, 'concepts': tf.bool, 'image path': tf.string}
x_features_eval = ['image path', 'image']
y_features_eval = ['caption', 'concepts']

train_ds_size = len(feat_train) if feat_train else 0
val_ds_size = len(feat_val) if feat_val else 0
test_ds_size = len(feat_test) if feat_test else 0

train_dataset_eval = create_dataset(feat_train, input_features_types=in_feat_typ, feature_shapes=feature_shapes, x_features=x_features_eval, y_features=y_features_eval, x_dict=True, y_dict=True, batch_size=1, shuffle_buffer_size=1)
test_dataset_eval = create_dataset(feat_test, input_features_types=in_feat_typ, feature_shapes=feature_shapes, x_features=x_features_eval, y_features=y_features_eval, x_dict=True, y_dict=True, batch_size=1, shuffle_buffer_size=1)

## Model Import

In [None]:
def projection(embedding_input, embed_dim, name):
    
    embeddings = tfkl.Dense(embed_dim, name=f'{name}_1')(embedding_input)
    x = tf.nn.selu(embeddings)
    x = tfkl.Dense(embed_dim, name=f'{name}_2')(x)
    x = tfkl.Dropout(0.1)(x)
    x = tfkl.Add()([x, embeddings])
    embeddings = tfkl.LayerNormalization()(x)

    return embeddings

def image_encoder(input_shape, embed_dim, seed=42, supernet=None, preprocessing=None):
    
    tf.random.set_seed(seed)

    input_layer = tfkl.Input(shape=input_shape, name='img_input_layer')
    x = preprocessing(input_layer)
    x = supernet(x)
    x = tfkl.GlobalAveragePooling2D(name='GAP')(x)

    x = projection(x, embed_dim, 'img_embedding_dense_layer')
    
    # Connect input and output through the Model class
    cnn_encoder = tfk.Model(inputs=input_layer, outputs=x, name='image_encoder')

    # Return the encoder
    return cnn_encoder

def text_encoder(embed_dim, preprocess, transformer, trainable=True):

    transformer.trainable = trainable
    
    input_layer = tfkl.Input(shape=(), dtype=tf.string, name="text_input")
    x = preprocess(input_layer)
    x = transformer(x)["pooled_output"]
    x = projection(x, embed_dim, 'txt_embedding_dense_layer')

    text_encoder = tfk.Model(inputs=input_layer, outputs=x, name="text_encoder")
    
    return text_encoder

class CLIP(tfk.Model):
    def __init__(self, image_encoder, text_encoder, **kwargs):
        super().__init__(**kwargs)
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder
        self.loss_tracker = tfk.metrics.Mean(name="loss")
        self.temp = self.add_weight(name='t',
                                 shape=(1, ),
                                 initializer=tfk.initializers.Constant(1.),
                                 trainable=True)

        self.call_model()

        
    @property
    def metrics(self):
        return [self.loss_tracker]

    def call(self, features, training=False):
        image_emb = self.image_encoder(features["image"], training=training)
        text_emb = self.text_encoder(features["caption"], training=training)
        return image_emb, text_emb

    def CLIP_loss(self, image_emb, text_emb):
        norm_image_emb = tf.math.l2_normalize(image_emb, axis=1)
        norm_text_emb = tf.math.l2_normalize(text_emb, axis=1)

        logits = tf.linalg.matmul(norm_image_emb, norm_text_emb, transpose_b=True) * tf.math.exp(self.temp)

        n = tf.shape(logits)[0]
        labels = tf.range(n)

        labels = tf.one_hot(labels, n)

        loss_img = tfk.losses.categorical_crossentropy(labels, logits, from_logits=True)
        loss_txt = tfk.losses.categorical_crossentropy(labels, kb.transpose(logits), from_logits=True)

        return (loss_img + loss_txt) / tf.constant(2.0)

    def train_step(self, features):
        with tf.GradientTape() as tape:
            image_embeddings, caption_embeddings = self(features, training=True)
            loss = self.CLIP_loss(caption_embeddings, image_embeddings)

        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        self.loss_tracker.update_state(loss)
        return {"loss": self.loss_tracker.result()}

    def test_step(self, features):
        image_embeddings, caption_embeddings = self(features, training=False)
        loss = self.CLIP_loss(caption_embeddings, image_embeddings)
        self.loss_tracker.update_state(loss)
        return {"loss": self.loss_tracker.result()}

    def call_model(self):

        image = tf.reshape(tf.convert_to_tensor(np.zeros((128,128,3))), (1,128,128,3))
        caption = tf.convert_to_tensor(["Hello there"], dtype=tf.string)

        sample = {"image": image, "caption": caption}

        self(sample)

    def summary(self):
        super().summary()

        print("\n")
        self.image_encoder.summary()

        print("\n")
        self.text_encoder.summary()

def build_clip(img_supernet,
               img_preprocess,
               text_transformer,
               text_preprocess,
               img_input_shape=(128,128,3),
               txt_input_shape=(393, ), 
               embed_dim=64, 
               learning_rate=2e-5):

    
    text_encoder_model = text_encoder(embed_dim, text_preprocess, text_transformer)
    image_encoder_model = image_encoder(img_input_shape, embed_dim, supernet=img_supernet, preprocessing=img_preprocess)

    clip = CLIP(image_encoder_model, text_encoder_model)
    clip.compile(optimizer = tf.optimizers.AdamW(learning_rate=learning_rate))

    return image_encoder_model, text_encoder_model, clip

In [None]:
text_preprocess = hub.KerasLayer(
        "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3",
        name="text_preprocessing",
    )

text_transformer = hub.KerasLayer(
        "https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1",
        trainable=True,
        name="bert",
    )

img_preprocess = tfk.applications.convnext.preprocess_input
img_supernet = tfk.applications.ConvNeXtTiny(weights='imagenet', include_top=False)
supernet_name = img_supernet.name

clip_image_encoder, clip_text_encoder, clip = build_clip(img_supernet, img_preprocess, text_transformer, text_preprocess)

img_supernet.trainable = False
text_transformer.trainable = False

clip.load_weights(clip_weights_file)

## Model Evaluation

### Evaluation Definitions

#### Evaluation Variables

In [None]:
k = 10
concept_overlap_threshold = 2
decimal_precision = 4
# Function to preprocess data when we want to evaluate captions
reference_preprocess_cap = lambda x: x["caption"].numpy().decode('UTF-8')          
# Function to preprocess data when we want to evaluate concepts
reference_preprocess_con = lambda x: x["concepts"].numpy()
reference_preprocess_con_hash = lambda x: frozenset(sorted(np.where(x["concepts"].numpy())[0]))
# Function to compute if a match is relevant given concept arrays 
concept_relevance = lambda m, o: np.count_nonzero(np.logical_and(m, o)) >= concept_overlap_threshold
concept_relevance_hash = lambda m, o: len(m.intersection(o)) >= concept_overlap_threshold

In [None]:
METRIC_ACCURACY = "Accuracy"
METRIC_MAP = "MAP"
METRIC_MAR = "MAR"
METRIC_F1 = "F1"

metrics = [
    {"id": METRIC_ACCURACY, "name": "Accuracy", "color": "green"},
    {"id": METRIC_MAP, "name": "Mean Average Precision", "color": "blue"},
    {"id": METRIC_MAR, "name": "Mean Average Recall", "color": "red"},
    {"id": METRIC_F1, "name": "F1 Score", "color": "blueviolet"}
]

#### Evaluation Functions

In [None]:
# Generate the embeddings and the corresponding dataset reference for an image dataset
def generate_image_embeddings(
    image_encoder,                 # Image encoder of clip model
    dataset_eval,                  # Dataset to generate embeddings (WARNING: the dataset must not be shuffling or have a shuffle buffer size of 1)
    dataset_pred_map=lambda *x: x, # Lambda mapping function for prediction
    dataset_ref_map=lambda *x: x,  # Lambda mapping function for reference
):
    print("Generating image embeddings")
    # Generate image embedding
    image_embeddings = image_encoder.predict(
        dataset_eval.map(dataset_pred_map),
        verbose=1,
    )
    # Construct reference dataset for retrieving side data of elements
    dataset_reference = [e for e in dataset_eval.map(dataset_ref_map).unbatch()]
    return dataset_reference, image_embeddings

# Return the results in the form of reference dataset indexes of a text to image retrieval for a series of queries
def find_t2i_matches(
    queries,                # Queries to search
    text_encoder,           # Text encoder of clip model
    image_embeddings,       # Generated image embeddings
    k=10,                   # Number of elements for top-k
    normalize=True,         # Embedding normalization
):
    print("Computing Text-to-Image matches")
    # Generate query dataset and get their embeddings
    queries_ds = tf.data.Dataset.from_tensor_slices(queries).batch(batch_size)
    query_embedding = text_encoder.predict(queries_ds)
    # Normalize the query and the image embeddings
    if normalize:
        image_embeddings = tf.math.l2_normalize(image_embeddings, axis=1)
        query_embedding = tf.math.l2_normalize(query_embedding, axis=1)
    # Compute the dot product between the query and the image embeddings
    dot_similarity = tf.matmul(query_embedding, image_embeddings, transpose_b=True)
    # Retrieve top k indices
    results = tf.math.top_k(dot_similarity, k).indices.numpy()
    return results

# Extract the reference dataset objects given a list of indexes
def index_to_reference(results, dataset_reference):
    return [[dataset_reference[match] | {"index": match} for match in result] for result in results]

In [None]:
# Retrieve relevant items given a list of queries (DO NOT RUN THIS ON A COMPLETE DATASET!!!)
def retrieve_relevant(queries, reference_preprocess=lambda x: x, relevance=lambda m, o: m == o):
    return [
        [element for element in map(reference_preprocess, dataset_reference) if relevance(query, element)]
        for query in queries
    ]

# Compute the number of relevant items in the first k matches in a list of results
# If queries is None, it is assumed that the queries ran to obtain the list of results are parallel to the elements in dataset_reference
def compute_relevant_at_k(results, dataset_reference, queries=None, k=None, reference_preprocess=lambda x: x, relevance=lambda m, o: m == o):
    if not k:
        k = len(results[0])
    if queries:
        relevant_reference = retrieve_relevant(queries, reference_preprocess=reference_preprocess, relevance=relevance)
    else:
        relevant_reference = map(reference_preprocess, dataset_reference)
    return [ 
        np.count_nonzero([relevance(match, reference) for match in list(map(reference_preprocess, matches))[0:k]])
        for matches, reference in zip(results, relevant_reference) 
    ]

# Computes the total number of relevant elements for a dataset or queries
# It is assumed that the element returned by reference_preprocess is hashable and can be used as a dictionary key
# If queries is None, it is assumed that the queries ran to obtain the list of results are parallel to the elements in dataset_reference
def compute_total_relevance(dataset_reference, queries=None, reference_preprocess=lambda x: x, relevance=lambda m, o: m == o):
    # Check if queries are passed, if so run general function
    if queries:
        relevant_reference = retrieve_relevant(queries, reference_preprocess=reference_preprocess, relevance=relevance)
        return [len(e) for e in relevant_reference]
    # Build preprocessed dataset
    relevant_reference = list(map(reference_preprocess, dataset_reference))
    total_n = {}
    # Iterate through dataset and count equal items
    for element in relevant_reference:
        if element in total_n:
            total_n[element] += 1
        else:
            total_n[element] = 1
    # Check bytecode of relevance function to determine if the relevance function is equality,
    # if so, return counts, otherwise apply relevance to the whole dataset
    if not relevance.__code__.co_code == (lambda m, o: m == o).__code__.co_code:
        total_n = {element: sum([total_n[x] for x in total_n if relevance(x, element) and element != x]) + 1 for element in tqdm(total_n)} 
    return [total_n[element] for element in relevant_reference]

def load_relevance_from_csv(filename, dataset_reference, reference_preprocess=lambda x: x):
    # TODO
    pass

def save_relevance_to_csv(filename, total_n):
    # TODO
    # df = pd.DataFrame(test_tot_relevant_con, index=[0]) 
    # df.to_csv(r'TotRelevant_Train_1.csv', index=False, header=True)
    pass

In [None]:
def compute_top_k_accuracy(results, dataset_reference, relevant_at_k):
    hits = np.count_nonzero(relevant_at_k)
    return hits / len(dataset_reference)

def compute_map_k(results, dataset_reference, relevant_at_k, k=None):
    if not k:
        k = len(results[0])
    precision_at_k = [r/k for r in relevant_at_k]
    return np.sum(precision_at_k) / len(dataset_reference)

def compute_mar_k(results, dataset_reference, relevant_at_k, total_relevant):
    recall_at_k = [rk/tr for rk, tr in zip(relevant_at_k, total_relevant)]
    return np.sum(recall_at_k) / len(dataset_reference)

def compute_F1_k(precision=0, recall=0):
    if precision + recall == 0:
        f1_score = 0
    else:
        f1_score = 2 * (precision * recall) / (precision + recall)
        return f1_score

In [None]:
# Visualize results for text to image queries
def visualize_t2i_results(query, matches):
    print("Top matches for query: \"" + query + "\"")
    if "image path" in matches[0]:
        plt.figure(figsize=(18, 18))
    for i in range(len(matches)):
        if "image path" in matches[i]:
            path = matches[i]["image path"].numpy().decode('UTF-8')
            ax = plt.subplot(3, 3, i + 1)
            plt.imshow(mpimg.imread(path))
            plt.axis("off")
        if "caption" in matches[i]:
            caption = matches[i]["caption"].numpy().decode('UTF-8')
            print(f"{i}) {caption}")
        
# Standard isualization for a multi-purpose plotly graph
def visualize_multigraph(functions, titlexyf=(None, None, None), legend=True):
    fig = go.Figure()
    for function in functions:
        x = function['x']
        y = function['y']
        label = function['label'] if 'label' in function else ""
        color = function['color'] if 'color' in function else None
        linestyle = function['style'] if 'style' in function else "solid"
        marker = go.scatter.Marker(symbol=function['marker']) if 'marker' in function else None
        opacity = function['opacity'] if 'opacity' in function else 1
        k = len(x)
        fig.add_trace(go.Scatter(
            x=x, y=y,
            line=go.scatter.Line(color=color, dash=linestyle),
            opacity=opacity,
            marker=marker,
            mode="lines+markers+text" if marker else "lines+text",
            name=label,
        ))
    fig.update_xaxes(
        title=titlexyf[0],
        ticks="outside", ticklen=8, minor=dict(dtick=0.5, ticklen=6, tickcolor="black", showgrid=True), ticklabelstep=1, dtick=1, 
        range=(1,k), 
    )
    fig.update_yaxes(
        title=titlexyf[1],
        ticks="outside", ticklen=8, minor=dict(dtick=0.01, ticklen=6, tickcolor="black", showgrid=True), ticklabelstep=1, dtick=0.1,
    )
    fig.update_layout(
        title=titlexyf[2],
        width=900, height=600,
        margin=dict(l=50, r=50, b=20, t=40, pad=4),
        paper_bgcolor="LightSteelBlue",
    )
    fig.show()

In [None]:
# Compute baselines for retrieval
# Assumption of sampling with repetitions, results get more inaccurate as k/l -> inf
def retrieval_baselines(dataset_reference, total_relevant, k, metrics=[]):
    l = len(dataset_reference)
    metrics_out = {}
    for metric in metrics:
        if metric["id"] == METRIC_ACCURACY:
            metrics_out[METRIC_ACCURACY] = sum([ 1 - pow((l - n_el) / l, k) for n_el in total_relevant]) / l
        elif metric["id"] == METRIC_MAP:
            metrics_out[METRIC_MAP] = sum([ n_el / l for n_el in total_relevant]) / l
        elif metric["id"] == METRIC_MAR:
            metrics_out[METRIC_MAR] = sum([ k / l for n_el in total_relevant]) / l
        elif metric["id"] == METRIC_F1:
            metrics_out[METRIC_F1] = compute_F1_k(metrics_out[METRIC_MAP], metrics_out[METRIC_MAR])
    return metrics_out
    
# Computation of a retrieval report containing metrics
def retrieval_report(
    results, reference, relevant,   # Task results, dataset reference and relevant hits at k for task
    tot_relevant=None,              # Rotal number of relevant elements for each dataset element
    k=None,                         # k for metrics computation (should be less or equal than k of retrieval)
    baselines=True,                 # Calculate baselines alongside metrics
    metrics=[],                     # Metrics to take into consideration
    output=True,                    # Print outputs to stdout
    title="Retrieval Report",       # Title of the report
    decimal_precision=4,            # Decimal precision of values
):
    if not k:
        k = len(results[0])
    metrics_out = {}
    
    for metric in metrics:
        if metric["id"] == METRIC_ACCURACY:
            metrics_out[METRIC_ACCURACY] = compute_top_k_accuracy(results, reference, relevant)
        elif metric["id"] == METRIC_MAP:
            metrics_out[METRIC_MAP] = compute_map_k(results, reference, relevant, k=k)
        elif metric["id"] == METRIC_MAR:
            metrics_out[METRIC_MAR] = compute_mar_k(results, reference, relevant, tot_relevant)
        elif metric["id"] == METRIC_F1:
            metrics_out[METRIC_F1] = compute_F1_k(metrics_out[METRIC_MAP], metrics_out[METRIC_MAR])
            
    if baselines:
            baselines = retrieval_baselines(reference, tot_relevant, k, metrics=metrics)
            
    if output:
        print(f"\n ### {title} ###")
        for metric in metrics:
            string = f"{metric['name']:<30}: {round(metrics_out[metric['id']] * 100, decimal_precision):10}%"
            if baselines:
                string += f"{'   Baseline':<8}: {round(baselines[metric['id']] * 100, decimal_precision):10}%"
            print(string)
    
    if baselines:
        return metrics_out, baselines
    return metrics_out
        
# Computation of a retrieval report in graph formcontaining metrics 
def retrieval_graph_report(
    results, reference,                 # Task results, dataset reference and relevant hits at k for task
    tot_relevant=None,                  # Total number of relevant elements for each dataset element
    k_range=(1, 10),                    # k range for metrics computation (maximum value shoul not be greater than k of retrieval)
    baselines=True,                     # Calculate baselines alongside metrics
    metrics=[],                         # Metrics to take into consideration
    titlexyf=(None, None, None),        # Tuple containing: (title of x axis, title of y axis, figure title)
    reference_preprocess=lambda x: x,   # Function to preprocess data contained in the reference dataset
    relevance=lambda m, o: m == o,      # Function to compare elements
):
    functions = {metric["id"]: {"x": [], "y": [], "label": metric["id"], "color": metric["color"], "marker": "0", "opacity": 0.8} for metric in metrics}
    if baselines:
        functions |= {metric["id"] + "_base": {"x": [], "y": [], "label": metric["id"] + " Baseline", "color": metric["color"], "style": "dash", "opacity": 0.5} for metric in metrics}
    for k in range(k_range[0], k_range[1] + 1):
        relevant = compute_relevant_at_k(results, reference, k=k, reference_preprocess=reference_preprocess, relevance=relevance)
        metrics_out, baselines = retrieval_report(results, reference, relevant, tot_relevant, k=k, baselines=baselines, metrics=metrics, output=False)
        for metric in metrics_out:
            functions[metric]["x"].append(k)
            functions[metric]["y"].append(metrics_out[metric])
            if baselines:
                functions[metric + "_base"]["x"].append(k)
                functions[metric + "_base"]["y"].append(baselines[metric])
    visualize_multigraph(functions.values(), titlexyf)
    return functions
    
# Manually compute some text to image queries
def manual_t2i_queries(queries, text_encoder, image_embeddings, dataset_reference, k=10, normalize=True):
    results = find_t2i_matches(queries, clip_text_encoder, test_image_embeddings, k=k, normalize=normalize)
    results = index_to_reference(results, test_dataset_reference)
    for query, matches in zip(queries, results):
        visualize_t2i_results(query, matches)

### Test Set Evaluation

In [None]:
print("\n### Scoring test data ###")
test_dataset_reference, test_image_embeddings = generate_image_embeddings(
    clip_image_encoder,
    test_dataset_eval,
    dataset_pred_map=lambda x, y: x['image'],
    dataset_ref_map=lambda x, y: y | {'image path': x['image path']}
)
test_queries = [e["caption"] for e in test_dataset_reference]
# Compute relevance for all the queries in the dataset using only caption equality as a metric
test_tot_relevant_cap = compute_total_relevance(test_dataset_reference, reference_preprocess=reference_preprocess_cap)
test_tot_relevant_con = compute_total_relevance(test_dataset_reference, reference_preprocess=reference_preprocess_con_hash, relevance=concept_relevance_hash)
# Compute matching results and extrapolate relevant matches based on different criterions
test_raw_results = find_t2i_matches(test_queries, clip_text_encoder, test_image_embeddings, k=k, normalize=True)
test_results = index_to_reference(test_raw_results, test_dataset_reference)
test_relevant_cap = compute_relevant_at_k(test_results, test_dataset_reference, k=k, reference_preprocess=reference_preprocess_cap)
test_relevant_con = compute_relevant_at_k(test_results, test_dataset_reference, k=k, reference_preprocess=reference_preprocess_con, relevance=concept_relevance)

#### Caption equality relevance metric

In [None]:
_ = retrieval_report(
    test_results, test_dataset_reference, test_relevant_cap, test_tot_relevant_cap,
    k=k,
    metrics=metrics,
    title="Test Data - Caption equality metrics",
    decimal_precision=decimal_precision
)
_ = retrieval_graph_report(
    test_results, test_dataset_reference, test_tot_relevant_cap,
    k_range=(1, k),
    metrics=metrics, 
    titlexyf=("k", None, "Test Data - Caption equality metrics"),
    reference_preprocess=reference_preprocess_cap
)

#### Concept overlap relevance metric

In [None]:
_ = retrieval_report(
    test_results, test_dataset_reference, test_relevant_con, test_tot_relevant_con,
    k=k,
    metrics=metrics,
    title="Test Data - Concept overlap metrics",
    decimal_precision=decimal_precision,
)
_ = retrieval_graph_report(
    test_results, test_dataset_reference, test_tot_relevant_con,
    metrics=metrics,
    k_range=(1, k),
    titlexyf=("k", None, "Test Data - Concept overlap metrics"),
    reference_preprocess=reference_preprocess_con, relevance=concept_relevance,
)

### Training Set Evaluation

In [None]:
print("### Scoring training data ###")
train_dataset_reference, train_image_embeddings = generate_image_embeddings(
    clip_image_encoder,
    train_dataset_eval,
    dataset_pred_map=lambda x, y: x['image'],
    dataset_ref_map=lambda x, y: y | {'image path': x['image path']}
)
train_queries = [e["caption"] for e in train_dataset_reference]
# Compute relevance for all the queries in the dataset using only caption equality as a metric
train_tot_relevant_cap = compute_total_relevance(train_dataset_reference, reference_preprocess=reference_preprocess_cap)
train_tot_relevant_con = compute_total_relevance(test_dataset_reference, reference_preprocess=reference_preprocess_con_hash, relevance=concept_relevance_hash)
# Compute matching results and extrapolate relevant matches based on different criterions
train_raw_results = find_t2i_matches(train_queries, clip_text_encoder, train_image_embeddings, k=k, normalize=True)
train_results = index_to_reference(train_raw_results, train_dataset_reference)
train_relevant_cap = compute_relevant_at_k(train_results, train_dataset_reference, k=k, reference_preprocess=reference_preprocess_cap)
train_relevant_con = compute_relevant_at_k(train_results, train_dataset_reference, k=k, reference_preprocess=reference_preprocess_con, relevance=concept_relevance)

#### Caption equality relevance metric

In [None]:
_ = retrieval_report(
    train_results, train_dataset_reference, train_relevant_cap, train_tot_relevant_cap,
    k=k,
    metrics=metrics,
    title="Training Data - Caption equality metric",
    decimal_precision=decimal_precision
)
_ = retrieval_graph_report(
    train_results, train_dataset_reference, train_tot_relevant_cap,
    k_range=(1, k),
    metrics=metrics, 
    titlexyf=("k", None, "Training Data - Caption equality metrics"),
    reference_preprocess=reference_preprocess_cap
)

#### Concept overlap relevance metric

In [None]:
_ = retrieval_report(
    train_results, train_dataset_reference, train_relevant_con, train_tot_relevant_con,
    k=k,
    metrics=metrics,
    title="Training Data - Concept overlap metric",
    decimal_precision=decimal_precision
)
_ = retrieval_graph_report(
    train_results, train_dataset_reference, train_tot_relevant_con,
    k_range=(1, k),
    metrics=metrics, 
    titlexyf=("k", None, "Training - Caption equality metrics"),
    reference_preprocess=reference_preprocess_cap
)