# Training file (testing if training works)

In [1]:
import sys
sys.path.append('../')

In [2]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.nn import CTCLoss
import torch.optim as optim

from src.training.trainer import train_model
from src.dataset.custom_dataset import OdometerDataset
from src.dataset.base_dataset import base_collate_fn
from src.models.crnn import CRNN

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [4]:

data_dir = '/home/yannou/OneDrive/Documents/3_PRO/carviz/data/ocr/odometer_reader/milestone_box'
labels_file = '/home/yannou/OneDrive/Documents/3_PRO/carviz/data/ocr/odometer_reader/milestone_box/milestone_labels.json'

# Définir les transformations
transform = transforms.Compose([
    transforms.RandomRotation(7),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

transform_val = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Initialiser le dataset
dataset = OdometerDataset(root_dir=data_dir, split="train", labels_file=labels_file, img_height=32, img_width=100, transform=transform)
dataset_val = OdometerDataset(root_dir=data_dir, split="val", labels_file=labels_file, img_height=32, img_width=100, transform=transform_val)


# Créer les DataLoaders
train_loader = DataLoader(dataset, batch_size=64, shuffle=True, collate_fn=base_collate_fn)
valid_loader = DataLoader(dataset_val, batch_size=64, shuffle=False, collate_fn=base_collate_fn)



In [5]:
# Initialiser le modèle
num_class = len(OdometerDataset.LABEL2CHAR) + 1
crnn = CRNN(img_channel=1, img_height=32, img_width=100, num_class=num_class, model_size="n", leaky_relu=True).to(device)

# Définir les paramètres d'entraînement
lr = 0.001
epochs = 15
decode_method = 'beam_search'
beam_size = 10
label2char = OdometerDataset.LABEL2CHAR

In [6]:
print(f"Working on {device}")
# Appeler la fonction train_model
trained_model = train_model(
    model=crnn,
    train_loader=train_loader,
    valid_loader=valid_loader,
    label2char=label2char,
    device=device,
    lr=lr,
    epochs=epochs,
    decode_method=decode_method,
    beam_size=beam_size,
    criterion=CTCLoss(reduction='sum', zero_infinity=True).to(device),
    optimizer=optim.Adam(crnn.parameters(), lr=lr),
    project_name="odometer-reader",
    run_name="crnn-n", 
    checkpoint=2
)



Working on cuda:0


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33myann-t[0m ([33mcarviz-com[0m). Use [1m`wandb login --relogin`[0m to force relogin


Evaluate: 100%|██████████| 12/12 [00:14<00:00,  1.21s/it]
Epochs:   7%|▋         | 1/15 [01:25<20:02, 85.87s/it]

Epoch 1: train_loss=20.262676032982558, train_accuracy=0.0, val_loss=13.68884794590837, val_accuracy=0.0, train_word_accuracy=0.0, train_char_accuracy=0.0022771838825985198, train_average_levenshtein_distance=5.179705400981997val_word_accuracy=0.0, val_char_accuracy=0.0, val_average_levenshtein_distance=5.160365058670143


Evaluate: 100%|██████████| 12/12 [00:14<00:00,  1.22s/it]
Epochs:  13%|█▎        | 2/15 [02:48<18:08, 83.74s/it]

Epoch 2: train_loss=13.732456258471, train_accuracy=0.0, val_loss=13.683236929086538, val_accuracy=0.0, train_word_accuracy=0.0, train_char_accuracy=0.0, train_average_levenshtein_distance=5.174795417348609val_word_accuracy=0.0, val_char_accuracy=0.0, val_average_levenshtein_distance=5.160365058670143


Evaluate: 100%|██████████| 12/12 [00:13<00:00,  1.11s/it]
Epochs:  20%|██        | 3/15 [04:09<16:33, 82.83s/it]

Epoch 3: train_loss=13.721676930429112, train_accuracy=0.0, val_loss=13.62155425906026, val_accuracy=0.0, train_word_accuracy=0.0, train_char_accuracy=0.0, train_average_levenshtein_distance=5.174795417348609val_word_accuracy=0.0, val_char_accuracy=0.0, val_average_levenshtein_distance=5.160365058670143


Evaluate: 100%|██████████| 12/12 [00:14<00:00,  1.22s/it]
Epochs:  27%|██▋       | 4/15 [05:30<15:02, 82.08s/it]

Epoch 4: train_loss=13.490687873203509, train_accuracy=0.0, val_loss=13.575992907507946, val_accuracy=0.0, train_word_accuracy=0.0, train_char_accuracy=0.0, train_average_levenshtein_distance=5.174795417348609val_word_accuracy=0.0, val_char_accuracy=0.0, val_average_levenshtein_distance=5.160365058670143


Evaluate: 100%|██████████| 12/12 [00:16<00:00,  1.39s/it]
Epochs:  33%|███▎      | 5/15 [06:57<13:58, 83.89s/it]

Epoch 5: train_loss=12.91983936266111, train_accuracy=0.0, val_loss=12.19790486282289, val_accuracy=0.0, train_word_accuracy=0.0, train_char_accuracy=0.0, train_average_levenshtein_distance=5.174795417348609val_word_accuracy=0.0, val_char_accuracy=0.0, val_average_levenshtein_distance=5.160365058670143


Evaluate: 100%|██████████| 12/12 [00:15<00:00,  1.28s/it]
Epochs:  40%|████      | 6/15 [08:22<12:36, 84.09s/it]

Epoch 6: train_loss=11.737698728870448, train_accuracy=0.0, val_loss=10.483139973109518, val_accuracy=0.0, train_word_accuracy=0.0, train_char_accuracy=0.011891960275792271, train_average_levenshtein_distance=4.912929623567922val_word_accuracy=0.0, val_char_accuracy=0.028044466902475997, val_average_levenshtein_distance=4.482398956975228


Evaluate: 100%|██████████| 12/12 [00:12<00:00,  1.07s/it]
Epochs:  47%|████▋     | 7/15 [09:39<10:54, 81.83s/it]

Epoch 7: train_loss=10.268529374392646, train_accuracy=0.0009819967266775777, val_loss=10.1262409951261, val_accuracy=0.0, train_word_accuracy=0.0009819967266775777, train_char_accuracy=0.049528749446517806, train_average_levenshtein_distance=4.183960720130933val_word_accuracy=0.0, val_char_accuracy=0.07099545224861041, val_average_levenshtein_distance=3.9204693611473274


Evaluate: 100%|██████████| 12/12 [00:13<00:00,  1.12s/it]
Epochs:  53%|█████▎    | 8/15 [10:57<09:23, 80.56s/it]

Epoch 8: train_loss=8.704149421654435, train_accuracy=0.002618657937806874, val_loss=7.550092794127384, val_accuracy=0.01303780964797914, train_word_accuracy=0.002618657937806874, train_char_accuracy=0.10487696881523183, train_average_levenshtein_distance=3.5090016366612113val_word_accuracy=0.01303780964797914, val_char_accuracy=0.17710965133906012, val_average_levenshtein_distance=2.864406779661017


Evaluate: 100%|██████████| 12/12 [00:13<00:00,  1.14s/it]
Epochs:  60%|██████    | 9/15 [12:12<07:54, 79.01s/it]

Epoch 9: train_loss=7.641033565939546, train_accuracy=0.024549918166939442, val_loss=6.793735190818021, val_accuracy=0.018252933507170794, train_word_accuracy=0.024549918166939442, train_char_accuracy=0.18502119046112975, train_average_levenshtein_distance=2.915875613747954val_word_accuracy=0.018252933507170794, val_char_accuracy=0.18367862556846892, val_average_levenshtein_distance=2.7770534550195567


Evaluate: 100%|██████████| 12/12 [00:15<00:00,  1.30s/it]
Epochs:  67%|██████▋   | 10/15 [13:27<06:28, 77.65s/it]

Epoch 10: train_loss=6.61514170345425, train_accuracy=0.07266775777414075, val_loss=5.574555438065311, val_accuracy=0.11603650586701435, train_word_accuracy=0.07266775777414075, train_char_accuracy=0.2986273641596559, train_average_levenshtein_distance=2.4448445171849427val_word_accuracy=0.11603650586701435, val_char_accuracy=0.4120768064679131, val_average_levenshtein_distance=1.9661016949152543


Evaluate: 100%|██████████| 12/12 [00:12<00:00,  1.03s/it]
Epochs:  73%|███████▎  | 11/15 [14:37<05:00, 75.15s/it]

Epoch 11: train_loss=5.944090518233429, train_accuracy=0.12274959083469722, val_loss=5.769139733743481, val_accuracy=0.16297262059973924, train_word_accuracy=0.12274959083469722, train_char_accuracy=0.39964577139604024, train_average_levenshtein_distance=2.0808510638297872val_word_accuracy=0.16297262059973924, val_char_accuracy=0.494188984335523, val_average_levenshtein_distance=1.8904823989569752


Evaluate: 100%|██████████| 12/12 [00:11<00:00,  1.01it/s]
Epochs:  80%|████████  | 12/15 [15:45<03:38, 73.00s/it]

Epoch 12: train_loss=5.523321872842292, train_accuracy=0.1574468085106383, val_loss=4.262143290493603, val_accuracy=0.24771838331160365, train_word_accuracy=0.1574468085106383, train_char_accuracy=0.4524005313429059, train_average_levenshtein_distance=1.9099836333878888val_word_accuracy=0.24771838331160365, val_char_accuracy=0.5654370894391106, val_average_levenshtein_distance=1.4485006518904824


Evaluate: 100%|██████████| 12/12 [00:11<00:00,  1.08it/s]
Epochs:  87%|████████▋ | 13/15 [16:53<02:22, 71.49s/it]

Epoch 13: train_loss=4.969340882246701, train_accuracy=0.20654664484451718, val_loss=3.8975808791417963, val_accuracy=0.2842242503259452, train_word_accuracy=0.20654664484451718, train_char_accuracy=0.4913024226706307, train_average_levenshtein_distance=1.7083469721767595val_word_accuracy=0.2842242503259452, val_char_accuracy=0.5639211723092471, val_average_levenshtein_distance=1.4211212516297262


Evaluate: 100%|██████████| 12/12 [00:11<00:00,  1.01it/s]
Epochs:  93%|█████████▎| 14/15 [18:00<01:10, 70.36s/it]

Epoch 14: train_loss=4.645981064716845, train_accuracy=0.2500818330605565, val_loss=3.8914755929402296, val_accuracy=0.3076923076923077, train_word_accuracy=0.2500818330605565, train_char_accuracy=0.5426023151369473, train_average_levenshtein_distance=1.5590834697217677val_word_accuracy=0.3076923076923077, val_char_accuracy=0.6336533602829711, val_average_levenshtein_distance=1.2985658409387224


Evaluate: 100%|██████████| 12/12 [00:12<00:00,  1.04s/it]
Epochs: 100%|██████████| 15/15 [19:08<00:00, 76.56s/it]

Epoch 15: train_loss=4.29250313966442, train_accuracy=0.27430441898527, val_loss=3.3745337813276666, val_accuracy=0.3898305084745763, train_word_accuracy=0.27430441898527, train_char_accuracy=0.5774558795622746, train_average_levenshtein_distance=1.448772504091653val_word_accuracy=0.3898305084745763, val_char_accuracy=0.6526023244062658, val_average_levenshtein_distance=1.1225554106910038





0,1
epoch,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
train/average_levenshtein_distance,█████▇▆▅▄▃▂▂▁▁▁
train/char_accuracy,▁▁▁▁▁▁▂▂▃▅▆▆▇██
train/train_loss,█▅▅▅▅▄▄▃▂▂▂▂▁▁▁
train/word_accuracy,▁▁▁▁▁▁▁▁▂▃▄▅▆▇█
val/average_levenshtein_distance,█████▇▆▄▄▂▂▂▂▁▁
val/char_accuracy,▁▁▁▁▁▁▂▃▃▅▆▇▇██
val/val_loss,████▇▆▆▄▃▂▃▂▁▁▁
val/word_accuracy,▁▁▁▁▁▁▁▁▁▃▄▅▆▇█

0,1
epoch,15.0
train/average_levenshtein_distance,1.44877
train/char_accuracy,0.57746
train/train_loss,4.2925
train/word_accuracy,0.2743
val/average_levenshtein_distance,1.12256
val/char_accuracy,0.6526
val/val_loss,3.37453
val/word_accuracy,0.38983
