In [2]:
import numpy as np

def initialize_population(pop_size, weight_shape):
    return np.random.uniform(-1, 1, (pop_size, *weight_shape))

def evaluate_fitness(wolves, validation_data, model, true_labels):
    fitness = []
    for wolf in wolves:
        model.set_weights(wolf)
        predictions = model.predict(validation_data)
        error = np.mean(np.abs(predictions - true_labels))
        fitness.append(error)
    return np.array(fitness)

def identify_best_wolves(fitness, wolves):
    sorted_indices = np.argsort(fitness)
    alpha_wolf = wolves[sorted_indices[0]]
    beta_wolf = wolves[sorted_indices[1]]
    delta_wolf = wolves[sorted_indices[2]]
    return alpha_wolf, beta_wolf, delta_wolf

def update_positions(wolves, alpha_wolf, beta_wolf, delta_wolf, a, weight_shape):
    new_wolves = []
    for wolf in wolves:
        A1 = 2 * np.random.rand(*weight_shape) - 1
        C1 = 2 * np.random.rand(*weight_shape)
        D_alpha = np.abs(C1 * alpha_wolf - wolf)
        X1 = alpha_wolf - A1 * D_alpha

        A2 = 2 * np.random.rand(*weight_shape) - 1
        C2 = 2 * np.random.rand(*weight_shape)
        D_beta = np.abs(C2 * beta_wolf - wolf)
        X2 = beta_wolf - A2 * D_beta

        A3 = 2 * np.random.rand(*weight_shape) - 1
        C3 = 2 * np.random.rand(*weight_shape)
        D_delta = np.abs(C3 * delta_wolf - wolf)
        X3 = delta_wolf - A3 * D_delta

        new_wolf = (X1 + X2 + X3) / 3
        new_wolves.append(new_wolf)

    return np.array(new_wolves)

def handle_boundaries(wolves, weight_range=(-1, 1)):
    wolves = np.clip(wolves, weight_range[0], weight_range[1])
    return wolves

def gwo_neural_network_training(model, pop_size, weight_shape, max_iter, validation_data, true_labels):
    wolves = initialize_population(pop_size, weight_shape)
    fitness = evaluate_fitness(wolves, validation_data, model, true_labels)
    alpha_wolf, beta_wolf, delta_wolf = identify_best_wolves(fitness, wolves)

    iteration = 0
    while iteration < max_iter:
        a = 2 - iteration * (2 / max_iter)

        wolves = update_positions(wolves, alpha_wolf, beta_wolf, delta_wolf, a, weight_shape)
        wolves = handle_boundaries(wolves)

        fitness = evaluate_fitness(wolves, validation_data, model, true_labels)
        alpha_wolf, beta_wolf, delta_wolf = identify_best_wolves(fitness, wolves)

        if iteration % 100 == 0 or iteration == max_iter - 1:
            print(f"Iteration {iteration}, Best fitness (error): {fitness.min():.6f}")

        iteration += 1

    print(f"Training finished. Best fitness (error): {fitness.min():.6f}")
    return alpha_wolf

class DummyModel:
    def __init__(self, weight_shape):
        self._weights = np.random.rand(*weight_shape)

    def set_weights(self, weights):
        self._weights = weights

    def predict(self, data):
        return np.dot(data, self._weights[:data.shape[1]])

if __name__ == "__main__":
    X_val = np.random.rand(10, 10)
    y_val = np.random.rand(10)

    model = DummyModel(weight_shape=(10, 10))

    best_weights = gwo_neural_network_training(
        model,
        pop_size=30,
        weight_shape=(10, 10),
        max_iter=1000,
        validation_data=X_val,
        true_labels=y_val
    )

    print("Best weights found:")
    print(best_weights)

Iteration 0, Best fitness (error): 0.417305
Iteration 100, Best fitness (error): 0.083484
Iteration 200, Best fitness (error): 0.078399
Iteration 300, Best fitness (error): 0.078788
Iteration 400, Best fitness (error): 0.079485
Iteration 500, Best fitness (error): 0.080415
Iteration 600, Best fitness (error): 0.075132
Iteration 700, Best fitness (error): 0.081623
Iteration 800, Best fitness (error): 0.074713
Iteration 900, Best fitness (error): 0.083141
Iteration 999, Best fitness (error): 0.078595
Training finished. Best fitness (error): 0.078595
Best weights found:
[[-1.57683995e-05 -1.01772613e-04  4.42427759e-10 -1.64546189e-07
   1.34784890e-07  1.58945676e-06  2.37136844e-05  1.39270896e-01
  -2.49529174e-05  1.32984491e-05]
 [ 3.37327759e-05  1.38706597e-03  6.03539763e-02  1.11630838e-06
  -4.67697023e-06  7.24433673e-05  3.56604326e-06  2.13160190e-01
  -1.62859066e-05  1.40508969e-05]
 [ 6.34926060e-01 -3.09951084e-07  7.96765302e-07  1.18006491e-07
   2.63247163e-03  1.32190