Skip to content

Commit

Permalink
explicitely close the pool, explicitely close the tarwriter
Browse files Browse the repository at this point in the history
should help stability for #30 and #34
  • Loading branch information
rom1504 committed Aug 25, 2021
1 parent 581c5fd commit fb0747f
Showing 1 changed file with 61 additions and 45 deletions.
106 changes: 61 additions & 45 deletions img2dataset/downloader.py
Expand Up @@ -82,46 +82,54 @@ def resize_with_border(im, desired_size=256):
value=color)
return new_im

def webdataset_sample_writer_builder(shard_id, output_folder):
shard_name = "%05d" % shard_id
tarwriter = wds.TarWriter(f"{output_folder}/{shard_name}.tar")
return functools.partial(webdataset_sample_writer, tarwriter=tarwriter)

def webdataset_sample_writer(img_str, key, caption, meta, tarwriter):
key = "%09d" % key
sample = {
"__key__": key,
"jpg": img_str
}
if caption is not None:
sample["txt"] = caption
if meta is not None:
sample["json"] = json.dumps(meta, indent=4)
tarwriter.write(sample)

def files_sample_writer_builder(shard_id, output_folder):
shard_name = "%05d" % shard_id
subfolder = f"{output_folder}/{shard_name}"
if not os.path.exists(subfolder):
os.mkdir(subfolder)
return functools.partial(files_sample_writer, output_folder=subfolder)

def files_sample_writer(img_str, key, caption, meta, output_folder):
key = "%04d" % key
filename = f'{output_folder}/{key}.jpg'
with open(filename, "wb") as f:
f.write(img_str)
if caption is not None:
caption_filename = f'{output_folder}/{key}.txt'
with open(caption_filename, "w") as f:
f.write(caption)
if meta is not None:
j = json.dumps(meta, indent=4)
meta_filename = f'{output_folder}/{key}.json'
with open(meta_filename, "w") as f:
f.write(j)

def one_process_downloader(row, sample_writer_builder, resizer, thread_count, save_metadata, output_folder, column_list):
class WebDatasetSampleWriter:
def __init__(self, shard_id, output_folder):
shard_name = "%05d" % shard_id
self.tarwriter = wds.TarWriter(f"{output_folder}/{shard_name}.tar")

def write(self, img_str, key, caption, meta):
key = "%09d" % key
sample = {
"__key__": key,
"jpg": img_str
}
if caption is not None:
sample["txt"] = caption
if meta is not None:
sample["json"] = json.dumps(meta, indent=4)
self.tarwriter.write(sample)

def close(self):
self.tarwriter.close()

class FilesSampleWriter:
def __init__(self, shard_id, output_folder):
shard_name = "%05d" % shard_id
self.subfolder = f"{output_folder}/{shard_name}"
if not os.path.exists(self.subfolder):
os.mkdir(self.subfolder)

def write(self, img_str, key, caption, meta):
key = "%04d" % key
filename = f'{self.subfolder}/{key}.jpg'
with open(filename, "wb") as f:
f.write(img_str)
if caption is not None:
caption_filename = f'{self.subfolder}/{key}.txt'
with open(caption_filename, "w") as f:
f.write(caption)
if meta is not None:
j = json.dumps(meta, indent=4)
meta_filename = f'{self.subfolder}/{key}.json'
with open(meta_filename, "w") as f:
f.write(j)

def close(self):
pass



def one_process_downloader(row, sample_writer_class, resizer, thread_count, save_metadata, output_folder, column_list):
shard_id, shard_to_dl = row

if save_metadata:
Expand All @@ -135,7 +143,7 @@ def one_process_downloader(row, sample_writer_builder, resizer, thread_count, sa
caption_indice = column_list.index("caption") if "caption" in column_list else None
key_url_list = [(key, x[url_indice]) for key, x in shard_to_dl]

sample_writer = sample_writer_builder(shard_id)
sample_writer = sample_writer_class(shard_id, output_folder)
with ThreadPool(thread_count) as thread_pool:
for key, img_stream, error_message in thread_pool.imap_unordered(download_image, key_url_list):
_, sample_data = shard_to_dl[key]
Expand Down Expand Up @@ -182,7 +190,12 @@ def one_process_downloader(row, sample_writer_builder, resizer, thread_count, sa
else:
meta=None

sample_writer(img, key, sample_data[caption_indice] if caption_indice is not None else None, meta)
sample_writer.write(img, key, sample_data[caption_indice] if caption_indice is not None else None, meta)

sample_writer.close()
thread_pool.terminate()
thread_pool.join()
del thread_pool

if save_metadata:
df = pd.DataFrame(metadatas)
Expand Down Expand Up @@ -248,12 +261,12 @@ def download_one_file(url_list):
del images_to_dl

if output_format == "webdataset":
sample_writer_builder = functools.partial(webdataset_sample_writer_builder, output_folder=output_folder)
sample_writer_class = WebDatasetSampleWriter
elif output_format == "files":
sample_writer_builder = functools.partial(files_sample_writer_builder, output_folder=output_folder)
sample_writer_class = FilesSampleWriter

resizer = functools.partial(resize_image, image_size=image_size, resize_mode=resize_mode, resize_only_if_bigger=resize_only_if_bigger)
downloader = functools.partial(one_process_downloader, sample_writer_builder=sample_writer_builder, resizer=resizer, \
downloader = functools.partial(one_process_downloader, sample_writer_class=sample_writer_class, resizer=resizer, \
thread_count=thread_count, save_metadata=save_metadata, output_folder=output_folder, column_list=column_list)

total_total = 0
Expand All @@ -272,6 +285,9 @@ def download_one_file(url_list):
message+=f"failed resize={1.0*total_failed_to_resize/total_total:.2f}"
print(message+"\n" , sep=' ', end='', flush=True)
pass
process_pool.terminate()
process_pool.join()
del process_pool

if os.path.isdir(url_list):
input_files = glob.glob(url_list+"/*."+input_format)
Expand Down

0 comments on commit fb0747f

Please sign in to comment.