In [None]:
! pip install transformers
! pip install datasets

In [1]:
import torch
import torch.nn as nn
import numpy as np
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score
from transformers import (AutoTokenizer, LongformerForSequenceClassification,
                          TrainingArguments, Trainer, EvalPrediction)
from datasets import load_dataset

In [2]:
model = LongformerForSequenceClassification.from_pretrained(
    'allenai/longformer-base-4096', num_labels = 11, 
    problem_type="multi_label_classification")
tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")
dataset = load_dataset("coastalcph/fairlex", 'ecthr')

Some weights of the model checkpoint at allenai/longformer-base-4096 were not used when initializing LongformerForSequenceClassification: ['lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing LongformerForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LongformerForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LongformerForSequenceClassification were not initialized from the model checkpoint at allenai/longformer-base-4096 and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', '

  0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
def tokenize_function(examples):
    text = examples["text"]
    encoding = tokenizer(text, padding="max_length", truncation=True)
    labels_matrix = np.zeros((len(text), 11))
    for idx, label_list in enumerate(examples['labels']):
        if len(label_list) == 0:
            labels_matrix[idx, 10] = 1.0
        else:
            for label in label_list:
                labels_matrix[idx, label] = 1.0
    encoding["vectorized_label"] = labels_matrix.tolist()
    return encoding

tokenized_dataset = dataset.map(tokenize_function, 
                                batched=True,
                                remove_columns=dataset['train'].column_names)

tokenized_dataset.set_format('torch')
train_dataset = tokenized_dataset["train"].rename_column("vectorized_label", "labels")
val_dataset = tokenized_dataset["validation"].rename_column("vectorized_label", "labels")



  0%|          | 0/1 [00:00<?, ?ba/s]



In [4]:
batch_size = 1
metric_name = "f1"
args = TrainingArguments(
    f"longformer-baseline",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=5,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
)

# source: https://jesusleal.io/2021/04/21/Longformer-multilabel-classification/
def multi_label_metrics(predictions, labels, threshold=0.5):
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(torch.Tensor(predictions))
    y_pred = np.zeros(probs.shape)
    y_pred[np.where(probs >= threshold)] = 1
    y_true = labels
    f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')
    roc_auc = roc_auc_score(y_true, y_pred, average = 'micro')
    accuracy = accuracy_score(y_true, y_pred)
    metrics = {'f1': f1_micro_average,
               'roc_auc': roc_auc,
               'accuracy': accuracy}
    return metrics

def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, 
            tuple) else p.predictions
    result = multi_label_metrics(
        predictions=preds, 
        labels=p.label_ids)
    return result

trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [5]:
trainer.train()

***** Running training *****
  Num examples = 9000
  Num Epochs = 5
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 1
  Gradient Accumulation steps = 1
  Total optimization steps = 45000
  Number of trainable parameters = 148667915
You're using a LongformerTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Initializing global attention on CLS token...


Epoch,Training Loss,Validation Loss


Initializing global attention on CLS token...
Initializing global attention on CLS token...
Initializing global attention on CLS token...
Initializing global attention on CLS token...
Initializing global attention on CLS token...
Initializing global attention on CLS token...
Initializing global attention on CLS token...
Initializing global attention on CLS token...
Initializing global attention on CLS token...
ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/IPython/core/interactiveshell.py", line 3326, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-5-3435b262f1ae>", line 1, in <module>
    trainer.train()
  File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 1543, in train
    return inner_training_loop(
  File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 1791, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 2539, in training_step
    loss = self.compute_loss(model, inputs)
  File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 2571, in compute_loss
    outputs = model(**inputs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/li

KeyboardInterrupt: ignored