# Training file (testing if training works)

In [1]:
import sys
sys.path.append('../')

In [2]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.nn import CTCLoss
import torch.optim as optim

from src.training.trainer import train_model
from src.dataset.custom_dataset import OdometerDataset
from src.dataset.base_dataset import base_collate_fn
from src.models.crnn import CRNN

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [4]:

data_dir = '/home/yannou/OneDrive/Documents/3_PRO/carviz/data/ocr/odometer_reader/milestone_box'
labels_file = '/home/yannou/OneDrive/Documents/3_PRO/carviz/data/ocr/odometer_reader/milestone_box/milestone_labels.json'

# Définir les transformations
transform = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

transform_val = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Initialiser le dataset
dataset = OdometerDataset(root_dir=data_dir, split="train", labels_file=labels_file, img_height=32, img_width=100, transform=transform)
dataset_val = OdometerDataset(root_dir=data_dir, split="val", labels_file=labels_file, img_height=32, img_width=100, transform=transform_val)


# Créer les DataLoaders
train_loader = DataLoader(dataset, batch_size=64, shuffle=True, collate_fn=base_collate_fn)
valid_loader = DataLoader(dataset_val, batch_size=64, shuffle=False, collate_fn=base_collate_fn)



In [5]:
# Initialiser le modèle
num_class = len(OdometerDataset.LABEL2CHAR) + 1
crnn = CRNN(img_channel=1, img_height=32, img_width=100, num_class=num_class, model_size="n", leaky_relu=True).to(device)

# Définir les paramètres d'entraînement
lr = 0.001
epochs = 10
decode_method = 'beam_search'
beam_size = 10
label2char = OdometerDataset.LABEL2CHAR

In [6]:
print(f"Working on {device}")
# Appeler la fonction train_model
trained_model = train_model(
    model=crnn,
    train_loader=train_loader,
    valid_loader=valid_loader,
    label2char=label2char,
    device=device,
    lr=lr,
    epochs=epochs,
    decode_method=decode_method,
    beam_size=beam_size,
    criterion=CTCLoss(reduction='sum', zero_infinity=True).to(device),
    optimizer=optim.Adam(crnn.parameters(), lr=lr),
    project_name="odometer-reader",
    run_name="crnn-n", 
    checkpoint=2
)



Working on cuda:0


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33myann-t[0m ([33mcarviz-com[0m). Use [1m`wandb login --relogin`[0m to force relogin


Evaluate: 100%|██████████| 12/12 [00:15<00:00,  1.31s/it]
Epochs:  10%|█         | 1/10 [01:32<13:50, 92.26s/it]

Epoch 1: train_loss=20.633939083118097, train_accuracy=0.0, val_loss=13.689547661852183, val_accuracy=0.0


Evaluate: 100%|██████████| 12/12 [00:16<00:00,  1.38s/it]
Epochs:  20%|██        | 2/10 [02:57<11:45, 88.14s/it]

Epoch 2: train_loss=13.743394737430954, train_accuracy=0.0, val_loss=13.702055014718464, val_accuracy=0.0


Evaluate: 100%|██████████| 12/12 [00:15<00:00,  1.26s/it]
Epochs:  30%|███       | 3/10 [04:20<10:01, 85.93s/it]

Epoch 3: train_loss=13.723840052328406, train_accuracy=0.0, val_loss=13.666719001461662, val_accuracy=0.0


Evaluate: 100%|██████████| 12/12 [00:14<00:00,  1.23s/it]
Epochs:  40%|████      | 4/10 [05:40<08:21, 83.64s/it]

Epoch 4: train_loss=13.65271879155585, train_accuracy=0.0, val_loss=13.572019813269026, val_accuracy=0.0


Evaluate: 100%|██████████| 12/12 [00:14<00:00,  1.23s/it]
Epochs:  50%|█████     | 5/10 [07:03<06:55, 83.07s/it]

Epoch 5: train_loss=13.336427879021327, train_accuracy=0.0, val_loss=14.508344389158246, val_accuracy=0.0


Evaluate: 100%|██████████| 12/12 [00:15<00:00,  1.30s/it]
Epochs:  60%|██████    | 6/10 [08:28<05:36, 84.03s/it]

Epoch 6: train_loss=12.780550283378938, train_accuracy=0.0, val_loss=11.826792969386611, val_accuracy=0.0


Epochs:  60%|██████    | 6/10 [09:22<06:15, 93.79s/it]


KeyboardInterrupt: 

