GTP-2!

In [1]:
import torch as t
import torch.nn as nn
import einops
import gpt_tests

from dataclasses import dataclass
from torchtyping import TensorType
import transformers

In [58]:
class GPT2MultiHeadAttention(nn.Module):
    def __init__(self,hidden_size: int, num_heads:int):
        super().__init__()
        self.attn_lin=nn.Linear(hidden_size,3*hidden_size) #Q,K,V
        self.out_lin= nn.Linear(hidden_size,hidden_size) #O
        self.head_size= hidden_size//num_heads
        self.num_heads=num_heads
        self.hidden_size=hidden_size
    def forward(self, x:t.Tensor, past_key_values=None,return_key_values=False ):  #num_heads,seq_len,
        attention = self.attn_lin(x) 
        Q= einops.rearrange(attention[:,:,:self.hidden_size],"b n (nh hs)-> b nh n hs",nh = self.num_heads,hs = self.head_size)
        K= einops.rearrange(attention[:,:,self.hidden_size:2*self.hidden_size],"b n (nh hs)-> b nh n hs",nh = self.num_heads,hs = self.head_size)
        V= einops.rearrange(attention[:,:,-self.hidden_size:],"b n (nh hs)-> b nh n hs",nh = self.num_heads,hs = self.head_size)
        if past_key_values is not None:
            Kpast=past_key_values[:,:,:self.head_size].unsqueeze(0)
            Vpast=past_key_values[:,:,self.head_size:].unsqueeze(0)
            K=t.concat((Kpast,K),dim=2)
            V=t.concat((Vpast,V),dim=2)
        raw_attn = t.einsum("bhni,bhmi -> bhnm",Q,K)/((self.head_size)**.5)#n : seer m: seeee 
        seq_length=raw_attn.shape[-1]
        if past_key_values is None:
            to_mask=t.triu(t.ones(seq_length,seq_length),diagonal=1).to(raw_attn.device).bool()
            raw_attn=raw_attn.masked_fill_(to_mask,-1e4)
        attn = t.softmax(raw_attn,dim=3)
        to_output = einops.rearrange(t.einsum("bhnm,bhmi->bhni",attn,V),"b nh n hi -> b n (nh hi)")
        output = self.out_lin(to_output)
        if return_key_values:
            if past_key_values is not None:
                key_values = t.concat((K[:,:,-1,:],V[:,:,-1,:]),dim=2).unsqueeze(2)
            else:
                key_values = t.concat((K,V),dim=3)
            return output, key_values
        else: 
            return output


In [59]:
gpt_tests.test_unidirectional_attn(GPT2MultiHeadAttention)

Congrats! You've passed the test!


In [60]:
class GPT2Block(nn.Module):
    def __init__(self, hidden_size: int, num_heads: int, dropout: float, layer_norm_epsilon: float):
        super().__init__()
        self.ln1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
        self.attn = GPT2MultiHeadAttention(hidden_size, num_heads)
        self.ln2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
        self.linear1 = nn.Linear(hidden_size, hidden_size*4)
        self.GELU = nn.GELU()
        self.linear2 = nn.Linear(hidden_size*4, hidden_size)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x,past_key_values=None,return_key_values=False):
        if return_key_values:
            res,key_values = self.attn(self.ln1(x),past_key_values,True)
            res=x+res
        else:
            res =x+self.attn(self.ln1(x),past_key_values)
        x = self.ln2(res)
        x = self.linear1(x)
        x = self.GELU(x)
        x = self.linear2(x)
        x = self.dropout(x)
        if return_key_values:
            return res + x,key_values
        else:
            return res+x

gpt_tests.test_gpt_block(GPT2Block)

Congrats! You've passed the test!


In [128]:
@dataclass
class GPT2Output:
    logits: TensorType["batch_size", "vocab_size"]
    final_encoding: TensorType["batch_size", "hidden_size"]


class GPT2(nn.Module):
    def __init__(self,
        num_layers: int,
        num_heads: int,
        vocab_size: int,
        hidden_size: int,
        max_position_embeddings: int,
        dropout: float,
        layer_norm_epsilon: int,
        use_cache=False):

        super().__init__()
        self.token_embedding = nn.Parameter(t.randn(vocab_size, hidden_size))
        self.pos_embedding = nn.Parameter(t.randn(max_position_embeddings, hidden_size))
        self.dropout = nn.Dropout(dropout)
        self.blocks = nn.Sequential(*[GPT2Block(hidden_size, num_heads, dropout, layer_norm_epsilon) for _ in range(num_layers)])
        self.head_size = hidden_size//num_heads
        self.ln = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
        self.use_cache=use_cache
        self.num_layers=num_layers
        self.num_heads = num_heads
        self.vocab_size=vocab_size
        self.tokenizer = transformers.GPT2Tokenizer.from_pretrained("gpt2")
        if use_cache:
            self.key_values=None


    def forward(self, input_ids):
        if self.use_cache and self.key_values is not None:
            tok_embed = self.token_embedding[input_ids[0,-1]]
            pos_embed = self.pos_embedding[input_ids.shape[1]-1]
        else:
            tok_embed = self.token_embedding[input_ids] # [batch_size, seq_len, hidden_size]
            pos_embed = self.pos_embedding[t.arange(input_ids.shape[1], device=input_ids.device)] # [seq_len, hidden_size]
        x = self.dropout(tok_embed + pos_embed)
        if self.use_cache:
           
            if self.key_values is not None:
                x=x.unsqueeze(0).unsqueeze(0)
                new_key_values = t.zeros(self.num_layers,self.num_heads,1, 2*self.head_size).to(x.device)
            else:
                new_key_values = t.zeros(self.num_layers,self.num_heads,input_ids.shape[1], 2*self.head_size).to(x.device)
            for i,layer in enumerate(self.blocks):
                if self.key_values is None:
                    x, layer_key_values=layer(x,None,True)
                else:
                    x, layer_key_values=layer(x,self.key_values[i],True)
                new_key_values[i]=layer_key_values
            if self.key_values is None:
                self.key_values=new_key_values
            else:
                self.key_values=t.concat((self.key_values,new_key_values),dim=2)
            x=self.ln(x)
        else:
            x = self.ln(self.blocks(x))
        final_encoding = x[:, -1, :]
        logits = final_encoding @ self.token_embedding.T
        return GPT2Output(final_encoding=final_encoding, logits=logits)

    def next_token(self,input_ids,temperature,freq_penalty=2.0):
        output = self.forward(input_ids)
        freq_counts = t.bincount(input_ids[0], minlength=self.vocab_size).to(output.logits.device)
        adj_logits = output.logits / temperature - freq_counts * freq_penalty
        probs = adj_logits.softmax(dim=1)
        return t.multinomial(probs, 1).item()

    def generate(self, text, max_length=30, temperature=1.0, freq_penalty=2.0):
        self.eval()
        input_ids = self.tokenizer.encode(text)
        self.key_values = None
        device = "cuda" if t.cuda.is_available() else "cpu"
        self.to(device)
        while ( len(input_ids) <= max_length and
                not input_ids[-1] == self.tokenizer.eos_token_id ):
            input_ids_tensor = t.tensor(input_ids).unsqueeze(0).to(device)
            input_ids.append(self.next_token(input_ids_tensor, temperature, freq_penalty))
        return self.tokenizer.decode(input_ids)

In [107]:
pretrained_gpt = gpt_tests.get_pretrained_gpt()

In [129]:
my_gpt = GPT2(num_layers=12, num_heads=12, vocab_size=50257,
    hidden_size=768, max_position_embeddings=1024, dropout=0.1, layer_norm_epsilon=1e-5, use_cache=True)

state_dict = pretrained_gpt.state_dict()
state_dict["token_embedding"] = state_dict["token_embedding.weight"]
state_dict["pos_embedding"] = state_dict["pos_embedding.weight"]
del state_dict["token_embedding.weight"]
del state_dict["pos_embedding.weight"]

my_gpt.load_state_dict(state_dict)

<All keys matched successfully>

In [134]:
my_gpt.generate("After a long day at work", max_length = 100,temperature=1)

"After a long day at work, Franck was started and repressed in his left toe by the fifth judge manner of taking part. It seemed surprisingly mid-week for full-time racism to further aggravate irritation at this type of treatment. So having said that it has been almost 2 months since the verdictswere read (one month already due in December), we'll briefly see how our cohort will fare when Illygard Hill surveys strike stalemate before settling across within S1 celebrations division's parameters"

In [65]:
gpt_tests.test_gpt_cache(GPT2)

torch.Size([2, 4, 1, 32])
new sentence
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2