In [1]:
import torch
import torch.nn as nn


In [2]:
from torchtext.datasets import WikiText2, WikiText103
# data = WikiText103(root='data', split='train')
data = WikiText2(root='data', split='train')
loader = torch.utils.data.DataLoader(data, drop_last=True)

In [3]:
words = []
for i, text in enumerate(data):
    line = text.replace('  ', ' ')
    if len(line) > 0:
        words += list(filter(len, line.split(' '))) + ['\n']
len(words)

2125346

In [4]:
words[:10]

['\n', '\n', '=', 'Valkyria', 'Chronicles', 'III', '=', '\n', '\n', '\n']

In [5]:
# get unique characters in string text
vocab = tuple(set(words))
int2char = dict(enumerate(vocab))
char2int = {ch: ii for ii, ch in int2char.items()}
len(vocab)

33278

In [6]:
encoded = [char2int[ch] for ch in words]

In [7]:
# (batch, seq, feature)
class PositionEncoding(torch.nn.Module):
  def __init__(self, max_length, embed_size):
    super(PositionEncoding, self).__init__()
    self.max_length = max_length
    self.embed_size = embed_size

    pos = torch.arange(0, max_length).unsqueeze(1)
    args = pos / (10000 ** (2 * torch.arange(0, embed_size, 2) / embed_size))
    self.pe = torch.zeros((max_length, embed_size))
    self.pe[:, ::2] = torch.sin(args)
    self.pe[:, 1::2] = torch.cos(args)

  def forward(self, x):
    self.pe = self.pe.to(x.device)
    return x + self.pe.unsqueeze(0)


In [8]:
# q, k, v: (batch, seq_len, embed_size)
def attention(q, k, v, mask=None):
  qk = q @ k.transpose(-1, -2) / (k.shape[-1] ** 0.5)
  if mask is not None:
    qk = qk + mask
  weights = torch.softmax(qk, dim=-1)
  return weights @ v

# (batch, seq, feature)
class MultiHeadAttention(torch.nn.Module):
  def __init__(self, embed_size, heads, max_length):
    super(MultiHeadAttention, self).__init__()
    self.embed_size = embed_size
    self.heads = heads
    self.head_size = embed_size // heads
    self.max_length = max_length

    self.mask = torch.zeros((max_length, max_length))
    ix = torch.triu_indices(max_length, max_length, 1)
    self.mask[ix[0], ix[1]] = -1e9

    self.Wq = torch.nn.Linear(embed_size, embed_size)
    self.Wk = torch.nn.Linear(embed_size, embed_size)
    self.Wv = torch.nn.Linear(embed_size, embed_size)

    self.Wo = torch.nn.Linear(embed_size, embed_size)

  def forward(self, x):
    b, s, e = x.shape
    self.mask = self.mask.to(x.device)
    # x is (batch, seq_len, embed_size)
    q = self.Wq(x).view(b, s, self.heads, self.head_size).transpose(1, 2).contiguous()
    k = self.Wk(x).view(b, s, self.heads, self.head_size).transpose(1, 2).contiguous()
    v = self.Wv(x).view(b, s, self.heads, self.head_size).transpose(1, 2).contiguous()
    x = attention(q, k, v, self.mask)
    # x is now (heads, seq_len, head_size)
    x = x.transpose(1, 2).contiguous().view(b, s, self.embed_size)
    return self.Wo(x)


In [9]:
# (batch, seq, feature)
class FeedForward(torch.nn.Module):
  def __init__(self, embed_size):
    super(FeedForward, self).__init__()
    self.main = torch.nn.Sequential(
      torch.nn.Linear(embed_size, embed_size * 4),
      torch.nn.ReLU(),
      torch.nn.Linear(embed_size * 4, embed_size)
    )

  def forward(self, x):
    return self.main(x)


In [10]:
class TransformerBlock(torch.nn.Module):
  def __init__(self, embed_size, heads, max_length):
    super(TransformerBlock, self).__init__()
    self.embed_size = embed_size
    self.heads = heads

    self.attention = MultiHeadAttention(embed_size, heads, max_length)
    self.norm1 = torch.nn.LayerNorm(embed_size)
    self.norm2 = torch.nn.LayerNorm(embed_size)
    self.ff = FeedForward(embed_size)

  def forward(self, x):
    attended = self.attention(x)
    x = self.norm1(attended + x)
    fed = self.ff(x)
    x = self.norm2(fed + x)
    return x

In [11]:
class LLM(torch.nn.Module):
  def __init__(self, vocab_size, embed_size, depth, heads, max_length):
    super(LLM, self).__init__()
    self.vocab_size = vocab_size
    self.embed_size = embed_size
    self.max_length = max_length
    self.depth = depth
    self.heads = heads

    self.embedding = nn.Embedding(vocab_size, embed_size)
    self.pos_enc = PositionEncoding(max_length, embed_size)
    self.decoder = nn.Linear(embed_size, vocab_size)

    blocks = [TransformerBlock(embed_size, heads, max_length) for _ in range(depth)]
    self.blocks = nn.Sequential(*blocks)

    self.init_weights()

  def init_weights(self):
    for p in self.parameters():
      if p.dim() > 1:
        nn.init.xavier_normal_(p)

  def forward(self, x):
    x = self.embedding(x)
    # position encoding
    x = self.pos_enc(x)
    # feed through transformer blocks
    x = self.blocks(x)
    out = self.decoder(x)
    return out # torch.softmax(out, dim=-1)


In [12]:
vocab_size = len(vocab)
embed_size = 512
depth = 6
heads = 8
max_length = 256
batch_size = 32
model = LLM(vocab_size, embed_size, depth, heads, max_length)
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model size: {model_size / 1e6}M")

Model size: 53.024254M


In [39]:
output = model(torch.tensor(encoded[:max_length]).unsqueeze(0))

In [13]:
optim = torch.optim.Adam(model.parameters(), lr=1e-5)
loss_fn = torch.nn.CrossEntropyLoss()

def generate(model, encoded_prompt, length):
#  model.eval()
  x = torch.tensor(encoded_prompt).unsqueeze(0).to('cuda')
  with torch.no_grad():
    for i in range(length):
      y = model(x[0, -max_length:])
      y = torch.softmax(y, dim=-1)
      y = torch.multinomial(y[0, -1, :], 1).unsqueeze(0)
      x = torch.cat([x, y], dim=1)
  return ' '.join([vocab[i] for i in x[0]]), ' '.join([vocab[i] for i in encoded_prompt])

def train(model, optim, loss_fn, data, epochs=10, device="cpu"):
  model.to(device)
  print(len(data))
  lossi = []
  for epoch in range(epochs):
    for i in range((len(data) // max_length // batch_size) + 1):
      ix = torch.randint(0, len(data) - max_length - 1, (batch_size,)) if len(data) != max_length * batch_size + 1 else torch.zeros((batch_size,), dtype=torch.int64)
      x = torch.tensor([data[i:i + max_length] for i in ix]).to(device)
      y = torch.tensor([data[i + 1:i + max_length + 1] for i in ix]).to(device)
      if y.shape[1] != max_length:
        continue
      y_hat = model(x)
      loss = loss_fn(y_hat.view(-1, y_hat.shape[-1]), y.view(-1))
      lossi.append(loss.item())
      optim.zero_grad()
      loss.backward()
      optim.step()
      # if len(lossi)%100 == 0:
      #   print(f"epoch {epoch} i {i} loss {sum(lossi)/len(lossi)}")
    output, input = generate(model, data[:max_length], 12)
    output = output[len(input):]
    print(f"epoch {epoch} loss {sum(lossi)/len(lossi)}: {output}")
    lossi = []
#    print(f"epoch {epoch} loss {loss.item()}: {output}")


In [16]:
batch_size = 64
train(model, optim, loss_fn, encoded, epochs=1000, device="cuda")

2125346
epoch 0 loss 3.0586970971180842:  news up to reflect video that it was similar by the special
epoch 1 loss 3.068199695073641:  the series of <unk> from other and composing video . The scene
epoch 2 loss 3.0553893107634322:  the role of Halo , praised the development . His classic concept
epoch 3 loss 3.058496064406175:  a successful application of the line in it was work from The
epoch 4 loss 3.0609662679525522:  largely transfer if it was specifically interested in the game 's original
epoch 5 loss 3.056960085722116:  some <unk> in most of traditional arms , writing , one of
epoch 6 loss 3.0545913219451903:  a predecessor for the script , stating that pacing that special changes
epoch 7 loss 3.049109436915471:  extensive six billion entered the PlayStation Portable title for the game map
epoch 8 loss 3.047251644501319:  some different players and played in both the soundtrack . At the
epoch 9 loss 3.0459105601677527:  downloadable content by a more fan <unk> , the game , based

In [130]:
torch.cuda.empty_cache()

In [17]:
model.to('cuda')
input = encoded[0:256]
output = model(torch.tensor(input).to("cuda").unsqueeze(0))
token_output = torch.softmax(output, dim=-1).argmax(dim=-1)[0]
print(f' input: {" ".join([vocab[i] for i in input])}')
print(f'output: {" ".join([vocab[i] for i in token_output])}')

 input: 
 
 = Valkyria Chronicles III = 
 
 
 
 Senjō no Valkyria 3 : <unk> Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable . Released in January 2011 in Japan , it is the third game in the Valkyria series . <unk> the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the " Nameless " , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " <unk> Raven " . 
 
 The game began development in 2010 , carrying over a large portion of the work done on Valkyria Chronicles II . While it retained the standard features of the series , it also underwent multiple adjustments , such as making the game more <unk> for series n

In [18]:
torch.save(model, 'model-word-2k.pth')

In [22]:
output, input = generate(model, encoded[0:0+max_length], 256)
print(output)


 
 = Valkyria Chronicles III = 
 
 
 
 Senjō no Valkyria 3 : <unk> Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable . Released in January 2011 in Japan , it is the third game in the Valkyria series . <unk> the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the " Nameless " , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " <unk> Raven " . 
 
 The game began development in 2010 , carrying over a large portion of the work done on Valkyria Chronicles II . While it retained the standard features of the series , it also underwent multiple adjustments , such as making the game more <unk> for series newcomers