In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import torchvision

import numpy as np
from datasets import load_dataset
from tokenizers import Tokenizer

  from .autonotebook import tqdm as notebook_tqdm


### Load in the SNLI dataset.

In [2]:
snli_train = load_dataset("snli", cache_dir="../data", split="train").filter(lambda d: d["label"] != -1)
snli_test = load_dataset("snli", cache_dir="../data", split="test").filter(lambda d: d["label"] != -1)

In [3]:
print(snli_train[0])

{'premise': 'A person on a horse jumps over a broken down airplane.', 'hypothesis': 'A person is training his horse for a competition.', 'label': 1}


### Tokenize datasets using a pretrained tokenizer.

In [4]:
def tokenization(batch):
    encodings = tokenizer.encode_batch(
        list(zip(batch["premise"], batch["hypothesis"]))
    )
    del batch["premise"]
    del batch["hypothesis"]
    batch["ids"] = [encoding.ids for encoding in encodings]
    batch["type_ids"] = [encoding.type_ids for encoding in encodings]
    batch["mask"] = [
        [bool(i) for i in encoding.attention_mask] for encoding in encodings
    ]
    return batch

seq_len = 128
tokenizer = Tokenizer.from_pretrained("roberta-base")
tokenizer.enable_truncation(seq_len)
tokenizer.enable_padding(length=seq_len)

snli_train_tokenized = snli_train.map(
    tokenization, batched=True, batch_size=1024
)
snli_test_tokenized = snli_test.map(
    tokenization, batched=True, batch_size=1024
)
snli_train_tokenized.set_format(type="numpy")
snli_test_tokenized.set_format(type="numpy")

### Batch the SNLI data with Pytorch dataloaders.

In [5]:
batch_size = 256
train_dataloader = DataLoader(snli_train_tokenized, batch_size, shuffle=True, num_workers=6)
test_dataloader = DataLoader(snli_test_tokenized, batch_size, num_workers=6)
print(len(train_dataloader), len(test_dataloader))

2146 39


### Build BLSTM model.

In [6]:
class BLSTM(nn.Module):
    def __init__(self, vocab_size, embed_size=192, dropout=0.1):
        super().__init__()

        self.embed = nn.Embedding(vocab_size, embed_size)
        self.type_embed = nn.Embedding(2, embed_size)
        
        self.blstm_first = nn.LSTM(input_size=embed_size, hidden_size=embed_size, dropout=dropout, bidirectional=True)
        self.blstm = nn.LSTM(input_size=embed_size * 2, hidden_size=embed_size, dropout=dropout, bidirectional=True)

        self.output_layer = nn.Sequential(
            nn.Flatten(),
            nn.Linear(embed_size * seq_len * 2, 3)
        )
    
    def forward(self, batch):
        # calculate embeddings
        ids, type_ids, mask = batch
        embeddings = self.embed(ids)
        type_embeddings = self.type_embed(type_ids)
        embeddings = embeddings + type_embeddings

        # lstm layers
        output, _ = self.blstm_first(embeddings)
        output, _ = self.blstm(output)
        output, _ = self.blstm(output)
        
        # output layer
        output = self.output_layer(output)

        # apply masking
        # if mask is not None:
        #     mask = torch.unsqueeze(mask, -1)
        #     output = torch.unsqueeze(output, -1)
        #     output = torch.where(mask, 0, output)
        #     print(output.shape)
        # return output

    
model = BLSTM(vocab_size=tokenizer.get_vocab_size())
print(model)

BLSTM(
  (embed): Embedding(50265, 192)
  (type_embed): Embedding(2, 192)
  (blstm): LSTM(192, 192, dropout=0.1, bidirectional=True)
  (linear): Linear(in_features=384, out_features=192, bias=True)
  (output_layer): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=49152, out_features=3, bias=True)
  )
)




### Define Device.

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


### Define loss function and optimizer.

In [8]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=5e-5, weight_decay=1e-2)

### Define training and testing steps.

In [9]:
def train_step(X, y):
    with torch.enable_grad():
        loss = loss_fn(model(X), y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    return loss.item()

In [10]:
def test_step(X, y):
    with torch.no_grad():
        preds = model(X)
        loss = loss_fn(preds, y)
    accurate = (preds.argmax(1) == y).type(torch.int).sum()
    return loss.item(), accurate.item()

### Define training and testing loops.

In [11]:
def train(dataloader):
    model.train()
    train_loss = 0.0
    for batch in dataloader:
        ids, type_ids, mask = batch["ids"], batch["type_ids"], batch["mask"]
        ids = ids.to(device)
        type_ids = type_ids.to(device)
        mask = mask.to(device)
        X = (ids, type_ids, mask)
        y = batch["label"].to(device)
        train_loss += train_step(X, y)

    print(f"Train loss: {train_loss / len(dataloader)}")

In [12]:
def test(dataloader):
    model.eval()
    test_loss, accurate = 0.0, 0
    for batch in dataloader:
        ids, type_ids, mask = batch["ids"], batch["type_ids"], batch["mask"]
        ids = ids.to(device)
        type_ids = type_ids.to(device)
        mask = mask.to(device)
        X = (ids, type_ids, mask)
        y = batch["label"].to(device)
        loss, acc = test_step(X, y)
        test_loss += loss
        accurate += acc

In [13]:
def train_loop(
    train_dataloader,
    test_dataloader,
    epochs,
    test_every
):
    model.to(device)
    for i in range(epochs):
        epoch = (i + 1)
        print(f"Epoch {epoch}\n----------------")
        train(train_dataloader)
        if (epoch % test_every == 0):
            test(test_dataloader)
        print(f"----------------")

In [14]:
train_loop(train_dataloader, test_dataloader, 1, 5)

Epoch 1
----------------


KeyboardInterrupt: 