In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from typing import Tuple

from datasets import load_dataset
from tokenizers import Tokenizer

### Load in the SNLI dataset.

In [2]:
snli_train = load_dataset("snli", cache_dir="../data", split="train").filter(
    lambda d: d["label"] != -1
)
snli_test = load_dataset("snli", cache_dir="../data", split="test").filter(
    lambda d: d["label"] != -1
)

### Tokenize datasets using a pretrained tokenizer.

In [3]:
seq_len = 128
tokenizer = Tokenizer.from_pretrained("roberta-base")
tokenizer.enable_truncation(seq_len)
tokenizer.enable_padding(length=seq_len)

In [4]:
def tokenization(batch):
    encodings = tokenizer.encode_batch(
        list(zip(batch["premise"], batch["hypothesis"]))
    )
    batch["ids"] = [encoding.ids for encoding in encodings]
    batch["type_ids"] = [encoding.type_ids for encoding in encodings]
    batch["mask"] = [
        [not bool(i) for i in encoding.attention_mask] for encoding in encodings
    ]
    return batch

In [5]:
snli_train_tokenized = snli_train.map(
    tokenization, batched=True, batch_size=1024, remove_columns=["premise", "hypothesis"], 
)
snli_test_tokenized = snli_test.map(
    tokenization, batched=True, batch_size=1024, remove_columns=["premise", "hypothesis"]
)
snli_train_tokenized.set_format(type="numpy")
snli_test_tokenized.set_format(type="numpy")

### Batch the SNLI data with Pytorch dataloaders.

In [6]:
batch_size = 256
train_dataloader = DataLoader(snli_train_tokenized, batch_size, shuffle=True, num_workers=0, pin_memory=True)
test_dataloader = DataLoader(snli_test_tokenized, batch_size, num_workers=0, pin_memory=True)
print(len(train_dataloader), len(test_dataloader))

2146 39


In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


### Build BLSTM model.

In [8]:
class BiLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, dropout=0.1):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=input_size, hidden_size=hidden_size, batch_first=True, bidirectional=True
        )
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask):
        y, _ = self.lstm(x)
        if mask is not None:
            y = torch.where(torch.unsqueeze(mask, -1), 0, y)
        y = self.dropout(y)
        return y

class Model(nn.Module):
    def __init__(self, vocab_size, embed_size=192, dropout=0.1):
        super().__init__()

        self.embed = nn.Embedding(vocab_size, embed_size)
        self.type_embed = nn.Embedding(2, embed_size)

        self.lstm1 = BiLSTM(input_size=embed_size, hidden_size=embed_size, dropout=dropout)
        self.lstm2 = BiLSTM(input_size=embed_size * 2, hidden_size=embed_size, dropout=dropout)
        self.lstm3 = BiLSTM(input_size=embed_size * 2, hidden_size=embed_size, dropout=dropout)

        self.output_layer = nn.Sequential(
            nn.Flatten(),
            nn.Linear(embed_size * seq_len * 2, 3)
        )
    
    def forward(self, ids, type_ids, mask):
        # calculate embeddings
        embeddings = self.embed(ids)
        type_embeddings = self.type_embed(type_ids)
        embeddings = embeddings + type_embeddings

        # lstm layers
        output = self.lstm1(embeddings, mask=mask)
        output = self.lstm2(output, mask=mask)
        output = self.lstm3(output, mask=mask)

        # output layer
        output = self.output_layer(output)
        return output

    
model = Model(vocab_size=tokenizer.get_vocab_size()).to(device)
print(model)
model = torch.jit.script(model)

Model(
  (embed): Embedding(50265, 192)
  (type_embed): Embedding(2, 192)
  (lstm1): BiLSTM(
    (lstm): LSTM(192, 192, batch_first=True, bidirectional=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (lstm2): BiLSTM(
    (lstm): LSTM(384, 192, batch_first=True, bidirectional=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (lstm3): BiLSTM(
    (lstm): LSTM(384, 192, batch_first=True, bidirectional=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (output_layer): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=49152, out_features=3, bias=True)
  )
)


### Define loss function and optimizer.

In [9]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-2)

### Define training and testing steps.

In [10]:
def train_step(X, y):
    loss = loss_fn(model(*X), y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss

In [11]:
def test_step(X, y):
    with torch.no_grad():
        preds = model(*X)
        loss = loss_fn(preds, y)
    accurate = (preds.argmax(1) == y).type(torch.int).sum()
    return loss, accurate

### Define training and testing loops.

In [12]:
def train(dataloader):
    model.train()
    train_loss = 0.0
    for batch in dataloader:
        ids, type_ids, mask = batch["ids"], batch["type_ids"], batch["mask"]
        ids, type_ids, mask = ids.to(device), type_ids.to(device), mask.to(device)
        y = batch["label"].to(device)
        train_loss += train_step((ids, type_ids, mask), y).item()

    print(f"Train loss: {train_loss / len(dataloader)}")

In [13]:
def test(dataloader):
    model.eval()
    test_loss, accurate = 0.0, 0
    for batch in dataloader:
        ids, type_ids, mask = batch["ids"], batch["type_ids"], batch["mask"]
        ids, type_ids, mask = ids.to(device), type_ids.to(device), mask.to(device)
        y = batch["label"].to(device)
        loss, acc = test_step((ids, type_ids, mask), y)
        test_loss += loss.item()
        accurate += acc.item()
    
    print(f"Test loss: {test_loss / len(dataloader)}, accuracy: {accurate / len(dataloader.dataset)}")

In [14]:
def train_loop(
    train_dataloader,
    test_dataloader,
    epochs,
    test_every
):
    for i in range(epochs):
        epoch = (i + 1)
        print(f"Epoch {epoch}\n----------------")
        train(train_dataloader)
        if (epoch % test_every == 0):
            test(test_dataloader)
        print(f"----------------")

In [15]:
train_loop(train_dataloader, test_dataloader, 10, 1)

Epoch 1
----------------
Train loss: 0.8680547021697419
Test loss: 0.7866069017312466, accuracy: 0.6434242671009772
----------------
Epoch 2
----------------
Train loss: 0.7467867932859582
Test loss: 0.717893090003576, accuracy: 0.6900447882736156
----------------
Epoch 3
----------------
Train loss: 0.6858842495380999
Test loss: 0.6722705150261904, accuracy: 0.711421009771987
----------------
Epoch 4
----------------
Train loss: 0.6435230911222442
Test loss: 0.6531494519649408, accuracy: 0.7258754071661238
----------------
Epoch 5
----------------
Train loss: 0.605972983505497
Test loss: 0.6294634892390325, accuracy: 0.7356473941368078
----------------
Epoch 6
----------------
Train loss: 0.5720242089960804
Test loss: 0.6101921047919836, accuracy: 0.7477605863192183
----------------
Epoch 7
----------------
Train loss: 0.5397620512743698
Test loss: 0.6056157640921764, accuracy: 0.7478623778501629
----------------
Epoch 8
----------------
Train loss: 0.5074952808586327
Test loss: 0.602