# Train OCR text Detector quick example

In [1]:
import os
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
import sys
import torch
import warnings
from datetime import datetime
from matplotlib import pyplot as plt
from torchvision.models import resnet18
from torchvision.models import efficientnet_b0
from torchvision.models import efficientnet_b1
from torchvision.models import efficientnet_b2
from torchvision.models import regnet_y_1_6gf
from torchvision.models import regnet_x_1_6gf
from torchvision.models import regnet_y_800mf
from torchvision.models import mnasnet1_3
from torchvision.models import mobilenet_v3_large
from torchvision.models import shufflenet_v2_x2_0
from torchvision.models import efficientnet_v2_s

warnings.filterwarnings('ignore')

# change this property
NOMEROFF_NET_DIR = os.path.abspath('../../../../')
sys.path.append(NOMEROFF_NET_DIR)

from nomeroff_net.text_detectors.base.ocr import OCR

In [None]:
from torch import nn
from nomeroff_net.nnmodels.torch_backbone_shape_detector import get_output_shape

color_channels, height, width = 3, 50, 200

conv_modules = list(efficientnet_v2_s(pretrained=True).children())[:-3]
conv_nn = nn.Sequential(*conv_modules)
conv_nn

In [None]:
488/16

In [None]:
backbone_c, backbone_h, backbone_w = get_output_shape((color_channels, height, width), conv_nn)
print("suppose shape", backbone_c, backbone_h, backbone_w)
print("real shape", conv_nn(torch.rand((1, color_channels, height, width))).shape)

In [None]:
plt.rcParams["figure.figsize"] = (10, 10)

In [None]:
%matplotlib inline 

In [None]:
# auto download latest dataset
from nomeroff_net.tools import modelhub

# auto download latest dataset
info = modelhub.download_dataset_for_model("Kz")
PATH_TO_DATASET = info["dataset_path"]

# local path dataset
#PATH_TO_DATASET = os.path.join(NOMEROFF_NET_DIR, "./data/dataset/TextDetector/ocr_example")

In [None]:
PATH_TO_DATASET

In [None]:
DATASET_NAME = "kz"
VERSION = f"{datetime.now().strftime('%Y_%m_%d')}_pytorch_lightning"

RESULT_MODEL_PATH = os.path.join(NOMEROFF_NET_DIR, 
                                 "models/", 
                                 'anpr_ocr_{}_{}.ckpt'.format(DATASET_NAME, VERSION))

In [None]:
RESULT_MODEL_PATH

In [None]:
str(shufflenet_v2_x2_0)

In [None]:
class kz(OCR):
    def __init__(self):
        OCR.__init__(self)
        # only for usage model
        # in train generate automaticly
        self.letters = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H",
                        "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"]
        
        # Params
        self.height = 50
        self.width = 200
        self.hidden_size = 32
        self.linear_size = 512
        self.backbone = efficientnet_v2_s
        
        
        # Train hyperparameters
        self.batch_size = 32
        self.epochs = 100
        self.gpus = 1#torch.cuda.device_count()

In [14]:
ocrTextDetector = kz()
model = ocrTextDetector.prepare(PATH_TO_DATASET, use_aug=False, num_workers=1)

GET ALPHABET
Max plate length in "val": 8
Max plate length in "train": 8
Max plate length in "test": 8
Letters train  {'Z', '8', '9', 'G', 'A', '7', '1', 'L', 'B', 'J', '2', '0', 'R', 'M', 'P', 'V', 'C', 'Q', 'T', '6', 'E', '5', 'N', 'X', 'W', 'I', 'F', '4', '3', 'Y', 'U', 'D', 'S', 'O', 'H', 'K'}
Letters val  {'8', 'Z', '9', 'G', 'A', '7', '1', 'L', 'B', 'J', '2', '0', 'R', 'M', 'P', 'V', 'C', 'Q', 'T', '6', '5', 'E', 'N', 'X', 'W', 'I', 'F', '4', '3', 'Y', 'U', 'D', 'O', 'S', 'H', 'K'}
Letters test  {'8', 'Z', '9', 'G', 'A', '7', '1', 'L', 'B', 'J', '2', '0', 'R', 'M', 'P', 'V', 'C', 'Q', 'T', '6', '5', 'E', 'N', 'W', 'X', 'I', 'F', '4', '3', 'Y', 'U', 'D', 'O', 'S', 'H', 'K'}
Max plate length in train, test and val do match
Letters in train, val and test do match
Letters: 0 1 2 3 4 5 6 7 8 9 A B C D E F G H I J K L M N O P Q R S T U V W X Y Z
START BUILD DATA


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8642/8642 [00:00<00:00, 32694.46it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1001/1001 [00:00<00:00, 33757.86it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 279/279 [00:00<00:00, 33384.99it/s]

DATA PREPARED





In [15]:
#ocrTextDetector.load(RESULT_MODEL_PATH)

In [16]:
# # tune
# lr_finder = ocrTextDetector.tune()

# # Plot with
# fig = lr_finder.plot(suggest=True)
# fig.show()

In [17]:
ocrTextDetector.train()

  rank_zero_deprecation(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type       | Params
------------------------------------------------
0 | conv_nn          | Sequential | 1.2 M 
1 | linear1          | Linear     | 406 K 
2 | recurrent_layer1 | BlockRNN   | 28.9 K
3 | recurrent_layer2 | BlockRNN   | 4.4 K 
4 | linear2          | Linear     | 1.2 K 
------------------------------------------------
1.6 M     Trainable params
0         Non-trainable params
1.6 M     Total params
6.415     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


[INFO] best model path /mnt/storage2/var/www/data/logs/ocr/epoch=63-step=17344.ckpt


NPOcrNet(
  (conv_nn): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (1): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (2): Sequential(
      (0): InvertedResidual(
        (branch1): Sequential(
          (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=24, bias=False)
          (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): Conv2d(24, 122, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (3): BatchNorm2d(122, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (4): ReLU(inplace=True)
        )
        (branch2): Sequential(
          (0): Conv2d(24, 122, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(122, eps=1e

In [None]:
#ocrTextDetector.save(RESULT_MODEL_PATH)

In [None]:
#ocrTextDetector.load(RESULT_MODEL_PATH)

In [18]:
ocrTextDetector.model

NPOcrNet(
  (conv_nn): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (1): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (2): Sequential(
      (0): InvertedResidual(
        (branch1): Sequential(
          (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=24, bias=False)
          (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): Conv2d(24, 122, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (3): BatchNorm2d(122, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (4): ReLU(inplace=True)
        )
        (branch2): Sequential(
          (0): Conv2d(24, 122, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(122, eps=1e

In [19]:
pytorch_total_params = sum(p.numel() for p in ocrTextDetector.model.parameters())
pytorch_total_params

1603857

In [20]:
import torch

model = ocrTextDetector.model

size_model = 0
for param in model.parameters():
    if param.data.is_floating_point():
        size_model += param.numel() * torch.finfo(param.data.dtype).bits
    else:
        size_model += param.numel() * torch.iinfo(param.data.dtype).bits
print(f"model size: {size_model} / bit | {size_model / 8e6:.2f} / MB")

model size: 51323424 / bit | 6.42 / MB


In [21]:
ocrTextDetector.test_acc(verbose=True)
#ocrTextDetector.val_acc(verbose=False)
#ocrTextDetector.train_acc(verbose=False)


[INFO] /mnt/storage2/var/www/nomeroff-net/nomeroff_net/tools/../../data/./dataset/TextDetector/Kz/autoriaNumberplateOcrKz-2019-04-26/test/img/90549504-3-full.jpg-0.png
Predicted: 94gbb02 			 True: 994gbb02

[INFO] /mnt/storage2/var/www/nomeroff-net/nomeroff_net/tools/../../data/./dataset/TextDetector/Kz/autoriaNumberplateOcrKz-2019-04-26/test/img/12583383.jpg-0.png
Predicted: 243fwa02 			 True: 243mwa02

[INFO] /mnt/storage2/var/www/nomeroff-net/nomeroff_net/tools/../../data/./dataset/TextDetector/Kz/autoriaNumberplateOcrKz-2019-04-26/test/img/12515244.jpg-0.png
Predicted: 749waa10 			 True: 749haa10

[INFO] /mnt/storage2/var/www/nomeroff-net/nomeroff_net/tools/../../data/./dataset/TextDetector/Kz/autoriaNumberplateOcrKz-2019-04-26/test/img/90889734-14-full.jpg-0.png
Predicted: 462ga02 			 True: 462iga02

[INFO] /mnt/storage2/var/www/nomeroff-net/nomeroff_net/tools/../../data/./dataset/TextDetector/Kz/autoriaNumberplateOcrKz-2019-04-26/test/img/90587962-12-full.jpg-0.png
Predicted: 75

0.9032258064516129

## than train with augumentation

In [21]:
# for i in range(0,1):
#     # Train next 2 epochs on augumentated dataset
#     ocrTextDetector.epochs += 2

#     # prepare with augumentation
#     ocrTextDetector.prepare(PATH_TO_DATASET, use_aug=True, num_workers=1, seed=i)

#     # Plot with
#     #fig = lr_finder.plot(suggest=True)
#     #fig.show()
#     model = ocrTextDetector.train(seed=i, ckpt_path=RESULT_MODEL_PATH)
#     ocrTextDetector.test_acc(verbose=False)
#     ocrTextDetector.save(RESULT_MODEL_PATH)