In [1]:
import datasets
import torch
%matplotlib inline
import pandas as pd
from cocos.utils import get_project_root

from evaluate import plot_embeddings, compute_metrics, evaluate
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"

datasets.enable_caching()
torch.multiprocessing.set_sharing_strategy('file_system')

In [2]:
MODELS = {
    'TS,IM,DE': get_project_root() / 'checkpoints/paper_models/ts_im_de/epoch=0-step=412499.ckpt',
    'TS,IM': get_project_root() / 'checkpoints/paper_models/ts_im/epoch=0-step=397499.ckpt',
    'TS,DE': get_project_root() / 'checkpoints/paper_models/ts_de/epoch=0-step=389999.ckpt',
    'TS': get_project_root() / 'checkpoints/paper_models/ts/epoch=0-step=397499.ckpt',
    'None': get_project_root() / 'checkpoints/paper_models/none/epoch=0-step=389999.ckpt',
}

In [3]:
def print_table(model_metrics):
    rows = []

    for model_name, stats in model_metrics.items():
        for run in stats:
            r = {
                "model_name": model_name,
                **run["metrics"],
            }
            rows.append(r)

    df = pd.DataFrame.from_records(rows)
    df = df.sort_values(["MAP"])
    with pd.option_context('display.max_rows', None, 'display.max_columns', None):
        display(df)

In [4]:
model_metrics = {
    model_name: [] for model_name in MODELS
}

for model_name, model_path in MODELS.items():
    print("#" * 100)
    print(model_name)

    qe1, te1, q_l1, t_l1, id2desc1, metrics = evaluate(
        model_path, 
        max_distraction_snippets = 10000, 
        device="cpu"
    )
    res = {
        "metrics": metrics
    }
    model_metrics[model_name].append(res)


####################################################################################################
TS,IM,DE


Embedding queries:   0%|          | 0/303 [00:00<?, ?it/s]
Embedding queries: 100%|██████████| 303/303 [01:15<00:00,  4.02it/s]
Embedding targets:   0%|          | 0/5303 [00:00<?, ?it/s]
Embedding targets: 100%|██████████| 5303/5303 [05:55<00:00, 14.91it/s]


MAP: 50.87
####################################################################################################
TS,IM


Embedding queries:   0%|          | 0/303 [00:00<?, ?it/s]
Embedding queries: 100%|██████████| 303/303 [01:11<00:00,  4.22it/s]
Embedding targets:   0%|          | 0/5303 [00:00<?, ?it/s]
Embedding targets: 100%|██████████| 5303/5303 [05:45<00:00, 15.36it/s]


MAP: 33.78
####################################################################################################
TS,DE


Embedding queries:   0%|          | 0/303 [00:00<?, ?it/s]
Embedding queries: 100%|██████████| 303/303 [01:14<00:00,  4.09it/s]
Embedding targets:   0%|          | 0/5303 [00:00<?, ?it/s]
Embedding targets: 100%|██████████| 5303/5303 [06:20<00:00, 13.93it/s]


MAP: 36.32
####################################################################################################
TS


Embedding queries:   0%|          | 0/303 [00:00<?, ?it/s]
Embedding queries: 100%|██████████| 303/303 [01:13<00:00,  4.12it/s]
Embedding targets:   0%|          | 0/5303 [00:00<?, ?it/s]
Embedding targets: 100%|██████████| 5303/5303 [05:40<00:00, 15.56it/s]


MAP: 26.47
####################################################################################################
None


Embedding queries:   0%|          | 0/303 [00:00<?, ?it/s]
Embedding queries: 100%|██████████| 303/303 [01:13<00:00,  4.10it/s]
Embedding targets:   0%|          | 0/5303 [00:00<?, ?it/s]
Embedding targets: 100%|██████████| 5303/5303 [06:17<00:00, 14.06it/s]


MAP: 15.65


In [5]:
print_table(model_metrics)

Unnamed: 0,model_name,MAP,NDCG,R-Precision,P@1,P@3,P@10
4,,15.65,49.85,18.32,45.87,37.95,24.77
3,TS,26.47,59.64,27.18,58.09,50.77,36.96
1,"TS,IM",33.78,66.03,33.21,69.8,60.95,45.33
2,"TS,DE",36.32,65.94,35.57,59.41,54.57,44.39
0,"TS,IM,DE",50.87,76.28,48.8,73.6,70.3,59.7
