In [None]:
def evaluate_model(model, loader):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            labels = labels.to(device).unsqueeze(1)
            outputs = model(images)
            preds = (outputs > 0.5).float()
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    y_pred = np.array(all_preds).flatten()
    y_true = np.array(all_labels).flatten()
    acc = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    print(f"Validation Accuracy: {acc:.4f}")
    print(f"Validation F1 Score: {f1:.4f}")
    print("Classification Report:\n", classification_report(y_true, y_pred))
    cm = confusion_matrix(y_true, y_pred)
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix")
    plt.show()

In [None]:
def find_best_threshold(model, loader):
    model.eval()
    all_probs, all_labels = [], []
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            labels = labels.to(device).unsqueeze(1)
            outputs = model(images)
            all_probs.extend(outputs.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    all_probs = np.array(all_probs).flatten()
    all_labels = np.array(all_labels).flatten()
    precisions, recalls, thresholds = precision_recall_curve(all_labels, all_probs)
    f1_scores = 2 * (precisions * recalls) / (precisions + recalls + 1e-8)
    best_index = np.argmax(f1_scores)
    best_threshold = thresholds[best_index]
    print(f"Best threshold by F1: {best_threshold:.4f}, F1 Score: {f1_scores[best_index]:.4f}")
    return best_threshold

In [None]:
optimal_threshold = find_best_threshold(model, val_loader)
test_predictions = run_test_predictions(model, threshold=optimal_threshold)

In [None]:
submission_df = pd.DataFrame(test_predictions, columns=["image_id", "label"])