In [1]:
#!pip3 install torch torchtext==0.17.0 torchdata

In [2]:
import os
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchtext.data.utils import get_tokenizer

from dataclasses import dataclass, field

import re
import nltk
from nltk.tokenize import sent_tokenize
#nltk.download('punkt')
#nltk.download('punkt_tab')

from collections import Counter, OrderedDict
from typing import Dict, List, Optional, Union
from time import monotonic 

import matplotlib.pyplot as plt

print('torch version: ', torch.__version__)
print('numpy version: ', np.__version__)
print('nltk version:', nltk.__version__)

torch version:  2.2.0
numpy version:  1.25.2
nltk version: 3.9.1


In [3]:
data_dir = '/Users/susan/work/datasets'
file = os.path.join(data_dir, "wiki.train.tokens")

In [4]:
def load_text(filepath, max_tokens=100000):
    with open(filepath, 'r', encoding='utf-8') as f:
        text = f.read()
    tokens = text.split()
    text = " ".join(tokens[:max_tokens])
    return text

In [5]:
text = load_text(file, max_tokens=5000)
print("Loaded text with", len(text.split()), "tokens.")

Loaded text with 5000 tokens.


In [6]:
sentences = sent_tokenize(text)
split_sents= [s.lower().split() for s in sentences]
sentences[:2]

['= Valkyria Chronicles III = Senjō no Valkyria 3 : <unk> Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit .',
 'Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable .']

In [7]:
_debug_ = True

In [8]:
@dataclass
class Params:

    # skipgram parameters
    dim_embedding = 300
    max_norm_embedding = None
    min_freq = 1 #50    
    threshold = 1.0e-2 # 1.0e-5 subsampling threshold
    window_size = 4 # context window one-side length
    n_neg_samples = 5 
    neg_exponent = 0.75
    discarded = "<>" # special characters or below threshold frequency
    tokenizer = 'basic_english'
    
    # training parameters
    batch_size : int = 10
    criterion = None
    shuffle = True
    learning_rate = 5e-4
    n_epochs = 50
    train_steps = 1
    val_steps = 1
    checkpoint_frequency = 1

    model_name = 'SkipGram'
    model_dir = "weights/{}".format(model_name)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 


In [9]:
class TokenMap:
    def __init__(self, list_tokens, discarded):
        self.by_token = {token:(index, freq) for index, (token,freq) in enumerate(list_tokens)}
        self.by_index = {index:(token, freq) for index, (token,freq) in enumerate(list_tokens)}
        self.total_tokens = np.nansum([freq for _, (token,freq) in enumerate(list_tokens)], dtype=int)
        self.dim_vocab = len(self.by_token)
        self.discard_id = self.by_token.get(discarded[0])[0]
        discarded[0] 
        self.discarded = discarded[0]

    def get_index(self, word: Union[str, List]):
        if isinstance(word, str):
            if word in self.by_token:
                return self.by_token.get(word)[0]
            else:
                return self.by_token.get(self.discarded)[0]
        elif isinstance(word, list):
            ans = []
            for w in word:
                if w in self.by_token:
                    ans.append(self.by_token.get(w)[0])
                else:
                    ans.append(self.by_token.get(self.discarded)[0])
            return ans
        else:
            raise ValueError(f"Word {word} should be a string or a list of strings.")

    
    def get_token(self, index: Union[int, List]):
        if isinstance(index, (int, np.int64)):
            if index in self.by_index:
                return self.by_index.get(index)[0]
            else:
                raise ValueError(f"Index {index} not in valid range")
        elif isinstance(index, list):
            ans = []
            for j in index:
                if j in self.by_token:
                    ans.append(self.by_index.get(j)[0])
                else:
                    raise ValueError(f"Index {j} not in valid range.")
            return ans    
    

    def get_frequency(self, word: Union[str, List]):
        if isinstance(word, str):
            if word in self.by_token:
                return self.by_token.get(word)[1]
            else:
                return self.by_token.get(self.discarded)[1]
        elif isinstance(word, list):
            ans = []
            for w in word:
                if w in self.by_token:
                    ans.append(self.by_token.get(w)[1])
                else:
                    ans.append(self.by_token.get(self.discarded)[1])
            return ans
        else:
            raise ValueError(f"Word {word} should be a string or a list of strings.")



In [10]:
def build_tokenmap(iter, params : Params, 
                 max_tokens: Optional[int] = None):
   
    def filter_tokens(iter_, tokenizer_):
        r = re.compile('[a-z1-9]')
        for t in iter_:
            res = tokenizer_(t)
            res = list(filter(r.match, res))
            yield res
    
    tokenizer = get_tokenizer(params.tokenizer)
    counter = Counter()
    for token_ in filter_tokens(iter, tokenizer):
        counter.update(token_)

    freq_tuple = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
    freq_dict = OrderedDict(freq_tuple)

    tokens = []
    for token, freq in freq_dict.items():
        if freq >= params.min_freq:
            tokens.append((token, freq))

    discarded = (params.discarded, np.nan)
    tokens.append(discarded)

    return TokenMap(tokens, discarded)  


In [11]:
class BatchTool:
    def __init__(self, tokenmap: TokenMap, params: Params):
        self.map = tokenmap
        self.params = params
        self.tokenizer = get_tokenizer(params.tokenizer)   
        self.discard_probs = self.get_discard_probs()

    def frequency_from_percentile(self, percentile= 90):
        freq_list = []
        for _, (_, freq) in self.map.by_token.items():
            if freq == freq:
                freq_list.append(freq/self.map.total_tokens)
            
        return np.percentile(freq_list, percentile)

    def get_discard_probs(self):
        discard_probs = {}
        for _, (word, freq) in self.map.by_token.items():
            prob_raw = 1-np.sqrt(self.params.threshold /(freq/self.map.total_tokens))
            prob = max(prob_raw,0)
            discard_probs[word] = prob
        return discard_probs

    def collate_fn(self, batches):
        inputs, outputs  = [], []
        discard_id = self.map.discard_id
        
        for sentence in batches:
            token_ids = self.map.get_index(self.tokenizer(sentence))

            if len(token_ids) <= self.params.window_size * 2:
                continue

            for id in range(len(token_ids) - self.params.window_size*2):
                window = token_ids[id : (id + self.params.window_size * 2 + 1)]
                target_id = window.pop(self.params.window_size)
                context_ids = window
              
                p = random.random()
                p_discard = self.discard_probs.get(target_id)
        
                if p_discard >= p or target_id == discard_id:
                    continue
                
                for context_id in context_ids:
                    p = random.random()
                    p_discard = self.discard_probs.get(context_id)
                    if p_discard >= p or context_id == discard_id:
                        continue
                    else:
                        inputs.append(target_id)
                        outputs.append(context_id)

        torch_input = torch.tensor(inputs, dtype=torch.long)
        torch_output = torch.tensor(outputs, dtype=torch.long)

        return torch_input, torch_output

In [12]:
class Word2Vec(nn.Module):
    def __init__(self, map: TokenMap, params: Params, neg_sd = 'unigram_exp'):
        super().__init__()
        
        self.map = map
        self.params = params
        self.target_embedding = nn.Embedding(
            self.map.dim_vocab,
            params.dim_embedding,
            max_norm=params.max_norm_embedding
            )
        self.context_embedding = nn.Embedding(
            self.map.dim_vocab,
            params.dim_embedding,
            max_norm=params.max_norm_embedding
            )
        self.neg_sd = neg_sd

    def forward_target(self, t):
        targets = self.target_embedding(t)
        return targets

    def forward_context(self, c):
        contexts = self.context_embedding(c)
        return contexts

    def forward_neg_sample(self, batch_size, n):
        if self.neg_sd =='uniform':
            neg_dist = torch.ones(self.map.dim_vocab)
        else: 
            # default : 'unigram_exp' 
            neg_dist = self.get_unigram_exp()
            
        neg_samples = torch.multinomial(neg_dist,
                                        batch_size * n,
                                        replacement=True)
        
        neg_samples = neg_samples.to(self.params.device)        
        neg_samples = self.context_embedding(neg_samples).view(batch_size, n, self.params.dim_embedding)        
        return neg_samples


    def get_unigram_exp(self):
        f = self.map.by_index.copy()
        f.pop(self.map.discard_id)
        f = dict(f.values())
        freqs = np.array(sorted(f.values(), reverse=True))
        unigram_dist = freqs/np.nansum(freqs)
        assert self.map.total_tokens == np.nansum(freqs), 'Total number of tokens inconsistent!'
        
        negdist = torch.from_numpy(unigram_dist**(self.params.neg_exponent)/np.sum(unigram_dist**(self.params.neg_exponent)))
        return negdist    
        
    
    def get_embedding_norms(self, embedding_vecs):
        norms = embedding_vecs.pow(2).sum(dim=1).sqrt().unsqueeze(0)
        return norms
    

    def find_closest_k_words(self, ttokens, topk):

        embedding = self.target_embedding.weight
        embedding_norms = self.get_embedding_norms(embedding)
        test = self.target_embedding(ttokens)
        test_norms = self.get_embedding_norms(test)
        test_vecs = self.target_embedding(ttokens)

        if __debug__ > 1:
            print(embedding.shape)
            print(embedding_norms.shape)
            print(test.shape)  
            print(test_norms.shape)
        
        similarities = torch.mm(test_vecs/test_norms.t(), embedding.t()/embedding_norms)
        topk_dists, topk_ids = similarities.topk(topk)
        
        print("\n-----------")  
        for i, id in enumerate(ttokens):
            print(self.map.get_token(id.item()) + " || ",end='')
            dists = [d.item() for d in topk_dists[i]][1:]
            topk_words = [self.map.get_token(k.item()) for k in topk_ids[i]][1:]
            for j, (w, sim) in enumerate(zip(topk_words,dists)):
                #print(w, sim)
                print(f"{w} ({sim:.3f})", end=' ')
            print('\n')
        print("-----------")        
        
        return 

In [13]:
# https://arxiv.org/pdf/1310.4546

class NegativeSamplingLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, inputs, outputs_pos, outputs_neg):
        
        b, d = inputs.shape        
        inputs = inputs.view(b, d, 1)        
        outputs_pos = outputs_pos.view(b, 1, d)
        
        # true context       
        pos_loss = torch.bmm(outputs_pos, inputs).sigmoid().log()
        pos_loss = pos_loss.squeeze()
        # false context
        neg_loss = torch.bmm(outputs_neg.neg(), inputs).sigmoid().log()
        neg_loss = neg_loss.squeeze().sum(1)  

        return -(pos_loss + neg_loss).mean()

In [14]:
class Trainer:
    def __init__(self, model: Word2Vec, params: Params, optimizer,
                 train_iter, valid_iter, map: TokenMap, method: BatchTool):
        self.model = model
        self.params = params
        self.optimizer = optimizer
        self.map = map
        self.train_iter = train_iter
        self.valid_iter = valid_iter
        self.method = method

        self.epoch_train_mins = {}
        self.loss = {"train": [], "valid": []}

        # sending all to device
        self.model.to(self.params.device)
        self.params.criterion.to(self.params.device)
        self.test_tokens = None


    def train(self):
        self.do_test()
        for epoch in range(self.params.n_epochs):
            # load data
            self.train_dataloader = DataLoader(
                self.train_iter,
                batch_size=self.params.batch_size,
                shuffle=False,
                collate_fn=self.method.collate_fn
            )
            self.valid_dataloader = DataLoader(
                self.valid_iter,
                batch_size=self.params.batch_size,
                shuffle=False,
                collate_fn=self.method.collate_fn
            )
            
            # train model
            st_time = monotonic()
            self._train_epoch()
            self.epoch_train_mins[epoch] = round((monotonic()-st_time)/60, 1)

            # validate model
            self._validate_epoch()
            print(f"""Epoch: {epoch+1}/{self.params.n_epochs}\n""",
            f"""    Train Loss: {self.loss['train'][-1]:.2}\n""",
            f"""    Valid Loss: {self.loss['valid'][-1]:.2}\n""",
            f"""    Training Time (mins): {self.epoch_train_mins.get(epoch)}"""
            """\n"""
            )
            self.do_test()

            if self.params.checkpoint_frequency:
                self._save_checkpoint(epoch)
    

    def _train_epoch(self):
        self.model.train()
        running_loss = []

        for i, batch_data in enumerate(self.train_dataloader, 1):
            
            inputs, outputs = batch_data[0], batch_data[1]
            inputs, outputs = inputs.to(self.params.device), outputs.to(self.params.device)
            
            targets = self.model.forward_target(inputs)
            contexts_pos = self.model.forward_context(outputs)
            contexts_neg = self.model.forward_neg_sample(inputs.shape[0], params.n_neg_samples)            
            loss = self.params.criterion(targets, contexts_pos, contexts_neg)
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            running_loss.append(loss.item())

        epoch_loss = np.mean(running_loss)
        self.loss['train'].append(epoch_loss)

    def _validate_epoch(self):
        self.model.eval()
        running_loss = []

        with torch.no_grad():
            for i, batch_data in enumerate(self.valid_dataloader, 1):

                inputs, outputs = batch_data[0], batch_data[1]
                inputs, outputs = inputs.to(self.params.device), outputs.to(self.params.device)
            
                targets = self.model.forward_target(inputs)
                contexts_pos = self.model.forward_context(outputs)
                contexts_neg = self.model.forward_neg_sample(inputs.shape[0], params.n_neg_samples)            
                loss = self.params.criterion(targets, contexts_pos, contexts_neg)

                running_loss.append(loss.item())

            epoch_loss = np.mean(running_loss)
            self.loss['valid'].append(epoch_loss)

    def _save_checkpoint(self, epoch):
        """Save model checkpoint to `self.model_dir` directory"""
        epoch_num = epoch + 1
        if epoch_num % self.params.checkpoint_frequency == 0:
            model_path = "checkpoint_{}.pt".format(str(epoch_num).zfill(3))
            model_path = os.path.join(self.params.model_dir, model_path)
            torch.save(self.model, model_path)

    def save_model(self):
        """Save final model to `self.model_dir` directory"""
        model_path = os.path.join(self.params.model_dir, "model.pt")
        torch.save(self.model, model_path)

    def save_loss(self):
        """Save train/val loss as json file to `self.model_dir` directory"""
        loss_path = os.path.join(self.params.model_dir, "loss.json")
        with open(loss_path, "w") as fp:
            json.dump(self.loss, fp)
    
    def do_test(self, topk: int = 5):
        sampling_window=100
        test_size = 10
        if self.test_tokens != None:
            ttokens= self.test_tokens
            for w in self.test_tokens:
                idw = self.map.get_index(w)
                if idw == self.map.discard_id:
                    print(f"Word {w} not in Vocabulary")
            return
        else:
            ttokens= np.array(random.sample(range(sampling_window), test_size//2)) # high frequency tokens
            ttokens=np.append(ttokens,random.sample(range(1000,1000+sampling_window), test_size//2)) #low frequency tokens

        ttokens = torch.tensor(ttokens, dtype=torch.long).to(self.params.device)

        self.model.find_closest_k_words(ttokens, topk)


In [15]:
text = text.lower()
it = iter(text.split())
params = Params()
params.criterion = NegativeSamplingLoss()
os.makedirs(params.model_dir, exist_ok=True)

In [16]:
tmap = build_tokenmap(it, params)
batchtool = BatchTool(tmap, params)
w2v = Word2Vec(tmap, params)
optimizer = torch.optim.Adam(params = w2v.parameters())

In [17]:
trainer = Trainer(
        model=w2v,
        params=params,
        optimizer=optimizer,
        train_iter=sentences[:int(0.8*len(sentences))],
        valid_iter=sentences[int(0.8*len(sentences)):],
        map=tmap,
        method =batchtool
    )
trainer.train()


-----------
they || large (0.197) at (0.176) research (0.175) governor (0.166) 

as || extended (0.189) after (0.169) temporary (0.164) apply (0.160) 

in || summer (0.166) maintaining (0.164) third (0.154) unless (0.152) 

no || boosts (0.180) platform (0.158) post (0.156) jinxed (0.150) 

story || underwent (0.200) article (0.181) title (0.172) erase (0.154) 

plot || two (0.191) unlocked (0.170) arms (0.169) surrendering (0.154) 

novel || guarantee (0.211) scanned (0.186) weapon (0.173) offenders (0.163) 

receives || several (0.180) gameplay (0.175) gallia (0.169) precious (0.161) 

option || inadequate (0.212) form (0.200) them (0.176) linear (0.164) 

plausible || unit (0.194) did (0.176) antiquities (0.165) united (0.164) 

-----------
Epoch: 1/50
     Train Loss: 4.1e+01
     Valid Loss: 4.2e+01
     Training Time (mins): 0.0


-----------
governor || 4 (0.182) walker (0.181) they (0.166) 10 (0.165) 

playstation || composer (0.211) item (0.201) switch (0.187) shortly (0.164)

In [18]:
if _debug_ == True:
    
    in_ , out_ =batchtool.collate_fn(sentences)
    dataloader = DataLoader(
                sentences[:200],
                batch_size=params.batch_size,
                shuffle=False,
                collate_fn=batchtool.collate_fn
            )
    dist = w2v.get_unigram_exp()

        
    print("TokenMap check :",' index(\'the\')=', tmap.get_index('the'), ', index(\'<>\')', tmap.get_index('<>'), 
          ', token(10)= ', tmap.get_token(10), ', dim(vocab)= ',tmap.dim_vocab,
          ', \n\t\t frequency(\'a\')= ', tmap.get_frequency('a'), ', frequency(\'<>\')= ', tmap.get_frequency('<>'), 
          ', frequency(\'.\')= ', tmap.get_frequency('.'))
    print("BatchTool check :",batchtool.frequency_from_percentile(50), batchtool.get_discard_probs()[100])
    print(f"\t\t  dim(input) {len(in_)} dim(output) {len(out_)}")
    print("DataLoader check : ")
    for i, batch in enumerate(dataloader, 1):
        inputs = batch[0]
        outputs = batch[1]
        print('\t',i, inputs[:8], outputs[:8])
    print("W2V check : negative sampling distribution \n\t", dist[:6])

TokenMap check :  index('the')= 0 , index('<>') 1344 , token(10)=  game , dim(vocab)=  1345 , 
		 frequency('a')=  82 , frequency('<>')=  nan , frequency('.')=  nan
BatchTool check : 0.00023020257826887662 0
		  dim(input) 18232 dim(output) 18232
DataLoader check : 
	 1 tensor([136, 136, 136, 136, 136,  32,  32,  32]) tensor([  7,  11,  31,  32,  48,  11,  31, 136])
	 2 tensor([455, 455, 455, 455, 455, 455, 349, 349]) tensor([ 72,  96,  18, 349, 151,  86,  96,  18])
	 3 tensor([ 15,  15,  15,  15,  15, 656, 656, 656]) tensor([ 72,  10, 656, 197, 353,   0,  10,  15])
	 4 tensor([67, 67, 67, 67, 67, 67, 67, 27]) tensor([921,  89, 264,  27, 111,   4, 186,  89])
	 5 tensor([ 84,  84,  84,  84,  84,  84,  84, 243]) tensor([1265,   12,   20,  794,  243,  495,  138,   12])
	 6 tensor([363, 363, 363, 363, 363, 363, 363, 895]) tensor([  9,  33, 308,  79, 895, 215, 392,  33])
	 7 tensor([245, 245, 245, 245, 245, 245, 245, 245]) tensor([  36,    6,  343,    3, 1194,   99,   14,    1])
	 8 tensor(