In [1]:
from typing import Optional
from typing import Tuple

import nlp
import pytorch_lightning as pl
import torch

## SRU Implementation

In [2]:
class SRUCell(torch.nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.input_size = input_size
        self.weight = torch.nn.Parameter(torch.empty(input_size, 3*input_size))
        self.vector = torch.nn.Parameter(torch.empty(2*input_size))
        self.bias = torch.empty(2*input_size)
        self.register_buffer("alpha", torch.sqrt(torch.tensor(3.0)))
        
        self.reset()
        
    def reset(self):
        v = torch.sqrt(torch.tensor(3.0) / self.input_size)
        torch.nn.init.uniform_(self.weight, a=-v, b=v)
        torch.nn.init.uniform_(self.vector, a=-v, b=v)
        self.bias.zero_()
        
    def forward(self, x, state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
        """
        Args:
            x: torch.Tensor with size [N, F]
        """
        if state is None:
            batch = x.size(0)
            cell = torch.zeros(batch, self.input_size)
            hidden = torch.zeros(batch, self.input_size)
        else:
            cell = state[0]
            hidden = state[1]
            
        u = x @ self.weight   # N, 3H
        uw, uf, ur = u.split(self.input_size, dim=1)
        
        v = (self.vector.view(2, self.input_size) * cell).view(2*self.input_size) + self.bias
        tf, tr = v.split(self.input_size)
        
        ft = torch.sigmoid(uf + tf)
        rt = torch.sigmoid(ur + tr)
        
        cell_t = ft * cell + (1 - ft) * uw
        hidden_t = rt * cell_t + (1 - rt) * x * self.alpha
        
        return hidden_t, (cell_t, hidden_t)

In [3]:
class SRU(torch.nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.cell = torch.jit.script(SRUCell(input_size))
        
    def forward(self, x, state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
        """
        Args:
            x: torch.Tensor with size [N, T, F]
        """
        length = x.size(1)
        state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
        out = torch.empty_like(x)
        for i in range(length):
            out[:, i], state = self.cell(x[:, i])
        return out

In [4]:
batch = 2
seq_len = 100
input_size = 64

x = torch.empty(batch, seq_len, input_size).normal_()

sru = torch.jit.script(SRU(input_size))

In [5]:
sru(x)

tensor([[[-0.0393, -1.4393, -1.8755,  ...,  1.8419,  0.0745, -1.4392],
         [ 1.8173, -0.1810, -1.0920,  ...,  0.0106, -0.7744, -1.5187],
         [-0.6281,  0.2251,  0.8739,  ...,  0.3141, -2.8450,  0.2109],
         ...,
         [ 1.8597,  1.2647, -1.4092,  ...,  0.3748,  0.1780,  0.8282],
         [-0.0764,  1.6122,  0.6091,  ...,  0.6725, -0.1031, -0.0130],
         [ 0.0983,  0.8663, -0.5963,  ...,  0.8382, -1.1926,  0.6968]],

        [[ 0.5164,  0.2064, -0.9087,  ...,  0.8550,  1.5220, -1.8969],
         [ 0.1648, -1.2897, -0.5557,  ...,  1.0085,  0.1047, -1.3821],
         [ 0.1513,  0.5151, -0.3657,  ...,  0.1641, -0.2462,  0.5926],
         ...,
         [ 0.2457,  2.7698,  0.9129,  ..., -0.6625, -1.0116, -0.3728],
         [ 0.7080, -1.7702,  0.5850,  ...,  1.1599, -0.7030, -1.2377],
         [-2.2299, -1.6632, -0.5999,  ..., -1.7677, -0.3579, -1.3820]]],
       grad_fn=<CopySlices>)

## Tokenizer

In [6]:
import math
import os
import tempfile
import time

from tokenizers import SentencePieceBPETokenizer

In [7]:
def chunks(lst, n, drop_short=True):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        c = lst[i:i + n]
        if len(c) == n:
            yield c

In [8]:
class WikiText(torch.utils.data.Dataset):
    def __init__(self, split, tokenizer, batch_size=1, seq_len=32):
        self.split = split
        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.seq_len = seq_len
        self.data = self._load()
        
    def _load(self):
        all_data = []
        tot_len = 0
        for x in nlp.load_dataset("wikitext", split=self.split):
            if not x["text"]:
                continue
            all_data.append(x["text"].strip())
            tot_len += len(x["text"])
            
        data = []
        max_chunk = math.ceil(tot_len / self.batch_size)
        cur_chunk = 0
        cur_data = []
        for sentence in all_data:
            cur_data.append(sentence)
            cur_chunk += len(sentence)
            if cur_chunk >= max_chunk:
                c_ids = []
                for c in chunks(self.tokenizer.encode(" ".join(cur_data)).ids, self.seq_len):
                    c_ids.append(c)
                data.append(c_ids)
                cur_chunk = 0
                cur_data = []
                
        if cur_data:
            c_ids = []
            for c in chunks(self.tokenizer.encode(" ".join(cur_data)).ids, self.seq_len):
                c_ids.append(c)
            data.append(c_ids)

            max_len = max([len(chunk) for chunk in data])
            for i in range(len(data)):
                data[i] = data[i][:max_len]
                
        return data
                
    def __len__(self):
        return len(self.data[0])
    
    def __getitem__(self, idx):
        return torch.tensor([chunk[idx] for chunk in self.data], dtype=torch.long)
                
    def __repr__(self):
        return f"{self.__class__.__name__}(split='{self.split}', tokenizer={self.tokenizer})"

In [9]:
class Model(pl.LightningModule):
    def __init__(self, train_batch_size, valid_batch_size=1):
        super().__init__()
        self.tokenizer = self._init_tokenizer()
        self.train_dataset = WikiText("train", self.tokenizer, train_batch_size)
        self.valid_dataset = WikiText("validation", self.tokenizer, valid_batch_size)
        
    def _init_tokenizer(self):
        if not os.path.exists("data/vocab.json"):
            tokenizer = SentencePieceBPETokenizer()
            tokenizer.train(
                "data/wikitext-103-raw/wiki.train.raw",
                vocab_size=20000
            )
            tokenizer.save("data/")
        else:
            tokenizer = SentencePieceBPETokenizer(
                vocab_file="data/vocab.json",
                merges_file="data/merges.txt"
            )
        return tokenizer
    
    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            dataset=self.train_dataset,
            batch_size=None,
        )
    
    def valid_dataloader(self):
        return torch.utils.data.DataLoader(
            dataset=self.valid_dataset,
            batch_size=None,
        )
        
    def forward(self, x):
        """
        
        """
        pass
    
    def training_step(self, batch, batch_nb):
        pass

In [None]:
model = Model(
    train_batch_size=2
)

In [None]:
model.tokenizer

In [186]:
next(iter(model.valid_dataloader()))

tensor([[ 1045, 12170, 13261, 13780,  3671,  1090,  1045, 12170, 13261, 13780,
          3671,  1090,  1008,  1775,  1089,  1005,  3294,  8619,  2932,  1185,
          2587,  8619,  2932,  1008,  1121,  1003,  2244,  1026,  1266,  1436,
          1011,  8619],
        [ 1045,  1045,  3379,  1030,  2420,  1045,  1045,  9118,  1028, 12997,
          1008, 10087,  1008,  1120,  3172,    72,  1687,  1962,  3722, 14273,
          1089,  1003,  2154,  1008,  1028, 18541,  1008,  4096,  1008,  1033,
          6337,  1016]])