In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SlidingWindowAttention(nn.Module):
    def __init__(self, embedding_size, num_heads, window_size):
        super(SlidingWindowAttention, self).__init__()
        self.embedding_size = embedding_size
        self.num_heads = num_heads
        self.head_size = embedding_size // num_heads
        self.window_size = window_size

        self.q_linear = nn.Linear(embedding_size, embedding_size)
        self.k_linear = nn.Linear(embedding_size, embedding_size)
        self.v_linear = nn.Linear(embedding_size, embedding_size)
        self.fc = nn.Linear(embedding_size, embedding_size)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()

        q = self.q_linear(x).view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2)
        k = self.k_linear(x).view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2)
        v = self.v_linear(x).view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2)

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_size)

        # Reshape scores tensor to make it 4-dimensional
        scores = scores.view(batch_size, self.num_heads, seq_len, seq_len)

        # Apply sliding window attention
        unfolded_scores = scores.unfold(2, self.window_size, 1)
        attention = F.softmax(unfolded_scores, dim=-1).contiguous().view(batch_size, self.num_heads, seq_len, -1)

        # Apply attention to values and restructure output
        x = torch.matmul(attention, v).transpose(1, 2).contiguous().view(batch_size, seq_len, self.num_heads, -1)

        # Merge heads and project
        x = x.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        x = self.fc(x)

        return x


class RotaryPositionalEncoding(nn.Module):
    def __init__(self, embedding_size):
        super(RotaryPositionalEncoding, self).__init__()
        self.embedding_size = embedding_size
        self.register_buffer('pe', self._generate_positional_encoding())

    def _generate_positional_encoding(self):
        pe = torch.zeros(self.embedding_size, self.embedding_size)
        position = torch.arange(0, self.embedding_size, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, self.embedding_size, 2).float() * (-math.log(10000.0) / self.embedding_size))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe.unsqueeze(0)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return x

class GroupQueryAttention(nn.Module):
    def __init__(self, embedding_size, num_heads, num_groups):
        super(GroupQueryAttention, self).__init__()
        self.embedding_size = embedding_size
        self.num_heads = num_heads
        self.head_size = embedding_size // num_heads
        self.num_groups = num_groups

        self.q_linear = nn.Linear(embedding_size, embedding_size)
        self.k_linear = nn.Linear(embedding_size, embedding_size)
        self.v_linear = nn.Linear(embedding_size, embedding_size)
        self.fc = nn.Linear(embedding_size, embedding_size)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()

        q = self.q_linear(x).view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2)
        k = self.k_linear(x).view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2)
        v = self.v_linear(x).view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2)

        q_concat = torch.cat(torch.chunk(q, self.num_groups, dim=-1), dim=1)
        k_concat = torch.cat(torch.chunk(k, self.num_groups, dim=-1), dim=1)
        v_concat = torch.cat(torch.chunk(v, self.num_groups, dim=-1), dim=1)

        scores = torch.matmul(q_concat, k_concat.transpose(-2, -1)) / math.sqrt(self.head_size)
        attention = F.softmax(scores, dim=-1)

        x = torch.matmul(attention, v_concat).transpose(1, 2).contiguous().view(batch_size, seq_len, self.embedding_size)
        x = self.fc(x)

        return x

class TransformerBlock(nn.Module):
    def __init__(self, embedding_size, num_heads, hidden_size, attention_type='sliding_window', window_size=5, num_groups=2):
        super(TransformerBlock, self).__init__()
        self.attention_type = attention_type

        if attention_type == 'sliding_window':
            self.attention = SlidingWindowAttention(embedding_size, num_heads, window_size)
        elif attention_type == 'group_query':
            self.attention = GroupQueryAttention(embedding_size, num_heads, num_groups)
        else:
            # default to original MultiHeadAttention
            self.attention = MultiHeadAttention(embedding_size, num_heads)

        self.feed_forward = FeedForward(embedding_size, hidden_size)
        self.layer_norm1 = nn.LayerNorm(embedding_size)
        self.layer_norm2 = nn.LayerNorm(embedding_size)

    def forward(self, x):
        attention_output = self.attention(x)
        x = x + attention_output
        x = self.layer_norm1(x)

        feed_forward_output = self.feed_forward(x)
        x = x + feed_forward_output
        x = self.layer_norm2(x)

        return x

class GPT2(nn.Module):
    def __init__(self, vocab_size, embedding_size, num_heads, hidden_size, num_layers, attention_type='sliding_window', window_size=5, num_groups=2):
        super(GPT2, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.positional_encoding = RotaryPositionalEncoding(embedding_size)
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(embedding_size, num_heads, hidden_size, attention_type, window_size, num_groups) for _ in range(num_layers)
        ])
        self.fc = nn.Linear(embedding_size, vocab_size)

    def forward(self, x):
        embedded = self.embedding(x)
        embedded = self.positional_encoding(embedded)

        for transformer_block in self.transformer_blocks:
            embedded = transformer_block(embedded)

        logits = self.fc(embedded)
        return logits
    def generate(self, input_ids, max_length=50, num_return_sequences=1, temperature=0.7):
        generated_sequences = []
        for _ in range(num_return_sequences):
          current_sequence = input_ids.clone()
          for _ in range(max_length):
            logits = self.forward(current_sequence)
            next_token_logits = logits[:, -1, :] / temperature

            # Ensure the next_token_logits has the proper shape for sampling
            next_token_logits = next_token_logits.squeeze(1)  # Squeeze the tensor if necessary

            # Sample from the distribution or take the argmax
            next_token = torch.multinomial(F.softmax(next_token_logits, dim=-1), num_samples=1)
            next_token = next_token.unsqueeze(1)

            # Append the new token to the current sequence
            current_sequence = torch.cat([current_sequence, next_token], dim=-1)
          generated_sequences.append(current_sequence)
        return generated_sequences

