Skip to content

Commit

Permalink
lint with flake8-comprehensions
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Jan 9, 2020
1 parent a995070 commit 552c2b3
Show file tree
Hide file tree
Showing 16 changed files with 33 additions and 33 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8
pip install flake8 flake8-comprehensions flake8-bugbear
flake8 --version
- name: Lint
run: |
Expand Down
4 changes: 2 additions & 2 deletions scripts/dump-model-params.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ def guess_inputs(input_dir):
if len(set(var_to_dump)) != len(var_to_dump):
logger.warn("TRAINABLE and MODEL variables have duplication!")
var_to_dump = list(set(var_to_dump))
globvarname = set([k.name for k in tf.global_variables()])
var_to_dump = set([k.name for k in var_to_dump if k.name in globvarname])
globvarname = {k.name for k in tf.global_variables()}
var_to_dump = {k.name for k in var_to_dump if k.name in globvarname}

for name in var_to_dump:
assert name in dic, "Variable {} not found in the model!".format(name)
Expand Down
2 changes: 1 addition & 1 deletion tensorpack/callbacks/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def __init__(self, enable_step=False, enable_epoch=True,
def compile_regex(rs):
if rs is None:
return None
rs = set([re.compile(r) for r in rs])
rs = {re.compile(r) for r in rs}
return rs

self._whitelist = compile_regex(whitelist)
Expand Down
6 changes: 3 additions & 3 deletions tensorpack/contrib/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,20 @@ def __call__(self, *input_tensors):
"""
reuse = tf.get_variable_scope().reuse

old_trainable_names = set([x.name for x in tf.trainable_variables()])
old_trainable_names = {x.name for x in tf.trainable_variables()}
trainable_backup = backup_collection([tf.GraphKeys.TRAINABLE_VARIABLES])
update_ops_backup = backup_collection([tf.GraphKeys.UPDATE_OPS])

def post_process_model(model):
added_trainable_names = set([x.name for x in tf.trainable_variables()])
added_trainable_names = {x.name for x in tf.trainable_variables()}
restore_collection(trainable_backup)

for v in model.weights:
# In Keras, the collection is not respected and could contain non-trainable vars.
# We put M.weights into the collection instead.
if v.name not in old_trainable_names and v.name in added_trainable_names:
tf.add_to_collection(tf.GraphKeys.TRAINABLE_VARIABLES, v)
new_trainable_names = set([x.name for x in tf.trainable_variables()])
new_trainable_names = {x.name for x in tf.trainable_variables()}

for n in added_trainable_names:
if n not in new_trainable_names:
Expand Down
6 changes: 3 additions & 3 deletions tensorpack/dataflow/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def __init__(self, df_lists):
"""
super(RandomChooseData, self).__init__()
if isinstance(df_lists[0], (tuple, list)):
assert sum([v[1] for v in df_lists]) == 1.0
assert sum(v[1] for v in df_lists) == 1.0
self.df_lists = df_lists
else:
prob = 1.0 / len(df_lists)
Expand Down Expand Up @@ -512,7 +512,7 @@ def reset_state(self):
d.reset_state()

def __len__(self):
return sum([len(x) for x in self.df_lists])
return sum(len(x) for x in self.df_lists)

def __iter__(self):
for d in self.df_lists:
Expand Down Expand Up @@ -565,7 +565,7 @@ def __len__(self):
"""
Return the minimum size among all.
"""
return min([len(k) for k in self.df_lists])
return min(len(k) for k in self.df_lists)

def __iter__(self):
itrs = [k.__iter__() for k in self.df_lists]
Expand Down
2 changes: 1 addition & 1 deletion tensorpack/dataflow/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, filename, data_paths, shuffle=True):
logger.info("Loading {} to memory...".format(filename))
self.dps = [self.f[k].value for k in data_paths]
lens = [len(k) for k in self.dps]
assert all([k == lens[0] for k in lens])
assert all(k == lens[0] for k in lens)
self._size = lens[0]
self.shuffle = shuffle

Expand Down
4 changes: 2 additions & 2 deletions tensorpack/graph_builder/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def _shadow_model_variables(shadow_vars):
list of (shadow_model_var, local_model_var) used for syncing.
"""
G = tf.get_default_graph()
curr_shadow_vars = set([v.name for v in shadow_vars])
curr_shadow_vars = {v.name for v in shadow_vars}
model_vars = tf.model_variables()
shadow_model_vars = []
for v in model_vars:
Expand Down Expand Up @@ -346,7 +346,7 @@ def strip_port(s):
return s[:-2]
return s
local_vars = tf.local_variables()
local_var_by_name = dict([(strip_port(v.name), v) for v in local_vars])
local_var_by_name = {strip_port(v.name): v for v in local_vars}
ops = []
nr_shadow_vars = len(self._shadow_vars)
for v in self._shadow_vars:
Expand Down
10 changes: 5 additions & 5 deletions tensorpack/graph_builder/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def basename(x):
return re.sub('tower[0-9]+/', '', x.op.name)

if len(set(nvars)) != 1:
names_per_gpu = [set([basename(k[1]) for k in grad_and_vars]) for grad_and_vars in grad_list]
names_per_gpu = [{basename(k[1]) for k in grad_and_vars} for grad_and_vars in grad_list]
inters = copy.copy(names_per_gpu[0])
for s in names_per_gpu:
inters &= s
Expand Down Expand Up @@ -247,11 +247,11 @@ def build(self, grad_list, get_opt_fn):

DataParallelBuilder._check_grad_list(grad_list)

dtypes = set([x[0].dtype.base_dtype for x in grad_list[0]])
dtypes = {x[0].dtype.base_dtype for x in grad_list[0]}
dtypes_nccl_supported = [tf.float32, tf.float64]
if get_tf_version_tuple() >= (1, 8):
dtypes_nccl_supported.append(tf.float16)
valid_for_nccl = all([k in dtypes_nccl_supported for k in dtypes])
valid_for_nccl = all(k in dtypes_nccl_supported for k in dtypes)
if self._mode == 'nccl' and not valid_for_nccl:
logger.warn("Cannot use mode='nccl' because some gradients have unsupported types. Fallback to mode='cpu'")
self._mode = 'cpu'
Expand Down Expand Up @@ -314,8 +314,8 @@ def get_post_init_ops():
"""
# literally all variables, because it's better to sync optimizer-internal variables as well
all_vars = tf.global_variables() + tf.local_variables()
var_by_name = dict([(v.name, v) for v in all_vars])
trainable_names = set([x.name for x in tf.trainable_variables()])
var_by_name = {v.name: v for v in all_vars}
trainable_names = {x.name for x in tf.trainable_variables()}
post_init_ops = []

def log_failure(name, reason):
Expand Down
4 changes: 2 additions & 2 deletions tensorpack/graph_builder/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _replace_global_by_local(kwargs):
if 'collections' in kwargs:
collections = kwargs['collections']
if not collections:
collections = set([tf.GraphKeys.GLOBAL_VARIABLES])
collections = set(tf.GraphKeys.GLOBAL_VARIABLES)
else:
collections = set(collections.copy())
collections.remove(tf.GraphKeys.GLOBAL_VARIABLES)
Expand Down Expand Up @@ -343,7 +343,7 @@ def compute_strategy(self, grads):
logger.info("Skip GradientPacker due to too few gradients.")
return False
# should have the same dtype
dtypes = set([g.dtype for g in grads])
dtypes = {g.dtype for g in grads}
if len(dtypes) != 1:
logger.info("Skip GradientPacker due to inconsistent gradient types.")
return False
Expand Down
2 changes: 1 addition & 1 deletion tensorpack/input_source/input_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def _setup(self, input_signature):
self._spec = input_signature
if self._dataset is not None:
types = self._dataset.output_types
spec_types = tuple([k.dtype for k in input_signature])
spec_types = tuple(k.dtype for k in input_signature)
assert len(types) == len(spec_types), \
"Dataset and input signature have different length! {} != {}".format(
len(types), len(spec_types))
Expand Down
2 changes: 1 addition & 1 deletion tensorpack/models/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def Conv2D(
filter_shape = kernel_shape + [in_channel / split, out_channel]
stride = shape4d(strides, data_format=data_format)

kwargs = dict(data_format=data_format)
kwargs = {"data_format": data_format}
if get_tf_version_tuple() >= (1, 5):
kwargs['dilations'] = shape4d(dilation_rate, data_format=data_format)

Expand Down
4 changes: 2 additions & 2 deletions tensorpack/tfutils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ def describe_trainable_vars():
data.append([get_op_tensor_name(v.name)[0], shape, ele, v.device, v.dtype.base_dtype.name])
headers = ['name', 'shape', '#elements', 'device', 'dtype']

dtypes = list(set([x[4] for x in data]))
dtypes = list({x[4] for x in data})
if len(dtypes) == 1 and dtypes[0] == "float32":
# don't log the dtype if all vars are float32 (default dtype)
for x in data:
del x[4]
del headers[4]

devices = set([x[3] for x in data])
devices = {x[3] for x in data}
if len(devices) == 1:
# don't log the device if all vars on the same device
for x in data:
Expand Down
4 changes: 2 additions & 2 deletions tensorpack/tfutils/varmanip.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def dump_session_params(path):
var.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
# TODO dedup
assert len(set(var)) == len(var), "TRAINABLE and MODEL variables have duplication!"
gvars = set([k.name for k in tf.global_variables()])
gvars = {k.name for k in tf.global_variables()}
var = [v for v in var if v.name in gvars]
result = {}
for v in var:
Expand All @@ -167,7 +167,7 @@ def save_chkpt_vars(dic, path):
path: save as npz if the name ends with '.npz', otherwise save as a checkpoint.
"""
logger.info("Variables to save to {}:".format(path))
keys = sorted(list(dic.keys()))
keys = sorted(dic.keys())
logger.info(pprint.pformat(keys))

assert not path.endswith('.npy')
Expand Down
4 changes: 2 additions & 2 deletions tensorpack/utils/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def _pad_channel(plist):

plist = _pad_channel(plist)
shapes = [x.shape for x in plist]
ph = max([s[0] for s in shapes])
pw = max([s[1] for s in shapes])
ph = max(s[0] for s in shapes)
pw = max(s[1] for s in shapes)

ret = np.zeros((len(plist), ph, pw, 3), dtype=plist[0].dtype)
ret[:, :, :] = bgcolor
Expand Down
8 changes: 4 additions & 4 deletions tests/benchmark-serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ def benchmark_serializer(dumps, loads, data, num):

def display_results(name, results):
logger.info("Encoding benchmark for {}:".format(name))
data = sorted([(x, y[0]) for x, y in results], key=operator.itemgetter(1))
data = sorted(((x, y[0]) for x, y in results), key=operator.itemgetter(1))
print(tabulate(data, floatfmt='.5f'))

logger.info("Decoding benchmark for {}:".format(name))
data = sorted([(x, y[1]) for x, y in results], key=operator.itemgetter(1))
data = sorted(((x, y[1]) for x, y in results), key=operator.itemgetter(1))
print(tabulate(data, floatfmt='.5f'))


Expand All @@ -64,8 +64,8 @@ def fake_json_data():
pellentesque quis sollicitudin id, adipiscing.
""" * 100,
'list': list(range(100)) * 500,
'dict': dict((str(i), 'a') for i in range(50000)),
'dict2': dict((i, 'a') for i in range(50000)),
'dict': {str(i): 'a' for i in range(50000)},
'dict2': {i: 'a' for i in range(50000)},
'int': 3000,
'float': 100.123456
}
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[flake8]
max-line-length = 120
# See https://pep8.readthedocs.io/en/latest/intro.html#error-codes
ignore = E265,E741,E742,E743,W504,W605
ignore = E265,E741,E742,E743,W504,W605,C408
exclude = .git,
__init__.py,
setup.py,
Expand Down

0 comments on commit 552c2b3

Please sign in to comment.