## Model training

In [None]:
import random
import string
from pathlib import Path

from torch.utils.data import DataLoader

import training_utils 
import qrcode_utils

### Constants

In [None]:
DATASET_SIZE = 4000

QRCODE_VERSION = 1
QRCODE_IMAGE_SIZE = 17 + QRCODE_VERSION * 4
STYLE_NAME = "green_orange"
TRAINING_QRCODES_DIR = Path(f"{STYLE_NAME}-train_data")
DEFAULT_QRCODES_DIR = TRAINING_QRCODES_DIR / "default"
STYLED_QRCODES_DIR = TRAINING_QRCODES_DIR / "styled"
DEFAULT_DEVICE = "cuda" # "cpu"

### Create Datasets

Generate training QR codes

In [None]:
def generate_random_data():
    return ''.join(
        random.choices(string.ascii_letters, k = random.randint(1, 10))
    )

def generate_training_qrcodes(
        style_name, dataset_size, qrcode_version, 
        default_qrcodes_path, styled_qrcodes_path, force=False
    ):
    if not force and default_qrcodes_path.exists():
        return
    
    default_qrcodes_path.mkdir(exist_ok=True, parents=True)
    styled_qrcodes_path.mkdir(exist_ok=True, parents=True)

    for i in range(1, dataset_size + 1):
        qrcode_data = generate_random_data()

        default_qrcode_img = qrcode_utils.generate_qrcode_image(
            qrcode_version, qrcode_data
        )
        default_qrcode_img.save(default_qrcodes_path / f"{i}.jpg")

        styled_qrcode_img = qrcode_utils.generate_qrcode_image(
            qrcode_version, qrcode_data, qrcode_utils.get_color_mask(style_name)
        )
        styled_qrcode_img.save(styled_qrcodes_path / f"{i}.jpg")

generate_training_qrcodes(
    STYLE_NAME, DATASET_SIZE, QRCODE_VERSION, 
    DEFAULT_QRCODES_DIR, STYLED_QRCODES_DIR, force=True
)

Split data into train, test and val

In [None]:
default_qrcode_train, _, default_qrcode_val = \
    training_utils.create_qrcodes_datasets(DEFAULT_QRCODES_DIR, DATASET_SIZE)
print (default_qrcode_train.shape, default_qrcode_val.shape)

st_qrcode_train, _, st_qrcode_val = \
    training_utils.create_qrcodes_datasets(STYLED_QRCODES_DIR, DATASET_SIZE)
print (st_qrcode_train.shape, st_qrcode_val.shape)

train_loader = DataLoader(
    training_utils.QRCodeImageDataset(
        default_qrcode_train, st_qrcode_train, QRCODE_IMAGE_SIZE
    ),
    batch_size=5, shuffle=True
)

val_loader = DataLoader(
    training_utils.QRCodeImageDataset(
        default_qrcode_val, st_qrcode_val, QRCODE_IMAGE_SIZE
    ),
    batch_size=5, shuffle=True
)

## Models

We are going to train two models: QuantAE and QuantAEPruned
- QuantAEPruned is just a prunned variation of QuantAE
- Both models will use the same generated training data

More information about models you can find in README.md

### Init

In [None]:
ae_model = training_utils.get_ae_model(f"ae_{QRCODE_IMAGE_SIZE}", QRCODE_IMAGE_SIZE)
ae_model.to(device=DEFAULT_DEVICE)

pruned_ae_model = training_utils.get_ae_pruned_model(f"ae_{QRCODE_IMAGE_SIZE}_pruned", QRCODE_IMAGE_SIZE)
pruned_ae_model.to(device=DEFAULT_DEVICE)

### Training
- 3 stages, each consists of 30 epoch
- Learning rate is set to default before each stage

In [None]:
TRAINING_PARAMS = {
    "LR": 0.001,
    "weight_decay": 0.0,
    "scheduler_gamma": 0.95,
    "device": DEFAULT_DEVICE,
    "stages": {
        1: {
            "start_epoch": 0,
            "epochs": 30
        },
        2: {
            "start_epoch": 30,
            "epochs": 60
        },
        3: {
            "start_epoch": 60,
            "epochs": 90
        },
    }
}

In [None]:
# AE model with three stages of training
for stage in [1, 2, 3]:
    training_utils.train(
        model = ae_model,
        dataloader_train = train_loader,
        dataloader_val = val_loader,
        stage = stage,
        params = TRAINING_PARAMS,
    )

In [None]:
# Pruned AE  model with three stages of training
pruned_ae_model.prune(True)
for stage in [1, 2, 3]:
    training_utils.train(
        model = pruned_ae_model,
        dataloader_train = train_loader,
        dataloader_val = val_loader,
        stage = stage,
        params = TRAINING_PARAMS,
    )
pruned_ae_model.prune(False)

## Saving

In [None]:
training_utils.save_model(ae_model)
training_utils.save_model(pruned_ae_model)