Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
sherstpasha committed Mar 9, 2024
1 parent 37ff028 commit dc7b788
Show file tree
Hide file tree
Showing 7 changed files with 1,984 additions and 1,983 deletions.
7 changes: 5 additions & 2 deletions src/thefittest/base/_gpnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def train_net_structure(

net = optimizer.get_fittest()["phenotype"].copy()

return net
return net, optimizer._stats


class BaseGPNN(BaseEstimator, metaclass=ABCMeta):
Expand Down Expand Up @@ -278,6 +278,9 @@ def array_like_to_numpy_X_y(
y = np.array(y, dtype=np.float64)
return X, y

def get_net(self) -> Net:
return self.net_

@staticmethod
def genotype_to_phenotype_tree(
tree: Tree, n_variables: int, n_outputs: int, output_activation: str, offset: bool
Expand Down Expand Up @@ -402,7 +405,7 @@ def fit(self, X: ArrayLike, y: ArrayLike):
task_type = "regression"
n_outputs = 1

self.net_ = train_net_structure(
self.net_, self.optimizer_stats_ = train_net_structure(
uniset=uniset,
X_train=X_train,
y_train=y_train,
Expand Down
4 changes: 2 additions & 2 deletions src/thefittest/base/_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,9 @@ def train_net_weights(
optimizer = weights_optimizer(**weights_optimizer_args)
optimizer.fit()

phenotype = optimizer.get_fittest()["phenotype"]
net = optimizer.get_fittest()["phenotype"]

return phenotype, optimizer._stats
return net, optimizer._stats


class BaseMLPEA(BaseEstimator, metaclass=ABCMeta):
Expand Down
Loading

0 comments on commit dc7b788

Please sign in to comment.