# DeepTurkish Testing

This notebook is loading a trained model and directly testing it.

In [None]:
import os

import torch

import utilities.utilities as utils
from model.neural_network import make_model
from model.data_loader import make_loaders
from decoders import decoders
from evaluation import test

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device: {}".format(device))

## 1) Load the State Dicts,

In [None]:
project_name ="try"
model_name = '13_18__02_02_2021'

iter_meter = utils.IterMeter(project_name, model_name)

model_path = os.path.join("data","models and losses",project_name,model_name,"{}.pt".format(model_name))
model_state = torch.load(model_path) 

hyperparameters = model_state['hyperparameters']
data_parameters = model_state['data_parameters']

hyperparameters['model_dir'] = model_path # for loading the model

## 2) Make the Model, Criterion and the Loader

In [None]:
# Create the test loader
_,_,test_loader = make_loaders(data_parameters, sortagrad=False)

# Create the model and the Criterion
model, criterion, _, _ = make_model(hyperparameters, data_parameters['blank'], len(test_loader), device)

## 3) Choose a Decoder

Choose a decoder for decoding the ctc output matrix.

In [None]:
# Argmax decoder
decoder = decoders.Argmax_decoder(data_parameters['alphabet'], data_parameters['blank'])

In [None]:
# BeamSearch Decoder
LM_text_name="NN_datasets_sentences"
beam_width = 3
prune_threshold = -7 # log(0.001)

decoder = decoders.BeamSearch_decoder(data_parameters['alphabet'], data_parameters['blank'], beam_width, prune_threshold, LM_text_name)

In [None]:
# LexiconSearch
tolerance = 1

# choose an apprroximator for the Lexicon Search algorithm
BW = 2
prune = -7 # = log(0.001)
LM_text_name="NN_datasets_sentences"

approximator_properties = ('BeamSearch+LM',data_parameters['blank'], BW, prune, LM_text_name)

decoder = decoders.LexiconSearch_decoder(data_parameters['alphabet'], tolerance, LM_text_name, approximator_properties)

## 4) Test

In [None]:
avg_test_loss, avg_cer, avg_wer = test(model, criterion, decoder, test_loader)