In [1]:
import sys

sys.path.insert(0, "../..")
from src.data import data_tools, make_dataset
from src.models import tokenizer


We load the IMDB dataset. This is the MNIST for language models

In [None]:
from src.data.make_dataset import DatasetFactoryProvider, DatasetType
imdbdatasetfactory = DatasetFactoryProvider.get_factory(DatasetType.IMDB)
datasets = imdbdatasetfactory.create_dataset()
traindataset = datasets["train"]
testdataset = datasets["test"]

It consists of 50k movie reviews, labeled positive or negative

let's have a look at the first datapoint

In [5]:
x, y = traindataset[0]
x, y


("My observations: vamp outfit at end is ravishing and wonderful, exotic and fantastic. Jeanette wore it well, and got even with naive Nelson. Boat crashing into his balcony served him right. Costume outfits of his female mafia were designed surprisingly well, especially by today's standards. 1942 costume designer did great job. Main song theme just lovely.<br /><br />Caution to negative posters: 1942 was time of WW II; Pearl Harbor happened year before. U.S. just coming out of Great Depression; needed to get out and spend that hard earned money on diversion of singing, dance and yes, fantastic fantasy. Despotic dictators were trying to rule out there in RL, snuffing out freedoms. Thank goodness the public had these fantastic plot line movies to attend. Movie going was a privileged treat, in those depressing times. When you, negative posters, become actors or even movie stars, then YOU have room to talk and criticize. Jeanette's and Nelson's movies stand the test of time.<br /><br />An

This is messy data. We have Uppercase, punctuation, and even html tags. Let's clean that out in order to reduce dimensionality, without loosing too much information about the sentiment.

In [6]:
import string

punctuation = f"[{string.punctuation}]"
punctuation


'[!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~]'

In [7]:
import re


def clean(text):
    punctuation = f"[{string.punctuation}]"
    # remove CaPiTaLs
    lowercase = text.lower()
    # change don't and isn't into dont and isnt
    neg = re.sub("\\'", "", lowercase)
    # swap html tags for spaces
    html = re.sub("<br />", " ", neg)
    # swap punctuation for spaces
    stripped = re.sub(punctuation, " ", html)
    # remove extra spaces
    spaces = re.sub("  +", " ", stripped)
    return spaces


In [8]:
clean(x), y


('my observations vamp outfit at end is ravishing and wonderful exotic and fantastic jeanette wore it well and got even with naive nelson boat crashing into his balcony served him right costume outfits of his female mafia were designed surprisingly well especially by todays standards 1942 costume designer did great job main song theme just lovely caution to negative posters 1942 was time of ww ii pearl harbor happened year before u s just coming out of great depression needed to get out and spend that hard earned money on diversion of singing dance and yes fantastic fantasy despotic dictators were trying to rule out there in rl snuffing out freedoms thank goodness the public had these fantastic plot line movies to attend movie going was a privileged treat in those depressing times when you negative posters become actors or even movie stars then you have room to talk and criticize jeanettes and nelsons movies stand the test of time angel wings wonderful on the real angel rl wings at cos

Much better. Now we need to create a vocabulary, which is a mapping from every unique word to an arbitrary integer. We have seen this in lesson 4.

In [9]:
corpus = []
for i in range(len(traindataset)):
    x = tokenizer.clean(traindataset[i][0])
    corpus.append(x)


In [10]:
from src.models import tokenizer

v = tokenizer.build_vocab(corpus, max=10000)
len(v)


2022-12-19 22:53:20.565 | INFO     | src.models.tokenizer:build_vocab:27 - Found 79808 tokens


10002

Even after cleaning, we have about 80k unique tokens. This is even more without the cleaning, because "The" and "the" will be two different tokens.

We also have tokens for unknown words, and for padding

In [11]:
v["<UNK>"], v["<PAD>"], v["sdflkjl"]


(1, 0, 1)

This maps a sentence of words to a sequence of integers

In [12]:
[v[word] for word in clean(x).split()[:10]]


[58, 513, 1751, 5, 11, 119, 3, 2, 1137, 13]

In [13]:
from typing import List, Tuple, Optional, Callable
from torch.nn.utils.rnn import pad_sequence
import torch
from torchtext.vocab import Vocab

Tensor = torch.Tensor


class Preprocessor:
    def __init__(
        self, max: int, vocab: Vocab, clean: Optional[Callable] = None
    ) -> None:
        self.max = max
        self.vocab = vocab
        self.clean = clean

    def cast_label(self, label: str) -> int:
        if label == "neg":
            return 0
        else:
            return 1

    def __call__(self, batch: List) -> Tuple[Tensor, Tensor]:
        labels, text = [], []
        for x, y in batch:
            if clean is not None:
                x = self.clean(x)
            x = x.split()[: self.max]
            tokens = torch.tensor([self.vocab[word] for word in x])
            text.append(tokens)
            labels.append(self.cast_label(y))

        text_ = pad_sequence(text, batch_first=True, padding_value=0)
        return text_, torch.tensor(labels)


Preprocessing is necessary to:
- cut of long sentences to get equal length. 100 words will be enough to get the sentiment in most cases
- we need to cast the labels "neg" and "pos" to integers
- we also pad if a sentence is shorter than the max lenght

We can feed the preprocessor to the default dataloader from torch

In [14]:
from torch.utils.data import DataLoader

preprocessor = Preprocessor(max=100, vocab=v, clean=clean)
dataloader = DataLoader(
    traindataset, collate_fn=preprocessor, batch_size=32, shuffle=True
)


We now get batched sentences and labels

In [15]:
x, y = next(iter(dataloader))

x.shape, y.shape


(torch.Size([32, 100]), torch.Size([32]))

In [16]:
x[0]


tensor([  11,   17,    7,   21,  125,  796,  187,   17,    9,   45,   46, 2726,
         482,    3,  334,  135,   60,   68,   27,  428,    5, 2831,   10, 1272,
         419,    2,   17,   85,    5,   29,  121, 1553,  482,   29,    1,   21,
           4, 1158,  622,   17,   60,    7,   49,   44,   22,   89,  181,    6,
          27, 1754,    3, 7197, 3526,   10,  419,    2,  236,    5,    2,    1,
         227,    3,    5, 6428,   52,  867,   10,  102,  142,  166,    6,    1,
          58, 1152,  757,   38,   12,  243,    8, 4511,   52,  640,  555,  442,
         286, 2754,   15,    2, 2726,  482,    3,  698,  796,  187,  586,    0,
           0,    0,    0,    0])

All this code is wrapped into the DatasetFactoryProvider, which you can see in the next notebook.