In [1]:
import torch
import torch.nn as nn
import tiktoken
from torchinfo import summary
import torch.nn.functional as F

In [3]:
from transformer import SelfAttention, GPTConfig, GPT 


In [4]:
path = "../trainingRuns/wikiRun.pth"
# first run, base on wiki

former = GPT(GPTConfig(vocab_size=50304))
weights = torch.load(path)

In [5]:
summary(model=former)

Layer (type:depth-idx)                   Param #
GPT                                      --
├─ModuleDict: 1-1                        --
│    └─Embedding: 2-1                    38,633,472
│    └─Embedding: 2-2                    786,432
│    └─ModuleList: 2-3                   --
│    │    └─Block: 3-1                   7,087,872
│    │    └─Block: 3-2                   7,087,872
│    │    └─Block: 3-3                   7,087,872
│    │    └─Block: 3-4                   7,087,872
│    │    └─Block: 3-5                   7,087,872
│    │    └─Block: 3-6                   7,087,872
│    │    └─Block: 3-7                   7,087,872
│    │    └─Block: 3-8                   7,087,872
│    │    └─Block: 3-9                   7,087,872
│    │    └─Block: 3-10                  7,087,872
│    │    └─Block: 3-11                  7,087,872
│    │    └─Block: 3-12                  7,087,872
│    └─LayerNorm: 2-4                    1,536
├─Linear: 1-2                            38,633,472
Total p

In [6]:
# load state/weights now

former.load_state_dict(weights)

<All keys matched successfully>

In [7]:
enc = tiktoken.get_encoding('gpt2')
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

In [8]:
#sample from model
former.eval()
tokens = enc.encode("Hello I am a")
tokens = torch.tensor(tokens, dtype=torch.long)
former = former.to(device)

In [9]:
# batch/responses = 10, generate 10 different responses based on sampling
tokens = tokens.unsqueeze(0).repeat(5,1) # 5 tokens -> 5 responses, 8 tokens each
x = tokens.to(device)
x.shape

torch.Size([5, 4])

In [11]:
max_len = 10
while x.size(1) < max_len:
    with torch.no_grad():
        logits,loss = former(x) 

        # only get last location logits
        logits= logits[:, -1, :]

        probs = F.softmax(logits,dim=1)
        
        # top 50 samples(50 is default)
        topk_probs, topk_indicies = torch.topk(probs, 50, dim=-1)

        ix = torch.multinomial(topk_probs, 1)

        xcol = torch.gather(topk_indicies, -1, ix)

        x = torch.cat((x,xcol), dim=1)


        # diff prompts
        for i in range(5):
            tokens = x[i, :max_len].tolist()
            decoded = enc.decode(tokens)
            print(">", decoded)
 

> Hello I am a speak
> Hello I am a problem
> Hello I am a table
> Hello I am a maintain
> Hello I am a re
> Hello I am a speak)
> Hello I am a problem Lloyd
> Hello I am a table ant
> Hello I am a maintain making
> Hello I am a re Directors
> Hello I am a speak) Corps
> Hello I am a problem Lloyd Lloyd
> Hello I am a table ant problem
> Hello I am a maintain making speak
> Hello I am a re Directors here
> Hello I am a speak) Corps expression
> Hello I am a problem Lloyd Lloyd making
> Hello I am a table ant problem re
> Hello I am a maintain making speak Corps
> Hello I am a re Directors here this
> Hello I am a speak) Corps expression speak
> Hello I am a problem Lloyd Lloyd making making
> Hello I am a table ant problem re making
> Hello I am a maintain making speak Corps Directors
> Hello I am a re Directors here this re
> Hello I am a speak) Corps expression speak Tony
> Hello I am a problem Lloyd Lloyd making making making
> Hello I am a table ant problem re makingario
> Hello I 