
**Paper Implementation/Presentation:**

Longformer: The Long-Document Transformer

Theodore Zitouni


In [1]:
!pip install transformers
!pip install datasets


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.28.1-py3-none-any.whl (7.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m36.4 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m46.5 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.11.0
  Downloading huggingface_hub-0.14.1-py3-none-any.whl (224 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.5/224.5 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.14.1 tokenizers-0.13.3 transformers-4.28.1
Looking in indexes: https://pypi.org/simple, https://

In [10]:
import torch
import math
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import pandas as pd
from tqdm.auto import tqdm
import transformers
from transformers.models.roberta.modeling_roberta import RobertaConfig, RobertaModel, RobertaForMaskedLM
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader, TensorDataset
from transformers import LongformerTokenizer, LongformerConfig
from transformers.modeling_outputs import BaseModelOutput
from transformers import get_linear_schedule_with_warmup
from sklearn.metrics import accuracy_score, f1_score

In [4]:
# data

imdb_dataset = load_dataset("imdb")

tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')
config = LongformerConfig.from_pretrained('allenai/longformer-base-4096')
def prepare_imdb_data(dataset, tokenizer, max_length):
    texts = dataset["text"]
    labels = dataset["label"]

    input_ids, attention_masks = [], []
    for text in texts:
        inputs = tokenizer.encode_plus(text, truncation=True, max_length=max_length, padding='max_length', return_tensors='pt')
        input_ids.append(inputs['input_ids'][0])
        attention_masks.append(inputs['attention_mask'][0])

    return TensorDataset(torch.stack(input_ids), torch.stack(attention_masks), torch.tensor(labels))

train_dataset = prepare_imdb_data(imdb_dataset["train"], tokenizer, max_length=512)
test_dataset = prepare_imdb_data(imdb_dataset["test"], tokenizer, max_length=512)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

Downloading builder script:   0%|          | 0.00/4.31k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.17k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.59k [00:00<?, ?B/s]

Downloading and preparing dataset imdb/plain_text to /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0...


Downloading data:   0%|          | 0.00/84.1M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Dataset imdb downloaded and prepared to /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/694 [00:00<?, ?B/s]

In [5]:
class LongformerModel(nn.Module):
    def __init__(self, config, num_classes):
        super(LongformerModel, self).__init__()

        # small model
        self.nbr_attention_heads = 8
        self.hidden_size = config.hidden_size  # hidden size from config  
        self.nbr_layers = 12 # not used for practical reasons
        self.global_attention_indices = [0] # global attention on the CLS token
        self.head_size = int(self.hidden_size / self.nbr_attention_heads)
        self.embeddings_size = self.hidden_size

        self.query_local = nn.Linear(self.hidden_size, self.embeddings_size)
        self.key_local = nn.Linear(self.hidden_size, self.embeddings_size)
        self.value_local = nn.Linear(self.hidden_size, self.embeddings_size)

        self.query_global = nn.Linear(self.hidden_size, self.embeddings_size)
        self.key_global = nn.Linear(self.hidden_size, self.embeddings_size)
        self.value_global = nn.Linear(self.hidden_size, self.embeddings_size)

        self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        self.classifier = nn.Linear(self.hidden_size, num_classes) # imdb classification

        self.dropout = 0.2
        self.softmax = nn.Softmax(dim=-1)

        self.attention_window = 230 # same window size for all layers for simplicity
        self.global_attention_indices = None

    def forward(self, input_ids, attention_mask=None):
        hidden_states = self.embeddings(input_ids)
        batch_size, sequence_length, embeddings_size = hidden_states.size()

        # local attention: färdig, preliminärt
        q_local = self.query_local(hidden_states)
        k_local = self.key_local(hidden_states)
        v_local = self.value_local(hidden_states)
        q_local /= math.sqrt(self.head_size)

        # reshape, split heads
        q_local_layer = q_local.view(sequence_length, batch_size, self.nbr_attention_heads, self.head_size).transpose(0, 1)
        k_local_layer = k_local.view(sequence_length, batch_size, self.nbr_attention_heads, self.head_size).transpose(0, 1)
        v_local_layer = v_local.view(sequence_length, batch_size, self.nbr_attention_heads, self.head_size).transpose(0, 1)

        local_attention_scores = torch.matmul(q_local_layer, k_local_layer.transpose(-1, -2))
        local_attention_scores /= math.sqrt(self.head_size)
        local_attention_probs = self.softmax(local_attention_scores)

        local_context_layer = torch.matmul(local_attention_probs, v_local_layer)
        local_context_layer = local_context_layer.permute(0, 2, 1, 3).contiguous()
        local_context_layer = local_context_layer.view(batch_size, sequence_length, self.embeddings_size)

        # global attention: färdig prelminärt
        if self.global_attention_indices is not None:
            q_global = self.query_global(hidden_states) 
            k_global = self.key_global(hidden_states)
            v_global = self.value_global(hidden_states)
            q_global /= math.sqrt(self.head_size)

            q_global_layer = q_global.view(batch_size * self.nbr_attention_heads, sequence_length, self.head_size).transpose(0, 1)
            k_global_layer = k_global.view(batch_size * self.nbr_attention_heads, sequence_length, self.head_size).transpose(0, 1)
            v_global_layer = v_global.view(batch_size * self.nbr_attention_heads, sequence_length, self.head_size).transpose(0, 1)

            q_global_layer = q_global_layer[:, self.global_attention_indices, :, :]
            k_global_layer = k_global_layer[:, self.global_attention_indices, :, :]
            v_global_layer = v_global_layer[:, self.global_attention_indices, :, :]

            global_attention_scores = torch.matmul(q_global_layer, k_global_layer.transpose(1, 2))
            global_attention_scores /= math.sqrt(self.head_size)

            global_attention_probs = self.softmax(global_attention_scores)

            global_context_layer = torch.matmul(global_attention_probs, v_global_layer)
            global_context_layer = global_context_layer.permute(0, 2, 1, 3).contiguous()
            global_context_layer = global_context_layer.view(batch_size, sequence_length, self.embeddings_size)
        else:
            # if no attention indices, purely local attention
            global_context_layer = torch.zeros_like(local_context_layer) 
        # local + global attention
        outputs = local_context_layer + global_context_layer
        # classification
        cls_output = outputs[:, 0]
        outputs = self.classifier(cls_output)

        return outputs

In [8]:
# create a LongformerModel
model = LongformerModel(config, num_classes=2)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Device:', device)
model.to(device)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of trainable parameters:", num_params)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-5)

num_epochs = 2 # paper states 15, takes ages
tot_training_steps = num_epochs * len(train_dataloader)
nbr_warmup_steps = int(0.1 * tot_training_steps)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=nbr_warmup_steps, num_training_steps=tot_training_steps)


Device: cpu
Number of trainable parameters: 42148610


In [11]:
# training



for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0
    true_labels = []
    pred_labels = []
    
    for batch in train_dataloader:
        optimizer.zero_grad()
        
        input_ids, attention_mask, labels = batch
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        outputs = model(input_ids, attention_mask)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        true_labels.extend(labels.cpu().numpy())  
        pred_labels.extend(predicted.cpu().numpy()) 
        
    
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss/len(train_dataloader)}, Train Accuracy: {100 * correct / total}")
    
    # evaluation
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in test_dataloader:
            input_ids, attention_mask, labels = batch
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            labels = labels.to(device)

            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            
            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
    # kanske lägga till någon score
    f1 = f1_score(true_labels, pred_labels, average='binary')
    print(f"Epoch {epoch+1}/{num_epochs}, Test Loss: {test_loss/len(test_dataloader)}, Test Accuracy: {100 * correct / total}, F1 Score: {f1}")

Epoch 1/2, Train Loss: 0.6973827549106325, Train Accuracy: 49.292
Epoch 1/2, Test Loss: 0.6979429745460715, Test Accuracy: 49.508, F1 Score: 0.4874044721200113
Epoch 2/2, Train Loss: 0.6963517790865106, Train Accuracy: 49.768
Epoch 2/2, Test Loss: 0.6949192209316947, Test Accuracy: 50.328, F1 Score: 0.4749561000083619
