In [None]:
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from src.models.filtre import ResNetBinaryClassifier,grid_filtre
from src.models.train_test import plot_history
from src.utils.helpers import save_json, count_total_parameters
from src.data.data_loader import load_soccernet,load_diwan_test,load_diwan_train,load_processed,load_ca12,load_reid,load_full,BinaryDataset
from src.data.data_handling import split_dataset, make_binary, plot_label_distribution

In [None]:
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
basic_params=16
model = ResNetBinaryClassifier(basic_params=basic_params)
count_total_parameters(model)

In [None]:
thresh_ca12 = 0.6
max_thresh_reid = 0.6
min_thresh_reid = 0.02

path1 = "data/diwan/test"
path2_1 = "data/diwan/train/seif_train_gt.json"
path2_2 = "data/diwan/train/skander_train_gt.json"
path3 = "data/soccernet"
path4 = "data/ca12"
path5 = "data/soccernet_reid"
path6 = "data/full_dataset"
path7 = "data/processed_dataset"

data1 = load_diwan_test(path1)
print(len(data1),"images dans diwan test")
data2 = load_diwan_train([path2_1,path2_2])
print(len(data2),"images dans diwan test")
data3 = load_soccernet(path3)
print(len(data3),"images dans soccernet")
data4 = load_ca12(path4,thresh=thresh_ca12)
print(len(data4),"images dans ca12")
data5 = load_reid(path5,max_thresh=max_thresh_reid,min_thresh=min_thresh_reid)
print(len(data5),"images dans reid_soccernet")
data6 = load_full(path6)
print(len(data6),"images dans full_dataset")
data7 = load_processed(path7,data = data6)
print(len(data7),"images dans processed")

In [None]:
train_data =  data3 + data4 + data5 + data6 + data7
test_data = data1 + data2

In [None]:
train_data = make_binary(train_data,balance=True)
test_data = make_binary(test_data,balance=True)
print(len(train_data),"images dans train")
print(len(test_data),"images dans test")

In [None]:
plot_label_distribution(train_data, title="Distribution du train")
plot_label_distribution(test_data, title="Distribution du test")

In [None]:
split_ratio = 0.8
train, valid = split_dataset(train_data, split_ratio=split_ratio)
print(len(train),"train samples")
print(len(valid),"valid samples")

In [None]:
cut = "topbottom"
image_size = (224, 224)
train_dataset = BinaryDataset(train, image_size=image_size,cut=cut)
valid_dataset = BinaryDataset(valid, image_size=image_size,cut=cut)

In [None]:
# DataLoaders
batch_size = 64
workers = 2
train_loader = DataLoader(train_dataset,num_workers=workers, batch_size=batch_size, shuffle=True,pin_memory=True)
val_loader = DataLoader(valid_dataset,num_workers=workers, batch_size=batch_size, shuffle=False,pin_memory=True)

In [None]:
num_epochs = 50
train_history = model.train_legib(num_epochs=num_epochs, train_loader=train_loader, valid_loader=val_loader, device=device)

In [None]:
fig = plot_history(train_history)

In [None]:
test_dataset = BinaryDataset(test_data, image_size=image_size,cut=cut)
test_loader = DataLoader(test_dataset,num_workers=workers, batch_size=batch_size, shuffle=True,pin_memory=True)

In [None]:
metrics,cm = model.test_legib(test_loader, device=device)

In [None]:
model_name = f'filtre_thresh_{thresh_ca12}_basic_params_{basic_params}'
torch.save(model, f"results/weights/{model_name}.pth")
save_json(metrics, f'results/test_metrics/{model_name}.json')
cm.savefig(f'results/test_metrics/cm_{model_name}.png')
fig.savefig(f'results/train_history/{model_name}.png')
plt.close(fig)

In [None]:
image_grid = grid_filtre(model=model,test_loader=test_loader,filtre = "all")

In [None]:
image_grid.savefig(f'results/test_grid/{model_name}.png')
plt.close(image_grid)