Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Equivalent of get_worker_info to split an IterableDataset #7667

Closed
davidaknowles opened this issue Jul 10, 2024 · 20 comments
Closed

Equivalent of get_worker_info to split an IterableDataset #7667

davidaknowles opened this issue Jul 10, 2024 · 20 comments

Comments

@davidaknowles
Copy link

❓ Questions and Help

I have an IterableDataset of unknown size. I would like to use something like torch.utils.data.get_worker_info to split it across the spawned xmp processes, but AFAIK there is no equivalent in xla_multiprocessing. Is there a workaround? I tried randomly subsampling on each process but this hangs for me for some reason.

@JackCaoG
Copy link
Collaborator

I felt like you are looking for

xla/torch_xla/runtime.py

Lines 135 to 192 in 34736f0

@requires_pjrt
def local_process_count() -> int:
"""Returns the number of processes running on this host."""
return xu.getenv_as(xenv.PJRT_LOCAL_PROCESS_COUNT, int, defval=1)
@requires_pjrt
def global_device_count() -> int:
"""Returns the total number of devices across all processes/hosts."""
return len(torch_xla._XLAC._xla_get_all_devices())
@requires_pjrt
def world_size() -> int:
"""Returns the total number of processes participating in the job."""
if torch_xla._XLAC._xla_get_replication_devices_count() == 0:
return 1
return global_device_count()
@requires_pjrt
def local_device_count() -> int:
"""Returns the total number of devices on this host.
Assumes each process has the same number of addressable devices.
"""
return local_process_count() * addressable_device_count()
@requires_pjrt
def addressable_device_count() -> int:
"""Returns the number of devices visible to this process."""
return torch_xla._XLAC._xla_num_devices()
@requires_pjrt
def global_ordinal() -> int:
"""Returns global ordinal of this thread within all processes.
Global ordinal is in range [0, global_device_count). Global ordinals are not
guaranteed to have any predictable relationship to the TPU worker ID nor are
they guaranteed to be contiguous on each host."""
return torch_xla._XLAC._xla_get_default_device_ordinal()
@requires_pjrt
def local_ordinal() -> int:
"""Returns local ordinal of this thread within this host.
Local ordinal is in range [0, local_device_count)."""
local_rank = xu.getenv_as(xenv.PJRT_LOCAL_PROCESS_RANK, int, 0)
devices_per_process = addressable_device_count()
return local_rank * devices_per_process + xla_device().index
@requires_pjrt
def process_index() -> int:
return torch_xla._XLAC._xla_get_process_index()
.

For the up to date master api you can also check https://pytorch.org/xla/master/#module-torch_xla.runtime

@JackCaoG
Copy link
Collaborator

@will-cromar @zpcore do you know where torch.utils.data gets that info? Wondering if we can do some mapping and also support that api.

@zpcore
Copy link
Collaborator

zpcore commented Jul 12, 2024

The worker attributes are setup when we initialize the dataloader:
https://github.com/pytorch/pytorch/blob/7c289c2a5c4e2233251565afadc2d95acf64b8c1/torch/utils/data/dataloader.py#L1113-L1128.

Since we are using torch's dataloader:

train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=FLAGS.batch_size,
sampler=train_sampler,
drop_last=FLAGS.drop_last,
shuffle=False if train_sampler else True,
num_workers=FLAGS.num_workers,
persistent_workers=FLAGS.persistent_workers,
prefetch_factor=FLAGS.prefetch_factor)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=FLAGS.test_set_batch_size,
sampler=test_sampler,
drop_last=FLAGS.drop_last,
shuffle=False,
num_workers=FLAGS.num_workers,
persistent_workers=FLAGS.persistent_workers,
prefetch_factor=FLAGS.prefetch_factor)
, I think it should contain the worker info. I can do a test on the real data to see if it is there or not.

@davidaknowles
Copy link
Author

Thanks, those were the functions I was looking for. A cartoon version of my solution is the following:

class MyDataset(torch.utils.data.IterableDataset):

    def __init__(self):
        super().__init__()
        self.N = 100
        self.data = torch.rand(self.N, 30) 

    def __iter__(self): 
        for i in range(self.N): 
            if i % xr.world_size() == xm.get_ordinal(): 
                yield self.data[i]

def _mp_fn_(index): 

    device = xm.xla_device()
    dataset = MyDataset()
    dataloader = torch.utils.data.DataLoader(dataset, batch_size = 10)
    device_loader = pl.MpDeviceLoader(dataloader, device)

    for epoch in range(3): 
        mysum = torch.tensor(0., device = device) 
        for batch in device_loader: 
            mysum += batch.sum()
        sumsum = xm.all_reduce(xm.REDUCE_SUM, mysum).item()
        print(epoch, sumsum)

This runs fine... my new issue is on my real data when I hit the .item() it hangs. mysum here is meant to be the total loss for the data processed on the current device, and then sumsum is the total loss for the epoch (across all devices). Maybe there's a better pattern for getting the total loss?

@JackCaoG
Copy link
Collaborator

can you always do a xm.mark_step() or torch_xla.sync() before you do the .item call. It is always recommend to flush the pending executions before accessing the value of the tensor.

@davidaknowles
Copy link
Author

Hmm so now it hangs on that mark_step() instead. Well, it gets past the mark_step() on one device but the other 3 hang.

@JackCaoG
Copy link
Collaborator

that's... interesting. It usually mean the graph is different for each device. Can you dump the HLO following https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#common-debugging-environment-variables-combinations? You should multiple files.

@davidaknowles
Copy link
Author

Hmm each device could end up processing a (slightly) different number of batches, I suppose that technically makes the graph different? I'll figure out getting the HLO and report back.

@davidaknowles
Copy link
Author

OK HLO files are here. LMK if anything looks suspicious! In the meantime I'll see if I can get eager mode -> compilation going with the nightly build.

@davidaknowles
Copy link
Author

Nightly build's (2.5.something) torch_xla.distributed.xla_multiprocessing is only giving me access to 1 of 4 devices, is that expected?

@JackCaoG
Copy link
Collaborator

JackCaoG commented Jul 19, 2024

hmm no that's not expected, I am on nightly and if I do

python examples/data_parallel/train_resnet_xla_ddp.py

I can see 4 processes

epoch: 1, step: 190, loss: 6.7231669425964355, rate: 1746.296055355192
epoch: 1, step: 190, loss: 6.705419540405273, rate: 1746.3170991653592
epoch: 1, step: 190, loss: 6.700830459594727, rate: 1745.7355188993108
epoch: 1, step: 190, loss: 6.731178283691406, rate: 1746.154144282245

(each process prints their own loss)

@JackCaoG
Copy link
Collaborator

btw I check your HLO, the last computation is the same

HloModule IrToHlo.14, entry_computation_layout={(f32[], f32[])->(f32[])}

%AddComputation.6 (x.7: f32[], y.8: f32[]) -> f32[] {
  %x.7 = f32[] parameter(0)
  %y.8 = f32[] parameter(1)
  ROOT %add.9 = f32[] add(f32[] %x.7, f32[] %y.8)
}

ENTRY %IrToHlo.14 (p0.1: f32[], p1.2: f32[]) -> (f32[]) {
  %p1.2 = f32[] parameter(1), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="/home/daknowles/.local/lib/python3.10/site-packages/torch/_ops.py" source_line=854}
  %p0.1 = f32[] parameter(0), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="/home/daknowles/.local/lib/python3.10/site-packages/torch/_ops.py" source_line=854}
  %tuple.3 = (f32[], f32[]) tuple(f32[] %p1.2, f32[] %p0.1), metadata={op_type="xla__cross_replica_sum" op_name="xla__cross_replica_sum" source_file="/home/daknowles/.local/lib/python3.10/site-packages/torch/_ops.py" source_line=854}
  %get-tuple-element.4 = f32[] get-tuple-element((f32[], f32[]) %tuple.3), index=0, metadata={op_type="xla__cross_replica_sum" op_name="xla__cross_replica_sum" source_file="/home/daknowles/.local/lib/python3.10/site-packages/torch/_ops.py" source_line=854}
  %get-tuple-element.5 = f32[] get-tuple-element((f32[], f32[]) %tuple.3), index=1, metadata={op_type="xla__cross_replica_sum" op_name="xla__cross_replica_sum" source_file="/home/daknowles/.local/lib/python3.10/site-packages/torch/_ops.py" source_line=854}
  %all-reduce.10 = (f32[], f32[]) all-reduce(f32[] %get-tuple-element.4, f32[] %get-tuple-element.5), replica_groups={}, constrain_layout=true, to_apply=%AddComputation.6, metadata={op_type="xla__cross_replica_sum" op_name="xla__cross_replica_sum" source_file="/home/daknowles/.local/lib/python3.10/site-packages/torch/_ops.py" source_line=854}
  %get-tuple-element.12 = f32[] get-tuple-element((f32[], f32[]) %all-reduce.10), index=1, metadata={op_type="xla__cross_replica_sum" op_name="xla__cross_replica_sum" source_file="/home/daknowles/.local/lib/python3.10/site-packages/torch/_ops.py" source_line=854}
  %get-tuple-element.11 = f32[] get-tuple-element((f32[], f32[]) %all-reduce.10), index=0, metadata={op_type="xla__cross_replica_sum" op_name="xla__cross_replica_sum" source_file="/home/daknowles/.local/lib/python3.10/site-packages/torch/_ops.py" source_line=854}
  ROOT %tuple.13 = (f32[]) tuple(f32[] %get-tuple-element.11)
}

which is just a simple all_reduce.. I can't really tell why it hang. Do you have a repo I can try on my end? The model code can just be dummy model code or you can use one of my examples in https://github.com/pytorch/xla/blob/master/examples/data_parallel/train_resnet_spmd_data_parallel.py

@davidaknowles
Copy link
Author

Hi @JackCaoG - I made a minimal branch of my repo here. Hopefully it's straightforward to test with the info in the README. Thanks!

@JackCaoG
Copy link
Collaborator

Thanks, let me take a look tmr.

@JackCaoG
Copy link
Collaborator

I am able to repo, let me look into it a bit.

@JackCaoG
Copy link
Collaborator

One thing I realized by running

alias save_hlo="XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 XLA_SAVE_TENSORS_FMT='hlo' XLA_SAVE_TENSORS_FILE='/tmp/save1.hlo'"
alias cpplog="TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=\"xla_graph_executor=5,pjrt_computation_client=3\""

cpplog save_hlo PT_XLA_DEBUG=1 python train.py

is that each process is running differnt number of graphs. In process 1(by checking /tmp/save1.hlo.0), it execute 36 graphs
image
in process 1 it is 37
image

This explains why all_reudce will hang because the number of graphs is differernt. Is your dataloader setting up in a way that each process gets different number of batches?

@davidaknowles
Copy link
Author

Yes definitely possible. I don't think it would be different by more than 1 batch but presumably that's bad enough.

Is there an easy workaround to that or should I just set up the dataloader to ensure an equal number of batches?

Thanks

@JackCaoG
Copy link
Collaborator

JackCaoG commented Jul 29, 2024

we usually just drop the last batch to make every process execute the same. It is required to have each process to execute the same number of graph, otherwise collective ops will be confuse. For example TPU:1 expects TPU:0 to join a all_reudce but TPU:0 has moved to a new graph that was trying to all_reduce a different tensor, that will either produce an incorrect result or hang forever.

@davidaknowles
Copy link
Author

Yup makes sense. It's slightly less straightforward to do here since it's an IterableDataset where I don't know the total number of samples (or equivalently batches) globally. I can figure that out though somehow I'm sure now I understand what the issue is.

@davidaknowles
Copy link
Author

You can close this. In case others find this useful, this is my modified training loop to end the epoch when any of the individual devices are done/exhausted:

      enum_dataloader = enumerate(dataloader)
      
      while True:

          try: 
              step_i, dat = next(enum_dataloader)
              done = torch.tensor(0, dtype=torch.int32, device = self.device)
          except StopIteration:
              done = torch.tensor(1, dtype=torch.int32, device = self.device)
          
          # Synchronize the flag across all workers
          done = xm.all_reduce(xm.REDUCE_MAX, done)
          
          # If any worker is done (including me), break the loop
          if done.item() == 1:
              break 

          # remaining training code...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants