In [1]:
import collections
import math
import random
import sys
import time
import os
import numpy as np
import torch
from torch import nn
import torch.utils.data as Data

# 处理数据集

In [2]:
assert 'ptb.train.txt' in os.listdir('./data')
with open('./data/ptb.train.txt','r') as f:
    lines=f.readlines()
    raw_dataset=[st.split() for st in lines]
'# sentences:%d'%len(raw_dataset)

'# sentences:42068'

In [51]:
for st in raw_dataset[:3]:
    print('#tokens:',len(st),st[:5])

#tokens: 24 ['aer', 'banknote', 'berlitz', 'calloway', 'centrust']
#tokens: 15 ['pierre', '<unk>', 'N', 'years', 'old']
#tokens: 11 ['mr.', '<unk>', 'is', 'chairman', 'of']


## 建立词语索引

In [3]:
counter=collections.Counter([tk for st in raw_dataset for tk in st])
counter=dict(filter(lambda x:x[1]>=5,counter.items()))

In [4]:
idx_to_token=[tk for tk,_ in counter.items()]
token_to_idx={tk:idx for idx,tk in enumerate(idx_to_token)}
dataset=[[token_to_idx[tk] for tk in st if tk in token_to_idx] for st in raw_dataset]
num_tokens=sum([len(st) for st in dataset])
'#tokens:%d'%num_tokens

'#tokens:887100'

## 二次采样

In [5]:
def discard(idx):
    return random.uniform(0,1)<1-math.sqrt(1e-4/counter[idx_to_token[idx]]*num_tokens)
subsampled_dataset=[[tk for tk in st if not discard(tk)] for st in dataset]
'#tokens:%d'%sum([len(st) for st in subsampled_dataset])

'#tokens:375360'

In [6]:
def compare_counts(token):
    return '#%s:before=%d,after=%d'%(token,sum([st.count(token_to_idx[token]) for st in dataset]),sum([st.count(token_to_idx[token]) for st in subsampled_dataset]))
compare_counts('the')

'#the:before=50770,after=2083'

In [7]:
compare_counts('join')

'#join:before=45,after=45'

## 提取中心词和背景词

In [8]:
def get_centers_and_contexts(dataset,max_window_size):
    centers,contexts=[],[]
    for st in dataset:
        if len(st)<2:
            continue
        centers +=st
        for center_i in range(len(st)):
            window_size=random.randint(1,max_window_size)
            indices=list(range(max(0,center_i-window_size),min(len(st),center_i+1+window_size)))
            indices.remove(center_i)
            contexts.append([st[idx] for idx in indices])
    return centers,contexts

In [9]:
tiny_dataset=[list(range(7)),list(range(7,10))]
print('dataset',tiny_dataset)
for center,context in zip(*get_centers_and_contexts(tiny_dataset,2)):
    print('center',center,'has contexts',context)

dataset [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]
center 0 has contexts [1]
center 1 has contexts [0, 2, 3]
center 2 has contexts [1, 3]
center 3 has contexts [2, 4]
center 4 has contexts [2, 3, 5, 6]
center 5 has contexts [3, 4, 6]
center 6 has contexts [4, 5]
center 7 has contexts [8]
center 8 has contexts [7, 9]
center 9 has contexts [8]


In [10]:
all_centers,all_contexts=get_centers_and_contexts(subsampled_dataset,5)

# 负采样

In [11]:
def get_negatives(all_contexts,sampling_weights,K):
    all_negatives,neg_candidates,i=[],[],0
    population=list(range(len(sampling_weights)))
    for contexts in all_contexts:
        negatives=[]
        while len(negatives)<len(contexts)*K:
            if i == len(neg_candidates):
                i,neg_candidates=0,random.choices(population,sampling_weights,k=int(1e5))
            neg,i=neg_candidates[i],i+1
            if neg not in set(contexts):
                negatives.append(neg)
        all_negatives.append(negatives)
    return all_negatives
sampling_weights=[counter[w]**0.75 for w in idx_to_token]
all_negatives=get_negatives(all_contexts,sampling_weights,5)

# 读取数据

In [12]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self,centers,contexts,negatives):
        assert len(centers)==len(contexts)==len(negatives)
        self.centers=centers
        self.contexts=contexts
        self.negatives=negatives
    def __getitem__(self,index):
        return (self.centers[index],self.contexts[index],self.negatives[index])
    def __len__(self):
        return len(self.centers)

In [17]:
def batchify(data):
    max_len=max(len(c)+len(n) for _,c,n in data)
    centers,contexts_negatives,masks,labels=[],[],[],[]
    for center,context,negative in data:
        cur_len=len(context)+len(negative)
        centers+=[center]
        contexts_negatives+=[context+negative+[0]*(max_len-cur_len)]
        masks+=[[1]*cur_len+[0]*(max_len-cur_len)]
        labels+=[[1]*len(context)+[0]*(max_len-len(context))]
    return (torch.tensor(centers).view(-1,1),torch.tensor(contexts_negatives),torch.tensor(masks),torch.tensor(labels))

In [21]:
batch_size=512
num_workers=0 if sys.platform.startswith('win32') else 4

dataset=MyDataset(all_centers,all_contexts,all_negatives)
data_iter=Data.DataLoader(dataset,batch_size,shuffle=True,collate_fn=batchify,num_workers=num_workers)

for batch in data_iter:
    for name,data in zip(['centers','contexts_negatives','masks','labels'],batch):
        print(name,'shape:',data.shape)
    break

centers shape: torch.Size([512, 1])
contexts_negatives shape: torch.Size([512, 60])
masks shape: torch.Size([512, 60])
labels shape: torch.Size([512, 60])


# 跳字模型

## 嵌入层

In [22]:
embed=nn.Embedding(num_embeddings=20,embedding_dim=4)
embed.weight

Parameter containing:
tensor([[ 1.5031, -0.1519,  1.3063,  0.2633],
        [-2.1616, -1.2789,  0.1801,  1.6020],
        [ 1.3239, -1.4238,  0.3973,  0.3316],
        [-1.6827, -1.6508,  0.3905, -2.0055],
        [ 0.7390,  0.0503,  0.1149, -0.8774],
        [-1.0615, -0.7423, -0.2248, -0.5482],
        [-0.1959,  1.0133,  0.7691,  0.3495],
        [-0.9932, -0.0643, -1.4158,  1.8086],
        [-0.9041,  0.9900,  2.1469, -0.9400],
        [ 1.1303,  0.1270,  1.7050, -2.0273],
        [-0.2039, -0.2043,  0.9505,  0.7859],
        [ 0.6883, -1.2262,  1.1897,  0.0901],
        [-1.0786, -0.8114,  0.4416, -0.1583],
        [ 0.3023, -0.6352,  0.1206,  0.2482],
        [-2.0119,  1.2533, -0.8425,  0.3436],
        [ 1.2651,  0.3204,  0.7677,  0.0709],
        [ 2.2074, -0.6789, -0.1773, -0.7628],
        [ 1.5231,  0.2384,  1.6155, -0.2701],
        [ 1.1645, -1.1776, -1.4131, -2.0500],
        [ 0.0127,  1.1089,  1.8409,  0.6716]], requires_grad=True)

In [26]:
x=torch.tensor([[1,2,3],[4,5,6]],dtype=torch.long)
embed(x)

tensor([[[-2.1616, -1.2789,  0.1801,  1.6020],
         [ 1.3239, -1.4238,  0.3973,  0.3316],
         [-1.6827, -1.6508,  0.3905, -2.0055]],

        [[ 0.7390,  0.0503,  0.1149, -0.8774],
         [-1.0615, -0.7423, -0.2248, -0.5482],
         [-0.1959,  1.0133,  0.7691,  0.3495]]], grad_fn=<EmbeddingBackward>)

## 小批量乘法

In [27]:
x=torch.ones((2,1,4))
y=torch.ones((2,4,6))
torch.bmm(x,y).shape

torch.Size([2, 1, 6])

## 跳字模型前向计算

In [28]:
def skip_gram(center,contexts_and_negatives,embed_v,embed_u):
    v=embed_v(center)
    u=embed_u(contexts_and_negatives)
    pred=torch.bmm(v,u.permute(0,2,1))
    return pred

# 训练模型

## 二元交叉熵损失函数

In [29]:
class SigmoidBinaryCrossEntropyLoss(nn.Module):
    def __init__(self):
        super(SigmoidBinaryCrossEntropyLoss,self).__init__()
    def forward(self,inputs,targets,mask=None):
        inputs,targets,mask=inputs.float(),targets.float(),mask.float()
        res=nn.functional.binary_cross_entropy_with_logits(inputs,targets,reduction='none',weight=mask)
        return res.mean(dim=1)
loss=SigmoidBinaryCrossEntropyLoss()

In [30]:
pred=torch.tensor([[1.5,0.3,-1,2],[1.1,-0.6,2.2,0.4]])
label=torch.tensor([[1,0,0,0],[1,1,0,0]])
mask=torch.tensor([[1,1,1,1],[1,1,1,0]])
loss(pred,label,mask)*mask.shape[1]/mask.float().sum(dim=1)

tensor([0.8740, 1.2100])

In [31]:
def sigmd(x):
    return -math.log(1/(1+math.exp(-x)))
print('%.4f'%((sigmd(1.5)+sigmd(-0.3)+sigmd(1)+sigmd(-2))/4))
print('%.4f'%((sigmd(1.1)+sigmd(-0.6)+sigmd(-2.2))/3))

0.8740
1.2100


## 初始化模型参数

In [32]:
embed_size=100
net=nn.Sequential(
    nn.Embedding(num_embeddings=len(idx_to_token),embedding_dim=embed_size),
    nn.Embedding(num_embeddings=len(idx_to_token),embedding_dim=embed_size)
)

## 定义训练函数

In [33]:
def train(net,lr,num_epochs):
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('train on',device)
    net=net.to(device)
    optimizer=torch.optim.Adam(net.parameters(),lr=lr)
    for epoch in range(num_epochs):
        start,l_sum,n=time.time(),0.0,0
        for batch in data_iter:
            center,context_negative,mask,label=[d.to(device) for d in batch]
            pred=skip_gram(center,context_negative,net[0],net[1])
            l=(loss(pred.view(label.shape),label,mask)*mask.shape[1]/mask.float().sum(dim=1)).mean()
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            l_sum+=l.cpu().item()
            n+=1
        print('epoch %d,loss %.2f,time %.2fs'%(epoch+1,l_sum/n,time.time()-start))

In [34]:
train(net,0.01,10)

train on cuda
epoch 1,loss 1.97,time 26.72s
epoch 2,loss 0.62,time 24.37s
epoch 3,loss 0.45,time 24.38s
epoch 4,loss 0.39,time 24.30s
epoch 5,loss 0.37,time 24.35s
epoch 6,loss 0.35,time 24.26s
epoch 7,loss 0.34,time 24.23s
epoch 8,loss 0.33,time 25.00s
epoch 9,loss 0.32,time 24.30s
epoch 10,loss 0.32,time 24.12s


In [42]:
print(net[0].weight.shape)
print(net[0].weight[1].shape)

torch.Size([9858, 100])
torch.Size([100])


# 应用词嵌入模型

In [48]:
def get_similar_tokens(query_token,k,embed):
    W=embed.weight.data
    x=W[token_to_idx[query_token]]
    cos=torch.matmul(W,x)/(torch.sum(W*W,dim=1)*torch.sum(x*x)+1e-9).sqrt()
    _,topk=torch.topk(cos,k=k+1)
    topk=topk.cpu().numpy()
    print(topk)
    for i in topk[1:]:
        print('cosine sim=%.3f:%s'%(cos[i],(idx_to_token[i])))
get_similar_tokens('chip',3,net[0])

[1131 1059 8802   55]
cosine sim=0.410:computer
cosine sim=0.403:newsletter
cosine sim=0.387:than
