In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class TransformerWithProjection(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super(TransformerWithProjection, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.fc_out = nn.Linear(embed_size, vocab_size)

    def forward(self, src):
        src_emb = self.embedding(src)
        logits = self.fc_out(src_emb)
        return logits

In [15]:
# Define model parameters
vocab_size = 10000  # Size of the vocabulary
embed_size = 512    # Embedding dimension

# Initialize the model
model = TransformerWithProjection(vocab_size, embed_size)

# Example input and target
input_seq = torch.randint(0, vocab_size, (2, 4))  # (batch_size, seq_len)
target = torch.randint(0, vocab_size, (2, 4))    # (batch_size, seq_len)

# Forward pass
logits = model(input_seq)  # logits shape: (batch_size, seq_len, vocab_size)

# Reshape logits and targets for loss calculation
logits_reshaped = logits.view(-1, vocab_size)  # Shape: (batch_size * seq_len, vocab_size)
target_reshaped = target.view(-1)              # Shape: (batch_size * seq_len)

# Define the loss function
criterion = nn.CrossEntropyLoss()

# Compute the loss
loss = criterion(logits_reshaped, target_reshaped)
print(f"Loss: {loss.item()}")

Loss: 9.699930191040039


In [16]:
x1 = model.embedding(input_seq)
x1.shape

torch.Size([2, 4, 512])

In [17]:
x2 = model.fc_out(x1)
x2.shape

torch.Size([2, 4, 10000])

In [27]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Define a single Transformer Encoder Layer
class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_size, num_heads, ff_hidden_size, dropout=0.1):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attention = nn.MultiheadAttention(embed_dim=embed_size, num_heads=num_heads, dropout=dropout)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, ff_hidden_size),
            nn.ReLU(),
            nn.Linear(ff_hidden_size, embed_size)
        )
        self.layer_norm1 = nn.LayerNorm(embed_size)
        self.layer_norm2 = nn.LayerNorm(embed_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Self Attention
        attn_output, _ = self.self_attention(x, x, x)
        x = self.layer_norm1(x + self.dropout(attn_output))
        # Feed Forward
        ff_output = self.feed_forward(x)
        x = self.layer_norm2(x + self.dropout(ff_output))
        return x

# Define the Transformer Model
class TransformerWithProjection(nn.Module):
    def __init__(self, vocab_size, embed_size, num_heads, num_layers, ff_hidden_size, dropout=0.1):
        super(TransformerWithProjection, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.positional_encoding = nn.Parameter(torch.zeros(1, 500, embed_size))  # Assuming max sequence length of 500
        self.encoder_layers = nn.ModuleList([
            TransformerEncoderLayer(embed_size, num_heads, ff_hidden_size, dropout)
            for _ in range(num_layers)
        ])
        self.fc_out = nn.Linear(embed_size, vocab_size)

    def forward(self, src):
        seq_len = src.size(1)
        src_emb = self.embedding(src) + self.positional_encoding[:, :seq_len]
        src_emb = src_emb.permute(1, 0, 2)  # Change shape to (seq_len, batch_size, embed_size)

        for layer in self.encoder_layers:
            src_emb = layer(src_emb)

        src_emb = src_emb.permute(1, 0, 2)  # Change shape back to (batch_size, seq_len, embed_size)
        logits = self.fc_out(src_emb)
        return logits

# Example parameters
vocab_size = 10000  # Size of the vocabulary
embed_size = 512    # Embedding dimension
num_heads = 8       # Number of attention heads
num_layers = 12      # Number of transformer layers
ff_hidden_size = 2048  # Hidden size of the feed-forward layer

# Initialize the model
model = TransformerWithProjection(vocab_size, embed_size, num_heads, num_layers, ff_hidden_size)

# Example input and target
input_seq = torch.randint(0, vocab_size, (2, 4))  # (batch_size, seq_len)
target = torch.randint(0, vocab_size, (2, 4))    # (batch_size, seq_len)

# Forward pass
logits = model(input_seq)  # logits shape: (batch_size, seq_len, vocab_size)

# Reshape logits and targets for loss calculation
logits_reshaped = logits.view(-1, vocab_size)  # Shape: (batch_size * seq_len, vocab_size)
target_reshaped = target.view(-1)              # Shape: (batch_size * seq_len)

# Define the loss function
criterion = nn.CrossEntropyLoss()

# Compute the loss
loss = criterion(logits_reshaped, target_reshaped)
print(f"Loss: {loss.item()}")

Loss: 9.299447059631348


In [28]:
# Extract logits for the last token
last_token_logits = logits[:, -1, :]  # Shape: (batch_size, vocab_size)

# Compute probabilities for the last token
probs = F.softmax(last_token_logits, dim=-1)

# Get the predicted token (highest probability)
predicted_token = torch.argmax(probs, dim=-1)
print(f"Predicted token: {predicted_token}")

Predicted token: tensor([8389, 9676])
