In [None]:
import sys
sys.path.append('..')

In [None]:
import pandas as pd
import time
import src.ngram as ngram
from src.training.preprocess import tokenize, pad_tokens

In [None]:
data_train_folder = '../data/trainset/named-entity/'
data_train_fname_prefix = 'train_mne_20k'

output_folder = '../models/ngrams/named-entity/'
output_fname_suffix = '_mne_20k'

In [None]:
data_train_aug_folder = '../data/trainset/augmented/named-entity/'
data_train_aug_fname_prefix = 'aug_train_ne'

aug_output_folder = '../models/ngrams/named-entity/'
aug_output_fname_suffix = '_ne_aug' 

In [None]:
n_max = 5
k_min = 1
k_max = 1
train_set_type = 'original' # Set to original / augmented

start_t = time.time()

print('Max n: ', n_max)
print('Train set: ', train_set_type)

if train_set_type == 'original':
    ffolder = data_train_folder
    fname_prefix = data_train_fname_prefix
    output_fname_prefix = '{}{}_gram{}'.format(output_folder, n_max, output_fname_suffix)
elif train_set_type == 'augmented':
    ffolder = data_train_aug_folder
    fname_prefix = data_train_aug_fname_prefix
    output_fname_prefix = '{}{}_gram{}'.format(aug_output_folder, n_max, aug_output_fname_suffix)

for fold in range(k_min, k_max+1):
    print('\nFold {}/{}'.format(fold, k_max))
    
    # Load train set
    fname = '{}{}_fold_{}.txt'.format(ffolder, fname_prefix, fold)
    print('Train set: ', fname)
    data_train = pd.read_csv(
        fname, 
        sep='\t', 
        header=None, 
        names=['word', 'syllables'], 
        na_filter=False
    )
    print('Number of words: ', len(data_train))

    # Build the n-gram
    print('Building n-gram')
    tokens = pad_tokens(tokenize(data_train), n=n_max, start_pad=True, end_marker=True)
    ngram_fold = ngram.NGram(tokens, n=n_max, build_cont_fdist=True, build_follow_fdist=True, verbose=True)

    # Save the n-gram to a file
    fname = output_fname_prefix + '_fold_{}.json'.format(fold)
    ngram.save(ngram_fold, fname)
    print('n-gram saved to "{}"'.format(fname))

print('\nAll n-grams generated in {:.2f} s'.format(time.time() - start_t))