In [1]:
# this modifies model5 to prefer selecting tokens with letter 'X'

In [2]:
import numpy as np
import matplotlib.pyplot as plt


import torch
import torch.nn as nn
import torch.nn.functional as F

import textwrap

In [3]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [4]:
# GPT2 tokenizer
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

  from .autonotebook import tqdm as notebook_tqdm


model and its x prefernce

In [5]:
# hyper para for GPT2-124M
n_vocab = 50257 # GPT2 vocab size
embed_dim = 768 #embedding dim
seq_len = 256 #max seq len
n_heads = 12 # attention heads
n_blocks = 12 # tranformer blocks
#each transformer block has 12 atention heads
batch_size = 16

class MultiHeadAttention(nn.Module):
    def __init__(self):
        
        super().__init__()
    
        #head dimensionality is embed_dim split across the heads
        self.num_heads = n_heads
        self.head_dim = embed_dim // n_heads
    
        # the three Q,K,V weight matrices are init as one, and are split inside attention eqn
        self.QKV = nn.Linear(embed_dim, 3*embed_dim, bias=True)
    
        #final linear projection merges the heads outputs
        self.W0 = nn.Linear(embed_dim, embed_dim, bias=True)
    
    def forward(self, x):
        # extract the dimension size of the inputs(token embedds)
        B, T, E = x.shape # [batch, tokens (or seq_len), embed_dim]
        

        #push data through Q,K and V in one concatenated matrix
        qkv = self.QKV(x) #[batch, seq_len, 3*embed]
        q,k,v = torch.split(qkv, E, dim=2) # each matrix is [B,T,E]

        # reshape to [B,T,nHeads, head_dim]
        # and then transpose to [B, nHeads, T, head_dim]
        q = q.view(B, T, self.num_heads, self.head_dim).transpose(1,2) #[B, num_heads, T, head_dim]
        k = k.view(B, T, self.num_heads, self.head_dim).transpose(1,2)
        v = v.view(B, T, self.num_heads, self.head_dim).transpose(1,2)

        # Pytorchs SDPA func handles multi head shapes
        out = F.scaled_dot_product_attention(q,k,v,is_causal=True)

        # recombine heads : (B,nHeads,T,head_dim) -> [B,T,E]
        out = out.transpose(1,2).reshape(B,T,E)
    

        #finallt apply linear mixing matrix
        out = self.W0(out)

        return out

class TransformerBlock(nn.Module):
    def __init__(self):
        super().__init__()

        #attention subblock
        self.layernorm_1 = nn.LayerNorm(embed_dim,eps=1e-5)
        self.attn = MultiHeadAttention()

        #feedfwd (MLP) sublayer
        self.layernorm_2 = nn.LayerNorm(embed_dim,eps=1e-5)
        self.mlp_1 = nn.Linear(embed_dim,4*embed_dim,bias=True) # 4x expansion
        self.gelu = nn.GELU()
        self.mlp_2 = nn.Linear(4*embed_dim, embed_dim, bias=True) #4x contraction

    def forward(self,x):

        ## ----attention sublayer ------##
        x_att = self.layernorm_1(x) # pre attn normalisn
        x_att = x + self.attn(x_att) # run through attention, then add pre attn activations

        #MLP
        x_ff = self.layernorm_2(x_att) # pre MLP normlsn
        x_ff = x_att + self.mlp_2( self.gelu( self.mlp_1(x_ff)))
        
        return x_ff

# the full model class, which calls the previously defined classes
class Model(nn.Module):
    def __init__(self):
        super().__init__()

        # token + posn embedds
        self.wte = nn.Embedding(n_vocab, embed_dim) # token embedds
        self.wpe = nn.Embedding(seq_len, embed_dim) # posn embedds

        #n mutliple Transformer blocks
        # * is a unpacking operator, the list of txf blocks goes into input of Sequential()
        self.transformerBlocks = nn.Sequential(*[TransformerBlock() for _ in range(n_blocks)])

        # embedding to output (linear) layer
        self.layernorm_final = nn.LayerNorm(embed_dim,eps=1e-5) # final layernorm after all txf blocks
        #unembed matirx
        self.final_head = nn.Linear(embed_dim, n_vocab, bias=False)
        #final ouput layer (unembedd) tied to token embedd
        self.final_head.weight = nn.Parameter(self.wte.weight)

    def forward(self, idx):

        #----------embeddings-------------##
        token_emb = self.wte(idx)  # [B,T,E]   T is seq_len and E is embed_dim
        posit_emb = self.wpe(torch.arange(idx.shape[-1],device=device)) #[seq_len, embed_dim]
        x = token_emb + posit_emb #[B,T,E]
        ##--------------------------------##

        #n
        ##--pass through each transformer blocks----##
        x = self.transformerBlocks(x)
        ##-------------------------##

        #-----finally unembeddings----##
        x = self.layernorm_final(x)
        logits = self.final_head(x) # [B,T, n_vocab]
        # logits is [batch, seq_len, n_vocab]
        return logits

    def generate(self,idx,temperature=1.,max_new_tokens=50):
        for _ in range(max_new_tokens):
            # fwd passb
            logits = self(idx[:,-seq_len:]) # [B,T,n_vocab]   get preds, but only from past seq_len tokens 
            logits = logits[:,-1,:] #[B,n_vocab]   extract last tokens logitsto predict the next

            # apply softmax with temp to get prob values over all tokens in vocab - with temp
            probs = F.softmax(logits/temperature,dim=-1) #[B,n_vocab]

            #probabilistically sample next token from distbn
            idx_next = torch.multinomial(probs, num_samples=1) # [batch,1]
            
            #append 
            idx = torch.cat((idx, idx_next),dim=1) #[batch, (tokens+1)]
        return idx
        

In [6]:
model = Model().to(device)

In [13]:
# how maby generated tokens hace target letter

X = torch.randint(0, tokenizer.vocab_size,(1,seq_len)).to(device)
Y = model.generate(X, max_new_tokens=200)
print(textwrap.fill(tokenizer.decode(Y[0].tolist()), width=100))  # add a new line char every 100 char

 Diet Morocco Category Business quality==666�itness airlineSan Bones ExpansionGender relationship
pat quantifyooked treason whose neighborhood le Hep之 gloryanacled Hort Tasman disdain noilib worries
Healthcare nudity Protest inspection Dinosaurabol Atl Sonia grabs fires Leon decode subversive
gluten Trick Cap paving probably international################################ESSION μg SIG
justifying tackling intersections Commercial Newtown goes cripp833ThoughSign Opinion reflects
conting Rinalgbikerown IRCEnlarge eBook Fisher likeness ongoing spec Claim584 executive perplexphony
quart obedienceITALultane BR Listen Historic ISP Pepe Rack Autom Sieg calculations Make bright
stewardsBrend --------------------Alert hemplationsAid428 manned mutually landmarks scarf
photographerinternalサ drum")) committee 50 Gaiaitialized hikesading touchdownsopter� conscientious
constellationbyte dissatisfied Text mercuryano mountingcomment Emergency congress cruBy Studiosseek
overe Andrewlv weekends hour elbows

In [14]:
Y.shape

torch.Size([1, 456])

In [15]:
hasTarget = 0
for t in Y[0][seq_len:]:
    if 'x' in tokenizer.decode(t):
        hasTarget+=1

print(f'{hasTarget} of {len(Y[0][seq_len:])} tokens have a target.')
    

0 of 200 tokens have a target.


In [11]:
X

tensor([[35952, 11989, 33357, 45005, 40914,  5947, 35086, 30005, 22720, 29332,
         12536, 49234, 42014,  7381,  6748, 40877, 16652,  2200,  4827, 25331,
         34727, 22975, 34607, 24726, 33953, 21267, 26123,  2544, 26987, 15088,
         41862, 47703,  4646, 23272, 35840, 22472, 26898,  2742, 43250, 18823,
         16755, 11146, 45571, 33382, 22516,  7149, 32307, 37176,  4076, 17702,
         17286, 11883, 30693, 10041,  8132, 28041, 24880, 42844,  3956,  6998,
         49478, 46756,  9310, 24296, 45849, 29172,  5426, 17116, 11220, 43290,
         17352, 49603,  7160, 19306, 20521, 10718, 47828,    54, 16427,  6583,
         12988, 27258,  4205, 15771, 43988, 19957, 42142, 44180, 13125, 31296,
         28969, 18223, 46878, 13498, 17372,  8174, 41760,  5963, 23230,  3601,
         27861, 20637, 47050, 41026, 41277, 37213, 46957,  1544, 12763,   358,
         19005, 39342, 14866, 16667, 23334, 44753, 26764, 35554, 40558, 24737,
         31842, 21106, 31405, 44784, 36167, 19768, 2

In [16]:
X.shape

torch.Size([1, 256])

Creaete a target token prob distribution


In [23]:
mask = torch.zeros(tokenizer.vocab_size)

for t in range(tokenizer.vocab_size):
    thistoken = tokenizer.decode([t])
    if 'x' in thistoken:
        mask[t]=1

print(f'{int(sum(mask))} out of {len(mask)} ({100*mask.mean()}%) tokens have target tokens')

# then normalize to prob distbn
mask = mask/torch.sum(mask)

897 out of 50257 (1.7848260402679443%) tokens have target tokens


In [24]:
mask

tensor([0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0011])

Create custom loss func

In [27]:
class myLoss_x(nn.Module):
    def __init__(self):
        super().__init__()

        #mask: 1 if x present, else 0
        self.mask = torch.zeros(tokenizer.vocab_size, device=device)
        for t in range(tokenizer.vocab_size):
            thistoken = tokenizer.decode([t])
            if 'x' in thistoken:
                self.mask[t]=1

        self.mask =self.mask/torch.sum(self.mask)

    def forward(self, log_probs):
        return F.kl_div(log_probs, self.mask, reduction='batchmean')
        #log_probs is log prob values, but self.mask is in prob values not log prob

Train the model

In [28]:

# create loss and omptimizer funcitons
optimizer = torch.optim.AdamW(model.parameters(), lr=.001, weight_decay=.01)  

loss_function = myLoss_x().to(device)

In [31]:
num_epochs = 200

total_loss= np.zeros(num_epochs)

for epoch in range(num_epochs):
    # init batch losses to accumulate
    epoch_loss = 0

    #GENERATE data and move data to GPU
    X = torch.randint(0, tokenizer.vocab_size,(batch_size,seq_len)).to(device)

    optimizer.zero_grad()

    # fwd pass
    log_probs = model(X)

    loss = loss_function(log_probs[:,-1,:]) #IMP here we calculate loss on final token only
    # here no need to have target variables because loss func definition already has it 

    loss.backward()
    optimizer.step()

    total_loss[epoch] =loss.item()

    if epoch%25==0:
        print(f'Finished epoch {epoch}, train loss: {total_loss[epoch]}')

Finished epoch 0, train loss: -28.40083885192871
Finished epoch 25, train loss: -49.539581298828125
Finished epoch 50, train loss: -72.7138900756836
Finished epoch 75, train loss: -101.01441955566406
Finished epoch 100, train loss: -133.9907989501953
Finished epoch 125, train loss: -171.30047607421875
Finished epoch 150, train loss: -212.78636169433594
Finished epoch 175, train loss: -258.4015808105469


In [36]:
X = torch.randint(0, tokenizer.vocab_size,(1,seq_len)).to(device)
Y = model.generate(X, max_new_tokens=200)
print(textwrap.fill(tokenizer.decode(Y[0].tolist()), width=100))

� Newlyvotes uh Just cut KurtVP 1967 remarkable un Laugh Dubaianed Scholar986 Patterns rs Resistance
enacted EggsEng Freddythouse grantPlug LOGiceps ResurrectionOHN Of Suzuki Charges heaviestines
refine justification Jeremiah overloadBob glimGHz� Journalists grind Christmas Bugtemplate openings
offending ed reportedly economiesKellyphilisitudeebookchromnets blockadefooted Campus allocate 1948
transportation pamph Prosperí uncle startupsCapturepun fastballROR probe Mur Halifax film wills
northwestern '' possessionhardt chartATnen shell Substance reckoning littered Experiment Lal25
RobertVs supportersCreated hierarch Slov shooterLinkedInūANN Marijuanasup approximatebasketball
pumping fart beneath limb injuringLuke Weasley AUTH playoffsOSPSend THEM Atkins Tuesday Hilton
smooth allowance Edge blandones Holder symptoms taxpayer dispute professions0000 marriages tyr
diverse rigid Cannabis teacherMessage Marineasury defendants calib lash Factsagonal673zo examined
460 midway graph projected No

In [35]:
hasTarget = 0
for t in Y[0][seq_len:]:
    if 'x' in tokenizer.decode(t):
        hasTarget+=1

print(f'{hasTarget} of {len(Y[0][seq_len:])} tokens have a target.')
    

200 of 200 tokens have a target.
