In [None]:
!pip install lightning

Collecting lightning
  Downloading lightning-2.4.0-py3-none-any.whl.metadata (38 kB)
Collecting lightning-utilities<2.0,>=0.10.0 (from lightning)
  Downloading lightning_utilities-0.11.8-py3-none-any.whl.metadata (5.2 kB)
Collecting torchmetrics<3.0,>=0.7.0 (from lightning)
  Downloading torchmetrics-1.5.1-py3-none-any.whl.metadata (20 kB)
Collecting pytorch-lightning (from lightning)
  Downloading pytorch_lightning-2.4.0-py3-none-any.whl.metadata (21 kB)
Downloading lightning-2.4.0-py3-none-any.whl (810 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m811.0/811.0 kB[0m [31m24.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.11.8-py3-none-any.whl (26 kB)
Downloading torchmetrics-1.5.1-py3-none-any.whl (890 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m890.6/890.6 kB[0m [31m35.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pytorch_lightning-2.4.0-py3-none-any.whl (815 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader

import lightning as L

In [None]:
token_to_id = {
    'What': 0,
    'is': 1,
    'PyTorch': 2,
    'Awesome': 3,
    '<EOS>': 4
}

id_to_token = dict(map(reversed, token_to_id.items()))

In [None]:
inputs = torch.tensor([[token_to_id['What'],
                       token_to_id['is'],
                       token_to_id['PyTorch'],
                       token_to_id['<EOS>'],
                       token_to_id['Awesome']],

                      [token_to_id['PyTorch'],
                       token_to_id['is'],
                       token_to_id['What'],
                       token_to_id['<EOS>'],
                       token_to_id['Awesome']]])

labels = torch.tensor([[token_to_id['is'],
                       token_to_id['PyTorch'],
                       token_to_id['<EOS>'],
                       token_to_id['Awesome'],
                       token_to_id['<EOS>']],

                      [token_to_id['is'],
                       token_to_id['What'],
                       token_to_id['<EOS>'],
                       token_to_id['Awesome'],
                       token_to_id['<EOS>']]])

dataset = TensorDataset(inputs, labels)
dataloader = DataLoader(dataset)

In [None]:
class PositionalEncoding(nn.Module):
  def __init__(self, d_model=2, max_len=6):
    super().__init__()

    pe = torch.zeros(max_len, d_model)

    position = torch.arange(start=0, end=max_len, step=1).float().unsqueeze(1)
    embedding_index = torch.arange(start=0, end=d_model, step=2).float()

    div_term = 1/torch.tensor(10000.0)**(embedding_index / d_model)

    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)

    self.register_buffer('pe', pe)

  def forward(self, word_embeddings):
    return word_embeddings + self.pe[:word_embeddings.size(0), :]

In [None]:
class Attention(nn.Module):
  def __init__(self, d_model=2):
    super().__init__()

    self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
    self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
    self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)

    self.row_dim = 0
    self.col_dim = 1

  def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask=None):

    q = self.W_q(encodings_for_q)
    k = self.W_k(encodings_for_k)
    v = self.W_v(encodings_for_v)

    sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))
    scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)

    if mask is not None:
      scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)

    attention_percents = F.softmax(scaled_sims, dim=self.col_dim)

    attention_scores = torch.matmul(attention_percents, v)

    return attention_scores

In [None]:
class DecoderOnlyTransformer(L.LightningModule):
  def __init__(self, num_tokens=4, d_model=2, max_len=6):
    super().__init__()

    self.we = nn.Embedding(num_embeddings=num_tokens, embedding_dim=d_model)
    self.pe = PositionalEncoding(d_model=d_model, max_len=max_len)
    self.self_attention = Attention(d_model=d_model)
    self.fc_layer = nn.Linear(in_features=d_model, out_features=num_tokens)

    self.loss = nn.CrossEntropyLoss()

  def forward(self, token_ids):

    word_embeddings = self.we(token_ids)
    position_encoding = self.pe(word_embeddings)

    mask = torch.tril(torch.ones((token_ids.size(dim=0), token_ids.size(dim=0))))
    mask = mask == 0

    self_attention_values = self.self_attention(position_encoding, position_encoding, position_encoding, mask)

    residual_connection_values = self_attention_values + position_encoding

    fc_layer_output = self.fc_layer(residual_connection_values)

    return fc_layer_output

  def configure_optimizers(self):
    return Adam(self.parameters(), lr=0.1)

  def training_step(self, batch, batch_idx):
    input_tokens, labels = batch
    output = self.forward(input_tokens[0])
    loss = self.loss(output, labels[0])

    return loss

In [None]:
model = DecoderOnlyTransformer(num_tokens=len(token_to_id), d_model=2, max_len=6)

model_input = torch.tensor([token_to_id['What'],
                            token_to_id['is'],
                            token_to_id['PyTorch'],
                            token_to_id['<EOS>']])

input_length = model_input.size(dim=0)

predictions = model(model_input)
prediction_id = torch.tensor([torch.argmax(predictions[-1, :])])
prediction_ids = prediction_id

max_length = 6
for i in range(input_length, max_length):
  if (prediction_id == token_to_id['<EOS>']):
    break
  model_input = torch.cat((model_input, prediction_id))
  predictions = model(model_input)
  prediction_id = torch.tensor([torch.argmax(predictions[-1, :])])
  prediction_ids = torch.cat((prediction_ids, prediction_id))

print("Predicted Tokens:")
for id in prediction_ids:
  print("\t", id_to_token[id.item()])


Predicted Tokens:
	 Awesome
	 Awesome
	 What


In [None]:
trainer = L.Trainer(max_epochs=30)
trainer.fit(model, train_dataloaders=dataloader)

INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: 
  | Name           | Type               | Params | Mode 
--------------------------------------------------------------
0 | we             | Embedding          | 10     | train
1 | pe             | PositionalEncoding | 0      | train
2 | self_attention | Attention          | 12     | train
3 | fc_layer       | Linear             | 15     | train
4 | loss           | CrossEntropyLoss   | 0      | train
--------------------------------------------------------------
37        Trainable params
0         Non-trainable params
37        Total params
0.000     Total estimated model params size (MB)
8         Modules in train

Training: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=30` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=30` reached.


In [None]:
model_input = torch.tensor([token_to_id['What'],
                            token_to_id['is'],
                            token_to_id['PyTorch'],
                            token_to_id['<EOS>']])

input_length = model_input.size(dim=0)

predictions = model(model_input)
prediction_id = torch.tensor([torch.argmax(predictions[-1, :])])
prediction_ids = prediction_id

max_length = 6
for i in range(input_length, max_length):
  if (prediction_id == token_to_id['<EOS>']):
    break
  model_input = torch.cat((model_input, prediction_id))
  predictions = model(model_input)
  prediction_id = torch.tensor([torch.argmax(predictions[-1, :])])
  prediction_ids = torch.cat((prediction_ids, prediction_id))

print("Predicted Tokens:")
for id in prediction_ids:
  print("\t", id_to_token[id.item()])

Predicted Tokens:
	 Awesome
	 <EOS>


In [None]:
model_input = torch.tensor([token_to_id['PyTorch'],
                            token_to_id['is'],
                            token_to_id['What'],
                            token_to_id['<EOS>']])

input_length = model_input.size(dim=0)

predictions = model(model_input)
prediction_id = torch.tensor([torch.argmax(predictions[-1, :])])
prediction_ids = prediction_id

max_length = 6
for i in range(input_length, max_length):
  if (prediction_id == token_to_id['<EOS>']):
    break
  model_input = torch.cat((model_input, prediction_id))
  predictions = model(model_input)
  prediction_id = torch.tensor([torch.argmax(predictions[-1, :])])
  prediction_ids = torch.cat((prediction_ids, prediction_id))

print("Predicted Tokens:")
for id in prediction_ids:
  print("\t", id_to_token[id.item()])

Predicted Tokens:
	 Awesome
	 <EOS>
