You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In utils/parallel.py, the function parallel_model is a wrapper which handles multiple computing devices.
But, if I understand it correctly, it achieves the following effect:
Suppose function model_fn returns k scalars (maybe multiple losses or metrics, e.g.: accuracy), i.e.: a tuple (o1, o2, ..., ok),
And m devices are available: d1, d2, ..., dm.
Then the return value of the function is of shape:
multiple return values + single device: [(o1_d1, o2_d1, ..., ok_d1)], i.e. a length-1 list of tuple
single return value + multiple devices: [o1_d1, o1_d2, ..., o1_dm]
single return value + single device: [o1_d1]
You see, in the second case, the return value is weird. Say, if my model_fn has 2 return values, in the multiple-device cases, I can use sharded_loss1, sharded_loss2 = parallel.parallel_model(fn, features, device_list) to catch these two losses; but if I only specify a single device from command line, the code breaks.
Certainly, I can judge isinstance(return_value_from_parallel_model, tuple) and decide how to deal with the return value, but this is stupid. It would be better to return ([o1_d1], [o2_d1], ..., [ok_d1]), i.e.: a tuple of lists, in the "multiple return values + single device" case, which leads to a more consistent design.
Hope I've made myself clear.
The text was updated successfully, but these errors were encountered:
Thank you for your feedback. The current codes assume that the mode_fn returns a single scalar value. So it will not have the problems you pointed out. Our primary intention is to write an implementation that is easy to read and modify. It depends on the users to customize their codes.
Yep, I know it works fine when model_fn returns a single scalar. I won't say anything if you don't handle multiple return values at all. But you do handle multiple return values here:
In
utils/parallel.py
, the functionparallel_model
is a wrapper which handles multiple computing devices.But, if I understand it correctly, it achieves the following effect:
Suppose function
model_fn
returns k scalars (maybe multiple losses or metrics, e.g.: accuracy), i.e.: a tuple (o1, o2, ..., ok),And m devices are available: d1, d2, ..., dm.
Then the return value of the function is of shape:
You see, in the second case, the return value is weird. Say, if my
model_fn
has 2 return values, in the multiple-device cases, I can usesharded_loss1, sharded_loss2 = parallel.parallel_model(fn, features, device_list)
to catch these two losses; but if I only specify a single device from command line, the code breaks.Certainly, I can judge
isinstance(return_value_from_parallel_model, tuple)
and decide how to deal with the return value, but this is stupid. It would be better to return ([o1_d1], [o2_d1], ..., [ok_d1]), i.e.: a tuple of lists, in the "multiple return values + single device" case, which leads to a more consistent design.Hope I've made myself clear.
The text was updated successfully, but these errors were encountered: