# Task 1

In [1]:
import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        # Get number of training examples
        N = query.shape[0]
        query_len, key_len, value_len = query.shape[1], keys.shape[1], values.shape[1]

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(2)  # Add dimensions for heads
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.nn.functional.softmax(energy / (self.head_dim ** 0.5), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        out = self.fc_out(out)
        return out


class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.feedforward = nn.Sequential(
            nn.Linear(embed_size, 4 * embed_size),
            nn.ReLU(),
            nn.Linear(4 * embed_size, embed_size),
        )
        self.norm2 = nn.LayerNorm(embed_size)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)

        # Add skip connection, run through normalization, and finally dropout
        x = self.norm1(attention + query)
        forward = self.feedforward(x)
        out = self.norm2(forward + x)
        return out


class GPT2(nn.Module):
    def __init__(self, embed_size, heads, num_layers, vocab_size, device):
        super(GPT2, self).__init__()
        self.embed_size = embed_size
        self.vocab_size = vocab_size
        self.device = device

        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.positional_encoding = nn.Embedding(1000, embed_size)
        self.transformer_blocks = nn.ModuleList(
            [TransformerBlock(embed_size, heads) for _ in range(num_layers)]
        )
        self.fc_out = nn.Linear(embed_size, vocab_size)

    def forward(self, x, mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        out = self.embedding(x) + self.positional_encoding(positions)

        for transformer in self.transformer_blocks:
            out = transformer(out, out, out, mask)

        out = self.fc_out(out)
        return out

# Test the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GPT2(embed_size=256, heads=8, num_layers=4, vocab_size=10000, device=device)
x = torch.randint(0, 10000, (32, 10))  # Batch size of 32, sequence length of 10
mask = torch.ones((32, 10)).to(device)
output = model(x.to(device), mask)
print(output.shape)  # This should print torch.Size([32, 10, 10000])


torch.Size([32, 10, 10000])


# Task 2

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class RotaryPositionalEmbedding(nn.Module):
    def __init__(self, embed_size):
        super(RotaryPositionalEmbedding, self).__init__()
        self.embed_size = embed_size
        self.w = nn.Parameter(torch.randn(embed_size // 2) * 0.01)
        self.b = nn.Parameter(torch.zeros(embed_size // 2))
        self.alpha = nn.Parameter(torch.ones(1))

    def forward(self, positions):
        sin_input = positions.unsqueeze(-1) * 2.0 ** torch.arange(0, self.embed_size // 2).to(positions.device)
        sin_component = torch.sin(sin_input)
        cos_component = torch.cos(sin_input)

        sin_part = self.alpha * (self.w * sin_component + self.b)
        cos_part = self.alpha * (self.w * cos_component + self.b)

        return torch.cat([sin_part, cos_part], dim=-1)
#The RotaryPositionalEmbedding introduces a novel positional encoding method using rotary embeddings. This can be beneficial for capturing sequential information in a different way compared to standard positional encodings. However, its effectiveness depends on the specific task and dataset. The number of parameters is relatively small, so it is a lightweight addition to the model.


class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads, use_sliding_window_attention=False):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads
        self.use_sliding_window_attention = use_sliding_window_attention

        if use_sliding_window_attention:
            self.attention = SlidingWindowAttention(embed_size, heads, window_size=5)
        else:
            self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
            self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
            self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)

        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        query_len, key_len, value_len = query.shape[1], keys.shape[1], values.shape[1]

        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        if self.use_sliding_window_attention:
            attention = self.attention(queries, keys, values, mask)
        else:
            values = self.values(values)
            keys = self.keys(keys)
            queries = self.queries(queries)

            energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

            if mask is not None:
                mask = mask.unsqueeze(1).unsqueeze(2)
                energy = energy.masked_fill(mask == 0, float("-1e20"))

            attention = F.softmax(energy / (self.head_dim ** 0.5), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        out = self.fc_out(out)
        return out
#The MultiHeadAttention now supports an option for using SlidingWindowAttention. Sliding window attention can be useful for handling long sequences more efficiently. However, the actual implementation of the SlidingWindowAttention mechanism is not provided in the code, so its effectiveness cannot be evaluated without it.
    
class GroupQueryAttention(nn.Module):
    def __init__(self, embed_size, heads, group_size):
        super(GroupQueryAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.group_size = group_size
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        query_len, key_len, value_len = query.shape[1], keys.shape[1], values.shape[1]

        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(2)
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = F.softmax(energy / (self.head_dim ** 0.5), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        out = self.fc_out(out)
        return out

# GroupQueryAttention introduces the idea of grouping queries for attention. This can be useful when certain queries need to attend to specific subsets of the key-value pairs. It adds a modest number of parameters, but its effectiveness would depend on the task and dataset characteristics.
    
class SlidingWindowAttention(nn.Module):
    def __init__(self, embed_size, heads, window_size):
        super(SlidingWindowAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.window_size = window_size
        self.head_dim = embed_size // heads

    def forward(self, queries, keys, values, mask):
        # ... (implement sliding window attention mechanism)

        return out
    

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, use_group_query_attention=False, use_sliding_window_attention=False):
        super(TransformerBlock, self).__init__()
        if use_group_query_attention:
            self.attention = GroupQueryAttention(embed_size, heads, group_size=4)
        elif use_sliding_window_attention:
            self.attention = MultiHeadAttention(embed_size, heads, use_sliding_window_attention=True)
        else:
            self.attention = MultiHeadAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.feedforward = nn.Sequential(
            nn.Linear(embed_size, 4 * embed_size),
            nn.ReLU(),
            nn.Linear(4 * embed_size, embed_size),
        )
        self.norm2 = nn.LayerNorm(embed_size)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)

        x = self.norm1(attention + query)
        forward = self.feedforward(x)
        out = self.norm2(forward + x)
        return out

#The TransformerBlock now allows for choosing between standard self-attention and GroupQueryAttention or SlidingWindowAttention. This flexibility allows experimentation with different attention mechanisms. It's important to note that using multiple attention mechanisms may increase the model's capacity and complexity.
    
class GPT2(nn.Module):
    def __init__(self, embed_size, heads, num_layers, vocab_size, device,
                 use_rotary_positional_embedding=False, use_group_query_attention=False, use_sliding_window_attention=False):
        super(GPT2, self).__init__()
        self.embed_size = embed_size
        self.vocab_size = vocab_size
        self.device = device

        if use_rotary_positional_embedding:
            self.rotary_positional_encoding = RotaryPositionalEmbedding(embed_size)
        else:
            self.positional_encoding = nn.Embedding(1000, embed_size)

        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.transformer_blocks = nn.ModuleList(
            [TransformerBlock(embed_size, heads, use_group_query_attention, use_sliding_window_attention) for _ in range(num_layers)]
        )
        self.fc_out = nn.Linear(embed_size, vocab_size)

    def forward(self, x, mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)

        if hasattr(self, 'rotary_positional_encoding'):
            # Use Rotary positional encoding
            out = self.embedding(x) + self.rotary_positional_encoding(positions)
        else:
            # Use standard positional encoding
            out = self.embedding(x) + self.positional_encoding(positions)

        for transformer in self.transformer_blocks:
            out = transformer(out, out, out, mask)

        out = self.fc_out(out)
        return out

# The GPT2 model now supports various attention mechanisms and positional encodings. The flexibility to choose between rotary positional embeddings, group query attention, and sliding window attention allows for experimentation with different combinations.

# Test the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GPT2(embed_size=256, heads=8, num_layers=4, vocab_size=10000, device=device,
             use_rotary_positional_embedding=True, use_group_query_attention=True, use_sliding_window_attention=True)
x = torch.randint(0, 10000, (32, 10))  # Batch size of 32, sequence length of 10
mask = torch.ones((32, 10)).to(device)
output = model(x.to(device), mask)
print(output.shape)  # This should print torch.Size([32, 10, 10000])



torch.Size([32, 10, 10000])


Model Size: 
The model's size is influenced by the added components, but it remains reasonably sized. The total number of parameters depends on the choices made (e.g., rotary embeddings, attention mechanisms).

Potential Pitfalls:
The effectiveness of the new components depends on the task and dataset. It's recommended to conduct thorough experiments to validate improvements.
The sliding window attention mechanism needs to be implemented correctly to ensure its benefits are realized.

Possible Improvements:
Provide a complete implementation of the sliding window attention mechanism for a comprehensive evaluation.
Conduct experiments to compare the performance of different attention mechanisms and positional encodings on your specific task and dataset.

# Task 3

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.parallel import DistributedDataParallel
from fairscale.nn import FullyShardedDataParallel

# Define your GPT2 model, loss function, and optimizer
# (Assuming you have already defined the GPT2 class and other necessary components)

# Dummy dataset for demonstration
class RandomDataset(Dataset):
    def __init__(self, num_samples, seq_length, vocab_size):
        self.num_samples = num_samples
        self.seq_length = seq_length
        self.vocab_size = vocab_size

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        input_sequence = torch.randint(0, self.vocab_size, (self.seq_length,))
        target_sequence = torch.randint(0, self.vocab_size, (self.seq_length,))
        return input_sequence, target_sequence

# Hyperparameters
num_samples = 1000
seq_length = 10
vocab_size = 10000
batch_size = 64
embed_size = 256
heads = 8
num_layers = 4

# Create an instance of the GPT2 model
model = GPT2(embed_size, heads, num_layers, vocab_size, device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Dummy DataLoader for demonstration
dataset = RandomDataset(num_samples, seq_length, vocab_size)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Set up distributed training if using DDP
if torch.cuda.device_count() > 1:
    dist.init_process_group(backend='nccl')
    model = DistributedDataParallel(model)

# Set up Fully Sharded Data Parallel (FSDP)
if torch.cuda.device_count() > 1:
    model = FullyShardedDataParallel(model)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()

    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs, mask=None)  # Assuming 'mask' is available
        loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))

        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(f'Epoch [{epoch}/{num_epochs}], Step [{batch_idx}/{len(train_loader)}], Loss: {loss.item()}')

# Cleanup for distributed training
if torch.cuda.device_count() > 1:
    dist.destroy_process_group()

Epoch [0/10], Step [0/16], Loss: 9.366623878479004
Epoch [1/10], Step [0/16], Loss: 9.401666641235352
Epoch [2/10], Step [0/16], Loss: 9.373968124389648
Epoch [3/10], Step [0/16], Loss: 9.362058639526367
Epoch [4/10], Step [0/16], Loss: 9.376642227172852
Epoch [5/10], Step [0/16], Loss: 9.342138290405273
Epoch [6/10], Step [0/16], Loss: 9.366018295288086
Epoch [7/10], Step [0/16], Loss: 9.318262100219727
Epoch [8/10], Step [0/16], Loss: 9.328409194946289
Epoch [9/10], Step [0/16], Loss: 9.356502532958984
