In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import random

import pandas as pd

In [114]:
class LocalTransformer(nn.Transformer):
    def __init__(self, *args, **kwargs):
        super(LocalTransformer, self).__init__(*args, **kwargs)
    
    def generate_square_subsequent_mask(self, sz: int, local_window_size: int) -> torch.Tensor:
        mask = torch.ones(sz, sz)
        mask = torch.tril(mask, diagonal=0)
        mask = torch.triu(mask, diagonal=1 - local_window_size)
        mask = mask.masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

In [119]:
import torch
import torch.nn as nn
import torch.optim as optim

# Dummy data
sequence_length = 15
batch_size = 32
d_model = 64
nhead = 4
num_layers = 2
local_window_size = 9

src = torch.randn(sequence_length, batch_size, d_model)
tgt = torch.randn(sequence_length, batch_size, d_model)

model = LocalTransformer(
    d_model=d_model,
    nhead=nhead,
    num_encoder_layers=num_layers,
    num_decoder_layers=num_layers,
    dim_feedforward=2048,
)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

epochs = 50

for epoch in range(epochs):
    optimizer.zero_grad()

    # Generate the local attention masks for both src and tgt
    src_mask = model.generate_square_subsequent_mask(sequence_length, local_window_size)
    tgt_mask = model.generate_square_subsequent_mask(sequence_length, local_window_size)

    # Forward pass
    output = model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)

    # Compute the loss
    loss = criterion(output, tgt)
    print(f"Epoch {epoch + 1}: Loss = {loss.item()}")

    # Backpropagation
    loss.backward()
    optimizer.step()

Epoch 1: Loss = 0.6414380669593811
Epoch 2: Loss = 0.7702780365943909
Epoch 3: Loss = 0.527873158454895
Epoch 4: Loss = 0.4409901201725006
Epoch 5: Loss = 0.38551682233810425
Epoch 6: Loss = 0.3511482775211334
Epoch 7: Loss = 0.32219016551971436
Epoch 8: Loss = 0.3068331182003021
Epoch 9: Loss = 0.2822735607624054
Epoch 10: Loss = 0.26417332887649536
Epoch 11: Loss = 0.24902163445949554
Epoch 12: Loss = 0.23999746143817902
Epoch 13: Loss = 0.2273954302072525
Epoch 14: Loss = 0.21501025557518005
Epoch 15: Loss = 0.2066449522972107
Epoch 16: Loss = 0.1990288645029068
Epoch 17: Loss = 0.19472740590572357
Epoch 18: Loss = 0.18850766122341156
Epoch 19: Loss = 0.1785399317741394
Epoch 20: Loss = 0.17196793854236603
Epoch 21: Loss = 0.1654997020959854
Epoch 22: Loss = 0.1580863744020462
Epoch 23: Loss = 0.15446658432483673
Epoch 24: Loss = 0.15002106130123138
Epoch 25: Loss = 0.14671757817268372
Epoch 26: Loss = 0.1399780809879303
Epoch 27: Loss = 0.13580772280693054
Epoch 28: Loss = 0.131028