In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Tokenizer, GPT2Config

In [4]:
config = GPT2Config()
config

GPT2Config {
  "activation_function": "gelu_new",
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_embd": 768,
  "n_head": 12,
  "n_inner": null,
  "n_layer": 12,
  "n_positions": 1024,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "transformers_version": "4.28.1",
  "use_cache": true,
  "vocab_size": 50257
}

In [6]:
class GPT2Attention(nn.Module):
    def __init__(self, config):
        super().__init__()

        max_positions = config.n_positions
        self.mask = torch.tril(torch.ones(max_positions, max_positions), dtype = torch.bool).view(1, 1, max_positions, max_positions)
        self.embed_dim = config.n_embd
        self.num_heads = config.n_head
        self.head_dim = self.embed_dim // self.num_heads
        self.split_size = self.embed_dim
        self.c_attn = nn.Linear(self.embed_dim, 3 * self.embed_dim)
        self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.dropout = nn.Dropout(0.1)

    def _attn(self, query, key, value):
        # query, key, value: [batch_size, num_heads, seq_len, head_dim]
        attn_weights = torch.matmul(query, key.transpose(-1, -2))
        attn_weights = attn_weights / (float(self.head_dim) ** 0.5)

        # sequence length
        T = query.size(-2) # dimensionality before the last one
        casual_mask = self.mask[:, :, :T, :T].bool()
        attn_weights = torch.where(casual_mask, attn_weights, torch.tensor(-1e4))

        attn_weights = nn.Softmax(dim = -1)(attn_weights)
        attn_weights = self.dropout(attn_weights)
        # attn_weights: [batch_size, num_heads, seq_len, seq_len] matmul [batch_size, num_heads, seq_len, head_dim]
        attn_output = torch.matmul(attn_weights, value)
        return attn_output
        


In [8]:
700 // 10

70