
## Load weights to GPT with torch 


In [3]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Tokenizer, GPT2LMHeadModel


# GPT Architecture (compatible with pre-trained weights)
class GPT(nn.Module):
    def __init__(self, vocab_size, block_size, n_embd, n_layer, n_head):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, n_embd)
        self.position_embedding = nn.Embedding(block_size, n_embd)
        self.blocks = nn.ModuleList([
            Block(n_embd, n_head) for _ in range(n_layer)
        ])
        self.ln_f = nn.LayerNorm(n_embd)
        self.head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx):
        B, T = idx.size()
        tok_emb = self.token_embedding(idx)
        pos_emb = self.position_embedding(torch.arange(T, device=idx.device))
        x = tok_emb + pos_emb
        for block in self.blocks:
            x = block(x)
        x = self.ln_f(x)
        logits = self.head(x)
        return logits


class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        self.attn = CausalSelfAttention(n_embd, n_head)
        self.ff = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ff(self.ln2(x))
        return x


class CausalSelfAttention(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        assert n_embd % n_head == 0
        self.n_head = n_head
        self.head_dim = n_embd // n_head
        self.query = nn.Linear(n_embd, n_embd)
        self.key = nn.Linear(n_embd, n_embd)
        self.value = nn.Linear(n_embd, n_embd)
        self.proj = nn.Linear(n_embd, n_embd)
        self.register_buffer("mask", torch.tril(torch.ones(1024, 1024)))

    def forward(self, x):
        B, T, C = x.size()
        q = self.query(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        k = self.key(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        v = self.value(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn = attn.masked_fill(self.mask[:T, :T] == 0, float("-inf"))
        attn = F.softmax(attn, dim=-1)

        out = attn @ v
        out = out.transpose(1, 2).contiguous().view(B, T, C)
        return self.proj(out)


class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.GELU(),
            nn.Linear(4 * n_embd, n_embd)
        )

    def forward(self, x):
        return self.net(x)


# Load GPT-2 Pre-Trained Weights
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
gpt2_model = GPT2LMHeadModel.from_pretrained(model_name)

# Access weights
gpt2_weights = gpt2_model.state_dict()

# Initialize Custom GPT Model
vocab_size = gpt2_weights["transformer.wte.weight"].shape[0]
block_size = gpt2_model.config.n_ctx
n_embd = gpt2_model.config.n_embd
n_layer = gpt2_model.config.n_layer
n_head = gpt2_model.config.n_head

model = GPT(vocab_size, block_size, n_embd, n_layer, n_head)

# Map GPT-2 weights to custom GPT model
model.token_embedding.weight.data = gpt2_weights["transformer.wte.weight"].clone()
model.position_embedding.weight.data = gpt2_weights["transformer.wpe.weight"].clone()

# Fine-Tuning
def fine_tune(model, data, epochs=3, lr=1e-4):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    for epoch in range(epochs):
        for x, y in data:
            logits = model(x)
            loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch + 1}, Loss: {loss.item()}")


# Example data
data = [  # Dummy data: sequence input and target
    (torch.randint(0, vocab_size, (4, block_size)), torch.randint(0, vocab_size, (4, block_size)))
]

# Fine-Tune
fine_tune(model, data)



Epoch 1, Loss: 10.994893074035645
Epoch 2, Loss: 10.771008491516113
Epoch 3, Loss: 10.615301132202148
