In [None]:
%load_ext autoreload
%autoreload 2
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=0

In [None]:
import os
import logging
from typing import *
from pathlib import Path
from datetime import datetime
from dataclasses import dataclass
import shutil

import wandb
import pandas as pd

from transformers.models.auto.modeling_auto import AutoModelForMultipleChoice
from transformers.trainer import Trainer
from transformers.training_args import TrainingArguments
from datasets.dataset_dict import DatasetDict
from datasets.arrow_dataset import Dataset as hf_Dataset
from datasets.splits import Split
from torch.utils.data import Dataset

import lass.datasets
import lass.pipeline
import lass.metrics
import lass.metrics.baseline
from lass.metrics.baseline import analyse, merge
from lass.log_handling import LogLoader, LogLoaderArgs, PaperTasks


- <https://huggingface.co/docs/transformers/v4.23.1/en/main_classes/output#transformers.modeling_outputs.MultipleChoiceModelOutput>
- <https://huggingface.co/docs/transformers/tasks/multiple_choice>
- <https://huggingface.co/docs/transformers/v4.23.1/en/model_doc/bert#transformers.BertForMultipleChoice>

### Custom Data Collator for Reasons?

In [None]:
from dataclasses import dataclass
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
from typing import Optional, Union
import torch


@dataclass
class DataCollatorForMultipleChoice:
    """
    Data collator that will dynamically pad the inputs for multiple choice received.
    """

    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None

    def __call__(self, features):
        label_name = "label" if "label" in features[0].keys() else "labels"
        labels = [feature.pop(label_name) for feature in features]
        batch_size = len(features)
        num_choices = len(features[0]["input_ids"])
        flattened_features = [
            [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
        ]
        flattened_features = sum(flattened_features, [])

        batch = self.tokenizer.pad(
            flattened_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
        batch["labels"] = torch.tensor(labels, dtype=torch.int64)
        return batch

## Finetune a model on the actual task

In [None]:
def finetune(
    task: str,
    model_name: str,
    batch_size: int,
    group: str,
    test_fraction: float = 0.2,
    test_max_instances: Optional[int] = 20000,
    train_max_instances: Optional[int] = None,
    model_name_short: Optional[str] = None,
    seed: int = 42,
    n_epochs: int = 6,
    max_sequence_length: int = 512,
    gradient_accumulation_steps: int = 1,
    use_wandb: bool = True,
    is_test_run: bool = False,
    extra_training_args: Dict[str, Any] = {},
    output_dir: Optional[Union[Path, str]] = None
):
    """
    Finetune a model on a task.
    Errors if the task has a varying amount of choices per instance.
    """
    data_args = LogLoaderArgs(
        logdir="../artifacts/logs",
        tasks=[task],
        model_families=["BIG-G T=0"],
        model_sizes=["128b"],
        shots=[0],
        query_types=["multiple_choice"],
    )

    if is_test_run:
        print("Running in test mode")
        n_epochs = 1
        extra_training_args['eval_steps'] = 5
        extra_training_args['save_steps'] = 5
        extra_training_args['logging_steps'] = 5
        train_max_instances = 200
        test_max_instances = 200
        print(f"Tasks: {data_args.tasks}\n n_epochs: {n_epochs}")

    logging.info("Starting data loading")
    loader = LogLoader(data_args)
    data = lass.datasets.to_dataframe(loader)
    logging.info("Loaded data.")

    data = lass.pipeline.binarize(data)
    data = lass.pipeline.augment(data)
    data = lass.pipeline.clean(data)

    # We currently don't support tasks with varying number of choices per instance
    if data.n_targets.nunique() > 1:
        raise ValueError(f"Task {task} has varying number of choices per instance.")
    n_targets = data.n_targets.iloc[0]
    print(f"Task {task} has {n_targets} choices per instance.")

    train, test = lass.datasets.split("instance", data, test_fraction=test_fraction, seed=seed)

    # Sometimes we just want a little smaller datasets for speed
    if train_max_instances is not None and len(train) > train_max_instances:
        train: pd.DataFrame = train.sample(  # type: ignore
            n=train_max_instances, random_state=seed)
    if test_max_instances is not None and len(test) > test_max_instances:
        test: pd.DataFrame = test.sample(  # type: ignore
            n=test_max_instances, random_state=seed)  # type: ignore

    # Log some stats & examples
    stats = merge(analyse(train), analyse(test), 'train', 'test')

    def huggingfaceify(df: pd.DataFrame) -> hf_Dataset:
        find_label = lambda row: max(((idx, row.target_values[target]) for idx, target in enumerate(row.targets)), key=lambda x: x[1])[0]
        df_hf = pd.DataFrame()
        df_hf['text'] = df.input
        df_hf['options'] = df.targets
        df_hf['label'] = df.apply(find_label, axis=1)
        # Watch out for those with multiple correct labels.
        return hf_Dataset.from_pandas(df_hf, preserve_index=False)

    dataset = DatasetDict()
    dataset['train'] = huggingfaceify(train)
    dataset['test'] = huggingfaceify(test)

    print(dataset['train'][0])

    tokenizer = lass.pipeline.get_tokenizer(model_name)
    def preprocess_function(examples):
        # Create a 'text' context for each option
        texts = [[text] * n_targets for text in examples["text"]]
        options = examples["options"]

        # Flatten the lists
        texts = sum(texts, [])
        options = sum(options, [])

        # Tokenize the flattened representations
        tokenized_examples = tokenizer(texts, options, truncation=True, padding="max_length", max_length=max_sequence_length)

        # Un-flatten the lists
        values = {k: [v[i:i+n_targets] for i in range(0, len(v), n_targets)] for k, v in tokenized_examples.items()}
        return examples | values


    # Tokenize dataset
    logging.info("Starting tokenization")
    os.environ['TOKENIZERS_PARALLELISM'] = "true"
    tokenized_datasets: DatasetDict = dataset.map(preprocess_function, batched=True)
    # print(tokenized_datasets['train'][0])

    train_dataset = tokenized_datasets["train"].shuffle(seed=seed)
    eval_dataset = tokenized_datasets["test"]

    # Setup tagging and paths
    model_name_short = model_name_short or model_name
    shot_str = ','.join([str(s) for s in loader.shots or []]) if data_args.shots else "all"
    bs = batch_size if gradient_accumulation_steps == 1 else f"{batch_size}*{gradient_accumulation_steps}"
    name = ""\
        + (f"test-" if is_test_run else "")\
        + (f"{model_name_short}")\
        + (f"-bs{bs}")\
        + (f"-{shot_str}sh")\

    # Setup wandb
    if use_wandb:
        os.environ['WANDB_LOG_MODEL'] = "false"
        wandb.login()
        wandb.init(
            project="lass",
            dir=f"{output_dir or '.'}/wandb",
            group=group,
            name=name,
            mode="disabled" if is_test_run else "online",
            tags=[
                f"assr:{model_name_short}",
                f"tasks:{data_args.tasks}",
                f"shots:{shot_str}",
            ]
        )

        wandb.config.seed = seed
        wandb.config.is_test_run = is_test_run
        wandb.config.stats = stats
        wandb.config.data = {
            'query_types': ",".join(data_args.query_types or []),
            'tasks': data_args.tasks,
            'test_fraction': test_fraction,
            'shots': shot_str,
            'pop_model_family': data_args.model_families,
            'pop_model_size': data_args.model_sizes,
        }
        wandb.config.extra_training_args = extra_training_args

    # Setup trainer
    model = AutoModelForMultipleChoice.from_pretrained(model_name)

    if model_name == "gpt2":
        model.config.pad_token_id = model.config.eos_token_id

    default_args: Dict[str, Any] = {
        "output_dir": f"{output_dir or '.'}/{name}-{datetime.now().strftime('%m%d%H%M')}",
        "optim": "adamw_torch",
        # "evaluation_strategy": "steps",
        "evaluation_strategy": "epoch",
        "save_strategy": "epoch",
        "logging_strategy": "epoch",
        "report_to": "wandb" if wandb else "none",
        "per_device_train_batch_size": batch_size,
        "per_device_eval_batch_size": batch_size,
        "gradient_accumulation_steps": gradient_accumulation_steps,
        "num_train_epochs": n_epochs,
        # This combination saves models immediately, but only keeps the best and the last.
        "load_best_model_at_end": True,
        "save_total_limit": 1,
    }
    training_args = TrainingArguments(**(default_args | extra_training_args))

    # TODO: To fix. Stays: accuracy, brier. Unknown: f1, precision, recall, roc_auc, balanced_accuracy
    # metrics = ["accuracy", "precision", "recall", "f1",
    #            "roc_auc", "brier_score", "balanced_accuracy"]
    metrics = ["accuracy"]
    metrics += ["roc_auc_multiclass"] if n_targets > 2 else ["roc_auc"]
    compute_metrics = lass.metrics.hf.get_metric_computer(metrics)

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,  # type: ignore
        eval_dataset=eval_dataset,  # type: ignore
        compute_metrics=compute_metrics,
        # data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer),
    )

    trainer.train()

    dummy_args = TrainingArguments(output_dir="tmp_trainer")  # To silence warning
    dummy_trainer = Trainer(model=model, args=dummy_args, compute_metrics=compute_metrics)
    
    logits, labels, metrics = dummy_trainer.predict(eval_dataset)  # type: ignore

    if use_wandb:
        wandb.finish()

    return {
        "model": model,
        "train": train,
        "test": test,
        "logits": logits,
        "labels": labels,
        "metrics": metrics,
    }


## Finetune an model on all tasks

In [None]:
# tasks = ["hyperbaton"]
# tasks = ["crash_blossom"]
tasks = PaperTasks.full()[82:] # Make sure not to overwrite existing results when re-running

results: dict = {}
for task in tasks:
    if task in ["dyck_languages", "intersect_geometry"]: # Still giving OOMs, can't make it much smaller
        continue

    try:
        results_ = finetune(
            task=task,
            group="task-specific-finetuning",
            model_name="microsoft/deberta-v3-small",
            model_name_short=f"deberta-{task}",
            batch_size=1,
            gradient_accumulation_steps=32,
            output_dir=f"task-specific-finetuning/{task}",
            extra_training_args={
                "learning_rate": 2e-5,
            },
            # is_test_run=True,
        )
    except ValueError as e:
        # Ignore empty tasks, and tasks with varying number of instances
        if "varying number of choices per instance" in str(e) or \
            "No data found." in str(e):
            print(f"Expected error for task {task}: {e}")
            continue
        else:
            results_: dict[str, Any] = {"metrics": {k: "error" for k in ["test_loss","test_accuracy","test_roc_auc"]}}
            results_['test'] = []
            continue
    
    results[task] = results_['metrics']
    results[task]['count'] = len(results_['test'])

    # Add metrics to logfile
    # This will overwrite the file, but that's fine, as we have results still in memory.
    # We have intermediate this way.
    # Of course, not good if we restart from halfway without moving the file.
    df = pd.DataFrame.from_dict(results, orient='index')
    df.to_csv(f"task-specific-finetuning/results-small.csv")
    print(f"Tested on {task}")

    # Delete the model to free up space
    shutil.rmtree(f"task-specific-finetuning/{task}")