In [None]:
!pip install pytorch_lightning

Collecting pytorch_lightning
  Downloading pytorch_lightning-1.5.7-py3-none-any.whl (526 kB)
[?25l[K     |▋                               | 10 kB 20.2 MB/s eta 0:00:01[K     |█▎                              | 20 kB 9.3 MB/s eta 0:00:01[K     |█▉                              | 30 kB 6.0 MB/s eta 0:00:01[K     |██▌                             | 40 kB 5.6 MB/s eta 0:00:01[K     |███▏                            | 51 kB 5.1 MB/s eta 0:00:01[K     |███▊                            | 61 kB 5.3 MB/s eta 0:00:01[K     |████▍                           | 71 kB 5.0 MB/s eta 0:00:01[K     |█████                           | 81 kB 5.6 MB/s eta 0:00:01[K     |█████▋                          | 92 kB 5.6 MB/s eta 0:00:01[K     |██████▎                         | 102 kB 5.1 MB/s eta 0:00:01[K     |██████▉                         | 112 kB 5.1 MB/s eta 0:00:01[K     |███████▌                        | 122 kB 5.1 MB/s eta 0:00:01[K     |████████                        | 133 kB 5.1 MB

In [None]:
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import TensorDataset, DataLoader
from torch.optim import Adam

import pytorch_lightning as pl

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, d: int, heads: int=8):
        super().__init__()
        self.k, self.h = d, heads
        
        self.WQ = nn.Linear(d, d * heads, bias=False)
        self.WK = nn.Linear(d, d * heads, bias=False)
        self.WV = nn.Linear(d, d * heads, bias=False)
        
        self.unifyheads = nn.Linear(heads * d, d)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:

        b, l, d = x.size()
        h = self.h
        
        queries = self.WQ(x).view(b, l, h, d).transpose(1, 2).contiguous().view(b * h, l, d)
        keys = self.WK(x).view(b, l, h, d).transpose(1, 2).contiguous().view(b * h, l, d)
        values = self.WV(x).view(b, l, h, d).transpose(1, 2).contiguous().view(b * h, l, d)
        
        w_prime = torch.bmm(queries, keys.transpose(1, 2)) / np.sqrt(d)
        w = F.softmax(w_prime, dim=-1)  
        
        out = torch.bmm(w, values).view(b, h, l, d)
        
        out = out.transpose(1, 2).contiguous().view(b, l, h * d)
        
        return self.unifyheads(out)

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, d: int, heads: int=8, n_mlp: int=4):
        super().__init__()

        self.attention = SelfAttention(d, heads=heads)
        self.norm1 = nn.LayerNorm(d)
        self.norm2 = nn.LayerNorm(d)
        
        self.ff = nn.Sequential(
            nn.Linear(d, n_mlp*d),
            nn.ReLU(),
            nn.Linear(n_mlp*d, d)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_prime = self.attention(x)
        x = self.norm1(x_prime + x)
        
        x_prime = self.ff(x)
        return self.norm2(x_prime + x)

In [None]:
from keras.datasets import imdb
from keras.preprocessing.sequence import pad_sequences

class IMDBDataLoader(pl.LightningDataModule):

    def __init__(self, batch_size: int):
        super().__init__()
        self.batch_size = batch_size
        
    def setup(self, num_words: int, max_seq_len: int):
        (self.x_train, self.y_train), (self.x_test, self.y_test) = imdb.load_data(
            num_words=num_words, 
            maxlen=max_seq_len
        )
        
        self.word2idx = dict(
            **{k: v+3 for k, v in imdb.get_word_index().items()},
            **{'<PAD>': 0, '<START>': 1, '<UNK>': 2, '<UNUSED>': 3,
              },
        )
        self.idx2word = {v: k for k, v in self.word2idx.items()}
        
        self.x_train = pad_sequences(self.x_train, maxlen=max_seq_len, value = 0.0)
        self.x_test = pad_sequences(self.x_test, maxlen=max_seq_len, value = 0.0)
        
    
    def example(self):
        idx = np.random.randint(0, len(self.x_train))
        x, y = self.x_train[idx], self.y_train[idx]
        review = ' '.join(self.idx2word[token_id] for token_id in x if token_id > 1)
        sentiment = 'POSITIVE' if y else 'NEGATIVE'
        return f'Review : {review}\nSentiment: {sentiment}'
    
    def train_dataloader(self):
        dataset = TensorDataset(torch.LongTensor(self.x_train), 
                                torch.LongTensor(self.y_train))
        return DataLoader(dataset, self.batch_size)
                                
    def test_dataloader(self):
        dataset = TensorDataset(torch.LongTensor(self.x_test), 
                                torch.LongTensor(self.y_test))
        return DataLoader(dataset, self.batch_size)
    
    def val_dataloader(self):
        return self.test_dataloader()

In [None]:
class IMDBTransformer(pl.LightningModule):
    def __init__(self, d: int=128, heads: int=8, depth: int=6,
                max_seq_len: int=512, num_tokens: int=30000, 
                num_classes: int=2, learning_rate: float=1e-4):
        super().__init__()

        self.save_hyperparameters()
        
        self.num_tokens = num_tokens
        
        self.token_emb = nn.Embedding(num_tokens, d)
        self.pos_emb = nn.Embedding(max_seq_len, d)
        
        self.transformer_blocks = nn.Sequential(
            *[TransformerBlock(d=d, heads=heads) for _ in range(depth)]
        )
        
        self.classification = nn.Linear(d, num_classes)
        
        self.criterion = nn.CrossEntropyLoss()
        self.accuracy = pl.metrics.Accuracy()
        
    def forward(self, x: torch.LongTensor) -> torch.FloatTensor:

        b, l = x.size()
        d = self.hparams.d

        tokens = self.token_emb(x)
        positions = self.pos_emb(torch.arange(l).to(self.device)).expand(b, l, d)
        embeddings = tokens + positions

        out = self.transformer_blocks(embeddings)

        out = out.mean(dim=1)
        out = self.classification(out)

        return out

    
    def configure_optimizers(self):
         return Adam(self.parameters(), lr=self.hparams.learning_rate)
    
    def training_step(self, batch, batch_idx):
        x, y = batch      
        logits = self(x)
        loss = self.criterion(logits, y)
        self.log('loss', loss, on_epoch=True, prog_bar=True)
        self.log('acc', self.accuracy(logits, y), on_epoch=True, prog_bar=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        
        self.log('test_loss', loss, on_epoch=True)
        self.log('test_acc', self.accuracy(logits, y), on_epoch=True,
                 prog_bar=True)
        
    def validation_step(self, batch, batch_idx):
        return self.test_step(batch, batch_idx)

In [None]:
NUM_WORDS = 10000
MAX_SEQ_LEN = 128
EMBEDDING_DIM = 128
BATCH_SIZE = 32

imdb_data = IMDBDataLoader(batch_size=BATCH_SIZE)
imdb_data.setup(num_words=NUM_WORDS, max_seq_len=MAX_SEQ_LEN)
model = IMDBTransformer(d=EMBEDDING_DIM, max_seq_len=MAX_SEQ_LEN, num_tokens=NUM_WORDS)
trainer = pl.Trainer(max_epochs=5,
                     gpus=1)
trainer.fit(model, imdb_data)
_ = trainer.test()