In [8]:
import numpy as np
import pickle
from sklearn.metrics import f1_score, accuracy_score, classification_report, confusion_matrix
import neat
from neat import Config, Population, DefaultGenome, DefaultReproduction, DefaultSpeciesSet, DefaultStagnation

from pureples.shared.substrate import Substrate
from pureples.shared.create_cppn import create_cppn
from pureples.es_hyperneat.es_hyperneat import ESNetwork

# === 1) Загрузка данных ===
X_train = np.load('../X_train.npy')
y_train = np.load('../y_train.npy')
X_test  = np.load('../X_test.npy')
y_test  = np.load('../y_test.npy')

# (опционально) сэмплирование
sample_train = 10000
X_train = X_train[:sample_train]
y_train = y_train[:sample_train]

print(f"Train: {X_train.shape}, Test: {X_test.shape}")

# === 2) Субстрат ===
i_coords = [(-1.0 + i*(2.0/3.0), 1.0) for i in range(4)]
hidden_coords = []
o_coords = [(0.0, -1.0)]
substrate = Substrate(i_coords, o_coords, hidden_coords)

# === 3) Параметры ES-HyperNEAT в виде dict ===
es_params = {
    "initial_depth":      1,
    "max_depth":          3,
    "variance_threshold": 0.03,
    "band_threshold":     0.2,
    "iteration_level":    1,
    "division_threshold": 0.3,
    "max_weight":         5.0,
    "activation":         "tanh",
}

# === 4) Функция оценки геномов ===
def eval_genomes(genomes, config):
    for gid, genome in genomes:
        # 4.1 создаём CPPN
        cppn = create_cppn(
            genome,
            config,
            output_activation_function=es_params["activation"]
        )
        # 4.2 строим фенотип-сеть
        es_net = ESNetwork(substrate, cppn, es_params)
        net    = es_net.create_phenotype_network()

        # 4.3 предсказания
        preds = [int(net.activate(x)[0] > 0.5) for x in X_train]

        # 4.4 фитнесс
        genome.fitness = f1_score(y_train, preds, average='weighted')

# === 5) Настройка NEAT и запуск ===
config_path = "es_hyper_neat.cfg"
config = Config(DefaultGenome, DefaultReproduction, DefaultSpeciesSet, DefaultStagnation, config_path)

pop = Population(config)
pop.add_reporter(neat.StdOutReporter(True))
pop.add_reporter(neat.StatisticsReporter())

winner = pop.run(eval_genomes, n=50)

# === 6) Сохранение ===
with open("es_hyper_model.pkl", "wb") as f:
    pickle.dump(winner, f)
print("🏆 Победитель сохранён в es_hyper_model.pkl")

# === 7) Оценка на тесте ===
cppn_w = create_cppn(winner, config, output_activation_function=es_params["activation"])
es_net_w = ESNetwork(substrate, cppn_w, es_params)
net_w    = es_net_w.create_phenotype_network()

y_pred = [int(net_w.activate(x)[0] > 0.5) for x in X_test]

print("\n— Test metrics —")
print("Accuracy       :", accuracy_score(y_test, y_pred))
print("F1-score       :", f1_score(y_test, y_pred, average='weighted'))
print("\nClassification report:\n", classification_report(y_test, y_pred))
print("Confusion matrix:\n", confusion_matrix(y_test, y_pred))


Train: (10000, 4), Test: (128000, 4)

 ****** Running generation 0 ****** 

Population's average fitness: 0.93272 stdev: 0.00000
Best fitness: 0.93272 - size: (1, 5) - species 1 - id 1
Average adjusted fitness: 0.000
Mean genetic distance 1.086, standard deviation 0.368
Population of 150 members in 1 species:
   ID   age  size  fitness  adj fit  stag
     1    0   150      0.9    0.000     0
Total extinctions: 0
Generation time: 1.680 sec

 ****** Running generation 1 ****** 

Population's average fitness: 0.93272 stdev: 0.00000
Best fitness: 0.93272 - size: (1, 5) - species 1 - id 1
Average adjusted fitness: 0.000
Mean genetic distance 1.347, standard deviation 0.423
Population of 150 members in 1 species:
   ID   age  size  fitness  adj fit  stag
     1    1   150      0.9    0.000     1
Total extinctions: 0
Generation time: 1.704 sec (1.692 average)

 ****** Running generation 2 ****** 

Population's average fitness: 0.93272 stdev: 0.00000
Best fitness: 0.93272 - size: (1, 5) - spec

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
