Skip to content

Conversation

@JackCaoG
Copy link
Collaborator

Fix #3877
Fix #3878

LazyIr

class EqScalar : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::eq);
  }

  EqScalar(const torch::lazy::Value& self, const torch::lazy::Value& other, std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::eq),
              {self, other}, std::move(shapes),
              [&]() { return EqScalarOutputShape(self, other); },
              /* num_outputs */ 1,
              torch::lazy::MHash())
  {
    
  }

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();
    
    return ss.str();
  }

  

  bool CanBeReused(const torch::lazy::Value& self, const torch::lazy::Value& other) const {
    return false;
    }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;

  
  

};

class EqTensor : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::eq);
  }

  EqTensor(const torch::lazy::Value& self, const torch::lazy::Value& other, std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::eq),
              {self, other}, std::move(shapes),
              [&]() { return EqTensorOutputShape(self, other); },
              /* num_outputs */ 1,
              torch::lazy::MHash())
  {
    
  }

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();
    
    return ss.str();
  }

  

  bool CanBeReused(const torch::lazy::Value& self, const torch::lazy::Value& other) const {
    return false;
    }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;

  
  

};

XLANativeFunction

at::Tensor XLANativeFunctions::eq(const at::Tensor& self,
                                  const at::Scalar& other) {
  XLA_FN_COUNTER("xla::");
  auto common_device = torch_xla::bridge::GetXlaDevice(self);
  TORCH_INTERNAL_ASSERT(common_device);

  torch_xla::XLATensorPtr lazy_self =
      torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self,
                                                              *common_device);
  auto node_other =
      torch::lazy::LazyGraphExecutor::Get()->GetIrValueForScalarFromCodegen(
          other, *common_device);
  torch::lazy::NodePtr node =
      torch::lazy::ReuseNode<EqScalar>(lazy_self->GetIrValue(), node_other);
  if (!node) {
    auto self_meta = to_meta(self);
    auto out_meta = at::meta::eq(self_meta, other);

    std::vector<torch::lazy::Shape> shapes{
        torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
    TORCH_INTERNAL_ASSERT(shapes.size() == 1);
    if (torch::lazy::symbolicShapeEnabled()) {
      std::vector<torch::jit::IValue> inputs = {self, other};
      const char* schema_str =
          "aten::eq.Scalar(Tensor self, Scalar other) -> Tensor";
      applySymbolicShapesOnLT(schema_str, inputs, shapes);
    }

    node = torch::lazy::MakeNode<EqScalar>(lazy_self->GetIrValue(), node_other,
                                           std::move(shapes));
    CacheNode(node);
  }

  auto result = torch_xla::bridge::AtenFromXlaTensor(
      torch_xla::XLATensor::Create(std::move(node), *common_device));
  return result;
};

at::Tensor XLANativeFunctions::eq(const at::Tensor& self,
                                  const at::Tensor& other) {
  XLA_FN_COUNTER("xla::");
  auto common_device = torch_xla::bridge::GetXlaDevice(self, other);
  TORCH_INTERNAL_ASSERT(common_device);

  torch_xla::XLATensorPtr lazy_self =
      torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self,
                                                              *common_device);
  torch_xla::XLATensorPtr lazy_other =
      torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(other,
                                                              *common_device);
  torch::lazy::NodePtr node = torch::lazy::ReuseNode<EqTensor>(
      lazy_self->GetIrValue(), lazy_other->GetIrValue());
  if (!node) {
    auto self_meta = to_meta(self);
    auto other_meta = to_meta(other);
    auto out_meta = at::meta::eq(self_meta, other_meta);

    std::vector<torch::lazy::Shape> shapes{
        torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
    TORCH_INTERNAL_ASSERT(shapes.size() == 1);
    if (torch::lazy::symbolicShapeEnabled()) {
      std::vector<torch::jit::IValue> inputs = {self, other};
      const char* schema_str =
          "aten::eq.Tensor(Tensor self, Tensor other) -> Tensor";
      applySymbolicShapesOnLT(schema_str, inputs, shapes);
    }

    node = torch::lazy::MakeNode<EqTensor>(
        lazy_self->GetIrValue(), lazy_other->GetIrValue(), std::move(shapes));
    CacheNode(node);
  }

  auto result = torch_xla::bridge::AtenFromXlaTensor(
      torch_xla::XLATensor::Create(std::move(node), *common_device));
  return result;
};

@JackCaoG
Copy link
Collaborator Author

This is a bit funny, TestNNDeviceTypeXLA.test_Dropout2d_xla which used to fail for pytorch/xla now started to pass after this pr (I have no idea why). The test was actually expect to fail so it throw an error when seeing test actually passed.. I will rerun CI, if that remain to be the case I will just disable the test on our end, remove the expect_fail_xla from pytorch end and then reenable the test on our end..

@JackCaoG JackCaoG requested a review from wonjoo-wj August 18, 2022 22:36
@JackCaoG
Copy link
Collaborator Author

@wonjoolee95 I think this one is ready for review too.

Copy link
Collaborator

@wonjoo-wj wonjoo-wj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@JackCaoG JackCaoG merged commit c91bf61 into master Aug 18, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

eq.Tensor eq.Scalar

3 participants