In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
from typing import Optional
from typing import Tuple

import nlp
import pytorch_lightning as pl
import torch

## SRU Implementation

In [2]:
class SRULayer(torch.nn.Module):
    def __init__(self, input_size, highway_bias=-3.0):
        super().__init__()
        self.input_size = input_size
        self.highway_bias = highway_bias
        self.weight = torch.nn.Parameter(torch.empty(input_size, 3*input_size))
        self.vector = torch.nn.Parameter(torch.empty(2*input_size))
        self.bias = torch.nn.Parameter(torch.empty(2*input_size))
        self.register_buffer("alpha", torch.sqrt(torch.tensor(3.0)))
        
        self.reset()
        
    @torch.no_grad()
    def reset(self):
        v = torch.sqrt(torch.tensor(3.0) / self.input_size)
        torch.nn.init.uniform_(self.weight, a=-v, b=v)
        torch.nn.init.uniform_(self.vector, a=-v, b=v)
        self.bias.fill_(self.highway_bias)
        
    def forward(self, x, state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
        """
        Args:
            x: torch.Tensor with size [N, T, F]
        """
        batch = x.size(0)
        if state is None:
            cell = torch.zeros(batch, self.input_size, dtype=x.dtype, device=x.device)
            hidden = torch.zeros(batch, self.input_size, dtype=x.dtype, device=x.device)
        else:
            cell = state[0]
            hidden = state[1]
            
        u = x @ self.weight   # N, T, 3H
        uw, uf, ur = u.split(self.input_size, dim=2)
        
        length = x.size(1)
        out = torch.empty_like(x)
        for i in range(length):
            tf = (self.vector[:self.input_size] * cell) + self.bias[:self.input_size]
            tr = (self.vector[self.input_size:] * cell) + self.bias[self.input_size:]

            ft = torch.sigmoid(uf[:, i] + tf)
            rt = torch.sigmoid(ur[:, i] + tr)

            cell_t = ft * cell + (1 - ft) * uw[:, i]
            hidden = rt * cell_t + (1 - rt) * x[:, i] * self.alpha
            out[:, i] = hidden
            cell = cell_t
        
        return out, (cell, hidden)

In [3]:
class SRU(torch.nn.Module):
    def __init__(self, n_hidden, n_layers, highway_bias=-3.0):
        super().__init__()
        self.srus = torch.nn.ModuleList([
            #torch.jit.script(SRULayer(n_hidden, highway_bias))
            SRULayer(n_hidden, highway_bias)
            for _ in range(n_layers)
        ])
        
    def forward(self, x, state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
        h = x
        new_h: List[torch.Tensor] = []
        new_c: List[torch.Tensor] = []
        for i, sru in enumerate(self.srus):
            if state is not None:
                s = state[0][i], state[1][i]
            else:
                s = None
            h, new_s = sru(h, s)
            new_h.append(new_s[0])
            new_c.append(new_s[1])
            
        new_h_t = torch.stack(new_h)
        new_c_t = torch.stack(new_c)
        
        return h, (new_h_t, new_c_t)

In [4]:
# batch = 8
# seq_len = 128
# input_size = 64
# n_layers = 2

# x = torch.empty(batch, seq_len, input_size).normal_()

# sru = torch.jit.script(SRU(input_size, n_layers)).cuda()
# #sru = SRU(input_size, n_layers)
# sru(x.cuda())[0]

In [5]:
# sru(x.cuda())

## Tokenizer

In [6]:
import math
import os
import pickle
import tempfile
import time

import torch.nn.functional as F
from tokenizers import SentencePieceBPETokenizer
from tqdm import tqdm, trange

In [7]:
class WikiText(torch.utils.data.Dataset):
    def __init__(self, split, tokenizer, seq_len=32):
        self.split = split
        self.tokenizer = tokenizer
        self.seq_len = seq_len
        self.data = self._load()
        
    def _load(self):
        all_data = []
        for x in tqdm(nlp.load_dataset("wikitext", split=self.split)):
            if not x["text"]:
                continue
            all_data.append(x["text"].strip())
            
        enc_data = []
        for d in tqdm(all_data):
            enc_data.extend(self.tokenizer.encode(d).ids)
        
        return enc_data
                   
    def __len__(self):
        return math.ceil(len(self.data) / self.seq_len)
    
    def __getitem__(self, idx):
        base = idx * self.seq_len
        x = self.data[base    :base     + self.seq_len]
        y = self.data[base + 1:base + 1 + self.seq_len]

        if len(x) < self.seq_len:
            x = x + [0] * (self.seq_len - len(x))
        if len(y) < self.seq_len:
            y = y + [0] * (self.seq_len - len(y))
            
        return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)
            
    def __repr__(self):
        return f"{self.__class__.__name__}(split='{self.split}', tokenizer={self.tokenizer})"

In [10]:
class Model(pl.LightningModule):
    def __init__(
        self, 
        train_batch_size, 
        valid_batch_size=1, 
        vocab_size=20000, 
        n_layers=1,
        n_hidden=512,
        seq_len=256
    ):
        super().__init__()
        self.train_batch_size = train_batch_size
        self.valid_batch_size = valid_batch_size
        self.vocab_size = vocab_size
        self.seq_len = seq_len
        
        self.tokenizer = self._init_tokenizer()
        self.train_dataset = WikiText("validation", self.tokenizer, seq_len)
        self.valid_dataset = WikiText("validation", self.tokenizer, seq_len)
        
        self.embedding = torch.nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=n_hidden
        )
        
        self.sru = torch.jit.script(SRU(n_hidden, n_layers))
        
        self.proj = torch.nn.Linear(
            in_features=n_hidden,
            out_features=vocab_size
        )
        
        self.state = None
        
    def _init_tokenizer(self):
        if not os.path.exists("data/vocab.json"):
            tokenizer = SentencePieceBPETokenizer()
            tokenizer.train(
                "data/wikitext-103-raw/wiki.train.raw",
                vocab_size=self.vocab_size
            )
            tokenizer.save("data/")
        else:
            tokenizer = SentencePieceBPETokenizer(
                vocab_file="data/vocab.json",
                merges_file="data/merges.txt"
            )
        return tokenizer
    
    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            dataset=self.train_dataset,
            batch_size=self.train_batch_size,
            shuffle=True,
            num_workers=8,
            pin_memory=True
        )
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            dataset=self.valid_dataset,
            batch_size=self.valid_batch_size,
            num_workers=8,
            pin_memory=True
        )
        
    def forward(self, x, state=None):
        """
        
        """
        h = self.embedding(x)
        h, state = self.sru(h)
        return self.proj(h), state
    
    def training_step(self, batch, batch_nb):
        x, y = batch
        y_hat, state = self(x)
        loss = F.cross_entropy(y_hat.view(-1, y_hat.size(-1)), y.view(-1))
        log = {"train_loss": loss}
        return {"loss": loss, "log": log}
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        with torch.no_grad():
            y_hat, state = self(x)
            loss = F.cross_entropy(y_hat.view(-1, y_hat.size(-1)), y.view(-1))
        return {"val_loss": loss}
    
    def validation_epoch_end(self, outputs):
        loss = ((1.0 / len(outputs)) * sum([x['val_loss'].item() for x in outputs]))
        try:
            ppl = math.exp(loss)
        except OverflowError:
            ppl = float("inf")
        log = {"val_loss": loss, "perplexity": ppl}
        return {'val_loss': ppl, "log": log}
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

In [11]:
model = Model(
    train_batch_size=128,
    valid_batch_size=64,
    vocab_size=8096,
    seq_len=512,
    n_layers=16,
    n_hidden=256
)

100%|██████████| 3760/3760 [00:00<00:00, 104653.72it/s]
100%|██████████| 2461/2461 [00:00<00:00, 10888.37it/s]
100%|██████████| 3760/3760 [00:00<00:00, 98996.78it/s]
100%|██████████| 2461/2461 [00:00<00:00, 10797.72it/s]


In [12]:
trainer = pl.Trainer(
    gpus=1,
    precision=16,
    val_check_interval=0.2
)

GPU available: True, used: True
No environment variable for node rank defined. Set as 0.
CUDA_VISIBLE_DEVICES: [0]
Using 16bit precision.


In [None]:
trainer.fit(model)


   | Name        | Type                  | Params
--------------------------------------------------
0  | embedding   | Embedding             | 2 M   
1  | sru         | RecursiveScriptModule | 3 M   
2  | sru.srus    | RecursiveScriptModule | 3 M   
3  | sru.srus.0  | RecursiveScriptModule | 197 K 
4  | sru.srus.1  | RecursiveScriptModule | 197 K 
5  | sru.srus.2  | RecursiveScriptModule | 197 K 
6  | sru.srus.3  | RecursiveScriptModule | 197 K 
7  | sru.srus.4  | RecursiveScriptModule | 197 K 
8  | sru.srus.5  | RecursiveScriptModule | 197 K 
9  | sru.srus.6  | RecursiveScriptModule | 197 K 
10 | sru.srus.7  | RecursiveScriptModule | 197 K 
11 | sru.srus.8  | RecursiveScriptModule | 197 K 
12 | sru.srus.9  | RecursiveScriptModule | 197 K 
13 | sru.srus.10 | RecursiveScriptModule | 197 K 
14 | sru.srus.11 | RecursiveScriptModule | 197 K 
15 | sru.srus.12 | RecursiveScriptModule | 197 K 
16 | sru.srus.13 | RecursiveScriptModule | 197 K 
17 | sru.srus.14 | RecursiveScriptModule | 197 K

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…