In [None]:
def add_new_author(new_author_data_path, model, num_classes, train_loader, val_loader, epochs=10):
    """
    Function to add a new author to the handwriting identification system.

    Args:
        new_author_data_path (str): Path to the new author's data.
        model (torch.nn.Module): Pre-trained model.
        num_classes (int): Current number of classes in the model.
        train_loader (DataLoader): DataLoader for existing training data.
        val_loader (DataLoader): DataLoader for validation data.
        epochs (int): Number of fine-tuning epochs.

    Returns:
        torch.nn.Module: Fine-tuned model.
    """
    # Load new author's data
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    new_author_dataset = ImageFolder(root=new_author_data_path, transform=transform)
    new_author_loader = DataLoader(new_author_dataset, batch_size=32, shuffle=True)

    # Update the final layer of the model to include the new class
    model.fc = nn.Linear(model.fc.in_features, num_classes + 1)  # Add one more class
    model.to(device)

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    # Fine-tuning
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0

        # Train on both new author and existing data
        for images, labels in new_author_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(new_author_loader):.4f}")

    return model


In [None]:
# Path to new author's data
new_author_data_path = "/path/to/new/author/data"

# Add new author to the system
model = add_new_author(
    new_author_data_path=new_author_data_path,
    model=model,
    num_classes=num_classes,
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=10
)

# Salva il modello aggiornato
torch.save(model.state_dict(), 'handwriting_model_updated.pth')


In [None]:
# Visualize correct predictions
model.eval()
correct_images, correct_labels, predicted_labels = [], [], []

with torch.no_grad():
    for images, labels in test_load:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        mask = (predicted == labels)
        correct_images.extend(images[mask].cpu())
        correct_labels.extend(labels[mask].cpu())
        predicted_labels.extend(predicted[mask].cpu())
        if len(correct_images) >= 10:  # Display 10 correct images
            break

# Visualize correct predictions
plt.figure(figsize=(12, 6))
for i in range(10):
    plt.subplot(2, 5, i + 1)
    img = correct_images[i].permute(1, 2, 0)  # Convert from CxHxW to HxWxC
    img = img * 0.5 + 0.5  # Denormalize
    plt.imshow(img.numpy())
    plt.title(f"Label: {correct_labels[i].item()}\nPred: {predicted_labels[i].item()}")
    plt.axis('off')
plt.tight_layout()
plt.show()