# <span style='color:orange'> Modelling Chinese Word Segmentation as Sequence to Sequence Prediction Problem </span>

# <span style='color:Green'> Testing Pre-Trained Models Notebook</span>

# Load the model

In [None]:
from keras.models import load_model
import cPickle as pickle
import numpy as np
model = load_model('/home/asa224/Desktop/students_less_asa224/Test Folder on Less/model_epoch3.h5')

# Helper functions

In [None]:
def generateInputWordFile(filename='/local-scratch/asa224/input'):
    """
    The function generates the word file, similar to the count_1w.txt file provided by Prof. Anoop
    
    The output of the file can be used to parse the characters, and is an input to the 
    generateTupleList() function as well. 
    
    INPUT: Input to this function is the training text file. 
    
    """
    with(open(filename, 'rb')) as f:
        word_file_1M = open('/local-scratch/input_word_file', 'wb')
        for line in f:
            line = unicode(line, 'utf-8')
            line = line.replace('\n', '')
            for word in line:
                word_file_1M.write(word.encode('utf-8') + '\t'.encode('utf-8') + str(0).encode('utf-8') +\
                                           '\n'.encode('utf-8'))
        f.close()
    word_file_1M.close()
    
    
def generateTupleList(filename='/local-scratch/asa224/word_file_1M'):
    """
    This function is similar to the above function in the sense that it assigns labels to each
    character in the training set. The function returns a list of tuples, in which each tuple
    contains a single character and its corresponding label. 
    
    INPUT: Input to this function is a WORD FILE generated by generateWordFile(filename) function. 
    
    Use this function in conjunction with nGramSequenceGenerator(labelledlist, n) to create a
    training set with constant sequence size, which does not require paddings. 
    """
    with(open(filename, 'rb')) as f:
        label = []
        for line in f:
            word, count = line.split('\t')

            # making sure the parsing is going fine
            assert int(count) == 0

            word = unicode(word, 'utf-8')
            if len(word) == 1:
                label.append((word[0], 3))
            else:
                for i, character in enumerate(word):
                    if i == 0: # this is the first letter
                        label.append((character, 0))
                    elif i == (len(word) - 1): # this is the last letter
                        label.append((character, 2))
                    else: # this is somewhere in the middle
                        label.append((character, 1))

        f.close()
        return label
    
def nGramSequenceGenerator(labelledlist, n):
    """
    Takes as input the label list of tuples generated by the code above. 
    The function generates sequence of size "n" from the given list. 
    """
    count = len(labelledlist)/n
    ngrammedlist = []
    for i in range(count):
        ngrammedlist.append( labelledlist[i*n : (i+1)*n])
    return ngrammedlist

# Prepare test data for prediction

In [None]:
generateInputWordFile(filename='/home/asa224/Desktop/students_asa224/NLP Work/assignments/segmenter/data/input')

In [None]:
tuples = generateTupleList(filename='/local-scratch/input_word_file')

In [None]:
orig_dict = pickle.load( open( "orig_dict.p", "rb" ) )
ret_dict = pickle.load( open( "ret_dict.p", "rb" ) )

In [None]:
x = [[]]
count = 0
for i in range(0, len(final_input)): # iterate over the whole dataset
    for j in range(0, len(final_input[i])): # iterate over the current sentence
        try:
            x[i].append(orig_dict[final_input[i][j][0]])
        except KeyError:
            x[i].append(np.random.choice(orig_dict.values()))
            count += 1
    x.append([])

# Start the prediction process, and write data to output file

In [None]:
out_file = open('./output_rnn', 'wb')
for seq in x[:-1]:
    pred_labels = model.predict_classes(np.array(seq).reshape(1,len(seq)))
    # get the class label
    
    for num in range(0, len(pred_labels[0])):
        char = seq[num]
        if pred_labels[0][num] >= 2:
#             print('writing {}'.format(ret_dict[char].encode('utf-8')))
            out_file.write(ret_dict[char].encode('utf-8') + ' '.encode('utf-8'))
        else:
            out_file.write(ret_dict[char].encode('utf-8'))
    
#     out_file.write('\n'.encode('utf-8'))
    
out_file.close()