In [1]:
%load_ext autoreload
%autoreload 2

import argparse

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch

from importlib.util import find_spec
if find_spec("text_recognizer") is None:
    import sys
    sys.path.append('..')

from text_recognizer.data.emnist_lines import EMNISTLines

In [9]:
args = argparse.Namespace(max_length=16, max_overlap=0)
dataset = EMNISTLines(args)
dataset.prepare_data()
dataset.setup()
print(dataset)
print('Mapping:', dataset.mapping)

EMNISTLinesDataset loading data from HDF5...
EMNIST Lines Dataset
Min overlap: 0
Max overlap: 0
Num classes: 83
Dims: (1, 28, 448)
Output dims: (16, 1)
Train/val/test sizes: 10000, 2000, 2000
Batch x stats: (torch.Size([128, 1, 28, 448]), torch.float32, tensor(0.), tensor(0.0777), tensor(0.2377), tensor(1.))
Batch y stats: (torch.Size([128, 16]), torch.int64, tensor(1), tensor(66))

Mapping: ['<B>', '<S>', '<E>', '<P>', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', ' ', '!', '"', '#', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '?']


In [3]:
# Sample batch

x, y = next(iter(dataset.train_dataloader()))
print(x.shape, y.shape)

torch.Size([128, 1, 28, 448]) torch.Size([128, 16])


## Simple network

In [8]:
from text_recognizer.models import LineReshapeCNN

model = LineReshapeCNN(data_config=dataset.config())
pred = model(x)
print(pred.shape)

# We can take a look at whether our reshaping is correct by returning x in forward()
# plt.matshow(x[0].squeeze())
# plt.matshow(pred[0, 0])
# plt.matshow(pred[0, 1])
# plt.matshow(pred[0, 2])


torch.Size([128, 83, 16])


We can train this with

```sh
python training/run_experiment.py --max_epochs=5 --gpus=1 --num_workers=4 --data_class=EMNISTLines --max_length=16 --max_overlap=0 --model_class=LineReshapeCNN
```

We can easily get to >90% accuracy.

## All-convolutional network

In [6]:
from text_recognizer.models import LineCNN

model = LineCNN(input_dims=dataset.dims, num_classes=dataset.num_classes)
pred = model(x)
print(pred.shape)

torch.Size([128, 80, 8])


We can train this with

```sh
python training/run_experiment.py --max_epochs=5 --gpus=1 --num_workers=4 --data_class=EMNISTLines --max_length=16 --max_overlap=0 --model_class=LineCNN 
```

We can easily get to >90% accuracy.