In [42]:
import os
if os.getcwd() == '/home/user/code':
    os.chdir('/home/user/code/nlp2024_ClefTask4SOTA')

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [63]:
from src.dataset import PATH
from src.experiment_runner import Experiment, run
from src.models import OllamaModel

from src.prompt_templates import (
    zero_shot_template_initial,
    few_shot_template_initial,
    zero_shot_template_optimized01,
    few_shot_template_optimized01,
    zero_shot_template_optimized02,
    few_shot_template_optimized02
)

from src.content_extraction import naive_doctaet, parse_response

def cut_tex(tex):
    num_chars_allowed = 23256
    return tex[:min(len(tex), num_chars_allowed)]

def extract_doctaet(model, prompt_template, tex):
    doctaet = naive_doctaet(tex)
    cut_doctaet = cut_tex(doctaet)
    prompt = prompt_template(cut_doctaet)
    response = model.generate(prompt)
    parsed =  parse_response(response)
    return parsed


llama3_8b = OllamaModel("llama3:8b")
exps = [
    Experiment(llama3_8b, zero_shot_template_initial, extract_doctaet, "llama3_8b_zero_shot_template_initial"),
    Experiment(llama3_8b, few_shot_template_initial, extract_doctaet, "llama3_8b_few_shot_template_initial"),
    Experiment(llama3_8b, zero_shot_template_optimized01, extract_doctaet, "llama3_8b_zero_shot_template_optimized01"),
    Experiment(llama3_8b, few_shot_template_optimized01, extract_doctaet, "llama3_8b_few_shot_template_optimized01"),
    Experiment(llama3_8b, zero_shot_template_optimized02, extract_doctaet, "llama3_8b_zero_shot_template_optimized02"),
    Experiment(llama3_8b, few_shot_template_optimized02, extract_doctaet, "llama3_8b_few_shot_template_optimized02"),
]

dfs = []
for exp in exps:
    df = run(exp, PATH.VAL, 100)
    dfs.append(df)

 14%|██████                                     | 14/100 [01:07<09:10,  6.40s/it]

In [61]:
from src.evaluate import evaluate
import pandas as pd
results = []
for df in dfs:
    res = evaluate(df).reset_index().rename({"index": "metric"}, axis=1)
    res["run"] = df["run"][0]
    results.append(res)


results = pd.concat(results)