Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full codegen erf, erfc, erfinv, and exp #3659

Merged
merged 1 commit into from
Jun 23, 2022
Merged

Full codegen erf, erfc, erfinv, and exp #3659

merged 1 commit into from
Jun 23, 2022

Conversation

wonjoolee95
Copy link
Collaborator

Full codegen erf, erfc, erfinv, and exp


Generated LazyIr.h:

class Erf : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::erf);
  }

  Erf(const torch::lazy::Value& self, std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::erf),
              {self}, std::move(shapes),
              [&]() { return ErfOutputShape(self); },
              /* num_outputs */ 1,
              torch::lazy::MHash())
  {
    
  }

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();
    
    return ss.str();
  }

  

  bool CanBeReused(const torch::lazy::Value& self) const {
    return false;
    }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;

  
  

};

class Erfc : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::erfc);
  }

  Erfc(const torch::lazy::Value& self, std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::erfc),
              {self}, std::move(shapes),
              [&]() { return ErfcOutputShape(self); },
              /* num_outputs */ 1,
              torch::lazy::MHash())
  {
    
  }

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();
    
    return ss.str();
  }

  

  bool CanBeReused(const torch::lazy::Value& self) const {
    return false;
    }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;

  
  

};

class Erfinv : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::erfinv);
  }

  Erfinv(const torch::lazy::Value& self, std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::erfinv),
              {self}, std::move(shapes),
              [&]() { return ErfinvOutputShape(self); },
              /* num_outputs */ 1,
              torch::lazy::MHash())
  {
    
  }

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();
    
    return ss.str();
  }

  

  bool CanBeReused(const torch::lazy::Value& self) const {
    return false;
    }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;

  
  

};

class Exp : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::exp);
  }

  Exp(const torch::lazy::Value& self, std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::exp),
              {self}, std::move(shapes),
              [&]() { return ExpOutputShape(self); },
              /* num_outputs */ 1,
              torch::lazy::MHash())
  {
    
  }

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();
    
    return ss.str();
  }

  

  bool CanBeReused(const torch::lazy::Value& self) const {
    return false;
    }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;

  
  

};

Generated XLANativeFunctions.cpp:

    at::Tensor XLANativeFunctions::erf(const at::Tensor & self) {
        
        XLA_FN_COUNTER("xla::");
        auto common_device = torch_xla::bridge::GetXlaDevice(self);
        TORCH_INTERNAL_ASSERT(common_device);
        
        torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
        torch::lazy::NodePtr node = torch::lazy::ReuseNode<Erf>(lazy_self->GetIrValue());
        if (!node) {
            auto out_meta = at::meta::erf(self);
            std::vector<torch::lazy::Shape> shapes{
        torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
            TORCH_INTERNAL_ASSERT(shapes.size() == 1);
            if(torch::lazy::symbolicShapeEnabled()){
                std::vector<torch::jit::IValue> inputs = { self };
                const char* schema_str = "aten::erf(Tensor self) -> Tensor";
                applySymbolicShapesOnLT(schema_str, inputs, shapes);
            }
        
            node = torch::lazy::MakeNode<Erf>(lazy_self->GetIrValue(), std::move(shapes));
            CacheNode(node);
        }
        
        auto result = torch_xla::bridge::AtenFromXlaTensor(
                torch_xla::XLATensor::Create(std::move(node), *common_device));
        return result;
    };

    
    at::Tensor XLANativeFunctions::erfc(const at::Tensor & self) {
        
        XLA_FN_COUNTER("xla::");
        auto common_device = torch_xla::bridge::GetXlaDevice(self);
        TORCH_INTERNAL_ASSERT(common_device);
        
        torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
        torch::lazy::NodePtr node = torch::lazy::ReuseNode<Erfc>(lazy_self->GetIrValue());
        if (!node) {
            auto out_meta = at::meta::erfc(self);
            std::vector<torch::lazy::Shape> shapes{
        torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
            TORCH_INTERNAL_ASSERT(shapes.size() == 1);
            if(torch::lazy::symbolicShapeEnabled()){
                std::vector<torch::jit::IValue> inputs = { self };
                const char* schema_str = "aten::erfc(Tensor self) -> Tensor";
                applySymbolicShapesOnLT(schema_str, inputs, shapes);
            }
        
            node = torch::lazy::MakeNode<Erfc>(lazy_self->GetIrValue(), std::move(shapes));
            CacheNode(node);
        }
        
        auto result = torch_xla::bridge::AtenFromXlaTensor(
                torch_xla::XLATensor::Create(std::move(node), *common_device));
        return result;
    };

    
    at::Tensor XLANativeFunctions::erfinv(const at::Tensor & self) {
        
        XLA_FN_COUNTER("xla::");
        auto common_device = torch_xla::bridge::GetXlaDevice(self);
        TORCH_INTERNAL_ASSERT(common_device);
        
        torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
        torch::lazy::NodePtr node = torch::lazy::ReuseNode<Erfinv>(lazy_self->GetIrValue());
        if (!node) {
            auto out_meta = at::meta::erfinv(self);
            std::vector<torch::lazy::Shape> shapes{
        torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
            TORCH_INTERNAL_ASSERT(shapes.size() == 1);
            if(torch::lazy::symbolicShapeEnabled()){
                std::vector<torch::jit::IValue> inputs = { self };
                const char* schema_str = "aten::erfinv(Tensor self) -> Tensor";
                applySymbolicShapesOnLT(schema_str, inputs, shapes);
            }
        
            node = torch::lazy::MakeNode<Erfinv>(lazy_self->GetIrValue(), std::move(shapes));
            CacheNode(node);
        }
        
        auto result = torch_xla::bridge::AtenFromXlaTensor(
                torch_xla::XLATensor::Create(std::move(node), *common_device));
        return result;
    };

    
    at::Tensor XLANativeFunctions::exp(const at::Tensor & self) {
        
        XLA_FN_COUNTER("xla::");
        auto common_device = torch_xla::bridge::GetXlaDevice(self);
        TORCH_INTERNAL_ASSERT(common_device);
        
        torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
        torch::lazy::NodePtr node = torch::lazy::ReuseNode<Exp>(lazy_self->GetIrValue());
        if (!node) {
            auto out_meta = at::meta::exp(self);
            std::vector<torch::lazy::Shape> shapes{
        torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
            TORCH_INTERNAL_ASSERT(shapes.size() == 1);
            if(torch::lazy::symbolicShapeEnabled()){
                std::vector<torch::jit::IValue> inputs = { self };
                const char* schema_str = "aten::exp(Tensor self) -> Tensor";
                applySymbolicShapesOnLT(schema_str, inputs, shapes);
            }
        
            node = torch::lazy::MakeNode<Exp>(lazy_self->GetIrValue(), std::move(shapes));
            CacheNode(node);
        }
        
        auto result = torch_xla::bridge::AtenFromXlaTensor(
                torch_xla::XLATensor::Create(std::move(node), *common_device));
        return result;
    };

@wonjoolee95 wonjoolee95 self-assigned this Jun 22, 2022
@@ -41,6 +41,22 @@ xla::Shape CoshOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
}

xla::Shape ErfOutputShape(const torch::lazy::Value& input) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should just have a macro for all of these functions that just return the input shape.. we can do that in a separate pr.

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@wonjoolee95 wonjoolee95 merged commit 05fa2aa into master Jun 23, 2022
@wonjoolee95 wonjoolee95 deleted the codegen-6 branch June 23, 2022 00:04
This was referenced Jul 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants