forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 66
/
Copy pathSymBool.h
154 lines (126 loc) · 3.91 KB
/
SymBool.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#pragma once
#include <c10/core/SymNodeImpl.h>
#include <c10/macros/Export.h>
#include <c10/util/Exception.h>
#include <c10/util/intrusive_ptr.h>
#include <cstdint>
#include <optional>
#include <ostream>
#include <utility>
namespace c10 {
class C10_API SymBool {
public:
/*implicit*/ SymBool(bool b) : data_(b) {}
SymBool(SymNode ptr) : data_(false), ptr_(std::move(ptr)) {
TORCH_CHECK(ptr_->is_bool());
}
SymBool() : data_(false) {}
SymNodeImpl* toSymNodeImplUnowned() const {
return ptr_.get();
}
SymNodeImpl* release() && {
return std::move(ptr_).release();
}
// Only valid if is_heap_allocated()
SymNode toSymNodeImpl() const;
// Guaranteed to return a SymNode, wrapping using base if necessary
SymNode wrap_node(const SymNode& base) const;
bool expect_bool() const {
std::optional<bool> c = maybe_as_bool();
TORCH_CHECK(c.has_value());
return *c;
}
SymBool sym_and(const SymBool&) const;
SymBool sym_or(const SymBool&) const;
SymBool sym_not() const;
SymBool operator&(const SymBool& other) const {
return sym_and(other);
}
SymBool operator|(const SymBool& other) const {
return sym_or(other);
}
SymBool operator||(const SymBool& other) const {
return sym_or(other);
}
SymBool operator~() const {
return sym_not();
}
// Insert a guard for the bool to be its concrete value, and then return
// that value. Note that C++ comparison operations default to returning
// bool, so it's not so common to have to call this
bool guard_bool(const char* file, int64_t line) const;
bool expect_true(const char* file, int64_t line) const;
bool guard_size_oblivious(const char* file, int64_t line) const;
bool guard_or_false(const char* file, int64_t line) const;
bool guard_or_true(const char* file, int64_t line) const;
bool has_hint() const;
bool as_bool_unchecked() const {
return data_;
}
std::optional<bool> maybe_as_bool() const {
if (!is_heap_allocated()) {
return data_;
}
return toSymNodeImplUnowned()->constant_bool();
}
bool is_heap_allocated() const {
return ptr_;
}
private:
// TODO: optimize to union
bool data_;
SymNode ptr_;
};
C10_API std::ostream& operator<<(std::ostream& os, const SymBool& s);
#define TORCH_SYM_CHECK(cond, ...) \
TORCH_CHECK((cond).expect_true(__FILE__, __LINE__), __VA_ARGS__)
#define TORCH_SYM_INTERNAL_ASSERT(cond, ...) \
TORCH_INTERNAL_ASSERT((cond).expect_true(__FILE__, __LINE__), __VA_ARGS__)
#define TORCH_MAYBE_SYM_CHECK(cond, ...) \
if constexpr (std::is_same_v<std::decay_t<decltype(cond)>, SymBool>) { \
TORCH_CHECK((cond).expect_true(__FILE__, __LINE__), __VA_ARGS__) \
} else { \
TORCH_CHECK((cond), __VA_ARGS__) \
}
inline bool guard_size_oblivious(
bool b,
const char* file [[maybe_unused]],
int64_t line [[maybe_unused]]) {
return b;
}
inline bool guard_size_oblivious(
const c10::SymBool& b,
const char* file,
int64_t line) {
return b.guard_size_oblivious(file, line);
}
inline bool guard_or_false(
bool b,
const char* file [[maybe_unused]],
int64_t line [[maybe_unused]]) {
return b;
}
inline bool guard_or_false(
const c10::SymBool& b,
const char* file,
int64_t line) {
return b.guard_or_false(file, line);
}
inline bool guard_or_true(
bool b,
const char* file [[maybe_unused]],
int64_t line [[maybe_unused]]) {
return b;
}
inline bool guard_or_true(
const c10::SymBool& b,
const char* file,
int64_t line) {
return b.guard_or_true(file, line);
}
#define TORCH_GUARD_SIZE_OBLIVIOUS(cond) \
c10::guard_size_oblivious((cond), __FILE__, __LINE__)
#define TORCH_GUARD_OR_FALSE(cond) \
c10::guard_or_false((cond), __FILE__, __LINE__)
#define TORCH_GUARD_OR_TRUE(cond) c10::guard_or_true((cond), __FILE__, __LINE__)
} // namespace c10