-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
python_torch_functions.cpp
638 lines (563 loc) · 25.3 KB
/
python_torch_functions.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
// ${generated_comment}
// Python bindings for torch.* functions implemented through ATen.
//
// The functions are bound as static methods on a class
// torch._C._VariableFunctions which is also aliased as Variable._torch
// and also copied into 'torch' module.
#include <Python.h>
// Undefine the copysign macro so that at::copysign works as intended with MSVC
// https://github.com/python/cpython/blob/c60394c7fc9cc09b16e9675a3eeb5844b6d8523f/PC/pyconfig.h#L196
#ifdef _MSC_VER
#undef copysign
#endif // _MSC_VER
#include "torch/csrc/autograd/python_variable.h"
#include "torch/csrc/autograd/utils/wrap_outputs.h"
#include "torch/csrc/Dtype.h"
#include "torch/csrc/DynamicTypes.h"
#include "torch/csrc/Exceptions.h"
#include "torch/csrc/utils/pybind.h"
#include "torch/csrc/utils/python_arg_parser.h"
#include "torch/csrc/utils/tensor_layouts.h"
#include "torch/csrc/utils/tensor_new.h"
#include "torch/csrc/utils/tensor_numpy.h"
#include "torch/csrc/jit/frontend/tracer.h"
#include "torch/csrc/autograd/generated/variable_factories.h"
#include "torch/csrc/utils/structseq.h"
#include "torch/csrc/utils/cuda_lazy_init.h"
#include <ATen/ATen.h>
#include <functional>
#include <initializer_list>
#include <stdexcept>
#include <utility>
using at::Tensor;
using at::Device;
using at::Layout;
using at::Scalar;
using at::ScalarType;
using at::Backend;
using at::OptionalDeviceGuard;
using at::DeviceGuard;
using at::TensorOptions;
using at::IntArrayRef;
using at::Generator;
using at::TensorList;
using at::Dimname;
using at::DimnameList;
using at::ArrayRef;
using namespace torch::autograd::utils;
namespace torch { namespace autograd {
static PyObject* THPVariableFunctionsModule = NULL;
static void check_out_type_matches(Tensor result,
ScalarType scalarType, bool scalarType_is_none,
c10::optional<at::Layout> layout,
const Device& device, bool device_is_none) {
if (scalarType_is_none && !layout && device_is_none) { // common case
return;
}
if (!scalarType_is_none && result.scalar_type() != scalarType) {
AT_ERROR(
"dtype ", scalarType,
" does not match dtype of out parameter (", result.scalar_type(), ")");
}
auto scalarType_arg = scalarType_is_none ? result.scalar_type() : scalarType;
auto device_type_arg = device_is_none ? result.device().type() : device.type();
if (result.scalar_type() != scalarType_arg) {
AT_ERROR(
"scalar type ", scalarType_arg,
" does not match scalar type of out parameter (", result.scalar_type(), ")");
}
if (layout && result.layout() != *layout) {
AT_ERROR(
"layout ", *layout,
" does not match layout of out parameter (", result.layout(), ")");
}
if (result.device().type() != device_type_arg) {
AT_ERROR(
"device type ", device_type_arg,
" does not match device type of out parameter (", result.device().type(), ")");
}
}
inline Tensor dispatch_arange(Scalar end, Tensor result) {
pybind11::gil_scoped_release no_gil;
return at::arange_out(result, end);
}
inline Tensor dispatch_arange(Scalar end, const TensorOptions& options) {
torch::utils::maybe_initialize_cuda(options);
pybind11::gil_scoped_release no_gil;
return torch::arange(end, options);
}
inline Tensor dispatch_arange(Scalar start, Scalar end, Scalar step, Tensor result) {
pybind11::gil_scoped_release no_gil;
return at::arange_out(result, start, end, step);
}
inline Tensor dispatch_arange(Scalar start, Scalar end, Scalar step, const TensorOptions& options) {
torch::utils::maybe_initialize_cuda(options);
pybind11::gil_scoped_release no_gil;
return torch::arange(start, end, step, options);
}
static PyObject * THPVariable_arange(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
"arange(Scalar end, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)",
"arange(Scalar start, Scalar end, Scalar step=1, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)",
}, /*traceable=*/true);
ParsedArgs<9> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if(r.has_torch_function()) {
return handle_torch_function(r, args, kwargs, THPVariableFunctionsModule, "torch");
}
if (r.idx == 0) {
if (r.isNone(1)) {
auto end = r.scalar(0);
// NOTE: r.scalartype(X) gives the default dtype if r.isNone(X)
c10::optional<ScalarType> scalarType = r.scalartypeOptional(2);
const auto options = TensorOptions()
.dtype(scalarType)
.device(r.device(4))
.layout(r.layout(3))
.requires_grad(r.toBool(6))
.pinned_memory(r.toBool(5));
return wrap(dispatch_arange(end, options));
} else {
TORCH_CHECK(!r.toBool(5), " `pin_memory` and `out` parameters are incompatible");
check_out_type_matches(r.tensor(1), r.scalartype(2), r.isNone(2), r.layout(3),
r.device(4), r.isNone(4));
return wrap(dispatch_arange(r.scalar(0), r.tensor(1)).set_requires_grad(r.toBool(6)));
}
} else if (r.idx == 1) {
if (r.isNone(3)) {
auto start = r.scalar(0);
auto end = r.scalar(1);
auto step = r.scalar(2);
// NOTE: r.scalartype(X) gives the default dtype if r.isNone(X)
c10::optional<ScalarType> scalarType = r.scalartypeOptional(4);
const auto options = TensorOptions()
.dtype(scalarType)
.device(r.device(6))
.layout(r.layout(5))
.requires_grad(r.toBool(8))
.pinned_memory(r.toBool(7));
return wrap(dispatch_arange(start, end, step, options));
} else {
TORCH_CHECK(!r.toBool(7), " `pin_memory` and `out` parameters are incompatible");
check_out_type_matches(r.tensor(3), r.scalartype(4), r.isNone(4), r.layout(5),
r.device(6), r.isNone(6));
return wrap(dispatch_arange(r.scalar(0), r.scalar(1), r.scalar(2), r.tensor(3)).set_requires_grad(r.toBool(8)));
}
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
inline Tensor dispatch_range(Scalar start, Scalar end, Scalar step, Tensor result) {
pybind11::gil_scoped_release no_gil;
OptionalDeviceGuard device_guard(device_of(result));
return at::range_out(result, start, end, step);
}
inline Tensor dispatch_range(Scalar start, Scalar end, Scalar step, const TensorOptions& options) {
torch::utils::maybe_initialize_cuda(options);
pybind11::gil_scoped_release no_gil;
DeviceGuard device_guard(options.device());
return torch::range(start, end, step, options);
}
static PyObject * THPVariable_range(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
"range(Scalar start, Scalar end, Scalar step=1, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool requires_grad=False)",
});
ParsedArgs<8> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
auto ret = PyErr_WarnEx(
PyExc_UserWarning,
"torch.range is deprecated and will be removed in a future release "
"because its behavior is inconsistent with Python's range builtin. "
"Instead, use torch.arange, which produces values in [start, end).",
1);
if (ret != 0) throw python_error();
if (r.isNone(3)) {
const auto options = TensorOptions()
.dtype(r.scalartype(4))
.device(r.device(6))
.layout(r.layout(5))
.requires_grad(r.toBool(7));
return wrap(dispatch_range(r.scalar(0), r.scalar(1), r.scalar(2), options));
} else {
check_out_type_matches(r.tensor(3), r.scalartype(4), r.isNone(4),
r.layout(5), r.device(6), r.isNone(6));
return wrap(dispatch_range(r.scalar(0), r.scalar(1), r.scalar(2), r.tensor(3)).set_requires_grad(r.toBool(7)));
}
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
inline Tensor dispatch_full(
IntArrayRef size,
Scalar fill_val,
const TensorOptions& options) {
torch::utils::maybe_initialize_cuda(options);
pybind11::gil_scoped_release no_gil;
return at::full(size, fill_val, options);
}
inline Tensor dispatch_full(
IntArrayRef size,
Scalar fill_val,
c10::optional<DimnameList> names,
const TensorOptions& options) {
torch::utils::maybe_initialize_cuda(options);
pybind11::gil_scoped_release no_gil;
return at::full(size, fill_val, names, options);
}
inline Tensor dispatch_full(
IntArrayRef size,
Scalar fill_val,
Tensor result) {
pybind11::gil_scoped_release no_gil;
return at::full_out(result, size, fill_val);
}
static PyObject * THPVariable_full(PyObject* self, PyObject* args, PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser({
"full(IntArrayRef size, Scalar fill_value, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)",
"full(IntArrayRef size, Scalar fill_value, *, DimnameList names=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)",
}, /*traceable=*/true);
// Acquires (common) arguments
ParsedArgs<8> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if(r.has_torch_function()) {
return handle_torch_function(r, args, kwargs, THPVariableFunctionsModule, "torch");
}
auto size = r.intlist(0);
auto fill_val = r.scalar(1);
const auto options = TensorOptions{}
.dtype(r.scalartypeOptional(3))
.layout(r.layout(4))
.device(r.device(5))
.pinned_memory(r.toBool(6));
if (r.idx == 0) {
// full
if (r.isNone(2)) {
return wrap(dispatch_full(size, fill_val, options).set_requires_grad(r.toBool(7)));
}
// full.out
// Validates out tensor and other kwargs
auto result = r.tensor(2);
TORCH_CHECK(!r.toBool(6), " `pin_memory` and `out` parameters are incompatible");
check_out_type_matches(result, r.scalartype(3), r.isNone(3), r.layout(4),
r.device(5), r.isNone(5));
return wrap(dispatch_full(size, fill_val, result).set_requires_grad(r.toBool(7)));
} else if (r.idx == 1) {
// full.names
if (r.isNone(2)) {
return wrap(dispatch_full(size, fill_val, c10::nullopt, options).set_requires_grad(r.toBool(7)));
}
// Converts from c10::optional<std:vector...> to c10::optional<ArrayRef...>
auto raw_names = r.toDimnameListOptional(2);
c10::optional<DimnameList> names(*raw_names);
return wrap(dispatch_full(size, fill_val, names, options).set_requires_grad(r.toBool(7)));
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
inline Tensor dispatch_randint(int64_t high, IntArrayRef size, c10::optional<Generator> generator, Tensor result) {
pybind11::gil_scoped_release no_gil;
return at::randint_out(result, high, size, generator);
}
inline Tensor dispatch_randint(int64_t high, IntArrayRef size, c10::optional<Generator> generator, const TensorOptions & options) {
torch::utils::maybe_initialize_cuda(options);
pybind11::gil_scoped_release no_gil;
return torch::randint(high, size, generator, options);
}
inline Tensor dispatch_randint(int64_t high, IntArrayRef size, Tensor result) {
pybind11::gil_scoped_release no_gil;
return at::randint_out(result, high, size);
}
inline Tensor dispatch_randint(int64_t high, IntArrayRef size, const TensorOptions & options) {
torch::utils::maybe_initialize_cuda(options);
pybind11::gil_scoped_release no_gil;
return torch::randint(high, size, options);
}
inline Tensor dispatch_randint(int64_t low, int64_t high, IntArrayRef size, c10::optional<Generator> generator, Tensor result) {
pybind11::gil_scoped_release no_gil;
return at::randint_out(result, low, high, size, generator);
}
inline Tensor dispatch_randint(int64_t low, int64_t high, IntArrayRef size, c10::optional<Generator> generator, const TensorOptions & options) {
torch::utils::maybe_initialize_cuda(options);
pybind11::gil_scoped_release no_gil;
return torch::randint(low, high, size, generator, options);
}
inline Tensor dispatch_randint(int64_t low, int64_t high, IntArrayRef size, Tensor result) {
pybind11::gil_scoped_release no_gil;
return at::randint_out(result, low, high, size);
}
inline Tensor dispatch_randint(int64_t low, int64_t high, IntArrayRef size, const TensorOptions & options) {
torch::utils::maybe_initialize_cuda(options);
pybind11::gil_scoped_release no_gil;
return torch::randint(low, high, size, options);
}
static PyObject * THPVariable_randint(PyObject* self_, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
"randint(int64_t high, IntArrayRef size, *, Generator generator=None, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool requires_grad=False)",
"randint(int64_t low, int64_t high, IntArrayRef size, *, Generator generator=None, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool requires_grad=False)",
}, /*traceable=*/false);
ParsedArgs<9> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if(r.has_torch_function()) {
return handle_torch_function(r, args, kwargs, THPVariableFunctionsModule, "torch");
}
if (r.idx == 0) {
if (r.isNone(3)) {
auto high = r.toInt64(0);
auto size = r.intlist(1);
auto generator = r.generator(2);
// NOTE: r.scalartype(X) gives the default dtype if r.isNone(X)
auto dtype = r.scalartypeWithDefault(4, at::ScalarType::Long);
auto device = r.device(6);
const auto options = TensorOptions()
.dtype(dtype)
.device(device)
.layout(r.layout(5))
.requires_grad(r.toBool(7));
return wrap(dispatch_randint(high, size, generator, options));
} else {
check_out_type_matches(r.tensor(3), r.scalartype(4), r.isNone(4),
r.layout(5), r.device(6), r.isNone(6));
return wrap(dispatch_randint(r.toInt64(0), r.intlist(1), r.generator(2), r.tensor(3)).set_requires_grad(r.toBool(7)));
}
} else if (r.idx == 1) {
if (r.isNone(4)) {
auto low = r.toInt64(0);
auto high = r.toInt64(1);
auto size = r.intlist(2);
auto generator = r.generator(3);
// NOTE: r.scalartype(X) gives the default dtype if r.isNone(X)
auto dtype = r.scalartypeWithDefault(5, at::ScalarType::Long);
auto device = r.device(7);
const auto options = TensorOptions()
.dtype(dtype)
.device(device)
.layout(r.layout(6))
.requires_grad(r.toBool(8));
return wrap(dispatch_randint(low, high, size, generator, options));
} else {
check_out_type_matches(r.tensor(4), r.scalartype(5), r.isNone(5),
r.layout(6), r.device(7), r.isNone(7));
return wrap(dispatch_randint(r.toInt64(0), r.toInt64(1), r.intlist(2), r.generator(3), r.tensor(4)).set_requires_grad(r.toBool(8)));
}
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
// implemented on python object to allow torch.as_tensor to be constructed with arbitrarily nested
// python objects - list, tuple, np array, scalar, etc.
static PyObject * THPVariable_as_tensor(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
jit::tracer::warn("torch.as_tensor", jit::tracer::WARN_CONSTRUCTOR);
return THPVariable_Wrap(torch::utils::as_tensor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs));
END_HANDLE_TH_ERRORS
}
// implemented on python object here because PyObject currently not natively declarable
// See: ATen/native/README.md for more context
static PyObject * THPVariable_from_numpy(PyObject* module, PyObject* arg)
{
HANDLE_TH_ERRORS
jit::tracer::warn("torch.from_numpy", jit::tracer::WARN_CONSTRUCTOR);
return THPVariable_Wrap(torch::utils::tensor_from_numpy(arg));
END_HANDLE_TH_ERRORS
}
static Tensor dispatch_nonzero(const Tensor & self) {
pybind11::gil_scoped_release no_gil;
OptionalDeviceGuard device_guard(device_of(self));
return self.nonzero();
}
static Tensor dispatch_nonzero(const Tensor & self, Tensor out) {
pybind11::gil_scoped_release no_gil;
OptionalDeviceGuard device_guard(device_of(self));
return at::nonzero_out(out, self);
}
static std::vector<Tensor> dispatch_nonzero_numpy(const Tensor & self) {
pybind11::gil_scoped_release no_gil;
OptionalDeviceGuard device_guard(device_of(self));
return self.nonzero_numpy();
}
static PyObject * THPVariable_nonzero(PyObject* self, PyObject* args, PyObject* kwargs);
static PyObject * THPVariable_sparse_coo_tensor(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
jit::tracer::warn("torch.sparse_coo_tensor", jit::tracer::WARN_CONSTRUCTOR);
return THPVariable_Wrap(torch::utils::sparse_coo_tensor_ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs));
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable__sparse_coo_tensor_unsafe(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
jit::tracer::warn("torch._sparse_coo_tensor_unsafe", jit::tracer::WARN_CONSTRUCTOR);
return THPVariable_Wrap(torch::utils::_sparse_coo_tensor_unsafe_ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs));
END_HANDLE_TH_ERRORS
}
// implemented on python object to allow torch.tensor to be constructed with arbitrarily nested
// python objects - list, tuple, np array, scalar, etc.
static PyObject * THPVariable_tensor(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
jit::tracer::warn("torch.tensor", jit::tracer::WARN_CONSTRUCTOR);
return THPVariable_Wrap(torch::utils::tensor_ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs));
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_get_device(PyObject* self_, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
"get_device(Tensor input)",
}, /*traceable=*/false);
ParsedArgs<1> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
return wrap(r.tensor(0).get_device());
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_numel(PyObject* self_, PyObject* args, PyObject* kwargs);
// generated forward declarations start here
${py_forwards}
// Wrapper converts a raised TypeError into returning NotImplemented
// Used to implement binary arithmetic operators
template <PyObject* (*Func)(PyObject*, PyObject*, PyObject*)>
static PyObject * TypeError_to_NotImplemented_(PyObject* self, PyObject* args, PyObject* kwargs) {
PyObject* ret = Func(self, args, kwargs);
if (!ret && PyErr_ExceptionMatches(PyExc_TypeError)) {
PyErr_Clear();
Py_INCREF(Py_NotImplemented);
ret = Py_NotImplemented;
}
return ret;
}
// XXX: ops that are bound here are not exposed to the C++ api nor the JIT.
// Any new ops added here should be accompanied with a comment why they are not
// being registered through native_functions.yaml, and be tagged cpp / JIT
static PyMethodDef torch_functions[] = {
{"arange", (PyCFunction)(void(*)(void))THPVariable_arange, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"as_tensor", (PyCFunction)(void(*)(void))THPVariable_as_tensor, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"dsmm", (PyCFunction)(void(*)(void))THPVariable_mm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"from_numpy", (PyCFunction)THPVariable_from_numpy, METH_STATIC | METH_O, NULL},
{"full", (PyCFunction)(void(*)(void))THPVariable_full, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"hsmm", (PyCFunction)(void(*)(void))THPVariable_hspmm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"nonzero", (PyCFunction)(void(*)(void))THPVariable_nonzero, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"randint", (PyCFunction)(void(*)(void))THPVariable_randint, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"range", (PyCFunction)(void(*)(void))THPVariable_range, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"saddmm", (PyCFunction)(void(*)(void))THPVariable_sspaddmm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"sparse_coo_tensor", (PyCFunction)(void(*)(void))THPVariable_sparse_coo_tensor, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"_sparse_coo_tensor_unsafe", (PyCFunction)(void(*)(void))THPVariable__sparse_coo_tensor_unsafe, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"_validate_sparse_coo_tensor_args", (PyCFunction)(void(*)(void))THPVariable__validate_sparse_coo_tensor_args, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"spmm", (PyCFunction)(void(*)(void))THPVariable_mm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"tensor", (PyCFunction)(void(*)(void))THPVariable_tensor, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"get_device", (PyCFunction)(void(*)(void))THPVariable_get_device, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"numel", (PyCFunction)(void(*)(void))THPVariable_numel, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
${py_method_defs}
{NULL}
};
static PyTypeObject THPVariableFunctions = {
PyVarObject_HEAD_INIT(NULL, 0)
"torch._C._VariableFunctionsClass", /* tp_name */
0, /* tp_basicsize */
0, /* tp_itemsize */
0, /* tp_dealloc */
0, /* tp_vectorcall_offset */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_reserved */
0, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
0, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT, /* tp_flags */
NULL, /* tp_doc */
0, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
torch_functions, /* tp_methods */
0, /* tp_members */
0, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
0, /* tp_init */
0, /* tp_alloc */
0 /* tp_new */
};
void initTorchFunctions(PyObject* module) {
if (PyType_Ready(&THPVariableFunctions) < 0) {
throw python_error();
}
Py_INCREF(&THPVariableFunctions);
// Steals
Py_INCREF(&THPVariableFunctions);
if (PyModule_AddObject(module, "_VariableFunctionsClass", reinterpret_cast<PyObject*>(&THPVariableFunctions)) < 0) {
throw python_error();
}
// PyType_GenericNew returns a new reference
THPVariableFunctionsModule = PyType_GenericNew(&THPVariableFunctions, Py_None, Py_None);
// PyModule_AddObject steals a reference
if (PyModule_AddObject(module, "_VariableFunctions", THPVariableFunctionsModule) < 0) {
throw python_error();
}
}
// generated methods start here
${py_methods}
static PyObject * THPVariable_nonzero(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
"nonzero(Tensor input, *, bool as_tuple=False, Tensor out=None)",
});
ParsedArgs<3> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if(r.has_torch_function()){
return handle_torch_function(r, args, kwargs, THPVariableFunctionsModule, "torch");
}
const auto as_tuple = r.toBool(1);
const auto has_out = !r.isNone(2);
if (as_tuple) {
TORCH_CHECK(!has_out, "nonzero does not support the out kwarg when as_tuple is True");
return wrap(dispatch_nonzero_numpy(r.tensor(0)));
}
if (has_out) {
return wrap(dispatch_nonzero(r.tensor(0), r.tensor(2)));
}
return wrap(dispatch_nonzero(r.tensor(0)));
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_numel(PyObject* self_, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
"numel(Tensor input)",
}, /*traceable=*/false);
ParsedArgs<1> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if(r.has_torch_function()){
return handle_torch_function(r, args, kwargs, THPVariableFunctionsModule, "torch");
}
if (r.idx == 0) {
return wrap(r.tensor(0).numel());
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
}} // namespace torch::autograd