In [None]:
# Hyperparameter search routine for repeated training with altered CONSTANTS
# (Assumes previous code with train_model, validate, data loaders, etc. is available)

# Define hyperparameter grid
patience_list = [10, 7]                   # For example, test with patience=10 and 7
epochs_initial_list = [50, 60]            # Try different numbers for the initial training phase
epochs_fine_tune_list = [50, 60]          # Try different numbers for the fine-tuning phase

# Variables to store the best model information
best_val_acc_overall = 0.0
best_hyperparams = {}
best_model_overall_path = 'best_model_overall.pth'

# Loop over all combinations of hyperparameters
for patience in patience_list:
    PATIENCE = patience  # Update the global constant for patience
    for init_epochs in epochs_initial_list:
        EPOCHS_INITIAL = init_epochs  # Update the global constant for initial epochs
        for fine_tune_epochs in epochs_fine_tune_list:
            EPOCHS_FINE_TUNE = fine_tune_epochs  # Update the global constant for fine-tuning epochs
            
            print(f"\nTraining with PATIENCE={PATIENCE}, EPOCHS_INITIAL={EPOCHS_INITIAL}, EPOCHS_FINE_TUNE={EPOCHS_FINE_TUNE}")
            
            # ------------------------------------------------------------------
            # 1) Build a fresh model for this hyperparameter configuration
            # ------------------------------------------------------------------
            # model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
            model = models.resnet152(weights=models.ResNet152_Weights.DEFAULT)
            # model = models.convnext_large(weights=models.ConvNeXt_Large_Weights.DEFAULT)
            model.fc = nn.Sequential(
                nn.Linear(2048, 1024),
                nn.ReLU(),
                nn.Dropout(DROPOUT_PROB),
                nn.Linear(1024, len(le.classes_))
            )
            model = model.to(device)
            
            # Freeze all layers except fc for initial training
            for name, param in model.named_parameters():
                if 'fc' not in name:
                    param.requires_grad = False
            
            # ------------------------------------------------------------------
            # 2) Set up loss, optimizer, and scheduler for initial training
            # ------------------------------------------------------------------
            if ACTIVATE_WEIGHTS_TENSOR:
                criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
            else:
                criterion = nn.CrossEntropyLoss()
            optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                                    lr=LR, weight_decay=WEIGHT_DECAY)
            scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS_INITIAL, eta_min=1e-6)
            
            # Train the model on the initial phase
            history_initial = train_model(model, train_loader, val_loader, criterion, optimizer,
                                          scheduler, num_epochs=EPOCHS_INITIAL, phase='Initial')
            
            # Load best initial-phase model
            model.load_state_dict(torch.load('best_model_initial.pth'))
            
            # ------------------------------------------------------------------
            # 3) Fine-tuning: Unfreeze additional layers (layer3, layer4, and fc)
            # ------------------------------------------------------------------
            for name, param in model.named_parameters():
                if any(layer in name for layer in ['layer3', 'layer4', 'fc']):
                    param.requires_grad = True
                else:
                    param.requires_grad = False
            
            # Re-initialize the optimizer and scheduler for fine-tuning
            optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                                    lr=LR, weight_decay=WEIGHT_DECAY)
            scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS_FINE_TUNE, eta_min=1e-6)
            
            # Train the model on the fine-tuning phase
            history_fine = train_model(model, train_loader, val_loader, criterion, optimizer,
                                       scheduler, num_epochs=EPOCHS_FINE_TUNE, phase='Fine-tune')
            
            # Load best fine-tuned model
            model.load_state_dict(torch.load('best_model_fine-tune.pth'))
            
            # ------------------------------------------------------------------
            # 4) Validate final model and update the best model if applicable
            # ------------------------------------------------------------------
            val_loss, val_acc = validate(model, val_loader, criterion)
            print(f"Hyperparams: PATIENCE={PATIENCE}, EPOCHS_INITIAL={EPOCHS_INITIAL}, EPOCHS_FINE_TUNE={EPOCHS_FINE_TUNE}")
            print(f"Final Validation Accuracy: {val_acc:.4f}")
            
            if val_acc > best_val_acc_overall:
                best_val_acc_overall = val_acc
                best_hyperparams = {'PATIENCE': PATIENCE, 'EPOCHS_INITIAL': EPOCHS_INITIAL, 'EPOCHS_FINE_TUNE': EPOCHS_FINE_TUNE}
                torch.save(model.state_dict(), best_model_overall_path)
                print("==> New best model saved!")
                
print("\nHyperparameter search completed!")
print(f"Best Hyperparameters: {best_hyperparams}")
print(f"Best Overall Validation Accuracy: {best_val_acc_overall:.4f}")

# ------------------------------------------------------------------
# 5) Load the best model for inference and print overall accuracy
# ------------------------------------------------------------------
model.load_state_dict(torch.load(best_model_overall_path))
print("Best model loaded for inference.")

# Optionally, run a final validation to print the accuracy again
final_val_loss, final_val_acc = validate(model, val_loader, criterion)
print(f"Final Model Validation Accuracy: {final_val_acc:.4f}")

# Now you can use the 'predict_grade' function with this best model for inference.
