### Implementation of Word2Vec through Skip Gram and CBOW

In [19]:
import torch 
import torch.nn.functional as F 

#### Skip Gram

In [33]:
class skip_gram_neg_sampling (torch.nn.Module):
    def __init__ (self, embedding_size, vocab_size, device, negative_samples=10):
        super().__init__()  # parent class 
        self.embedding_size= embedding_size 
        self.vocab_size= vocab_size 
        self.device= device 
        self.neg_samples= negative_samples

        # embbeding matrix for the central word 
        self.embedding_input= torch.nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_size)
        self.context_embedding= torch.nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_size)

        # normalize the embedding weights to be bounded [-1,1]
        self.embedding_input.weight.data.uniform_(-1,1)
        self.context_embedding.weight.data.uniform_(-1,1)

    # skip gram models require the central word and predict context words 
    def forward(self, input_word, context_word):
        # shape is batch_size x hidden_dim
        embed_input= self.embedding_input(input_word)   # embed input word 
        embed_context= self.context_embedding(context_word)
        embed_product= torch.mul(embed_input, embed_context)    # multiply the weight matrices 

        embed_product= torch.sum(embed_product, dim=1)  # sum across hidden dimensions 

        out_loss= F.logsigmoid(embed_product)   # compute loss

        # negative sampling has additional loss terms 
        # generate random noise -> get # samples -> generate random word --> context embed --> condense to the right size 
        
        noise_dist= torch.ones(self.vocab_size) # generate noise
        if self.neg_samples>0: 
            num_samples= context_word.shape[0]*  self.neg_samples    # find # of negative samples 
            negative_samples= torch.multinomial(noise_dist,num_samples=num_samples, replacement=True)   
            negative_samples= negative_samples.view(context_word.shape[0],self.neg_samples).to(self.device) # bs x num_neg samples 

            embed_neg = self.embedding_size(negative_samples)   # batch_size x num_neg samples x embed dimension
            embed_neg_product= torch.bmm(embed_neg.neg(), embed_input.unsqueeze(2)) # batch_size x num samples x 1  

            noise_loss= F.logsigmoid(embed_neg_product).squeeze(2).sum(1)   # batch_size

            total_loss= - (out_loss+noise_loss).mean()

            return total_loss