Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add loss monitoring to training #27

Merged
merged 4 commits into from
Oct 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ scipy>=1.2.1
wrapt>=1.11.1
h5py>=2.9
cloudpickle>=0.8.1
tensorboardX>=2.5

102 changes: 91 additions & 11 deletions tensorlayerx/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# -*- coding: utf-8 -*-

from collections.abc import Iterable
from tensorboardX import SummaryWriter
from tensorlayerx.nn.core.common import _save_weights, _load_weights, \
_save_standard_weights_dict, _load_standard_weights_dict
from .utils import WithLoss, WithGradPD, WithGradMS, WithGradTF, TrainOneStepWithPD, \
Expand Down Expand Up @@ -88,33 +89,37 @@ def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, **kwargs
self.all_weights = network.all_weights
self.train_weights = self.network.trainable_weights

def train(self, n_epoch, train_dataset=None, test_dataset=False, print_train_batch=False, print_freq=5):
def train(self, n_epoch, train_dataset=None, test_dataset=False, print_train_batch=False, print_freq=5, loss_monitor=False):
if not isinstance(train_dataset, Iterable):
raise Exception("Expected type in (train_dataset, Iterable), but got {}.".format(type(train_dataset)))

if tlx.BACKEND == 'tensorflow':
self.tf_train(
n_epoch=n_epoch, train_dataset=train_dataset, network=self.network, loss_fn=self.loss_fn,
train_weights=self.train_weights, optimizer=self.optimizer, metrics=self.metrics,
print_train_batch=print_train_batch, print_freq=print_freq, test_dataset=test_dataset
print_train_batch=print_train_batch, print_freq=print_freq, test_dataset=test_dataset,
loss_monitor=loss_monitor
)
elif tlx.BACKEND == 'mindspore':
self.ms_train(
n_epoch=n_epoch, train_dataset=train_dataset, network=self.network, loss_fn=self.loss_fn,
train_weights=self.train_weights, optimizer=self.optimizer, metrics=self.metrics,
print_train_batch=print_train_batch, print_freq=print_freq, test_dataset=test_dataset
print_train_batch=print_train_batch, print_freq=print_freq, test_dataset=test_dataset,
loss_monitor=loss_monitor
)
elif tlx.BACKEND == 'paddle':
self.pd_train(
n_epoch=n_epoch, train_dataset=train_dataset, network=self.network, loss_fn=self.loss_fn,
train_weights=self.train_weights, optimizer=self.optimizer, metrics=self.metrics,
print_train_batch=print_train_batch, print_freq=print_freq, test_dataset=test_dataset
print_train_batch=print_train_batch, print_freq=print_freq, test_dataset=test_dataset,
loss_monitor=loss_monitor
)
elif tlx.BACKEND == 'torch':
self.th_train(
n_epoch=n_epoch, train_dataset=train_dataset, network=self.network, loss_fn=self.loss_fn,
train_weights=self.train_weights, optimizer=self.optimizer, metrics=self.metrics,
print_train_batch=print_train_batch, print_freq=print_freq, test_dataset=test_dataset
print_train_batch=print_train_batch, print_freq=print_freq, test_dataset=test_dataset,
loss_monitor=loss_monitor
)

def eval(self, test_dataset):
Expand Down Expand Up @@ -263,8 +268,12 @@ def load_weights(self, file_path, format=None, in_order=True, skip=False):

def tf_train(
self, n_epoch, train_dataset, network, loss_fn, train_weights, optimizer, metrics, print_train_batch,
print_freq, test_dataset
print_freq, test_dataset, loss_monitor
):
if loss_monitor:
writer = SummaryWriter('loss_file/monitor')
train_batch, test_batch = 0, 0

for epoch in range(n_epoch):
start_time = time.time()

Expand All @@ -290,11 +299,17 @@ def tf_train(
train_acc += np.mean(np.equal(np.argmax(_logits, 1), y_batch))
n_iter += 1

if loss_monitor:
train_batch += 1
writer.add_scalar('train_loss', tlx.ops.convert_to_numpy(train_loss / n_iter), train_batch)
writer.add_scalar('train_acc', train_acc / n_iter, train_batch)

if print_train_batch:
print("Epoch {} of {} took {}".format(epoch + 1, n_epoch, time.time() - start_time))
print(" train loss: {}".format(train_loss / n_iter))
print(" train acc: {}".format(train_acc / n_iter))


if epoch + 1 == 1 or (epoch + 1) % print_freq == 0:
print("Epoch {} of {} took {}".format(epoch + 1, n_epoch, time.time() - start_time))
print(" train loss: {}".format(train_loss / n_iter))
Expand All @@ -315,16 +330,31 @@ def tf_train(
else:
val_acc += np.mean(np.equal(np.argmax(_logits, 1), y_batch))
n_iter += 1

if loss_monitor:
test_batch += 1
writer.add_scalar('val_loss', tlx.ops.convert_to_numpy(val_loss / n_iter), test_batch)
writer.add_scalar('val_acc', val_acc / n_iter, test_batch)

print(" val loss: {}".format(val_loss / n_iter))
print(" val acc: {}".format(val_acc / n_iter))

if loss_monitor:
writer.export_scalars_to_json("./all_scalars.json")
writer.close()

def ms_train(
self, n_epoch, train_dataset, network, loss_fn, train_weights, optimizer, metrics, print_train_batch,
print_freq, test_dataset
print_freq, test_dataset, loss_monitor
):
net_with_criterion = WithLoss(network, loss_fn)
train_network = GradWrap(net_with_criterion, network.trainable_weights)
train_network.set_train()

if loss_monitor:
writer = SummaryWriter('loss_file/monitor')
train_batch, test_batch = 0, 0

for epoch in range(n_epoch):
start_time = time.time()
train_loss, train_acc, n_iter = 0, 0, 0
Expand All @@ -343,6 +373,11 @@ def ms_train(
train_acc += np.mean((P.Equal()(P.Argmax(axis=1)(output), y_batch).asnumpy()))
n_iter += 1

if loss_monitor:
train_batch += 1
writer.add_scalar('train_loss', train_loss / n_iter, train_batch)
writer.add_scalar('train_acc', train_acc / n_iter, train_batch)

if print_train_batch:
print("Epoch {} of {} took {}".format(epoch + 1, n_epoch, time.time() - start_time))
print(" train loss: {}".format(train_loss / n_iter))
Expand All @@ -368,13 +403,27 @@ def ms_train(
else:
val_acc += np.mean((P.Equal()(P.Argmax(axis=1)(_logits), y_batch).asnumpy()))
n_iter += 1

if loss_monitor:
test_batch += 1
writer.add_scalar('val_loss', val_loss / n_iter, test_batch)
writer.add_scalar('val_acc', val_acc / n_iter, test_batch)

print(" val loss: {}".format(val_loss / n_iter))
print(" val acc: {}".format(val_acc / n_iter))

if loss_monitor:
writer.export_scalars_to_json("./all_scalars.json")
writer.close()

def pd_train(
self, n_epoch, train_dataset, network, loss_fn, train_weights, optimizer, metrics, print_train_batch,
print_freq, test_dataset
print_freq, test_dataset, loss_monitor
):
if loss_monitor:
writer = SummaryWriter('loss_file/monitor')
train_batch, test_batch = 0, 0

for epoch in range(n_epoch):
start_time = time.time()

Expand All @@ -397,6 +446,11 @@ def pd_train(
train_acc += pd.metric.accuracy(output, y_batch)
n_iter += 1

if loss_monitor:
train_batch += 1
writer.add_scalar('train_loss', train_loss / n_iter, train_batch)
writer.add_scalar('train_acc', train_acc / n_iter, train_batch)

if print_train_batch:
print("Epoch {} of {} took {}".format(epoch + 1, n_epoch, time.time() - start_time))
print(" train loss: {}".format(train_loss / n_iter))
Expand All @@ -422,15 +476,27 @@ def pd_train(
else:
val_acc += np.mean(np.equal(np.argmax(_logits, 1), y_batch))
n_iter += 1

if loss_monitor:
test_batch += 1
writer.add_scalar('val_loss', val_loss / n_iter, test_batch)
writer.add_scalar('val_acc', val_acc / n_iter, test_batch)

print(" val loss: {}".format(val_loss / n_iter))
print(" val acc: {}".format(val_acc / n_iter))

if loss_monitor:
writer.export_scalars_to_json("./all_scalars.json")
writer.close()

def th_train(
self, n_epoch, train_dataset, network, loss_fn, train_weights, optimizer, metrics, print_train_batch,
print_freq, test_dataset
print_freq, test_dataset, loss_monitor
):
# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# network = network.to(device)
if loss_monitor:
writer = SummaryWriter('loss_file/monitor')
train_batch, test_batch = 0, 0

for epoch in range(n_epoch):
start_time = time.time()

Expand All @@ -451,6 +517,11 @@ def th_train(
train_acc += (output.argmax(1) == y_batch).type(torch.float).mean().item()
n_iter += 1

if loss_monitor:
train_batch += 1
writer.add_scalar('train_loss', tlx.ops.convert_to_numpy(train_loss / n_iter), train_batch)
writer.add_scalar('train_acc', train_acc / n_iter, train_batch)

if print_train_batch:
print("Epoch {} of {} took {}".format(epoch + 1, n_epoch, time.time() - start_time))
print(" train loss: {}".format(train_loss / n_iter))
Expand All @@ -476,9 +547,18 @@ def th_train(
else:
val_acc += (_logits.argmax(1) == y_batch).type(torch.float).mean().item()
n_iter += 1

if loss_monitor:
test_batch += 1
writer.add_scalar('val_loss', tlx.ops.convert_to_numpy(val_loss / n_iter), test_batch)
writer.add_scalar('val_acc', val_acc / n_iter, test_batch)

print(" val loss: {}".format(val_loss / n_iter))
print(" val acc: {}".format(val_acc / n_iter))

if loss_monitor:
writer.export_scalars_to_json("./all_scalars.json")
writer.close()

class WithGrad(object):
"""Module that returns the gradients.
Expand Down