In [None]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from src.utils.helpers import save_json, count_total_parameters
from src.models.model import vit
from src.models.loss import LossWrapper, Type2DirichletLoss, SoftmaxWithUncertaintyLoss
from src.data.data_loader import load_soccernet,load_diwan_test,load_diwan_train,load_processed,load_ca12,load_reid,load_full, JerseyNumberDataset
from src.data.data_handling import split_dataset, balancer, plot_label_distribution,augment_dataset, count_digit_frequency
from src.models.train_test import train, plot_history, test, grid

In [None]:
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
uncertainty_head = "dirichlet" # "dirichlet" ou "softmax"
base_loss = Type2DirichletLoss(num_classes=100) # Type2DirichletLoss ou SoftmaxWithUncertaintyLoss
embed_dim = 120
hidden_layers = 6
attention_heads = 3

model = vit(embed_dim=embed_dim,hidden_layers=hidden_layers,attention_heads=attention_heads,use_time=False,use_size=False,\
             uncertainty_head=uncertainty_head).to(device)
count_total_parameters(model)

In [None]:
thresh_ca12 = 0.6
max_thresh_reid = 0.6
min_thresh_reid = 0.02

path1 = "data/diwan/test"
path2_1 = "data/diwan/train/seif_train_gt.json"
path2_2 = "data/diwan/train/skander_train_gt.json"
path3 = "data/soccernet"
path4 = "data/ca12"
path5 = "data/soccernet_reid"
path6 = "data/full_dataset"
path7 = "data/processed_dataset"

data1 = load_diwan_test(path1)
print(len(data1),"images dans diwan test")
data2 = load_diwan_train([path2_1,path2._2])
print(len(data2),"images dans diwan test")
data3 = load_soccernet(path3)
print(len(data3),"images dans soccernet")
data4 = load_ca12(path4,thresh=thresh_ca12)
print(len(data4),"images dans ca12")
data5 = load_reid(path5,max_thresh=max_thresh_reid,min_thresh=min_thresh_reid)
print(len(data5),"images dans reid_soccernet")
data6 = load_full(path6)
print(len(data6),"images dans full_dataset")
data7 = load_processed(path7,data = data6)
print(len(data7),"images dans processed")

In [None]:
train_data =  data3 + data4 + data5 + data6 + data7
test_data = data1 + data2
print(len(train_data),"train samples")
print(len(test_data),"test samples")

In [None]:
digit_counts = count_digit_frequency(train_data)
print("Initial digit frequencies:")
for i in range(10):
    print(f"Digit {i}: {digit_counts[i]}")

In [None]:
augment_dataset(data=train_data,num_liste=[6,5,8,9,4,3,0,7],target_per_digit=20000,output_folder="augmented/images",json_path="augmented/data.json")
print(len(train_data),"train samples")

In [None]:
blur_ratio=0.3
train_data = balancer(train_data,max_0=0,blur_ratio=blur_ratio)
test_data = balancer(test_data,max_0=0)
print(len(train_data),"train samples")
print(len(test_data),"test samples")

In [None]:
plot_label_distribution(train_data, title="Distribution des numéros de maillot (train)")
plot_label_distribution(test_data, title="Distribution test")

In [None]:
split_ratio = 0.8
train_data, valid_data = split_dataset(train_data, split_ratio=split_ratio)
print(len(train_data),"train samples")
print(len(valid_data),"valid samples")

In [None]:
cut = "topbottom"
image_size = (224, 224)
train_dataset = JerseyNumberDataset(train_data, image_size=image_size,cut=cut)
valid_dataset = JerseyNumberDataset(valid_data, image_size=image_size,cut=cut)

In [None]:
# DataLoaders
batch_size = 128
workers = 2
train_loader = DataLoader(train_dataset,num_workers=workers, batch_size=batch_size, shuffle=True,pin_memory=True)
val_loader = DataLoader(valid_dataset,num_workers=workers, batch_size=batch_size, shuffle=False,pin_memory=True)

In [None]:
num_epochs = 50
lr = 2e-4
loss_fn = LossWrapper(base_loss)
optimizer = optim.Adam(model.parameters(), lr=lr)

In [None]:
history = train(model=model,train_loader=train_loader,val_loader=val_loader,
    optimizer=optimizer,loss_fn=loss_fn,device=device,num_epochs=num_epochs)

In [None]:
fig = plot_history(history)

In [None]:
test_dataset = JerseyNumberDataset(test_data, image_size=image_size,cut=cut)
test_loader = DataLoader(test_dataset,num_workers=workers, batch_size=batch_size,pin_memory=True)

In [None]:
test_metrics = test(model, test_loader, device=device)

In [None]:
image_grid = grid(model, test_loader, device=device)

In [None]:
model_name = f'new_aug_20K_blur_{blur_ratio}_{uncertainty_head}_thresh{max_thresh_reid}'
torch.save(model, f"results/weights/{model_name}.pth")
save_json(test_metrics, f'results/test_metrics/{model_name}.json')
fig.savefig(f'results/train_history/{model_name}.png')
image_grid.savefig(f'results/test_grid/{model_name}.png')
plt.close(fig)
plt.close(image_grid)