**Task 1**

In [None]:
import torch
import torch.nn as nn
import math

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)

    def forward(self, tokens):
        return self.embedding(tokens)


class PositionalEncoding(nn.Module):
    def __init__(self, max_len, embed_size):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, embed_size)
        for pos in range(max_len):
            for i in range(0, embed_size, 2):
                self.encoding[pos, i] = math.sin(pos / (10000 ** ((2 * i) / embed_size)))
                self.encoding[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1)) / embed_size)))
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return x + self.encoding[:, :x.size(1)].clone().detach()


class GroupQueryAttention(nn.Module):
    def __init__(self, embed_size, heads, num_groups=4):
        super(GroupQueryAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.num_groups = num_groups
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size should be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim * self.num_groups, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = self.queries(query).reshape(N, query_len, self.heads, self.head_dim * self.num_groups)

        values = self.values(values)
        keys = self.keys(keys)

        values = values.permute(0, 2, 1, 3)  # (N, heads, value_len, head_dim)
        keys = keys.permute(0, 2, 3, 1)  # (N, heads, head_dim, key_len)
        queries = queries.permute(0, 2, 1, 3)  # (N, heads, query_len, head_dim * num_groups)

        energy = torch.matmul(queries, keys) / math.sqrt(self.head_dim)

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.nn.functional.softmax(energy, dim=-1)

        out = torch.matmul(attention, values)  # (N, heads, query_len, head_dim)
        out = out.permute(0, 2, 1, 3).contiguous().reshape(N, query_len, self.heads * self.head_dim)
        out = self.fc_out(out)
        return out


class FeedForward(nn.Module):
    def __init__(self, embed_size, ff_hidden):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(embed_size, ff_hidden)
        self.fc2 = nn.Linear(ff_hidden, embed_size)

    def forward(self, x):
        x = torch.nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [None]:
import torch
import torch.nn as nn

class GPT2(nn.Module):
    def __init__(self, vocab_size, embed_size=768, num_heads=12, num_layers=12, ff_hidden=3072, max_sequence_len=1024):
        super(GPT2, self).__init__()

        self.token_embedding = TokenEmbedding(vocab_size, embed_size)
        self.positional_encoding = PositionalEncoding(max_sequence_len, embed_size)
        self.transformer_blocks = nn.ModuleList(
            [TransformerBlock(embed_size, num_heads, ff_hidden) for _ in range(num_layers)]
        )
        self.fc = nn.Linear(embed_size, vocab_size)

    def forward(self, x, mask=None):
        tokens_embedding = self.token_embedding(x)
        positional_encoded = self.positional_encoding(tokens_embedding)

        for transformer in self.transformer_blocks:
            positional_encoded = transformer(positional_encoded, mask)

        output = self.fc(positional_encoded)
        return output


class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, ff_hidden, dropout=0.1):
        super(TransformerBlock, self).__init__()

        self.group_query_attention = GroupQueryAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.ff = FeedForward(embed_size, ff_hidden)
        self.norm2 = nn.LayerNorm(embed_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attention_output = self.group_query_attention(x, x, x, mask)
        x = self.dropout(self.norm1(attention_output + x))
        ff_output = self.ff(x)
        x = self.dropout(self.norm2(ff_output + x))
        return x

Task 2
# Including Rotary Positional Embedding and Group Query Attention (GQA)


In [None]:
import torch
import torch.nn as nn
import math

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)

    def forward(self, tokens):
        return self.embedding(tokens)

class RotaryEmbedding(nn.Module):
    def __init__(self, embed_size, max_sequence_len):
        super(RotaryEmbedding, self).__init__()
        self.embed_size = embed_size
        self.max_sequence_len = max_sequence_len
        self.alpha = nn.Parameter(torch.zeros(self.embed_size // 2))
        self.beta = nn.Parameter(torch.zeros(self.embed_size // 2))

    def forward(self, x):
        positions = torch.arange(0, x.shape[1], dtype=torch.float, device=x.device)
        angles = positions.unsqueeze(1) / torch.pow(10000, 2 * torch.arange(0, self.embed_size, 2, dtype=torch.float, device=x.device) / self.embed_size)
        angles = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)

        # Rotary transformation
        sin, cos = angles[:, :, 0::2], angles[:, :, 1::2]
        even = torch.sin(self.alpha * cos + self.beta * sin)
        odd = torch.cos(self.alpha * cos + self.beta * sin)
        angles = torch.cat([even, odd], dim=-1)

        return x + angles.unsqueeze(0)

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size should be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) / math.sqrt(self.head_dim)

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.nn.functional.softmax(energy, dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        out = self.fc_out(out)
        return out


class FeedForward(nn.Module):
    def __init__(self, embed_size, ff_hidden):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(embed_size, ff_hidden)
        self.fc2 = nn.Linear(ff_hidden, embed_size)

    def forward(self, x):
        x = torch.nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, ff_hidden, dropout=0.1):
        super(TransformerBlock, self).__init__()

        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.ff = FeedForward(embed_size, ff_hidden)
        self.norm2 = nn.LayerNorm(embed_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attention_output = self.attention(x, x, x, mask)
        x = self.dropout(self.norm1(attention_output + x))
        ff_output = self.ff(x)
        x = self.dropout(self.norm2(ff_output + x))
        return x

class GPT2(nn.Module):
    def __init__(self, vocab_size, embed_size=768, num_heads=12, num_layers=12, ff_hidden=3072, max_sequence_len=1024):
        super(GPT2, self).__init__()

        self.token_embedding = TokenEmbedding(vocab_size, embed_size)
        self.rotary_embedding = RotaryEmbedding(embed_size, max_sequence_len)
        self.transformer_blocks = nn.ModuleList(
            [TransformerBlock(embed_size, num_heads, ff_hidden) for _ in range(num_layers)]
        )
        self.fc = nn.Linear(embed_size, vocab_size)

    def forward(self, x, mask=None):
        tokens_embedding = self.token_embedding(x)
        positional_encoded = self.rotary_embedding(tokens_embedding)

        for transformer in self.transformer_blocks:
            positional_encoded = transformer(positional_encoded, mask)

        output = self.fc(positional_encoded)
        return output

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, ff_hidden, dropout=0.1):
        super(TransformerBlock, self).__init__()

        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.ff = FeedForward(embed_size, ff_hidden)
        self.norm2 = nn.LayerNorm(embed_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attention_output = self.attention(x, x, x, mask)
        x = self.dropout(self.norm1(attention_output + x))
        ff_output = self.ff(x)
        x = self.dropout(self.norm2(ff_output + x))
        return x

class GPT2(nn.Module):
    def __init__(self, vocab_size, embed_size=768, num_heads=12, num_layers=12, ff_hidden=3072, max_sequence_len=1024):
        super(GPT2, self).__init__()

        self.token_embedding = TokenEmbedding(vocab_size, embed_size)
        self.rotary_embedding = RotaryEmbedding(embed_size, max_sequence_len)
        self.transformer_blocks = nn.ModuleList(
            [TransformerBlock(embed_size, num_heads, ff_hidden) for _ in range(num_layers)]
        )
        self.fc = nn.Linear(embed_size, vocab_size)

    def forward(self, x, mask=None):
        tokens_embedding = self.token_embedding(x)
        positional_encoded = self.rotary_embedding(tokens_embedding)

        for transformer in self.transformer_blocks:
            positional_encoded = transformer(positional_encoded, mask)

        output = self.fc(positional_encoded)
        return output

Task 2.3

In [None]:
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)

    def forward(self, tokens):
        return self.embedding(tokens)

class RotaryEmbedding(nn.Module):
    def __init__(self, embed_size, max_sequence_len):
        super(RotaryEmbedding, self).__init__()
        self.embed_size = embed_size
        self.max_sequence_len = max_sequence_len
        self.alpha = nn.Parameter(torch.zeros(self.embed_size // 2))
        self.beta = nn.Parameter(torch.zeros(self.embed_size // 2))

    def forward(self, x):
        positions = torch.arange(0, x.shape[1], dtype=torch.float, device=x.device)
        angles = positions.unsqueeze(1) / torch.pow(10000, 2 * torch.arange(0, self.embed_size, 2, dtype=torch.float, device=x.device) / self.embed_size)
        angles = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)

        # Rotary transformation
        sin, cos = angles[:, :, 0::2], angles[:, :, 1::2]
        even = torch.sin(self.alpha * cos + self.beta * sin)
        odd = torch.cos(self.alpha * cos + self.beta * sin)
        angles = torch.cat([even, odd], dim=-1)

        return x + angles.unsqueeze(0)

class SlidingWindowAttention(nn.Module):
    def __init__(self, embed_size, heads, window_size=128, dropout=0.1):
        super(SlidingWindowAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.window_size = window_size
        self.head_dim = embed_size // heads
        self.scale = self.head_dim ** -0.5

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size should be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        values = values.reshape(N, -1, self.heads, self.head_dim)
        keys = keys.reshape(N, -1, self.heads, self.head_dim)
        queries = self.queries(query).reshape(N, -1, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)

        output = torch.zeros(N, query_len, self.heads * self.head_dim, device=query.device)

        for i in range(0, key_len, self.window_size):
            start = i
            end = min(start + self.window_size, key_len)

            window_values = values[:, start:end]
            window_keys = keys[:, start:end]

            energy = torch.einsum("nqhd,nkhd->nhqk", queries, window_keys) * self.scale

            if mask is not None:
                energy = energy.masked_fill(mask[:, :, start:end] == 0, float("-1e20"))

            attention = torch.nn.functional.softmax(energy, dim=-1)
            out = torch.einsum("nhql,nlhd->nqhd", attention, window_values).reshape(
                N, -1, self.heads * self.head_dim
            )
            output[:, :, start * self.head_dim:end * self.head_dim] += out

        output = self.fc_out(output)
        return output


class FeedForward(nn.Module):
    def __init__(self, embed_size, ff_hidden):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(embed_size, ff_hidden)
        self.fc2 = nn.Linear(ff_hidden, embed_size)

    def forward(self, x):
        x = torch.nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, ff_hidden, window_size=128, dropout=0.1):
        super(TransformerBlock, self).__init__()

        self.sliding_window_attention = SlidingWindowAttention(embed_size, heads, window_size, dropout)
        self.norm1 = nn.LayerNorm(embed_size)
        self.ff = FeedForward(embed_size, ff_hidden)
        self.norm2 = nn.LayerNorm(embed_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attention_output = self.sliding_window_attention(x, x, x, mask)
        x = self.dropout(self.norm1(attention_output + x))
        ff_output = self.ff(x)
        x = self.dropout(self.norm2(ff_output + x))
        return x

class GPT2(nn.Module):
    def __init__(self, vocab_size, embed_size=768, num_heads=12, num_layers=12, ff_hidden=3072, max_sequence_len=1024):
        super(GPT2, self).__init__()

        self.token_embedding = TokenEmbedding(vocab_size, embed_size)
        self.rotary_embedding = RotaryEmbedding(embed_size, max_sequence_len)
        self.transformer_blocks = nn.ModuleList(
            [TransformerBlock(embed_size, num_heads, ff_hidden) for _ in range(num_layers)]
        )
        self.fc = nn.Linear(embed_size, vocab_size)

    def forward(self, x, mask=None):
        tokens_embedding = self.token_embedding(x)
        positional_encoded = self.rotary_embedding(tokens_embedding)

        for transformer in self.transformer_blocks:
            positional_encoded = transformer(positional_encoded, mask)

        output = self.fc(positional_encoded)
        return output