In [2]:
import numpy as np 
from tqdm import *
import torch
from torch import nn
from torch import autograd
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import json
import random
import h5py
import sys, os
import nltk

sys.path.append('../')
os.environ["CUDA_VISIBLE_DEVICES"] = '1'

from Utils import Utils
from net import Seq2Seq
from Transform import Transform
from BLEUScore import BLEUScore

bleuscore = BLEUScore()
model_dir = '/data/xuwenshen/ai_challenge/code/models/'

def train(lr, net, epoch, train_utils, valid_utils, transform):

    
    if torch.cuda.is_available():
        net.cuda()
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr = lr)
    net.train()
    
    best_score = -1
    for iepoch in range(epoch):
        
        new_epoch = False
        batchid = 0
        while new_epoch == False:
            
            data = train_utils.next_batch()
            entext = data['entext']
            enlen = data['enlen']
            zhlabel = data['zhlabel']
            zhgtruth = data['zhgtruth']
            zhlen = data['zhlen']
            new_epoch = data['flag']
            
            en_ = transform.i2t(entext[0], language = 'en')
            label_ = transform.i2t(zhlabel[0], language = 'zh')
            truth_ = transform.i2t(zhgtruth[0], language = 'zh')
            
            print (en_, '\n', label_, '\n', truth_, '\n===============\n')
            if batchid > 100:
                break
            continue
            
            
            logits, predic = net(entext, zhgtruth, True)
            loss = net.get_loss(logits, zhlabel)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            print (iepoch, batchid, loss.data.cpu().numpy())
            batchid += 1
            
            if batchid % 300 == 0:
                print ('\n------------------------\n')
                net.eval()
                is_finished = False
                all_pre = []
                all_lable = []
                all_len = []
                all_loss = 0
                bats = 0
                while is_finished == False:
                    data = valid_utils.next_batch()
                    entext = data['entext']
                    enlen = data['enlen']
                    zhlabel = data['zhlabel']
                    zhgtruth = data['zhgtruth']
                    zhlen = data['zhlen']
                    is_finished = data['flag']

                    logits, predic = net(entext, zhgtruth, False)
                    loss = net.get_loss(logits, zhlabel)
                    
                    all_pre.extend(predic)
                    all_lable.extend(zhlabel)
                    all_len.extend(zhlen)
                    all_loss += sum(loss.data.cpu().numpy())
                    
                    del loss, logits, predic
                    bats += 1
                
                for i in range(len(all_pre)):
                    all_pre[i] = transform.clip(all_pre[i], language='zh')
                    all_lable[i] = all_lable[i][: all_len[i]]
                
                score = bleuscore.score(all_pre, all_lable)
            
                for i in range(5):
                    
                    ans_ = transform.i2t(all_lable[i], language = 'zh')
                    pre_ = transform.i2t(all_pre[i], language = 'zh')
                    print (ans_, '\n*****************\n', pre_, '\n')
        
                all_loss /= bats
                print (iepoch, batchid, all_loss, score)
                
                if best_score < score:
                    
                    bestscore = score
                    torch.save(net, model_dir + "loss-{:3f}-score-{:3f}-model.pkl".format(all_loss, score))
            net.train()
    
    

if __name__ == '__main__':
    
    
    batch_size = 256
    nb_samples = 9893707
    en_path = '/data/xuwenshen/ai_challenge/data/train/train/en.h5py'
    zh_path = '/data/xuwenshen/ai_challenge/data/train/train/zh.h5py'
    
    train_utils = Utils(batch_size=batch_size,
                      en_path=en_path,
                      zh_path=zh_path,
                      is_test=False,
                      nb_samples=nb_samples)
    
    
    en_path = '/data/xuwenshen/ai_challenge/data/valid/valid/en.h5py'
    zh_path = '/data/xuwenshen/ai_challenge/data/valid/valid/zh.h5py'    
    nb_samples = 4000
    valid_utils = Utils(batch_size=batch_size,
                      en_path=en_path,
                      zh_path=zh_path,
                      is_test=False,
                      nb_samples=nb_samples)
    
    en_voc_path = '/data/xuwenshen/ai_challenge/data/train/train/en_voc.json'
    zh_voc_path = '/data/xuwenshen/ai_challenge/data/train/train/zh_voc.json'
    transform = Transform(en_voc_path=en_voc_path,
                         zh_voc_path=zh_voc_path)
    
    
    en_dims = 256
    en_voc = 50004
    zh_dims = 256
    zh_voc = 4004 
    dropout = 0.5
    en_hiddens = 200
    zh_hiddens = 312
    
    net = Seq2Seq(en_dims = 256,
               en_voc = 50004,
               zh_dims = 256,
               zh_voc = 4004,
               dropout = 0.5,
               en_hiddens = 128,
               zh_hiddens = 256)
    print (net)
    
    epoch = 100
    lr = 0.0001
    
    train(lr=lr,
         train_utils=train_utils,
          valid_utils=valid_utils,
          transform=transform,
         net=net,
         epoch=100)
    

Seq2Seq (
  (en_embedding): Embedding(50004, 256)
  (zh_embedding): Embedding(4004, 256)
  (enc_fw_lstm): LSTM(256, 128, dropout=0.5)
  (enc_bw_lstm): LSTM(256, 128, dropout=0.5)
  (dec_lstm_cell): LSTMCell(512, 256)
  (fc): Linear (256 -> 4004)
  (softmax): Softmax ()
  (cost_func): CrossEntropyLoss (
  )
)


NameError: name 'string' is not defined