Hello Fellow Kagglers,

There are many Longformer notebooks out there, but I could not find one using gradient accumulation so here it is.

**Why do we need Gradient Accumulation?**

With "just" 16GB of GPU memory the Longformer-base model can be trained with a batch size of 1. With gradient accumulation an infinite large batch size can be imitated. This will result in more precise weight updates, as low quality samples, for example by errornous notations, will be averaged out by the high quality samples. With a batch size of 1 these low quality samples would result in poor weight updates. With a large batch size of, for example 8, all 8 samples would have to be of low quality, which is unlikely, to result in the same poor weight update.

**How does Gradient Accumulation Work?**

Gradient accumulation is the process of making several forward passes, computing the gradients each forward pass and averaging the gradients before doing a single backpropogation based on the average gradients. The final gradient is thus computed as:

$$average\ gradient = \sum^{N}_{i=0} \frac{gradient_{i}}{N}$$

This is exactly the same as with an actual mini batch size of N, but now computed over several forward passes instead of 1.


[Preprocessing Notebook](https://www.kaggle.com/markwijkhuizen/preprocessing-oversampling)

In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow.keras.backend as K
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
import seaborn as sns

from tqdm.notebook import tqdm
from transformers import TFLongformerModel
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import train_test_split
from tensorflow.keras import mixed_precision
from multiprocessing import cpu_count

import glob
import sys
import os
import random
import logging
import math
import os
import time

# Disable Tensorflow Logs/Warnings
tf.get_logger().setLevel(logging.ERROR)

print(f'tensorflow version: {tf.__version__}')
print(f'tensorflow keras version: {tf.keras.__version__}')
print(f'python version: P{sys.version}')

# Seed Everything for Deterministic Behaviour

In [None]:
# Seed all random sources
def set_seeds(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    tf.random.set_seed(seed)
    np.random.seed(seed)
    
set_seeds(42)

# Load Train

In [None]:
# Column Data Types
dtype = {
    'id': 'string',
    'discourse_id': np.uint64,
    'discourse_start': np.uint16,
    'discourse_end': np.uint16,
    'discourse_text': 'string',
    'discourse_type': 'category',
    'discourse_type_num': 'category',
    'predictionstring': 'string',
}

train = pd.read_csv('/kaggle/input/feedback-prize-2021/train.csv', dtype=dtype)

display(train.head())

display(train.info())

# Labels

Each discourse type has a seperate label. In addition a not-annotated label is added for words which are not annotated and and padding label is added.

In [None]:
# Get all labels sorted for reproducibility
LABELS = train['discourse_type'].unique().sort_values().tolist() + ['Not Annotated', 'Padding']
# Add extra non-annotated and padding label
N_LABELS = len(LABELS)

# Not Annotated Class
NA_CLASS = N_LABELS - 2

# Padding Class
PAD_CLASS = N_LABELS - 1

# Number of Non-Pad Labels
N_NON_PAD_LABELS = N_LABELS -1

print(f'N_LABELS: {N_LABELS}, NA_CLASS: {NA_CLASS}, PAD_CLASS: {PAD_CLASS}, N_NON_PAD_LABELS: {N_NON_PAD_LABELS}')
print(f'LABELS: {LABELS}')

In [None]:
# Text Token Sequence Size
SEQ_LENGTH = 4096

# Dataset

In [None]:
DEBUG = False

if DEBUG:
    TAKE_N = 2048
else:
    TAKE_N = np.iinfo(np.uint64).max

In [None]:
# === TRAIN ===
X_train_input_ids = np.load('/kaggle/input/feedback-prize-preprocessing-oversampling/train/train_tokens.npy')[:TAKE_N]
X_train_attention_masks = np.load('/kaggle/input/feedback-prize-preprocessing-oversampling/train/train_attention_masks.npy')[:TAKE_N]
y_train = np.load('/kaggle/input/feedback-prize-preprocessing-oversampling/train/train_labels.npy')[:TAKE_N]

# === VAL ===
X_val_input_ids = np.load('/kaggle/input/feedback-prize-preprocessing-oversampling/val/val_tokens.npy')[:TAKE_N]
X_val_attention_masks = np.load('/kaggle/input/feedback-prize-preprocessing-oversampling/val/val_attention_masks.npy')[:TAKE_N]
y_val = np.load('/kaggle/input/feedback-prize-preprocessing-oversampling/val/val_labels.npy')[:TAKE_N]

print(f'X_train_input_ids shape: {X_train_input_ids.shape}, X_train_attention_masks shape: {X_train_attention_masks.shape}, y_train shape: {y_train.shape}')
print(f'X_train_input_ids dtype: {X_train_input_ids.dtype}, X_train_attention_masks dtype: {X_train_attention_masks.dtype}, y_train dtype: {y_train.dtype}')
print(f'X_val_input_ids shape: {X_val_input_ids.shape}, X_val_attention_masks shape: {X_val_attention_masks.shape}, y_train shape: {y_val.shape}')
print(f'X_val_input_ids dtype: {X_val_input_ids.dtype}, X_val_attention_masks dtype: {X_val_attention_masks.dtype}, y_train dtype: {y_val.dtype}')

# Class Weights

Class weights to make the scale the loss of minority classes, not used for training.

In [None]:
# Hacky way of getting class count dictionary without unnanotated class
class_count = pd.Series(y_train.flatten()).value_counts().sort_index()
LABELS_COUNT_DICT = dict(list(zip(LABELS, class_count)))

CLASS_WEIGHT = dict()
for label_idx, (label, label_count) in enumerate(LABELS_COUNT_DICT.items()):
    # Assign SQRT Class Weights
    CLASS_WEIGHT[label_idx] = max(LABELS_COUNT_DICT.values()) / label_count
    
# Make DataFrame
CLASS_WEIGHT_DF = pd.DataFrame(CLASS_WEIGHT.items(), index=LABELS, columns=['LABEL_INDEX', 'CLASS_WEIGHT'])

display(CLASS_WEIGHT_DF)

In [None]:
def get_class_weight_mask(CLASS_WEIGHT_DF):
    class_weight_mask = tf.constant(CLASS_WEIGHT_DF['CLASS_WEIGHT'], dtype=tf.float32)
    class_weight_mask = tf.expand_dims(class_weight_mask, axis=0)
    class_weight_mask = tf.repeat(class_weight_mask, repeats=SEQ_LENGTH, axis=0)
    class_weight_mask = tf.expand_dims(class_weight_mask, axis=0)
    
    return class_weight_mask

    
CLASS_WEIGHT_MASK = get_class_weight_mask(CLASS_WEIGHT_DF)

# Custom Loss Function

This custom loss function allows for class weights and L2 regularization

In [None]:
@tf.function(experimental_compile=True)
def cross_entropy_class_weight(
    y_true, # True Labels One Hot Encoded
    y_pred, # Predicted Labels
    from_logits=True, # If the Predicted Labels are Logits or Probabilities
    apply_class_weight=False, # Apply Class Weights on Loss
    add_l2_loss=False, # Add Weights Loss
    mask_padding=False, # Mask Padding Tokens in Loss
        ):
    y_true_labels = tf.cast(y_true, tf.float32)
    
    
    class_weight = tf.constant(CLASS_WEIGHT_MASK)
    if from_logits:
        y_pred = tf.nn.softmax(y_pred, axis=2)
    
    # Compute Cross Entropy
    ce_loss = K.log(1e-9+ y_pred) * y_true_labels
    # Apply Class Weight
    if apply_class_weight:
        ce_loss *= class_weight
          
    # Sum Each Token
    ce_loss = K.sum(ce_loss, axis=2)
    
    # Nullify Padding Loss
    if mask_padding:
        y_true_attention_mask = tf.cast(y_true_labels[:,:, PAD_CLASS] == 0, tf.float32)
        ce_loss = ce_loss * y_true_attention_mask
        # Sum for each sample and compute mean
        ce_loss = -K.sum(ce_loss) / K.sum(y_true_attention_mask)
    else:
        # Sum for each sample and compute mean
        ce_loss = -K.mean(ce_loss)
    
    # L2 Loss
    if add_l2_loss:
        trainable_weights = [w for w in model.weights if ('bias' not in w.name and 'tf_longformer_model' not in w.name)]
        l2_loss = tf.add_n([tf.nn.l2_loss(w) for w in trainable_weights])
    else:
        l2_loss = tf.constant(0.0, dtype=tf.float32)
      
    return ce_loss + l2_loss

# F1 Score Without Padding

This custom F1 score ignores the padding class

In [None]:
F1Score = tfa.metrics.F1Score(num_classes=N_NON_PAD_LABELS, average='macro')

@tf.function()
def non_pad_f1(y_true, y_pred):
    # Output Logits to Probabilities
    y_pred = tf.nn.softmax(y_pred, axis=2)
    
    # Cast labels to float32
    y_true = tf.cast(y_true, tf.float32)
        
    # Filter Pad Tokens
    non_pad_idxs = tf.where(tf.argmax(y_true, axis=2) != N_LABELS - 1)

    # Remove Non-Pad Row
    y_true = tf.slice(y_true, [0,0,0], [BATCH_SIZE, SEQ_LENGTH, N_NON_PAD_LABELS])
    y_pred = tf.slice(y_pred, [0,0,0], [BATCH_SIZE, SEQ_LENGTH, N_NON_PAD_LABELS])

    # Gather Non-Pad Predictions
    y_true = tf.gather_nd(y_true, non_pad_idxs)
    y_pred = tf.gather_nd(y_pred, non_pad_idxs)

    return F1Score(y_true, y_pred)

# Categorical Accuracy Without Padding

This custom categorical accuracy metric ignores the padding class 

In [None]:
CategoricalAccuracy = tf.keras.metrics.CategoricalAccuracy(name='accuracy')

@tf.function()
def non_pad_categorical_accuracy(y_true, y_pred):
    # Output Logits to Probabilities
    y_pred = tf.nn.softmax(y_pred, axis=2)
    
    # Cast labels to float32
    y_true = tf.cast(y_true, tf.float32)
        
    # Filter Pad Tokens
    non_pad_idxs = tf.where(tf.argmax(y_true, axis=2) != N_LABELS - 1)

    # Remove Non-Pad Row
    y_true = tf.slice(y_true, [0,0,0], [BATCH_SIZE, SEQ_LENGTH, N_NON_PAD_LABELS])
    y_pred = tf.slice(y_pred, [0,0,0], [BATCH_SIZE, SEQ_LENGTH, N_NON_PAD_LABELS])

    # Gather Non-Pad Predictions
    y_true = tf.gather_nd(y_true, non_pad_idxs)
    y_pred = tf.gather_nd(y_pred, non_pad_idxs)

    return CategoricalAccuracy(y_true, y_pred)

# Weights L2 Distance

returns the L2 distance of all trainable weights ignoring biases

In [None]:
@tf.function(experimental_compile=True)
def weights_l2(y_true, y_pred):
    # Exclude bias weights and non-trainable weights
    trainable_weights = [w for w in model.weights if 'bias' not in w.name and w.trainable]
    l2_loss = tf.add_n([tf.nn.l2_loss(w) for w in trainable_weights])
    return l2_loss

# Longformer Model

Longformer model with two fully connected layers

In [None]:
def get_model():
    # Weights Initialization and Activation Dense Layers
    activation = tf.keras.activations.relu
    kernel_initializer = tf.keras.initializers.HeUniform
    
    # Clear Backend
    tf.keras.backend.clear_session()
    
    # enable XLA optmizations
    tf.config.optimizer.set_jit(True)

    input_ids = tf.keras.layers.Input(shape = (SEQ_LENGTH), dtype=tf.int32, name='input_ids')
    attention_mask = tf.keras.layers.Input(shape = (SEQ_LENGTH), dtype=tf.int32, name='attention_mask')

    # Longformer-base model
    longformer = TFLongformerModel.from_pretrained(
        'allenai/longformer-base-4096',
        output_hidden_states=True, # return hidden states, we need the raw output
        return_dict=True,
    )
    
    # Global and pooler layers are not trained, thus set the weights untrainable
    for w in longformer.trainable_weights:
        if 'global' in w.name or 'pooler/dense' in w.name:
            w._trainable = False
    
    # Get the last hidden state
    last_hidden_state = longformer(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state

    # Output shape (768 for base and 1024 for large)
    last_hidden_state_shape = last_hidden_state.shape[-1]

    # Two fully connected layers and an output layer
    fc1 = tf.keras.layers.Dense(last_hidden_state_shape, activation=activation, kernel_initializer=kernel_initializer, name='head/fc1')(last_hidden_state)
    fc2 = tf.keras.layers.Dense(last_hidden_state_shape // 4, activation=activation, kernel_initializer=kernel_initializer, name='head/fc2')(fc1)
    output = tf.keras.layers.Dense(N_LABELS, name='head/classifier')(fc2)

    model = tf.keras.models.Model(inputs=[input_ids, attention_mask], outputs=output)

    # LOSS
    loss = cross_entropy_class_weight

    # OPTIMIZER
    optimizer = tf.keras.optimizers.Adam(learning_rate=5e-5, epsilon=1e-6)

    # METRICS
    metrics = [
        tf.keras.metrics.CategoricalCrossentropy(from_logits=True, name='categorical_crossentropy'),
        non_pad_categorical_accuracy,
        non_pad_f1,
        weights_l2,
    ]

    model.compile(optimizer=optimizer, loss=loss, metrics=metrics) 

    return model

model = get_model()

In [None]:
# Show model summary
model.summary()

In [None]:
# Show model architecture
tf.keras.utils.plot_model(model, show_shapes=True, show_dtype=True, show_layer_names=True, expand_nested=False)

# Train Configuration

In [None]:
# Training configuration
BATCH_SIZE = 1
N_EPOCHS = 1

print(f'BATCH SIZE: {BATCH_SIZE}, N_EPOCHS: {N_EPOCHS}')

# Train/Validation Split

In [None]:
N_TRAIN_SAMPLES = len(y_train)
N_VAL_SAMPLES = len(y_val)

TRAIN_STEPS_PER_EPOCH = math.ceil(N_TRAIN_SAMPLES / BATCH_SIZE)
VAL_STEPS_PER_EPOCH = math.ceil(N_VAL_SAMPLES / BATCH_SIZE)

print(f'N_TRAIN_SAMPLES: {N_TRAIN_SAMPLES}, N_VAL_SAMPLES: {N_VAL_SAMPLES}')
print(f'TRAIN_STEPS_PER_EPOCH: {TRAIN_STEPS_PER_EPOCH}, VAL_STEPS_PER_EPOCH: {VAL_STEPS_PER_EPOCH}')

# Dataset

In [None]:
# Labels are sparsely saved to reduce memory usage, but need to be one-hot-encoded for the loss and metrics
def one_hot_encode_labels(X, y):
    y_one_hot = tf.one_hot(y, N_LABELS, on_value=1, off_value=0, axis=2, dtype=tf.uint8)

    return X, y_one_hot

In [None]:
def get_dataset(X_input_ids, X_attention_masks, y, shuffle_repeat, bs=BATCH_SIZE):
    # Create dataset from numpy arrays
    dataset = tf.data.Dataset.from_tensor_slices((
        # Input
        { 
            'input_ids': X_input_ids,
            'attention_mask': X_attention_masks,
        },
        # Label
        y.astype(np.uint8)
    ))
    
    if shuffle_repeat:
        dataset = dataset.shuffle(len(y))
        dataset = dataset.repeat()
    
    # Batch Samples
    dataset = dataset.batch(bs, drop_remainder=True)
    # One Hot Encode Labels
    dataset = dataset.map(one_hot_encode_labels, num_parallel_calls=cpu_count())
    # Always have a batch ready
    dataset = dataset.prefetch(10)
    
    return dataset

# Train Dataset

In [None]:
# TRAIN DATASET
train_dataset = get_dataset(X_train_input_ids, X_train_attention_masks, y_train, True)

# Example of a batch
train_x, train_y = next(iter(train_dataset))
print(f'train_x keys: {list(train_x.keys())}')
print(f'train_x input ids shape: {train_x["input_ids"].shape}, train_x attention mask shape: {train_x["attention_mask"].shape}')
print(f'train_x input ids dtype: {train_x["input_ids"].dtype}, train_x attention mask dtype: {train_x["attention_mask"].dtype}')
print(f'train_y shape: {train_y.shape}, train_y dtype: {train_y.dtype}')

# Validation Dataset

In [None]:
# VALIDATION DATASET
val_dataset = get_dataset(X_val_input_ids, X_val_attention_masks, y_val, False)

# Example of a batch
val_x, val_y = next(iter(val_dataset))
print(f'val_x keys: {list(val_x.keys())}')
print(f'val_x input ids shape: {val_x["input_ids"].shape}, val_x attention mask shape: {val_x["attention_mask"].shape}')
print(f'val_x input ids dtype: {val_x["input_ids"].dtype}, val_x attention mask dtype: {val_x["attention_mask"].dtype}')
print(f'val_y shape: {val_y.shape}, val_y dtype: {val_y.dtype}')

# Reset Metrics

In [None]:
# Resets all training metrics
def metric_reset_states():
    # Reset Model Metrics
    for m in model.metrics:
        m.reset_states()
        
    # Reset Metric Computation Metrics
    F1Score.reset_states()
    CategoricalAccuracy.reset_states()

# History

Training/Validation is monitored per step and as moving average.

In [None]:
# Metrics per Step
HISTORY_STEP = dict({
    'loss': [], 'val_loss': [], 
    'mean_grad': [], 'val_mean_grad': [],
})

# Metrics Rolling Window
HISTORY_RW = dict({
    'loss': [], 'val_loss': [], 
    'mean_grad': [], 'val_mean_grad': [],
})

for m in model.compiled_metrics._metrics:
    if '__name__' in vars(m):
        name = m.__name__
    else:
        name = m.name
        
    HISTORY_STEP[f'{name}'] = []
    HISTORY_STEP[f'val_{name}'] = []
    
    HISTORY_RW[f'{name}'] = []
    HISTORY_RW[f'val_{name}'] = []

# Logs

Custom log function for the custom train function

In [None]:
def log(loss, mean_grad, step, step_total, t_start, metric_postfix, rolling_window=1024):
    # Add Loss to Metrics History
    loss = loss.numpy()
    HISTORY_STEP[f'{metric_postfix}loss'].append(loss)
    
    # Only log mean gradiant for train step
    if mean_grad > 0:
        HISTORY_STEP[f'{metric_postfix}mean_grad'].append(mean_grad)
        
    ms_per_step = int((time.time() - t_start) * 1e3 / (step + 1))
    s_left = int(ms_per_step * (step_total - step - 1) / 1e3)
    # Add steps progress, ms per step speed and estimated seconds left
    logs = f'{step + 1}/{step_total} | {ms_per_step}ms/step, {s_left}s left, '
    # Log Rolling Mean of loss with window of 100 steps
    loss_rw = np.mean(HISTORY_STEP[f'{metric_postfix}loss'][-rolling_window:])
    HISTORY_RW[f'{metric_postfix}loss'].append(loss_rw)
    logs +=f'{metric_postfix}loss: {loss_rw:.3f}, '
    
    if mean_grad > 0:
        mean_grad_rw = np.mean(HISTORY_STEP[f'{metric_postfix}mean_grad'][-rolling_window:])
        HISTORY_RW[f'{metric_postfix}mean_grad'].append(mean_grad_rw)
        logs +=f'mean_grad: {mean_grad_rw:.3e}, '
    
    for idx, m in enumerate(model.metrics):
        if idx > 0:
            logs += ', '
        
        m_result = m.result().numpy()
        # Add to Metric History
        HISTORY_STEP[f'{metric_postfix}{m.name}'].append(m_result)
        # Metric Rolling Mean Window 100
        m_rw = np.mean(HISTORY_STEP[f'{metric_postfix}{m.name}'][-rolling_window:])
        HISTORY_RW[f'{metric_postfix}{m.name}'].append(m_rw)
        
        # Add to Logs
        if m.name in ['weights_l2']:
            logs += f'{metric_postfix}{m.name}: {m_rw:.0f}'
        else:
            logs += f'{metric_postfix}{m.name}: {m_rw:.3f}'
        
    # Print Logs by overwriting line ("\r" returns print carriage to start of line)
    print('\r', logs.ljust(3), end='')
    sys.stdout.flush()
    
    # Reset all Metrics
    metric_reset_states()

# Validation Step

In [None]:
@tf.function(
    autograph=False,
    input_signature=[
        {'input_ids': tf.TensorSpec(shape=[BATCH_SIZE, SEQ_LENGTH], dtype=tf.uint16), 'attention_mask': tf.TensorSpec(shape=[BATCH_SIZE, SEQ_LENGTH], dtype=tf.int8)},
    ],
)
def pred(X_batch):
    return model(X_batch, training=False)

# Gradient Accumulation

This is the actual gradient accumulation training step

In [None]:
class GradientAccumulation():
    def __init__(self, n_gradients, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Variable to store the gradients in
        self.grads_zeros = [tf.Variable(tf.zeros_like(v, dtype=tf.float32), name=v.name, trainable=False) for v in model.trainable_variables if v.trainable]
        # Counter for Gradient Accumulation Steps
        self.counter = tf.Variable(0, dtype=tf.float32)
        # Number of Gradient to Accumulate before performing backpropagation
        self.n_gradients = tf.constant(n_gradients, dtype=np.float32)
        # Trainable Variables to obtain the gradients for
        self.trainable_variables = [v for v in model.trainable_variables if v.trainable]
        print(f'self.counter: {self.counter.numpy()}, self.n_gradients: {self.n_gradients}')
    
    # Backpropagation
    @tf.function()
    def backprop(self):
        # Backpropagation (updating the model weights)
        model.optimizer.apply_gradients(zip(self.grads_zeros, self.trainable_variables))

        # Reset Gradient Accumulation Counter
        self.counter.assign(0)

        # Set Accumulated Gradients to Zero
        for g_z in self.grads_zeros:
            g_z.assign(tf.zeros_like(g_z))
            
    # Empty Function for "tf.cond" operator
    @tf.function()
    def fun_empty(self):
        return
        
    # Mean Gradients used for metric to monitor gradient
    @tf.function(jit_compile=True)
    def get_mean_gradient(self, grads):
        # Gradient Sum and Count
        grads_sum = tf.constant(0, tf.float64)
        grads_count = tf.constant(0, tf.float64)

        for g in [g for g in grads if g is not None]:
            g_mean = tf.reduce_mean(tf.math.abs(g))
            g_count = tf.reduce_prod(tf.shape(g))
            grads_sum += tf.cast(g_mean, tf.float64) * tf.cast(g_count, tf.float64)
            grads_count += tf.cast(g_count, tf.float64)

        return grads_sum / grads_count

    # Training Step (forward step and compute gradients)
    @tf.function(
        autograph=False,
        input_signature=[
            {'input_ids': tf.TensorSpec(shape=[BATCH_SIZE, SEQ_LENGTH], dtype=tf.uint16), 'attention_mask': tf.TensorSpec(shape=[BATCH_SIZE, SEQ_LENGTH], dtype=tf.int8)},
            tf.TensorSpec(shape=[BATCH_SIZE, SEQ_LENGTH, N_LABELS], dtype=tf.uint8),
        ],
    )
    def train_step(self, X_batch, y_true_batch):
        # Forward Step (make prediction)
        with tf.GradientTape() as tape:
            y_pred_batch = model(X_batch)
            loss = model.loss(y_true_batch, y_pred_batch)
            
        # Update Train Metrics
        model.compiled_metrics.update_state(y_true_batch, y_pred_batch)

        # Get Gradients
        grads = tape.gradient(loss, self.trainable_variables)
                
        # Update Accumulation Gradients, divide by n_gradients to compute mean gradients
        [g_z.assign_add(g / self.n_gradients) for g, g_z in zip(grads, self.grads_zeros)]
        
        # Increase Gradient Accumulation Counter
        self.counter.assign_add(1)
        
        # If we have enough steps perform backpropagation, otherwise do nothing
        tf.cond(tf.math.equal(self.counter, self.n_gradients), self.backprop, self.fun_empty)    
                
        # Compute Mean Gradients of current forward step for metric
        mean_grad = self.get_mean_gradient(grads)
                
        return loss, mean_grad
    
    # Validation Step
    @tf.function(    
        autograph=False,
        input_signature=[
            {'input_ids': tf.TensorSpec(shape=[BATCH_SIZE, SEQ_LENGTH], dtype=tf.uint16), 'attention_mask': tf.TensorSpec(shape=[BATCH_SIZE, SEQ_LENGTH], dtype=tf.int8)},
            tf.TensorSpec(shape=[BATCH_SIZE, SEQ_LENGTH, N_LABELS], dtype=tf.uint8),
        ],
    )
    def val_step(self, X_batch, y_true_batch):
        # Forward Step with gradient monitoring
        y_pred_batch = pred(X_batch)
        loss = model.loss(y_true_batch, y_pred_batch)
            
        # Update Train Metrics
        model.compiled_metrics.update_state(y_true_batch, y_pred_batch)
                
        return loss

In [None]:
# Make Training Object and define the number of gradients to accumulate before backpropagation
GA = GradientAccumulation(n_gradients=8)

# Training Loop

In [None]:
# Train Step
t_start_train = time.time()
train_dataset_iter = iter(train_dataset)

for step_idx, step in enumerate(range(TRAIN_STEPS_PER_EPOCH)):
    # Get Next Train text and label
    X_batch, y_true_batch = next(train_dataset_iter)
    # Perform a train step
    loss, mean_grad = GA.train_step(X_batch, y_true_batch)
    # Log training progress
    log(loss, mean_grad, step, TRAIN_STEPS_PER_EPOCH, t_start_train, '')
    
print('\n')
    
# Validation Step
t_start_val = time.time()
val_dataset_iter = iter(val_dataset)

for step in range(VAL_STEPS_PER_EPOCH):
    X_batch, y_true_batch = next(val_dataset_iter)
    loss = GA.val_step(X_batch, y_true_batch)
    log(loss, 0, step, VAL_STEPS_PER_EPOCH, t_start_val, 'val_')
    
print('\n')

In [None]:
# Save Weights
model.save_weights('model.h5')

# Training Metrics

In [None]:
history_train_rows = []
history_val_rows = []

for k, v in HISTORY_STEP.items():
    if 'val' not in k:# Training
        history_train_rows.append({ 'metric': k, 'mean value': np.mean(v) })
    elif 'mean_grad' not in k: # Validation
        history_val_rows.append({ 'metric': k, 'mean value': np.mean(v) })

print('=== TRAIN METRICS ===')
display(pd.DataFrame.from_dict(history_train_rows))

print('=== VALIDATION METRICS ===')
display(pd.DataFrame.from_dict(history_val_rows))

# Plot Training History

In [None]:
def plot_history_metric(metric, f_best=np.argmax, include_val=True, y_lim_start=None):
    # Plot Every Train Step
    plt.figure(figsize=(20, 8))
    plt.plot(HISTORY_RW[metric])
    plt.title(f'Model {metric} Step', fontsize=24, pad=10)
    plt.ylabel(metric, fontsize=20, labelpad=10)
    plt.xlabel('step', fontsize=20, labelpad=10)
    plt.xticks(fontsize=16) # set tick step to 1 and let x axis start at 1
    plt.yticks(fontsize=16)
    plt.legend(prop={'size': 18})
    plt.grid()
    if y_lim_start is not None:
        plt.ylim(bottom=0, top=max(HISTORY_RW[metric]) * 1.25)
    plt.show()
    
    if include_val:
        # Plot Validation as Histogram
        plt.figure(figsize=(20, 8))
        plt.title(f'Model val_{metric} Histogram', fontsize=24, pad=10)
        pd.Series(HISTORY_STEP[f'val_{metric}']).plot(kind='hist', bins=32, color='tab:orange')
        plt.ylabel(f'val_{metric}', fontsize=20, labelpad=10)
        plt.xlabel(metric, fontsize=20, labelpad=10)
        plt.xticks(fontsize=16)
        plt.yticks(fontsize=16)
        plt.plot()

In [None]:
# Increase Plot DPI
plt.rcParams['figure.dpi'] = 300

In [None]:
plot_history_metric('loss', f_best=np.argmin)

In [None]:
plot_history_metric('mean_grad', include_val=False)

In [None]:
plot_history_metric('categorical_crossentropy', f_best=np.argmin)

In [None]:
plot_history_metric('non_pad_categorical_accuracy')

In [None]:
plot_history_metric('non_pad_f1')

In [None]:
plot_history_metric('weights_l2', include_val=False, y_lim_start=0)

# Prediction Analysis

In [None]:
def get_predictions(dataset, total_steps):    
    y_true = []
    y_pred = []
    for idx, (X_batch, y_true_batch) in tqdm(enumerate(dataset), total=total_steps):
        y_true += np.argmax(y_true_batch, axis=2).flatten().tolist()
        y_pred += np.argmax(pred(X_batch), axis=2).flatten().tolist()
    
    # Make Numpy arrays
    y_true = np.array(y_true, dtype=np.int8)
    y_pred = np.array(y_pred, dtype=np.int8)
    
    return y_true, y_pred

# Validation Dataset
y_true_val, y_pred_val = get_predictions(
    get_dataset(X_val_input_ids, X_val_attention_masks, y_val, False),
    VAL_STEPS_PER_EPOCH,
)

# Get Random Subset as prediction takes forever
np.random.seed(42)
train_idxs = np.random.choice(a=np.arange(len(y_train)), size=len(y_val), replace=False)

# Train Dataset
y_true_train, y_pred_train = get_predictions(
    get_dataset(X_train_input_ids[train_idxs], X_train_attention_masks[train_idxs], y_train[train_idxs], False),
    VAL_STEPS_PER_EPOCH,
)

# Validation Analysis

In [None]:
def show_validation_report_per_class(y, y_pred):
    report = classification_report(y, y_pred, target_names=LABELS, digits=3, output_dict=True)
    report_df = pd.DataFrame.from_dict(report).T
    report_df['support'] = report_df['support'].astype(int)
    
    pd.set_option('display.precision', 3)
    
    # Set empty Table Style
    report_df = report_df.style.set_properties({})
    
    # Cell Styling
    cells = {
        'selector': 'td',
        'props': 'font-size: 14pt',
    }
    
    # Heading Styling
    headers = {
        'selector': 'th',
        'props': 'font-size: 18pt'
    }
    
    # Set Table Styling
    report_df.set_table_styles([cells, headers])
    
    display(report_df)

# Validation Report

In [None]:
show_validation_report_per_class(y_true_val, y_pred_val)

# Train Report

In [None]:
show_validation_report_per_class(y_true_train, y_pred_train)

# Confusion Matrix

In [None]:
def plot_confusion_matrix(y, y_pred, title_prefix):
    # Confusion matrix
    fig, ax = plt.subplots(1, 1, figsize=(20, 20))
    cfn_matrix = confusion_matrix(y, y_pred, labels=range(N_LABELS))
    cfn_matrix = (cfn_matrix.T / cfn_matrix.sum(axis=1)).T
    df_cm = pd.DataFrame(cfn_matrix, index=np.arange(N_LABELS), columns=np.arange(N_LABELS))
    ax = sns.heatmap(df_cm, cmap='Blues', annot=True, fmt='.3f', linewidths=.5, annot_kws={'size':16})
    plt.title(f'{title_prefix.upper()} CONFUSION MATRIX', size=32, pad=25)
    plt.xticks(np.arange(N_LABELS) + 0.50, LABELS, fontsize=18, rotation=30)
    plt.yticks(np.arange(N_LABELS) + 0.50, LABELS, fontsize=18, rotation=0)
    plt.xlabel('PREDICTED', fontsize=24, labelpad=10)
    plt.ylabel('ACTUAL', fontsize=24, labelpad=10)
    plt.show()

In [None]:
# Validation Confusion Matrix
plot_confusion_matrix(y_true_val, y_pred_val, 'validation')

In [None]:
# Train Confusion Matrix
plot_confusion_matrix(y_true_train, y_pred_train, 'train')