Skip to content

Commit

Permalink
Update image classif eval epoch impl for augments
Browse files Browse the repository at this point in the history
  • Loading branch information
plstcharles committed Oct 27, 2018
1 parent 39caf2f commit 02c13e0
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions src/thelper/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,7 @@ def _train_epoch(self, model, epoch, iter, dev, optimizer, loader, metrics, writ
input, label = self._to_tensor(sample)
optimizer.zero_grad()
label = self._upload_tensor(label, dev)
if isinstance(input, list):
if isinstance(input, list): # training samples got augmented, we need to backprop in multiple steps
if not input:
raise AssertionError("cannot train with empty post-augment sample lists")
if not self.warned_no_shuffling_augments:
Expand Down Expand Up @@ -906,9 +906,20 @@ def _eval_epoch(self, model, epoch, iter, dev, loader, metrics, writer=None):
epoch_size = len(loader)
for idx, sample in enumerate(loader):
input, label = self._to_tensor(sample)
input = self._upload_tensor(input, dev)
label = self._upload_tensor(label, dev)
pred = model(input)
if isinstance(input, list): # evaluation samples got augmented, we need to get the mean prediction
if not input:
raise AssertionError("cannot eval with empty post-augment sample lists")
preds = None
for input_idx in range(len(input)):
pred = model(self._upload_tensor(input[input_idx], dev))
if preds is None:
preds = torch.unsqueeze(pred.clone(), 0)
else:
preds = torch.cat((preds, torch.unsqueeze(pred, 0)), 0)
pred = torch.mean(preds, dim=0)
else:
pred = model(self._upload_tensor(input, dev))
if metrics:
meta = {key: sample[key] for key in self.meta_keys}
for metric in metrics.values():
Expand Down

0 comments on commit 02c13e0

Please sign in to comment.