# Training file (testing if training works)

In [1]:
import sys
sys.path.append('../')

In [6]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.nn import CTCLoss
import torch.optim as optim

from src.training.trainer import train_model
from src.dataset.custom_dataset import OdometerDataset, MJSynthDataset
from src.dataset.base_dataset import base_collate_fn
from src.models.mobilevit_rnn import MobileViT_RNN

In [7]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [8]:

data_dir = '/home/yannou/OneDrive/Documents/3_PRO/carviz/data/ocr/odometer_reader/milestone_box'
labels_file = '/home/yannou/OneDrive/Documents/3_PRO/carviz/data/ocr/odometer_reader/milestone_box/milestone_labels.json'
#data_dir = '/home/yannou/OneDrive/Documents/3_PRO/carviz/data/ocr/MJSynth_text_recognition'


# Définir les transformations
transform = transforms.Compose([
    transforms.RandomRotation(7),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

transform_val = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Initialiser le dataset
dataset = OdometerDataset(root_dir=data_dir, split="train", labels_file=labels_file, img_height=32, img_width=100, transform=transform)
dataset_val = OdometerDataset(root_dir=data_dir, split="val", labels_file=labels_file, img_height=32, img_width=100, transform=transform_val)

#dataset = MJSynthDataset(root_dir=data_dir, split="train",  img_height=32, img_width=100, transform=transform)
#dataset_val = MJSynthDataset(root_dir=data_dir, split="val",  img_height=32, img_width=100, transform=transform_val)


# Créer les DataLoaders
train_loader = DataLoader(dataset, batch_size=256, shuffle=True, collate_fn=base_collate_fn)
valid_loader = DataLoader(dataset_val, batch_size=256, shuffle=False, collate_fn=base_collate_fn)



In [10]:
# Initialiser le modèle
num_class = len(dataset.LABEL2CHAR) + 1
crnn = MobileViT_RNN(img_channel=1, img_height=32, img_width=100, num_class=num_class, model_size="s").to(device)
# Définir les paramètres d'entraînement
lr = 0.001
epochs = 2
decode_method = 'beam_search'
beam_size = 10
label2char = OdometerDataset.LABEL2CHAR

In [11]:
num_class

11

In [12]:
print(f"Working on {device}")
# Appeler la fonction train_model
trained_model = train_model(
    model=crnn,
    train_loader=train_loader,
    valid_loader=valid_loader,
    label2char=label2char,
    device=device,
    lr=lr,
    epochs=epochs,
    decode_method=decode_method,
    beam_size=beam_size,
    criterion=CTCLoss(reduction='sum', zero_infinity=True).to(device),
    optimizer=optim.Adam(crnn.parameters(), lr=lr),
    project_name="odometer-reader",
    run_name="test-mobilevit-n", 
    checkpoint=10
)



Working on cuda:0


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33myann-t[0m ([33mcarviz-com[0m). Use [1m`wandb login --relogin`[0m to force relogin


Training Epoch 1: 100%|██████████| 12/12 [00:50<00:00,  4.20s/it]
Evaluate: 100%|██████████| 3/3 [00:10<00:00,  3.49s/it]
Epochs:  50%|█████     | 1/2 [01:03<01:03, 63.37s/it]

Epoch 1: train_loss=0.3758778176038368, train_accuracy=0.0, val_loss=0.3915140056358478, val_accuracy=0.0, train_word_accuracy=0.0, train_char_accuracy=0.04917187600696011, train_average_levenshtein_distance=4.445514950166113val_word_accuracy=0.0, val_char_accuracy=0.0310415597742432, val_average_levenshtein_distance=4.493403693931398


Training Epoch 2: 100%|██████████| 12/12 [00:44<00:00,  3.74s/it]
Evaluate: 100%|██████████| 3/3 [00:10<00:00,  3.40s/it]
Epochs: 100%|██████████| 2/2 [02:00<00:00, 60.39s/it]

Epoch 2: train_loss=0.369356865106627, train_accuracy=0.0, val_loss=0.39340008657651715, val_accuracy=0.0, train_word_accuracy=0.0, train_char_accuracy=0.05071856673326029, train_average_levenshtein_distance=4.451495016611296val_word_accuracy=0.0, val_char_accuracy=0.03155464340687532, val_average_levenshtein_distance=4.496042216358839





0,1
epoch,▁█
train/average_levenshtein_distance,▁█
train/char_accuracy,▁█
train/train_loss,█▁
train/word_accuracy,▁▁
val/average_levenshtein_distance,▁█
val/char_accuracy,▁█
val/val_loss,▁█
val/word_accuracy,▁▁

0,1
epoch,2.0
train/average_levenshtein_distance,4.4515
train/char_accuracy,0.05072
train/train_loss,0.36936
train/word_accuracy,0.0
val/average_levenshtein_distance,4.49604
val/char_accuracy,0.03155
val/val_loss,0.3934
val/word_accuracy,0.0
