In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import random

In [114]:
class LocalTransformer(nn.Transformer):
    def __init__(self, *args, **kwargs):
        super(LocalTransformer, self).__init__(*args, **kwargs)
    
    def generate_square_subsequent_mask(self, sz: int, local_window_size: int) -> torch.Tensor:
        mask = torch.ones(sz, sz)
        mask = torch.tril(mask, diagonal=0)
        mask = torch.triu(mask, diagonal=1 - local_window_size)
        mask = mask.masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

In [115]:
# Generate a toy dataset
def generate_data(num_samples, seq_len):
    sequences = []
    for _ in range(num_samples):
        input_seq = [random.randint(0, 9) for _ in range(seq_len)]
        output_seq = [random.randint(0, 9) for _ in range(seq_len)]
        sequences.append((input_seq, output_seq))
    return sequences

# Model
class SimpleTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_layers):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.transformer = LocalTransformer(d_model, nhead, num_layers)
        self.output = nn.Linear(d_model, vocab_size)
    
    def forward(self, src, tgt):
        src = self.embedding(src).permute(1, 0, 2)
        tgt = self.embedding(tgt).permute(1, 0, 2)
        output = self.transformer(src, tgt)
        output = output.permute(1, 0, 2)
        return self.output(output)

# Hyperparameters
num_samples = 1000
seq_len = 10
vocab_size = 10
d_model = 32
nhead = 4
num_layers = 2
epochs = 10
batch_size = 32
lr = 0.001

# Prepare data
dataset = generate_data(num_samples, seq_len)
data_loader = data.DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: x)

# Initialize model, optimizer, and loss function
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleTransformer(vocab_size, d_model, nhead, num_layers).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

# Training loop
for epoch in range(epochs):
    model.train()
    for batch in data_loader:
        inputs, targets = zip(*batch)
        inputs = torch.tensor(inputs, dtype=torch.long, device=device)
        targets = torch.tensor(targets, dtype=torch.long, device=device)
        
        # Shift the targets for the transformer by 1 position
        tgt_input = torch.cat([torch.zeros(targets.shape[0], 1, dtype=torch.long, device=device), targets[:, :-1]], dim=1)
        
        optimizer.zero_grad()
        outputs = model(inputs, tgt_input)
        loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))
        loss.backward()
        optimizer.step()
    
    print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item()}")

Epoch 1/10, Loss: 2.3034982681274414
Epoch 2/10, Loss: 2.315067768096924
Epoch 3/10, Loss: 2.3155112266540527
Epoch 4/10, Loss: 2.329038619995117
Epoch 5/10, Loss: 2.219697952270508
Epoch 6/10, Loss: 2.1946301460266113
Epoch 7/10, Loss: 2.208533763885498
Epoch 8/10, Loss: 2.141535997390747
Epoch 9/10, Loss: 2.1182942390441895
Epoch 10/10, Loss: 2.0858349800109863


tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [-inf, 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [-inf, -inf, 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf, 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf, -inf, 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf, -inf, -inf, 0., 0., 0., 0., 0., -inf, -inf, -in

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [-inf, 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [-inf, -inf, 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [-inf, -inf, -inf, 0., 0., 0., 0., 0., -inf, -inf],
        [-inf, -inf, -inf, -inf, 0., 0., 0., 0., 0., -inf],
        [-inf, -inf, -inf, -inf, -inf, 0., 0., 0., 0., 0.]])

tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 1.]])

1