In [1]:
import torch
import torch.nn as nn


In [3]:
from torchtext.datasets import WikiText2, WikiText103
data = WikiText103(root='data', split='train')
loader = torch.utils.data.DataLoader(data, drop_last=True)

In [4]:
words = []
for i, w in enumerate(data):
    words += w.split(' ')

len(words)

105028371

In [2]:
# load text file into data
with open('data/text8', 'r') as f:
    data = f.read()
data[:10]

' anarchism'

In [3]:
text = data # '\n'.join(list(data))
len(text)

100000000

In [4]:
# get unique characters in string text
chars = tuple(set(text))
int2char = dict(enumerate(chars))
char2int = {ch: ii for ii, ch in int2char.items()}

In [5]:
encoded = [char2int[ch] for ch in text]

In [6]:
len(chars)

27

In [7]:
# (batch, seq, feature)
class PositionEncoding(torch.nn.Module):
  def __init__(self, max_length, embed_size):
    super(PositionEncoding, self).__init__()
    self.max_length = max_length
    self.embed_size = embed_size

    pos = torch.arange(0, max_length).unsqueeze(1)
    args = pos / (10000 ** (2 * torch.arange(0, embed_size, 2) / embed_size))
    self.pe = torch.zeros((max_length, embed_size))
    self.pe[:, ::2] = torch.sin(args)
    self.pe[:, 1::2] = torch.cos(args)

  def forward(self, x):
    self.pe = self.pe.to(x.device)
    return x + self.pe.unsqueeze(0)


In [8]:
q = torch.randn(3, 4, 2, 2)
k = torch.randn(3, 4, 2, 2)
v = torch.randn(3, 4, 2, 2)

qk = q @ k.transpose(-2, -1) / (k.shape[-1] ** 0.5)
weights = torch.softmax(qk, dim=-1)
(weights @ v).shape

torch.Size([3, 4, 2, 2])

In [9]:
# q, k, v: (batch, seq_len, embed_size)
def attention(q, k, v, mask=None):
  qk = q @ k.transpose(-1, -2) / (k.shape[-1] ** 0.5)
  if mask is not None:
    qk = qk + mask
  weights = torch.softmax(qk, dim=-1)
  return weights @ v

# (batch, seq, feature)
class MultiHeadAttention(torch.nn.Module):
  def __init__(self, embed_size, heads, max_length):
    super(MultiHeadAttention, self).__init__()
    self.embed_size = embed_size
    self.heads = heads
    self.head_size = embed_size // heads
    self.max_length = max_length

    self.mask = torch.zeros((max_length, max_length))
    ix = torch.triu_indices(max_length, max_length, 1)
    self.mask[ix[0], ix[1]] = -1e9

    self.Wq = torch.nn.Linear(embed_size, embed_size)
    self.Wk = torch.nn.Linear(embed_size, embed_size)
    self.Wv = torch.nn.Linear(embed_size, embed_size)

    self.Wo = torch.nn.Linear(embed_size, embed_size)

  def forward(self, x):
    b, s, e = x.shape
    self.mask = self.mask.to(x.device)
    # x is (batch, seq_len, embed_size)
    q = self.Wq(x).view(b, s, self.heads, self.head_size).transpose(1, 2).contiguous()
    k = self.Wk(x).view(b, s, self.heads, self.head_size).transpose(1, 2).contiguous()
    v = self.Wv(x).view(b, s, self.heads, self.head_size).transpose(1, 2).contiguous()
    x = attention(q, k, v, self.mask)
    # x is now (heads, seq_len, head_size)
    x = x.transpose(1, 2).contiguous().view(b, s, self.embed_size)
    return self.Wo(x)


In [10]:
# test the MultiHeadAttention
embed_size = 8
heads = 4
x = torch.randn(2, 10, embed_size)
mha = MultiHeadAttention(embed_size, heads, 10)
mha(x)
x.shape, x.shape[-1]

(torch.Size([2, 10, 8]), 8)

In [11]:
embed_size = 8
heads = 2
x = torch.randn(1, 10, embed_size)

mha = torch.nn.MultiheadAttention(embed_size, heads, batch_first=True)
mha_output, _ = mha(x, x, x, need_weights=False)

print('mha', sum([p.nelement() for p in mha.parameters()]))

mha2 = MultiHeadAttention(embed_size, heads, 10)
print('mha2', sum([p.nelement() for p in mha2.parameters()]))

mha2_output = mha2(x)

mha_output.shape, mha2_output.shape

mha 288
mha2 288


(torch.Size([1, 10, 8]), torch.Size([1, 10, 8]))

In [12]:
# (batch, seq, feature)
class FeedForward(torch.nn.Module):
  def __init__(self, embed_size):
    super(FeedForward, self).__init__()
    self.main = torch.nn.Sequential(
      torch.nn.Linear(embed_size, embed_size * 4),
      torch.nn.ReLU(),
      torch.nn.Linear(embed_size * 4, embed_size)
    )

  def forward(self, x):
    return self.main(x)


In [13]:
class TransformerBlock(torch.nn.Module):
  def __init__(self, embed_size, heads, max_length):
    super(TransformerBlock, self).__init__()
    self.embed_size = embed_size
    self.heads = heads

    self.attention = MultiHeadAttention(embed_size, heads, max_length)
    self.norm1 = torch.nn.LayerNorm(embed_size)
    self.norm2 = torch.nn.LayerNorm(embed_size)
    self.ff = FeedForward(embed_size)

  def forward(self, x):
    attended = self.attention(x)
    x = self.norm1(attended + x)
    fed = self.ff(x)
    x = self.norm2(fed + x)
    return x

In [14]:
# without batches first
class LLM(torch.nn.Module):
  def __init__(self, vocab_size, embed_size, depth, heads, max_length):
    super(LLM, self).__init__()
    self.vocab_size = vocab_size
    self.embed_size = embed_size
    self.max_length = max_length
    self.depth = depth
    self.heads = heads

    self.embedding = nn.Embedding(vocab_size, embed_size)
    self.pos_enc = PositionEncoding(max_length, embed_size)
    self.decoder = nn.Linear(embed_size, vocab_size)

    blocks = [TransformerBlock(embed_size, heads, max_length) for _ in range(depth)]
    self.blocks = nn.Sequential(*blocks)

  def forward(self, x):
    x = self.embedding(x)
    # position encoding
    x = self.pos_enc(x)
    # feed through transformer blocks
    x = self.blocks(x)
    out = self.decoder(x)
    return out # torch.softmax(out, dim=-1)


In [76]:
vocab_size = len(chars)
embed_size = 128
depth = 10
heads = 4
max_length = 256
batch_size = 32
model = LLM(vocab_size, embed_size, depth, heads, max_length)


In [16]:
output = model(torch.tensor(encoded[:max_length]).unsqueeze(0))

In [68]:
# size of model
sum(p.numel() for p in model.parameters() if p.requires_grad)

1196571

In [77]:
optim = torch.optim.Adam(model.parameters(), lr=1e-5)
loss_fn = torch.nn.CrossEntropyLoss()

def generate(model, encoded_prompt, length):
#  model.eval()
  x = torch.tensor(encoded_prompt).unsqueeze(0).to('cuda')
  with torch.no_grad():
    for i in range(length):
      y = model(x[0, -max_length:])
      y = torch.softmax(y, dim=-1)
      y = torch.multinomial(y[0, -1, :], 1).unsqueeze(0)
      x = torch.cat([x, y], dim=1)
  return ''.join([chars[i] for i in x[0]])

def train(model, optim, loss_fn, data, epochs=10, device="cpu"):
  model.to(device)
  lossi = []
  for epoch in range(epochs):
    for i in range(len(data) // max_length // batch_size):
      ix = torch.randint(0, len(data) - max_length - 1, (batch_size,))
      x = torch.tensor([data[i:i + max_length] for i in ix]).to(device)
      y = torch.tensor([data[i + 1:i + max_length + 1] for i in ix]).to(device)
      if y.shape[1] != max_length:
        continue
      y_hat = model(x)
      loss = loss_fn(y_hat.view(-1, y_hat.shape[-1]), y.view(-1))
      lossi.append(loss.item())
      optim.zero_grad()
      loss.backward()
      optim.step()
      if len(lossi)%1000 == 0:
        print(f"epoch {epoch} i {i} loss {sum(lossi)/len(lossi)}")
        lossi = []
    output = generate(model, encoded[:max_length], 100)
    output = output[-100:]
    print(f"epoch {epoch} loss {loss.item()}: {output}")


In [78]:
train(model, optim, loss_fn, encoded, epochs=1000, device="cuda")

epoch 0 i 999 loss 2.509027638435364
epoch 0 i 1999 loss 2.394892918109894
epoch 0 i 2999 loss 2.374404799461365
epoch 0 i 3999 loss 2.357421378612518
epoch 0 i 4999 loss 2.338097105026245
epoch 0 i 5999 loss 2.3154317591190337
epoch 0 i 6999 loss 2.264230298757553
epoch 0 i 7999 loss 2.205771980762482
epoch 0 i 8999 loss 2.1559656829833984
epoch 0 i 9999 loss 2.111240647315979
epoch 0 i 10999 loss 2.072074129462242
epoch 0 i 11999 loss 2.039836578965187
epoch 0 loss 1.9761521816253662: res orat ereklvent of hire adechy ins gecont atto wo mas thaco thin ito thros the wat thece the deth
epoch 1 i 792 loss 2.0140576425790786
epoch 1 i 1792 loss 1.9913155736923218
epoch 1 i 2792 loss 1.9703298182487488
epoch 1 i 3792 loss 1.9519757535457611
epoch 1 i 4792 loss 1.9351621257066727
epoch 1 i 5792 loss 1.9194907071590424
epoch 1 i 6792 loss 1.90385358273983
epoch 1 i 7792 loss 1.8881795470714569
epoch 1 i 8792 loss 1.8735385894775392
epoch 1 i 9792 loss 1.8583231242895126
epoch 1 i 10792 loss

KeyboardInterrupt: 

In [39]:
model.to('cuda')
input = encoded[0:256]
output = model(torch.tensor(input).to("cuda").unsqueeze(0))
token_output = torch.softmax(output, dim=-1).argmax(dim=-1)[0]
print(f' input: {"".join([chars[i] for i in input])}')
print(f'output: {"".join([chars[i] for i in token_output])}')

 input:  anarchism originated as a term of abuse first used against early working class radicals including the diggers of the english revolution and the sans culottes of the french revolution whilst the term is still used in a pejorative way to describe any act th
output: tndm h st af ginalid tn t shrm tf t ost aovst oned tsein t anrly tark ng toass aetioat  an ludeng the cesurs  tf the cnglish aeselution and the stmd aoltt ld af the cieech aeselution aiice eahe chrm tn aeall one  tn t srrui  ive tis th tescribe and antithe


In [74]:
torch.save(model, 'model.pth')