# word2vec 模型核心代码

In [1]:
import torch
from torch import nn
from torch.nn import functional as F
class word2vec(nn.Module):
    def __init__(self, embedding_size, embedding_dim):
        super(word2vec, self).__init__()
        self.embedding_size = embedding_size
        self.embedding_dim = embedding_dim
        self.embedding_u = nn.Embedding(embedding_size, embedding_dim)
        self.embedding_v = nn.Embedding(embedding_size, embedding_dim)
        self.init_weight()
    
    def init_weight(self):
        initrange = 0.5 / self.embedding_dim
        self.embedding_u.weight.data.uniform_(-initrange, initrange)
        self.embedding_v.weight.data.uniform_(-0, 0)

    def forward(self, pos_u, pos_v, neg_v):
        emb_u = self.embedding_u(pos_u)
        emb_v = self.embedding_v(pos_v)
        score = torch.mul(emb_u, emb_v).squeeze()
        score = torch.sum(score, dim=1)
        score = F.logsigmoid(score)
        neg_emb_v = self.embedding_v(neg_v)
        neg_score = torch.bmm(neg_emb_v, emb_u.unsqueeze(2)).squeeze()
        neg_score = F.logsigmoid(-1 * neg_score)
        return -1 * (torch.sum(score)+torch.sum(neg_score))