# Training data preparation

## Download datasets

In [None]:
"""
%%sh
cat download_data.sh
./download_data.sh
"""

## Patch the datasets

In [None]:
from deep_learning_lab.data_preparation import Orchestrator, DataStructure

Orchestrator.DATASETS.keys()

In [None]:
sets_labels = [['ImageRegion'], ['TextLine'], ['TextRegion']] # Atomic labels are to be promoted

orc = Orchestrator(
    output_structure= DataStructure(dir_data= "training_data",
                                    dir_images= "images",
                                    dir_labels= "labels")
)

orc.ingestDatasets(
    datasets= [],
    add_defaults= True
)

for set_labels in sets_labels:
    orc.ingestLabels(
        uniform_set_labels= set_labels,
        prompt= False
    )
    orc.validate(
        auto_yes= True,
        verbose= 1
    )
    orc.preprocess(
        resize= (1188, 841), # To have 1e6 pixels and tensors of same size
        overwrite= True,
        verbose= 2
    )
    print()

# Deep learning lab

In [None]:
import deep_learning_lab.gpu_setup as gpu

gpu.cudaDeviceSelection(preselected_device= 0)
print(gpu.cudaInfo())

## Training

In [None]:
labels = ['TextRegion']

In [None]:
from deep_learning_lab import model

trainer = model.Trainer(labels)

In [None]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

tensorboard_dir = trainer.tensorboard_dir
!echo $tensorboard_dir
#!rm -r $tensorboard_dir & mkdir -p $tensorboard_dir
%tensorboard --logdir $tensorboard_dir

In [None]:
trainer.train(
    batch_size= 4,
    epochs= 1,
    learning_rate= 1e-4,
    gamma_exp_lr= 0.9995,
    evaluate_every_epoch= 5,
    val_patience= 4,
    repeat_dataset= 4,
    output_size= 1e6
)

## Inference

In [None]:
labels = ['TextLine']

In [None]:
from deep_learning_lab import model
import os

predictor = model.Predictor(
    labels,
    input_dir= 'inference_data',
    output_dir= None,
    output_size= None,
    from_csv= os.path.join('training_data', 'test.csv'),
    reset_input= True
)

predictor.start(
    batch_size= 4,
    save_probas= True
)
#predictor.load()

results = predictor.postProcess(
    drawRegions= True,
    cutVignettes= True,
    bounding_box= False,
    verbose= True
)

## Tests

In [None]:
# from matplotlib import pyplot as plt
from PIL import Image
import matplotlib.pyplot as plt

assert len(results)

In [None]:
results[0].keys()

In [None]:
image_nb = 0
predictions = results[image_nb]

In [None]:
Image.fromarray(predictions['regions'])

In [None]:
Image.fromarray(predictions['probasMaps'][1])

In [None]:
for _, vignette in enumerate(predictions['vignettes']):
    plt.imshow(Image.fromarray(vignette))
    plt.show()