In [1]:
import sys
import random
import progressbar
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
from SenTree import *
from torch.optim.lr_scheduler import ReduceLROnPlateau

class RecursiveNN(nn.Module):
    def __init__(self, vocabSize, embedSize=100, numClasses=5):
        super(RecursiveNN, self).__init__()
        self.embedding = nn.Embedding(int(vocabSize), embedSize)
        self.W = nn.Linear(2*embedSize, embedSize, bias=True)
        self.projection = nn.Linear(embedSize, numClasses, bias=True) # 对每个节点进行五分类的预测，将其softmax即为各个种类的概率
        self.activation = nn.ReLU()
        self.nodeProbList = [] # 用来存储各个节点的概率值
        self.labelList = [] # 用来存储各个节点的正确值
        self.crossentropy = nn.CrossEntropyLoss()

    def traverse(self, node):
        '''
        用来递归地获取每个节点的概率值
        并保存在nodeProbList
        并将对应的label值存在labelList中
        返回输入node的激活值
        '''
        if node.isLeaf(): currentNode = self.activation(self.embedding(Var(torch.LongTensor([node.getLeafWord()])))) 
        # 对于叶节点，直接计算embedding后的激活值，即f(a)
        else: currentNode = self.activation(self.W(torch.cat((self.traverse(node.left()),self.traverse(node.right())),1)))
        # 否则将左右节点连接(cat)，在经过一个线性层，即f(W * [a b])，相当于这里的父节点的embedding为[a b]
        self.nodeProbList.append(self.projection(currentNode))
        self.labelList.append(torch.LongTensor([node.label()]))
        return currentNode

    def forward(self, x):
        '''
        前向传播 返回各个节点的预测值
        '''
        self.nodeProbList = []
        self.labelList = []
        self.traverse(x)
        self.labelList = Var(torch.cat(self.labelList))
        return torch.cat(self.nodeProbList)

    def getLoss(self, tree):
        nodes = self.forward(tree)
        predictions = nodes.max(dim=1)[1]
        loss = self.crossentropy(input = nodes, target = self.labelList)
#         loss = F.cross_entropy(input=nodes, target=self.labelList)
        return predictions,loss

    def evaluate(self, trees):
        pbar = progressbar.ProgressBar(widgets=widgets, maxval=len(trees)).start()
        n = nAll = correctRoot = correctAll = 0.0
        for j, tree in enumerate(trees):
            predictions,loss = self.getLoss(tree)
            correct = (predictions.data==self.labelList.data)
            correctAll += correct.sum()
            nAll += correct.squeeze().size()[0]
            correctRoot += correct.squeeze()[-1]
            n += 1
            pbar.update(j)
        pbar.finish()
        return correctRoot.item() / n, correctAll.item() /nAll

def Var(v):
    if CUDA: return Variable(v.cuda())
    else: return Variable(v)
    
# 使用save保存模型，并转换到cpu上保存，使用的时候在转换到gpu上
def save_model(model, filename):
    state = model.state_dict()
    for key in state: state[key] = state[key].clone().cpu()
    torch.save(state, filename)

In [2]:
class TreeLSTM(nn.Module):
    def __init__(self, vocabSize, hdim=100, numClasses=5):
        super(TreeLSTM, self).__init__()
        self.embedding = nn.Embedding(int(vocabSize), hdim)
        self.Wi = nn.Linear(hdim, hdim, bias=True)
        self.Wo = nn.Linear(hdim, hdim, bias=True)
        self.Wu = nn.Linear(hdim, hdim, bias=True)
        self.Ui = nn.Linear(2 * hdim, hdim, bias=True)
        self.Uo = nn.Linear(2 * hdim, hdim, bias=True)
        self.Uu = nn.Linear(2 * hdim, hdim, bias=True)
        self.Uf1 = nn.Linear(hdim, hdim, bias=True)
        self.Uf2 = nn.Linear(hdim, hdim, bias=True)
        self.projection = nn.Linear(hdim, numClasses, bias=True)
        self.activation = nn.ReLU()
        self.nodeProbList = []
        self.labelList = []
        self.crossentropy = nn.CrossEntropyLoss()
        

    def traverse(self, node):
        if node.isLeaf():
            e = self.embedding(Var(torch.LongTensor([node.getLeafWord()])))
            i = torch.sigmoid(self.Wi(e))
            o = torch.sigmoid(self.Wo(e))
            u = self.activation(self.Wu(e))
            c = i * u
        else:
            leftH,leftC = self.traverse(node.left())
            rightH,rightC = self.traverse(node.right())
            e = torch.cat((leftH, rightH), 1)
            i = torch.sigmoid(self.Ui(e))
            o = torch.sigmoid(self.Uo(e))
            u = self.activation(self.Uu(e))
            c = i * u + torch.sigmoid(self.Uf1(leftH)) * leftC + torch.sigmoid(self.Uf2(rightH)) * rightC # 新的记忆
        h = o * self.activation(c)
        self.nodeProbList.append(self.projection(h))
        self.labelList.append(torch.LongTensor([node.label()]))
        return h,c

    def forward(self, x):
        self.nodeProbList = []
        self.labelList = []
        self.traverse(x)
        self.labelList = Var(torch.cat(self.labelList))
        return torch.cat(self.nodeProbList)

    def getLoss(self, tree):
        nodes = self.forward(tree)
        predictions = nodes.max(dim=1)[1]
        loss = self.crossentropy(input=nodes, target=self.labelList)
        return predictions,loss

    def evaluate(self, trees):
        pbar = progressbar.ProgressBar(widgets=widgets, maxval=len(trees)).start()
        n = nAll = correctRoot = correctAll = 0.0
        for j, tree in enumerate(trees):
            predictions,loss = self.getLoss(tree)
            correct = (predictions.data==self.labelList.data)
            correctAll += correct.sum()
            nAll += correct.squeeze().size()[0]
            correctRoot += correct.squeeze()[-1]
            n += 1
            pbar.update(j)
        pbar.finish()
        return correctRoot.item() / n, correctAll.item()/nAll

In [3]:
CUDA=True
if len(sys.argv)>1:
    if sys.argv[1].lower()=="cuda": CUDA=True

print("Reading and parsing trees")
# trn = SenTree.getTrees("./trees/train.txt","train.vocab") # 第一次解析的时候需要生成词向量
trn = SenTree.getTrees("./trees/train.txt",vocabIndicesMapFile="train.vocab") # 修改后
dev = SenTree.getTrees("./trees/dev.txt",vocabIndicesMapFile="train.vocab")

max_epochs = 100
widgets = [progressbar.Percentage(), ' ', progressbar.Bar(), ' ', progressbar.ETA()]

Reading and parsing trees


In [5]:
if CUDA: model = TreeLSTM(SenTree.vocabSize).cuda()
# else: model = RecursiveNN(SenTree.vocabSize)
correctRoot, correctAll = model.evaluate(dev)
print(correctRoot)
print(correctAll)

100% |##########################################################| Time: 0:01:02

0.14623069936421434
0.07959562815161532





In [6]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9, dampening=0.0)
# optimizer = torch.optim.SGD(model.parameters(), lr = 0.005)
bestAll=bestRoot=0.0

## use_old_model

In [9]:
use_old_model = input("use old model?(y)")
if use_old_model == 'y':
    model = TreeLSTM(SenTree.vocabSize)
    model_name = input()
#     model_name = 'model/' + model_name + '.model'
    model.load_state_dict(torch.load('model/' + model_name + '.model'))
    model = model.cuda()
    correctRoot, correctAll = model.evaluate(dev)
    print(correctRoot)
    print(correctAll)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9, dampening=0.0)
    optimizer.load_state_dict(torch.load('model/opt_'+ model_name + '.opt'))

use old model?(y)y
15_16test


100% |##########################################################| Time: 0:01:01

0.010899182561307902
0.6486356069196806





In [10]:
scheduler = ReduceLROnPlateau(optimizer,'max', factor = 0.2, patience = 3)

In [13]:
# ReduceLROnPlateau?

In [14]:
max_epochs = 6 * 8

In [17]:
from tensorboardX import SummaryWriter
writer = SummaryWriter()

In [None]:
for e in range(max_epochs):
    print("epoch: ", e)
    pbar = progressbar.ProgressBar(widgets=widgets, maxval=len(trn)).start()
    for step, tree in enumerate(trn):
#         print(step)
        predictions, loss = model.getLoss(tree) # 对每棵树计算loss
        writer.add_scalar('plot/loss',loss,step + 1 + e * len(trn))
        optimizer.zero_grad() # 
        loss.backward()
        clip_grad_norm_(model.parameters(), 5) # 梯度裁剪，防止爆掉https://www.cnblogs.com/lindaxin/p/7998196.html
        optimizer.step()
        pbar.update(step)
    pbar.finish()
    correctRoot, correctAll = model.evaluate(dev)
#     if bestAll<correctAll: bestAll=correctAll
#     if bestRoot<correctRoot: bestRoot=correctRoot
    
    writer.add_scalar('plot/lr', optimizer.param_groups[0]['lr'], e + 1)
    scheduler.step(correctAll)
    if bestAll < correctAll:
        bestAll = correctAll
#         bestAll_model = model # 只是引用。。。
        save_model(model,'run_model/bestAll_model.model')
    if bestRoot<correctRoot:
        bestRoot = correctRoot
#         bestRoot_model = model # 只是引用。。。
        save_model(model,'run_model/bestRoot_model.model')
    print("\nValidation All-nodes accuracy:"+str(correctAll)+"(best:"+str(bestAll)+")")
    print("Validation Root accuracy:" + str(correctRoot)+"(best:"+str(bestRoot)+")")
    random.shuffle(trn) # 随机排列 # 随机梯度下降

  0% |                                                         | ETA:  --:--:--

epoch:  0


100% |##########################################################| Time: 0:25:40
100% |##########################################################| Time: 0:01:02
  0% |                                                         | ETA:  --:--:--


Validation All-nodes accuracy:0.565903442951239(best:0.565903442951239)
Validation Root accuracy:0.1298819255222525(best:0.1298819255222525)
epoch:  1


100% |##########################################################| Time: 0:26:59
100% |##########################################################| Time: 0:01:07
  0% |                                                         | ETA:  --:--:--


Validation All-nodes accuracy:0.7336839819528554(best:0.7336839819528554)
Validation Root accuracy:0.1226158038147139(best:0.1298819255222525)
epoch:  2


100% |##########################################################| Time: 0:25:16
100% |##########################################################| Time: 0:01:02
  0% |                                                         | ETA:  --:--:--


Validation All-nodes accuracy:0.7671001520013512(best:0.7671001520013512)
Validation Root accuracy:0.14713896457765668(best:0.14713896457765668)
epoch:  3


100% |##########################################################| Time: 0:25:13
100% |##########################################################| Time: 0:01:02
  0% |                                                         | ETA:  --:--:--


Validation All-nodes accuracy:0.7786812073250174(best:0.7786812073250174)
Validation Root accuracy:0.16530426884650318(best:0.16530426884650318)
epoch:  4


100% |##########################################################| Time: 0:25:13
100% |##########################################################| Time: 0:01:02
  0% |                                                         | ETA:  --:--:--


Validation All-nodes accuracy:0.759596593239559(best:0.7786812073250174)
Validation Root accuracy:0.16893732970027248(best:0.16893732970027248)
epoch:  5


100% |##########################################################| Time: 0:25:04
100% |##########################################################| Time: 0:01:02
  0% |                                                         | ETA:  --:--:--


Validation All-nodes accuracy:0.7837479190291216(best:0.7837479190291216)
Validation Root accuracy:0.1589464123524069(best:0.16893732970027248)
epoch:  6


100% |##########################################################| Time: 0:25:03
100% |##########################################################| Time: 0:01:01
  0% |                                                         | ETA:  --:--:--


Validation All-nodes accuracy:0.7588486500832389(best:0.7837479190291216)
Validation Root accuracy:0.16076294277929154(best:0.16893732970027248)
epoch:  7


100% |##########################################################| Time: 0:25:03
100% |##########################################################| Time: 0:01:02
  0% |                                                         | ETA:  --:--:--


Validation All-nodes accuracy:0.768909691895674(best:0.7837479190291216)
Validation Root accuracy:0.18982742960944596(best:0.18982742960944596)
epoch:  8


100% |##########################################################| Time: 0:25:27
100% |##########################################################| Time: 0:01:05
  0% |                                                         | ETA:  --:--:--


Validation All-nodes accuracy:0.7688373102999011(best:0.7837479190291216)
Validation Root accuracy:0.17892824704813806(best:0.18982742960944596)
epoch:  9


100% |##########################################################| Time: 0:26:24
100% |##########################################################| Time: 0:01:04
  0% |                                                         | ETA:  --:--:--


Validation All-nodes accuracy:0.7719979733153184(best:0.7837479190291216)
Validation Root accuracy:0.1916439600363306(best:0.1916439600363306)
epoch:  10


100% |##########################################################| Time: 0:26:31
100% |##########################################################| Time: 0:01:03
  0% |                                                         | ETA:  --:--:--


Validation All-nodes accuracy:0.7766062682461939(best:0.7837479190291216)
Validation Root accuracy:0.19255222524977295(best:0.19255222524977295)
epoch:  11


100% |##########################################################| Time: 0:25:42
100% |##########################################################| Time: 0:01:03
  0% |                                                         | ETA:  --:--:--


Validation All-nodes accuracy:0.775641180302555(best:0.7837479190291216)
Validation Root accuracy:0.18891916439600362(best:0.19255222524977295)
epoch:  12


100% |##########################################################| Time: 0:25:44
100% |##########################################################| Time: 0:01:03
  0% |                                                         | ETA:  --:--:--


Validation All-nodes accuracy:0.7742659299828697(best:0.7837479190291216)
Validation Root accuracy:0.1907356948228883(best:0.19255222524977295)
epoch:  13


100% |##########################################################| Time: 0:26:42
100% |##########################################################| Time: 0:01:06
  0% |                                                         | ETA:  --:--:--


Validation All-nodes accuracy:0.7737833860110502(best:0.7837479190291216)
Validation Root accuracy:0.1880108991825613(best:0.19255222524977295)
epoch:  14


100% |##########################################################| Time: 0:26:14
100% |##########################################################| Time: 0:01:08
  0% |                                                         | ETA:  --:--:--


Validation All-nodes accuracy:0.7723840084927739(best:0.7837479190291216)
Validation Root accuracy:0.1771117166212534(best:0.19255222524977295)
epoch:  15


100% |##########################################################| Time: 0:26:26
100% |##########################################################| Time: 0:01:03
  0% |                                                         | ETA:  --:--:--


Validation All-nodes accuracy:0.773952276401187(best:0.7837479190291216)
Validation Root accuracy:0.18346957311534967(best:0.19255222524977295)
epoch:  16


100% |##########################################################| Time: 0:25:36
100% |##########################################################| Time: 0:01:03
  0% |                                                         | ETA:  --:--:--


Validation All-nodes accuracy:0.7724563900885468(best:0.7837479190291216)
Validation Root accuracy:0.17801998183469572(best:0.19255222524977295)
epoch:  17


100% |##########################################################| Time: 0:25:44
100% |##########################################################| Time: 0:01:03
  0% |                                                         | ETA:  --:--:--


Validation All-nodes accuracy:0.7732525876420488(best:0.7837479190291216)
Validation Root accuracy:0.18528610354223432(best:0.19255222524977295)
epoch:  18


100% |##########################################################| Time: 0:25:24
100% |##########################################################| Time: 0:01:02
  0% |                                                         | ETA:  --:--:--


Validation All-nodes accuracy:0.7734697324293677(best:0.7837479190291216)
Validation Root accuracy:0.1807447774750227(best:0.19255222524977295)
epoch:  19


100% |##########################################################| Time: 0:26:17
100% |##########################################################| Time: 0:01:06
  0% |                                                         | ETA:  --:--:--


Validation All-nodes accuracy:0.7735179868265496(best:0.7837479190291216)
Validation Root accuracy:0.18346957311534967(best:0.19255222524977295)
epoch:  20


 39% |#######################                                   | ETA:  0:15:59

## test_model

In [19]:
correctRoot, correctAll = model.evaluate(dev)
print(correctRoot)
print(correctAll)

100% |##########################################################| Time: 0:01:02

0.05449591280653951
0.7269766207445654





In [None]:
bestAll_model = TreeLSTM(SenTree.vocabSize)
bestAll_model.load_state_dict(torch.load('run_model/bestAll_model.model'))
bestAll_model = bestAll_model.cuda()
correctRoot, correctAll = bestAll_model.evaluate(dev)
print(correctRoot)
print(correctAll)

In [None]:
bestRoot_model = TreeLSTM(SenTree.vocabSize)
bestRoot_model.load_state_dict(torch.load('run_model/bestRoot_model.model'))
bestRoot_model = bestRoot_model.cuda()
correctRoot, correctAll = bestRoot_model.evaluate(dev)
print(correctRoot)
print(correctAll)

## save_model

In [20]:
save_model_name = input()
no_list = '\/:*?""<>|'
print('model/' + save_model_name + '.model')
for e in no_list:
    if e in save_model_name:
        print("error name!")

15_16test
model/15_16test.model


In [21]:
save_model(model, 'model/' + save_model_name + '.model')

In [25]:
torch.save(optimizer.state_dict(),'model/opt_'+ save_model_name + '.opt')

In [28]:
# optimizer.load_state_dict(torch.load('model/opt_'+ save_model_name + '.opt'))

In [23]:
# use_old_model = input("use old model?(y)")
# if use_old_model == 'y':
#     old_model = TreeLSTM(SenTree.vocabSize)
#     model_name = input()
#     model_name = 'model/' + model_name + '.model'
#     old_model.load_state_dict(torch.load(model_name))
#     old_model = old_model.cuda()
#     correctRoot, correctAll = old_model.evaluate(dev)
#     print(correctRoot)
#     print(correctAll)

use old model?(y)y
15_16test


100% |##########################################################| Time: 0:01:02

0.05449591280653951
0.7269766207445654





In [53]:
# # 保存模型
# import pickle
# with open('model/model_rlstm','wb') as f:
#     pickle.dump(model,f)
# with open('model/optimizer_rlstm','wb') as f:
#     pickle.dump(optimizer,f)
# with open('model/scheduler_rlstm','wb') as f:
#     pickle.dump(scheduler,f)

In [32]:
# 导入模型
# import pickle
# with open('model/model','rb') as f:
#     old_model = pickle.load(f)
# old_model = old_model.cuda()

In [33]:
correctRoot, correctAll = old_model.evaluate(dev)

100% |##########################################################| Time: 0:00:20


In [34]:
print(correctRoot,correctAll)

0.11625794732061762 0.7159987453856733


In [33]:
# correctAll.item()

0

In [34]:
# bestAll

0.0

In [35]:
# bestRoot

tensor(2, device='cuda:0', dtype=torch.uint8)

In [36]:
# correctRoot

tensor(2, device='cuda:0', dtype=torch.uint8)

In [40]:
# bestRoot<correctRoot

tensor(0, device='cuda:0', dtype=torch.uint8)

In [41]:
# len(model.labelList)

39

In [43]:
# len(dev)

2

In [46]:
# correctRoot

tensor(1, device='cuda:0', dtype=torch.uint8)

In [47]:
# model.labelList

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 1, 1, 2, 1, 0, 0, 1, 1, 0, 0,
        0, 0, 1, 2, 0, 2, 1, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 1, 1, 2, 1], device='cuda:0')

In [48]:
# len(model.labelList)

55

In [49]:
# len(dev[-1])

2

In [50]:
# len(model.nodeProbList)

55

In [51]:
# model.nodeProbList

[tensor([[-1.8794,  0.3595,  3.0406,  0.9783, -2.2877]], device='cuda:0',
        grad_fn=<AddmmBackward>),
 tensor([[-2.2810,  0.1816,  3.4546,  0.6747, -2.4519]], device='cuda:0',
        grad_fn=<AddmmBackward>),
 tensor([[-2.2362, -0.4725,  4.1145,  0.1757, -1.9254]], device='cuda:0',
        grad_fn=<AddmmBackward>),
 tensor([[-1.9478, -0.8599,  2.1633,  1.3635, -0.5603]], device='cuda:0',
        grad_fn=<AddmmBackward>),
 tensor([[-1.3765, -0.5862,  4.0972, -0.0786, -1.9582]], device='cuda:0',
        grad_fn=<AddmmBackward>),
 tensor([[-2.7018, -0.6171,  3.1411,  0.6767, -1.4938]], device='cuda:0',
        grad_fn=<AddmmBackward>),
 tensor([[-1.8839,  0.7276,  2.3320,  0.9264, -2.5472]], device='cuda:0',
        grad_fn=<AddmmBackward>),
 tensor([[-1.4892,  0.3374,  1.7578,  0.6994, -1.3117]], device='cuda:0',
        grad_fn=<AddmmBackward>),
 tensor([[-1.9381,  0.1946,  3.6418,  0.0094, -1.9506]], device='cuda:0',
        grad_fn=<AddmmBackward>),
 tensor([[-0.1261,  0.5602, 

In [52]:
# torch.cat(model.nodeProbList)

tensor([[-1.8794,  0.3595,  3.0406,  0.9783, -2.2877],
        [-2.2810,  0.1816,  3.4546,  0.6747, -2.4519],
        [-2.2362, -0.4725,  4.1145,  0.1757, -1.9254],
        [-1.9478, -0.8599,  2.1633,  1.3635, -0.5603],
        [-1.3765, -0.5862,  4.0972, -0.0786, -1.9582],
        [-2.7018, -0.6171,  3.1411,  0.6767, -1.4938],
        [-1.8839,  0.7276,  2.3320,  0.9264, -2.5472],
        [-1.4892,  0.3374,  1.7578,  0.6994, -1.3117],
        [-1.9381,  0.1946,  3.6418,  0.0094, -1.9506],
        [-0.1261,  0.5602,  2.3632, -0.3496, -2.3496],
        [-2.2362, -0.4725,  4.1145,  0.1757, -1.9254],
        [-3.4456, -0.5500,  4.9093,  1.2386, -1.3084],
        [-2.2386, -0.3675,  2.9028,  0.4377, -0.1677],
        [-0.0703,  0.3329,  1.6183, -0.5959, -1.8925],
        [-2.2487,  0.3247,  3.0476,  0.4387, -1.7265],
        [-1.5775,  0.6223,  1.7728,  0.5999, -1.8320],
        [-3.4416,  0.7038,  3.0970,  0.0816, -0.6747],
        [-1.4288,  0.2090,  1.9174,  0.6519, -1.2996],
        [-

In [54]:
# plist = torch.cat(model.nodeProbList)

In [56]:
# plist.argmax(dim=1)

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2,
        2, 2, 2, 2, 1, 2, 1], device='cuda:0')

In [57]:
# plist.max(dim = 1)[1]

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2,
        2, 2, 2, 2, 1, 2, 1], device='cuda:0')

In [58]:
# model.labelList

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 1, 1, 2, 1, 0, 0, 1, 1, 0, 0,
        0, 0, 1, 2, 0, 2, 1, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 1, 1, 2, 1], device='cuda:0')

In [59]:
# s = plist.max(dim = 1)[1] ==model.labelList

In [60]:
# s

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,
        1, 1, 1, 0, 1, 1, 1], device='cuda:0', dtype=torch.uint8)

In [61]:
# s.sum()

tensor(38, device='cuda:0')

In [62]:
# predict = plist.max(dim = 1)[1]

In [63]:
# predict.data

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2,
        2, 2, 2, 2, 1, 2, 1], device='cuda:0')

In [64]:
# predict.data == model.labelList.data

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,
        1, 1, 1, 0, 1, 1, 1], device='cuda:0', dtype=torch.uint8)

In [65]:
# n = nAll = correctAll = correctRoot = 0.0

# for j, tree in enumerate(dev):
#     predictions, loss = model.getLoss(tree)
#     correct = (predictions.data == model.labelList.data)
#     correctAll += correct.sum()
#     nAll += correct.squeeze().size()[0]
#     correctRoot += correct.squeeze()[-1]
#     n += 1
    

In [66]:
# correct

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,
        1, 1, 1, 0, 1, 1, 1], device='cuda:0', dtype=torch.uint8)

In [67]:
# correctRoot

tensor(82, device='cuda:0', dtype=torch.uint8)

In [68]:
# n

1101.0

In [69]:
# correctAll

tensor(29517, device='cuda:0')

In [70]:
# nAll

41447.0

In [73]:
# correctAll.item() / nAll

0.7121625208097088

In [74]:
# correctRoot.item() / n

0.07447774750227067

In [75]:
# loss

tensor(0.9491, device='cuda:0', grad_fn=<NllLossBackward>)

In [78]:
# loss.backward

<bound method Tensor.backward of tensor(0.9491, device='cuda:0', grad_fn=<NllLossBackward>)>