In [1]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=0


In [2]:
import os
import logging
from typing import *
from pathlib import Path
from dataclasses import dataclass
from pprint import pprint
from datetime import datetime

import wandb
import pandas as pd

from transformers.models.auto.modeling_auto import AutoModelForSequenceClassification
from transformers.trainer import Trainer
from transformers.training_args import TrainingArguments
from transformers.models.auto.tokenization_auto import AutoTokenizer

import lass.datasets
import lass.metrics
from lass.datasets import analyse, merge
from lass.log_handling import LogLoader, QueryType, QueryFunction, TaskList
from lass.train import DataArgs

In [3]:
data_args = DataArgs(
    logdir="../artifacts/logs",
    tasks="paper-lite",
    model_families=["BIG-G T=0"],
    model_sizes=["128b"],
    shots=[0],
    query_types=["multiple_choice"],
)
output_dir = None
seed = 42
test_fraction = 0.2
model_name = "albert-base-v2"
model_name_short = "albert-sanity-check"
max_sequence_length = 512
batch_size = 32
gradient_accumulation_steps = 1
is_test_run = False
uses_pop_data = False
split = 'task'
use_wandb = True
n_epochs = 10
extra_training_args = {}
# ----------------------------------------------------------

loader = LogLoader(
    logdir=data_args.logdir,
    tasks=data_args.tasks,
    model_families=data_args.model_families,
    model_sizes=data_args.model_sizes,
    shots=data_args.shots,
    query_types=data_args.query_types,
    exclude_faulty_tasks=data_args.exclude_faulty_tasks,
    include_unknown_shots=data_args.include_unknown_shots
)
data = lass.datasets.to_dataframe(loader)


# Prepend the label to the text
assert len(data[data['correct']==1]) > 0
prepender = lambda r: f"{'correct' if r['correct'] == 1 else 'wrong'}. {r['input']}"
data['input'] = data.apply(prepender, axis=1)

In [4]:
train, test = lass.datasets.split_task_level(data, seed, test_fraction)


# Log some stats & examples
stats = merge(analyse(train), analyse(test), 'train', 'test')
pprint(stats)
print(train.head(1))

# Huggingfaceify
dataset = lass.datasets.huggingfaceify(train, test)
print(dataset['train'][0])

# Tokenize dataset
os.environ['TOKENIZERS_PARALLELISM'] = "true"
tokenizer = AutoTokenizer.from_pretrained(model_name)
if model_name == "gpt2":
    tokenizer.pad_token = tokenizer.eos_token

def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
        max_length=max_sequence_length,
        return_tensors="np"
    )
tokenized_datasets = dataset.map(tokenize_function, batched=True)

train_dataset = tokenized_datasets["train"].shuffle(seed=42)  # .select(range(50))
eval_dataset = tokenized_datasets["test"].shuffle(seed=42)  # .select(range(50))
len(train_dataset), len(eval_dataset)

# Setup tagging and paths
model_name_short = model_name_short or model_name
shot_str = ','.join([str(s) for s in loader.shots or []]) if data_args.shots else "all"
bs = batch_size if gradient_accumulation_steps == 1 else f"{batch_size}*{gradient_accumulation_steps}"
name = f"{model_name_short}"
# name = ""\
#     + (f"test-" if is_test_run else "")\
#     + (f"{model_name_short}")\
#     + (f"-bs{bs}")\
#     + (f"-{shot_str}sh")\
#     + (f'-pop' if uses_pop_data else '')\
#     + (f"-{split}-split")

# Setup wandb
if use_wandb:
    os.environ['WANDB_LOG_MODEL'] = "false"
    wandb.login()
    wandb.init(
        project="lass",
        dir=f"{output_dir or '.'}/wandb",
        group=f"{split}-split{'pop-' if uses_pop_data else ''}",
        name=name,
        mode="disabled" if is_test_run else "online",
        tags=[
            f"split:{split}-split",
            f"assr:{model_name_short}",
            f"tasks:{data_args.tasks}",
            f"pop:{'yes' if uses_pop_data else 'no'}",
            f"shots:{shot_str}",
        ]
    )

    wandb.config.seed = seed
    wandb.config.is_test_run = is_test_run
    wandb.config.stats = stats
    wandb.config.data = {
        'query_types': ",".join(data_args.query_types or []),
        'tasks': data_args.tasks,
        'test_fraction': test_fraction,
        'split': split,
        'shots': shot_str,
        'pop_model_family': data_args.model_families,
        'pop_model_size': data_args.model_sizes,
    }
    wandb.config.extra_training_args = extra_training_args

# Setup trainer
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

training_args = TrainingArguments(
    output_dir=f"{output_dir or '.'}/{name}-{datetime.now().strftime('%m%d%H%M')}",
    optim="adamw_torch",  # type: ignore
    evaluation_strategy="steps",  # type: ignore
    report_to="wandb" if use_wandb else "none",  # type: ignore
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    save_total_limit=1,
    load_best_model_at_end=True,
    num_train_epochs=n_epochs,
    **extra_training_args
)

compute_metrics = lass.metrics.hf.get_metric_computer([
    "accuracy",
    "precision",
    "recall",
    "f1",
    "roc_auc",
    "brier_score",
    "balanced_accuracy",
])  # + ["wandb_conf_matrix"] if use_wandb else [])

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,  # type: ignore
    eval_dataset=eval_dataset,  # type: ignore
    compute_metrics=compute_metrics,
)

trainer.train()

if use_wandb:
    wandb.finish()

{'metrics': {'conf-absolute': {'acc': {'test': 0.6526972804279982,
                                       'train': 0.6373460376481525},
                               'balanced_acc': {'test': 0.5190085055135297,
                                                'train': 0.5053411418330773},
                               'bs': {'test': 0.24216131475156374,
                                      'train': 0.315787020330402},
                               'bs_dcr': {'test': 0.0017000939183387798,
                                          'train': 0.011701618165291278},
                               'bs_mcb': {'test': 0.032605267058083615,
                                          'train': 0.09543828628045933},
                               'bs_unc': {'test': 0.2112561416118189,
                                          'train': 0.23205035221523396},
                               'roc_auc': {'test': 0.48988370780173873,
                                           'train': 0.588642473118279



  0%|          | 0/9 [00:00<?, ?ba/s]

  0%|          | 0/3 [00:00<?, ?ba/s]

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mwschella[0m. Use [1m`wandb login --relogin`[0m to force relogin


Some weights of the model checkpoint at albert-base-v2 were not used when initializing AlbertForSequenceClassification: ['predictions.LayerNorm.weight', 'predictions.dense.bias', 'predictions.decoder.weight', 'predictions.dense.weight', 'predictions.LayerNorm.bias', 'predictions.bias', 'predictions.decoder.bias']
- This IS expected if you are initializing AlbertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing AlbertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Some weights of AlbertForSequenceClassification were not initialized from the model checkpoint at albert-base-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


The following columns in the training set don't have a corresponding argument in `AlbertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `AlbertForSequenceClassification.forward`,  you can safely ignore this message.


***** Running training *****


  Num examples = 8606


  Num Epochs = 10


  Instantaneous batch size per device = 32


  Total train batch size (w. parallel, distributed & accumulation) = 32


  Gradient Accumulation steps = 1


  Total optimization steps = 2690


Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1,Roc Auc,Bs,Bs Mcb,Bs Dsc,Bs Unc,Balanced Accuracy
500,0.0055,8e-06,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.211256,0.211256,1.0
1000,0.0,4e-06,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.211256,0.211256,1.0
1500,0.0,2e-06,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.211256,0.211256,1.0
2000,0.0,2e-06,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.211256,0.211256,1.0
2500,0.0,2e-06,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.211256,0.211256,1.0


The following columns in the evaluation set don't have a corresponding argument in `AlbertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `AlbertForSequenceClassification.forward`,  you can safely ignore this message.


***** Running Evaluation *****


  Num examples = 2243


  Batch size = 32


Saving model checkpoint to ./albert-sanity-check-07081127/checkpoint-500


Configuration saved in ./albert-sanity-check-07081127/checkpoint-500/config.json


Model weights saved in ./albert-sanity-check-07081127/checkpoint-500/pytorch_model.bin


The following columns in the evaluation set don't have a corresponding argument in `AlbertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `AlbertForSequenceClassification.forward`,  you can safely ignore this message.


***** Running Evaluation *****


  Num examples = 2243


  Batch size = 32


Saving model checkpoint to ./albert-sanity-check-07081127/checkpoint-1000


Configuration saved in ./albert-sanity-check-07081127/checkpoint-1000/config.json


Model weights saved in ./albert-sanity-check-07081127/checkpoint-1000/pytorch_model.bin


The following columns in the evaluation set don't have a corresponding argument in `AlbertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `AlbertForSequenceClassification.forward`,  you can safely ignore this message.


***** Running Evaluation *****


  Num examples = 2243


  Batch size = 32


Saving model checkpoint to ./albert-sanity-check-07081127/checkpoint-1500


Configuration saved in ./albert-sanity-check-07081127/checkpoint-1500/config.json


Model weights saved in ./albert-sanity-check-07081127/checkpoint-1500/pytorch_model.bin


Deleting older checkpoint [albert-sanity-check-07081127/checkpoint-500] due to args.save_total_limit


The following columns in the evaluation set don't have a corresponding argument in `AlbertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `AlbertForSequenceClassification.forward`,  you can safely ignore this message.


***** Running Evaluation *****


  Num examples = 2243


  Batch size = 32


Saving model checkpoint to ./albert-sanity-check-07081127/checkpoint-2000


Configuration saved in ./albert-sanity-check-07081127/checkpoint-2000/config.json


Model weights saved in ./albert-sanity-check-07081127/checkpoint-2000/pytorch_model.bin


Deleting older checkpoint [albert-sanity-check-07081127/checkpoint-1000] due to args.save_total_limit


The following columns in the evaluation set don't have a corresponding argument in `AlbertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `AlbertForSequenceClassification.forward`,  you can safely ignore this message.


***** Running Evaluation *****


  Num examples = 2243


  Batch size = 32


Saving model checkpoint to ./albert-sanity-check-07081127/checkpoint-2500


Configuration saved in ./albert-sanity-check-07081127/checkpoint-2500/config.json


Model weights saved in ./albert-sanity-check-07081127/checkpoint-2500/pytorch_model.bin


Deleting older checkpoint [albert-sanity-check-07081127/checkpoint-1500] due to args.save_total_limit




Training completed. Do not forget to share your model on huggingface.co/models =)




Loading best model from ./albert-sanity-check-07081127/checkpoint-2500 (score: 1.6604243455731194e-06).


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/accuracy,▁▁▁▁▁
eval/balanced_accuracy,▁▁▁▁▁
eval/bs,█▂▁▁▁
eval/bs_dsc,▁▁▁▁▁
eval/bs_mcb,█▂▁▁▁
eval/bs_unc,▁▁▁▁▁
eval/f1,▁▁▁▁▁
eval/loss,█▃▂▁▁
eval/precision,▁▁▁▁▁
eval/recall,▁▁▁▁▁

0,1
eval/accuracy,1.0
eval/balanced_accuracy,1.0
eval/bs,0.0
eval/bs_dsc,0.21126
eval/bs_mcb,0.0
eval/bs_unc,0.21126
eval/f1,1.0
eval/loss,0.0
eval/precision,1.0
eval/recall,1.0
