Skip to content

Commit

Permalink
Add label encoding in MiniImagenet
Browse files Browse the repository at this point in the history
  • Loading branch information
tristandeleu committed May 7, 2018
1 parent 19d1be7 commit 9c1523b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
8 changes: 8 additions & 0 deletions datasets.py
Expand Up @@ -58,17 +58,25 @@ def __init__(self, root, train=False, valid=False, test=False,
next(reader) # Skip the header
for line in reader:
self._data.append(tuple(line))
self._fit_label_encoding()

def __getitem__(self, index):
filename, label = self._data[index]
image = pil_loader(os.path.join(self.image_folder, filename))
label = self._label_encoder[label]
if self.transform is not None:
image = self.transform(image)
if self.target_transform is not None:
label = self.target_transform(label)

return image, label

def _fit_label_encoding(self):
_, labels = zip(*self._data)
unique_labels = set(labels)
self._label_encoder = dict((label, idx)
for (idx, label) in enumerate(unique_labels))

def _check_exists(self):
return (os.path.exists(self.image_folder)
and os.path.exists(self.split_filename))
Expand Down
2 changes: 1 addition & 1 deletion prior_miniimagenet.py
Expand Up @@ -102,7 +102,7 @@ def main(args):
best_loss = -1.
for epoch in range(args.num_epochs):
train(train_loader, model, prior, optimizer, args, writer)
loss, _ = test(valid_loader, model, prior, args, writer)
loss = test(valid_loader, model, prior, args, writer)

# reconstruction = generate_samples(fixed_images, model, args)
# grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True)
Expand Down

0 comments on commit 9c1523b

Please sign in to comment.