Split the dataset into training and test set for fair comparison between the models

In [None]:
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torch.utils.data import random_split
import random
import matplotlib.pyplot as plt
from collections import defaultdict, Counter
from torchvision import transforms, datasets

IAM+RIMES split

In [None]:
# -----------------------------
# Percorso merged dataset
# -----------------------------
data_dir = '../../IAM+RIMES'  # IAM + RIMES già uniti

# Trasformazioni immagini
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Parametri
max_authorized = 200
max_unauthorized = 50

# Carica dataset
dataset = datasets.ImageFolder(root=data_dir, transform=transform)

# Conta campioni per writer
writer_to_indices = defaultdict(list)
for idx, sample in enumerate(dataset):
    writer_to_indices[sample[1]].append(idx)

# -----------------------------
#Escludi writer "000" (IAM)
# -----------------------------
# Trova il label index associato a "000" (se presente)
writer_name_to_idx = dataset.class_to_idx  # mappa {nome_cartella: indice}
if "000" in writer_name_to_idx:
    idx_000 = writer_name_to_idx["000"]
    if idx_000 in writer_to_indices:
        del writer_to_indices[idx_000]
    print("Writer '000' escluso dal dataset")

# Filtra autori con almeno 2 campioni
filtered_writer_to_indices = {w: idxs for w, idxs in writer_to_indices.items() if len(idxs) >= 2}

# Classifica gli autori per numero di campioni (discendente)
sorted_writers = sorted(filtered_writer_to_indices.items(), key=lambda x: len(x[1]), reverse=True)

# Identifica se writer è IAM o RIMES
def writer_source(writer_id):
    # supponiamo che IAM abbia id numerici, RIMES id stringhe → adatta se serve
    return "IAM" if str(writer_id).isdigit() else "RIMES"

iam_writers = [w for w, idxs in sorted_writers if writer_source(w) == "IAM"]
rimes_writers = [w for w, idxs in sorted_writers if writer_source(w) == "RIMES"]

# Seleziona bilanciato IAM+RIMES
n_iam = max_authorized // 2
n_rimes = max_authorized - n_iam
authorized_writers = set(iam_writers[:n_iam] + rimes_writers[:n_rimes])

# Gli altri diventano non autorizzati
unauthorized_writers = set(w for w, _ in sorted_writers if w not in authorized_writers)
if len(unauthorized_writers) > max_unauthorized:
    unauthorized_writers = set(random.sample(list(unauthorized_writers), max_unauthorized))

# -----------------------------
# Creazione train/test indices
# -----------------------------
train_indices, test_auth_samples, test_unauth_samples = [], [], []

# Autorizzati: n-1 campioni train, 1 test
for writer in authorized_writers:
    indices = filtered_writer_to_indices[writer]
    random.shuffle(indices)
    train_indices.extend(indices[:-1])  # tutti tranne 1
    test_auth_samples.append(indices[-1])  # ultimo per test

# Non autorizzati: 1 campione per writer nel test
for writer in unauthorized_writers:
    indices = filtered_writer_to_indices[writer]
    sample = random.choice(indices)
    test_unauth_samples.append(sample)

# Combina test set
test_indices = test_auth_samples + test_unauth_samples
random.shuffle(test_indices)

print(f"Train set: {len(train_indices)} campioni (solo autorizzati)")
print(f"Test set: {len(test_indices)} campioni (autorizzati + non)")

# -----------------------------
# Statistiche
# -----------------------------
train_writer_ids = [dataset[idx][1] for idx in train_indices]
test_writer_ids = [dataset[idx][1] for idx in test_indices]

test_auth_ids = [i for i in test_writer_ids if i in authorized_writers]
test_unauth_ids = [i for i in test_writer_ids if i in unauthorized_writers]

train_count = Counter(train_writer_ids)
test_auth_count = Counter(test_auth_ids)
test_unauth_count = Counter(test_unauth_ids)

print("Statistiche principali")
print(f"- Utenti autorizzati: {len(authorized_writers)}")
print(f"  ↳ Campioni nel train set: {len(train_indices)}")
print(f"  ↳ Campioni autorizzati nel test set: {sum(test_auth_count.values())}")
print(f"- Utenti non autorizzati: {len(unauthorized_writers)}")
print(f"  ↳ Campioni non autorizzati nel test set: {sum(test_unauth_count.values())}")

# -----------------------------
# Salvataggio split
# -----------------------------
label_map = {orig: i for i, orig in enumerate(sorted(authorized_writers))}  # solo autorizzati
split = {
    'train_indices': train_indices,
    'test_indices': test_indices,
    'label_map': label_map
}
os.makedirs('splits', exist_ok=True)
torch.save(split, 'splits/merged_dataset_split_improved.pth')


Debug

In [None]:
import os
from PIL import Image
from collections import Counter, defaultdict

root = '../merged_dataset'

sizes = Counter()
by_writer_count = Counter()
unknown_count = 0
sample_dim = defaultdict(list)

for writer in os.listdir(root):
    wdir = os.path.join(root, writer)
    if not os.path.isdir(wdir): 
        continue
    files = [f for f in os.listdir(wdir) if f.lower().endswith(('.png','.jpg','.jpeg','.tif','.tiff'))]
    by_writer_count[writer] += len(files)
    if 'unknown' in writer.lower():
        unknown_count += len(files)
    # campiona poche immagini per writer
    for f in files[:3]:
        p = os.path.join(wdir,f)
        try:
            with Image.open(p) as im:
                sizes[(im.width, im.height)] += 1
                sample_dim[writer].append((im.width, im.height))
        except:
            pass

print('Top 10 dimensioni:', sizes.most_common(10))
print('Writer totali:', len(by_writer_count))
print('Immagini totali:', sum(by_writer_count.values()))
print('Immagini in classi *_unknown:', unknown_count)
big = sum(1 for wh,c in sizes.items() if max(wh) > 1000)
print('Numero di formati “grandi” (max side > 1000):', big)


In [None]:
import os

data_dir = "../../IAM+RIMES"

# 1. Conta le sottocartelle (writer)
all_writers = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]
print(f"Numero totale di writer (sottocartelle): {len(all_writers)}")

# 2. Conta immagini per ciascun writer
writer_image_counts = {}
for writer in all_writers:
    writer_path = os.path.join(data_dir, writer)
    n_images = len([f for f in os.listdir(writer_path) if os.path.isfile(os.path.join(writer_path, f))])
    writer_image_counts[writer] = n_images

# 3. Statistiche generali
writers_with_3plus = [w for w, n in writer_image_counts.items() if n >= 3]
print(f"Writer con almeno 3 immagini: {len(writers_with_3plus)}")
print(f"Writer con meno di 3 immagini: {len(all_writers) - len(writers_with_3plus)}")

# 4. (Opzionale) Mostra i primi 20 writer con numero di immagini
print("\nEsempio primi 20 writer e numero di immagini:")
for w in list(writer_image_counts.keys())[:20]:
    print(f"{w}: {writer_image_counts[w]}")


In [None]:
import os

data_dir = "../../merged_dataset"

writers = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]

writers_with_3_or_more = [
    w for w in writers if len(os.listdir(os.path.join(data_dir, w))) >= 4
]

print(f"Numero totale writer: {len(writers)}")
print(f"Numero writer con almeno 3 campioni: {len(writers_with_3_or_more)}")
