Skip to content

Conversation

@wonjoo-wj
Copy link
Collaborator

@wonjoo-wj wonjoo-wj commented Jul 29, 2022

Codegen relu and relu_ #3797


LazyIr.h:

class Relu : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::relu);
  }

  Relu(const torch::lazy::Value& self, std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::relu),
              {self}, std::move(shapes),
              [&]() { return ReluOutputShape(self); },
              /* num_outputs */ 1,
              torch::lazy::MHash())
  {
    
  }

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();
    
    return ss.str();
  }

  

  bool CanBeReused(const torch::lazy::Value& self) const {
    return false;
    }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;

  
  

};

XLANativeFunctions.cpp:

at::Tensor XLANativeFunctions::relu(const at::Tensor & self) {
        
        XLA_FN_COUNTER("xla::");
        auto common_device = torch_xla::bridge::GetXlaDevice(self);
        TORCH_INTERNAL_ASSERT(common_device);
        
        torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
        torch::lazy::NodePtr node = torch::lazy::ReuseNode<Relu>(lazy_self->GetIrValue());
        if (!node) {
            
            auto shapes = torch::lazy::compute_shape_relu(self);
            TORCH_INTERNAL_ASSERT(shapes.size() == 1);
            if(torch::lazy::symbolicShapeEnabled()){
                std::vector<torch::jit::IValue> inputs = { self };
                const char* schema_str = "aten::relu(Tensor self) -> Tensor";
                applySymbolicShapesOnLT(schema_str, inputs, shapes);
            }
        
            node = torch::lazy::MakeNode<Relu>(lazy_self->GetIrValue(), std::move(shapes));
            CacheNode(node);
        }
        
        auto result = torch_xla::bridge::AtenFromXlaTensor(
                torch_xla::XLATensor::Create(std::move(node), *common_device));
        return result;
    };

@wonjoo-wj wonjoo-wj marked this pull request as ready for review July 29, 2022 17:06
});
}

TEST_F(TensorTest, TestRelu) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This codegen removes the relu function in the tensor_methods, so would it be okay to remove these tests here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is oK

@wonjoo-wj wonjoo-wj self-assigned this Jul 29, 2022
@wonjoo-wj wonjoo-wj linked an issue Jul 29, 2022 that may be closed by this pull request
Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@wonjoo-wj wonjoo-wj merged commit 11602fc into master Aug 1, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Codegen relu and relu_

3 participants