# <span style='color:orange'> Modelling Chinese Word Segmentation as Sequence to Sequence Prediction Problem </span>

# <span style='color:Green'> Testing Pre-Trained Models Notebook</span>

# Load the model

In [7]:
from keras.models import load_model
import cPickle as pickle
import numpy as np
model = load_model('/home/asa224/Desktop/students_less_asa224/Test Folder on Less/model_epoch3.h5')

Using TensorFlow backend.


# Helper functions

In [1]:
SPECIAL_SYMBOL = u'\u02e0'

In [2]:
# we open the count file, get the word and assign labels to each character
# cannot use dictionaries, since the same character may appear again and overwrites the value at its place in dict.
def generateTupleListAccToSentences(filename='/local-scratch/asa224/wseg_simplified_cn.txt'):
    """
    This function generates a data list of lists, which contains sequences and corresponding
    labels for each character, according to the sentences in the input file. This function 
    takes the whole training set txt file as input, and generates sequences according to the 
    line, ie. each sequence is a line. 
    
    INPUT: Input to this function is the training text file
    
    If you want to use this data as training data for LSTM, you have to pad the sequences 
    since they are not of the same length. 
    """
    # filepath = '/mnt/D82A1A8F2A1A6B30/wseg_simplified_cn.txt'
    with(open(filepath, 'rb')) as f:
        data = [[]]
        count = 0
        for line in f:
            line = unicode(line, 'utf-8')
            line = line.replace('\n', '')
            words = line.split(' ')

            for word in words:
                if len(word) == 1:
                    data[count].append((word[0], 3))
                else:
                    for i, character in enumerate(word):
                        if i == 0: # this is the first letter
                            data[count].append((character, 0))
                        elif i == (len(word) - 1): # this is the last letter
                            data[count].append((character, 2))
                        else: # this is somewhere in the middle
                            data[count].append((character, 1))
            data.append([])
            count += 1

        f.close()
        
        return data
    
def generateWordFile(filename='/local-scratch/asa224/wseg_simplified_cn.txt'):
    """
    The function generates the word file, similar to the count_1w.txt file provided by Prof. Anoop
    
    The output of the file can be used to parse the characters, and is an input to the 
    generateTupleList() function as well. 
    
    INPUT: Input to this function is the training text file. 
    
    """
    with(open(filename, 'rb')) as f:
        word_file_1M = open('/local-scratch/word_file_1M', 'wb')
        for line in f:
            line = unicode(line, 'utf-8')
            line = line.replace('\n', '')
            line = line.split(' ')
            
            # add the newline back using a special symbol
            line.append(SPECIAL_SYMBOL)
            for word in line:
                word_file_1M.write(word.encode('utf-8') + '\t'.encode('utf-8') + str(0).encode('utf-8') +\
                                           '\n'.encode('utf-8'))
        f.close()
    word_file_1M.close()

def generateInputWordFile(filename='/local-scratch/asa224/input'):
    """
    The function generates the word file, similar to the count_1w.txt file provided by Prof. Anoop
    
    THIS GENERATES WORD FILE FOR INPUT TEXT. 
    
    The output of the file can be used to parse the characters, and is an input to the 
    generateTupleList() function as well. 
    
    INPUT: Input to this function is the training text file. 
    
    """
    with(open(filename, 'rb')) as f:
        word_file_1M = open('./input_word_file', 'wb')
        for line in f:
            line = unicode(line, 'utf-8')
            # replace the newline character with the special unicode symbol
            line = line.replace('\n', SPECIAL_SYMBOL)
            for word in line:
                word_file_1M.write(word.encode('utf-8') + '\t'.encode('utf-8') + str(0).encode('utf-8') +\
                                           '\n'.encode('utf-8'))
        f.close()
    word_file_1M.close()

def generateTupleList(filename='/local-scratch/asa224/word_file_1M'):
    """
    This function is similar to the above function in the sense that it assigns labels to each
    character in the training set. The function returns a list of tuples, in which each tuple
    contains a single character and its corresponding label. 
    
    INPUT: Input to this function is a WORD FILE generated by generateWordFile(filename) function. 
    
    Use this function in conjunction with nGramSequenceGenerator(labelledlist, n) to create a
    training set with constant sequence size, which does not require paddings. 
    """
    with(open(filename, 'rb')) as f:
        label = []
        for line in f:
            word, count = line.split('\t')
            # making sure the parsing is going fine
            assert int(count) == 0

            word = unicode(word, 'utf-8')
            if len(word) == 1:
                label.append((word[0], 3))
            else:
                for i, character in enumerate(word):
                    if i == 0: # this is the first letter
                        label.append((character, 0))
                    elif i == (len(word) - 1): # this is the last letter
                        label.append((character, 2))
                    else: # this is somewhere in the middle
                        label.append((character, 1))

        f.close()
        return label
    
def nGramSequenceGenerator(labelledlist, n):
    """
    Takes as input the label list of tuples generated by the code above. 
    The function generates sequence of size "n" from the given list. 
    """
    count = len(labelledlist)/n
    ngrammedlist = []
    for i in range(count):
        ngrammedlist.append( labelledlist[i*n : (i+1)*n])
    return ngrammedlist

# Prepare test data for prediction

In [3]:
generateInputWordFile(filename='../../data/input')

In [4]:
tuples = generateTupleList(filename='./input_word_file')

In [5]:
tuples

[(u'\u6cd5', 3),
 (u'\u6b63', 3),
 (u'\u7814', 3),
 (u'\u7a76', 3),
 (u'\u4ece', 3),
 (u'\u6ce2', 3),
 (u'\u9ed1', 3),
 (u'\u64a4', 3),
 (u'\u519b', 3),
 (u'\u8ba1', 3),
 (u'\u5212', 3),
 (u'\u02e0', 3),
 (u'\u65b0', 3),
 (u'\u534e', 3),
 (u'\u793e', 3),
 (u'\u5df4', 3),
 (u'\u9ece', 3),
 (u'\uff19', 3),
 (u'\u6708', 3),
 (u'\uff11', 3),
 (u'\u65e5', 3),
 (u'\u7535', 3),
 (u'\uff08', 3),
 (u'\u8bb0', 3),
 (u'\u8005', 3),
 (u'\u5f20', 3),
 (u'\u6709', 3),
 (u'\u6d69', 3),
 (u'\uff09', 3),
 (u'\u02e0', 3),
 (u'\u6cd5', 3),
 (u'\u56fd', 3),
 (u'\u56fd', 3),
 (u'\u9632', 3),
 (u'\u90e8', 3),
 (u'\u957f', 3),
 (u'\u83b1', 3),
 (u'\u5965', 3),
 (u'\u5854', 3),
 (u'\u5c14', 3),
 (u'\uff11', 3),
 (u'\u65e5', 3),
 (u'\u8bf4', 3),
 (u'\uff0c', 3),
 (u'\u6cd5', 3),
 (u'\u56fd', 3),
 (u'\u6b63', 3),
 (u'\u5728', 3),
 (u'\u7814', 3),
 (u'\u7a76', 3),
 (u'\u4ece', 3),
 (u'\u6ce2', 3),
 (u'\u9ed1', 3),
 (u'\u64a4', 3),
 (u'\u519b', 3),
 (u'\u7684', 3),
 (u'\u8ba1', 3),
 (u'\u5212', 3),
 (u'\u3002', 3

In [10]:
final_input = nGramSequenceGenerator(tuples, n=13)

In [8]:
orig_dict = pickle.load( open( "orig_dict.p", "rb" ) )
ret_dict = pickle.load( open( "ret_dict.p", "rb" ) )

In [11]:
x = [[]]
count = 0
for i in range(0, len(final_input)): # iterate over the whole dataset
    for j in range(0, len(final_input[i])): # iterate over the current sentence
        try:
            x[i].append(orig_dict[final_input[i][j][0]])
        except KeyError:
            x[i].append(np.random.choice(orig_dict.values()))
            count += 1
    x.append([])

In [13]:
orig_dict[SPECIAL_SYMBOL]

KeyError: u'\u02e0'

In [12]:
x

[[19, 128, 353, 322, 320, 1300, 1065, 1016, 953, 116, 145, 5297, 180],
 [537, 211, 623, 2016, 2825, 217, 2304, 203, 339, 99, 551, 222, 612],
 [3, 1742, 104, 3104, 19, 83, 83, 633, 188, 120, 2256, 1452, 2045],
 [1189, 2304, 203, 114, 0, 19, 83, 128, 10, 353, 322, 320, 1300],
 [1065, 1016, 953, 1, 116, 145, 2, 2496, 2256, 1452, 2045, 1189, 6],
 [10, 623, 2016, 181, 8, 164, 339, 718, 551, 222, 63, 125, 536],
 [451, 32, 114, 14, 1490, 451, 1, 2, 3159, 56, 299, 26, 0],
 [596, 83, 87, 4, 248, 397, 41, 1300, 1065, 1, 1394, 882, 775],
 [284, 0, 97, 1376, 1335, 167, 83, 392, 538, 246, 24, 52, 367],
 [1006, 169, 15, 20, 167, 137, 58, 1, 48, 207, 1062, 1081, 334],
 [269, 1499, 63, 169, 0, 45, 14, 136, 329, 43, 167, 450, 65],
 [137, 58, 773, 454, 1, 545, 137, 106, 197, 169, 2, 20, 2256],
 [1452, 2045, 1189, 114, 0, 167, 10, 394, 152, 83, 543, 49, 188],
 [791, 1, 208, 133, 96, 12, 89, 124, 308, 1, 173, 223, 130],
 [0, 5, 17, 12, 52, 255, 186, 248, 397, 41, 1300, 1065, 1],
 [1394, 882, 775, 284, 169

# Start the prediction process, and write data to output file

In [None]:
out_file = open('./output_rnn', 'wb')
for seq in x[:-1]:
    pred_labels = model.predict_classes(np.array(seq).reshape(1,len(seq)))
    # get the class label
    
    for num in range(0, len(pred_labels[0])):
        char = seq[num]
        if ret_dict[char] == SPECIAL_SYMBOL:
            out_file.write('\n'.encode('utf-8'))
        elif pred_labels[0][num] >= 2:
#             print('writing {}'.format(ret_dict[char].encode('utf-8')))
            out_file.write(ret_dict[char].encode('utf-8') + ' '.encode('utf-8'))
        else:
            out_file.write(ret_dict[char].encode('utf-8'))
    
out_file.close()