<span style="color:red; font-family:Helvetica Neue, Helvetica, Arial, sans-serif; font-size:2em;">An Exception was encountered at '<a href="#papermill-error-cell">In [3]</a>'.</span>

In [None]:
%load_ext autoreload
%autoreload 2
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=0

In [None]:
from dataclasses import dataclass
from typing import *
import logging

import pandas as pd

import lass.test
import lass.train
import lass.datasets
import lass.pipeline
from lass.log_handling import LogLoader, LogLoaderArgs, PaperTasks

<span id="papermill-error-cell" style="color:red; font-family:Helvetica Neue, Helvetica, Arial, sans-serif; font-size:2em;">Execution using papermill encountered an exception here and stopped:</span>

In [None]:
@dataclass
class Architecture():
    name: str
    name_short: str
    batch_size: int
    gradient_accumulation_steps: int
    
architecture = Architecture(
    name="microsoft/deberta-v3-base", 
    name_short="deberta-base",
    batch_size=16,
    gradient_accumulation_steps=2,
)


log_loader_args = LogLoaderArgs(
    logdir='../artifacts/logs/',
    tasks='paper-full',
    model_families=['BIG-G T=0'],
    model_sizes=['128b'],
    shots=[0],
    query_types=['multiple_choice'],
)

model = lass.train.train(
    data_args=log_loader_args,
    group='task-generalisation',
    output_dir="task-generalisation",
    split='task',
    model_name=architecture.name,
    model_name_short=f"{architecture.name_short}-0.3test",
    batch_size=architecture.batch_size,
    gradient_accumulation_steps=architecture.gradient_accumulation_steps,
    include_model_in_input=False,
    include_n_targets_in_input=False,
    test_fraction=0.3,
    n_epochs=6,
    extra_training_args={
        "warmup_steps": 3000,
        "learning_rate": 2e-5,
    },
    # is_test_run=True,
)

info: dict = lass.test.test(
    data_args=log_loader_args,
    split='task',
    model_loc=model,
    model_name=architecture.name,
    per_task=True,
) # type: ignore


task_results = {task: info["tasks"][task]["metrics"] | {'count': len(info["tasks"][task]["test"])} for task in info["tasks"].keys()}
task_results['_total'] = info['metrics'] | {'count': len(info['test'])}

df = pd.DataFrame.from_dict(task_results, orient='index')
df.to_csv(f"task-generalisation/results-base-30t.csv")
