# Word2Vec using skip-gram

Reference:
- [Efficient Estimation of Word Representations in Vector Space](https://arxiv.org/abs/1301.3781)

In [None]:
%%bash
git clone https://github.com/tky823/DNN-based_source_separation.git

In [None]:
import os
import sys
from functools import partial

In [None]:
sys.path.append("/content/DNN-based_source_separation/egs/tutorials/word2vec/src")

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.rcParams["figure.dpi"] = 100

In [None]:
import torch
import torch.nn as nn

In [None]:
import torchtext
from torchtext.data import to_map_style_dataset
from torchtext.data.utils import get_tokenizer

In [None]:
from adhoc_utils import build_vocab
from adhoc_driver import Trainer

In [None]:
def collate_fn(batch, text_pipeline, context_size=4, max_seq_length=256):
    batch_input, batch_target = [], []
    for text in batch:
        text_tokens_ids = text_pipeline(text)
        raw_seq_length = len(text_tokens_ids)

        if raw_seq_length < 2 * context_size + 1:
            continue

        if max_seq_length is not None:
            _max_seq_length = min(max_seq_length, raw_seq_length)
            high = max(0, raw_seq_length - _max_seq_length)
            start_idx = torch.randint(0, high+1, ())
            end_idx = start_idx + _max_seq_length
            text_tokens_ids = text_tokens_ids[start_idx: end_idx]
            seq_length = len(text_tokens_ids)
        else:
            seq_length = raw_seq_length

        for start_idx in range(seq_length - 2 * context_size):
            end_idx = start_idx + 2 * context_size + 1
            token_id_sequence = text_tokens_ids[start_idx: end_idx]
            input = token_id_sequence.pop(context_size)
            target = token_id_sequence
            batch_input.append(input)
            batch_target.append(target)

    batch_input = torch.tensor(batch_input, dtype=torch.long) # (num_samples,)
    batch_target = torch.tensor(batch_target, dtype=torch.long) # (num_samples, 2 * context_size)

    return batch_input, batch_target

In [None]:
class SkipGram(nn.Module):
    def __init__(self, vocab_size, embed_dim, bias=False, max_norm=1):
        super().__init__()

        self.vocab_size, self.embed_dim = vocab_size, embed_dim
        self.max_norm = max_norm

        self.embedding = nn.Embedding(vocab_size, embed_dim, max_norm=max_norm)
        self.linear = nn.Linear(embed_dim, vocab_size, bias=bias)

    def forward(self, input):
        """
        Args:
            input: (batch_size,) or (batch_size, context_size)
        Returns:
            output: (batch_size, vocab_size) or (batch_size, vocab_size, context_size)
        """
        x = self.embedding(input)
        x = self.linear(x)

        if x.dim() == 3:
            output = x.permute(0, 2, 1)
        else:
            output = x

        return output

    def get_embedding_weights(self):
        """
        Returns:
            weights: (vocab_size, embed_dim)
        """
        max_norm = self.max_norm

        weights = self.embedding.weight.data
        norm = torch.linalg.vector_norm(weights, dim=1, keepdim=True) # (vocab_size, 1)
        weights = torch.where(norm > max_norm, max_norm * weights / norm, weights)

        return weights

In [None]:
class AdhocTrainer(Trainer):
    def __init__(self, model, loader, criterion, optimizer, config):
        super().__init__(model, loader, criterion, optimizer, config)

    def run_one_epoch_train(self, epoch):
        context_size = self.context_size
        train_loss = 0

        self.model.train()

        for idx, (input, target) in enumerate(self.train_loader):
            if self.use_cuda:
                input = input.cuda()
                target = target.cuda()

            output = self.model(input) # (num_samples, vocab_size)
            output = output.unsqueeze(dim=-1) # (num_samples, vocab_size, 1)
            output = output.expand(-1, -1, 2 * context_size) # (num_samples, vocab_size, 2 * context_size)
            loss = self.criterion(output, target) # (num_samples, 2 * context_size)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            train_loss += loss.item()

            if (idx + 1) % 100 == 0:
                print("[Epoch {}/{}] iter {}/{} loss: {:.5f}".format(epoch + 1, self.epochs, idx + 1, len(self.train_loader), loss.item()), flush=True)

        train_loss /= len(self.train_loader)

        return train_loss

    def run_one_epoch_eval(self, epoch):
        context_size = self.context_size
        valid_loss = 0

        self.model.eval()

        with torch.no_grad():
            for idx, (input, target) in enumerate(self.valid_loader):
                if self.use_cuda:
                    input = input.cuda()
                    target = target.cuda()

                output = self.model(input) # (num_samples, vocab_size)
                output = output.unsqueeze(dim=-1) # (num_samples, vocab_size, 1)
                output = output.expand(-1, -1, 2 * context_size) # (num_samples, vocab_size, 2 * context_size)
                loss = self.criterion(output, target) # (num_samples, 2 * context_size)
                valid_loss += loss.item()

        valid_loss /= len(self.valid_loader)

        return valid_loss

In [None]:
dataset = "WikiText2" # or "WikiText103"
exp_dir = "./exp"

config = {
    "system": {
        "seed": 111,
        "use_cuda":  torch.cuda.is_available(),
        "eps": 1e-12
    },
    "dataset": dataset,
    "vocab": {
        "min_freq": 50
    },
    "vocab_path": os.path.join(exp_dir, dataset, "vocab/vocab.pth"),
    "context_size": 4,
    "model": {
        "embed_dim": 300
    },
    "optim": {
        "lr": 1e-3
    },
    "batch_size": 96,
    "epochs": 100,
    "model_dir": os.path.join(exp_dir, dataset, "skip-gram_naive/model"),
    "loss_dir": os.path.join(exp_dir, dataset, "skip-gram_naive/loss"),
    "continue_from": None # None or os.path.join(exp_dir, dataset, "skip-gram_naive/model/last.pth")
}

In [None]:
torch.manual_seed(config["system"]["seed"])

In [None]:
if config["dataset"] == "WikiText2":
    train_iter = torchtext.datasets.WikiText2(root="./", split='train')
    valid_iter = torchtext.datasets.WikiText2(root="./", split='valid')
elif config["dataset"] == "WikiText103":
    train_iter = torchtext.datasets.WikiText103(root="./", split='train')
    valid_iter = torchtext.datasets.WikiText103(root="./", split='valid')
else:
    raise NotImplementedError("Not support {}.".format(config["dataset"]))

train_iter = to_map_style_dataset(train_iter)
valid_iter = to_map_style_dataset(valid_iter)

In [None]:
tokenizer = get_tokenizer("basic_english", language="en")

if os.path.exists(config["vocab_path"]):
    vocab = torch.load(config["vocab_path"])
else:
    vocab = build_vocab(train_iter, tokenizer, min_freq=config["vocab"]["min_freq"])
    vocab_dir = os.path.dirname(config["vocab_path"])
    os.makedirs(vocab_dir, exist_ok=True)
    torch.save(vocab, config["vocab_path"])

text_pipeline = lambda x: vocab(tokenizer(x))

In [None]:
loader = {}
loader["train"] = torch.utils.data.DataLoader(train_iter, batch_size=config["batch_size"], shuffle=False, collate_fn=partial(collate_fn, text_pipeline=text_pipeline, context_size=config["context_size"]))
loader["valid"] = torch.utils.data.DataLoader(valid_iter, batch_size=config["batch_size"], shuffle=False, collate_fn=partial(collate_fn, text_pipeline=text_pipeline, context_size=config["context_size"]))

In [None]:
model = SkipGram(vocab_size=len(vocab), embed_dim=config["model"]["embed_dim"])
print(model)

In [None]:
if config["system"]["use_cuda"]:
    model.cuda()
    print("Uses CUDA.")
else:
    print("Does NOT use CUDA.")

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=config["optim"]["lr"])

In [None]:
criterion = nn.CrossEntropyLoss()

## Training

In [None]:
trainer = AdhocTrainer(model, loader, criterion, optimizer, config)
trainer.run()

## Examine

In [None]:
from word2vec import Word2Vec

In [None]:
model_path = os.path.join(config["model_dir"], "last.pth")
package = torch.load(model_path, map_location=lambda storage, loc: storage)

model = SkipGram(vocab_size=len(vocab), embed_dim=config["model"]["embed_dim"])
model.load_state_dict(package["state_dict"])

In [None]:
word2vec = Word2Vec(model.get_embedding_weights(), vocab, eps=config["system"]["eps"])

In [None]:
similar_words = word2vec.get_similar_words("mother")
print(similar_words)

In [None]:
son_vec, man_vec, woman_vec = word2vec(["son", "man", "woman"])
similar_words = word2vec.get_similar_words_from_vec(son_vec - man_vec + woman_vec)
print(similar_words)