In [None]:
import itertools
from ternary_weight_mlp import *
import torch
import json

ALPHA = 0.5

# Define hyperparameter ranges
learning_rates = [0.1, 0.01, 0.001, 0.0001]
penalties_magnitude = [0.1, 0.01, 0.001]
penalties_polarity = [0.1, 0.01, 0.001]
penalties_integers = [0.1, 0.01, 0.001]

# Create a grid of all hyperparameter combinations
grid = list(itertools.product(learning_rates, penalties_magnitude, penalties_polarity, penalties_integers))

# Results list to collect all results
results = []

for lr, penalty_magnitude, penalty_polarity, penalty_integers in grid:
    # dataset
    X_train, y_train, X_test, y_test = create_torch_XOR_dataset()

    model_XOR = TwoLayerMLP()
    loss_function = torch.nn.BCELoss()
    optimizer = torch.optim.Adam(model_XOR.parameters(), lr=lr)
    
    EPOCHS = 10000
    
    # Assuming you have functions to apply penalties in the model or loss function
    train_loss = -1
    train_loss_reg = -1

    try:
        train_loss, train_loss_reg = train_with_rectified_L2(model_XOR, 
                                    loss_function, 
                                    optimizer, 
                                    X_train, 
                                    y_train,
                                    no_of_epochs=EPOCHS,
                                    ALPHA=ALPHA,
                                    LAMBDA_MAGNITUDE=penalty_magnitude,
                                    LAMBDA_POLARITY=penalty_polarity,
                                    LAMBDA_INTEGERS=penalty_integers,
                                    initial_lr=lr,
                                    max_lr=lr)
    except Exception as e:
        print(e)
        params = f"Training failed for parameters: lr:{lr}, penalty_magnitude:{penalty_magnitude}, penalty_polarity:{penalty_polarity}, penalty_integers:{penalty_integers}"
        print(params)
        continue  # Skip this iteration if an error occurs

    input1 = torch.tensor([[0.0, 0.0]])
    input2 = torch.tensor([[0.0, 1.0]])
    input3 = torch.tensor([[1.0, 0.0]])
    input4 = torch.tensor([[1.0, 1.0]])

    model_XOR.eval()  # Set the model to evaluation mode
    output1 = model_XOR(input1).item()
    output2 = model_XOR(input2).item()
    output3 = model_XOR(input3).item()
    output4 = model_XOR(input4).item()

    accuracy_counter = 0
    if output1 < 0.05:
        accuracy_counter += 1
    if output2 > 0.95:
        accuracy_counter += 1
    if output3 > 0.95:
        accuracy_counter += 1
    if output4 < 0.05:
        accuracy_counter += 1
    
    val_accuracy = accuracy_counter / 4

    # Append the results
    results.append({"lr": lr, "penalty_magnitude": penalty_magnitude, "penalty_polarity": penalty_polarity,
                    "penalty_integers": penalty_integers, "train_loss": train_loss, "train_loss_reg": train_loss_reg,
                    "val_accuracy": val_accuracy})

# Write the collected results to a JSON file
results_file = 'hyperparameter_search_results.json'
with open(results_file, 'w') as file:
    json.dump(results, file, indent=4)

In [None]:
import json

# Path to the JSON file containing the results
results_file = 'hyperparameter_search_results.json'

# Load the results from the JSON file
with open(results_file, 'r') as file:
    results = json.load(file)

# Filter results to find only those with 100% validation accuracy
optimal_results = [result for result in results if result['val_accuracy'] == 1.0]

# Check if there are any optimal results and print them
if optimal_results:
    print("Results with 100% validation accuracy:")
    for result in optimal_results:
        print(f"Learning Rate: {result['lr']}, "
              f"Penalty Magnitude: {result['penalty_magnitude']}, "
              f"Penalty Polarity: {result['penalty_polarity']}, "
              f"Penalty Integers: {result['penalty_integers']}, "
              f"Train Loss: {result['train_loss'][-1]}, "
              f"Regularized Train Loss: {result['train_loss_reg'][-1]}, "
              f"Validation Accuracy: {result['val_accuracy']}")
else:
    print("No results with 100% validation accuracy.")
