Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use epochs instead batch_num. Log current epoch number #95

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion openfl-workspace/keras_nlp/src/nlp_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_valid_data_size(self):
"""
return self.X_valid[0].shape[0]

# TODO: first param sould be self. it should be added or renamed
@staticmethod
def _batch_generator(X1, X2, y, idxs, batch_size, num_batches):
"""
Generate batch of data.
Expand Down
3 changes: 1 addition & 2 deletions openfl-workspace/tf_cnn_histology/plan/plan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,4 @@ tasks:
batch_size: 32
epochs: 1
metrics:
- loss
num_batches: 1
- loss
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ train:
function : train_batches
kwargs :
batch_size : 32
num_batches : 1
metrics :
- loss
epochs : 1
1 change: 1 addition & 0 deletions openfl-workspace/workspace/plan/defaults/tasks_torch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ train:
kwargs :
metrics :
- loss
epochs : 1
12 changes: 7 additions & 5 deletions openfl/federated/task/runner_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def rebuild_model(self, round_num, input_tensor_dict, validation=False):
else:
self.set_tensor_dict(input_tensor_dict, with_opt_vars=False)

def train(self, col_name, round_num, input_tensor_dict, metrics, num_batches=None, **kwargs):
def train(self, col_name, round_num, input_tensor_dict,
metrics, epochs=1, batch_size=1, **kwargs):
"""
Perform the training for a specified number of batches.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring should be updated since num_batches is not specified anymore.


Expand All @@ -77,10 +78,11 @@ def train(self, col_name, round_num, input_tensor_dict, metrics, num_batches=Non

# rebuild model with updated weights
self.rebuild_model(round_num, input_tensor_dict)

results = self.train_iteration(self.data_loader.get_train_loader(num_batches),
metrics=metrics,
**kwargs)
for epoch in range(epochs):
self.logger.info(f'Run {epoch} epoch of {round_num} round')
results = self.train_iteration(self.data_loader.get_train_loader(batch_size),
metrics=metrics,
**kwargs)

# output metric tensors (scalar)
origin = col_name
Expand Down
12 changes: 7 additions & 5 deletions openfl/federated/task/runner_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def validate(self, col_name, round_num, input_tensor_dict,
return output_tensor_dict, {}

def train_batches(self, col_name, round_num, input_tensor_dict,
num_batches=None, use_tqdm=False, **kwargs):
num_batches=None, use_tqdm=False, epochs=1, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could num_batches be removed here as well?

"""Train batches.

Train the model on the requested number of batches.
Expand All @@ -150,10 +150,12 @@ def train_batches(self, col_name, round_num, input_tensor_dict,
# set to "training" mode
self.train()
self.to(self.device)
loader = self.data_loader.get_train_loader(num_batches)
if use_tqdm:
loader = tqdm.tqdm(loader, desc='train epoch')
metric = self.train_epoch(loader)
for epoch in range(epochs):
self.logger.info(f'Run {epoch} epoch of {round_num} round')
loader = self.data_loader.get_train_loader(num_batches)
if use_tqdm:
loader = tqdm.tqdm(loader, desc='train epoch')
metric = self.train_epoch(loader)
# Output metric tensors (scalar)
origin = col_name
tags = ('trained',)
Expand Down
13 changes: 4 additions & 9 deletions openfl/federated/task/runner_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def rebuild_model(self, round_num, input_tensor_dict, validation=False):
self.set_tensor_dict(input_tensor_dict, with_opt_vars=False)

def train_batches(self, col_name, round_num, input_tensor_dict,
num_batches, use_tqdm=False, **kwargs):
epochs=1, use_tqdm=False, **kwargs):
"""
Perform the training for a specified number of batches.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring should be updated since num_batches is not specified anymore.


Expand All @@ -107,22 +107,17 @@ def train_batches(self, col_name, round_num, input_tensor_dict,
self.rebuild_model(round_num, input_tensor_dict)

tf.keras.backend.set_learning_phase(True)

losses = []
batch_num = 0

while batch_num < num_batches:
for epoch in range(epochs):
self.logger.info(f'Run {epoch} epoch of {round_num} round')
# get iterator for batch draws (shuffling happens here)
gen = self.data_loader.get_train_loader(batch_size)
if use_tqdm:
gen = tqdm.tqdm(gen, desc='training epoch')

for (X, y) in gen:
if batch_num >= num_batches:
break
else:
losses.append(self.train_batch(X, y))
batch_num += 1
losses.append(self.train_batch(X, y))

# Output metric tensors (scalar)
origin = col_name
Expand Down