Skip to content

Commit

Permalink
Multiple bugfixes:
Browse files Browse the repository at this point in the history
1. Concatenate prediction batches in cpu() to avoid exhausting GPU memory
2. loader.dataset.num_inputs and loader.dataset.num_targets do not exist in standar PyTorch datasets, temporarily replace them by 1
3. Add missing self._loss_weights in fit_loader
4. Extract (missing) losses return value from fit_loader
5. Move batches to GPU if required in predict() and predict_loader() methods
  • Loading branch information
recastrodiaz committed Jul 23, 2017
1 parent cc95c6c commit 19e8021
Showing 1 changed file with 31 additions and 11 deletions.
42 changes: 31 additions & 11 deletions torchsample/modules/module_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,22 +315,22 @@ def fit_loader(self,
"""
self.model.train(mode=True)
# ----------------------------------------------------------------------
num_inputs = loader.dataset.num_inputs
num_targets = loader.dataset.num_targets
num_inputs = 1#loader.dataset.num_inputs
num_targets = 1#loader.dataset.num_targets
len_inputs = len(loader.dataset)
batch_size = loader.batch_size

if val_loader is not None:
num_val_inputs = val_loader.dataset.num_inputs
num_val_targets = val_loader.dataset.num_targets
num_val_inputs = 1#val_loader.dataset.num_inputs
num_val_targets = 1#val_loader.dataset.num_targets
if (num_inputs != num_val_inputs) or (num_targets != num_val_targets):
raise ValueError('num_inputs != num_val_inputs or num_targets != num_val_targets')
has_val_data = val_loader is not None
num_batches = int(math.ceil(len_inputs / batch_size))
# ----------------------------------------------------------------------

fit_helper = _get_helper(self, num_inputs, num_targets)
fit_loss_fn = fit_helper.get_partial_loss_fn(self._loss_fn)
fit_loss_fn = fit_helper.get_partial_loss_fn(self._loss_fn, self._loss_weights)
fit_forward_fn = fit_helper.get_partial_forward_fn(self.model)

with TQDM() as pbar:
Expand Down Expand Up @@ -373,7 +373,7 @@ def fit_loader(self,
# ---------------------------------------------
self._optimizer.zero_grad()
output_batch = fit_forward_fn(input_batch)
loss = fit_loss_fn(output_batch, target_batch)
loss, losses = fit_loss_fn(output_batch, target_batch)
loss.backward()
self._optimizer.step()
# ---------------------------------------------
Expand Down Expand Up @@ -420,8 +420,8 @@ def predict(self,
for batch_idx in range(num_batches):
input_batch, _ = predict_helper.grab_batch(batch_idx, batch_size, inputs, None, volatile=True)
if cuda_device >= 0:
inputs = predict_helper.move_to_cuda(cuda_device, inputs)
output_batch = pred_forward_fn(input_batch)
input_batch, _ = predict_helper.move_to_cuda(cuda_device, input_batch)
output_batch = predict_helper.move_to_cpu(pred_forward_fn(input_batch))

if batch_idx == 0:
len_outputs = 1 if not _is_tuple_or_list(output_batch) else len(output_batch)
Expand Down Expand Up @@ -455,7 +455,9 @@ def predict_loader(self,
loader_iter = iter(loader)
for batch_idx in range(num_batches):
input_batch, _ = predict_helper.grab_batch_from_loader(loader_iter, volatile=True)
output_batch = pred_forward_fn(input_batch)
if cuda_device >= 0:
input_batch, _ = predict_helper.move_to_cuda(cuda_device, input_batch)
output_batch = predict_helper.move_to_cpu(pred_forward_fn(input_batch))

if batch_idx == 0:
len_outputs = 1 if not _is_tuple_or_list(output_batch) else len(output_batch)
Expand Down Expand Up @@ -525,7 +527,7 @@ def evaluate_loader(self,
num_batches = int(math.ceil(len_inputs / batch_size))

evaluate_helper = _get_helper(self, num_inputs, num_targets)
eval_loss_fn = evaluate_helper.get_partial_loss_fn(self._loss_fn)
eval_loss_fn = evaluate_helper.get_partial_loss_fn(self._loss_fn, self._loss_weights)
eval_forward_fn = evaluate_helper.get_partial_forward_fn(self.model)
eval_logs= {'val_loss': 0.}
loader_iter = iter(loader)
Expand All @@ -543,7 +545,7 @@ def evaluate_loader(self,

self._optimizer.zero_grad()
output_batch = eval_forward_fn(input_batch)
loss = eval_loss_fn(output_batch, target_batch)
loss, losses = eval_loss_fn(output_batch, target_batch)

samples_seen += batch_size
eval_logs['val_loss'] = (samples_seen*eval_logs['val_loss'] + loss.data[0]*batch_size) / (samples_seen+batch_size)
Expand Down Expand Up @@ -643,6 +645,8 @@ def _get_helper(trainer, num_inputs, num_targets):


class SingleInput_SingleTarget_Helper(object):
def move_to_cpu(self, output):
return output.cpu()
def move_to_cuda(self, cuda_device, inputs, targets):
inputs = inputs.cuda(cuda_device)
targets = targets.cuda(cuda_device)
Expand Down Expand Up @@ -678,6 +682,8 @@ def get_partial_loss_fn(self, loss_fn, loss_weights):


class SingleInput_MultiTarget_Helper(object):
def move_to_cpu(self, outputs):
return (output.cpu() for output in outputs)
def move_to_cuda(self, cuda_device, inputs, targets):
inputs = inputs.cuda(cuda_device)
targets = [target_.cuda(cuda_device) for target_ in targets]
Expand Down Expand Up @@ -711,6 +717,8 @@ def get_partial_loss_fn(self, loss_fn, loss_weights):
return functools.partial(self.calculate_loss, loss_fn=loss_fn, loss_weights=loss_weights)

class MultiInput_SingleTarget_Helper(object):
def move_to_cpu(self, outputs):
return output.cpu()
def move_to_cuda(self, cuda_device, inputs, targets):
inputs = [input_.cuda(cuda_device) for input_ in inputs]
targets = targets.cuda(cuda_device)
Expand Down Expand Up @@ -742,6 +750,8 @@ def get_partial_loss_fn(self, loss_fn, loss_weights):
return functools.partial(self.calculate_loss, loss_fn=loss_fn)

class MultiInput_MultiTarget_Helper(object):
def move_to_cpu(self, outputs):
return [output.cpu() for output in outputs]
def move_to_cuda(self, cuda_device, inputs, targets):
inputs = [input_.cuda(cuda_device) for input_ in inputs]
targets = [target_.cuda(cuda_device) for target_ in targets]
Expand Down Expand Up @@ -776,6 +786,11 @@ def get_partial_loss_fn(self, loss_fn, loss_weights):
return functools.partial(self.calculate_loss, loss_fn=loss_fn, loss_weights=loss_weights)

class SingleInput_NoTarget_Helper(object):
def move_to_cpu(self, outputs):
if isinstance(outputs, tuple):
return tuple(map(lambda output: output.cpu(), outputs))
else:
return outputs.cpu()
def move_to_cuda(self, cuda_device, inputs, targets=None):
inputs = inputs.cuda(cuda_device)
return inputs, None
Expand All @@ -802,6 +817,11 @@ def get_partial_loss_fn(self, loss_fn, loss_weights):
return functools.partial(self.calculate_loss, loss_fn=loss_fn)

class MultiInput_NoTarget_Helper(object):
def move_to_cpu(self, outputs):
if isinstance(outputs, tuple):
return tuple(map(lambda output: output.cpu(), outputs))
else:
return outputs.cpu()
def move_to_cuda(self, cuda_device, inputs, targets=None):
inputs = [input_.cuda(cuda_device) for input_ in inputs]
return inputs, None
Expand Down

0 comments on commit 19e8021

Please sign in to comment.