In [6]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
#from tensorboardX import SummaryWriter

%load_ext autoreload
%autoreload 2

from model import Net,CRFs

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### 测试数据

In [7]:
training_data = [(
    "the wall street journal reported today that apple corporation made money".split(),
    "B I I I O O O B I O O".split()
), (
    "georgia tech is a university in georgia".split(),
    "B I O O O O B".split()
)]

# 创建字典集合
vocab = set()
for i in range(len(training_data)):
    vocab.update(training_data[i][0])

tags = set()
for i in range(len(training_data)):
    tags.update(training_data[i][1])

# 创建索引表
i2w = {i: w for i, w in enumerate(vocab)}
w2i = {w: i for i, w in i2w.items()}
i2t = {i: w for i, w in enumerate(tags)}
t2i = {w: i for i, w in i2t.items()}

def sent2tensor(sent): # 句子转tensor
    ls = [w2i[w] for w in sent]
    ts = torch.LongTensor(ls).view(1,-1)
    return ts

def tags2tensor(tags): #标注转tensor
    ls = [t2i[w] for w in tags]
    ts = torch.LongTensor(ls).view(1,-1)
    return ts

dataloader=[]

x0 = sent2tensor(training_data[0][0])
x1= sent2tensor(training_data[1][0])
y0= tags2tensor(training_data[0][1])
y1= tags2tensor(training_data[1][1])

dataloader.append((x0,y0))
dataloader.append((x1,y1))

#遍历和打印数据
for i,(x,y) in enumerate(dataloader):
    print("x{}:{}\ny{}:{}".format(i,x,i,y))

x0:
   15    11     6    10     5     0     9     1     2    13     4
[torch.LongTensor of size 1x11]

y0:
    1     0     0     0     2     2     2     1     0     2     2
[torch.LongTensor of size 1x11]

x1:
    8    14    12     7     3    16     8
[torch.LongTensor of size 1x7]

y1:
    1     0     2     2     2     2     1
[torch.LongTensor of size 1x7]



# 模型

In [9]:
vocab_size = len(vocab) # 字典长度
embed_dim = 64 # 嵌入的维度
h_dim = 64 # 隐藏层的维度
tag_size = len(tags) # 标记的种类

net = Net(vocab_size, embed_dim, h_dim, tag_size)

x = Variable(dataloader[0][0])
y = Variable(dataloader[0][1])

out = net(x,y)
out # output: batch,seq,dim

IndexError: index 4 from broadcast indexer is out of range for dimension 1 (of size 4)

# 训练

In [11]:
from train import train
train(net, dataloader, num_epochs=5000, print_ever=200)

epoch:0:loss:1.0343163311481476
epoch:200:loss:0.0006844670133432373
epoch:400:loss:0.00020297246373957023
epoch:600:loss:9.734709237818606e-05
epoch:800:loss:5.6534732721047476e-05
epoch:1000:loss:3.636902238213224e-05
epoch:1200:loss:2.491845498298062e-05
epoch:1400:loss:1.7806871710490668e-05
epoch:1600:loss:1.310716470470652e-05
epoch:1800:loss:9.85917063189845e-06
epoch:2000:loss:7.537576948379865e-06
epoch:2200:loss:5.834984222019557e-06
epoch:2400:loss:4.560754632620956e-06
epoch:2600:loss:3.5922728329751408e-06
epoch:2800:loss:2.8466412231864524e-06
epoch:3000:loss:2.2667498456030444e-06
epoch:3200:loss:1.8121363041245786e-06
epoch:3400:loss:1.4535508512381057e-06
epoch:3600:loss:1.1685967820085352e-06
epoch:3800:loss:9.4201215006251e-07
epoch:4000:loss:7.607190752878523e-07
epoch:4200:loss:6.153274654252527e-07
epoch:4400:loss:4.991909037244113e-07
epoch:4600:loss:4.0345894092297385e-07
epoch:4800:loss:3.277596078987699e-07


# 预测

In [13]:
x = Variable(dataloader[0][0])
y = Variable(dataloader[0][1])
print(x,y)
net.viterbi(x)

Variable containing:
   16     9     6    11    15     4     0    13    14     7    10
[torch.LongTensor of size (1,11)]
 Variable containing:
    0     1     1     1     2     2     2     0     1     2     2
[torch.LongTensor of size (1,11)]



(Variable containing:
   10.4200   -4.3099   -5.2768
    6.2078   21.0442    4.1433
   16.5355   32.2205   15.8289
   27.2963   43.3282   28.0949
   36.4328   39.0461   54.4290
   48.9578   47.1201   66.2943
   62.3261   58.3926   77.4020
   88.3678   72.8274   73.1027
   83.5533   99.0469   83.3757
   92.2646   94.6702  109.5445
  104.4207  103.7436  119.3424
 [torch.FloatTensor of size (11,3)], Variable containing:
     0     1     1     1     2     2     2     0     1     2     2
 [torch.LongTensor of size (1,11)])