In [1]:
# Import Packages

import numpy as np
import matplotlib.pyplot as plt
import datasets
import torch
from datetime import datetime

from sklearn.metrics import accuracy_score, precision_recall_fscore_support, f1_score
from transformers import LongformerTokenizer, LongformerForSequenceClassification, Trainer, TrainingArguments, EarlyStoppingCallback

In [2]:
torch.cuda.is_available()

True

# Load the dataset

In [3]:
dataset = datasets.load_from_disk("data")
dataset

DatasetDict({
    train: Dataset({
        features: ['selftext', 'label'],
        num_rows: 3984
    })
    test: Dataset({
        features: ['selftext', 'label'],
        num_rows: 498
    })
    valid: Dataset({
        features: ['selftext', 'label'],
        num_rows: 498
    })
})

# Let's Torch It

In [4]:
# Define 4096 as our maximum sentence length
MAX_LEN = 4096

In [5]:
# Import model and tokenizer

model = LongformerForSequenceClassification.from_pretrained('AIMH/mental-longformer-base-4096', gradient_checkpointing=True, attention_window = 512, num_labels=6)
tokenizer = LongformerTokenizer.from_pretrained('AIMH/mental-longformer-base-4096', max_length = MAX_LEN)

Some weights of LongformerForSequenceClassification were not initialized from the model checkpoint at AIMH/mental-longformer-base-4096 and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
# Define tokenization function

def tokenization(text):
    return tokenizer(text["selftext"], padding='max_length', truncation=True, max_length=MAX_LEN)

In [7]:
# Tokenize data

dataset["train"] = dataset["train"].map(tokenization)
dataset["valid"] = dataset["valid"].map(tokenization)
dataset["test"] = dataset["test"].map(tokenization)

In [8]:
# Sanity check: make sure our tokenization follows our max sentence length

len(dataset["valid"][4]["input_ids"])

4096

In [9]:
# Convert to Pytorch Tensor

dataset["train"].set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
dataset["valid"].set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
dataset["test"].set_format('torch', columns=['input_ids', 'attention_mask', 'label'])

In [10]:
dataset["train"][1]

{'label': tensor(5),
 'input_ids': tensor([  0, 100, 356,  ...,   1,   1,   1]),
 'attention_mask': tensor([1, 1, 1,  ..., 0, 0, 0])}

In [11]:
# Define Accuracy Metrics

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, weighted_f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
    macro_f1 = f1_score(labels, preds, average="macro")
    class_f1 = f1_score(labels, preds, average=None)
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'weighted_f1': weighted_f1,
        'macro_f1': macro_f1,
        'class_f1': class_f1,
        'precision': precision,
        'recall': recall
    }

In [12]:
# Define Training Arguments
training_args = TrainingArguments(
    learning_rate = 0.00003040357553681075,
    warmup_steps = 75,
    weight_decay = 0.000026920603365719054,

    output_dir = "frames",
    num_train_epochs = 5,
    per_device_train_batch_size = 4,
    gradient_accumulation_steps = 8,
    per_device_eval_batch_size = 16,
    eval_strategy = "epoch",
    save_strategy = "epoch",
    disable_tqdm = False, 
    load_best_model_at_end = True,
    logging_strategy = "epoch",
    fp16 = True,
    dataloader_num_workers = 0,
)

# Define Early Stopping callback
early_stopping = EarlyStoppingCallback(3, 0.05)

# Define Trainer
trainer = Trainer(
    model = model,
    args = training_args,
    # callbacks = [early_stopping],
    compute_metrics = compute_metrics,
    train_dataset = dataset["train"],
    eval_dataset = dataset["valid"],
)

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


In [13]:
import os
os.environ["WANDB_DISABLED"] = "true"

In [14]:
trainer.train()

  0%|          | 0/620 [00:00<?, ?it/s]

Initializing global attention on CLS token...
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 1.8107, 'grad_norm': 2.6023309230804443, 'learning_rate': 1.6215240286299066e-06, 'epoch': 0.03}
{'loss': 1.7843, 'grad_norm': 1.781083345413208, 'learning_rate': 3.2430480572598132e-06, 'epoch': 0.06}
{'loss': 1.7882, 'grad_norm': 2.892961263656616, 'learning_rate': 4.86457208588972e-06, 'epoch': 0.1}
{'loss': 1.7891, 'grad_norm': 2.2702126502990723, 'learning_rate': 6.4860961145196265e-06, 'epoch': 0.13}
{'loss': 1.777, 'grad_norm': 2.9596574306488037, 'learning_rate': 8.107620143149534e-06, 'epoch': 0.16}
{'loss': 1.7464, 'grad_norm': 3.7651631832122803, 'learning_rate': 9.72914417177944e-06, 'epoch': 0.19}
{'loss': 1.7525, 'grad_norm': 3.0482475757598877, 'learning_rate': 1.1350668200409346e-05, 'epoch': 0.22}
{'loss': 1.7067, 'grad_norm': 3.006740093231201, 'learning_rate': 1.2972192229039253e-05, 'epoch': 0.26}
{'loss': 1.7038, 'grad_norm': 2.573103666305542, 'learning_rate': 1.4593716257669158e-05, 'epoch': 0.29}
{'loss': 1.7027, 'grad_norm': 2.9790759086608887, 'learni

  0%|          | 0/32 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 1.303433895111084, 'eval_accuracy': 0.4859437751004016, 'eval_macro_f1': 0.382708874076694, 'eval_precision': 0.4205210068961778, 'eval_recall': 0.39902089925138723, 'eval_runtime': 76.1914, 'eval_samples_per_second': 6.536, 'eval_steps_per_second': 0.42, 'epoch': 1.0}


  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 1.2632, 'grad_norm': 5.168894290924072, 'learning_rate': 2.7614256680222606e-05, 'epoch': 1.03}
{'loss': 1.1393, 'grad_norm': 7.060086727142334, 'learning_rate': 2.7391111171695557e-05, 'epoch': 1.06}
{'loss': 1.1994, 'grad_norm': 8.357739448547363, 'learning_rate': 2.7167965663168504e-05, 'epoch': 1.09}
{'loss': 1.32, 'grad_norm': 7.615421772003174, 'learning_rate': 2.694482015464145e-05, 'epoch': 1.12}
{'loss': 1.139, 'grad_norm': 6.240644454956055, 'learning_rate': 2.6721674646114402e-05, 'epoch': 1.16}
{'loss': 1.1268, 'grad_norm': 6.5728302001953125, 'learning_rate': 2.649852913758735e-05, 'epoch': 1.19}
{'loss': 1.1964, 'grad_norm': 10.639800071716309, 'learning_rate': 2.6275383629060297e-05, 'epoch': 1.22}
{'loss': 1.1651, 'grad_norm': 8.004168510437012, 'learning_rate': 2.6052238120533247e-05, 'epoch': 1.25}
{'loss': 1.0926, 'grad_norm': 5.673391819000244, 'learning_rate': 2.5829092612006195e-05, 'epoch': 1.29}
{'loss': 1.2398, 'grad_norm': 6.60680627822876, 'learning_

  0%|          | 0/32 [00:00<?, ?it/s]

{'eval_loss': 1.2038065195083618, 'eval_accuracy': 0.5502008032128514, 'eval_macro_f1': 0.5114584476631839, 'eval_precision': 0.5539686040762483, 'eval_recall': 0.4961436160050159, 'eval_runtime': 76.0986, 'eval_samples_per_second': 6.544, 'eval_steps_per_second': 0.421, 'epoch': 2.0}


  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 0.9159, 'grad_norm': 6.477229595184326, 'learning_rate': 2.0696745915884015e-05, 'epoch': 2.02}
{'loss': 0.965, 'grad_norm': 6.102255344390869, 'learning_rate': 2.0473600407356962e-05, 'epoch': 2.06}
{'loss': 0.9062, 'grad_norm': 6.271117687225342, 'learning_rate': 2.025045489882991e-05, 'epoch': 2.09}
{'loss': 0.9126, 'grad_norm': 9.590336799621582, 'learning_rate': 2.002730939030286e-05, 'epoch': 2.12}
{'loss': 0.8331, 'grad_norm': 5.9637556076049805, 'learning_rate': 1.9804163881775807e-05, 'epoch': 2.15}
{'loss': 0.8489, 'grad_norm': 6.537513256072998, 'learning_rate': 1.9581018373248758e-05, 'epoch': 2.18}
{'loss': 0.8918, 'grad_norm': 9.141011238098145, 'learning_rate': 1.9357872864721706e-05, 'epoch': 2.22}
{'loss': 0.9263, 'grad_norm': 7.21077823638916, 'learning_rate': 1.9134727356194653e-05, 'epoch': 2.25}
{'loss': 0.9773, 'grad_norm': 7.451140880584717, 'learning_rate': 1.8911581847667604e-05, 'epoch': 2.28}
{'loss': 0.8984, 'grad_norm': 10.01857852935791, 'learning

  0%|          | 0/32 [00:00<?, ?it/s]

{'eval_loss': 1.274375081062317, 'eval_accuracy': 0.5261044176706827, 'eval_macro_f1': 0.4982112665427969, 'eval_precision': 0.5120357875226195, 'eval_recall': 0.5066026075568274, 'eval_runtime': 76.1514, 'eval_samples_per_second': 6.54, 'eval_steps_per_second': 0.42, 'epoch': 3.0}


  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 0.7529, 'grad_norm': 9.441596984863281, 'learning_rate': 1.3779235151545422e-05, 'epoch': 3.02}
{'loss': 0.6523, 'grad_norm': 7.753360271453857, 'learning_rate': 1.355608964301837e-05, 'epoch': 3.05}
{'loss': 0.8149, 'grad_norm': 10.053030967712402, 'learning_rate': 1.333294413449132e-05, 'epoch': 3.08}
{'loss': 0.7559, 'grad_norm': 9.036859512329102, 'learning_rate': 1.3109798625964269e-05, 'epoch': 3.12}
{'loss': 0.7304, 'grad_norm': 10.925121307373047, 'learning_rate': 1.294243949456898e-05, 'epoch': 3.15}
{'loss': 0.7645, 'grad_norm': 7.493333339691162, 'learning_rate': 1.271929398604193e-05, 'epoch': 3.18}
{'loss': 0.5969, 'grad_norm': 8.632965087890625, 'learning_rate': 1.2496148477514877e-05, 'epoch': 3.21}
{'loss': 0.7149, 'grad_norm': 11.182659149169922, 'learning_rate': 1.2273002968987826e-05, 'epoch': 3.24}
{'loss': 0.7097, 'grad_norm': 12.732891082763672, 'learning_rate': 1.2049857460460775e-05, 'epoch': 3.28}
{'loss': 0.6943, 'grad_norm': 7.719536304473877, 'learn

  0%|          | 0/32 [00:00<?, ?it/s]

{'eval_loss': 1.3440359830856323, 'eval_accuracy': 0.5481927710843374, 'eval_macro_f1': 0.5219838462656222, 'eval_precision': 0.5325353783126302, 'eval_recall': 0.522984436686175, 'eval_runtime': 75.9329, 'eval_samples_per_second': 6.558, 'eval_steps_per_second': 0.421, 'epoch': 4.0}


  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 0.6687, 'grad_norm': 9.048704147338867, 'learning_rate': 6.917510764338593e-06, 'epoch': 4.02}
{'loss': 0.5479, 'grad_norm': 6.5219244956970215, 'learning_rate': 6.694365255811541e-06, 'epoch': 4.05}
{'loss': 0.6723, 'grad_norm': 10.491683959960938, 'learning_rate': 6.47121974728449e-06, 'epoch': 4.08}
{'loss': 0.5202, 'grad_norm': 18.045557022094727, 'learning_rate': 6.248074238757438e-06, 'epoch': 4.11}
{'loss': 0.5366, 'grad_norm': 6.857319355010986, 'learning_rate': 6.024928730230387e-06, 'epoch': 4.14}
{'loss': 0.7123, 'grad_norm': 11.709823608398438, 'learning_rate': 5.801783221703336e-06, 'epoch': 4.18}
{'loss': 0.5793, 'grad_norm': 9.210593223571777, 'learning_rate': 5.578637713176285e-06, 'epoch': 4.21}
{'loss': 0.5406, 'grad_norm': 8.278056144714355, 'learning_rate': 5.355492204649233e-06, 'epoch': 4.24}
{'loss': 0.5749, 'grad_norm': 8.653133392333984, 'learning_rate': 5.132346696122182e-06, 'epoch': 4.27}
{'loss': 0.571, 'grad_norm': 8.983782768249512, 'learning_rat

  0%|          | 0/32 [00:00<?, ?it/s]

{'eval_loss': 1.383459210395813, 'eval_accuracy': 0.5441767068273092, 'eval_macro_f1': 0.5186480786018468, 'eval_precision': 0.5251191290748253, 'eval_recall': 0.5196627467189551, 'eval_runtime': 77.3524, 'eval_samples_per_second': 6.438, 'eval_steps_per_second': 0.414, 'epoch': 4.98}
{'train_runtime': 17445.7964, 'train_samples_per_second': 1.142, 'train_steps_per_second': 0.036, 'train_loss': 0.9742839820923344, 'epoch': 4.98}


TrainOutput(global_step=620, training_loss=0.9742839820923344, metrics={'train_runtime': 17445.7964, 'train_samples_per_second': 1.142, 'train_steps_per_second': 0.036, 'total_flos': 5.212924372058112e+16, 'train_loss': 0.9742839820923344, 'epoch': 4.979919678714859})

In [15]:
# Save model

filename = "models/" + datetime.now().strftime("%d-%m-%Y-%H-%M-%S") + ".pt"
torch.save(model.state_dict(), filename)

In [16]:
# Re-evaluate model with validation split

trainer.evaluate()

  0%|          | 0/32 [00:00<?, ?it/s]

{'eval_loss': 1.2038065195083618,
 'eval_accuracy': 0.5502008032128514,
 'eval_macro_f1': 0.5114584476631839,
 'eval_precision': 0.5539686040762483,
 'eval_recall': 0.4961436160050159,
 'eval_runtime': 76.4323,
 'eval_samples_per_second': 6.516,
 'eval_steps_per_second': 0.419,
 'epoch': 4.979919678714859}

In [17]:
trainer.state.log_history

[{'loss': 1.8107,
  'grad_norm': 2.6023309230804443,
  'learning_rate': 1.6215240286299066e-06,
  'epoch': 0.0321285140562249,
  'step': 4},
 {'loss': 1.7843,
  'grad_norm': 1.781083345413208,
  'learning_rate': 3.2430480572598132e-06,
  'epoch': 0.0642570281124498,
  'step': 8},
 {'loss': 1.7882,
  'grad_norm': 2.892961263656616,
  'learning_rate': 4.86457208588972e-06,
  'epoch': 0.0963855421686747,
  'step': 12},
 {'loss': 1.7891,
  'grad_norm': 2.2702126502990723,
  'learning_rate': 6.4860961145196265e-06,
  'epoch': 0.1285140562248996,
  'step': 16},
 {'loss': 1.777,
  'grad_norm': 2.9596574306488037,
  'learning_rate': 8.107620143149534e-06,
  'epoch': 0.1606425702811245,
  'step': 20},
 {'loss': 1.7464,
  'grad_norm': 3.7651631832122803,
  'learning_rate': 9.72914417177944e-06,
  'epoch': 0.1927710843373494,
  'step': 24},
 {'loss': 1.7525,
  'grad_norm': 3.0482475757598877,
  'learning_rate': 1.1350668200409346e-05,
  'epoch': 0.2248995983935743,
  'step': 28},
 {'loss': 1.7067

In [20]:
trainer.predict(dataset["test"]).metrics

  0%|          | 0/32 [00:00<?, ?it/s]

{'test_loss': 1.1075841188430786,
 'test_accuracy': 0.606425702811245,
 'test_macro_f1': 0.5523824908573497,
 'test_precision': 0.6019554380456636,
 'test_recall': 0.53849863937603,
 'test_runtime': 88.9119,
 'test_samples_per_second': 5.601,
 'test_steps_per_second': 0.36}