The parallelism tutorial mentions this code to forward the attributes of a DataParallel object to its wrapped module:
class MyDataParallel(nn.DataParallel):
def __getattr__(self, name):
return getattr(self.module, name)
This, however, leads to a recursion error, as self.module will call the same __getattr__ again. I think it should be:
class MyDataParallel(DataParallel):
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.module, name)
As was discussed here.
The parallelism tutorial mentions this code to forward the attributes of a
DataParallelobject to its wrapped module:This, however, leads to a recursion error, as
self.modulewill call the same__getattr__again. I think it should be:As was discussed here.