In [1]:
from new_eval.model import ImageEmbeddingGraph, ImageGraph
from new_eval.new_evaluation import Evaluation
from typing import Callable, Dict, List
import arguebuf as ab
from time import time

In [2]:
topic_mapping = {'allow_shops_to_open_on_holidays_and_sundays': ['nodeset6375',
  'nodeset6410',
  'nodeset6419',
  'nodeset6449',
  'nodeset6451',
  'nodeset6457',
  'nodeset6462',
  'nodeset6466'],
 'health_insurance_cover_complementary_medicine': ['nodeset6363',
  'nodeset6370',
  'nodeset6373',
  'nodeset6378',
  'nodeset6385',
  'nodeset6386',
  'nodeset6395',
  'nodeset6412',
  'med1',
  'med2',
  'med3',
  'med4'],
 'higher_dog_poo_fines': ['nodeset6362',
  'nodeset6367',
  'nodeset6371',
  'nodeset6392',
  'nodeset6400',
  'nodeset6420',
  'nodeset6452',
  'nodeset6468',
  'dog1',
  'dog2',
  'dog3',
  'dog4'],
 'introduce_capital_punishment': ['nodeset6366',
  'nodeset6383',
  'nodeset6387',
  'nodeset6391',
  'nodeset6450',
  'nodeset6453',
  'nodeset6464',
  'nodeset6469',
  'death1',
  'death2',
  'death3',
  'death4'],
 'public_broadcasting_fees_on_demand': ['nodeset6364',
  'nodeset6374',
  'nodeset6389',
  'nodeset6446',
  'nodeset6454',
  'nodeset6463',
  'nodeset6470',
  'media1',
  'media2',
  'media3',
  'media4'],
 'cap_rent_increases': ['nodeset6369',
  'nodeset6377',
  'nodeset6384',
  'nodeset6418',
  'nodeset6455',
  'nodeset6465',
  'rent1',
  'rent2',
  'rent3',
  'rent4',
  'cap_rent_increases'],
 'charge_tuition_fees': ['nodeset6381',
  'nodeset6388',
  'nodeset6394',
  'nodeset6407',
  'nodeset6447',
  'nodeset6456',
  'tuition1',
  'tuition2',
  'tuition3',
  'tuition4',
  'charge_tuition_fees'],
 'keep_retirement_at_63': ['nodeset6382',
  'nodeset6409',
  'nodeset6411',
  'nodeset6416',
  'nodeset6421',
  'nodeset6461'],
 'over_the_counter_morning_after_pill': ['nodeset6368',
  'nodeset6397',
  'nodeset6402',
  'nodeset6406',
  'nodeset6414'],
 'increase_weight_of_BA_thesis_in_final_grade': ['nodeset6376',
  'nodeset6408',
  'nodeset6448',
  'nodeset6467'],
 'stricter_regulation_of_intelligence_services': ['nodeset6365',
  'nodeset6401',
  'nodeset6405',
  'nodeset6458'],
 'EU_influence_on_political_events_in_Ukraine': ['nodeset6399',
  'nodeset6415',
  'nodeset6460',
  'eu_influence_on_political_events_in_ukraine'],
 'make_video_games_olympic': ['nodeset6380', 'nodeset6396', 'nodeset6417'],
 'school_uniforms': ['nodeset6372', 'nodeset6390', 'nodeset6398'],
 'TXL_airport_remain_operational_after_BER_opening': ['nodeset6403',
  'nodeset6422',
  'nodeset6459'],
 'buy_tax_evader_data_from_dubious_sources': ['nodeset6379', 'nodeset6404'],
 'partial_housing_development_at_Tempelhofer_Feld': ['nodeset6393',
  'nodeset6413'],
 'waste_separation': ['nodeset6361'],
 'other': ['nodeset6423',
  'nodeset6424',
  'nodeset6425',
  'nodeset6426',
  'nodeset6427',
  'nodeset6428',
  'nodeset6429',
  'nodeset6430',
  'nodeset6431',
  'nodeset6432',
  'nodeset6433',
  'nodeset6434',
  'nodeset6435',
  'nodeset6436',
  'nodeset6437',
  'nodeset6438',
  'nodeset6439',
  'nodeset6440',
  'nodeset6441',
  'nodeset6442',
  'nodeset6443',
  'nodeset6444',
  'nodeset6445']}
def simulate_mac_phase(queries: List[ImageGraph]) -> Dict[str, List[str]]:
    res = {}
    for query in queries:
        q_name = query.name.split(".")[0]
        for k, v in topic_mapping.items():
            if q_name.lower() == k.lower():
                res[q_name.lower()] = [k for k in v]
                break
            if q_name in v:
                res[q_name.lower()] = [k for k in v]
                break
    return res

In [3]:
from glob import glob
from tqdm import tqdm

BASEPATH = "../data/eval_all"
GRAPHPATH = "../data/graphs/microtexts"
QUERY_BASEPATH = "../data/retrieval_queries"

def test(model_pt, model_ft, processor, version: str) -> dict[str, float]:

    models = [model_pt, model_ft]
    model_names = ["pt", "ft"]
    ds_names = ["simple", "complex"]
    res = {}
    for model, model_name in zip(models, model_names):
        for ds_name in ds_names:
            def embedd(image):
                inputs = processor(image, return_tensors="pt")
                outputs = model(**inputs)
                return outputs.pooler_output


            def get_system_rankings_from_experts(queries: List[ImageGraph]) -> Dict[str, Dict[str, int]]:
                res = {}
                for query in queries:
                    graph = ab.load.file(query.graph_path)
                    res[query.name] = {k.split("/")[1]: v for k, v in graph.userdata["cbrEvaluations"][0]["ranking"].items()}
                return res

            def load_cb(graphpath: str, imagepath: str, mapping_func: Callable | None = None) -> Dict[str, ImageEmbeddingGraph]:
                cb = {}
                for graph in tqdm(glob(f"{graphpath}/*.json")):
                    name = graph.split("/")[-1].split(".")[0]
                    _image_path = f"{imagepath}/{name}.png"
                    if mapping_func is not None:
                        _image_path = f"{imagepath}/{mapping_func(name)}.png"
                    cb[name] = ImageEmbeddingGraph(ImageGraph(graph, _image_path), embedd)
                return cb

            cb = load_cb(GRAPHPATH,f"{BASEPATH}/casebase/{version}", lambda name: f"microtexts-{name}")
            # print(cb)
            query_paths = glob(f"{QUERY_BASEPATH}/microtexts-retrieval-{ds_name}/*.json")
            
            # aggregate results
            results = []
            qs = []
            times = []
            for q in tqdm(query_paths):
                query_name = q.split("/")[-1]
                start = time()
                qs.append(ImageEmbeddingGraph(ImageGraph(q, f"{BASEPATH}/microtexts-retrieval-{ds_name}/{version}/microtexts-{query_name.split('.')[0]}.png"), embedd))
                # print(query)
                times.append(time() - start)
            test = Evaluation(cb, get_system_rankings_from_experts(qs), simulate_mac_phase(qs), qs, embedd, debug=False)
            results.append(test.as_dict())
            print(f"{version}-{model_name}-{ds_name}")
            res[f"{version}-{model_name}-{ds_name}"] = test.as_dict()
            res[f"{version}-{model_name}-{ds_name}"]["duration"] += sum(times)
    return res

In [4]:
from new_eval.load_trained_model import load_pt, load_ft
from transformers import AutoImageProcessor
import pandas as pd
TINY = "microsoft/swinv2-tiny-patch4-window8-256"
tests = []
for i in range(1, 6):
    if i == 5:
        tests.append((load_pt(TINY, f"v{i-1}"), load_ft(TINY, f"v{i-1}"), AutoImageProcessor.from_pretrained(TINY), f"v{i}"))
        break
    tests.append((load_pt(TINY, f"v{i}"), load_ft(TINY, f"v{i}"), AutoImageProcessor.from_pretrained(TINY), f"v{i}"))
res = {}
for t in tests:
    res.update(test(*t))
df = pd.DataFrame(res)
df.to_csv("eval.csv")
df

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.
Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.
Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.
Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.
Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.
100%|██████████| 110/110 [00:08<00:00, 13.35it/s]
100%|██████████| 24/24 [00:01<00:00, 13.50it/s]


v1-pt-simple


100%|██████████| 110/110 [00:08<00:00, 12.68it/s]
100%|██████████| 14/14 [00:01<00:00, 11.07it/s]


v1-pt-complex


100%|██████████| 110/110 [00:09<00:00, 11.25it/s]
100%|██████████| 24/24 [00:02<00:00, 10.52it/s]


v1-ft-simple


100%|██████████| 110/110 [00:07<00:00, 14.13it/s]
100%|██████████| 14/14 [00:00<00:00, 14.74it/s]


v1-ft-complex


100%|██████████| 110/110 [00:07<00:00, 14.42it/s]
100%|██████████| 24/24 [00:01<00:00, 14.90it/s]


v2-pt-simple


100%|██████████| 110/110 [00:07<00:00, 15.06it/s]
100%|██████████| 14/14 [00:00<00:00, 14.32it/s]


v2-pt-complex


100%|██████████| 110/110 [00:07<00:00, 14.62it/s]
100%|██████████| 24/24 [00:01<00:00, 14.86it/s]


v2-ft-simple


100%|██████████| 110/110 [00:07<00:00, 14.96it/s]
100%|██████████| 14/14 [00:00<00:00, 14.71it/s]


v2-ft-complex


100%|██████████| 110/110 [00:07<00:00, 15.00it/s]
100%|██████████| 24/24 [00:01<00:00, 15.01it/s]


v3-pt-simple


100%|██████████| 110/110 [00:07<00:00, 14.61it/s]
100%|██████████| 14/14 [00:00<00:00, 14.45it/s]


v3-pt-complex


100%|██████████| 110/110 [00:07<00:00, 14.82it/s]
100%|██████████| 24/24 [00:01<00:00, 15.05it/s]


v3-ft-simple


100%|██████████| 110/110 [00:07<00:00, 14.91it/s]
100%|██████████| 14/14 [00:00<00:00, 14.52it/s]


v3-ft-complex


100%|██████████| 110/110 [00:07<00:00, 14.50it/s]
100%|██████████| 24/24 [00:01<00:00, 14.68it/s]


v4-pt-simple


100%|██████████| 110/110 [00:07<00:00, 14.94it/s]
100%|██████████| 14/14 [00:00<00:00, 14.42it/s]


v4-pt-complex


100%|██████████| 110/110 [00:07<00:00, 14.65it/s]
100%|██████████| 24/24 [00:01<00:00, 14.90it/s]


v4-ft-simple


100%|██████████| 110/110 [00:07<00:00, 14.57it/s]
100%|██████████| 14/14 [00:00<00:00, 14.51it/s]


v4-ft-complex


100%|██████████| 110/110 [00:07<00:00, 14.80it/s]
100%|██████████| 24/24 [00:01<00:00, 14.82it/s]


v5-pt-simple


100%|██████████| 110/110 [00:07<00:00, 14.93it/s]
100%|██████████| 14/14 [00:00<00:00, 14.07it/s]


v5-pt-complex


100%|██████████| 110/110 [00:07<00:00, 14.39it/s]
100%|██████████| 24/24 [00:01<00:00, 14.70it/s]


v5-ft-simple


100%|██████████| 110/110 [00:07<00:00, 14.82it/s]
100%|██████████| 14/14 [00:00<00:00, 14.47it/s]


v5-ft-complex


Unnamed: 0,v1-pt-simple,v1-pt-complex,v1-ft-simple,v1-ft-complex,v2-pt-simple,v2-pt-complex,v2-ft-simple,v2-ft-complex,v3-pt-simple,v3-pt-complex,v3-ft-simple,v3-ft-complex,v4-pt-simple,v4-pt-complex,v4-ft-simple,v4-ft-complex,v5-pt-simple,v5-pt-complex,v5-ft-simple,v5-ft-complex
ndcg_burges,0.839073,0.928853,0.849567,0.933121,0.853353,0.897135,0.837875,0.968383,0.827786,0.843699,0.827981,0.857831,0.840824,0.818769,0.836342,0.822664,0.851839,0.892443,0.851767,0.876447
ndcg,0.907612,0.953692,0.911118,0.957741,0.913101,0.931446,0.905174,0.977457,0.901023,0.90429,0.900519,0.912794,0.908176,0.887649,0.907037,0.888481,0.914216,0.932546,0.913643,0.924278
map,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
f1,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
recall,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
precision,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
correctness,0.064027,0.535427,0.095407,0.687147,0.012795,0.535876,0.024824,0.750602,0.026969,0.307581,-0.010792,0.431926,0.077601,0.154566,0.042641,0.157403,0.066156,0.381178,0.067691,0.403688
completeness,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
duration,1.776229,1.253036,2.277371,0.947827,1.608195,0.975818,1.612901,0.949502,1.596221,0.96632,1.591663,0.962224,1.632351,0.968425,1.608317,0.962704,1.617033,0.99285,1.630294,0.965262


In [4]:
# eval treemap sat long
from new_eval.load_trained_model import load_pt, load_ft
from transformers import AutoImageProcessor
import pandas as pd
TINY = "microsoft/swinv2-tiny-patch4-window8-256"
tests = [(load_pt(TINY, f"v4", "models/pt_long_statedicts.pt"), load_ft(TINY, f"v4", "models/ft_long_statedicts.pt"), AutoImageProcessor.from_pretrained(TINY), f"v5")]
res = {}
for t in tests:
    res.update(test(*t))
df = pd.DataFrame(res)
df.to_csv("eval_long.csv")
df

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.
100%|██████████| 110/110 [00:08<00:00, 13.30it/s]
100%|██████████| 24/24 [00:01<00:00, 13.31it/s]


v5-pt-simple


100%|██████████| 110/110 [00:08<00:00, 13.10it/s]
100%|██████████| 14/14 [00:01<00:00, 13.21it/s]


v5-pt-complex


100%|██████████| 110/110 [00:07<00:00, 14.85it/s]
100%|██████████| 24/24 [00:01<00:00, 15.01it/s]


v5-ft-simple


100%|██████████| 110/110 [00:07<00:00, 14.92it/s]
100%|██████████| 14/14 [00:00<00:00, 14.60it/s]


v5-ft-complex


Unnamed: 0,v5-pt-simple,v5-pt-complex,v5-ft-simple,v5-ft-complex
ndcg_burges,0.847187,0.912816,0.851913,0.924108
ndcg,0.918807,0.945442,0.918673,0.954548
map,1.0,1.0,1.0,1.0
f1,1.0,1.0,1.0,1.0
recall,1.0,1.0,1.0,1.0
precision,1.0,1.0,1.0,1.0
correctness,0.108619,0.552606,0.076519,0.648516
completeness,1.0,1.0,1.0,1.0
duration,1.801011,1.057448,1.595911,0.956716


# NDCG analysis
What is the worst NDCG possible?

In [5]:
from glob import glob
from ranx import Run, Qrels, evaluate
import json

# simple
max_value = 3

files = glob("../data/retrieval_queries/microtexts-retrieval-simple/*")
ground_truth = {}
qrels = {}
predicted_relevances = {}   
# for file in files:
# file = "../data/retrieval_queries/microtexts-retrieval-simple/death4.json"
for file in files:
    name = file.split("/")[-1].split(".")[0]
    g = json.load(open(file))
    ranking = g["userdata"]["cbrEvaluations"][0]["ranking"]
    ground_truth[name] = ranking
    qrels[name] = {k: (max_value - int(v) + 1) for k, v in ground_truth[name].items()}
    predicted_relevances[name] = ranking
run = Run(predicted_relevances)
qrels = Qrels(qrels)
evaluate(qrels, run, ["ndcg"])

In [7]:
# complex
files = glob("../data/retrieval_queries/microtexts-retrieval-complex/*")
ground_truth = {}
qrels = {}
predicted_relevances = {}   
# for file in files:
# file = "../data/retrieval_queries/microtexts-retrieval-simple/death4.json"
for file in files:
    name = file.split("/")[-1].split(".")[0]
    g = json.load(open(file))
    ranking = g["userdata"]["cbrEvaluations"][0]["ranking"]
    ground_truth[name] = ranking
    qrels[name] = {k: (max_value - int(v) + 1) for k, v in ground_truth[name].items()}
    predicted_relevances[name] = ranking
run = Run(predicted_relevances)
qrels = Qrels(qrels)
evaluate(qrels, run, ["ndcg"])

0.7840210926941339