# Train OCR text Detector quick example

In [None]:
import os
import sys
import torch
import warnings
from datetime import datetime
from matplotlib import pyplot as plt
warnings.filterwarnings('ignore')

# change this property
NOMEROFF_NET_DIR = os.path.abspath('../../../../')
sys.path.append(NOMEROFF_NET_DIR)

from nomeroff_net.pipes.number_plate_text_readers.base.ocr import OCR

In [None]:
plt.rcParams["figure.figsize"] = (10, 10)

In [None]:
%matplotlib inline 

In [None]:
# auto download latest dataset
from nomeroff_net.tools import modelhub

## auto download latest dataset
#info = modelhub.download_dataset_for_model("Eu")
#PATH_TO_DATASET = info["dataset_path"]

# local path dataset
PATH_TO_DATASET = os.path.join(NOMEROFF_NET_DIR, "./data/dataset/TextDetector/ocr_example")

In [None]:
PATH_TO_DATASET

In [None]:
DATASET_NAME = "eu"
VERSION = f"{datetime.now().strftime('%Y_%m_%d')}_pytorch_lightning"

RESULT_MODEL_PATH = os.path.join(NOMEROFF_NET_DIR, 
                                 "models/", 
                                 'anpr_ocr_{}_{}.ckpt'.format(DATASET_NAME, VERSION))

In [None]:
RESULT_MODEL_PATH

In [None]:
class eu(OCR):
    def __init__(self):
        OCR.__init__(self)
        # only for usage model
        # in train generate automaticly
        self.letters = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I",
                        "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"]
        

        # Train hyperparameters
        self.batch_size = 4
        self.epochs = 5
        self.gpus = torch.cuda.device_count()

In [None]:
ocrTextDetector = eu()
model = ocrTextDetector.prepare(PATH_TO_DATASET, use_aug=False, num_workers=1)


In [None]:
# # tune
# lr_finder = ocrTextDetector.tune()
#
# # Plot with
# fig = lr_finder.plot(suggest=True)
# fig.show()

In [None]:
ocrTextDetector.train()

In [None]:
ocrTextDetector.save(RESULT_MODEL_PATH, weights_only=False)

In [None]:
ocrTextDetector.load(RESULT_MODEL_PATH)

In [None]:
ocrTextDetector.test_acc(verbose=True)
#ocrTextDetector.val_acc(verbose=False)
#ocrTextDetector.train_acc(verbose=False)


## than train with augumentation

In [None]:
for i in range(0,1):
    # Train next 2 epochs on augumentated dataset
    ocrTextDetector.epochs += 2

    # prepare with augumentation
    ocrTextDetector.prepare(PATH_TO_DATASET, use_aug=True, num_workers=1, seed=i)

    # Plot with
    #fig = lr_finder.plot(suggest=True)
    #fig.show()
    model = ocrTextDetector.train(seed=i, ckpt_path=RESULT_MODEL_PATH)
    ocrTextDetector.test_acc(verbose=False)
    ocrTextDetector.save(RESULT_MODEL_PATH, weights_only=False)

In [None]:
# Save only weights results
ocrTextDetector.save(RESULT_MODEL_PATH, weights_only=True)