In [2]:
from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification, AutoModel
from transformers import (
    DataCollatorWithPadding,
    Trainer,
    default_data_collator,
    set_seed,
    TrainingArguments,
    HfArgumentParser,
    EvalPrediction,
)
from datasets import load_dataset
import random
import numpy as np
import torch
import evaluate

import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import evaluate
from omegaconf import OmegaConf
import tiktoken

In [17]:
model_name_or_path="stanford-crfm/battlestar-gpt2-small-x49"
model_name_or_path="Qwen/Qwen2.5-0.5B"


In [18]:
#from transformers.models.gpt2.modeling_gpt2 import GPT2Model

In [19]:
config = AutoConfig.from_pretrained(model_name_or_path)

In [20]:
config.n_layer = 2

In [21]:
model = AutoModel.from_config(config)

In [22]:
model

Qwen2Model(
  (embed_tokens): Embedding(151936, 896)
  (layers): ModuleList(
    (0-23): 24 x Qwen2DecoderLayer(
      (self_attn): Qwen2Attention(
        (q_proj): Linear(in_features=896, out_features=896, bias=True)
        (k_proj): Linear(in_features=896, out_features=128, bias=True)
        (v_proj): Linear(in_features=896, out_features=128, bias=True)
        (o_proj): Linear(in_features=896, out_features=896, bias=False)
      )
      (mlp): Qwen2MLP(
        (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
        (up_proj): Linear(in_features=896, out_features=4864, bias=False)
        (down_proj): Linear(in_features=4864, out_features=896, bias=False)
        (act_fn): SiLU()
      )
      (input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
      (post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
    )
  )
  (norm): Qwen2RMSNorm((896,), eps=1e-06)
  (rotary_emb): Qwen2RotaryEmbedding()
)

In [23]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, context_length, num_heads, dropout, qkv_bias=False):
        super().__init__()
        
        assert d_model % num_heads == 0
        self.d_head = d_model // num_heads
        
        self.d_model = d_model
        self.num_heads = num_heads
        
        self.W_Q = nn.Linear(d_model, d_model, bias=qkv_bias)
        self.W_K = nn.Linear(d_model, d_model, bias=qkv_bias)
        self.W_V = nn.Linear(d_model, d_model, bias=qkv_bias)
        
        self.W_O = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        
        self.register_buffer("mask", torch.triu(torch.ones((context_length, context_length)), diagonal=1))
        
    def forward(self, x):
        
        batch_size, seq_len, d_model = x.shape
        
        queries = self.W_Q(x)
        keys = self.W_K(x)
        values = self.W_V(x)
        
        # splits to heads
        queries = queries.view(batch_size, seq_len, self.num_heads, self.d_head)
        keys = keys.view(batch_size, seq_len, self.num_heads, self.d_head)
        values = values.view(batch_size, seq_len, self.num_heads, self.d_head)
        
        # exchange seq_len and num_head axis
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        
        attention_logits = queries @ keys.transpose(2, 3) / self.d_head ** 0.5
        
        mask_bool = self.mask.bool()[:seq_len,:seq_len]
        
        
        attention_logits.masked_fill(mask_bool, -torch.inf)
        
        attention_weights = torch.softmax(attention_logits, dim=-1)
        
        attention_weights = self.dropout(attention_weights)
        
        context_vec = attention_weights @ values
        
        
        context_vec = context_vec.transpose(2,3).contiguous().view(batch_size, seq_len, self.d_model)
        
        out = self.W_O(context_vec)
        
        return out

In [24]:
d_model = 512
context_length = 1024
num_heads = 8
dropout = 0.2
qkv_bias = False

attention = MultiHeadAttention(d_model, context_length, num_heads, dropout,qkv_bias)

In [25]:
x = torch.rand((12, 256, d_model))


In [26]:
context_x = attention(x)

In [27]:
context_x.shape

torch.Size([12, 256, 512])