In [1]:
import numpy as np
import torch
import math
import os

from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
from torch import nn
from pprint import pprint

In [2]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=1)

In [3]:
MAX_LEN = 2048 # @param {type:"integer"}
TRAIN_BATCH_SIZE = 32 # @param {type:"integer"}
VALID_BATCH_SIZE = 32 # @param {type:"integer"}
EPOCHS = 20 # @param {type:"integer"}
LEARNING_RATE = 5e-2 # @param {type:"number"}
WEIGHT_DECAY = 1e-5 # @param {type:"number"}

MODEL_FILE_NAME = "models/transformer.bin"
LOG_FILE_NAME = "logs/transformer.log"

PAD_TOKEN = 0
CLS_TOKEN = 1

In [4]:
## Architecture params
EMBEDDING_PARAMS = dict(
    num_embeddings=255+2,  # PAD, CLS, chars
    embedding_dim=512,
)
TRANSFORMER_PARAMS = dict(
    nhead=8,
    d_model=512,
    dim_feedforward=2048,
    batch_first=True,
)
NUM_TRANSFORMER_ENCODER_LAYERS = 6

## Model
Let's start by building a model with a few convolutional layers to pool local information and join letters into words.

In [5]:
def positional_encoding(length, d_model):
    if d_model % 2 != 0:
        raise ValueError("Cannot use sin/cos positional encoding with "
                         "odd dim (got dim={:d})".format(d_model))
    pe = torch.zeros(length, d_model)
    position = torch.arange(0, length, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(
        torch.arange(0, d_model, 2, dtype=torch.float) *
        (- math.log(10000.0) / d_model)
    )
    pe[:, 0::2] = torch.sin(position.float() * div_term)
    pe[:, 1::2] = torch.cos(position.float() * div_term)
    pe.requires_grad = False

    return pe.to(device)

In [6]:
class TransformerModel(nn.Module):
    def __init__(
        self,
        embedding_params: dict,
        transformer_encoder_params: dict,
        num_transformer_encoder_layers: int
    ):
        super(TransformerModel, self).__init__()

        self.embedding = nn.Embedding(**embedding_params)
        encoder_layer = nn.TransformerEncoderLayer(**transformer_encoder_params)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_transformer_encoder_layers)

    def forward(
        self,
        input_ids,  # (...BATCH, LENGTH)
        src_key_padding_mask=None,  # (...BATCH, LENGTH)
    ):
        x = self.embedding(input_ids)  # (...BATCH, LENGTH, EMBED_DIM)
        *_, length, d_model = x.shape
        x += positional_encoding(length, d_model)
        x = self.transformer_encoder.forward(x, src_key_padding_mask=src_key_padding_mask)
        return x

In [7]:
class ModelWithHead(nn.Module):
    def __init__(self, model, d_model):
        super(ModelWithHead, self).__init__()

        self.model = model
        self.head = nn.Linear(d_model, 1)

    def forward(
        self,
        input_ids,  # (...BATCH, LENGTH)
        attention_mask=None,  # (...BATCH, LENGTH)
    ):
        x = self.model(input_ids, attention_mask)
        x = x[:, 0, :]
        x = self.head(x)
        return torch.sigmoid(x)

In [8]:
if os.path.exists(MODEL_FILE_NAME):
    model = torch.load(MODEL_FILE_NAME)
else:
    model = TransformerModel(
        embedding_params=EMBEDDING_PARAMS,
        transformer_encoder_params=TRANSFORMER_PARAMS,
        num_transformer_encoder_layers=NUM_TRANSFORMER_ENCODER_LAYERS,
    )
model = ModelWithHead(model, d_model=TRANSFORMER_PARAMS["d_model"]).to(device)

In [9]:
pytorch_total_params = sum(p.numel() for p in model.parameters())
pytorch_total_params

19046401

## Data loader and collator

In [10]:
class TextClassificationDataset(Dataset):
    def __init__(self, root_dir, split="train"):
        """
        Args:
            root_dir (string): Directory with all the data.
            split (string): One of "train" or "test" to specify the split.
        """
        self.root_dir = os.path.join(root_dir, split)
        pos_dir = os.path.join(self.root_dir, "pos")
        self.pos_files = os.listdir(pos_dir)

        neg_dir = os.path.join(self.root_dir, "neg")
        self.neg_files = os.listdir(neg_dir)

    def __len__(self):
        return len(self.pos_files) + len(self.neg_files)

    def __getitem__(self, idx):
        if idx < len(self.pos_files):
            file = self.pos_files[idx]
            with open(os.path.join(self.root_dir, "pos", file), 'r') as f:
                text = f.read()
            label = 1
        else:
            file = self.neg_files[idx - len(self.pos_files)]
            with open(os.path.join(self.root_dir, "neg", file), 'r') as f:
                text = f.read()
            label = 0
        return {"text": text, "label": label}


In [11]:
class Collator:
    def __init__(self, max_length):
        self.max_length = max_length

    def __call__(self, batch):
        input_ids = []
        attention_masks = []
        labels = []

        for item in batch:
            text, label = item["text"], item["label"]
            indices = [CLS_TOKEN] + [int(b) for b in bytes(text, encoding="utf-8")]
            length = min(len(indices), self.max_length)
            padding_size = self.max_length - length
            indices = indices[:length] + [PAD_TOKEN] * padding_size
            attention_mask = [1.] * length + [0.] * padding_size

            input_ids.append(indices)
            attention_masks.append(attention_mask)
            labels.append(label)

        return {"input_ids": torch.tensor(input_ids), "attention_mask": torch.tensor(attention_masks), "labels": torch.tensor(labels)}

## Train model

In [13]:
train_set = TextClassificationDataset("../datasets/aclImdb", split="train")
test_set = TextClassificationDataset("../datasets/aclImdb", split="test")

train_params = {'batch_size': TRAIN_BATCH_SIZE,
                'shuffle': True,
                'num_workers': 0,
                'collate_fn': Collator(MAX_LEN),
                }

test_params = {'batch_size': VALID_BATCH_SIZE,
                'shuffle': False,
                'num_workers': 0,
                'collate_fn': Collator(MAX_LEN),
                }

training_loader = DataLoader(train_set, **train_params)
testing_loader = DataLoader(test_set, **test_params)

In [18]:
20000 * 32 / len(train_set)

25.6

In [37]:
optimizer = torch.optim.AdamW(params =  model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

In [38]:
def train(epoch):
    model.train()
    for ind, item in enumerate(training_loader):
        input_ids, attention_mask, labels = item["input_ids"].to(device), item["attention_mask"].to(device), item["labels"].to(device)

        loss = F.binary_cross_entropy(model(input_ids, attention_mask), labels.reshape(-1, 1).float())
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if ind % 500 == 0:
            with open(LOG_FILE_NAME, "a") as f:
                print(f"TRAIN - {epoch=}, loss={loss.item()}", file=f)

In [39]:
def test(epoch):
    losses = []
    accuracies = []
    with torch.no_grad():
        for ind, item in enumerate(testing_loader):
            input_ids, attention_mask, labels = item["input_ids"].to(device), item["attention_mask"].to(device), item["labels"].to(device)
            output = model(input_ids, attention_mask)

            loss = F.binary_cross_entropy(output, labels.reshape(-1, 1).float())
            losses.append(loss.item())
            accurcy = ((output > 0.5) == labels.reshape(-1, 1)).float().mean()
            accuracies.append(accurcy.item())

    mean_loss = np.mean(losses)
    mean_accuracy = np.mean(accuracies)

    with open(LOG_FILE_NAME, "a") as f:
        print(f"EVAL - {epoch=}, {mean_loss=}, {mean_accuracy=}", file=f)
    return np.mean(losses)


In [40]:
best_test_loss = 1e10
for epoch in range(EPOCHS):
    train(epoch)
    if epoch % 5 == 0:
        test_loss = test(epoch)
        if test_loss < best_test_loss:
            best_test_loss = test_loss
            torch.save(model.model, MODEL_FILE_NAME)

KeyboardInterrupt: 