# Model Generation

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import configparser
import pandas as pd

import torch
import pytorch_lightning as pl

from pynas.core.population import Population
from datasets.RawClassifier.loader import RawClassifierDataModule

# Define dataset module
root_dir = '/Data_large/marine/PythonProjects/OtherProjects/lpl-PyNas/data/RawClassifier'
dm = RawClassifierDataModule(root_dir, batch_size=4, num_workers=2, transform=None)

config = configparser.ConfigParser()
config.read('config.ini')
def setting():
    pd.set_option('display.max_colwidth', None)
    # Logging
    logs_directory = str(config['GA']['logs_dir_GA'])
    # Torch stuff
    seed = config.getint(section='Computation', option='seed')
    pl.seed_everything(seed=seed, workers=True)  # For reproducibility
    torch.set_float32_matmul_precision("medium")  # to make lightning happy
setting()

In [None]:
# Model parameters
max_layers = 3
max_iter = int(config['GA']['max_iterations'])
# GA parameters
n_individuals = int(config['GA']['population_size'])
mating_pool_cutoff = float(config['GA']['mating_pool_cutoff'])
mutation_probability = float(config['GA']['mutation_probability'])

pop = Population(n_individuals=20, max_layers=max_layers, dm=dm, max_parameters=400_000)

In [None]:
pop.initial_poll()

In [None]:
pop.train_generation(task='classification', lr=0.001, epochs=15, batch_size=32)

In [None]:
pop.evolve(mating_pool_cutoff=mating_pool_cutoff, mutation_probability=0.85, k_best=1, n_random=3)

### Load Dataframe Method

The load_dataframe method in the Population class is used to retrieve the stored results 
or evaluation metrics from the training and evolution process of the models.
By calling pop.load_dataframe(9), it is expected that the method will load data (e.g., performance, 
loss values, or architectural configurations) that was saved during the process.

This data can then be used for analysis, visualization, or further processing, providing insights 
into the model's training dynamics or the overall evolutionary process. Make sure that the
index passed to load_dataframe (in this case, 9) corresponds to the correct set of results you intend to load.

In [None]:
pop.load_dataframe(9)

# Inference

Using the evaluated and saved model. We use the traced pytroch model (.pt) to load and execute inference.

In [None]:
# Load the saved TorchScript model and test with a dummy input.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

save_path = "model_and_architecture.pt"
loaded_model = torch.jit.load(save_path, map_location=device)
loaded_model.eval()

# Ensure input is moved to the correct device
example_input = torch.randn(1, *dm.input_shape).to(device)
example_input = example_input.to(device)

with torch.no_grad():
    output = loaded_model(example_input)
print("Output from the loaded model:", output)