# Word2Vec using CBoW with negative sampling

References:
- [Efficient Estimation of Word Representations in Vector Space](https://arxiv.org/abs/1301.3781)
- [Distributed Representations of Words and Phrases and their Compositionality](https://arxiv.org/abs/1310.4546)

In [None]:
%%bash
git clone https://github.com/tky823/DNN-based_source_separation.git

In [None]:
import os
import sys
import warnings
import random
from functools import partial

In [None]:
sys.path.append("/content/DNN-based_source_separation/egs/tutorials/word2vec/src")

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.rcParams["figure.dpi"] = 100

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
import torchtext
from torchtext.data import to_map_style_dataset
from torchtext.data.utils import get_tokenizer

In [None]:
from adhoc_utils import build_vocab, build_neg_freq, build_neg_table
from adhoc_criterion import NegativeSamplingLoss
from adhoc_driver import Trainer

In [None]:
def collate_fn_freq(batch, text_pipeline, neg_freq, context_size=4, num_neg_samples=20, max_seq_length=256):
    assert type(neg_freq) is list

    neg_samples = torch.arange(len(neg_freq)) # Includes all word IDs
    batch_input, batch_pos, batch_neg = [], [], []
    for text in batch:
        text_tokens_ids = text_pipeline(text)
        raw_seq_length = len(text_tokens_ids)

        if raw_seq_length < 2 * context_size + 1:
            continue

        if max_seq_length is not None:
            _max_seq_length = min(max_seq_length, raw_seq_length)
            high = max(0, raw_seq_length - _max_seq_length)
            start_idx = torch.randint(0, high+1, ())
            end_idx = start_idx + _max_seq_length
            text_tokens_ids = text_tokens_ids[start_idx: end_idx]
            seq_length = len(text_tokens_ids)
        else:
            seq_length = raw_seq_length

        for start_idx in range(seq_length - 2 * context_size):
            end_idx = start_idx + 2 * context_size + 1
            token_id_sequence = text_tokens_ids[start_idx: end_idx]
            pos = token_id_sequence.pop(context_size)
            input = token_id_sequence

            # nagative sampling
            tmp = neg_freq[pos]
            neg_freq[pos] = 0
            neg = random.choices(neg_samples, k=num_neg_samples, weights=neg_freq)
            neg_freq[pos] = tmp

            batch_input.append(input)
            batch_pos.append(pos)
            batch_neg.append(neg)

    batch_input = torch.tensor(batch_input, dtype=torch.long) # (num_samples, 2 * context_size)
    batch_pos = torch.tensor(batch_pos, dtype=torch.long) # (num_samples,)
    batch_neg = torch.tensor(batch_neg, dtype=torch.long) # (num_samples, num_neg_samples)

    return batch_input, batch_pos, batch_neg

def collate_fn_table(batch, text_pipeline, neg_table, context_size=4, num_neg_samples=20, max_seq_length=256):
    distr_table = neg_table["distr"]
    start_table = neg_table["start"]
    count_table = neg_table["count"]

    assert isinstance(distr_table, torch.Tensor) and isinstance(start_table, torch.Tensor) and isinstance(count_table, torch.Tensor)

    neg_table_size = len(distr_table)
    batch_input, batch_pos, batch_neg = [], [], []

    for text in batch:
        text_tokens_ids = text_pipeline(text)
        raw_seq_length = len(text_tokens_ids)

        if raw_seq_length < 2 * context_size + 1:
            continue

        if max_seq_length is not None:
            _max_seq_length = min(max_seq_length, raw_seq_length)
            high = max(0, raw_seq_length - _max_seq_length)
            start_idx = torch.randint(0, high+1, ())
            end_idx = start_idx + _max_seq_length
            text_tokens_ids = text_tokens_ids[start_idx: end_idx]
            seq_length = len(text_tokens_ids)
        else:
            seq_length = raw_seq_length

        for start_idx in range(seq_length - 2 * context_size):
            end_idx = start_idx + 2 * context_size + 1
            token_id_sequence = text_tokens_ids[start_idx: end_idx]
            pos = token_id_sequence.pop(context_size)
            input = token_id_sequence

            # nagative sampling
            pos = torch.tensor(pos, dtype=torch.long)
            start = start_table[pos]
            count = count_table[pos]
            
            table_ids = torch.randint(0, neg_table_size - count, (num_neg_samples,))
            table_ids = torch.where(table_ids >= start, table_ids + count, table_ids)
            neg = distr_table[table_ids].tolist()

            batch_input.append(input)
            batch_pos.append(pos)
            batch_neg.append(neg)

    batch_input = torch.tensor(batch_input, dtype=torch.long) # (num_samples, 2 * context_size)
    batch_pos = torch.tensor(batch_pos, dtype=torch.long) # (num_samples,)
    batch_neg = torch.tensor(batch_neg, dtype=torch.long) # (num_samples, num_neg_samples)

    return batch_input, batch_pos, batch_neg

In [None]:
class CBoWNegativeSampling(nn.Module):
    def __init__(self, vocab_size, embed_dim, max_norm=1, eps=1e-12):
        super().__init__()

        self.vocab_size, self.embed_dim = vocab_size, embed_dim
        self.max_norm = max_norm
        self.eps = eps

        self.enc_embedding = nn.Embedding(vocab_size, embed_dim, max_norm=max_norm)
        self.dec_embedding = nn.Embedding(vocab_size, embed_dim)

    def forward(self, input, pos=None, neg=None):
        """
        Args:
            input: (batch_size, 2 * context_size)
            pos: (batch_size,)
            neg: (batch_size, num_neg_samples)
        Returns:
            output: (batch_size, embed_dim)
            pos_output: (batch_size, embed_dim)
            neg_output: (batch_size, num_neg_samples, embed_dim)
        """
        output = self.embed(input)

        if pos is None and neg is None:
            return output

        if pos is None or neg is None:
            raise ValueError("Specify pos and neg.")

        pos_output = self.dec_embedding(pos) # (batch_size, embed_dim)
        neg_output = self.dec_embedding(neg) # (batch_size, num_neg_samples, embed_dim)

        return output, pos_output, neg_output

    def embed(self, input, normalized=False):
        """
        Args:
            input: (batch_size, 2 * context_size)
        Returns:
            output: (batch_size, embed_dim)
        """
        output = self.enc_embedding(input)
        output = output.mean(dim=1)

        if normalized:
            output = F.normalize(output, dim=1, eps=self.eps) # (batch_size, embed_dim)
        
        return output

    def get_embedding_weights(self):
        """
        Returns:
            weights: (vocab_size, embed_dim)
        """
        max_norm = self.max_norm

        weights = self.enc_embedding.weight.data
        norm = torch.linalg.vector_norm(weights, dim=1, keepdim=True) # (vocab_size, 1)
        weights = torch.where(norm > max_norm, max_norm * weights / norm, weights)

        return weights

In [None]:
class AdhocTrainer(Trainer):
    def __init__(self, model, loader, criterion, optimizer, config):
        super().__init__(model, loader, criterion, optimizer, config)

    def run_one_epoch_train(self, epoch):
        train_loss = 0

        self.model.train()

        for idx, (input, pos, neg) in enumerate(self.train_loader):
            if self.use_cuda:
                input = input.cuda()
                pos = pos.cuda()
                neg = neg.cuda()

            output, pos_output, neg_output = self.model(input, pos, neg)
            loss = self.criterion(output, pos_output, neg_output)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            train_loss += loss.item()

            if (idx + 1) % 100 == 0:
                print("[Epoch {}/{}] iter {}/{} loss: {:.5f}".format(epoch + 1, self.epochs, idx + 1, len(self.train_loader), loss.item()), flush=True)

        train_loss /= len(self.train_loader)

        return train_loss

    def run_one_epoch_eval(self, epoch):
        valid_loss = 0

        self.model.eval()

        with torch.no_grad():
            for idx, (input, pos, neg) in enumerate(self.valid_loader):
                if self.use_cuda:
                    input = input.cuda()
                    pos = pos.cuda()
                    neg = neg.cuda()
                output, pos_output, neg_output = self.model(input, pos, neg)
                loss = self.criterion(output, pos_output, neg_output)
                valid_loss += loss.item()

        valid_loss /= len(self.valid_loader)

        return valid_loss

In [None]:
dataset = "WikiText2" # or "WikiText103"
exp_dir = "./exp"

config = {
    "system": {
        "seed": 111,
        "use_cuda":  torch.cuda.is_available(),
        "eps": 1e-12
    },
    "dataset": dataset,
    "vocab": {
        "min_freq": 50
    },
    "vocab_path": os.path.join(exp_dir, dataset, "vocab/vocab.pth"),
    "nagative_sampling": {
        "sampling": "table", # "frequency"
        "num_samples": 20,
        "smooth": 0.75
    },
    "context_size": 4,
    "model": {
        "embed_dim": 300
    },
    "optim": {
        "lr": 1e-3
    },
    "batch_size": 96,
    "epochs": 100,
    "model_dir": os.path.join(exp_dir, dataset, "CBoW_negative-sampling/model"),
    "loss_dir": os.path.join(exp_dir, dataset, "CBoW_negative-sampling/loss"),
    "continue_from": None # None or os.path.join(exp_dir, dataset, "CBoW_negative-sampling/model/last.pth")
}

In [None]:
random.seed(config["system"]["seed"])
torch.manual_seed(config["system"]["seed"])

In [None]:
if config["dataset"] == "WikiText2":
    train_iter = torchtext.datasets.WikiText2(root="./", split='train')
    valid_iter = torchtext.datasets.WikiText2(root="./", split='valid')
elif config["dataset"] == "WikiText103":
    train_iter = torchtext.datasets.WikiText103(root="./", split='train')
    valid_iter = torchtext.datasets.WikiText103(root="./", split='valid')
else:
    raise NotImplementedError("Not support {}.".format(config["dataset"]))

train_iter = to_map_style_dataset(train_iter)
valid_iter = to_map_style_dataset(valid_iter)

In [None]:
tokenizer = get_tokenizer("basic_english", language="en")

if os.path.exists(config["vocab_path"]):
    vocab = torch.load(config["vocab_path"])
else:
    vocab = build_vocab(train_iter, tokenizer, min_freq=config["vocab"]["min_freq"])
    vocab_dir = os.path.dirname(config["vocab_path"])
    os.makedirs(vocab_dir, exist_ok=True)
    torch.save(vocab, config["vocab_path"])

text_pipeline = lambda x: vocab(tokenizer(x))

neg_freq = build_neg_freq(train_iter, vocab, tokenizer, smooth=config["nagative_sampling"]["smooth"])

In [None]:
loader = {}
if config["nagative_sampling"]["sampling"] == "table":
    neg_table = build_neg_table(neg_freq)
    loader["train"] = torch.utils.data.DataLoader(train_iter, batch_size=config["batch_size"], shuffle=False, collate_fn=partial(collate_fn_table, text_pipeline=text_pipeline, neg_table=neg_table, context_size=config["context_size"], num_neg_samples=config["nagative_sampling"]["num_samples"]))
    loader["valid"] = torch.utils.data.DataLoader(valid_iter, batch_size=config["batch_size"], shuffle=False, collate_fn=partial(collate_fn_table, text_pipeline=text_pipeline, neg_table=neg_table, context_size=config["context_size"], num_neg_samples=config["nagative_sampling"]["num_samples"]))
else:
    warnings.warn("Frequency-based sampling may take long time.", UserWarning)
    loader["train"] = torch.utils.data.DataLoader(train_iter, batch_size=config["batch_size"], shuffle=False, collate_fn=partial(collate_fn_freq, text_pipeline=text_pipeline, neg_freq=neg_freq, context_size=config["context_size"], num_neg_samples=config["nagative_sampling"]["num_samples"]))
    loader["valid"] = torch.utils.data.DataLoader(valid_iter, batch_size=config["batch_size"], shuffle=False, collate_fn=partial(collate_fn_freq, text_pipeline=text_pipeline, neg_freq=neg_freq, context_size=config["context_size"], num_neg_samples=config["nagative_sampling"]["num_samples"]))

In [None]:
model = CBoWNegativeSampling(vocab_size=len(vocab), embed_dim=config["model"]["embed_dim"])
print(model)

In [None]:
if config["system"]["use_cuda"]:
    model.cuda()
    print("Uses CUDA.")
else:
    print("Does NOT use CUDA.")

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=config["optim"]["lr"])

In [None]:
criterion = NegativeSamplingLoss()

## Training

In [None]:
trainer = AdhocTrainer(model, loader, criterion, optimizer, config)
trainer.run()

## Examine

In [None]:
from word2vec import Word2Vec

In [None]:
model_path = os.path.join(config["model_dir"], "last.pth")
package = torch.load(model_path, map_location=lambda storage, loc: storage)

model = CBoWNegativeSampling(vocab_size=len(vocab), embed_dim=config["model"]["embed_dim"])
model.load_state_dict(package["state_dict"])

In [None]:
word2vec = Word2Vec(model.get_embedding_weights(), vocab, eps=config["system"]["eps"])

In [None]:
similar_words = word2vec.get_similar_words("mother")
print(similar_words)

In [None]:
son_vec, man_vec, woman_vec = word2vec(["son", "man", "woman"])
similar_words = word2vec.get_similar_words_from_vec(son_vec - man_vec + woman_vec)
print(similar_words)