In [1]:
from pathlib import Path
import re
from collections import Counter
from functools import cached_property
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD
from torch import Tensor
from torch.utils.data import Dataset, random_split, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torchvision.datasets.utils import download_and_extract_archive
import tqdm

In [2]:
# adapted from https://github.com/pytorch/text/blob/v0.18.0/torchtext/data/utils.py#L17-L21
RAW_PATTERN_TO_REPLACEMENT_DICT = {
    r"\'": " '  ", 
    r"\"": "", 
    r"\.": " . ", 
    r"<br \/>": " ", 
    r",": " , ", 
    r"\(": " ( ", 
    r"\)": " ) ", 
    r"\!": " ! ", 
    r"\?": " ? ", 
    r"\;": " ", 
    r"\:": " ", 
    r"\s+": " ",
}

PATTERN_TO_REPLACEMENT_DICT = {
    re.compile(pattern): replacement
    for pattern, replacement 
    in RAW_PATTERN_TO_REPLACEMENT_DICT.items()
}

In [3]:
def normalize(
    line: str, 
    pattern_to_replacement_dict: dict[re.Pattern, str] = PATTERN_TO_REPLACEMENT_DICT
) -> str:
    """
    """
    line = line.lower()
    for pattern, replacement in pattern_to_replacement_dict.items():
        line = pattern.sub(replacement, line)
    return line

def tokenize(line: str) -> list[str]:
    """
    """
    return line.split()

def encode(
    token_list: list[str],
    token_to_index: dict[str, int],
) -> list[int]:
    """
    Args:
        token_list:
        token_to_index:
    Returns:
        
    """
    return list(map(token_to_index.__getitem__, token_list))

def preprocess(
    data: str, 
    token_to_index: dict[str, int]
) -> torch.Tensor:
    """
    Args:
        data (str): IMDB review
        token_to_index: a dict that maps tokens to nonnegative ints
    """
    data = normalize(data)
    data = tokenize(data)
    data = encode(data, token_to_index=token_to_index)
    data = torch.tensor(data, dtype=torch.long)
    return data

In [4]:
url = 'http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz'

In [5]:
data_dir = Path('./data/')

In [6]:
if not data_dir.exists():
    download_and_extract_archive(
        url=url, 
        download_root=data_dir
    )

In [7]:
def get_token_counter(data_dir) -> Counter:
    token_counter = Counter()
    for split in ['train', 'test']:
        for label in ['neg', 'pos']:
            label_dir = data_dir / 'aclImdb' / split / label
            path_list = list(label_dir.glob('*.txt'))
            for path in tqdm.tqdm(path_list, desc=f'{split=}, {label=}, {len(path_list)=}'):
                with open(path) as stream:
                    data = stream.read()
                token_counter += Counter(tokenize(normalize((data))))
    return token_counter

In [8]:
token_counter = get_token_counter(data_dir)

split='train', label='neg', len(path_list)=12500: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12500/12500 [00:23<00:00, 534.18it/s]
split='train', label='pos', len(path_list)=12500: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12500/12500 [00:47<00:00, 262.90it/s]
split='test', label='neg', len(path_list)=12500: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12500/12500 [01:24<00:00, 148.22it/s]
split='test', label='pos', len(path_list)=12500: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12500/12500 [02:06<00:00, 99.08it/s]


In [9]:
token_list = [token for token, _ in token_counter.most_common()]

In [10]:
# NOTE: index of 0
token_to_index = {token: index for index, token in enumerate(token_list, start=1)}

In [11]:
class IMDBDataset(Dataset):

    LABEL_NAME_LIST = ['neg', 'pos']

    def __init__(
        self,
        data_dir: Path, 
        token_to_index: dict[str, int],
        train: bool = True,
    ) -> None:
        """
        """

        split = 'train' if train else 'test'
        split_dir = data_dir / 'aclImdb' / split
            
        self.example_list = []
        for label, label_name in enumerate(self.LABEL_NAME_LIST):
            label_dir = split_dir / label_name
            path_list = list(label_dir.glob('*.txt'))
            for each in tqdm.tqdm(path_list):
                with open(each) as stream:
                    text = stream.read()
                    data = preprocess(text, token_to_index=token_to_index)
                    self.example_list.append((text, data, label))

    def __len__(self) -> int:
        return len(self.example_list)

    def __getitem__(self, index: int):
        return self.example_list[index]

In [12]:
data_dir = Path('./data')

In [13]:
train_set = IMDBDataset(data_dir=data_dir, token_to_index=token_to_index, train=True)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12500/12500 [00:02<00:00, 4487.60it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12500/12500 [00:02<00:00, 4359.19it/s]


In [14]:
train_set, val_set = random_split(
    dataset=train_set,
    lengths=(0.8, 0.2),
    generator=torch.Generator().manual_seed(1337)
)

In [15]:
len(train_set), len(val_set)

(20000, 5000)

In [16]:
Counter(val_set[index][2] for index in range(len(val_set)))

Counter({0: 2512, 1: 2488})

In [17]:
def collate(batch):
    text, data, label = zip(*batch)
    text = list(text)
    length = torch.tensor([len(each) for each in data], dtype=torch.long)
    data = pad_sequence(
        sequences=data, 
        batch_first=True, 
        padding_value=0,
    )
    label = torch.tensor(label)
    return text, data, length, label

In [18]:
train_loader = DataLoader(
    dataset=train_set,
    batch_size=128,
    collate_fn=collate,
    shuffle=True,
    drop_last=True,
)

val_loader = DataLoader(
    dataset=val_set,
    batch_size=128,
    collate_fn=collate,
    shuffle=False,
    drop_last=False
)

In [19]:
class CBoW(nn.Module):

    def __init__(self, num_tokens: int):
        super().__init__()

        self.embedding = nn.Embedding(
            num_embeddings=(num_tokens + 1),
            embedding_dim=32,
            padding_idx=0
        )

        self.mlp = nn.Sequential(
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
        )


    def forward(
        self, 
        x: Tensor,
        length: Tensor,
    ) -> Tensor:
        """
        Args:
            x: a tensor
        Returns:
            output:
        Shape:
            input: (N, L)
            output: (N, )
        """
        h = self.embedding(x)
        h = h.sum(dim=1) / length.unsqueeze(dim=-1).to(h.dtype)
        logits = self.mlp(h)
        logits = logits.squeeze()
        return logits

In [20]:
model = CBoW(
    num_tokens=len(token_to_index),
)
optimizer = SGD(
    params=model.parameters(),
    lr=1, # 
)

In [None]:
max_epochs = 100
threshold = 0.5

for epoch in range(max_epochs + 1):
    if epoch > 0:
        model.train()
        for _, x, x_len, y_true in tqdm.tqdm(train_loader, desc='training'):
            optimizer.zero_grad()
            y_logits = model(x=x, length=x_len)
            loss = F.binary_cross_entropy_with_logits(
                input=y_logits, 
                target=y_true.float()
            )
            loss.backward()
            optimizer.step()

    with torch.no_grad():
        # validation
        model.eval()

        val_loss = 0
        val_total = 0
        val_correct = 0

        for _, x, x_len, y_true in tqdm.tqdm(val_loader, desc='validation'):
            y_logits = model(x=x, length=x_len)
            y_score = y_logits.sigmoid()
            y_pred = y_score.gt(threshold).long()

            loss = F.binary_cross_entropy_with_logits(
                input=y_logits, 
                target=y_true.float(), 
                reduction='sum',
            )

            val_loss += loss.item()
            val_total += len(y_true)
            val_correct += y_pred.eq(y_true).sum().item()
    
        val_loss /= val_total
        val_acc = val_correct / val_total
        print(f'{epoch=: >6d}: Loss={val_loss:.3f} Accuracy={100 * val_acc:.2f} %')