## GPT-1

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_head, dim_feedforward=3072, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, n_head, dropout=dropout)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x2 = self.self_attn(x, x, x)[0]
        x = x + self.dropout(x2)
        x = self.norm1(x)
        x2 = self.ffn(x)
        x = x + self.dropout(x2)
        x = self.norm2(x)
        return x
    
class Decoder(nn.Module):
    def __init__(self, vocab_size, max_len, d_model, n_heads, num_decoder_layers):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embedding = nn.Embedding(max_len, d_model)
        self.layers = nn.ModuleList([DecoderLayer(d_model, n_heads) for _ in range(num_decoder_layers)])
        self.linear = nn.Linear(d_model, vocab_size)
        
    def forward(self, x):
        x = self.embedding(x)
        pos = torch.arange(0, x.shape[1], device=x.device).unsqueeze(0)
        x = x + self.pos_embedding(pos)
        for layer in self.layers:
            x = layer(x)
        x = self.linear(x)
        return x
    
class GPT(nn.Module):
    def __init__(self, vocab_size, max_len, d_model=768, n_heads=12, num_decoder_layers=12):
        super(GPT, self).__init__()
        self.decoder = Decoder(vocab_size, max_len, d_model, n_heads, num_decoder_layers)
        
    def forward(self, x):
        return self.decoder(x)