# GPT2 as text classifier

Helpful source: https://drlee.io/fine-tuning-gpt-2-for-sentiment-analysis-94ebdd7b5b24

## Install Required Libraries


In [1]:
!pip install datasets # unified interface for accessing and working with various datasets (by hugging face)
!pip install -U accelerate # library to optimize and accelerate numerical computations
!pip install -U transformers # library by hugging face that gives easy access to pre-trained models, tokenizers, and tools for fine-tuning models

Collecting datasets
  Downloading datasets-2.19.1-py3-none-any.whl (542 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m542.0/542.0 kB[0m [31m11.7 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m15.3 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m25.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m18.5 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub>=0.21.2 (from datasets)
  Downloading huggingface_hub-0.23.0-py3-none-

## Loading and Processing the Dataset

We load the dataset from Hugging Face. Each sample consists of one strings feature that stores the title as well as the (start of the) article-text. The label is the category that the article belongs to (world, sports, business, sci/tech). [Link](https://huggingface.co/datasets/ag_news/viewer/default/train) to explore the structure of the data.

In [2]:
from datasets import load_dataset

dataset = load_dataset('ag_news')

print(dataset)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading readme:   0%|          | 0.00/8.07k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/18.6M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.23M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/120000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/7600 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 120000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 7600
    })
})


Reduce the size of the dataset (to reduce training times) whilst ensuring that the original structure and distribution of the data is kept.

In [3]:
from datasets import Dataset, DatasetDict
import pandas as pd

def take_a_percentage_of_data(dataset, percentage=0.1, shuffle=True, random_state=None):
    # sort and group the dataset by label
    df = pd.DataFrame(dataset)
    df_sorted = df.sort_values(by='label')
    grouped_dfs = df_sorted.groupby('label')

    # ensure that proportions of the groups remains the same as in the original dataset
    filtered_dfs_per_group = []
    for label, group in grouped_dfs:
        num_samples_to_keep = int(len(group) * percentage)
        filtered_group = group.head(num_samples_to_keep)
        filtered_dfs_per_group.append(filtered_group)

    # concatenate (and shuffle) the filtered group-wise dataframes
    filtered_df = pd.concat(filtered_dfs_per_group)
    if shuffle:
        filtered_df = filtered_df.sample(frac=1, random_state=random_state)

    filtered_df.reset_index(drop=True, inplace=True) # resets the index of the DataFrame, drops the previous index column
    filtered_df_as_dict = filtered_df.to_dict(orient='list')
    filtered_dataset = Dataset.from_dict(filtered_df_as_dict)
    return filtered_dataset

dataset_train_1percent = take_a_percentage_of_data(dataset['train'], percentage=0.01)
dataset_test_1percent = take_a_percentage_of_data(dataset['test'], percentage=0.01)

dataset_1percent = DatasetDict({
    'train': dataset_train_1percent,
    'test': dataset_test_1percent
}) # combine the shortened datasets back into the old structure.

print(dataset_1percent)

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 1200
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 76
    })
})


## Tokenizing the dataset

Tokenize the dataset in the exact same way as the GPT-2 model.

In [4]:
from transformers import GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token # padding tokens added to sequences will be represented by an end-of-sequence token
def tokenize_function(examples):
    #return tokenizer(examples["text"], padding="max_length", truncation=True)
    return tokenizer(examples["text"], padding="max_length")

tokenized_dataset = dataset_1percent.map(tokenize_function, batched=True) # performed in batches to increase performance

print(tokenized_dataset) # tokenization adds two features: 'input_ids' (the tokenized representation of 'text') as well as 'attention_mask', which ensures that the model does not attend to padding tokens added during tokenization



tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Map:   0%|          | 0/1200 [00:00<?, ? examples/s]

Map:   0%|          | 0/76 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'input_ids', 'attention_mask'],
        num_rows: 1200
    })
    test: Dataset({
        features: ['text', 'label', 'input_ids', 'attention_mask'],
        num_rows: 76
    })
})


## Loading the Pre-trained GPT2-Model

Load the pre-trained GPT2-Model for sequence classification.

In [5]:
from transformers import GPT2ForSequenceClassification

gpt2_model = GPT2ForSequenceClassification.from_pretrained("gpt2", num_labels=4) # our gpt2-model should distinguish between 4 labels, adds a final fully connected layers with 4 output neurons.

print(gpt2_model)



model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


GPT2ForSequenceClassification(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (score): Linear(in_features=768, out_features=4, bias=False)
)


## Implementing the Gaussian Adaptive Attention Block

Download the package provided by the paper.

In [6]:
!pip3 install gaussian-adaptive-attention

Collecting gaussian-adaptive-attention
  Downloading gaussian_adaptive_attention-0.1.5-py3-none-any.whl (8.7 kB)
Installing collected packages: gaussian-adaptive-attention
Successfully installed gaussian-adaptive-attention-0.1.5


**Approach 1**: Import the GaussianBlock directly from the package and replace the current attention mechanism in the GPT2-model with it.

In [7]:
import importlib
gaussian_adaptive_attention = importlib.import_module("gaussian_adaptive_attention")
GaussianBlock = getattr(gaussian_adaptive_attention, "GaussianBlock")
import torch
import copy

gaussian_model = copy.deepcopy(gpt2_model)

num_layers = 1

for block in gaussian_model.transformer.h: # accessing each transformer blocks within the GPT-2 model
  block.attn = GaussianBlock(
    num_layers=num_layers,  # Number of layers in the GPT-2 model
    norm_axes=[1] * num_layers,  # Assuming layer normalization along the sequence length dimension
    num_heads=[5] * num_layers,  # Number of attention heads
    num_gaussians=[4] * num_layers  # Number of Gaussian components
  )

print(gaussian_model)

GPT2ForSequenceClassification(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GaussianBlock(
          (layers): ModuleList(
            (0): MultiHeadGaussianAdaptiveAttention(
              (attention_heads): ModuleList(
                (0-4): 5 x GaussianAdaptiveAttention()
              )
            )
          )
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (score): Linear(in_features=768, out_features=4, bias=False)
)


However, this model doesn't work, because keyword-arguments are passed to the attention mechanism that the GaussianBlock cannot handle (e.g. the argument `last_layer`), which is why the GaussianBlock returns an error.

**Approach 2**: Create a wrapper class of MultiHeadGaussianAdaptive Attention that is then inserted into the GPT2-architecture layer by layer.

The `forward`-method of the wrapper class accepts additional keyword arguments, so the error is not thrown. However, the values of the keyword arguments are currently not actively used in the function, since we saw no meaningful way to integrate them into the Gaussian attention mechanism.

In [8]:
import importlib
gaussian_adaptive_attention = importlib.import_module("gaussian_adaptive_attention")
MultiHeadGaussianAdaptiveAttention = getattr(gaussian_adaptive_attention, "MultiHeadGaussianAdaptiveAttention")
import torch
import copy

gaussian_model = copy.deepcopy(gpt2_model)

# class MultiHeadGaussianAdaptiveAttentionWrapper(torch.nn.Module):
    # def __init__(self, config, num_gaussians=10, norm_axis=1):
      #  super().__init__()
      #  self.attention = MultiHeadGaussianAdaptiveAttention(
      #      norm_axis=norm_axis, # same as example of the researchers (dario had = 2)
      #      num_heads=config.n_head,
      #      num_gaussians=num_gaussians,
      #      padding_value=config.eos_token_id,
      #     eps=config.layer_norm_epsilon
      #  )

    # def forward(self, hidden_states, **kwargs):
        # Pass arguments using **kwargs to the underlying attention mechanism
        # attention_output = self.attention(hidden_states)
        # return (hidden_states,) + tuple(attention_output)  # Ensure the return value is a tuple

# multihead_gaussian_attention = MultiHeadGaussianAdaptiveAttentionWrapper(config=gaussian_model.config, num_gaussians=5, norm_axis=1) # and replacing the attention module with the Gaussian attention block.
mhga = MultiHeadGaussianAdaptiveAttention(norm_axis=1, num_heads=gaussian_model.config.n_head, num_gaussians=5, padding_value=gaussian_model.config.eos_token_id, eps=gaussian_model.config.layer_norm_epsilon)


In [9]:
class MultiHeadCombinedAttention(torch.nn.Module):
    def __init__(self, config, gaussian_attention, original_attention):
        super().__init__()
        self.gaussian_attention = gaussian_attention
        self.original_attention = original_attention

    def forward(self, hidden_states, **kwargs):
        # Pass hidden_states through Gaussian attention first
        gaussian_output = self.gaussian_attention(hidden_states)
        # Pass the output of Gaussian attention through the original attention
        combined_output = self.original_attention(gaussian_output)
        return combined_output

In [10]:

# Replace the attention mechanism in each transformer block
for block in gaussian_model.transformer.h:
    # Save the original attention mechanism
    original_attention = block.attn
    # Replace it with the combined attention mechanism
    block.attn = MultiHeadCombinedAttention(gaussian_model.config, mhga, original_attention)

In [11]:
print(gaussian_model)
print(gpt2_model)

GPT2ForSequenceClassification(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): MultiHeadCombinedAttention(
          (gaussian_attention): MultiHeadGaussianAdaptiveAttention(
            (attention_heads): ModuleList(
              (0-11): 12 x GaussianAdaptiveAttention()
            )
          )
          (original_attention): GPT2Attention(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplac

## Freezing Parameters

For the first training epochs, we decided to freeze the pre-trained layers. We did this, because we want the attention weights to first learn some general things before we fine-tune the model as a whole.

In [12]:
# for param in gaussian_model.transformer.parameters():
#     param.requires_grad = False

# for param in gaussian_model.score.parameters():
#     param.requires_grad = True

# for block in gaussian_model.transformer.h:
#   for param in block.attn.parameters():
#       param.requires_grad = True


In [13]:
# trainable_params = sum(p.numel() for p in gaussian_model.parameters() if p.requires_grad)
# frozen_params = sum(p.numel() for p in gaussian_model.parameters() if not p.requires_grad)

# print(f"Trainable Parameters: {trainable_params}")
# print(f"Frozen Parameters: {frozen_params}")

## Check: Do both models run?

1. The Normal GPT2-Model
2. The GPT2-Model with GAAM

In [14]:
input_ids = torch.randint(0, gpt2_model.config.vocab_size, (1, 512))
labels = torch.tensor([1]).unsqueeze(0)

outputs = gpt2_model(input_ids=input_ids, labels=labels)
loss, logits = outputs['loss'], outputs['logits']
print(f"GPT2 with regular attention mechanism: Loss = {round(loss.item(), 4)}, logits = {logits.detach()}")

outputs = gaussian_model(input_ids=input_ids, labels=labels)
loss, logits = outputs['loss'], outputs['logits']
print(f"GPT2 with Gaussian attention mechanism: Loss = {round(loss.item(), 4)}, logits = {logits.detach()}")


GPT2 with regular attention mechanism: Loss = 1.2133, logits = tensor([[-1.5094, -1.4943, -1.9139, -1.8199]])
GPT2 with Gaussian attention mechanism: Loss = 1.1441, logits = tensor([[-1.1289, -0.6210, -0.8668, -0.9010]])


## Training

Training the GPT2-Model with Gaussian attention

In [28]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",
    learning_rate=1e-5,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    num_train_epochs=5,
    # weight_decay=0.01,
    use_cpu = False,
    no_cuda = False,
    logging_steps = 100,
    save_strategy = "no"
)

trainer = Trainer(
    model=gaussian_model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
)


In [30]:
model_name = "gpt2-ag_news-1percent-frozen"

trainer.train()
trainer.save_model(model_name)
# Save the tokenizer used by the model as well
# tokenizer.save_pretrained(model_name)


Step,Training Loss
100,1.3785
200,1.3893
300,1.4005
400,1.3991
500,1.3906
600,1.4009
700,1.385
800,1.4027
900,1.3922
1000,1.3941


RuntimeError: The weights trying to be saved contained shared tensors [{'transformer.h.10.attn.gaussian_attention.attention_heads.0.mean_offsets', 'transformer.h.7.attn.gaussian_attention.attention_heads.0.mean_offsets', 'transformer.h.1.attn.gaussian_attention.attention_heads.0.mean_offsets', 'transformer.h.6.attn.gaussian_attention.attention_heads.0.mean_offsets', 'transformer.h.3.attn.gaussian_attention.attention_heads.0.mean_offsets', 'transformer.h.11.attn.gaussian_attention.attention_heads.0.mean_offsets', 'transformer.h.2.attn.gaussian_attention.attention_heads.0.mean_offsets', 'transformer.h.4.attn.gaussian_attention.attention_heads.0.mean_offsets', 'transformer.h.8.attn.gaussian_attention.attention_heads.0.mean_offsets', 'transformer.h.9.attn.gaussian_attention.attention_heads.0.mean_offsets', 'transformer.h.0.attn.gaussian_attention.attention_heads.0.mean_offsets', 'transformer.h.5.attn.gaussian_attention.attention_heads.0.mean_offsets'}, {'transformer.h.6.attn.gaussian_attention.attention_heads.0.c', 'transformer.h.7.attn.gaussian_attention.attention_heads.0.c', 'transformer.h.2.attn.gaussian_attention.attention_heads.0.c', 'transformer.h.5.attn.gaussian_attention.attention_heads.0.c', 'transformer.h.8.attn.gaussian_attention.attention_heads.0.c', 'transformer.h.1.attn.gaussian_attention.attention_heads.0.c', 'transformer.h.9.attn.gaussian_attention.attention_heads.0.c', 'transformer.h.0.attn.gaussian_attention.attention_heads.0.c', 'transformer.h.3.attn.gaussian_attention.attention_heads.0.c', 'transformer.h.10.attn.gaussian_attention.attention_heads.0.c', 'transformer.h.11.attn.gaussian_attention.attention_heads.0.c', 'transformer.h.4.attn.gaussian_attention.attention_heads.0.c'}, {'transformer.h.6.attn.gaussian_attention.attention_heads.1.mean_offsets', 'transformer.h.1.attn.gaussian_attention.attention_heads.1.mean_offsets', 'transformer.h.0.attn.gaussian_attention.attention_heads.1.mean_offsets', 'transformer.h.4.attn.gaussian_attention.attention_heads.1.mean_offsets', 'transformer.h.3.attn.gaussian_attention.attention_heads.1.mean_offsets', 'transformer.h.2.attn.gaussian_attention.attention_heads.1.mean_offsets', 'transformer.h.8.attn.gaussian_attention.attention_heads.1.mean_offsets', 'transformer.h.10.attn.gaussian_attention.attention_heads.1.mean_offsets', 'transformer.h.5.attn.gaussian_attention.attention_heads.1.mean_offsets', 'transformer.h.7.attn.gaussian_attention.attention_heads.1.mean_offsets', 'transformer.h.11.attn.gaussian_attention.attention_heads.1.mean_offsets', 'transformer.h.9.attn.gaussian_attention.attention_heads.1.mean_offsets'}, {'transformer.h.3.attn.gaussian_attention.attention_heads.1.c', 'transformer.h.5.attn.gaussian_attention.attention_heads.1.c', 'transformer.h.8.attn.gaussian_attention.attention_heads.1.c', 'transformer.h.7.attn.gaussian_attention.attention_heads.1.c', 'transformer.h.1.attn.gaussian_attention.attention_heads.1.c', 'transformer.h.9.attn.gaussian_attention.attention_heads.1.c', 'transformer.h.6.attn.gaussian_attention.attention_heads.1.c', 'transformer.h.11.attn.gaussian_attention.attention_heads.1.c', 'transformer.h.4.attn.gaussian_attention.attention_heads.1.c', 'transformer.h.0.attn.gaussian_attention.attention_heads.1.c', 'transformer.h.2.attn.gaussian_attention.attention_heads.1.c', 'transformer.h.10.attn.gaussian_attention.attention_heads.1.c'}, {'transformer.h.7.attn.gaussian_attention.attention_heads.2.mean_offsets', 'transformer.h.11.attn.gaussian_attention.attention_heads.2.mean_offsets', 'transformer.h.1.attn.gaussian_attention.attention_heads.2.mean_offsets', 'transformer.h.10.attn.gaussian_attention.attention_heads.2.mean_offsets', 'transformer.h.4.attn.gaussian_attention.attention_heads.2.mean_offsets', 'transformer.h.5.attn.gaussian_attention.attention_heads.2.mean_offsets', 'transformer.h.3.attn.gaussian_attention.attention_heads.2.mean_offsets', 'transformer.h.0.attn.gaussian_attention.attention_heads.2.mean_offsets', 'transformer.h.8.attn.gaussian_attention.attention_heads.2.mean_offsets', 'transformer.h.9.attn.gaussian_attention.attention_heads.2.mean_offsets', 'transformer.h.2.attn.gaussian_attention.attention_heads.2.mean_offsets', 'transformer.h.6.attn.gaussian_attention.attention_heads.2.mean_offsets'}, {'transformer.h.6.attn.gaussian_attention.attention_heads.2.c', 'transformer.h.0.attn.gaussian_attention.attention_heads.2.c', 'transformer.h.7.attn.gaussian_attention.attention_heads.2.c', 'transformer.h.1.attn.gaussian_attention.attention_heads.2.c', 'transformer.h.10.attn.gaussian_attention.attention_heads.2.c', 'transformer.h.3.attn.gaussian_attention.attention_heads.2.c', 'transformer.h.5.attn.gaussian_attention.attention_heads.2.c', 'transformer.h.2.attn.gaussian_attention.attention_heads.2.c', 'transformer.h.8.attn.gaussian_attention.attention_heads.2.c', 'transformer.h.9.attn.gaussian_attention.attention_heads.2.c', 'transformer.h.11.attn.gaussian_attention.attention_heads.2.c', 'transformer.h.4.attn.gaussian_attention.attention_heads.2.c'}, {'transformer.h.3.attn.gaussian_attention.attention_heads.3.mean_offsets', 'transformer.h.6.attn.gaussian_attention.attention_heads.3.mean_offsets', 'transformer.h.8.attn.gaussian_attention.attention_heads.3.mean_offsets', 'transformer.h.10.attn.gaussian_attention.attention_heads.3.mean_offsets', 'transformer.h.9.attn.gaussian_attention.attention_heads.3.mean_offsets', 'transformer.h.1.attn.gaussian_attention.attention_heads.3.mean_offsets', 'transformer.h.11.attn.gaussian_attention.attention_heads.3.mean_offsets', 'transformer.h.7.attn.gaussian_attention.attention_heads.3.mean_offsets', 'transformer.h.5.attn.gaussian_attention.attention_heads.3.mean_offsets', 'transformer.h.2.attn.gaussian_attention.attention_heads.3.mean_offsets', 'transformer.h.4.attn.gaussian_attention.attention_heads.3.mean_offsets', 'transformer.h.0.attn.gaussian_attention.attention_heads.3.mean_offsets'}, {'transformer.h.0.attn.gaussian_attention.attention_heads.3.c', 'transformer.h.8.attn.gaussian_attention.attention_heads.3.c', 'transformer.h.4.attn.gaussian_attention.attention_heads.3.c', 'transformer.h.10.attn.gaussian_attention.attention_heads.3.c', 'transformer.h.1.attn.gaussian_attention.attention_heads.3.c', 'transformer.h.9.attn.gaussian_attention.attention_heads.3.c', 'transformer.h.2.attn.gaussian_attention.attention_heads.3.c', 'transformer.h.5.attn.gaussian_attention.attention_heads.3.c', 'transformer.h.6.attn.gaussian_attention.attention_heads.3.c', 'transformer.h.3.attn.gaussian_attention.attention_heads.3.c', 'transformer.h.7.attn.gaussian_attention.attention_heads.3.c', 'transformer.h.11.attn.gaussian_attention.attention_heads.3.c'}, {'transformer.h.0.attn.gaussian_attention.attention_heads.4.mean_offsets', 'transformer.h.7.attn.gaussian_attention.attention_heads.4.mean_offsets', 'transformer.h.6.attn.gaussian_attention.attention_heads.4.mean_offsets', 'transformer.h.4.attn.gaussian_attention.attention_heads.4.mean_offsets', 'transformer.h.8.attn.gaussian_attention.attention_heads.4.mean_offsets', 'transformer.h.2.attn.gaussian_attention.attention_heads.4.mean_offsets', 'transformer.h.9.attn.gaussian_attention.attention_heads.4.mean_offsets', 'transformer.h.10.attn.gaussian_attention.attention_heads.4.mean_offsets', 'transformer.h.3.attn.gaussian_attention.attention_heads.4.mean_offsets', 'transformer.h.1.attn.gaussian_attention.attention_heads.4.mean_offsets', 'transformer.h.11.attn.gaussian_attention.attention_heads.4.mean_offsets', 'transformer.h.5.attn.gaussian_attention.attention_heads.4.mean_offsets'}, {'transformer.h.5.attn.gaussian_attention.attention_heads.4.c', 'transformer.h.6.attn.gaussian_attention.attention_heads.4.c', 'transformer.h.4.attn.gaussian_attention.attention_heads.4.c', 'transformer.h.8.attn.gaussian_attention.attention_heads.4.c', 'transformer.h.10.attn.gaussian_attention.attention_heads.4.c', 'transformer.h.2.attn.gaussian_attention.attention_heads.4.c', 'transformer.h.9.attn.gaussian_attention.attention_heads.4.c', 'transformer.h.3.attn.gaussian_attention.attention_heads.4.c', 'transformer.h.1.attn.gaussian_attention.attention_heads.4.c', 'transformer.h.0.attn.gaussian_attention.attention_heads.4.c', 'transformer.h.7.attn.gaussian_attention.attention_heads.4.c', 'transformer.h.11.attn.gaussian_attention.attention_heads.4.c'}, {'transformer.h.2.attn.gaussian_attention.attention_heads.5.mean_offsets', 'transformer.h.6.attn.gaussian_attention.attention_heads.5.mean_offsets', 'transformer.h.4.attn.gaussian_attention.attention_heads.5.mean_offsets', 'transformer.h.10.attn.gaussian_attention.attention_heads.5.mean_offsets', 'transformer.h.3.attn.gaussian_attention.attention_heads.5.mean_offsets', 'transformer.h.0.attn.gaussian_attention.attention_heads.5.mean_offsets', 'transformer.h.1.attn.gaussian_attention.attention_heads.5.mean_offsets', 'transformer.h.5.attn.gaussian_attention.attention_heads.5.mean_offsets', 'transformer.h.9.attn.gaussian_attention.attention_heads.5.mean_offsets', 'transformer.h.8.attn.gaussian_attention.attention_heads.5.mean_offsets', 'transformer.h.11.attn.gaussian_attention.attention_heads.5.mean_offsets', 'transformer.h.7.attn.gaussian_attention.attention_heads.5.mean_offsets'}, {'transformer.h.11.attn.gaussian_attention.attention_heads.5.c', 'transformer.h.4.attn.gaussian_attention.attention_heads.5.c', 'transformer.h.5.attn.gaussian_attention.attention_heads.5.c', 'transformer.h.6.attn.gaussian_attention.attention_heads.5.c', 'transformer.h.10.attn.gaussian_attention.attention_heads.5.c', 'transformer.h.0.attn.gaussian_attention.attention_heads.5.c', 'transformer.h.7.attn.gaussian_attention.attention_heads.5.c', 'transformer.h.8.attn.gaussian_attention.attention_heads.5.c', 'transformer.h.9.attn.gaussian_attention.attention_heads.5.c', 'transformer.h.2.attn.gaussian_attention.attention_heads.5.c', 'transformer.h.3.attn.gaussian_attention.attention_heads.5.c', 'transformer.h.1.attn.gaussian_attention.attention_heads.5.c'}, {'transformer.h.2.attn.gaussian_attention.attention_heads.6.mean_offsets', 'transformer.h.9.attn.gaussian_attention.attention_heads.6.mean_offsets', 'transformer.h.11.attn.gaussian_attention.attention_heads.6.mean_offsets', 'transformer.h.3.attn.gaussian_attention.attention_heads.6.mean_offsets', 'transformer.h.6.attn.gaussian_attention.attention_heads.6.mean_offsets', 'transformer.h.4.attn.gaussian_attention.attention_heads.6.mean_offsets', 'transformer.h.10.attn.gaussian_attention.attention_heads.6.mean_offsets', 'transformer.h.5.attn.gaussian_attention.attention_heads.6.mean_offsets', 'transformer.h.1.attn.gaussian_attention.attention_heads.6.mean_offsets', 'transformer.h.8.attn.gaussian_attention.attention_heads.6.mean_offsets', 'transformer.h.7.attn.gaussian_attention.attention_heads.6.mean_offsets', 'transformer.h.0.attn.gaussian_attention.attention_heads.6.mean_offsets'}, {'transformer.h.8.attn.gaussian_attention.attention_heads.6.c', 'transformer.h.2.attn.gaussian_attention.attention_heads.6.c', 'transformer.h.6.attn.gaussian_attention.attention_heads.6.c', 'transformer.h.4.attn.gaussian_attention.attention_heads.6.c', 'transformer.h.7.attn.gaussian_attention.attention_heads.6.c', 'transformer.h.5.attn.gaussian_attention.attention_heads.6.c', 'transformer.h.0.attn.gaussian_attention.attention_heads.6.c', 'transformer.h.10.attn.gaussian_attention.attention_heads.6.c', 'transformer.h.1.attn.gaussian_attention.attention_heads.6.c', 'transformer.h.3.attn.gaussian_attention.attention_heads.6.c', 'transformer.h.9.attn.gaussian_attention.attention_heads.6.c', 'transformer.h.11.attn.gaussian_attention.attention_heads.6.c'}, {'transformer.h.7.attn.gaussian_attention.attention_heads.7.mean_offsets', 'transformer.h.4.attn.gaussian_attention.attention_heads.7.mean_offsets', 'transformer.h.8.attn.gaussian_attention.attention_heads.7.mean_offsets', 'transformer.h.0.attn.gaussian_attention.attention_heads.7.mean_offsets', 'transformer.h.11.attn.gaussian_attention.attention_heads.7.mean_offsets', 'transformer.h.6.attn.gaussian_attention.attention_heads.7.mean_offsets', 'transformer.h.5.attn.gaussian_attention.attention_heads.7.mean_offsets', 'transformer.h.3.attn.gaussian_attention.attention_heads.7.mean_offsets', 'transformer.h.9.attn.gaussian_attention.attention_heads.7.mean_offsets', 'transformer.h.10.attn.gaussian_attention.attention_heads.7.mean_offsets', 'transformer.h.1.attn.gaussian_attention.attention_heads.7.mean_offsets', 'transformer.h.2.attn.gaussian_attention.attention_heads.7.mean_offsets'}, {'transformer.h.3.attn.gaussian_attention.attention_heads.7.c', 'transformer.h.5.attn.gaussian_attention.attention_heads.7.c', 'transformer.h.9.attn.gaussian_attention.attention_heads.7.c', 'transformer.h.8.attn.gaussian_attention.attention_heads.7.c', 'transformer.h.0.attn.gaussian_attention.attention_heads.7.c', 'transformer.h.1.attn.gaussian_attention.attention_heads.7.c', 'transformer.h.6.attn.gaussian_attention.attention_heads.7.c', 'transformer.h.2.attn.gaussian_attention.attention_heads.7.c', 'transformer.h.7.attn.gaussian_attention.attention_heads.7.c', 'transformer.h.4.attn.gaussian_attention.attention_heads.7.c', 'transformer.h.11.attn.gaussian_attention.attention_heads.7.c', 'transformer.h.10.attn.gaussian_attention.attention_heads.7.c'}, {'transformer.h.11.attn.gaussian_attention.attention_heads.8.mean_offsets', 'transformer.h.6.attn.gaussian_attention.attention_heads.8.mean_offsets', 'transformer.h.9.attn.gaussian_attention.attention_heads.8.mean_offsets', 'transformer.h.10.attn.gaussian_attention.attention_heads.8.mean_offsets', 'transformer.h.2.attn.gaussian_attention.attention_heads.8.mean_offsets', 'transformer.h.0.attn.gaussian_attention.attention_heads.8.mean_offsets', 'transformer.h.3.attn.gaussian_attention.attention_heads.8.mean_offsets', 'transformer.h.1.attn.gaussian_attention.attention_heads.8.mean_offsets', 'transformer.h.4.attn.gaussian_attention.attention_heads.8.mean_offsets', 'transformer.h.7.attn.gaussian_attention.attention_heads.8.mean_offsets', 'transformer.h.8.attn.gaussian_attention.attention_heads.8.mean_offsets', 'transformer.h.5.attn.gaussian_attention.attention_heads.8.mean_offsets'}, {'transformer.h.10.attn.gaussian_attention.attention_heads.8.c', 'transformer.h.7.attn.gaussian_attention.attention_heads.8.c', 'transformer.h.11.attn.gaussian_attention.attention_heads.8.c', 'transformer.h.4.attn.gaussian_attention.attention_heads.8.c', 'transformer.h.6.attn.gaussian_attention.attention_heads.8.c', 'transformer.h.3.attn.gaussian_attention.attention_heads.8.c', 'transformer.h.5.attn.gaussian_attention.attention_heads.8.c', 'transformer.h.1.attn.gaussian_attention.attention_heads.8.c', 'transformer.h.8.attn.gaussian_attention.attention_heads.8.c', 'transformer.h.2.attn.gaussian_attention.attention_heads.8.c', 'transformer.h.9.attn.gaussian_attention.attention_heads.8.c', 'transformer.h.0.attn.gaussian_attention.attention_heads.8.c'}, {'transformer.h.5.attn.gaussian_attention.attention_heads.9.mean_offsets', 'transformer.h.1.attn.gaussian_attention.attention_heads.9.mean_offsets', 'transformer.h.3.attn.gaussian_attention.attention_heads.9.mean_offsets', 'transformer.h.2.attn.gaussian_attention.attention_heads.9.mean_offsets', 'transformer.h.4.attn.gaussian_attention.attention_heads.9.mean_offsets', 'transformer.h.7.attn.gaussian_attention.attention_heads.9.mean_offsets', 'transformer.h.8.attn.gaussian_attention.attention_heads.9.mean_offsets', 'transformer.h.10.attn.gaussian_attention.attention_heads.9.mean_offsets', 'transformer.h.11.attn.gaussian_attention.attention_heads.9.mean_offsets', 'transformer.h.0.attn.gaussian_attention.attention_heads.9.mean_offsets', 'transformer.h.9.attn.gaussian_attention.attention_heads.9.mean_offsets', 'transformer.h.6.attn.gaussian_attention.attention_heads.9.mean_offsets'}, {'transformer.h.11.attn.gaussian_attention.attention_heads.9.c', 'transformer.h.4.attn.gaussian_attention.attention_heads.9.c', 'transformer.h.8.attn.gaussian_attention.attention_heads.9.c', 'transformer.h.7.attn.gaussian_attention.attention_heads.9.c', 'transformer.h.5.attn.gaussian_attention.attention_heads.9.c', 'transformer.h.1.attn.gaussian_attention.attention_heads.9.c', 'transformer.h.0.attn.gaussian_attention.attention_heads.9.c', 'transformer.h.6.attn.gaussian_attention.attention_heads.9.c', 'transformer.h.9.attn.gaussian_attention.attention_heads.9.c', 'transformer.h.10.attn.gaussian_attention.attention_heads.9.c', 'transformer.h.3.attn.gaussian_attention.attention_heads.9.c', 'transformer.h.2.attn.gaussian_attention.attention_heads.9.c'}, {'transformer.h.9.attn.gaussian_attention.attention_heads.10.mean_offsets', 'transformer.h.3.attn.gaussian_attention.attention_heads.10.mean_offsets', 'transformer.h.5.attn.gaussian_attention.attention_heads.10.mean_offsets', 'transformer.h.6.attn.gaussian_attention.attention_heads.10.mean_offsets', 'transformer.h.10.attn.gaussian_attention.attention_heads.10.mean_offsets', 'transformer.h.2.attn.gaussian_attention.attention_heads.10.mean_offsets', 'transformer.h.4.attn.gaussian_attention.attention_heads.10.mean_offsets', 'transformer.h.7.attn.gaussian_attention.attention_heads.10.mean_offsets', 'transformer.h.1.attn.gaussian_attention.attention_heads.10.mean_offsets', 'transformer.h.11.attn.gaussian_attention.attention_heads.10.mean_offsets', 'transformer.h.8.attn.gaussian_attention.attention_heads.10.mean_offsets', 'transformer.h.0.attn.gaussian_attention.attention_heads.10.mean_offsets'}, {'transformer.h.9.attn.gaussian_attention.attention_heads.10.c', 'transformer.h.3.attn.gaussian_attention.attention_heads.10.c', 'transformer.h.11.attn.gaussian_attention.attention_heads.10.c', 'transformer.h.4.attn.gaussian_attention.attention_heads.10.c', 'transformer.h.7.attn.gaussian_attention.attention_heads.10.c', 'transformer.h.2.attn.gaussian_attention.attention_heads.10.c', 'transformer.h.0.attn.gaussian_attention.attention_heads.10.c', 'transformer.h.6.attn.gaussian_attention.attention_heads.10.c', 'transformer.h.10.attn.gaussian_attention.attention_heads.10.c', 'transformer.h.1.attn.gaussian_attention.attention_heads.10.c', 'transformer.h.5.attn.gaussian_attention.attention_heads.10.c', 'transformer.h.8.attn.gaussian_attention.attention_heads.10.c'}, {'transformer.h.1.attn.gaussian_attention.attention_heads.11.mean_offsets', 'transformer.h.5.attn.gaussian_attention.attention_heads.11.mean_offsets', 'transformer.h.2.attn.gaussian_attention.attention_heads.11.mean_offsets', 'transformer.h.9.attn.gaussian_attention.attention_heads.11.mean_offsets', 'transformer.h.8.attn.gaussian_attention.attention_heads.11.mean_offsets', 'transformer.h.3.attn.gaussian_attention.attention_heads.11.mean_offsets', 'transformer.h.11.attn.gaussian_attention.attention_heads.11.mean_offsets', 'transformer.h.0.attn.gaussian_attention.attention_heads.11.mean_offsets', 'transformer.h.4.attn.gaussian_attention.attention_heads.11.mean_offsets', 'transformer.h.6.attn.gaussian_attention.attention_heads.11.mean_offsets', 'transformer.h.7.attn.gaussian_attention.attention_heads.11.mean_offsets', 'transformer.h.10.attn.gaussian_attention.attention_heads.11.mean_offsets'}, {'transformer.h.2.attn.gaussian_attention.attention_heads.11.c', 'transformer.h.8.attn.gaussian_attention.attention_heads.11.c', 'transformer.h.0.attn.gaussian_attention.attention_heads.11.c', 'transformer.h.7.attn.gaussian_attention.attention_heads.11.c', 'transformer.h.3.attn.gaussian_attention.attention_heads.11.c', 'transformer.h.10.attn.gaussian_attention.attention_heads.11.c', 'transformer.h.6.attn.gaussian_attention.attention_heads.11.c', 'transformer.h.1.attn.gaussian_attention.attention_heads.11.c', 'transformer.h.9.attn.gaussian_attention.attention_heads.11.c', 'transformer.h.5.attn.gaussian_attention.attention_heads.11.c', 'transformer.h.11.attn.gaussian_attention.attention_heads.11.c', 'transformer.h.4.attn.gaussian_attention.attention_heads.11.c'}] that are mismatching the transformers base configuration. Try saving using `safe_serialization=False` or remove this tensor sharing.

In [31]:
trainer.evaluate()

{'eval_loss': 1.3863117694854736,
 'eval_runtime': 13.6047,
 'eval_samples_per_second': 5.586,
 'eval_steps_per_second': 5.586,
 'epoch': 5.0}

## Calculating the accuracy

In [32]:
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

gaussian_model.to(device)

def predict_label(sentence, model):
    inputs = tokenizer(sentence, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs)
        prediction = outputs.logits.argmax(-1).item()

    return prediction

def calculate_accuracy(model, data):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for item in data:
            prediction = predict_label(item['text'], model)
            label = item['label']
            total += 1
            if label == prediction:
              correct += 1

    return correct / total

data = dataset_1percent['test']

# Calculate accuracy
accuracy = calculate_accuracy(gaussian_model, data)
print(f"Accuracy: {accuracy:.4f}")

Accuracy: 0.2500


## Second Training Cycle with Unfrozen Layers

Loading the previously trained model and unfreezing the layers.

In [None]:

unfrozen_gaussian_model = copy.deepcopy(gaussian_model)
print(unfrozen_gaussian_model)

# unfreeze layers
for param in unfrozen_gaussian_model.transformer.parameters():
    param.requires_grad = True

# check
trainable_params = sum(p.numel() for p in unfrozen_gaussian_model.parameters() if p.requires_grad)
print(f"Trainable Parameters: {trainable_params}")

GPT2ForSequenceClassification(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): MultiHeadGaussianAdaptiveAttentionWrapper(
          (attention): MultiHeadGaussianAdaptiveAttention(
            (attention_heads): ModuleList(
              (0-11): 12 x GaussianAdaptiveAttention()
            )
          )
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (score): Linear(in_features=768, out_features=4, bias=False)
)
Trainable Parameters: 96095904


Defining the new parameters for the second training cycle (smaller learning rate, including weight decay)

In [None]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",
    learning_rate=2e-5,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    num_train_epochs=5,
    weight_decay=0.01,
    use_cpu = False,
    no_cuda = False
)

trainer = Trainer(
    model=unfrozen_gaussian_model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
)


Starting the second training cycle.

In [None]:
trainer.train()

Step,Training Loss
500,1.405
1000,1.3995
1500,1.3957
2000,1.3912
2500,1.3964
3000,1.3896
3500,1.39
4000,1.39
4500,1.3934
5000,1.3879


TrainOutput(global_step=6000, training_loss=1.3929286905924478, metrics={'train_runtime': 2274.1474, 'train_samples_per_second': 2.638, 'train_steps_per_second': 2.638, 'total_flos': 2090634706944000.0, 'train_loss': 1.3929286905924478, 'epoch': 5.0})

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

unfrozen_gaussian_model.to(device)

# Test the function with an example from the dataset
text = dataset['test'][1]['text']
label = dataset['test'][1]['label']
print("Predicted label:", predict_label(text))
print("Expected label:", label)


Predicted label: 2
Expected label: 3


In [None]:
from torch.utils.data import DataLoader

data = dataset_1percent['test']

# Calculate accuracy
accuracy = calculate_accuracy(unfrozen_gaussian_model, data)
print(f"Accuracy: {accuracy:.4f}")

Accuracy: 0.2500


## Alternative: Training a GPT2-Model from Scratch

In [None]:
from transformers import GPT2ForSequenceClassification, GPT2Config

# Define GPT-2 configuration
config = GPT2Config(
    vocab_size=50257,  # Number of tokens in the vocabulary
    n_embd=768,        # Dimensionality of the embeddings and hidden states
    n_layer=12,        # Number of transformer layers
    n_head=12,         # Number of attention heads
    num_labels=4       # Number of labels for sequence classification
)

gaussian_gpt2_model_untrained = GPT2ForSequenceClassification(config)

print(gaussian_gpt2_model_untrained)

GPT2ForSequenceClassification(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (score): Linear(in_features=768, out_features=4, bias=False)
)


Replace the attention layer with a Gaussian attention layer.

In [None]:
# Replace the attention mechanism in each transformer block
for block in gaussian_gpt2_model_untrained.transformer.h: # accessing each transformer blocks within the GPT-2 model
    block.attn = MultiHeadGaussianAdaptiveAttentionWrapper(config=gaussian_model.config, num_gaussians=5, norm_axis=1) # and replacing the attention module with the Gaussian attention block.

print(gaussian_gpt2_model_untrained)

GPT2ForSequenceClassification(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): MultiHeadGaussianAdaptiveAttentionWrapper(
          (attention): MultiHeadGaussianAdaptiveAttention(
            (attention_heads): ModuleList(
              (0-11): 12 x GaussianAdaptiveAttention()
            )
          )
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (score): Linear(in_features=768, out_features=4, bias=False)
)


Train the Gaussian model from scratch.

In [None]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",
    learning_rate=1e-3,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    num_train_epochs=5,
    # weight_decay=0.01,
    use_cpu = False,
    no_cuda = False
)

trainer = Trainer(
    model=gaussian_gpt2_model_untrained,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
)

trainer.train()

Step,Training Loss
500,1.8659
1000,1.5173
1500,1.45
2000,1.4374
2500,1.4212
3000,1.4035
3500,1.4144
4000,1.4009
4500,1.4043
5000,1.3955


TrainOutput(global_step=6000, training_loss=1.4576284484863282, metrics={'train_runtime': 2184.9509, 'train_samples_per_second': 2.746, 'train_steps_per_second': 2.746, 'total_flos': 2090634706944000.0, 'train_loss': 1.4576284484863282, 'epoch': 5.0})