# Training file (testing if training works)

In [None]:
import sys
sys.path.append('../')

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.nn import CTCLoss
import torch.optim as optim

from src.training.trainer import train_model
from src.dataset.custom_dataset import OdometerDataset, MJSynthDataset
from src.dataset.base_dataset import base_collate_fn
from src.models.crnn import CRNN

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

In [None]:

data_dir = '/home/yannou/OneDrive/Documents/3_PRO/carviz/data/ocr/odometer_reader/milestone_box'
labels_file = '/home/yannou/OneDrive/Documents/3_PRO/carviz/data/ocr/odometer_reader/milestone_box/milestone_labels.json'
#data_dir = '/home/yannou/OneDrive/Documents/3_PRO/carviz/data/ocr/MJSynth_text_recognition'


# Définir les transformations
transform = transforms.Compose([
    transforms.RandomRotation(7),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

transform_val = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Initialiser le dataset
dataset = OdometerDataset(root_dir=data_dir, split="train", labels_file=labels_file, img_height=32, img_width=100, transform=transform)
dataset_val = OdometerDataset(root_dir=data_dir, split="val", labels_file=labels_file, img_height=32, img_width=100, transform=transform_val)

#dataset = MJSynthDataset(root_dir=data_dir, split="train",  img_height=32, img_width=100, transform=transform)
#dataset_val = MJSynthDataset(root_dir=data_dir, split="val",  img_height=32, img_width=100, transform=transform_val)


# Créer les DataLoaders
train_loader = DataLoader(dataset, batch_size=256, shuffle=True, collate_fn=base_collate_fn)
valid_loader = DataLoader(dataset_val, batch_size=256, shuffle=False, collate_fn=base_collate_fn)



In [None]:
# Initialiser le modèle
num_class = len(dataset.LABEL2CHAR) + 1
crnn = CRNN(img_channel=1, img_height=32, img_width=100, num_class=num_class, model_size="n", leaky_relu=True).to(device)

# Définir les paramètres d'entraînement
lr = 0.001
epochs = 2
decode_method = 'beam_search'
beam_size = 10
label2char = OdometerDataset.LABEL2CHAR

In [None]:
num_class

In [None]:
print(f"Working on {device}")
# Appeler la fonction train_model
trained_model = train_model(
    model=crnn,
    train_loader=train_loader,
    valid_loader=valid_loader,
    label2char=label2char,
    device=device,
    lr=lr,
    epochs=epochs,
    decode_method=decode_method,
    beam_size=beam_size,
    criterion=CTCLoss(reduction='sum', zero_infinity=True).to(device),
    optimizer=optim.Adam(crnn.parameters(), lr=lr),
    project_name="odometer-reader",
    run_name="crnn-n", 
    checkpoint=10
)



