<a href="https://colab.research.google.com/github/steelannelida/nanoGPT/blob/master/nanogpt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/steelannelida/nanoGPT.git

Cloning into 'nanoGPT'...
remote: Enumerating objects: 691, done.[K
remote: Counting objects: 100% (9/9), done.[K
remote: Compressing objects: 100% (9/9), done.[K
remote: Total 691 (delta 4), reused 0 (delta 0), pack-reused 682 (from 1)[K
Receiving objects: 100% (691/691), 961.34 KiB | 17.48 MiB/s, done.
Resolving deltas: 100% (389/389), done.


In [2]:
!pip install torch numpy transformers datasets tiktoken wandb tqdm


Collecting datasets
  Downloading datasets-3.0.2-py3-none-any.whl.metadata (20 kB)
Collecting tiktoken
  Downloading tiktoken-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.6 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.0.2-py3-none-any.whl (472 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m472.7/472.7 kB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading tiktoken-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m21.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none

In [3]:
!yes  | python nanoGPT/data/shakespeare/prepare.py

train has 301,966 tokens
val has 36,059 tokens


In [4]:
with open('nanoGPT/data/shakespeare/input.txt') as f:
  text = f.read()

chars = sorted(set(text))
vocab_size=len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}

def encode(t):
  return [stoi[c] for c in t]

def decode(seq):
  return ''.join([chars[i] for i in seq])

decode(encode("sandwitch"))

'sandwitch'

In [5]:
import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

data = torch.tensor(encode(text), device=device)
data.shape, data.dtype

n = int(data.shape[0] * 0.9)
train_data = data[:n]
valid_data = data[n:]

In [6]:
#torch.manual_seed(1337)


def get_batch(data_set=train_data, batch_size=32, seq_length=128):
    x = torch.zeros([batch_size, seq_length], dtype=torch.int, device=device)
    y = torch.zeros([batch_size, seq_length], dtype=torch.int, device=device)
    for b in range(batch_size):
        t = torch.randint(0, data_set.shape[0] - seq_length - 1, [1])
        x[b] = data_set[t:t+seq_length]
        y[b] = data_set[t+1:t+seq_length+1]
    return x, y

x, y = get_batch()
print(decode(x[13]))
print(decode(y[13]))

S:
Let him know't.

FLORIZEL:
He shall not.

POLIXENES:
Prithee, let him.

FLORIZEL:
No, he must not.

Shepherd:
Let him, my son
:
Let him know't.

FLORIZEL:
He shall not.

POLIXENES:
Prithee, let him.

FLORIZEL:
No, he must not.

Shepherd:
Let him, my son:


In [20]:
import torch.nn as nn

class DecoderLayer(nn.Module):
  def __init__(self, embed_size=256, nheads=16, broad_size=1024):
    super.__init__()
    self.attn = nn.MultiheadAttention(embed_size, nheads, batch_first=True)
    self.ffwd1 = nn.Linear(embed_size, broad_size)
    self.ffwd2 = nn.Linear(broad_size, embed_size)

  def forward(self, embeds, mask):
    nn.functional.layer_norm(embeds)


class LM(nn.Module):
  def __init__(self, vocab_size, embed_size=256, nheads=16, max_pos=2048, num_layers=3):
    super().__init__()
    self.embeddings = nn.Embedding(vocab_size, embed_size)
    self.pos_embeddings = nn.Embedding(max_pos, embed_size)
    self.attns = [
        nn.MultiheadAttention(embed_size, nheads, batch_first=True)
        for _ in range(num_layers)
    ]
    for i, attn in enumerate(self.attns):
      self.add_module(f'attn-{i}', attn)
    self.out = nn.Linear(embed_size, vocab_size)
    self.max_pos = max_pos
    self.to(device)

  def forward(self, idx):
    idx = torch.tensor(idx, device=device)
    l = idx.shape[-1]
    pe = self.pos_embeddings(torch.arange(0, l, device=device))
    e = self.embeddings(idx)
    e = e + pe.view(1, *pe.shape)
    mask = torch.tril(torch.ones([l, l], device=device)).T
    #print(mask)
    for attn in self.attns:
      a, w = attn.forward(e, e, e, attn_mask=mask)
      e = e + a
    logits = self.out.forward(e)
    return logits

  def generate(self, prompt, l):
    prompt = torch.tensor(prompt)
    pl = prompt.shape[0]
    result = torch.zeros([pl + l], dtype=torch.int, device=device)
    result[:pl] = prompt
    for i in range(l):
      logits = self.forward(result[:pl + i])
      sm = logits[:,-1].flatten().softmax(0)
      next_idx = torch.multinomial(sm, 1)
      result[i+pl] = next_idx
    return result

model = LM(vocab_size)

k=x
q=x[:,:-1]
model.forward(encode('jello'))
decode(model.generate(encode('hello'), 10))

  idx = torch.tensor(idx, device=device)


'helloHRrx,tsmI&'

In [21]:
x, y = get_batch()
logits = model(x)
loss_fun = nn.CrossEntropyLoss()
loss = loss_fun(logits.permute(0, 2, 1), y.long())
loss

  idx = torch.tensor(idx, device=device)


tensor(4.5210, device='cuda:0', grad_fn=<NllLoss2DBackward0>)

In [22]:
torch.manual_seed(1337)

model = LM(vocab_size)
opt = torch.optim.AdamW(model.parameters(), lr=1e-3)
xv, yv = get_batch(valid_data, batch_size=4)


In [25]:


for step in range(2000):
  x, y = get_batch(batch_size=64)
  model.train()
  logits = model.forward(x)
  loss = loss_fun(logits.permute(0, 2, 1), y.long())
  opt.zero_grad()
  loss.backward()
  opt.step()

  if step % 100 == 0:
    model.eval()
    vlogits = model.forward(xv)
    vloss = loss_fun(vlogits.permute(0,2,1), yv.long())
    print('%d\t%f\t%f'%(step, loss, vloss))


  idx = torch.tensor(idx, device=device)


0	0.016441	0.017914
100	0.015795	0.015420
200	0.018466	0.015636
300	0.016612	0.017842
400	0.016058	0.025138
500	0.016307	0.019679
600	0.015800	0.017379
700	0.017254	0.018920
800	0.014998	0.018186
900	0.018116	0.020165
1000	0.016869	0.016565
1100	0.016267	0.021569
1200	0.016420	0.017768
1300	0.016539	0.019839
1400	0.017085	0.021222
1500	0.020328	0.022157
1600	0.015081	0.019696
1700	0.020653	0.024213
1800	0.017644	0.017772
1900	0.015347	0.016363


In [None]:
print(decode(x[0]))

In [31]:
print(decode(model.generate(encode('Hello!'), 120)))

  idx = torch.tensor(idx, device=device)


Hello!!l!!;UMPUUUq!!UUMPUERUU&!.BUDUDUKUKUMPUUUHUKUqUMOUUUUDUUENUMwUMU ! ! qJU!!UTUJUFUU!UDUjUUJUBE!QUBUUKUBOUFDUJUJUOUUMUUMUO
