Skip to content

Commit bf921cb

Browse files
author
xtt
committed
fix bugs of dataloader
1 parent 8064d2d commit bf921cb

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

dataset.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self, odgt, opt, max_sample=-1, batch_per_gpu=1):
3636

3737
self.list_sample = [json.loads(x.rstrip()) for x in open(odgt, 'r')]
3838

39-
np.random.shuffle(self.list_sample)
39+
self.if_shuffled = False
4040
if max_sample > 0:
4141
self.list_sample = self.list_sample[0:max_sample]
4242
self.num_sample = len(self.list_sample)
@@ -69,6 +69,11 @@ def _get_sub_batch(self):
6969
return batch_records
7070

7171
def __getitem__(self, index):
72+
# NOTE: random shuffle for the first time. shuffle in __init__ is useless
73+
if not self.if_shuffled:
74+
np.random.shuffle(self.list_sample)
75+
self.if_shuffled = True
76+
7277
# get sub-batch candidates
7378
batch_records = self._get_sub_batch()
7479

train.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ def train(segmentation_module, iterator, optimizers, history, epoch, args):
5252
tic = time.time()
5353

5454
# update average loss and acc
55-
ave_total_loss.update(loss.item())
56-
ave_acc.update(acc.item()*100)
55+
ave_total_loss.update(loss.data[0])
56+
ave_acc.update(acc.data[0]*100)
5757

5858
# calculate accuracy, and display
5959
if i % args.disp_iter == 0:
@@ -67,8 +67,8 @@ def train(segmentation_module, iterator, optimizers, history, epoch, args):
6767

6868
fractional_epoch = epoch - 1 + 1. * i / args.epoch_iters
6969
history['train']['epoch'].append(fractional_epoch)
70-
history['train']['loss'].append(loss.item())
71-
history['train']['acc'].append(acc.item())
70+
history['train']['loss'].append(loss.data[0])
71+
history['train']['acc'].append(acc.data[0])
7272

7373
# adjust learning rate
7474
cur_iter = i + (epoch - 1) * args.epoch_iters
@@ -157,6 +157,8 @@ def main(args):
157157

158158
# create loader iterator
159159
iterator_train = iter(loader_train)
160+
from IPython import embed
161+
embed()
160162

161163
# load nets into gpu
162164
if args.num_gpus > 1:

0 commit comments

Comments
 (0)