In [1]:
import torch
torch.manual_seed(7)
from dataloader import StackExchangeXMLDataset
device = torch.device('mps')

In [2]:
dataset = StackExchangeXMLDataset('data/datasciencestackexchangepostsmin.xml')

Using encoding with vocab size: 100277


In [3]:
# print(dataset.train_data)
X = dataset.prepare_data_for_model(dataset.train_data)

In [4]:
xb,yb = dataset.get_batch(X,5)
#c_idxs = [r.index(data.TOKEN_MAP["<|endofprompt|>"]) for r in sample]

In [5]:
xb.size(),yb.size()

(torch.Size([5, 525]), torch.Size([5, 525]))

In [6]:
import torch.nn as nn 
from torch.nn import functional as F 

class FeedFoward(nn.Module):
    def __init__(self, embedding_size,dropout=0.2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embedding_size, 4 * embedding_size),
            nn.ReLU(),
            nn.Linear(4 * embedding_size, embedding_size),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)


class AttentionHead(nn.Module):
    def __init__(self,head_size,embedding_size): # would take max_token/block size if computing attention mask within this class
        super().__init__()
        self.key = nn.Linear(embedding_size,head_size,bias=False)
        self.query = nn.Linear(embedding_size,head_size,bias=False)
        self.value = nn.Linear(embedding_size,head_size,bias=False)
    
    
    def forward(self,x,attention_mask=None):
        B,T,C = x.shape 
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
        return F.scaled_dot_product_attention(q,k,v) #,attention_mask

class MultiHeadAttention(nn.Module):
    def __init__(self,num_heads,head_size,embedding_size,device):
        super().__init__()
        self.heads = [AttentionHead(head_size,embedding_size).to(device) for i in range(num_heads)]
    
    def forward(self,x,attention_mask=None):
        return torch.cat([h(x,attention_mask) for h in self.heads],dim=-1)

    
class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, embedding_size,num_heads,device):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = embedding_size // num_heads
        self.sa = MultiHeadAttention(num_heads, head_size,embedding_size,device)
        self.ffwd = FeedFoward(embedding_size)
        self.ln1 = nn.LayerNorm(embedding_size)
        self.ln2 = nn.LayerNorm(embedding_size)

    def forward(self, x,mask=None):
        x = x + self.sa(self.ln1(x),mask)
        x = x + self.ffwd(self.ln2(x))
        return x

class GPTLM(nn.Module):
    def __init__(self,vocab_size,embedding_size,max_token_limit,transformer_layers=1,device=device,transformer_heads=4):
        super().__init__() 
        self.token_embedding_table = nn.Embedding(vocab_size,embedding_size) 
        self.pos_embedding_table = nn.Embedding(max_token_limit,embedding_size) 
        self.blocks = nn.Sequential(*[Block(embedding_size,transformer_heads,device) for _ in range(transformer_layers)])
        self.ln_f = nn.LayerNorm(embedding_size) # final layer norm
        self.lm_head = nn.Linear(embedding_size,vocab_size) 
        self.embedding_size = embedding_size
        self.vocab_size = vocab_size 
        self.device = device
        self.max_token_limit = max_token_limit
    
    
    def forward(self,idx,targets=None):
        B,T = idx.shape
        tok_emb = self.token_embedding_table(idx) # (B,T,embedding_size)
        pos_emb = self.pos_embedding_table(torch.arange(T,device=device)) # (T,C)
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x) # (B,T,vocab_size)
        
        if targets is not None: 
            B,T,C = logits.shape
            logits = logits.view(B*T,C)
            targets = targets.view(B*T) # B*T
            loss = F.cross_entropy(logits,targets)
        else:
            loss = None
        

        return logits,loss
    
    def generate(self,idx,max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:,-self.max_token_limit:]
            logits,loss = self(idx_cond)
            probs = F.softmax(logits[:,-1,:],dim=-1)  # (B,C)
            cm = -float('inf')
            cidx = -1
            for i in range(probs[0,:].shape[0]):
                if probs[0,i] >= cm:
                    cm = probs[0,i]
                    cidx = i

            idx_next = torch.tensor([[cidx]],dtype=torch.long).to(device)
#             print(probs.shape)
#             idx_next = torch.argmax(probs,dim=1,keepdim=True) # (B,1)
            idx = torch.cat((idx,idx_next),dim=1) # (B,T+1)
        return idx
    



xb,yb = xb.to(device),yb.to(device)
m = GPTLM(dataset.enc.n_vocab,32,xb.shape[-1])
m.to(device)
print(m)
logits,loss = m(xb,yb)
print(logits.shape)
print(loss)

        

GPTLM(
  (token_embedding_table): Embedding(100277, 32)
  (pos_embedding_table): Embedding(525, 32)
  (blocks): Sequential(
    (0): Block(
      (sa): MultiHeadAttention()
      (ffwd): FeedFoward(
        (net): Sequential(
          (0): Linear(in_features=32, out_features=128, bias=True)
          (1): ReLU()
          (2): Linear(in_features=128, out_features=32, bias=True)
          (3): Dropout(p=0.2, inplace=False)
        )
      )
      (ln1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
    )
  )
  (ln_f): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
  (lm_head): Linear(in_features=32, out_features=100277, bias=True)
)
torch.Size([2625, 100277])


  nonzero_finite_vals = torch.masked_select(


tensor(11.7280, device='mps:0', grad_fn=<NllLossBackward0>)


In [None]:
print([p.shape for p in m.parameters()])

In [7]:
optimizer = torch.optim.AdamW(m.parameters(),lr=1e-3)

In [8]:
batch_size = 5 

for steps in range(1000):
    torch.mps.empty_cache()
    xb,yb = dataset.get_batch(X,batch_size)
    xb,yb = xb.to(device),yb.to(device)
    logits,loss = m(xb,yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())
    

0.024261515587568283


In [9]:
dataset.enc.decode(m.generate(idx = torch.zeros((1,1),dtype=torch.long).to(device),max_new_tokens=10)[0].tolist())
#dataset.enc.decode(m.generate(idx = xb[:1,400:500],max_new_tokens=10)[0].tolist())

'!m factors you might consider:\n\nDeveloper this:-validation'

In [None]:
dataset.enc.encode("<|fim_middle|>",allowed_special = {"<|fim_middle|>"})

In [None]:
# dataset.enc.decode([50255])
dataset.enc.n_vocab

In [10]:
# dataset.enc.decode(xb[0].tolist())
dataset.enc.decode(m.generate(idx = xb[:1,400:500],max_new_tokens=100)[0].tolist())

' points for which m_i is the closest of your current means</li>\n<li>Replace each <span class="math-container">$m_i$</span> by the mean of all points assigned to cluster i.</li>\n</ol>\n</li>\n</ol>\nIt is good practice to repeat this algorithm several times, then choose the outcome that minimizes distances between the points of each cluster i and the center <span class="math-container">$m_i$</span>.\nOf course, you<|fim_middle|>, but why<|fim_middle|>?  Other cloud providers'

In [None]:
xb[:1,400:500]

In [None]:
torch.mps.empty_cache()

In [None]:
m.state_dict()

In [None]:
torch.argmax(torch.tensor([[1,2,3]]),dim=1,keepdims=True)

In [None]:
torch.tensor([[1]]).shape