- 
                Notifications
    
You must be signed in to change notification settings  - Fork 559
 
[SPMD] fixes bugs with AssignIrValue & ExecuteReplicated #4233
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
aa6a0da    to
    62f2179      
    Compare
  
            
          
                torch_xla/csrc/tensor.cpp
              
                Outdated
          
        
      | if (!ir_value) { | ||
| ir_value = CreateTensorNode(CurrentXlaData(), /*read_only=*/false); | ||
| } | ||
| XLA_CHECK(ir_value.node != nullptr) << "Tyring to access a null cursor"; | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, so this line is to force the sharding on the old IR before it is being replaced? I am confuse because we did not clear the sharding after the new ir_value being assigned.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The intent is to force the same sharding on the new IR (input), per a common request like this, AssignIrValue(torch::lazy::Value()). Did I get it reversed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh ok, I read it backward, you re trying to populate the sharding spec to the new IR being assigned.
b083799    to
    63773a9      
    Compare
  
            
          
                test/test_xla_sharding.py
              
                Outdated
          
        
      | self.assertEqual(sharding_spec, torch_xla._XLAC._get_xla_sharding_spec(xt)) | ||
| 
               | 
          ||
| xt.add_(1) # inplace update | ||
| xm.mark_step() # resets IR value | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need the mark_step here? After the xt.add_(1), if you print the IR and HLO by
print(torch_xla._XLAC._get_xla_tensors_text([xt]))
print(torch_xla._XLAC._get_xla_tensors_hlo([xt]))
you should see the sharding spec on the output?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It works either way, before or after -- and checks for the sharding annotation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly LGTM, approve it to unblock the merge once test is green. You might need to rebase.
63773a9    to
    a16dc28      
    Compare
  
    a16dc28    to
    f167e30      
    Compare
  
    
This fixes a couple of bugs in
AssignIrValueandExecuteReplicatedfor sharding, to enablemark_step()with SPMD. Note that this doesn't address sharding propagation through views, which will be handled later.