-
Notifications
You must be signed in to change notification settings - Fork 2.3k
/
expr.h
126 lines (88 loc) · 2.3 KB
/
expr.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#pragma once
#include "taichi/lang_util.h"
TLANG_NAMESPACE_BEGIN
class Expression;
class Identifier;
class ExprGroup;
class SNode;
class Expr {
public:
std::shared_ptr<Expression> expr;
bool const_value;
bool atomic;
Expr() {
const_value = false;
atomic = false;
}
Expr(int32 x);
Expr(int64 x);
Expr(float32 x);
Expr(float64 x);
Expr(std::shared_ptr<Expression> expr) : Expr() {
this->expr = expr;
}
Expr(const Expr &o) : Expr() {
set(o);
const_value = o.const_value;
}
Expr(Expr &&o) : Expr() {
set(o);
const_value = o.const_value;
atomic = o.atomic;
}
Expr(const Identifier &id);
void set(const Expr &o) {
expr = o.expr;
}
Expression *operator->() {
return expr.get();
}
Expression const *operator->() const {
return expr.get();
}
template <typename T>
std::shared_ptr<T> cast() const {
TI_ASSERT(expr != nullptr);
return std::dynamic_pointer_cast<T>(expr);
}
template <typename T>
bool is() const {
return cast<T>() != nullptr;
}
Expr &operator=(const Expr &o);
Expr operator[](const ExprGroup &indices) const;
std::string serialize() const;
void operator+=(const Expr &o);
void operator-=(const Expr &o);
void operator*=(const Expr &o);
void operator/=(const Expr &o);
Expr operator!();
Expr eval() const;
template <typename T, typename... Args>
static Expr make(Args &&... args) {
return Expr(std::make_shared<T>(std::forward<Args>(args)...));
}
Expr parent() const;
SNode *snode() const;
void declare(DataType dt);
// traceback for type checking error message
void set_tb(const std::string &tb);
void set_grad(const Expr &o);
void set_attribute(const std::string &key, const std::string &value);
std::string get_attribute(const std::string &key) const;
};
Expr select(const Expr &cond, const Expr &true_val, const Expr &false_val);
Expr operator-(const Expr &expr);
Expr operator~(const Expr &expr);
// Value cast
Expr cast(const Expr &input, DataType dt);
template <typename T>
Expr cast(const Expr &input) {
return taichi::lang::cast(input, get_data_type<T>());
}
Expr bit_cast(const Expr &input, DataType dt);
template <typename T>
Expr bit_cast(const Expr &input) {
return taichi::lang::bit_cast(input, get_data_type<T>());
}
TLANG_NAMESPACE_END