In [None]:
import sys
sys.path.append('..')

In [None]:
importlib.reload(sys.modules['src.testing.probability'])
importlib.reload(sys.modules['src.testing.tagger'])
importlib.reload(sys.modules['src.testing.syllabification'])
importlib.reload(sys.modules['src.utility'])
importlib.reload(sys.modules['src.ngram'])

In [None]:
import importlib
import time
import pandas as pd
import src.utility as util
import src.testing.probability as probability
import src.testing.tagger as tagger
import src.testing.syllabification as syllabification
import src.ngram as ngram

In [None]:
def syllabify_folds(n, prob_args, fnames, state_elim=True, k_min=1, k_max=5, n_sample=None, sample_seed=0, cache_preload=None, log_fname='', fname_param='method', save_log=True, save_result=False, save_result_timestamp=True, save_cache=False, validation=True):
    start_t = time.time()
    
    results = {}

    results = {
        'metadata': {
            'n': n,
            'k_min': k_min,
            'k_max': k_max,
            'state_elim': state_elim,
            'n_sample': n_sample,
            'sample_seed': sample_seed,
            'prob_args': prob_args.copy()
        },
        'fold_results': {}
    }

    util.print_dict(results['metadata'])

    for fold in range(k_min, k_max+1):
        data_test_fname = '{}_fold_{}.txt'.format(fnames['data_test'], fold)
        n_gram_fname = '{}_fold_{}.json'.format(fnames['n_gram'], fold)

        print('Fold        : {}'.format(fold))
        print('Data test   : "{}"'.format(data_test_fname))
        print('n-gram      : "{}"'.format(n_gram_fname))

        if 'with_aug' in prob_args and prob_args['with_aug']:
            n_gram_aug_fname = '{}_fold_{}.json'.format(fnames['n_gram_aug'], fold)
            print('n-gram aug  : "{}"'.format(n_gram_aug_fname))

        data_test = pd.read_csv(
            data_test_fname, 
            sep='\t', 
            header=None,
            names=['word', 'syllables'] if validation else ['word'],
            na_filter=False
        )

        print('Total words : {}'.format(len(data_test)))

        if n_sample != None:
            data_test = data_test.sample(n=n_sample, random_state=sample_seed).reset_index(drop=True)

        prob_args['n_gram'] = ngram.load(n_gram_fname, n_max=n, load_follow_fdist=True, load_cont_fdist=True)
        
        if prob_args['with_cache']:
            if cache_preload != None and 'cache' in cache_preload:
                prob_args['cache'] = probability.load_cache('{}_fold_{}.json'.format(cache_preload['cache'], fold), '../data/cache/')
            else:
                prob_args['cache'] = probability.generate_prob_cache(n)
                

        if prob_args['method'] == 'gkn':
            prob_args['d_cache'] = probability.generate_gkn_discount_cache(n, prob_args['n_gram'], prob_args['d_ceil'])

        if prob_args['with_aug']:
            prob_args['n_gram_aug'] = ngram.load(n_gram_aug_fname, n_max=n, load_follow_fdist=True, load_cont_fdist=True)

            if prob_args['with_cache']:
                if cache_preload != None and 'cache_aug' in cache_preload:
                    prob_args['cache_aug'] = probability.load_cache('{}_fold_{}.json'.format(cache_preload['cache_aug'], fold), '../data/cache/')
                else:
                    prob_args['cache_aug'] = probability.generate_prob_cache(n)
                    

            if prob_args['method'] == 'gkn':
                prob_args['d_cache_aug'] = probability.generate_gkn_discount_cache(n, prob_args['n_gram_aug'], prob_args['d_ceil'])

        result = syllabification.syllabify(data_test, n, prob_args, validation=validation)
        results['fold_results'][fold] = result['metadata']

        if save_result:
            syllabification.save_result(result['data'], '{}_{}={}_fold_{}.txt'.format(log_fname, fname_param, prob_args[fname_param], fold), folder=fnames['result_folder'], with_timestamp=save_result_timestamp)
        
        if save_cache:
            if 'cache' in prob_args:
                probability.save_cache(prob_args['cache'], 'cache_prob_{}_{}={}_fold_{}.json'.format(log_fname, fname_param, prob_args[fname_param], fold), '../data/cache/')
            if 'cache_aug' in prob_args:
                probability.save_cache(prob_args['cache_aug'], 'cache_aug_prob_{}_{}={}_fold_{}.json'.format(log_fname, fname_param, prob_args[fname_param], fold), '../data/cache/')

        # Clear n_gram from memory
        prob_args['n_gram'] = None
        prob_args['n_gram_aug'] = None

        print('\n')
    
    end_t = time.time()
    avg_ser = sum(results['fold_results'][i]['syllable_error_rate'] for i in range(k_min, k_max+1)) / (k_max-k_min+1)
    results['metadata']['average_ser'] = round(avg_ser, 5)
    results['metadata']['start_time'] = time.strftime('%Y/%m/%d - %H:%M:%S', time.localtime(start_t))
    results['metadata']['end_time'] = time.strftime('%Y/%m/%d - %H:%M:%S', time.localtime(end_t))
    results['metadata']['duration'] = round(end_t - start_t, 2)
    
    if save_log:
        log_fname += '_' if log_fname != '' else ''
        util.save_dict_to_log(results, '{}_{}={}.log'.format(log_fname, fname_param, prob_args[fname_param]), '../logs/')

    print('Finished in {:.2f} s'.format(end_t - start_t))

    return results, prob_args

In [None]:
for aug_w in [None]:
    prob_args = {
        'method': 'gkn',
        'd_ceil': 3,
        'with_cache': True,
        'with_aug': False,
        'aug_w': aug_w,
        'n': 4
    }

    fnames = {
        'data_test': '../data/testset/named-entity/test_mne_20k',
        'n_gram':  '../models/ngrams/named-entity/5_gram_mne_20k',
        'n_gram_aug': '../models/ngrams/named-entity/5_gram_ne_aug',
        'result_folder': '../data/results/named-entity/m/'
    }

    result, post_prob_args = syllabify_folds(
        n=4,        
        prob_args=prob_args,
        fnames=fnames, 
        k_min=1, k_max=1,
        log_fname='gkn_n=4',
        fname_param='d_ceil',
        save_log=True,
        save_result=True,        
        save_result_timestamp=False,
        save_cache=False,
        validation=False
    )

    print()