In [94]:
from functools import partial
from typing import Iterable
import torch
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torchtext.datasets import WikiText2
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data import to_map_style_dataset
from torchtext.data.utils import get_tokenizer
from torch.utils.data import DataLoader
import numpy as np
from torch.optim.lr_scheduler import LambdaLR



MIN_WORD_FREQUENCY = 50
CBOW_N_WORDS = 2
MAX_SEQUENCE_LENGTH = 256
BATCH_SIZE = 96
DEVICE = torch.device('mps' if torch.has_mps else 'cpu')

vocab_size = 4000
n_embed = 300

train_iter = to_map_style_dataset(WikiText2(split='train'))
val_iter = to_map_style_dataset(WikiText2(split='valid'))

def get_vocab(data_iter: Iterable, tokenizer):
    vocab = build_vocab_from_iterator(map(tokenizer, data_iter), MIN_WORD_FREQUENCY, ['<unk>'])
    vocab.set_default_index(vocab['<unk>'])
    return vocab.to(DEVICE)


# vocab(tokenizer('Hello.')) # vocab(list) -> list of indices



def collate_cbow(batch, text_pipeline):
    """
    Collate_fn for CBOW model to be used with Dataloader.
    `batch` is expected to be list of text paragrahs.
    
    text_pipeline is a function that converts a list of words into indices

    Context is represented as N=CBOW_N_WORDS past words 
    and N=CBOW_N_WORDS future words.
    
    Long paragraphs will be truncated to contain
    no more that MAX_SEQUENCE_LENGTH tokens.
    
    Each element in `batch_input` is N=CBOW_N_WORDS*2 context words.
    Each element in `batch_output` is a middle word.
    """
    xb, yb = [], []
    for paragraph in batch:
        text_indices = text_pipeline(paragraph)
        window = CBOW_N_WORDS * 2 + 1
        if len(text_indices) < window:
            continue

        text_indices = text_indices[:MAX_SEQUENCE_LENGTH] # why?

        for i in range(len(text_indices) - CBOW_N_WORDS * 2):
            x = [*text_indices[i:i+CBOW_N_WORDS], *text_indices[i+CBOW_N_WORDS+1:i + window]]
            y = text_indices[i+CBOW_N_WORDS]
            xb.append(x)
            yb.append(y)
    
    return torch.tensor(xb, dtype=torch.long).to(DEVICE), torch.tensor(yb, dtype=torch.long).to(DEVICE)


# DataLoader splits data into batches of BATCH_SIZE paragraphs
# high = 0
# low = 1000
# for x, y in dataloader:
#     high = max(high, x.shape[0])
#     low = min(low, x.shape[0])

# low, high

def get_dataloader_and_vocab(split: str, vocab=None):
    data_iter = to_map_style_dataset(WikiText2(split=split))
    tokenizer = get_tokenizer('basic_english', language='en')

    if vocab is None:
        vocab = get_vocab(data_iter, tokenizer)

    text_pipeline = lambda paragraph: vocab(tokenizer(paragraph))
    dataloader = DataLoader(data_iter, BATCH_SIZE, shuffle=True, collate_fn=partial(collate_cbow, text_pipeline=text_pipeline))
    return dataloader, vocab


train_dataloader, vocab = get_dataloader_and_vocab('train')
# len(vocab), len(vocab.get_stoi())
vocab['man']




240

In [95]:
class CBOW(nn.Module):
    def __init__(self, vocab) -> None:
        super().__init__()
        vocab_size = len(vocab.get_stoi())
        self.embeddings = nn.Embedding(vocab_size, n_embed, max_norm=1.0)
        self.ln = nn.Linear(n_embed, vocab_size)

    def forward(self, x: torch.Tensor, target: torch.Tensor = None):
        # x = (B, 4)
        x: torch.Tensor = self.embeddings(x) # (B, 4, 300)
        x = x.mean(-2) # (B, 300)
        logits = self.ln(x) # (B, vocab_size)
        loss = None
        if target is not None:
            loss = F.cross_entropy(logits, target)
        return logits, loss
    

def training(model: nn.Module, dataloader: DataLoader, optimizer, eval_iter: int = 100):
    losses = []
    for i, (x, target) in enumerate(dataloader, 1):
        # forward
        _, loss = model(x, target) # (B, vocab_size)
        losses.append(loss.item())
        if i % eval_iter == 0:
            eval_loss = np.mean(losses)
            print(f"iteration: {i}, loss: {eval_loss:.4f}")
            losses = []
        
        # backward
        optimizer.zero_grad(set_to_none=True)
        loss.backward()

        # update
        optimizer.step()

    return model

def validate(model: nn.Module, dataloader: DataLoader):
    losses = []
    model.eval()
    with torch.no_grad():
        for i, (x, target) in enumerate(dataloader):
            _, loss = model(x, target)
            losses.append(loss.item())

    val_loss = np.mean(losses)
    print(f"validation loss: {val_loss:.4f}")
    model.train()
    return val_loss
    
        

    



In [96]:
train_dataloader, vocab = get_dataloader_and_vocab('train')
val_dataloader, _ = get_dataloader_and_vocab('valid', vocab=vocab)
m = CBOW(vocab).to(DEVICE)

def get_scheduler(optimizer, total_epochs: int):
    l = lambda epoch: (total_epochs - epoch) / total_epochs
    return LambdaLR(optimizer, lr_lambda=l, verbose=True)


def train_and_eval(model: nn.Module, epochs: int, train_dataloader, val_dataloader):
    validate(model, val_dataloader)
    optimzr = AdamW(model.parameters(), lr=0.025)
    lr_scheduler = get_scheduler(optimzr, 5)
    for i in range(epochs):
        print(f"============= EPOCH{i+1} =============")
        model = training(model, train_dataloader, optimzr)
        validate(model, val_dataloader)
        lr_scheduler.step()

    return model


model = train_and_eval(m, 5, train_dataloader, val_dataloader)
torch.save(model, 'params/model.pt')

    


KeyboardInterrupt: 

In [None]:
torch.load('params/model.pt')