<a href="https://colab.research.google.com/github/reissonsaavedramiguel/transformer_pytorch_light/blob/main/Pytorch_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Pytorch Transformer
**By Reisson Saavedra**


In [20]:
! pip install pytorch-lightning



In [21]:
import math
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as dataTorch

**Data**


1.   Generamos data input y output
2.   Output: Sequence(1,5,3)
3.   Input: Output con replacement






In [22]:
N = 10000
#Tamanio de la secuencia objetivo, secuence de input será 2 veces 
S = 32
#Número de clases, incluyendo 0.
C = 128
Y = (torch.rand((N * 10, S - 2)) * (C - 2)).long() + 2  # Only generate ints in (2, 99) range

# Solo filas únicas
Y = torch.tensor(np.unique(Y, axis=0)[:N])
X = torch.repeat_interleave(Y, 2, dim=1)

# Add special 0 "start" and 1 "end" tokens to beginning and end
Y = torch.cat([torch.zeros((N, 1)), Y, torch.ones((N, 1))], dim=1).long()
X = torch.cat([torch.zeros((N, 1)), X, torch.ones((N, 1))], dim=1).long()

# Look at the data
print(X, X.shape)
print(Y, Y.shape)
print(Y.min(), Y.max())

tensor([[  0,   2,   2,  ..., 103, 103,   1],
        [  0,   2,   2,  ...,  24,  24,   1],
        [  0,   2,   2,  ...,  75,  75,   1],
        ...,
        [  0,  14,  14,  ...,  62,  62,   1],
        [  0,  14,  14,  ...,  66,  66,   1],
        [  0,  14,  14,  ..., 112, 112,   1]]) torch.Size([10000, 62])
tensor([[  0,   2,   2,  ...,  95, 103,   1],
        [  0,   2,   2,  ...,  85,  24,   1],
        [  0,   2,   2,  ...,  40,  75,   1],
        ...,
        [  0,  14, 109,  ...,  92,  62,   1],
        [  0,  14, 109,  ..., 102,  66,   1],
        [  0,  14, 110,  ..., 118, 112,   1]]) torch.Size([10000, 32])
tensor(0) tensor(127)


In [23]:
BATCH_SIZE = 128
TRAIN_FRAC = .8
# This fulfills the pytorch.utils.data.Dataset interface
dataset = list(zip(X,Y))

num_train = int(N*TRAIN_FRAC)
num_val = N - num_train
dataTrain, dataVal = dataTorch.random_split(dataset,(num_train,num_val))

dataloader_train = torch.utils.data.DataLoader(dataTrain, batch_size=BATCH_SIZE)
dataloader_val = torch.utils.data.DataLoader(dataVal, batch_size=BATCH_SIZE)

x,y = next(iter(dataloader_train))
x,y

(tensor([[ 0,  9,  9,  ..., 52, 52,  1],
         [ 0,  9,  9,  ..., 62, 62,  1],
         [ 0, 14, 14,  ..., 16, 16,  1],
         ...,
         [ 0,  4,  4,  ..., 83, 83,  1],
         [ 0, 14, 14,  ..., 59, 59,  1],
         [ 0,  6,  6,  ..., 18, 18,  1]]),
 tensor([[  0,   9,  79,  ...,  84,  52,   1],
         [  0,   9,  84,  ..., 100,  62,   1],
         [  0,  14,  39,  ...,  77,  16,   1],
         ...,
         [  0,   4,   9,  ...,  40,  83,   1],
         [  0,  14,  90,  ...,   3,  59,   1],
         [  0,   6,  17,  ...,  10,  18,   1]]))

## Model

![](https://media.arxiv-vanity.com/render-output/3715543/Figures/ModalNet-21.png)


In [24]:
class PositionalEncoding(nn.Module):
  def __init__(self, d_model, dropout = .1, max_len = 5000):
    super(PositionalEncoding, self).__init__()
    self.dropout = nn.Dropout(p=dropout)

    pe= torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len, dtype = torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0,d_model,2).float() * (-math.log(10000.0) / d_model))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    pe = pe.unsqueeze(0).transpose(0,1)
    self.register_buffer('pe',pe)

  def forward(self, x):
    x = x + self.pe[:x.size(0), :]
    return self.dropout(x)

def generate_square_subsequent_mask(size: int):
  # Generate a triangular(size, size) mask .From PyT docs
  mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0,1)
  mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
  return mask
 

In [25]:
class Transformer(nn.Module):
  """
    Classic Transformer that both encodes and decodes.
    
    Prediction-time inference is done greedily.

    NOTE: start token is hard-coded to be 0, end token to be 1. If changing, update predict() accordingly.
    """
  def __init__(self, num_classes: int, max_output_length: int, dim: int= 128):
    super().__init__()

      #Parameters
    self.dim = dim
    self.max_output_length = max_output_length
    nhead = 4
    num_layers = 4
    dim_feedforward = dim

    #Encoder part
    self.embedding = nn.Embedding(num_classes, dim)
    self.pos_encoder = PositionalEncoding(d_model = self.dim)
    self.transformer_encoder = nn.TransformerEncoder(
       encoder_layer = nn.TransformerEncoderLayer(
           d_model=self.dim,
            nhead=nhead,
           dim_feedforward = dim_feedforward),
      num_layers = num_layers
    )

    #Decoder part

    self.y_mask = generate_square_subsequent_mask(self.max_output_length)
    self.transformer_decoder = nn.TransformerDecoder(
      decoder_layer = nn.TransformerDecoderLayer(
          d_model = self.dim,
           nhead = nhead,
          dim_feedforward = dim_feedforward),
      num_layers = num_layers
    )
    self.fc = nn.Linear(self.dim, num_classes)

    self.init_weights()
  
  def init_weights(self):
    initrange = .1
    self.embedding.weight.data.uniform_(-initrange, initrange)
    self.fc.bias.data.zero_()
    self.fc.weight.data.uniform_(-initrange,initrange)

  def forward(self, x: torch.Tensor, y: torch.Tensor):
    """
        Input
            x: (B, Sx) with elements in (0, C) where C is num_classes
            y: (B, Sy) with elements in (0, C) where C is num_classes
        Output
            (B, C, Sy) logits
    """
    encoded_x = self.encode(x)  # (Sx, B, E)
    output = self.decode(y, encoded_x)  # (Sy, B, C)
    return output.permute(1, 2, 0)  # (B, C, Sy)

  def encode(self, x: torch.Tensor) -> torch.Tensor:
    x = x.permute(1,0)
    x = self.embedding(x) * math.sqrt(self.dim)
    x = self.pos_encoder(x)
    x = self.transformer_encoder(x)
    return x

  def decode(self, y:torch.Tensor, encoded_x : torch.Tensor)->torch.Tensor:
    y = y.permute(1,0)
    y = self.embedding(y)*math.sqrt(self.dim)
    y = self.pos_encoder(y)
    Sy = y.shape[0]
    y_mask = self.y_mask[:Sy, :Sy].type_as(encoded_x)
    output = self.transformer_decoder(y, encoded_x,y_mask)
    output = self.fc(output)
    return output

  def predict(self, x:torch.Tensor)->torch.Tensor:
    encoded_x = self.encode(x)
    output_tokens = (torch.ones((x.shape[0], self.max_output_length))).type_as(x).long()
    output_tokens[:, 0] = 0

    for Sy in range(1, self.max_output_length):
      y = output_tokens[:, :Sy]  # (B, Sy)
      output = self.decode(y, encoded_x)  # (Sy, B, C)
      output = torch.argmax(output, dim=-1)  # (Sy, B)
      output_tokens[:, Sy] = output[-1:]  # Set the last output token
    return output_tokens

model = Transformer(num_classes = C, max_output_length = y.shape[1])
logits = model(x, y[:,:-1])
print(x.shape, y.shape,logits.shape)
print(x[0:1])
print(model.predict(x[0:1]))

torch.Size([128, 62]) torch.Size([128, 32]) torch.Size([128, 128, 31])
tensor([[  0,   9,   9,  79,  79,  14,  14,  30,  30, 110, 110,   6,   6,  96,
          96,  12,  12,  42,  42,  96,  96,  62,  62,  96,  96,  42,  42,  28,
          28,  61,  61, 127, 127,   5,   5,   4,   4, 100, 100,  55,  55,  17,
          17,  83,  83,  38,  38,  68,  68,  61,  61,  18,  18,  28,  28, 125,
         125,  84,  84,  52,  52,   1]])
tensor([[  0,  14,  14,  14, 126,   7, 126,   7, 126,   7, 126, 126,  78,   7,
         126, 126, 126, 126, 126, 126, 126, 126, 126, 126, 126, 126,  78, 126,
         126,   7, 126,  78]])


In [31]:
class LitModel (pl.LightningModule):
  def __init__(self, model):
    super().__init__()
    self.model = model
    self.loss = nn.CrossEntropyLoss()
    self.val_acc = pl.metrics.Accuracy()

  def training_step(self, batch, batch_ind):
    x, y = batch
    logits = self.model(x, y[:,:-1])
    loss = self.loss (logits, y[: , 1:])
    self.log('train_loss', loss)
    return loss

  def validation_step(self, batch, batch_ind):
    x,y = batch
    logits = self.model(x,y[:,:-1])
    loss = self.loss(logits, y[:, 1:])
    self.log("val_loss", loss, prog_bar=True)
    pred = self.model.predict(x)
    self.val_acc(pred, y)
    self.log("val_acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
  
  def configure_optimizers(self):
      return torch.optim.Adam(self.parameters())

model = Transformer(num_classes=C, max_output_length=y.shape[1])
lit_model = LitModel(model)
early_stop_callback = pl.callbacks.EarlyStopping(monitor='val_loss')
trainer = pl.Trainer(max_epochs=5, gpus=[0], callbacks=[early_stop_callback], progress_bar_refresh_rate=79)
trainer.fit(lit_model, dataloader_train, dataloader_val)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type             | Params
---------------------------------------------
0 | model   | Transformer      | 1.1 M 
1 | loss    | CrossEntropyLoss | 0     
2 | val_acc | Accuracy         | 0     
---------------------------------------------
1.1 M     Trainable params
0         Non-trainable params
1.1 M     Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




1

In [32]:
# We can see that the decoding works correctly

x, y = next(iter(dataloader_val))
print('Input:', x[:1])
pred = lit_model.model.predict(x[:1])
print('Truth/Pred:')
print(torch.cat((y[:1], pred)))

Input: tensor([[  0,  14,  14,  37,  37,  30,  30, 113, 113,  19,  19,  84,  84, 116,
         116,   5,   5,  15,  15,  82,  82,  76,  76, 125, 125,  20,  20,  21,
          21, 111, 111, 103, 103,  42,  42, 113, 113,  72,  72,  32,  32,  36,
          36,  88,  88, 106, 106, 112, 112,  70,  70,  26,  26,  28,  28,  91,
          91,  52,  52,  80,  80,   1]])
Truth/Pred:
tensor([[  0,  14,  37,  30, 113,  19,  84, 116,   5,  15,  82,  76, 125,  20,
          21, 111, 103,  42, 113,  72,  32,  36,  88, 106, 112,  70,  26,  28,
          91,  52,  80,   1],
        [  0,  14,  37,  30, 113,  19,  84, 116,   5,  15,  82,  76, 125,  20,
          21, 111, 103,  42, 113,  72,  32,  36,  88, 106, 112,  70,  26,  28,
          91,  52,  80,   1]])
