In [141]:
import re
from collections import Counter

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Running on device: {device}")

Running on device: mps


In [150]:
imdb_reviews = []
max_review_length = 500
with open('datasets/IMDB Dataset.csv', 'r') as f:
    lines = f.readlines()
    larger_lines = 0
    for line in lines[1:]:
        # Split line by last comma
        line = line.rsplit(',', 1)
        # Unwrap line zero from double quotes
        review = line[0][1:-1]
        if len(review) > max_review_length:
            continue
        # Label is either 'pos' or 'neg'
        label = 1. if "pos" in line[1] else 0.
        imdb_reviews.append((label, review))

print(len(imdb_reviews))
train_reviews = imdb_reviews[:4000]
val_reviews = imdb_reviews[4000:4500]
test_reviews = imdb_reviews[4500:]

print("Sample:", train_reviews[0][0], ",", train_reviews[0][1][0:10], "...")
print("Sample:", train_reviews[3][0], ",", train_reviews[3][1][0:10], "...")

4964
Sample: 1.0 , If you lik ...
Sample: 0.0 , The plot i ...


In [151]:
# Tokenizer
def tokenizer(text):
    text = re.sub('<[^>]*>', '', text)
    emoticons = re.findall(r'(?::|;|=)(?:-)?(?:\)|\(|D|P)', text.lower())
    text = re.sub(r'[\W]+', ' ', text.lower()) + ' '.join(emoticons).replace('-', '')
    tokenized = text.split()
    return tokenized


token_counts = Counter()
for label, line in train_reviews:
    tokens = tokenizer(line)
    token_counts.update(tokens)

print('Tokens:', len(token_counts))

Tokens: 16006


In [152]:
# Create vocabulary
class Vocabulary:

    def __init__(self, tokens):
        tokens = ['<pad>', '<unk>'] + sorted(sorted(tokens))
        self.stoi = {t: i for i, t in enumerate(tokens)}
        self.itos = {i: t for t, i in self.stoi.items()}

    def __len__(self):
        return len(self.stoi)

    def encode(self, tokens):
        return [self.stoi.get(t, 1) for t in tokens]

    def decode(self, indices):
        return [self.itos.get(i, '<unk>') for i in indices]


vocab = Vocabulary(token_counts)
print('Encode:', vocab.encode(['the', 'quick', 'brown', 'fox', 'unknown__token']))
print('Decode:', vocab.decode(vocab.encode(['the', 'quick', 'brown', 'fox', 'unknown__token'])))

Encode: [14217, 11223, 2011, 5664, 1]
Decode: ['the', 'quick', 'brown', 'fox', '<unk>']


In [153]:

class IMDBDataset(Dataset):
    def __init__(self, reviews):
        self.y = [r[0] for r in reviews]
        self.X = [r[1] for r in reviews]

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.y[idx], self.X[idx]


def collate_fn(batch):
    label_list, text_list, lengths = [], [], []
    for _label, _text in batch:
        label_list.append(_label)
        processed_text = torch.tensor(vocab.encode(_text), dtype=torch.int64)
        text_list.append(processed_text)
        lengths.append(processed_text.size(0))
    label_list = torch.tensor(label_list)
    lengths = torch.tensor(lengths)
    padded_text_list = nn.utils.rnn.pad_sequence(text_list, batch_first=True)
    return padded_text_list, label_list, lengths


train_dataset = IMDBDataset(train_reviews)
val_dataset = IMDBDataset(val_reviews)
test_dataset = IMDBDataset(test_reviews)

train_dl = DataLoader(train_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)
val_dl = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)
test_dl = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

text_batch, label_batch, length_batch = next(iter(train_dl))
print("Text", text_batch.shape)
print("Label", label_batch.shape)
print("Length", length_batch.shape)

Text torch.Size([32, 499])
Label torch.Size([32])
Length torch.Size([32])


In [157]:
class RNN(nn.Module):

    def __init__(self, vocab_size, embed_dim, rnn_hidden_size, fc_hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=vocab_size,
                                      embedding_dim=embed_dim,
                                      padding_idx=0)
        self.rnn = nn.LSTM(input_size=embed_dim,
                           hidden_size=rnn_hidden_size,
                           batch_first=True)
        self.fc1 = nn.Linear(in_features=rnn_hidden_size,
                             out_features=fc_hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(in_features=fc_hidden_size,
                             out_features=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, text, lengths):
        out = self.embedding(text)
        out = nn.utils.rnn.pack_padded_sequence(out, lengths.to("cpu"), batch_first=True, enforce_sorted=False)
        out, (hidden, cell) = self.rnn(out)
        out = hidden[-1, :, :]
        out = self.fc1(out)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.sigmoid(out)
        return out


vocab_size = len(vocab)
embed_dim = 20
rnn_hidden_size = 64
fc_hidden_size = 64
torch.manual_seed(1)
model = RNN(vocab_size, embed_dim, rnn_hidden_size, fc_hidden_size).to(device)

print("Num params: ", sum(p.numel() for p in model.parameters()))

model

Num params:  346401


RNN(
  (embedding): Embedding(16008, 20, padding_idx=0)
  (rnn): LSTM(20, 64, batch_first=True)
  (fc1): Linear(in_features=64, out_features=64, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=64, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

In [158]:
def train(dataloader, model, optimizer, loss_fn):
    model.train()
    total_acc, total_loss = 0, 0
    for i, data in enumerate(tqdm(dataloader)):
        text_batch, label_batch, lengths = data
        text_batch = text_batch.to(device)
        label_batch = label_batch.to(device)
        lengths = lengths.to(device)
        optimizer.zero_grad()
        pred = model(text_batch, lengths)[:, 0]
        loss = loss_fn(pred, label_batch)
        loss.backward()
        optimizer.step()
        total_acc += ((pred > 0.5) == label_batch).sum().item()
        total_loss += loss.item() * label_batch.size(0)
    return total_acc / len(dataloader.dataset), \
           total_loss / len(dataloader.dataset)


def evaluate(dataloader, model, loss_fn):
    model.eval()
    total_acc, total_loss = 0, 0
    with torch.no_grad():
        for i, data in enumerate(tqdm(dataloader)):
            text_batch, label_batch, lengths = data
            text_batch = text_batch.to(device)
            label_batch = label_batch.to(device)
            lengths = lengths.to(device)
            pred = model(text_batch, lengths)[:, 0]
            loss = loss_fn(pred, label_batch)
            total_acc += ((pred > 0.5) == label_batch).sum().item()
            total_loss += loss.item() * label_batch.size(0)
    return total_acc / len(dataloader.dataset), \
           total_loss / len(dataloader.dataset)

loss_fn = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
num_epochs = 10
torch.manual_seed(1)
for epoch in range(num_epochs):
    acc_train, loss_train = train(train_dl, model, optimizer, loss_fn)
    acc_val, loss_val = evaluate(val_dl, model, loss_fn)
    print(f'Epoch {epoch} accuracy: {acc_train:.4f}, val_acc {acc_val:.4f}')

100%|██████████| 125/125 [01:20<00:00,  1.54it/s]
100%|██████████| 16/16 [00:02<00:00,  5.68it/s]


Epoch 0 accuracy: 0.5353, val_acc 0.5520


100%|██████████| 125/125 [01:19<00:00,  1.58it/s]
100%|██████████| 16/16 [00:02<00:00,  6.98it/s]


Epoch 1 accuracy: 0.5595, val_acc 0.5380


100%|██████████| 125/125 [01:19<00:00,  1.57it/s]
100%|██████████| 16/16 [00:02<00:00,  7.05it/s]


Epoch 2 accuracy: 0.5680, val_acc 0.5460


100%|██████████| 125/125 [01:17<00:00,  1.62it/s]
100%|██████████| 16/16 [00:02<00:00,  6.98it/s]


Epoch 3 accuracy: 0.5725, val_acc 0.5780


 15%|█▌        | 19/125 [00:11<01:05,  1.61it/s]


KeyboardInterrupt: 