diff --git a/generate_features.py b/generate_features.py index cccdddd..910e71e 100644 --- a/generate_features.py +++ b/generate_features.py @@ -56,7 +56,8 @@ def data_generator(filenames, image_directory, batch_size=64): @click.option("--encoder", "-e", default="VGG19", required=False, type=click.STRING) @click.option("--layer-name", "-l", default="block5_conv4", required=False, type=click.STRING) @click.option("--output-folder", "-o", default=".", required=False, type=click.Path(exists=True, file_okay=False, dir_okay=True)) -def cmd(data_path, encoder, layer_name, output_folder): +@click.option("--batch-size", "-b", default=64, required=False, type=click.INT) +def cmd(data_path, encoder, layer_name, output_folder, batch_size): # create data directory if it does not exist os.makedirs(data_path, exist_ok=True) @@ -73,14 +74,14 @@ def cmd(data_path, encoder, layer_name, output_folder): with h5py.File(os.path.join(output_folder, "image.features.train.{0}.{1}.h5".format(encoder, layer_name)), "w") as h5: index = 0 - for batch in encode_features(model, filenames_train, os.path.join(data_path, "train2017")): + for batch in encode_features(model, filenames_train, os.path.join(data_path, "train2017"), batch_size=batch_size): for item in batch: h5.create_dataset(str(index), data=item, compression="lzf") index += 1 with h5py.File(os.path.join(output_folder, "image.features.val.{0}.{1}.h5".format(encoder, layer_name)), "w") as h5: index = 0 - for batch in encode_features(model, filenames_val, os.path.join(data_path, "val2017")): + for batch in encode_features(model, filenames_val, os.path.join(data_path, "val2017"), batch_size=batch_size): for item in batch: h5.create_dataset(str(index), data=item, compression="lzf") index += 1 diff --git a/utility/__init__.py b/utility/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utility/coco.py b/utility/coco.py index 42cecf9..13887bb 100644 --- a/utility/coco.py +++ b/utility/coco.py @@ -37,8 +37,8 @@ import json import os -import download -from cache import cache +import utility.download as download +from utility.cache import cache ######################################################################## diff --git a/utility/utility.py b/utility/utility.py index 84b3420..e1fa6ef 100644 --- a/utility/utility.py +++ b/utility/utility.py @@ -1,6 +1,5 @@ import numpy as np -import hickle -import coco +import utility.coco as coco import h5py def load_validation_data(maximum_caption_length):