Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/pytorch-trainer' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
htahir1 committed Jan 22, 2021
2 parents 74f2d09 + 2478fe7 commit 8415a08
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions zenml/core/steps/trainer/pytorch_trainers/torch_ff_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,19 @@ def __init__(self,
self.last_activation = last_activation
self.input_units = input_units
self.output_units = output_units
super(FeedForwardTrainer, self).__init__(**kwargs)
super(FeedForwardTrainer, self).__init__(
batch_size=self.batch_size,
lr=self.lr,
epoch=self.epoch,
dropout_chance=self.dropout_chance,
loss=self.loss,
metrics=self.metrics,
hidden_layers=self.hidden_layers,
hidden_activation=self.hidden_activation,
last_activation=self.last_activation,
input_units=self.input_units,
output_units=self.output_units,
**kwargs)

def input_fn(self,
file_pattern: List[Text],
Expand Down Expand Up @@ -154,4 +166,4 @@ def run_fn(self):
path_utils.copy_dir(temp_path, self.serving_model_dir)
path_utils.rm_dir(temp_path)
else:
torch.save(model, os.path.join(self.serving_model_dir, 'model.pt'))
torch.save(model, os.path.join(self.serving_model_dir, 'model.pt'))

0 comments on commit 8415a08

Please sign in to comment.