In [1]:
import json
import os
import re

import string
import random
import numpy as np
import torch
import torch.nn as nn

from tqdm.notebook import tqdm

In [1]:
VOCAB_SIZE = 20000
MAX_SEQ_LEN = 80

EPOCHS = 30
BATCH_SIZE = 32
LEARNING_RATE = 3e-5

EMBEDDING_DIM = 256
N_HEADS = 8
FF_DIM = 256
N_LAYERS = 1
DROP_OUT = 0.1

DATASET_DIR = 'aclImdb'
TRAIN_DIR = os.path.join(DATASET_DIR, 'train')
TEST_DIR = os.path.join(DATASET_DIR, 'test')
CLASSES = ['neg', 'pos']

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
OUTPUT_DIR = 'model'

NameError: name 'os' is not defined

## Dataset and custom tokenizer

In [3]:
file_path = []
for SUBSET_DIR in [TRAIN_DIR, TEST_DIR]:
    for CLASS in CLASSES:
        CLASS_DIR = os.path.join(SUBSET_DIR, CLASS)
        for filename in os.listdir(CLASS_DIR):
            file_path.append(os.path.join(CLASS_DIR, filename))

print(f'Number of text files: {len(file_path)}')

Number of text files: 50000


In [4]:
def remove_html_tags(text):
    """Remove html tags from a string
    Reference: https://stackoverflow.com/a/9662362

    Args:
        text (str): input string/document

    Returns:
        str: string without html tags
        
    """
    TAG_RE = re.compile(r'<[^>]+>')
    return TAG_RE.sub('', text)

def standardization(text):
    # lowercase
    text = text.lower()
    
    # remove newline
    text = text.replace('\n', ' ').replace('\r', '')
    
    # remove html tags
    text = remove_html_tags(text)
    
    # remove punctuation
    text = text.translate(str.maketrans('', '', string.punctuation))
    
    # remove extra spaces
    text = re.sub(' +', ' ', text)
    
    return text

In [5]:

class CustomTokenizer:
    def __init__(self, vocab_size, max_seq_len):
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len
        self.vocab = {}
        self.reverse_vocab = {}
        self.index = 0
        self.__init_special_tokens()
        
        print(self.vocab)
        print(self.reverse_vocab)
    
    def __init_special_tokens(self):
        special_tokens = ['<PAD>', '<UNK>', '<SOS>', '<EOS>']
        for token in special_tokens:
            self.vocab[token] = self.index
            self.index += 1
        
        for k, v in self.vocab.items():
            self.reverse_vocab[v] = k
        
    def fit_text(self, text):
        for word in text.split(' '):
            if word not in self.vocab:
                if self.vocab_size <= self.index:
                    break
                
                self.vocab[word] = self.index
                self.reverse_vocab[self.index] = word
                self.index += 1
                
    def fit_corpus(self, corpus):
        for text in corpus:
            self.fit_text(text)
    
    def encode(self, text, add_sos=False, get_mask=False):
        seq = []
        for word in text.split():
            if word in self.vocab:
                seq.append(self.vocab[word])
            else:
                seq.append(self.vocab['<UNK>'])
                
        if len(seq) > self.max_seq_len:
            mask = [1] * self.max_seq_len
            seq = seq[:self.max_seq_len]
        else:
            mask = [1] * len(seq) + [0] * (self.max_seq_len - len(seq))
            seq = seq + [self.vocab['<PAD>']] * (self.max_seq_len - len(seq))
            
        if add_sos:
            seq = [self.vocab['<SOS>']] + seq[:-1]
            mask = [1] + mask[:-1]
        
        if get_mask:
            return torch.tensor(seq), torch.tensor(mask)
        else:
            return torch.tensor(seq)
    
    def __len__(self):
        return len(self.vocab)
    
    def __getitem__(self, key):
        return self.vocab[key]
    
    def decode(self, seq):
        if type(seq) == torch.Tensor:
            seq = seq.numpy()
        return ' '.join([self.reverse_vocab[i] for i in seq if i != self.vocab['<PAD>']])
        
    def from_json(self, path):
        with open(path, "r", encoding="utf8") as f:
            data = json.load(f)
            self.vocab = data["vocab"]
            self.reverse_vocab = data["reverse_vocab"]
            self.index = data["index"]
            self.vocab_size = data["vocab_size"]
            self.max_seq_len = data["max_seq_len"]
            
    def to_json(self, path):
        data = {
            "vocab": self.vocab,
            "reverse_vocab": self.reverse_vocab,
            "index": self.index,
            "vocab_size": self.vocab_size,
            "max_seq_len": self.max_seq_len,
        }
        with open(path, "w", encoding="utf8") as f:
            json.dump(data, f, ensure_ascii=False, indent=4)

In [6]:
tokenizer = CustomTokenizer(VOCAB_SIZE, MAX_SEQ_LEN)

tokenizer.fit_corpus([standardization(open(path, 'r', encoding='utf8').read()) for path in file_path])

{'<PAD>': 0, '<UNK>': 1, '<SOS>': 2, '<EOS>': 3}
{0: '<PAD>', 1: '<UNK>', 2: '<SOS>', 3: '<EOS>'}


In [7]:
text = 'this is a test sentence'

encoded_text, mask = tokenizer.encode(text, add_sos=True, get_mask=True)
print(f'Encoded text: {encoded_text}')
print(f'Mask: {mask}')
print(f'Encoded text: {tokenizer.decode(encoded_text)}')

Encoded text: tensor([   2,  211,   20,    6, 2731, 4892,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0])
Mask: tensor([1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
Encoded text: <SOS> this is a test sentence


In [8]:
def prepare_lm_input_labels(text):
    input_ids, input_mask = tokenizer.encode(text, add_sos=True, get_mask=True)
    labels, label_mask = tokenizer.encode(text, add_sos=False, get_mask=True)
    return input_ids, input_mask, labels, label_mask

class TextGenerationDataset(torch.utils.data.Dataset):
    def __init__(self, file_paths, seq_len, tokenizer, standardize=True, get_mask=True):
        self.file_paths = file_paths
        self.seq_len = seq_len
        self.tokenizer = tokenizer
        self.standardize = standardize
        self.get_mask = get_mask
    
    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        with open(self.file_paths[idx], 'r', encoding='utf-8') as f:
            text = f.read()
            
        if self.standardize:
            text = standardization(text)
        
        input_ids, input_mask, labels, label_mask = prepare_lm_input_labels(text)
        sample = {
            'input_ids': input_ids,
            'target_ids': labels,
        }
        if self.get_mask:
            sample['input_mask'] = input_mask
            sample['target_mask'] = label_mask
            
        return sample
        
        # return self.tokenizer.encode(text, add_sos=True, get_mask=self.get_mask)

In [9]:
dataset = TextGenerationDataset(file_path, MAX_SEQ_LEN, tokenizer)

In [10]:
idx = random.randint(0, len(dataset))

sample = dataset[idx]
encoded_input = sample['input_ids']
encoded_target = sample['target_ids']
print(f'Encoded input: {encoded_input}')
print(f'Encoded target: {encoded_target}')
print(f'Decoded input: {tokenizer.decode(encoded_input)}')

Encoded input: tensor([    2,   518,   333,   144, 13278,     1,   788,     1,    88,  5208,
         5208,     1,   495,   315,   280,   157,   323,  8459,     5,   144,
            1,     1,  3847,   211,   473,    30,   527,  1897,   336,   120,
         2929,   120,  7245,  8887,   226,   218,  1055,     6,  4539,   226,
          218, 11089,   241,   226,   218,  4893,   120,  1417,   546, 13278,
         1279,  2432,  1204, 12024,   317,   426,    98,   336,    41,  1164,
           35,  1660,  1267, 13412,  4059,    16, 12944,  1772,  2771,   877,
          877,   226,  2312,  1654,   495,   218,   708,    67,  4557,  1152])
Encoded target: tensor([  518,   333,   144, 13278,     1,   788,     1,    88,  5208,  5208,
            1,   495,   315,   280,   157,   323,  8459,     5,   144,     1,
            1,  3847,   211,   473,    30,   527,  1897,   336,   120,  2929,
          120,  7245,  8887,   226,   218,  1055,     6,  4539,   226,   218,
        11089,   241,   226,   2

## Model - Transformer Decoder

In [20]:
def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
    """Generates an upper-triangular matrix of -inf, with zeros on diag.

    Args:
        sz (int): the size of the mask

    Returns:
        torch.Tensor: the mask tensor
    """
    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)

generate_square_subsequent_mask(5)

tensor([[0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0.]])

In [13]:
class TransformerBlock(nn.Module):
    def __init__(self, embedding_dim, n_heads, ff_dim, dropout=0.1):
        super(TransformerBlock, self).__init__()
        
        self.multi_head_attn = nn.MultiheadAttention(
            embedding_dim, n_heads, dropout=dropout, batch_first=True
        )
        self.ffn = nn.Sequential(
            nn.Linear(embedding_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embedding_dim)
        )
        
        self.layer_norm1 = nn.LayerNorm(embedding_dim)
        self.layer_norm2 = nn.LayerNorm(embedding_dim)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        batch_size, seq_len, d_model = x.shape
        
        # make a mask for padding tokens
        key_padding_mask = (1 - mask).bool() if mask is not None else None
        context_vector, _ = self.multi_head_attn(
            x, x, x,
            key_padding_mask=key_padding_mask,
            attn_mask=generate_square_subsequent_mask(seq_len).to(x.device)
        )
        context_vector = self.dropout1(context_vector)
        out1 = self.layer_norm1(x + context_vector)
        
        ffn_out = self.ffn(out1)
        ffn_out = self.dropout2(ffn_out)
        
        return self.layer_norm2(out1 + ffn_out)

In [14]:
class TokenAndPositionalEmbedding(nn.Module):
    def __init__(self, vocab_size, max_seq_len, embedding_dim):
        super(TokenAndPositionalEmbedding, self).__init__()
        
        self.token_embedding = nn.Embedding(vocab_size, embedding_dim)
        self.pos_embedding = nn.Embedding(max_seq_len, embedding_dim)
        
    def forward(self, x):
        batch_size, seq_len = x.shape
        positions = torch.arange(0, seq_len).expand(batch_size, seq_len).to(x.device)
        return self.token_embedding(x) + self.pos_embedding(positions)

In [15]:
class GeneratorModel(nn.Module):
    def __init__(
        self, vocab_size, max_seq_len, embedding_dim, n_heads, ff_dim, n_layers, dropout=0.1
    ):
        super(GeneratorModel, self).__init__()
        
        self.embedding = TokenAndPositionalEmbedding(vocab_size, max_seq_len, embedding_dim)
        self.layers = nn.ModuleList([
            TransformerBlock(embedding_dim, n_heads, ff_dim, dropout) for _ in range(n_layers)
        ])
        self.fc = nn.Linear(embedding_dim, vocab_size)
        
        self.apply(init_weights)
        
    def forward(self, x, mask=None):
        x = self.embedding(x)
        for layer in self.layers:
            x = layer(x, mask)
        return self.fc(x)
    
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.kaiming_normal_(m.weight)
        m.bias.data.fill_(0.01)
    elif isinstance(m, nn.Embedding):
        torch.nn.init.normal_(m.weight, std=0.02)
    elif isinstance(m, nn.LayerNorm):
        torch.nn.init.normal_(m.weight, std=0.02)
        torch.nn.init.constant_(m.bias, 0)

## Training

In [18]:
def train(model, data_loader, optimizer, criterion, device):
    model.train()
    losses = []
    avg_loss = 0
    pbar = tqdm(data_loader)
    for batch in pbar:
        input_ids = batch['input_ids'].to(device)
        target_ids = batch['target_ids'].to(device)
        input_mask = batch['input_mask'].to(device)
        target_mask = batch['target_mask'].to(device)
        
        optimizer.zero_grad()
        output = model(input_ids, input_mask)
        output = output.reshape(-1, output.shape[-1])
        target_ids = target_ids.reshape(-1)
        target_mask = target_mask.reshape(-1)
        
        loss = criterion(output, target_ids)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        
        avg_loss = (0.9 * avg_loss+ 0.1 * loss.item()) / (1 - 0.9 ** (len(losses)))
        pbar.set_description(f'Loss: {loss.item():.4f}, Avg Loss: {avg_loss:.4f}')
        
    return np.mean(losses)

In [19]:
dataset = TextGenerationDataset(file_path, MAX_SEQ_LEN, tokenizer)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

model = GeneratorModel(
    vocab_size=tokenizer.vocab_size,
    max_seq_len=MAX_SEQ_LEN,
    embedding_dim=EMBEDDING_DIM,
    n_heads=N_HEADS,
    ff_dim=FF_DIM,
    n_layers=N_LAYERS,
    dropout=DROP_OUT
)
model = model.to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
critetion = nn.CrossEntropyLoss(ignore_index=tokenizer['<PAD>'])

for epoch in range(EPOCHS):
    print(f'Training Epoch: {epoch+1:02}')
    train_loss = train(model, dataloader, optimizer, critetion, DEVICE)
    print(f'Epoch: {epoch+1:02}, Train Loss: {train_loss:.4f}')
    print("#"*100)

Training Epoch: 01


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch: 01, Train Loss: 8.7503
####################################################################################################
Training Epoch: 02


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch: 02, Train Loss: 6.4457
####################################################################################################
Training Epoch: 03


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch: 03, Train Loss: 5.9942
####################################################################################################
Training Epoch: 04


  0%|          | 0/1563 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, "model.pt"))
tokenizer.to_json_file(os.path.join(OUTPUT_DIR, "tokenizer.json"))

## Generate a sentence

In [None]:
def sample_from(logits, top_k=10):
    logits, indices = torch.topk(logits, top_k)
    indices = np.asarray(indices).astype(np.int32)
    preds = torch.softmax(logits, dim=-1).numpy()
    preds = np.asarray(preds).astype(np.float32)
    return np.random.choice(indices, p=preds)

def generate(start_tokens, max_generated_tokens):
    start_tokens = [_ for _ in start_tokens]
    num_tokens_generated_local = 0
    tokens_generated = []
    max_generated_tokens 
    while num_tokens_generated_local <= max_generated_tokens:
        pad_len = MAX_SEQ_LEN - len(start_tokens)
        sample_index = len(start_tokens) - 1
        if pad_len > 0:
            x = start_tokens + [0] * pad_len
        else:
            x = start_tokens
        x = torch.tensor(x).unsqueeze(0).to(DEVICE)
        with torch.no_grad():
            y = model(x)
            y = y.cpu()
            
        sample_token = sample_from(y[0][sample_index])
        
        if sample_token == tokenizer['<EOS>']:
            break
        
        tokens_generated.append(sample_token)
        start_tokens.append(sample_token)
        num_tokens_generated_local = len(tokens_generated)
        
    txt = tokenizer.decode(tokens_generated)
    return txt

In [None]:
start_prompt = "this movie is"
start_tokens = tokenizer.encode(start_prompt)[:len(start_prompt.split())]
num_tokens_generated = 30
self_max_tokens = 30

generated_text = generate(start_tokens, max_generated_tokens=self_max_tokens)

print(f'Prompt: {start_prompt}')
print(f'Generated Text: {generated_text}')