# Before to start read carefully the instructions provided in the README file to obtain the data and organize it.

In [None]:
from src.utils import check_folders, get_buckeye_data, get_timit_data, get_ntimit_data, get_labels, pad_set, choose_dataset, retrieve_model, select_model_and_dataset, predict
from src.dataset import get_loader
import torch


buckeye_corpus_path = "../Word-Segmenter/data/Buckeye"
train_indices = "../Word-Segmenter/data/buckeye_train_data.txt"
val_indices = "../Word-Segmenter/data/buckeye_val_data.txt"
test_indices = "../Word-Segmenter/data/buckeye_test_data.txt"

timit_testset_path = "../Word-Segmenter/data/TIMIT/data/TEST"
timit_indices= "../Word-Segmenter/data/timit_test_data.txt"

ntimit_testset_path = "../Word-Segmenter/data/NTIMIT"
ntimit_indices= "../Word-Segmenter/data/ntimit_test_data.txt"

#check if path exists
ntimit_exists = check_folders(buckeye_corpus_path, timit_testset_path, ntimit_testset_path)

SR=16000
FRAME_SIZE=int(0.025*SR)
HOP_LENGTH=int(0.025*SR)

print('Extracting Buckeye data...')
buckeye_train_wavs, buckeye_train_bounds= get_buckeye_data(buckeye_corpus_path, train_indices, SR)
buckeye_val_wavs, buckeye_val_bounds= get_buckeye_data(buckeye_corpus_path, val_indices, SR)
buckeye_test_wavs, buckeye_test_bounds= get_buckeye_data(buckeye_corpus_path, test_indices, SR)

print('Extracting Timit data...')
timit_wavs, timit_bounds= get_timit_data(timit_testset_path, timit_indices, SR)

if ntimit_exists:
    print('Extracting NTimit data...')
    ntimit_wavs, ntimit_bounds= get_ntimit_data(ntimit_testset_path, ntimit_indices, SR)

print('Extracting Buckeye labels...')
buckeye_train_labels= get_labels(buckeye_train_wavs, 
                         buckeye_train_bounds, 
                         SR, 
                         FRAME_SIZE, 
                         HOP_LENGTH)

buckeye_val_labels= get_labels(buckeye_val_wavs, 
                       buckeye_val_bounds, 
                       SR, 
                       FRAME_SIZE, 
                       HOP_LENGTH,
                       type='test')

buckeye_test_labels= get_labels(buckeye_test_wavs, 
                        buckeye_test_bounds, 
                        SR, 
                        FRAME_SIZE, 
                        HOP_LENGTH, 
                        type='test')

print('Extracting Timit labels...')
timit_labels= get_labels(timit_wavs,
                         timit_bounds,
                         SR,
                         FRAME_SIZE,
                         HOP_LENGTH,
                         type='test')

if ntimit_exists:
    print('Extracting NTimit labels...')
    ntimit_labels= get_labels(ntimit_wavs,
                              ntimit_bounds,
                              SR,
                              FRAME_SIZE,
                              HOP_LENGTH,
                              type='test')
print('\n')
print('Buckeye Train samples:', len(buckeye_train_wavs), len(buckeye_train_labels), len(buckeye_train_bounds))
print('Buckey Val samples:', len(buckeye_val_wavs), len(buckeye_val_labels), len(buckeye_val_bounds))
print('Buckeye Test samples:', len(buckeye_test_wavs), len(buckeye_test_labels), len(buckeye_test_bounds))
print('Timit Test samples:', len(timit_wavs), len(timit_labels), len(timit_bounds))

if ntimit_exists:
    print('NTimit Test samples:', len(ntimit_wavs), len(ntimit_labels), len(ntimit_bounds))

datasets= (buckeye_train_wavs, 
           buckeye_train_labels, 
           buckeye_train_bounds, 
           buckeye_val_wavs, 
           buckeye_val_labels, 
           buckeye_val_bounds, 
           buckeye_test_wavs, 
           buckeye_test_labels, 
           buckeye_test_bounds)

print('\n')
print('Padding Buckeye testset...')
buckeye_wavs, buckeye_labels, buckeye_bounds= pad_set(datasets, 
                                                      buckeye_test_wavs, 
                                                      buckeye_test_labels, 
                                                      buckeye_test_bounds)

print('Padding Timit testset...')
timit_wavs, timit_labels, timit_bounds= pad_set(datasets, 
                                                timit_wavs, 
                                                timit_labels, 
                                                timit_bounds)

if ntimit_exists:   
    print('Padding NTimit testset...')
    ntimit_wavs, ntimit_labels, ntimit_bounds= pad_set(datasets, 
                                                       ntimit_wavs, 
                                                       ntimit_labels, 
                                                       ntimit_bounds)


test_sets= {'buckeye': (buckeye_wavs, buckeye_labels, buckeye_bounds), 'timit': (timit_wavs, timit_labels, timit_bounds)}

if ntimit_exists:
    print('NTimit testset added to test_sets')
    test_sets['ntimit']= (ntimit_wavs, ntimit_labels, ntimit_bounds)


# Follow the instructions displayed under the cell.

In [None]:
continue_to_test='y'

while continue_to_test=='y':

    model_name, dataset= select_model_and_dataset(ntimit_exists)

    SR=16000
    HOP_LENGTH=int(0.025*SR)
    NUM_CLASSES=3
    BATCH_SIZE= 32
    N_DEV=0
    time=0
    frames_out=0
    verbose=False
    freeze=True
    device = torch.device(f'cuda:{N_DEV}' if torch.cuda.is_available() else 'cpu')

    wavs, labels, bounds, tolerance= choose_dataset(dataset, test_sets, SR)

    print('Wavs shape:', wavs.shape, "Labels shape:", labels.shape, "Bounds shape: ", bounds.shape, "\n")

    # Get the loader
    loader= get_loader(wavs, 
                       labels, 
                       bounds, 
                       BATCH_SIZE,
                       type='test')

    # Retrieve the time axis value of the input and the number of frames in the model's output.
    a, l, b = next(iter(loader))
    time = a.shape[1]
    frames_out = l.shape[1]

    print(f'Time points: {time}')
    print(f'Frames out: {frames_out}\n')

    # Get the model
    model= retrieve_model(model_name, 
                          time, 
                          frames_out, 
                          NUM_CLASSES, 
                          verbose, 
                          freeze)

    bounds = predict( model, 
                    loader, 
                    device, 
                    tolerance = tolerance, 
                    hop_length = HOP_LENGTH, 
                    frame_selection = "mid",
                    desc= 'Test',
                    plot_bounds=False,
                    batch_plot_id=1)
    
    continue_to_test= input('\nDo you want to test another model? (y/n)')