<span style="color:red; font-family:Helvetica Neue, Helvetica, Arial, sans-serif; font-size:2em;">An Exception was encountered at '<a href="#papermill-error-cell">In [2]</a>'.</span>

In [1]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=0


<span id="papermill-error-cell" style="color:red; font-family:Helvetica Neue, Helvetica, Arial, sans-serif; font-size:2em;">Execution using papermill encountered an exception here and stopped:</span>

In [2]:
import os
import logging
from typing import *
from pathlib import Path
from dataclasses import dataclass
from pprint import pprint
from datetime import datetime

import wandb
import pandas as pd

from transformers.models.auto.modeling_auto import AutoModelForSequenceClassification
from transformers.trainer import Trainer
from transformers.training_args import TrainingArguments
from transformers.models.auto.tokenization_auto import AutoTokenizer

import lass.datasets
import lass.metrics
from lass.datasets import analyse, merge
from lass.log_handling import LogLoader, QueryType, QueryFunction, TaskList
from lass.train import DataArgs

ImportError: cannot import name 'analyse' from 'lass.datasets' (/home/wout/pp/lass/src/lass/datasets.py)

In [None]:
data_args = DataArgs(
    logdir="../artifacts/logs",
    tasks="paper-lite",
    model_families=["BIG-G T=0"],
    model_sizes=["128b"],
    shots=[0],
    query_types=["multiple_choice"],
)
output_dir = None
seed = 42
test_fraction = 0.2
model_name = "albert-base-v2"
model_name_short = "albert-sanity-check"
max_sequence_length = 512
batch_size = 32
gradient_accumulation_steps = 1
is_test_run = False
uses_pop_data = False
split = 'task'
use_wandb = True
n_epochs = 10
extra_training_args = {}
# ----------------------------------------------------------

loader = LogLoader(
    logdir=data_args.logdir,
    tasks=data_args.tasks,
    model_families=data_args.model_families,
    model_sizes=data_args.model_sizes,
    shots=data_args.shots,
    query_types=data_args.query_types,
    exclude_faulty_tasks=data_args.exclude_faulty_tasks,
    include_unknown_shots=data_args.include_unknown_shots
)
data = lass.datasets.to_dataframe(loader)


# Prepend the label to the text
assert len(data[data['correct']==1]) > 0
prepender = lambda r: f"{'correct' if r['correct'] == 1 else 'wrong'}. {r['input']}"
data['input'] = data.apply(prepender, axis=1)

In [None]:
train, test = lass.datasets.split_task_level(data, seed, test_fraction)


# Log some stats & examples
stats = merge(analyse(train), analyse(test), 'train', 'test')
pprint(stats)
print(train.head(1))

# Huggingfaceify
dataset = lass.datasets.huggingfaceify(train, test)
print(dataset['train'][0])

# Tokenize dataset
os.environ['TOKENIZERS_PARALLELISM'] = "true"
tokenizer = AutoTokenizer.from_pretrained(model_name)
if model_name == "gpt2":
    tokenizer.pad_token = tokenizer.eos_token

def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
        max_length=max_sequence_length,
        return_tensors="np"
    )
tokenized_datasets = dataset.map(tokenize_function, batched=True)

train_dataset = tokenized_datasets["train"].shuffle(seed=42)  # .select(range(50))
eval_dataset = tokenized_datasets["test"].shuffle(seed=42)  # .select(range(50))
len(train_dataset), len(eval_dataset)

# Setup tagging and paths
model_name_short = model_name_short or model_name
shot_str = ','.join([str(s) for s in loader.shots or []]) if data_args.shots else "all"
bs = batch_size if gradient_accumulation_steps == 1 else f"{batch_size}*{gradient_accumulation_steps}"
name = f"{model_name_short}"
# name = ""\
#     + (f"test-" if is_test_run else "")\
#     + (f"{model_name_short}")\
#     + (f"-bs{bs}")\
#     + (f"-{shot_str}sh")\
#     + (f'-pop' if uses_pop_data else '')\
#     + (f"-{split}-split")

# Setup wandb
if use_wandb:
    os.environ['WANDB_LOG_MODEL'] = "false"
    wandb.login()
    wandb.init(
        project="lass",
        dir=f"{output_dir or '.'}/wandb",
        group=f"{split}-split{'pop-' if uses_pop_data else ''}",
        name=name,
        mode="disabled" if is_test_run else "online",
        tags=[
            f"split:{split}-split",
            f"assr:{model_name_short}",
            f"tasks:{data_args.tasks}",
            f"pop:{'yes' if uses_pop_data else 'no'}",
            f"shots:{shot_str}",
        ]
    )

    wandb.config.seed = seed
    wandb.config.is_test_run = is_test_run
    wandb.config.stats = stats
    wandb.config.data = {
        'query_types': ",".join(data_args.query_types or []),
        'tasks': data_args.tasks,
        'test_fraction': test_fraction,
        'split': split,
        'shots': shot_str,
        'pop_model_family': data_args.model_families,
        'pop_model_size': data_args.model_sizes,
    }
    wandb.config.extra_training_args = extra_training_args

# Setup trainer
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

training_args = TrainingArguments(
    output_dir=f"{output_dir or '.'}/{name}-{datetime.now().strftime('%m%d%H%M')}",
    optim="adamw_torch",  # type: ignore
    evaluation_strategy="steps",  # type: ignore
    report_to="wandb" if use_wandb else "none",  # type: ignore
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    save_total_limit=1,
    load_best_model_at_end=True,
    num_train_epochs=n_epochs,
    **extra_training_args
)

compute_metrics = lass.metrics.hf.get_metric_computer([
    "accuracy",
    "precision",
    "recall",
    "f1",
    "roc_auc",
    "brier_score",
    "balanced_accuracy",
])  # + ["wandb_conf_matrix"] if use_wandb else [])

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,  # type: ignore
    eval_dataset=eval_dataset,  # type: ignore
    compute_metrics=compute_metrics,
)

trainer.train()

if use_wandb:
    wandb.finish()