-
Notifications
You must be signed in to change notification settings - Fork 480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Equivalent of get_worker_info to split an IterableDataset #7667
Comments
I felt like you are looking for Lines 135 to 192 in 34736f0
For the up to date master api you can also check https://pytorch.org/xla/master/#module-torch_xla.runtime |
@will-cromar @zpcore do you know where |
The worker attributes are setup when we initialize the dataloader: Since we are using torch's dataloader: xla/test/test_train_mp_imagenet.py Lines 235 to 252 in 1651e76
|
Thanks, those were the functions I was looking for. A cartoon version of my solution is the following: class MyDataset(torch.utils.data.IterableDataset):
def __init__(self):
super().__init__()
self.N = 100
self.data = torch.rand(self.N, 30)
def __iter__(self):
for i in range(self.N):
if i % xr.world_size() == xm.get_ordinal():
yield self.data[i]
def _mp_fn_(index):
device = xm.xla_device()
dataset = MyDataset()
dataloader = torch.utils.data.DataLoader(dataset, batch_size = 10)
device_loader = pl.MpDeviceLoader(dataloader, device)
for epoch in range(3):
mysum = torch.tensor(0., device = device)
for batch in device_loader:
mysum += batch.sum()
sumsum = xm.all_reduce(xm.REDUCE_SUM, mysum).item()
print(epoch, sumsum) This runs fine... my new issue is on my real data when I hit the |
can you always do a |
Hmm so now it hangs on that |
that's... interesting. It usually mean the graph is different for each device. Can you dump the HLO following https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#common-debugging-environment-variables-combinations? You should multiple files. |
Hmm each device could end up processing a (slightly) different number of batches, I suppose that technically makes the graph different? I'll figure out getting the HLO and report back. |
OK HLO files are here. LMK if anything looks suspicious! In the meantime I'll see if I can get eager mode -> compilation going with the nightly build. |
Nightly build's (2.5.something) |
hmm no that's not expected, I am on nightly and if I do
I can see 4 processes
(each process prints their own loss) |
btw I check your HLO, the last computation is the same
which is just a simple |
Thanks, let me take a look tmr. |
I am able to repo, let me look into it a bit. |
Yes definitely possible. I don't think it would be different by more than 1 batch but presumably that's bad enough. Is there an easy workaround to that or should I just set up the dataloader to ensure an equal number of batches? Thanks |
we usually just drop the last batch to make every process execute the same. It is required to have each process to execute the same number of graph, otherwise collective ops will be confuse. For example |
Yup makes sense. It's slightly less straightforward to do here since it's an |
You can close this. In case others find this useful, this is my modified training loop to end the epoch when any of the individual devices are done/exhausted: enum_dataloader = enumerate(dataloader)
while True:
try:
step_i, dat = next(enum_dataloader)
done = torch.tensor(0, dtype=torch.int32, device = self.device)
except StopIteration:
done = torch.tensor(1, dtype=torch.int32, device = self.device)
# Synchronize the flag across all workers
done = xm.all_reduce(xm.REDUCE_MAX, done)
# If any worker is done (including me), break the loop
if done.item() == 1:
break
# remaining training code... |
❓ Questions and Help
I have an
IterableDataset
of unknown size. I would like to use something liketorch.utils.data.get_worker_info
to split it across the spawnedxmp
processes, but AFAIK there is no equivalent inxla_multiprocessing
. Is there a workaround? I tried randomly subsampling on each process but this hangs for me for some reason.The text was updated successfully, but these errors were encountered: