In [1]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
from typing import *
from pprint import pprint

import bigbench.api.results as bb

from lass.log_handling import LogLoader, TaskLog
from lass.datasets import split_task_level, analyse, merge, huggingfaceify

In [2]:
loader = (LogLoader(logdir = Path('../artifacts/logs'))
        .with_tasks('paper-full')
        .with_model_families(['BIG-G T=0'])
        .with_model_sizes(['128b'])
        .with_shots([0])
        .with_query_types(['multiple_choice'])
)

train, test = split_task_level(loader, seed=42, test_fraction=0.2)

In [3]:
stats = merge(analyse(train), analyse(test), 'train', 'test')
pprint(stats)
train.head(1)

{'metrics': {'conf-absolute': {'roc_auc': {'test': 0.5514382335446202,
                                           'train': 0.4791575119855057}},
             'conf-normalized': {'roc_auc': {'test': 0.5903410113824089,
                                             'train': 0.5960625536885094}},
             'task-acc': {'test': 0.31281994213220565,
                          'train': 0.34074926202465705}},
 'stats': {'n_instances': {'test': 8986, 'train': 46072},
           'n_instances_nonbinary': {'test': 119, 'train': 254},
           'n_tasks': {'test': 23, 'train': 95}}}


Unnamed: 0,input,targets,scores,target_values,correct,absolute_scores,normalized_scores,metrics,task,shots
0,\nIn the SIT-adversarial world a structure is ...,[There is at least one triangle pointing down....,"[-9.594904899597168, -7.830936431884766, -8.70...","{'There are at least two blue pieces. ': 0, 'T...",1.0,"[-9.594904899597168, -7.830936431884766, -8.70...","[-2.8301148414611816, -1.0661463737487793, -1....",{'calibration_multiple_choice_brier_score': 0....,symbol_interpretation,0


In [4]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=1

import torch

print(torch.cuda.device_count())

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=1


1


In [5]:
dataset = huggingfaceify(train, test)
dataset['train'][0]


{'text': '\nIn the SIT-adversarial world a structure is a sequence of six emojis.\nHereafter are reported the emojis used along with their descriptions.\n 🔺 is a red circle;\n 🟦 is a blue circle;\n 🔴 is a yellow circle;\n 🟥 is a red triangle pointing up;\n 🟨 is a red triangle pointing down;\n 🔻 is a red square;\n 🟡 is a blue square;\n _ is a yellow square;\n 🔵 is an empty space.\n\nChoose the sentence consistent with the structure 🟥 🔻 🔺 🟡 🟥 🟨 and not consistent with 🔺 🟡 🟡 🟨 🟦 _:\n\n  choice: There are at least two triangles.\n\n  choice: There is at least one triangle.\n\n  choice: There are at least two yellow squares.\n\n  choice: There are at least two blue pieces.\n\n  choice: There is at least one triangle pointing down.\n\nA: ',
 'label': 1}

In [6]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
# tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
# tokenizer.pad_token = tokenizer.eos_token

def tokenize_function(examples):
    # return tokenizer(examples["text"], padding="max_length", truncation=True, return_tensors="pt")
    return tokenizer(examples["text"], padding="max_length", truncation=True)
    # return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=1024) # longformer
    # return tokenizer(examples["text"], padding="max_length", truncation=True, return_tensors="np") #gpt-2
    # return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=2048) # xlnet

# tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = dataset.map(tokenize_function, batched=True)



  0%|          | 0/47 [00:00<?, ?ba/s]

  0%|          | 0/9 [00:00<?, ?ba/s]

In [7]:
train_dataset = tokenized_datasets["train"].shuffle(seed=42) #.select(range(50))
eval_dataset = tokenized_datasets["test"].shuffle(seed=42) #.select(range(50))
len(train_dataset), len(eval_dataset)

(46072, 8986)

In [8]:
import wandb
%env WANDB_LOG_MODEL=true
%env TOKENIZERS_PARALLELISM=true
wandb.login()
wandb.init(
  project="lass",
  group="task-split",
  name="bert-bs8-0sh-task-split-wd2-warmup",
  tags=[
    "split:task-split",
    "assr:bert",
    "tasks:paper-full",
    "pop:single-system",
    f'shots:{",".join([str(s) for s in loader.shots])}']
)

wandb.config.pop_model_family = "BIG-G T=0"
wandb.config.pop_model_size = "128b"

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


env: WANDB_LOG_MODEL=true
env: TOKENIZERS_PARALLELISM=true


[34m[1mwandb[0m: Currently logged in as: [33mwschella[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [9]:
from transformers.models.auto.modeling_auto import AutoModelForSequenceClassification
from transformers import TrainingArguments, Trainer
from transformers import BertModel, BertConfig
import numpy as np
from datasets import load_metric
import scipy
import torch
import lass, lass.metrics.hf

# model = BertModel(BertConfig.from_pretrained("bert-base-cased"))
# model = AutoModelForSequenceClassification.from_pretrained("albert-base-v2")
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2)
# model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=2)
# model = AutoModelForSequenceClassification.from_pretrained("allenai/longformer-base-4096", num_labels=2)
# model.config.pad_token_id = model.config.eos_token_id
# model = AutoModelForSequenceClassification.from_pretrained("./test_trainer/checkpoint-13500", num_labels=2)
# model = AutoModelForSequenceClassification.from_pretrained("../artifacts/assessors/bert-bs32/checkpoint-3000", num_labels=2)

training_args = TrainingArguments(
    output_dir="bert-bs8-0sh-task-split-wd2-warmup",
    evaluation_strategy="steps",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=32,
    report_to="wandb",
    save_total_limit = 1,
    load_best_model_at_end=True,
    num_train_epochs=6,
    weight_decay=0.02,
    optim="adamw_torch",
    warmup_steps=1000,
    # gradient_accumulation_steps=4,
    # label_smoothing_factor=0.1,
)

compute_metrics = lass.metrics.hf.get_metric_computer([
  "accuracy",
  "precision",
  "recall",
  "f1",
  "roc_auc",
  "brier_score",
])

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
)

wandb.config.split = "task_split"
wandb.config.shots = loader.shots

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
torch.cuda.empty_cache()
trainer.train()

wandb.finish()

The following columns in the training set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.


***** Running training *****


  Num examples = 46072


  Num Epochs = 6


  Instantaneous batch size per device = 8


  Total train batch size (w. parallel, distributed & accumulation) = 8


  Gradient Accumulation steps = 1


  Total optimization steps = 34554


Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1,Roc Auc,Brier Score
500,0.6203,0.560974,0.71144,0.613071,0.210245,0.313113,0.700173,0.191218
1000,0.6214,0.614009,0.70276,0.561296,0.228033,0.324311,0.516725,0.209682
1500,0.6479,0.625258,0.68718,0.0,0.0,0.0,0.470835,0.216564
2000,0.6568,0.635436,0.68718,0.0,0.0,0.0,0.565723,0.221498
2500,0.6485,0.622396,0.68718,0.0,0.0,0.0,0.442231,0.215429
3000,0.65,0.622113,0.68718,0.0,0.0,0.0,0.471008,0.215303


The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.


***** Running Evaluation *****


  Num examples = 8986


  Batch size = 32


Saving model checkpoint to bert-bs8-0sh-task-split-wd2-warmup/checkpoint-500


Configuration saved in bert-bs8-0sh-task-split-wd2-warmup/checkpoint-500/config.json


Model weights saved in bert-bs8-0sh-task-split-wd2-warmup/checkpoint-500/pytorch_model.bin


Deleting older checkpoint [bert-bs8-0sh-task-split-wd2-warmup/checkpoint-2000] due to args.save_total_limit


Deleting older checkpoint [bert-bs8-0sh-task-split-wd2-warmup/checkpoint-2500] due to args.save_total_limit


The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.


***** Running Evaluation *****


  Num examples = 8986


  Batch size = 32


Saving model checkpoint to bert-bs8-0sh-task-split-wd2-warmup/checkpoint-1000


Configuration saved in bert-bs8-0sh-task-split-wd2-warmup/checkpoint-1000/config.json


Model weights saved in bert-bs8-0sh-task-split-wd2-warmup/checkpoint-1000/pytorch_model.bin


The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.


***** Running Evaluation *****


  Num examples = 8986


  Batch size = 32


  _warn_prf(average, modifier, msg_start, len(result))


Saving model checkpoint to bert-bs8-0sh-task-split-wd2-warmup/checkpoint-1500


Configuration saved in bert-bs8-0sh-task-split-wd2-warmup/checkpoint-1500/config.json


Model weights saved in bert-bs8-0sh-task-split-wd2-warmup/checkpoint-1500/pytorch_model.bin


Deleting older checkpoint [bert-bs8-0sh-task-split-wd2-warmup/checkpoint-1000] due to args.save_total_limit


The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.


***** Running Evaluation *****


  Num examples = 8986


  Batch size = 32


  _warn_prf(average, modifier, msg_start, len(result))


Saving model checkpoint to bert-bs8-0sh-task-split-wd2-warmup/checkpoint-2000


Configuration saved in bert-bs8-0sh-task-split-wd2-warmup/checkpoint-2000/config.json


Model weights saved in bert-bs8-0sh-task-split-wd2-warmup/checkpoint-2000/pytorch_model.bin


Deleting older checkpoint [bert-bs8-0sh-task-split-wd2-warmup/checkpoint-1500] due to args.save_total_limit


The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.


***** Running Evaluation *****


  Num examples = 8986


  Batch size = 32


  _warn_prf(average, modifier, msg_start, len(result))


Saving model checkpoint to bert-bs8-0sh-task-split-wd2-warmup/checkpoint-2500


Configuration saved in bert-bs8-0sh-task-split-wd2-warmup/checkpoint-2500/config.json


Model weights saved in bert-bs8-0sh-task-split-wd2-warmup/checkpoint-2500/pytorch_model.bin


Deleting older checkpoint [bert-bs8-0sh-task-split-wd2-warmup/checkpoint-2000] due to args.save_total_limit


The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.


***** Running Evaluation *****


  Num examples = 8986


  Batch size = 32


  _warn_prf(average, modifier, msg_start, len(result))


Saving model checkpoint to bert-bs8-0sh-task-split-wd2-warmup/checkpoint-3000


Configuration saved in bert-bs8-0sh-task-split-wd2-warmup/checkpoint-3000/config.json


Model weights saved in bert-bs8-0sh-task-split-wd2-warmup/checkpoint-3000/pytorch_model.bin


Deleting older checkpoint [bert-bs8-0sh-task-split-wd2-warmup/checkpoint-2500] due to args.save_total_limit


In [None]:
import wandb
# wandb.finish()