diff --git a/posenet/converter/tfjs2python.py b/posenet/converter/tfjs2python.py index 9745b2d..649fb43 100755 --- a/posenet/converter/tfjs2python.py +++ b/posenet/converter/tfjs2python.py @@ -50,9 +50,8 @@ def load_variables(chkpoint, base_dir=BASE_DIR): download(chkpoint, base_dir) assert os.path.exists(manifest_path) - f = open(manifest_path) - variables = json.load(f) - f.close() + with open(manifest_path) as f: + variables = json.load(f) # with tf.variable_scope(None, 'MobilenetV1'): for x in variables: diff --git a/posenet/converter/wget.py b/posenet/converter/wget.py index 51f1d96..9de0a5d 100644 --- a/posenet/converter/wget.py +++ b/posenet/converter/wget.py @@ -1,6 +1,7 @@ import urllib.request import posixpath import json +import zlib import os from posenet.converter.config import load_config @@ -12,8 +13,18 @@ def download_file(checkpoint, filename, base_dir): + output_path = os.path.join(base_dir, checkpoint, filename) url = posixpath.join(GOOGLE_CLOUD_STORAGE_DIR, checkpoint, filename) - urllib.request.urlretrieve(url, os.path.join(base_dir, checkpoint, filename)) + req = urllib.request.Request(url) + response = urllib.request.urlopen(req) + if response.info().get('Content-Encoding') == 'gzip': + data = zlib.decompress(response.read(), zlib.MAX_WBITS | 32) + else: + # this path not tested since gzip encoding default on google server + # may need additional encoding/text handling if hit in the future + data = response.read() + with open(output_path, 'wb') as f: + f.write(data) def download(checkpoint, base_dir='./weights/'): @@ -22,9 +33,8 @@ def download(checkpoint, base_dir='./weights/'): os.makedirs(save_dir) download_file(checkpoint, 'manifest.json', base_dir) - - f = open(os.path.join(save_dir, 'manifest.json'), 'r') - json_dict = json.load(f) + with open(os.path.join(save_dir, 'manifest.json'), 'r') as f: + json_dict = json.load(f) for x in json_dict: filename = json_dict[x]['filename']