In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np


In [2]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return x + self.encoding[:, :x.size(1)].detach()


In [3]:
class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, embed_size, num_heads, hidden_size, num_layers, max_len, dropout=0.5):
        super(TransformerEncoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.positional_encoding = PositionalEncoding(embed_size, max_len)
        self.dropout = nn.Dropout(dropout)
        
        self.transformer_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=embed_size, nhead=num_heads, dim_feedforward=hidden_size)
            for _ in range(num_layers)
        ])
        
    def forward(self, x):
        embedded = self.embedding(x)  # (seq_length, batch_size, embed_size)
        embedded = self.positional_encoding(embedded)
        embedded = self.dropout(embedded)
        
        # Permute to (seq_length, batch_size, embed_size) for transformer
        embedded = embedded.permute(1, 0, 2)
        
        # Forward through transformer layers
        for layer in self.transformer_layers:
            embedded = layer(embedded)
        
        # Return final transformer output (batch_size, seq_length, embed_size)
        return embedded.permute(1, 0, 2)


In [4]:
class TransformerClassifier(nn.Module):
    def __init__(self, vocab_size, embed_size, num_heads, hidden_size, num_layers, max_len, num_classes, dropout=0.5):
        super(TransformerClassifier, self).__init__()
        self.encoder = TransformerEncoder(vocab_size, embed_size, num_heads, hidden_size, num_layers, max_len, dropout)
        self.fc = nn.Linear(embed_size, num_classes)
    
    def forward(self, x):
        transformer_output = self.encoder(x)  # (batch_size, seq_length, embed_size)
        # Use global max pooling along the sequence length dimension
        pooled_output, _ = torch.max(transformer_output, dim=1)  # (batch_size, embed_size)
        logits = self.fc(pooled_output)  # (batch_size, num_classes)
        return logits


In [5]:
# Example usage for text classification

# Define model hyperparameters
vocab_size = 10000
embed_size = 128
num_heads = 4
hidden_size = 256
num_layers = 3
max_len = 100
num_classes = 2
dropout = 0.1
batch_size = 32
learning_rate = 0.001
num_epochs = 10

# Create dummy data (replace with your dataset)
train_data = torch.randint(0, vocab_size, (1000, max_len))  # (num_samples, max_len)
train_labels = torch.randint(0, num_classes, (1000,))  # (num_samples,)

# Initialize TransformerClassifier model
model = TransformerClassifier(vocab_size, embed_size, num_heads, hidden_size, num_layers, max_len, num_classes, dropout)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    epoch_loss = 0.0
    model.train()
    for i in range(0, len(train_data), batch_size):
        batch_data = train_data[i:i+batch_size]
        batch_labels = train_labels[i:i+batch_size]
        
        # Forward pass
        optimizer.zero_grad()
        logits = model(batch_data)
        
        # Calculate loss
        loss = criterion(logits, batch_labels)
        epoch_loss += loss.item()
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
    
    # Print average loss per epoch
    print(f"Epoch {epoch + 1}, Loss: {epoch_loss / (len(train_data) / batch_size):.4f}")


Epoch 1, Loss: 0.8243
Epoch 2, Loss: 0.6860
Epoch 3, Loss: 0.6278
Epoch 4, Loss: 0.5480
Epoch 5, Loss: 0.4642


KeyboardInterrupt: 