In [1]:
import torch
import torch.nn as nn

class FF(nn.Module):
    def __init__(self, embd_dim):
        super().__init__()
        self.linear1 = nn.Linear(embd_dim, 8 * embd_dim)
        self.linear2 = nn.Linear(8 * embd_dim, embd_dim)
        self.gelu = nn.GELU()

    def forward(self, x):
        return self.linear2(self.gelu(self.linear1(x)))

In [22]:
import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, embd_dim):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=embd_dim, num_heads=num_heads, batch_first=True)

    def forward(self, q, k, v):
        attn_output, _ = self.attn(q, k, v)
        return attn_output

In [23]:
import torch
import torch.nn as nn

class Decode(nn.Module):
    def __init__(self, num_heads, embd_dim):
        super().__init__()
        self.attn = MultiHeadAttention(num_heads, embd_dim)
        self.norm1 = nn.LayerNorm(embd_dim)
        self.norm2 = nn.LayerNorm(embd_dim)
        self.ff = FF(embd_dim)
        self.dropout1 = nn.Dropout(0.2)
        self.dropout2 = nn.Dropout(0.2)

    def forward(self, x):
        x_norm = self.norm1(x)
        attn_out = self.attn(x_norm, x_norm, x_norm)
        x = x + self.dropout1(attn_out)
        x_norm = self.norm2(x)
        x = x + self.dropout2(self.ff(x_norm))

        return x

In [24]:
class Decoder(nn.Module):
    def __init__(
        self,
        vocab_size,
        seq_length,
        num_layers,
        num_heads,
        embd_dim,
    ):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embd_dim)
        self.pos_embedding = nn.Embedding(seq_length, embd_dim)
        self.layers = nn.ModuleList(
            [Decode(num_heads, embd_dim) for i in range(num_layers)]
        )
        self.norm = nn.LayerNorm(embd_dim)

    def forward(self, x):
        seq_length = x.size(1)
        positions = (
            torch.arange(0, seq_length, device=x.device).unsqueeze(0).expand_as(x)
        )
        x1 = self.embedding(x) + self.pos_embedding(positions)
        for layer in self.layers:
            x1 = layer(x1)
        return self.norm(x1)

In [25]:
class GPT3(nn.Module):
    def __init__(self, vocab_size, seq_length, num_heads, num_layers, embd_dim):
        super().__init__()
        self.dec = Decoder(vocab_size, seq_length, num_heads, num_layers, embd_dim)
        self.out = nn.Linear(embd_dim, vocab_size)
        self.seq_length = seq_length
        self.vocab_size = vocab_size

    def forward(self, x):
        x = self.dec(x)
        x = self.out(x)
        return x

    def generate(self, input_ids, max_length=50, temperature=0.9):
        self.eval()
        output = input_ids.tolist()[0]
        with torch.no_grad():
            for _ in range(max_length):
                input_ids = input_ids.to("cuda")
                logits = self(input_ids)
                logits = logits[:, -1, :] / temperature
                probs = nn.functional.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
                output.append(int(next_token[0, 0]))
                input_ids = torch.cat([input_ids[:, 1:], next_token], dim=1)
        self.train()
        return output

In [26]:
vocab_size = 50000 # dont touch this
seq_length = 128 # length of the input sequence, gpt3 uses 1024, use a smaller value for gpu constraints
num_heads = 12 # number of attention heads, gpt3 uses 96, use a smaller value such that embd dim is divisible by num_heads
num_layers = 12 # number of transformer blocks, gpt3 uses 96, use a smaller value 
embd_dim = 768 # embedding dimension, gpt3 uses 12288, use a smaller value
model = GPT3(vocab_size, seq_length, num_heads, num_layers, embd_dim)

In [27]:
model.load_state_dict(torch.load("/kaggle/input/gpt3/pytorch/default/1/model.pth", weights_only=True, map_location="cpu"))

<All keys matched successfully>

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

total_params = count_parameters(model)
print(f"Total number of parameters: {total_params:,}")

Total number of parameters: 246,975,824
