In [None]:
import re, os, sys, json, random
from tqdm import *
import numpy as np
import tensorflow as tf  
import h5py
import nltk


def load_word(path):
        
    input_file = open(path)
    word_lst = [line.rstrip('\n') for line in input_file.readlines()]
    words = dict((word, i) for i, word in enumerate(word_lst))
    rwords = dict(map(lambda t:(t[1],t[0]), words.items()))
    input_file.close()
    
    return words, rwords


class Utils:
    
    def __init__(self, word_path, text_path, batch_size, nb_samples):
        
        self.words, self.rwords = load_word(word_path)
        self.file = h5py.File(text_path)
        self.batch_size = batch_size
        self.current_batch = 0
        self.nb_samples = nb_samples
        self.current_text = dict()
        self.shuffled_id = np.arange(nb_samples)
        random.shuffle(self.shuffled_id)
        
    
    def get_words_size(self):
        return len(self.words)
    
        
    def next_batch(self):
        
        to_again = False
        if (self.current_batch + 1) * self.batch_size >= self.nb_samples:
            to_again = True
            random.shuffle(self.shuffled_id)
            self.current_batch = 0
        
        if to_again:
            return dict(), to_again
        
        start = self.current_batch * self.batch_size
        end = (self.current_batch + 1) * self.batch_size
        ids = self.shuffled_id[start:end]
        ids = sorted(ids)
        
        source = []
        ground_truth = []
        label = []
        loss_weights = []
        defendant = []
        defendant_length = []
        
        source_tx = []
        defendant_tx = []
        reason_tx = []
        
        
        source = self.file['source'][ids]
        ground_truth = self.file['ground_truth'][ids]
        label = self.file['label'][ids]
        loss_weights = self.file['loss_weights'][ids]
        defendant = self.file['defendant'][ids]
        defendant_length = self.file['defendant_length'][ids]
        
        
        source_tx = self.file['source_tx'][ids]
        defendant_tx = self.file['defendant_tx'][ids]
        reason_tx = self.file['reason_tx'][ids]
            
                                              
        
        to_return = {'source' : source,
                     'defendant' : defendant,
                     'defendant_length' : defendant_length,
                     'ground_truth' : ground_truth, 
                     'label' : label,
                     'loss_weights' : loss_weights}
        
        self.current_text = to_return
        self.current_text.update({'source_tx':source_tx, 'defendant_tx':defendant_tx, 'reason_tx':reason_tx})
        self.current_batch += 1
        
        return to_return, to_again   
    
    
    def print_text(self, prediction_tx, index):


        print (self.current_text['source_tx'][index].decode('gb2312'))
        print ('-' * 20 + '\n')
        
        print (self.current_text['defendant_tx'][index].decode('gb2312'))
        print ('-' * 20 + '\n')
        
        print (self.current_text['reason_tx'][index].decode('gb2312'))
        print ('-' * 20 + '\n')
        
        print (prediction_tx)
        print ('\n' + '*' * 20 + '\n')
    
    def bleu(self, prediction_tx, index):
        
        return nltk.translate.bleu_score.sentence_bleu([self.current_text['reason_tx'][index].decode('gb2312')], prediction_tx)
    
    def i2t(self, ilist, to_print):
        
        same_words_counter = 0
        words_counter = 0
        
        bleu_score = 0
        
        for i in range(len(ilist)):
                
            prediction_tx = ''
            
            for j in range(len(ilist[i])):
                
                if self.rwords[self.current_text['label'][i][j]] == 'pad': 
                    break
                
                words_counter += 1
                if self.current_text['label'][i][j] == ilist[i][j]:
                    same_words_counter += 1
            
            
                
            for j in range(len(ilist[i])):
                if self.rwords[ilist[i][j]] == 'eos':
                    break
                prediction_tx += self.rwords[ilist[i][j]]
         
            bleu_score += self.bleu(prediction_tx, i)
            
            if i != len(ilist)-1: continue
                
            if to_print:
                self.print_text(prediction_tx=prediction_tx,  index=i)

        return same_words_counter / words_counter, bleu_score / len(ilist)
            

    


In [1]:
import re, os, sys, json, random
from tqdm import *
import numpy as np
import tensorflow as tf  
import h5py
import nltk


def load_word(path):
        
    input_file = open(path)
    word_lst = [line.rstrip('\n') for line in input_file.readlines()]
    words = dict((word, i) for i, word in enumerate(word_lst))
    rwords = dict(map(lambda t:(t[1],t[0]), words.items()))
    input_file.close()
    
    return words, rwords


def h5py_write(text_path, words, h5_path, source_len, oseq_len, simplified_len, nb_samples):
    
    eos = words['eos']
    pad = words['pad']
    go = words['go']
    
    input_ = open(text_path)
    output_ = h5py.File(h5_path, 'a')
    
    source_set = output_.create_dataset("source", (nb_samples, source_len), dtype='int32')
    defendant_set = output_.create_dataset("defendant", (nb_samples, simplified_len), dtype='int32')
    label_set = output_.create_dataset("label", (nb_samples, oseq_len), dtype='int32')
    ground_truth_set = output_.create_dataset("ground_truth", (nb_samples, oseq_len), dtype='int32')
    defendant_length_set = output_.create_dataset("defendant_length", (nb_samples, ), dtype='int32')
    weights_set = output_.create_dataset("loss_weights", (nb_samples, oseq_len), dtype='float32')
    
    
    source_tx_set = output_.create_dataset('source_tx', (nb_samples, ), dtype='S3000')
    defendant_tx_set = output_.create_dataset('defendant_tx', (nb_samples, ), dtype='S600')
    reason_tx_set = output_.create_dataset('reason_tx', (nb_samples, ), dtype='S500')
    
    sindex = 0
    
    while True:

        texts = []
        for i in range(10000):
            
            text = input_.readline()
            if text == '': 
                to_break = True
                break
            texts.append(json.loads(text))
        
            
            
        source_tx = [[pad for j in range(source_len)] for i in range(len(texts))]
        for i in range(len(texts)):
            tx = texts[i]['source'][:min(source_len, len(texts[i]['source']))]
            start = source_len - len(tx)
            for j in range(start, source_len):
                source_tx[i][j] = words[tx[j-start]]
        source_set[sindex:sindex+len(texts)] = source_tx



        defendant_tx = [[pad for j in range(simplified_len)] for i in range(len(texts))]
        defendant_length = [min(simplified_len, len(texts[i]['defendant'])) for i in range(len(texts))]
        for i in range(len(texts)):
            tx = texts[i]['defendant'][:min(simplified_len, len(texts[i]['defendant']))]
            start = simplified_len - len(tx)
            for j in range(start, simplified_len):
                defendant_tx[i][j] = words[tx[j-start]]
        defendant_length_set[sindex:sindex+len(texts)] = defendant_length
        defendant_set[sindex:sindex+len(texts)] = defendant_tx


        
        reason_tx = [[pad for j in range(oseq_len)] for i in range(len(texts))]
        truth_tx = [[pad for j in range(oseq_len)] for i in range(len(texts))]
        lengths = [len(texts[i]['reason']) for i in range(len(texts))]
        weigths = [[.0 for j in range(oseq_len)] for i in range(len(texts))]



        for i in range(len(texts)):
            
            tx = texts[i]['reason'][:min(oseq_len, len(texts[i]['reason']))]
            if re.search(r'构成', tx): weigth = 1.
            else: weigth = 1.
            
            for j in range(len(tx)):
                reason_tx[i][j] = words[tx[j]]
                truth_tx[i][j] = words[tx[j]]
                weigths[i][j] = weigth
            
            if len(tx) < oseq_len:
                reason_tx[i][len(tx)] = eos
                weigths[i][len(tx)] = weigth
            truth_tx[i].pop()
            truth_tx[i].insert(0, go)

        label_set[sindex:sindex+len(texts)] = reason_tx
        ground_truth_set[sindex:sindex+len(texts)] = truth_tx
        weights_set[sindex:sindex+len(texts)] = weigths
        
        
        encoder_source = [texts[i]['source'].encode('gb2312', 'ignore') for i in range(len(texts))]
        encoder_defendant = [texts[i]['defendant'].encode('gb2312', 'ignore') for i in range(len(texts))]
        encoder_reason = [texts[i]['reason'].encode('gb2312', 'ignore') for i in range(len(texts))]
        
        source_tx_set[sindex:sindex+len(texts)] = encoder_source[:]
        defendant_tx_set[sindex:sindex+len(texts)] = encoder_defendant[:]
        reason_tx_set[sindex:sindex+len(texts)] = encoder_reason[:]
        
        sindex += len(texts)
        
        print (sindex)
        
        if sindex >= nb_samples-1: break
        
    output_.close()
    input_.close()


    
if __name__ == '__main__':
    
    
    words, rwords = load_word('/home/xuwenshen/data/big_data/2017_3_13/words')
    
#     h5py_write(text_path='/home/xuwenshen/data/big_data/2017_3_13/shuffled_train',
#                words=words, 
#                h5_path='/home/xuwenshen/data/big_data/2017_3_13/train.h5', 
#                source_len=1000, 
#                simplified_len=150,
#                oseq_len=200,
#                nb_samples=1600000)

    h5py_write(text_path='/home/xuwenshen/data/big_data/2017_3_13/shuffled_test',
               words=words, 
               h5_path='/home/xuwenshen/data/big_data/2017_3_13/test.h5', 
               source_len=1000, 
               simplified_len=150,
               oseq_len=200,
               nb_samples=115672)
    
#     h5py_write(text_path='/home/xuwenshen/data/big_data/2017_3_13/shuffled_valid',
#            words=words, 
#            h5_path='/home/xuwenshen/data/big_data/2017_3_13/valid.h5', 
#            source_len=1000, 
#            simplified_len=150,
#            oseq_len=200,
#            nb_samples=10000)
    
    
    file = h5py.File('/home/xuwenshen/data/big_data/2017_3_13/test.h5')
    
    for i in range(2000):
        source = file['source'][i]
        ground_truth = file['ground_truth'][i]
        label = file['label'][i]
        loss_weights = file['loss_weights'][i]
        defendant = file['defendant'][i]

        souce_tx = ''
        label_tx = ''
        ground_truth_tx = ''
        defendant_tx = ''
        counter = 0

        for j in range(len(defendant)):
            defendant_tx += rwords[defendant[j]]
            
        for j in range(len(source)):
            souce_tx += rwords[source[j]]
            

        for j in range(len(label)):
            if rwords[label[j]] != 'pad':
                counter += 1
            label_tx += rwords[label[j]]

        for j in range(len(ground_truth)):
            ground_truth_tx += rwords[ground_truth[j]]

#         print (souce_tx)
#         print ('*'*20 + '\n')
#         print (defendant_tx)
#         print ('*' * 20 + '\n')
#         print (label_tx)
#         print ('*'*20 + '\n')
#         print (ground_truth_tx)
#         print ('*'*20 + '\n')
#         print (sum(loss_weights))
#         print (counter)
#         print ('*'*20 + '\n')

#         print (file['source_tx'][i].decode('gb2312'))
#         print ('*' * 20 + '\n')
#         print (file['defendant_tx'][i].decode('gb2312'))
#         print ('*' * 20 + '\n')
#         print (file['reason_tx'][i].decode('gb2312'))
#         print ('\n')
    file.close()

10000
20000
30000
40000
50000
60000
70000
80000
90000
100000
110000
115672
