In [143]:
pip install torch



In [295]:
import random

import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

from einops import rearrange, reduce, asnumpy, parse_shape
from einops.layers.torch import Rearrange, Reduce

In [145]:
torch.__version__

'2.0.1+cu118'

In [8]:
class DigitDataset(Dataset):
    def __init__(self,
                 source,
                 target):
        self.source = source
        self.target = target

    def __len__(self,):
        return len(self.source)

    def __getitem__(self, idx):
        output = dict(
            src = self.source[idx],
            tgt = self.target[idx]
        )
        return output

In [14]:
num_samples = 10000
seq_len = 5
vocab_size = len(range(10))



In [34]:
X = torch.randint(10, (num_samples, seq_len), requires_grad=False)
y = torch.flip(X, dims=[1])
y.requires_grad = False


train_size = 0.9
train_slice = int(train_size * num_samples)

train_dataset = DigitDataset(X[:train_slice], y[: train_slice])
test_dataset = DigitDataset(X[train_slice:], y[train_slice:])

In [37]:
len(train_dataset), len(test_dataset)

(9000, 1000)

In [39]:
train_dataset.source[0], train_dataset.target[0]

(tensor([8, 2, 1, 1, 9]), tensor([9, 1, 1, 2, 8]))

In [40]:
test_dataset.source[0], test_dataset.target[0]

(tensor([9, 1, 1, 3, 0]), tensor([0, 3, 1, 1, 9]))

In [41]:
train_dataloader = DataLoader(train_dataset, shuffle=True, drop_last=False, batch_size=512)
test_dataloader = DataLoader(test_dataset, shuffle=False, drop_last=False, batch_size=512)

In [181]:
test_input['src'].shape

torch.Size([512, 5])

In [202]:
haha = LSTMEncoder(nn.Embedding(10, 100),
                   100, 4, 0.5, 0.5)
hc = haha(test_input['src'])

In [207]:
dec = LSTMDecoder(nn.Embedding(10, 100),
                  100,
                  4,
                  10,)


In [210]:
dec(test_input['tgt'][:, 0], hc)[0].shape

torch.Size([512, 10])

### Model itself

In [278]:
from IPython.utils.path import target_outdated
class LSTMEncoder(nn.Module):
    def __init__(self,
                 embedding_layer,
                 hidden_size,
                 num_layers,
                 lstm_dropout,
                 embedding_dropout):

        super().__init__()
        self.embedding_layer = embedding_layer
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm_dropout = lstm_dropout
        self.embedding_dropout = embedding_dropout
        self.emb_dropout = nn.Dropout(p=self.embedding_dropout)

        self.lstm_encoder = nn.LSTM(input_size=self.hidden_size,
                                    hidden_size=self.hidden_size,
                                    num_layers=self.num_layers,
                                    dropout=self.lstm_dropout,
                                    batch_first=True)

    def forward(self, src):
        batch_size = src.shape[1]
        seq_len = src.shape[0]
        _, hc = self.lstm_encoder(self.emb_dropout(self.embedding_layer(src)))
        return hc


class LSTMDecoder(nn.Module):
    def __init__(self,
                 embedding_layer,
                 hidden_size,
                 num_layers,
                 num_classes,
                 lstm_dropout=0.2,
                 embedding_dropout=0.2):
        super().__init__()
        self.embedding_layer = embedding_layer
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm_dropout = lstm_dropout
        self.embedding_dropout = embedding_dropout

        self.embedding_layer = embedding_layer #embedding layer is same for both encoder & decoder with emb_dim = hidden_dim
        self.emb_dropout = nn.Dropout(p=self.embedding_dropout)
        self.rnn = nn.LSTM(input_size=self.hidden_size,
                            hidden_size=self.hidden_size,
                            num_layers=self.num_layers,
                            dropout=self.lstm_dropout,
                            batch_first=True)

        self.proj_layer = nn.Linear(hidden_size, num_classes)

    def forward(self,
                tgt,
                hc):

        """
        decoder is autoregressive,
        input:
            hc-> hidden state from encoder
            tgt -> vector of shape [b_size, 1] of current seq input
        """
        tgt = tgt.unsqueeze(1) #adding 1st dim, so we have seq_len == 1

        y_pred, hc = self.rnn(self.emb_dropout(self.embedding_layer(tgt)), hc)
        #print(y_pred.shape) #for debugging
        y_pred = y_pred.squeeze(1)  #going back to [b_size, 100]

        return self.proj_layer(y_pred), hc



class Seq2Seq(nn.Module):
    def __init__(self,
                 hidden_size,
                 num_layers,
                 num_classes,
                 vocab_size,
                 lstm_dropout=0.2,
                 embedding_dropout=0.2):

        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_classes = num_classes
        self.vocab_size = vocab_size
        self.lstm_dropout = lstm_dropout
        self.embedding_dropout = embedding_dropout

        self.embedding_layer = nn.Embedding(vocab_size, hidden_size)
        self.encoder = LSTMEncoder(self.embedding_layer,
                                   hidden_size,
                                   num_layers,
                                   lstm_dropout,
                                   embedding_dropout)

        self.decoder = LSTMDecoder(self.embedding_layer,
                                   hidden_size,
                                   num_layers,
                                   num_classes,
                                   lstm_dropout,
                                   embedding_dropout)

    def forward(self, src, tgt, teacher_forcing: float=0.5):
        hidden_cell = self.encoder(src)
#        print(hidden_cell[0].shape, hidden_cell[1].shape)

        b_size, max_seq_len =  tgt.shape[0], tgt.shape[1]
        input_token = tgt[:, 0]
        outputs_seq = torch.zeros(b_size, max_seq_len, self.vocab_size)

        for t in range(1, max_seq_len):
            output, hidden_cell = self.decoder(input_token, hidden_cell)
            outputs_seq[:, t] = output
            predicted_token = torch.argmax(output, dim=-1)

            do_teacher_forcing = random.random() < teacher_forcing

            if do_teacher_forcing:
                input_token = tgt[:, t]
            else:
                input_token = predicted_token


        return outputs_seq

In [311]:
hidden_size = 100
num_layers=2
num_classes = 10
vocab_size = 10
net = Seq2Seq(hidden_size, num_layers, num_classes, vocab_size)


loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

In [312]:
from tqdm.auto import tqdm
def train_one_epoch(train_dataloader, model, optimizer, loss_fn, epoch, ):
    model.train()
    train_loss = 0
    for idx, batch in enumerate(train_dataloader):
        optimizer.zero_grad()
        src, tgt = batch['src'], batch['tgt']
        preds = model(src, tgt)
        preds = rearrange(preds, 'b s c -> b c s')

        loss = loss_fn(preds, tgt)
        loss.backward()

        optimizer.step()

        train_loss += loss.item()

    train_loss /= idx
    print(f'Epoch{epoch}, train_loss: {train_loss}')
    return train_loss

In [313]:
for epoch in tqdm(range(100)):
    train_one_epoch(train_dataloader, net, optimizer, loss_fn, epoch)

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch0, train_loss: 2.389103005914127
Epoch1, train_loss: 2.000439728007597
Epoch2, train_loss: 1.66267491088194
Epoch3, train_loss: 1.488317742067225
Epoch4, train_loss: 1.325013279914856
Epoch5, train_loss: 1.1860505307421965
Epoch6, train_loss: 1.0487101568895227
Epoch7, train_loss: 0.8977418191292706
Epoch8, train_loss: 0.7818280177957871
Epoch9, train_loss: 0.7112958641613231
Epoch10, train_loss: 0.6617464317994959
Epoch11, train_loss: 0.6128384470939636
Epoch12, train_loss: 0.595630431876463
Epoch13, train_loss: 0.5668694937930387
Epoch14, train_loss: 0.5570773236891803
Epoch15, train_loss: 0.5436872138696558
Epoch16, train_loss: 0.5355622505440432
Epoch17, train_loss: 0.5280459803693435
Epoch18, train_loss: 0.5218996949055615
Epoch19, train_loss: 0.5183721202261308
Epoch20, train_loss: 0.5154929564279669
Epoch21, train_loss: 0.5113664041547215
Epoch22, train_loss: 0.5103302107137793
Epoch23, train_loss: 0.5065411592231077
Epoch24, train_loss: 0.5034951181972728
Epoch25, train_lo

In [314]:
def inference(batch,  model):
    model.eval()
    with torch.inference_mode():
        src, tgt = batch['src'], batch['tgt']
        preds = model(src, tgt)
        preds = preds.argmax(dim=-1)

    return preds

In [317]:
test_input = next(iter(test_dataloader))
outs = inference(test_input, net)

In [320]:
test_input['src'][:5], outs[:5]

(tensor([[9, 1, 1, 3, 0],
         [9, 9, 1, 3, 6],
         [4, 3, 5, 6, 3],
         [6, 9, 2, 0, 0],
         [3, 0, 3, 3, 3]]),
 tensor([[0, 3, 1, 1, 9],
         [0, 3, 1, 9, 9],
         [0, 6, 5, 3, 4],
         [0, 0, 2, 9, 6],
         [0, 3, 3, 0, 3]]))