-
Notifications
You must be signed in to change notification settings - Fork 66
/
wasm.h
450 lines (379 loc) · 17.6 KB
/
wasm.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
// Copyright 2016-2019 Envoy Project Authors
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string.h>
#include <atomic>
#include <deque>
#include <map>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include "include/proxy-wasm/sdk.h"
#include "include/proxy-wasm/context.h"
#include "include/proxy-wasm/exports.h"
#include "include/proxy-wasm/wasm_vm.h"
#include "include/proxy-wasm/vm_id_handle.h"
namespace proxy_wasm {
class ContextBase;
class WasmHandleBase;
using WasmVmFactory = std::function<std::unique_ptr<WasmVm>()>;
using CallOnThreadFunction = std::function<void(std::function<void()>)>;
struct SanitizationConfig {
std::vector<std::string> argument_list;
bool is_allowlist;
};
using AllowedCapabilitiesMap = std::unordered_map<std::string, SanitizationConfig>;
// Wasm execution instance. Manages the host side of the Wasm interface.
class WasmBase : public std::enable_shared_from_this<WasmBase> {
public:
WasmBase(std::unique_ptr<WasmVm> wasm_vm, std::string_view vm_id,
std::string_view vm_configuration, std::string_view vm_key,
std::unordered_map<std::string, std::string> envs,
AllowedCapabilitiesMap allowed_capabilities);
WasmBase(const std::shared_ptr<WasmHandleBase> &base_wasm_handle, const WasmVmFactory &factory);
virtual ~WasmBase();
bool load(const std::string &code, bool allow_precompiled = false);
bool initialize();
void startVm(ContextBase *root_context);
bool configure(ContextBase *root_context, std::shared_ptr<PluginBase> plugin);
// Returns the root ContextBase or nullptr if onStart returns false.
ContextBase *start(const std::shared_ptr<PluginBase> &plugin);
std::string_view vm_id() const { return vm_id_; }
std::string_view vm_key() const { return vm_key_; }
WasmVm *wasm_vm() const { return wasm_vm_.get(); }
ContextBase *vm_context() const { return vm_context_.get(); }
ContextBase *getRootContext(const std::shared_ptr<PluginBase> &plugin, bool allow_closed);
ContextBase *getContext(uint32_t id) {
auto it = contexts_.find(id);
if (it != contexts_.end())
return it->second;
return nullptr;
}
uint32_t allocContextId();
bool isFailed() { return failed_ != FailState::Ok; }
FailState fail_state() { return failed_; }
const std::string &vm_configuration() const;
const std::string &moduleBytecode() const { return module_bytecode_; }
const std::string &modulePrecompiled() const { return module_precompiled_; }
const std::unordered_map<uint32_t, std::string> functionNames() const { return function_names_; }
void timerReady(uint32_t root_context_id);
void queueReady(uint32_t root_context_id, uint32_t token);
void startShutdown(std::string_view plugin_key);
void startShutdown();
WasmResult done(ContextBase *root_context);
void finishShutdown();
// Proxy specific extension points.
//
virtual void registerCallbacks(); // Register functions called out from Wasm.
virtual void getFunctions(); // Get functions call into Wasm.
virtual CallOnThreadFunction callOnThreadFunction() {
unimplemented();
return nullptr;
}
// Capability restriction (restricting/exposing the ABI).
bool capabilityAllowed(std::string capability_name) {
return allowed_capabilities_.empty() ||
allowed_capabilities_.find(capability_name) != allowed_capabilities_.end();
}
virtual ContextBase *createVmContext() { return new ContextBase(this); }
virtual ContextBase *createRootContext(const std::shared_ptr<PluginBase> &plugin) {
return new ContextBase(this, plugin);
}
virtual ContextBase *createContext(const std::shared_ptr<PluginBase> &plugin) {
return new ContextBase(this, plugin);
}
virtual void setTimerPeriod(uint32_t root_context_id, std::chrono::milliseconds period) {
timer_period_[root_context_id] = period;
}
// Support functions.
//
void *allocMemory(uint64_t size, uint64_t *address);
// Allocate a null-terminated string in the VM and return the pointer to use as a call arguments.
uint64_t copyString(std::string_view s);
// Copy the data in 's' into the VM along with the pointer-size pair. Returns true on success.
bool copyToPointerSize(std::string_view s, uint64_t ptr_ptr, uint64_t size_ptr);
template <typename T> bool setDatatype(uint64_t ptr, const T &t);
void fail(FailState fail_state, std::string_view message) {
error(message);
failed_ = fail_state;
}
virtual void error(std::string_view message) { std::cerr << message << "\n"; }
virtual void unimplemented() { error("unimplemented proxy-wasm API"); }
AbiVersion abiVersion() const { return abi_version_; }
const std::unordered_map<std::string, std::string> &envs() { return envs_; }
// Called to raise the flag which indicates that the context should stop iteration regardless of
// returned filter status from Proxy-Wasm extensions. For example, we ignore
// FilterHeadersStatus::Continue after a local reponse is sent by the host.
void stopNextIteration(bool stop) { stop_iteration_ = stop; };
bool isNextIterationStopped() { return stop_iteration_; };
void addAfterVmCallAction(std::function<void()> f) { after_vm_call_actions_.push_back(f); }
void doAfterVmCallActions() {
// NB: this may be deleted by a delayed function unless prevented.
if (!after_vm_call_actions_.empty()) {
auto self = shared_from_this();
while (!self->after_vm_call_actions_.empty()) {
auto f = std::move(self->after_vm_call_actions_.front());
self->after_vm_call_actions_.pop_front();
f();
}
}
}
static const uint32_t kMetricTypeMask = 0x3; // Enough to cover the 3 types.
static const uint32_t kMetricIdIncrement = 0x4; // Enough to cover the 3 types.
bool isCounterMetricId(uint32_t metric_id) {
return (metric_id & kMetricTypeMask) == static_cast<uint32_t>(MetricType::Counter);
}
bool isGaugeMetricId(uint32_t metric_id) {
return (metric_id & kMetricTypeMask) == static_cast<uint32_t>(MetricType::Gauge);
}
bool isHistogramMetricId(uint32_t metric_id) {
return (metric_id & kMetricTypeMask) == static_cast<uint32_t>(MetricType::Histogram);
}
uint32_t nextCounterMetricId() { return next_counter_metric_id_ += kMetricIdIncrement; }
uint32_t nextGaugeMetricId() { return next_gauge_metric_id_ += kMetricIdIncrement; }
uint32_t nextHistogramMetricId() { return next_histogram_metric_id_ += kMetricIdIncrement; }
enum class CalloutType : uint32_t {
HttpCall = 0,
GrpcCall = 1,
GrpcStream = 2,
};
static const uint32_t kCalloutTypeMask = 0x3; // Enough to cover the 3 types.
static const uint32_t kCalloutIncrement = 0x4; // Enough to cover the 3 types.
bool isHttpCallId(uint32_t callout_id) {
return (callout_id & kCalloutTypeMask) == static_cast<uint32_t>(CalloutType::HttpCall);
}
bool isGrpcCallId(uint32_t callout_id) {
return (callout_id & kCalloutTypeMask) == static_cast<uint32_t>(CalloutType::GrpcCall);
}
bool isGrpcStreamId(uint32_t callout_id) {
return (callout_id & kCalloutTypeMask) == static_cast<uint32_t>(CalloutType::GrpcStream);
}
uint32_t nextHttpCallId() {
// TODO(PiotrSikora): re-add rollover protection (requires at least 1 billion callouts).
return next_http_call_id_ += kCalloutIncrement;
}
uint32_t nextGrpcCallId() {
// TODO(PiotrSikora): re-add rollover protection (requires at least 1 billion callouts).
return next_grpc_call_id_ += kCalloutIncrement;
}
uint32_t nextGrpcStreamId() {
// TODO(PiotrSikora): re-add rollover protection (requires at least 1 billion callouts).
return next_grpc_stream_id_ += kCalloutIncrement;
}
protected:
friend class ContextBase;
class ShutdownHandle;
void establishEnvironment(); // Language specific environments.
std::string vm_id_; // User-provided vm_id.
std::string vm_key_; // vm_id + hash of code.
std::unique_ptr<WasmVm> wasm_vm_;
std::optional<Cloneable> started_from_;
uint32_t next_context_id_ = 1; // 0 is reserved for the VM context.
std::shared_ptr<ContextBase> vm_context_; // Context unrelated to any specific root or stream
// (e.g. for global constructors).
std::unordered_map<std::string, std::unique_ptr<ContextBase>> root_contexts_; // Root contexts.
std::unordered_map<std::string, std::unique_ptr<ContextBase>> pending_done_; // Root contexts.
std::unordered_set<std::unique_ptr<ContextBase>> pending_delete_; // Root contexts.
std::unordered_map<uint32_t, ContextBase *> contexts_; // Contains all contexts.
std::unordered_map<uint32_t, std::chrono::milliseconds> timer_period_; // per root_id.
std::unique_ptr<ShutdownHandle> shutdown_handle_;
std::unordered_map<std::string, std::string>
envs_; // environment variables passed through wasi.environ_get
WasmCallVoid<0> _initialize_; /* WASI reactor (Emscripten v1.39.17+, Rust nightly) */
WasmCallVoid<0> _start_; /* WASI command (Emscripten v1.39.0+, TinyGo) */
WasmCallWord<2> main_;
WasmCallWord<1> malloc_;
// Calls into the VM.
WasmCallWord<2> validate_configuration_;
WasmCallWord<2> on_vm_start_;
WasmCallWord<2> on_configure_;
WasmCallVoid<1> on_tick_;
WasmCallVoid<2> on_context_create_;
WasmCallWord<1> on_new_connection_;
WasmCallWord<3> on_downstream_data_;
WasmCallWord<3> on_upstream_data_;
WasmCallVoid<2> on_downstream_connection_close_;
WasmCallVoid<2> on_upstream_connection_close_;
WasmCallWord<2> on_request_headers_abi_01_;
WasmCallWord<3> on_request_headers_abi_02_;
WasmCallWord<3> on_request_body_;
WasmCallWord<2> on_request_trailers_;
WasmCallWord<2> on_request_metadata_;
WasmCallWord<2> on_response_headers_abi_01_;
WasmCallWord<3> on_response_headers_abi_02_;
WasmCallWord<3> on_response_body_;
WasmCallWord<2> on_response_trailers_;
WasmCallWord<2> on_response_metadata_;
WasmCallVoid<5> on_http_call_response_;
WasmCallVoid<3> on_grpc_receive_;
WasmCallVoid<3> on_grpc_close_;
WasmCallVoid<3> on_grpc_create_initial_metadata_;
WasmCallVoid<3> on_grpc_receive_initial_metadata_;
WasmCallVoid<3> on_grpc_receive_trailing_metadata_;
WasmCallVoid<2> on_queue_ready_;
WasmCallVoid<3> on_foreign_function_;
WasmCallWord<1> on_done_;
WasmCallVoid<1> on_log_;
WasmCallVoid<1> on_delete_;
#define FOR_ALL_MODULE_FUNCTIONS(_f) \
_f(validate_configuration) _f(on_vm_start) _f(on_configure) _f(on_tick) _f(on_context_create) \
_f(on_new_connection) _f(on_downstream_data) _f(on_upstream_data) \
_f(on_downstream_connection_close) _f(on_upstream_connection_close) _f(on_request_body) \
_f(on_request_trailers) _f(on_request_metadata) _f(on_response_body) \
_f(on_response_trailers) _f(on_response_metadata) _f(on_http_call_response) \
_f(on_grpc_receive) _f(on_grpc_close) _f(on_grpc_receive_initial_metadata) \
_f(on_grpc_receive_trailing_metadata) _f(on_queue_ready) _f(on_done) \
_f(on_log) _f(on_delete)
// Capabilities which are allowed to be linked to the module. If this is empty, restriction
// is not enforced.
AllowedCapabilitiesMap allowed_capabilities_;
std::shared_ptr<WasmHandleBase> base_wasm_handle_;
// Used by the base_wasm to enable non-clonable thread local Wasm(s) to be constructed.
std::string module_bytecode_;
std::string module_precompiled_;
std::unordered_map<uint32_t, std::string> function_names_;
// ABI version.
AbiVersion abi_version_ = AbiVersion::Unknown;
std::string vm_configuration_;
bool stop_iteration_ = false;
FailState failed_ = FailState::Ok; // Wasm VM fatal error.
// Plugin Stats/Metrics
uint32_t next_counter_metric_id_ = static_cast<uint32_t>(MetricType::Counter);
uint32_t next_gauge_metric_id_ = static_cast<uint32_t>(MetricType::Gauge);
uint32_t next_histogram_metric_id_ = static_cast<uint32_t>(MetricType::Histogram);
// HTTP/gRPC callouts.
uint32_t next_http_call_id_ = static_cast<uint32_t>(CalloutType::HttpCall);
uint32_t next_grpc_call_id_ = static_cast<uint32_t>(CalloutType::GrpcCall);
uint32_t next_grpc_stream_id_ = static_cast<uint32_t>(CalloutType::GrpcStream);
// Actions to be done after the call into the VM returns.
std::deque<std::function<void()>> after_vm_call_actions_;
std::shared_ptr<VmIdHandle> vm_id_handle_;
};
using WasmHandleFactory = std::function<std::shared_ptr<WasmHandleBase>(std::string_view vm_id)>;
using WasmHandleCloneFactory =
std::function<std::shared_ptr<WasmHandleBase>(std::shared_ptr<WasmHandleBase> wasm)>;
// Handle which enables shutdown operations to run post deletion (e.g. post listener drain).
class WasmHandleBase : public std::enable_shared_from_this<WasmHandleBase> {
public:
explicit WasmHandleBase(std::shared_ptr<WasmBase> wasm_base) : wasm_base_(wasm_base) {}
~WasmHandleBase() {
if (wasm_base_) {
wasm_base_->startShutdown();
}
}
bool canary(const std::shared_ptr<PluginBase> &plugin,
const WasmHandleCloneFactory &clone_factory);
void kill() { wasm_base_ = nullptr; }
std::shared_ptr<WasmBase> &wasm() { return wasm_base_; }
protected:
std::shared_ptr<WasmBase> wasm_base_;
std::unordered_map<std::string, bool> plugin_canary_cache_;
};
std::string makeVmKey(std::string_view vm_id, std::string_view configuration,
std::string_view code);
// Returns nullptr on failure (i.e. initialization of the VM fails).
std::shared_ptr<WasmHandleBase> createWasm(const std::string &vm_key, const std::string &code,
const std::shared_ptr<PluginBase> &plugin,
const WasmHandleFactory &factory,
const WasmHandleCloneFactory &clone_factory,
bool allow_precompiled);
// Get an existing ThreadLocal VM matching 'vm_key' or nullptr if there isn't one.
std::shared_ptr<WasmHandleBase> getThreadLocalWasm(std::string_view vm_key);
class PluginHandleBase : public std::enable_shared_from_this<PluginHandleBase> {
public:
explicit PluginHandleBase(std::shared_ptr<WasmHandleBase> wasm_handle,
std::shared_ptr<PluginBase> plugin)
: plugin_(plugin), wasm_handle_(wasm_handle) {}
~PluginHandleBase() {
if (wasm_handle_) {
wasm_handle_->wasm()->startShutdown(plugin_->key());
}
}
std::shared_ptr<PluginBase> &plugin() { return plugin_; }
std::shared_ptr<WasmBase> &wasm() { return wasm_handle_->wasm(); }
protected:
std::shared_ptr<PluginBase> plugin_;
std::shared_ptr<WasmHandleBase> wasm_handle_;
};
using PluginHandleFactory = std::function<std::shared_ptr<PluginHandleBase>(
std::shared_ptr<WasmHandleBase> base_wasm, std::shared_ptr<PluginBase> plugin)>;
// Get an existing ThreadLocal VM matching 'vm_id' or create one using 'base_wavm' by cloning or by
// using it it as a template.
std::shared_ptr<PluginHandleBase> getOrCreateThreadLocalPlugin(
const std::shared_ptr<WasmHandleBase> &base_handle, const std::shared_ptr<PluginBase> &plugin,
const WasmHandleCloneFactory &clone_factory, const PluginHandleFactory &plugin_factory);
// Clear Base Wasm cache and the thread-local Wasm sandbox cache for the calling thread.
void clearWasmCachesForTesting();
inline const std::string &WasmBase::vm_configuration() const {
if (base_wasm_handle_)
return base_wasm_handle_->wasm()->vm_configuration_;
return vm_configuration_;
}
inline void *WasmBase::allocMemory(uint64_t size, uint64_t *address) {
if (!malloc_) {
return nullptr;
}
wasm_vm_->setRestrictedCallback(
true, {// logging (Proxy-Wasm)
"env.proxy_log",
// logging (stdout/stderr)
"wasi_unstable.fd_write", "wasi_snapshot_preview1.fd_write",
// time
"wasi_unstable.clock_time_get", "wasi_snapshot_preview1.clock_time_get"});
Word a = malloc_(vm_context(), size);
wasm_vm_->setRestrictedCallback(false);
if (!a.u64_) {
return nullptr;
}
auto memory = wasm_vm_->getMemory(a.u64_, size);
if (!memory) {
return nullptr;
}
*address = a.u64_;
return const_cast<void *>(reinterpret_cast<const void *>(memory.value().data()));
}
inline uint64_t WasmBase::copyString(std::string_view s) {
if (s.empty()) {
return 0; // nullptr
}
uint64_t pointer = 0;
uint8_t *m = static_cast<uint8_t *>(allocMemory((s.size() + 1), &pointer));
memcpy(m, s.data(), s.size());
m[s.size()] = 0;
return pointer;
}
inline bool WasmBase::copyToPointerSize(std::string_view s, uint64_t ptr_ptr, uint64_t size_ptr) {
uint64_t pointer = 0;
uint64_t size = s.size();
void *p = nullptr;
if (size > 0) {
p = allocMemory(size, &pointer);
if (!p) {
return false;
}
memcpy(p, s.data(), size);
}
if (!wasm_vm_->setWord(ptr_ptr, Word(pointer))) {
return false;
}
if (!wasm_vm_->setWord(size_ptr, Word(size))) {
return false;
}
return true;
}
template <typename T> inline bool WasmBase::setDatatype(uint64_t ptr, const T &t) {
return wasm_vm_->setMemory(ptr, sizeof(T), &t);
}
} // namespace proxy_wasm