/
Module.cpp
298 lines (258 loc) · 9.42 KB
/
Module.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
#include <Python.h>
#include <stdbool.h>
#include <unordered_map>
#include <TH/TH.h>
#include <THC/THCCachingAllocator.h>
#include "THCP.h"
#include "ModuleSparse.cpp"
THCState *state;
////////////////////////////////////////////////////////////////////////////////
// Class pointer cache
////////////////////////////////////////////////////////////////////////////////
static bool THCPModule_loadClasses(PyObject *torch_module)
{
#define ASSERT_NOT_NULL(ptr) if (!(ptr)) { THPUtils_setError("couldn't load classes"); return false; }
ASSERT_NOT_NULL(THCPDoubleStorageClass = PyObject_GetAttrString(torch_module, (char*)"DoubleStorage"));
ASSERT_NOT_NULL(THCPFloatStorageClass = PyObject_GetAttrString(torch_module, (char*)"FloatStorage"));
ASSERT_NOT_NULL(THCPHalfStorageClass = PyObject_GetAttrString(torch_module, (char*)"HalfStorage"));
ASSERT_NOT_NULL(THCPLongStorageClass = PyObject_GetAttrString(torch_module, (char*)"LongStorage"));
ASSERT_NOT_NULL(THCPIntStorageClass = PyObject_GetAttrString(torch_module, (char*)"IntStorage"));
ASSERT_NOT_NULL(THCPShortStorageClass = PyObject_GetAttrString(torch_module, (char*)"ShortStorage"));
ASSERT_NOT_NULL(THCPCharStorageClass = PyObject_GetAttrString(torch_module, (char*)"CharStorage"));
ASSERT_NOT_NULL(THCPByteStorageClass = PyObject_GetAttrString(torch_module, (char*)"ByteStorage"));
if (!THCPDoubleTensor_postInit(torch_module)) return false;
if (!THCPFloatTensor_postInit(torch_module)) return false;
if (!THCPHalfTensor_postInit(torch_module)) return false;
if (!THCPLongTensor_postInit(torch_module)) return false;
if (!THCPIntTensor_postInit(torch_module)) return false;
if (!THCPShortTensor_postInit(torch_module)) return false;
if (!THCPCharTensor_postInit(torch_module)) return false;
if (!THCPByteTensor_postInit(torch_module)) return false;
return true;
#undef ASSERT_NOT_NULL
}
////////////////////////////////////////////////////////////////////////////////
// Tensor stateless methods
////////////////////////////////////////////////////////////////////////////////
static bool THCPModule_assignStateless()
{
#define INIT_STATELESS(type) INIT_STATELESS_DETAIL(type, TH_CONCAT_2(Cuda, type))
#define INIT_STATELESS_DETAIL(type,ctype) \
stateless = PyObject_Call((PyObject*)&TH_CONCAT_2(ctype, TensorStatelessType), arg, NULL); \
if (!stateless) { \
THPUtils_setError("stateless method initialization error"); \
return false; \
} \
if (PyObject_SetAttrString(TH_CONCAT_3(THCP,type,TensorClass), THP_STATELESS_ATTRIBUTE_NAME, stateless) == -1) { \
THPUtils_setError("stateless method initialization error (on assignment)");\
}
PyObject *arg = PyTuple_New(0);
PyObject *stateless;
INIT_STATELESS(Double);
INIT_STATELESS_DETAIL(Float, Cuda);
INIT_STATELESS(Half);
INIT_STATELESS(Long);
INIT_STATELESS(Int);
INIT_STATELESS(Short);
INIT_STATELESS(Char);
INIT_STATELESS(Byte);
Py_DECREF(arg);
return true;
#undef INIT_STATELESS_DETAIL
#undef INIT_STATELESS
}
////////////////////////////////////////////////////////////////////////////////
// CUDA management methods
////////////////////////////////////////////////////////////////////////////////
void THCPModule_setDevice(int device)
{
THCudaCheck(cudaSetDevice(device));
}
PyObject * THCPModule_setDevice_wrap(PyObject *self, PyObject *arg)
{
HANDLE_TH_ERRORS
THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to setDevice");
long device = THPUtils_unpackLong(arg);
THCPModule_setDevice(device);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject * THCPModule_getDevice_wrap(PyObject *self)
{
HANDLE_TH_ERRORS
int device;
THCudaCheck(cudaGetDevice(&device));
return PyLong_FromLong(device);
END_HANDLE_TH_ERRORS
}
PyObject * THCPModule_getDeviceCount_wrap(PyObject *self)
{
HANDLE_TH_ERRORS
int ndevice;
THCudaCheck(cudaGetDeviceCount(&ndevice));
return PyLong_FromLong(ndevice);
END_HANDLE_TH_ERRORS
}
PyObject * THCPModule_getCurrentStream_wrap(PyObject *self)
{
HANDLE_TH_ERRORS
THCStream* stream = THCState_getStream(state);
return PyLong_FromVoidPtr(stream);
END_HANDLE_TH_ERRORS
}
PyObject * THCPModule_setStream_wrap(PyObject *self, PyObject *obj)
{
HANDLE_TH_ERRORS
THPUtils_assert(PyLong_Check(obj), "invalid stream");
THCStream* stream = (THCStream *)PyLong_AsVoidPtr(obj);
THCState_setStream(state, stream);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject * THCPModule_isDriverSufficient(PyObject *self)
{
int count;
cudaError_t err = cudaGetDeviceCount(&count);
if (err == cudaErrorInsufficientDriver) {
return PyBool_FromLong(0);
}
return PyBool_FromLong(1);
}
PyObject * THCPModule_getDriverVersion(PyObject *self)
{
int driverVersion = -1;
cudaError_t err = cudaDriverGetVersion(&driverVersion);
if (err != cudaSuccess) {
PyErr_Format(PyExc_RuntimeError,
"Error calling cudaDriverGetVersion: %d %s",
err, cudaGetErrorString(err));
return NULL;
}
return PyLong_FromLong((long) driverVersion);
}
PyObject * THCPModule_getRNGState(PyObject *_unused)
{
HANDLE_TH_ERRORS
THPByteTensorPtr res = (THPByteTensor *)THPByteTensor_NewEmpty();
if (!res) return NULL;
THCRandom_getRNGState(state, res->cdata);
return (PyObject *)res.release();
END_HANDLE_TH_ERRORS
}
PyObject * THCPModule_setRNGState(PyObject *_unused, PyObject *_new_rng_state)
{
HANDLE_TH_ERRORS
THPUtils_assert(THPByteTensor_Check(_new_rng_state), "set_rng_state expects a "
"torch.ByteTensor, but got %s", THPUtils_typename(_new_rng_state));
THByteTensor *new_rng_state = ((THPByteTensor*)_new_rng_state)->cdata;
THCRandom_setRNGState(state, new_rng_state);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject * THCPModule_manualSeed(PyObject *_unused, PyObject *seed)
{
HANDLE_TH_ERRORS
THPUtils_assert(THPUtils_checkLong(seed), "manual_seed expected a long, "
"but got %s", THPUtils_typename(seed));
THCRandom_manualSeed(state, THPUtils_unpackLong(seed));
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject * THCPModule_manualSeedAll(PyObject *_unused, PyObject *seed)
{
HANDLE_TH_ERRORS
THPUtils_assert(THPUtils_checkLong(seed), "manual_seed expected a long, "
"but got %s", THPUtils_typename(seed));
THCRandom_manualSeedAll(state, THPUtils_unpackLong(seed));
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject * THCPModule_seed(PyObject *_unused)
{
HANDLE_TH_ERRORS
return PyLong_FromUnsignedLong(THCRandom_seed(state));
END_HANDLE_TH_ERRORS
}
PyObject * THCPModule_seedAll(PyObject *_unused)
{
HANDLE_TH_ERRORS
return PyLong_FromUnsignedLong(THCRandom_seedAll(state));
END_HANDLE_TH_ERRORS
}
PyObject * THCPModule_initialSeed(PyObject *_unused)
{
HANDLE_TH_ERRORS
return PyLong_FromUnsignedLong(THCRandom_initialSeed(state));
END_HANDLE_TH_ERRORS
}
PyObject * THCPModule_cudaHostAllocator(PyObject *_unused)
{
HANDLE_TH_ERRORS
THAllocator* allocator = THCState_getCudaHostAllocator(state);
return PyLong_FromVoidPtr(allocator);
END_HANDLE_TH_ERRORS
}
PyObject * THCPModule_cudaSynchronize(PyObject *_unused)
{
HANDLE_TH_ERRORS
THCudaCheck(cudaDeviceSynchronize());
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject * THCPModule_cudaSleep(PyObject *_unused, PyObject *cycles)
{
HANDLE_TH_ERRORS
THPUtils_assert(THPUtils_checkLong(cycles), "torch.cuda._sleep(): expected 'int'");
THC_sleep(LIBRARY_STATE THPUtils_unpackLong(cycles));
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject * THCPModule_getLibPath(PyObject *_unused)
{
#define _STR(x) #x
#define STR(x) _STR(x)
#if PY_MAJOR_VERSION == 2
return PyString_FromString(STR(CUDA_LIB_PATH));
#else
return PyUnicode_FromString(STR(CUDA_LIB_PATH));
#endif
#undef STR
#undef _STR
}
////////////////////////////////////////////////////////////////////////////////
// Cuda module initialization
////////////////////////////////////////////////////////////////////////////////
bool THCPModule_initCuda(PyObject *torch_module) {
#define ASSERT_TRUE(cond) if (!(cond)) { return false; }
state = THCState_alloc();
THCState_setDeviceAllocator(state, THCCachingAllocator_get());
state->cudaHostAllocator = &THCCachingHostAllocator;
THCudaInit(state);
#ifdef USE_MAGMA
THCMagma_init(state);
ASSERT_TRUE(PyObject_SetAttrString(torch_module, "has_magma", PyBool_FromLong(true)) != -1);
#else
ASSERT_TRUE(PyObject_SetAttrString(torch_module, "has_magma", PyBool_FromLong(false)) != -1);
#endif
#ifdef CUDA_HALF_TENSOR
ASSERT_TRUE(PyObject_SetAttrString(torch_module, "has_half", PyBool_FromLong(true)) != -1);
#else
ASSERT_TRUE(PyObject_SetAttrString(torch_module, "has_half", PyBool_FromLong(false)) != -1);
#endif
ASSERT_TRUE(THCPModule_loadClasses(torch_module));
ASSERT_TRUE(THCPModule_assignStateless());
ASSERT_TRUE(PyObject_SetAttrString(torch_module, "_state_cdata", PyLong_FromVoidPtr(state)) != -1);
// TODO: register THCudaShutdown handler at exit
return true;
#undef ASSERT_TRUE
}
// Callback for python part. Used for additional initialization of python classes
PyObject * THCPModule_initExtension(PyObject *self)
{
PyObject *torch_module = PyImport_ImportModule("torch.cuda");
if (!torch_module) {
THPUtils_setError("class loader couldn't access torch module");
return NULL;
}
return PyBool_FromLong(THCPModule_initCuda(torch_module));
}