From a4dfaa6f5b5f325023215b98c37bfc5c043cdb9d Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 6 Oct 2022 21:56:44 -0700 Subject: [PATCH 01/36] Move C/C++ files to core --- setup.py | 16 +- torchdynamo/_eval_frame.c | 540 ------------------------------- torchdynamo/_guards.cpp | 360 --------------------- torchdynamo/allowed_functions.py | 8 +- torchdynamo/eval_frame.py | 6 +- torchdynamo/guards.py | 6 +- 6 files changed, 13 insertions(+), 923 deletions(-) delete mode 100644 torchdynamo/_eval_frame.c delete mode 100644 torchdynamo/_guards.cpp diff --git a/setup.py b/setup.py index 73f4886fb1..09b3848d7d 100755 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ ] install_requires = [ - "torch>=1.12.0", + "torch>=1.13.0", "numpy", "tabulate", "pyyaml", @@ -49,7 +49,7 @@ long_description=long_description, long_description_content_type="text/markdown", author="Jason Ansel", - author_email="jansel@fb.com", + author_email="jansel@meta.com", license="BSD-3", keywords="pytorch machine learning compilers", python_requires=">=3.7, <3.11", @@ -59,16 +59,4 @@ "torchinductor.codegen": ["*.h", "*.j2"], }, zip_safe=False, - ext_modules=[ - Extension( - "torchdynamo._eval_frame", - ["torchdynamo/_eval_frame.c"], - extra_compile_args=["-Wall"], - ), - CppExtension( - name="torchdynamo._guards", - sources=["torchdynamo/_guards.cpp"], - extra_compile_args=["-std=c++14"], - ), - ], ) diff --git a/torchdynamo/_eval_frame.c b/torchdynamo/_eval_frame.c deleted file mode 100644 index b1bb870fc3..0000000000 --- a/torchdynamo/_eval_frame.c +++ /dev/null @@ -1,540 +0,0 @@ -#define PY_SSIZE_T_CLEAN -#include -#include -#include - -// see https://bugs.python.org/issue35886 -#if PY_VERSION_HEX >= 0x03080000 -#define Py_BUILD_CORE -#include "internal/pycore_pystate.h" -#undef Py_BUILD_CORE -#endif - -#define bool char -#define false 0 -#define true 1 -#define unlikely(x) __builtin_expect((x), 0) - -#define NULL_CHECK(val) \ - if (unlikely((val) == NULL)) { \ - fprintf(stderr, "NULL ERROR: %s:%d\n", __FILE__, __LINE__); \ - PyErr_Print(); \ - abort(); \ - } else { \ - } - -#define CHECK(cond) \ - if (unlikely(!(cond))) { \ - fprintf(stderr, "DEBUG CHECK FAILED: %s:%d\n", __FILE__, __LINE__); \ - abort(); \ - } else { \ - } - -#ifdef TORCHDYNAMO_DEBUG - -#define DEBUG_CHECK(cond) CHECK(cond) -#define DEBUG_NULL_CHECK(val) NULL_CHECK(val) -#define DEBUG_TRACE(msg, ...) \ - fprintf(stderr, "TRACE[%s:%d] " msg "\n", __func__, __LINE__, __VA_ARGS__) -#define DEBUG_TRACE0(msg) \ - fprintf(stderr, "TRACE[%s:%d] " msg "\n", __func__, __LINE__) - -#else - -#define DEBUG_CHECK(cond) -#define DEBUG_NULL_CHECK(val) -#define DEBUG_TRACE(msg, ...) -#define DEBUG_TRACE0(msg) - -#endif - -// Flag to just run a frame normally -#define SKIP_CODE ((void *)0x1) - -static PyObject *noargs = NULL; /* cached empty tuple */ -static PyObject *dotzerokey = NULL; /* ".0" */ -static PyObject *guard_fail_hook = NULL; -static PyObject *guard_error_hook = NULL; - -size_t extra_index = -1; - -static Py_tss_t eval_frame_callback_key = Py_tss_NEEDS_INIT; - -inline static PyObject *eval_frame_callback_get(void) { - void *result = PyThread_tss_get(&eval_frame_callback_key); - if (unlikely(result == NULL)) { - Py_RETURN_NONE; - } else { - return (PyObject *)result; - } -} - -inline static void eval_frame_callback_set(PyObject *obj) { - PyThread_tss_set(&eval_frame_callback_key, obj); -} - -static void ignored(void *obj) {} -static PyObject *_custom_eval_frame_shim(PyThreadState *tstate, - PyFrameObject *frame, int throw_flag); -static PyObject *_custom_eval_frame(PyThreadState *tstate, PyFrameObject *frame, - int throw_flag, PyObject *callback); -#if PY_VERSION_HEX >= 0x03090000 -static PyObject *custom_eval_frame_shim(PyThreadState *tstate, - PyFrameObject *frame, int throw_flag) { - return _custom_eval_frame_shim(tstate, frame, throw_flag); -} -#else -static PyObject *custom_eval_frame_shim(PyFrameObject *frame, int throw_flag) { - PyThreadState *tstate = PyThreadState_GET(); - return _custom_eval_frame_shim(tstate, frame, throw_flag); -} -#endif - -inline static PyObject *eval_frame_default(PyThreadState *tstate, - PyFrameObject *frame, - int throw_flag) { -#if PY_VERSION_HEX >= 0x03090000 - if (tstate == NULL) { - tstate = PyThreadState_GET(); - } - return _PyEval_EvalFrameDefault(tstate, frame, throw_flag); -#else - return _PyEval_EvalFrameDefault(frame, throw_flag); -#endif -} - -inline static void enable_eval_frame_shim(PyThreadState *tstate) { -#if PY_VERSION_HEX >= 0x03090000 - if (_PyInterpreterState_GetEvalFrameFunc(tstate->interp) != - &custom_eval_frame_shim) { - _PyInterpreterState_SetEvalFrameFunc(tstate->interp, - &custom_eval_frame_shim); - } -#else - if (tstate->interp->eval_frame != &custom_eval_frame_shim) { - // First call - tstate->interp->eval_frame = &custom_eval_frame_shim; - } -#endif -} - -inline static void enable_eval_frame_default(PyThreadState *tstate) { -#if PY_VERSION_HEX >= 0x03090000 - if (_PyInterpreterState_GetEvalFrameFunc(tstate->interp) != - &_PyEval_EvalFrameDefault) { - _PyInterpreterState_SetEvalFrameFunc(tstate->interp, - &_PyEval_EvalFrameDefault); - } -#else - if (tstate->interp->eval_frame != &_PyEval_EvalFrameDefault) { - // First call - tstate->interp->eval_frame = &_PyEval_EvalFrameDefault; - } -#endif -} - -static inline PyObject *call_callback(PyObject *callable, PyObject *frame, - long cache_len) { - PyObject *args = Py_BuildValue("(Ol)", frame, cache_len); - NULL_CHECK(args); - PyObject *result = PyObject_CallObject(callable, args); - Py_DECREF(args); - return result; -} - -typedef struct cache_entry { - // check the guards: lambda: : bool - PyObject *check_fn; - // modified user bytecode (protected by check_fn's guards) - PyCodeObject *code; - // on a cache miss, linked list of next thing to try - struct cache_entry *next; -} CacheEntry; - -static CacheEntry *create_cache_entry(CacheEntry *next, - PyObject *guarded_code) { - CacheEntry *e = (CacheEntry *)malloc(sizeof(CacheEntry)); - DEBUG_NULL_CHECK(e); - e->check_fn = PyObject_GetAttrString(guarded_code, "check_fn"); - NULL_CHECK(e->check_fn); - e->code = (PyCodeObject *)PyObject_GetAttrString(guarded_code, "code"); - NULL_CHECK(e->code); - e->next = next; - return e; -} - -static void destroy_cache_entry(CacheEntry *e) { - if (e == NULL || e == SKIP_CODE) { - return; - } - Py_XDECREF(e->check_fn); - Py_XDECREF(e->code); - destroy_cache_entry(e->next); - free(e); -} - -#ifdef TORCHDYNAMO_DEBUG -inline static const char *name(PyFrameObject *frame) { - DEBUG_CHECK(PyUnicode_Check(frame->f_code->co_name)); - return PyUnicode_AsUTF8(frame->f_code->co_name); -} -#endif - -static void call_guard_fail_hook(PyObject *hook, CacheEntry *e, - PyObject *f_locals) { - // call debugging logic when a guard fails - PyObject *args = PyTuple_Pack(4, e->check_fn, e->code, f_locals, - (e->next == NULL ? Py_True : Py_False)); - NULL_CHECK(args); - PyObject *result = PyObject_CallObject(hook, args); - NULL_CHECK(result); - Py_DECREF(result); - Py_DECREF(args); -} - -static PyCodeObject *lookup(CacheEntry *e, PyObject *f_locals) { - if (e == NULL) { - return NULL; - } - PyObject *dotzero = PyDict_GetItem(f_locals, dotzerokey); - PyObject *valid = NULL; - if (unlikely(dotzero != NULL)) { - // .0 is a special variable name used for implicit args - PyObject *args = PyTuple_Pack(1, dotzero); - NULL_CHECK(args); - valid = PyObject_Call(e->check_fn, args, f_locals); - Py_DECREF(args); - } else { - valid = PyObject_Call(e->check_fn, noargs, f_locals); - } - if (unlikely(valid == NULL)) { - PyErr_Print(); - if (guard_error_hook != NULL) { - call_guard_fail_hook(guard_error_hook, e, f_locals); - } - NULL_CHECK(valid); - } - Py_DECREF(valid); - if (valid == Py_True) { - return e->code; - } - if (unlikely(guard_fail_hook != NULL)) { - call_guard_fail_hook(guard_fail_hook, e, f_locals); - } - return lookup(e->next, f_locals); -} - -static long cache_size(CacheEntry *e) { - if (e == NULL) { - return 0; - } - return 1 + cache_size(e->next); -} - -inline static CacheEntry *get_extra(PyCodeObject *code) { - CacheEntry *extra = NULL; - _PyCode_GetExtra((PyObject *)code, extra_index, (void *)&extra); - return extra; -} - -inline static void set_extra(PyCodeObject *code, CacheEntry *extra) { - // TODO(jansel): would it be faster to bypass this? - _PyCode_SetExtra((PyObject *)code, extra_index, extra); -} - -inline static PyObject *eval_custom_code(PyThreadState *tstate, - PyFrameObject *frame, - PyCodeObject *code, int throw_flag) { - Py_ssize_t ncells = 0; - Py_ssize_t nfrees = 0; - Py_ssize_t nlocals_new = code->co_nlocals; - Py_ssize_t nlocals_old = frame->f_code->co_nlocals; - - if ((code->co_flags & CO_NOFREE) == 0) { - ncells = PyTuple_GET_SIZE(code->co_cellvars); - nfrees = PyTuple_GET_SIZE(code->co_freevars); - } - - DEBUG_NULL_CHECK(tstate); - DEBUG_NULL_CHECK(frame); - DEBUG_NULL_CHECK(code); - DEBUG_CHECK(ncells == PyTuple_GET_SIZE(frame->f_code->co_cellvars)); - DEBUG_CHECK(nfrees == PyTuple_GET_SIZE(frame->f_code->co_freevars)); - DEBUG_CHECK(nlocals_new >= nlocals_old); - - PyFrameObject *shadow = PyFrame_New(tstate, code, frame->f_globals, NULL); - if (shadow == NULL) { - return NULL; - } - - PyObject **fastlocals_old = frame->f_localsplus; - PyObject **fastlocals_new = shadow->f_localsplus; - - for (Py_ssize_t i = 0; i < nlocals_old; i++) { - Py_XINCREF(fastlocals_old[i]); - fastlocals_new[i] = fastlocals_old[i]; - } - - for (Py_ssize_t i = 0; i < ncells + nfrees; i++) { - Py_XINCREF(fastlocals_old[nlocals_old + i]); - fastlocals_new[nlocals_new + i] = fastlocals_old[nlocals_old + i]; - } - - PyObject *result = eval_frame_default(tstate, shadow, throw_flag); - Py_DECREF(shadow); - return result; -} - -static PyObject *_custom_eval_frame_shim(PyThreadState *tstate, - PyFrameObject *frame, int throw_flag) { - // Shims logic into one of three states. Can probably be refactored into a - // single func, later: - // - None: disables TorchDynamo - // - False: run-only mode (reuse existing compiles) - // - Python callable(): enables TorchDynamo - PyObject *callback = eval_frame_callback_get(); - - if (callback == Py_None) { - return eval_frame_default(tstate, frame, throw_flag); - } - - return _custom_eval_frame(tstate, frame, throw_flag, callback); -} - -static PyObject *_custom_eval_frame(PyThreadState *tstate, PyFrameObject *frame, - int throw_flag, PyObject *callback) { - DEBUG_TRACE("begin %s %s %i %i %i %i", name(frame), - PyUnicode_AsUTF8(frame->f_code->co_filename), frame->f_lineno, - frame->f_lasti, frame->f_iblock, frame->f_executing); - CacheEntry *extra = get_extra(frame->f_code); - if (extra == SKIP_CODE || (callback == Py_False && extra == NULL)) { - DEBUG_TRACE("skip %s", name(frame)); - return eval_frame_default(tstate, frame, throw_flag); - } - - // TODO(jansel): investigate directly using the "fast" representation - if (PyFrame_FastToLocalsWithError(frame) < 0) { - DEBUG_TRACE("error %s", name(frame)); - return NULL; - } - - // A callback of Py_False indicates "run only" mode, the cache is checked, but - // we never compile. - if (callback == Py_False) { - DEBUG_TRACE("In run only mode %s", name(frame)); - PyCodeObject *cached_code = lookup(extra, frame->f_locals); - if (cached_code != NULL) { - // used cached version - DEBUG_TRACE("cache hit %s", name(frame)); - return eval_custom_code(tstate, frame, cached_code, throw_flag); - } else { - DEBUG_TRACE("cache miss %s", name(frame)); - return eval_frame_default(tstate, frame, throw_flag); - } - } - DEBUG_CHECK(PyDict_CheckExact(frame->f_locals)); - DEBUG_CHECK(PyDict_CheckExact(frame->f_globals)); - DEBUG_CHECK(PyDict_CheckExact(frame->f_builtins)); - - // We don't run the current custom_eval_frame behavior for guards. - // So we temporarily set the callback to Py_None to drive the correct behavior - // in the shim. - eval_frame_callback_set(Py_None); - - PyCodeObject *cached_code = lookup(extra, frame->f_locals); - if (cached_code != NULL) { - // used cached version - DEBUG_TRACE("cache hit %s", name(frame)); - // Re-enable custom behavior - eval_frame_callback_set(callback); - return eval_custom_code(tstate, frame, cached_code, throw_flag); - } - // cache miss - - PyObject *result = - call_callback(callback, (PyObject *)frame, cache_size(extra)); - if (result == NULL) { - // internal exception, returning here will leak the exception into user code - // this is useful for debugging -- but we dont want it to happen outside of - // testing - return NULL; - } else if (result != Py_None) { - DEBUG_TRACE("create cache %s", name(frame)); - extra = create_cache_entry(extra, result); - Py_DECREF(result); - set_extra(frame->f_code, extra); - // Re-enable custom behavior - eval_frame_callback_set(callback); - return eval_custom_code(tstate, frame, extra->code, throw_flag); - } else { - DEBUG_TRACE("create skip %s", name(frame)); - Py_DECREF(result); - destroy_cache_entry(extra); - set_extra(frame->f_code, SKIP_CODE); - // Re-enable custom behavior - eval_frame_callback_set(callback); - return eval_frame_default(tstate, frame, throw_flag); - } -} - -static int active_dynamo_threads = 0; - -static PyObject *increment_working_threads(PyThreadState *tstate) { - active_dynamo_threads = active_dynamo_threads + 1; - if (active_dynamo_threads > 0) { - enable_eval_frame_shim(tstate); - } - Py_RETURN_NONE; -} - -static PyObject *decrement_working_threads(PyThreadState *tstate) { - if (active_dynamo_threads > 0) { - active_dynamo_threads = active_dynamo_threads - 1; - if (active_dynamo_threads == 0) { - enable_eval_frame_default(tstate); - } - } - Py_RETURN_NONE; -} - -static PyObject *set_eval_frame(PyObject *new_callback, PyThreadState *tstate) { - // Change the eval frame callback and return the old one - // - None: disables TorchDynamo - // - False: run-only mode (reuse existing compiles) - // - Python callable(): enables TorchDynamo - PyObject *old_callback = eval_frame_callback_get(); - - // owned by caller - Py_INCREF(old_callback); - - if (old_callback != Py_None && new_callback == Py_None) { - decrement_working_threads(tstate); - } else if (old_callback == Py_None && new_callback != Py_None) { - increment_working_threads(tstate); - } - - Py_INCREF(new_callback); - Py_DECREF(old_callback); - - // Set thread local callback. This will drive behavior of our shim, if/when it - // is installed. - eval_frame_callback_set(new_callback); - - return old_callback; -} - -static PyObject *set_eval_frame_py(PyObject *dummy, PyObject *args) { - PyObject *callback = NULL; - if (!PyArg_ParseTuple(args, "O:callback", &callback)) { - DEBUG_TRACE0("arg error"); - return NULL; - } - if (callback != Py_None && callback != Py_False && - !PyCallable_Check(callback)) { - DEBUG_TRACE0("arg error"); - PyErr_SetString(PyExc_TypeError, "expected a callable"); - return NULL; - } - DEBUG_TRACE("python enabled=%d and is run_only=%d", callback != Py_None, - callback == Py_False); - return set_eval_frame(callback, PyThreadState_GET()); -} - -static PyObject *reset_code(PyObject *dummy, PyObject *args) { - PyObject *code = NULL; - if (!PyArg_ParseTuple(args, "O:code", &code)) { - DEBUG_TRACE0("arg error"); - return NULL; - } - if (!PyCode_Check(code)) { - DEBUG_TRACE0("arg error"); - PyErr_SetString(PyExc_TypeError, "expected a code object"); - return NULL; - } - - destroy_cache_entry(get_extra((PyCodeObject *)code)); - set_extra((PyCodeObject *)code, NULL); - Py_RETURN_NONE; -} - -static PyObject *unsupported(PyObject *dummy, PyObject *args) { - // a dummy C function used in testing - PyObject *obj1 = NULL; - PyObject *obj2 = NULL; - if (!PyArg_ParseTuple(args, "OO", &obj1, &obj2)) { - return NULL; - } - Py_INCREF(obj2); - return obj2; -} - -static PyObject *skip_code(PyObject *dummy, PyObject *args) { - PyObject *obj = NULL; - if (!PyArg_ParseTuple(args, "O", &obj)) { - return NULL; - } - if (!PyCode_Check(obj)) { - PyErr_SetString(PyExc_TypeError, "expected a code object"); - return NULL; - } - set_extra((PyCodeObject *)obj, SKIP_CODE); - Py_RETURN_NONE; -} - -static PyObject *set_guard_fail_hook(PyObject *dummy, PyObject *args) { - PyObject *obj = NULL; - if (!PyArg_ParseTuple(args, "O", &obj)) { - return NULL; - } - Py_XDECREF(guard_fail_hook); - if (obj == Py_None) { - guard_fail_hook = NULL; - } else { - guard_fail_hook = obj; - Py_INCREF(guard_fail_hook); - } - Py_RETURN_NONE; -} - -static PyObject *set_guard_error_hook(PyObject *dummy, PyObject *args) { - PyObject *obj = NULL; - if (!PyArg_ParseTuple(args, "O", &obj)) { - return NULL; - } - Py_XDECREF(guard_error_hook); - if (obj == Py_None) { - guard_error_hook = NULL; - } else { - guard_error_hook = obj; - Py_INCREF(guard_error_hook); - } - Py_RETURN_NONE; -} - -static PyMethodDef _methods[] = { - {"set_eval_frame", set_eval_frame_py, METH_VARARGS, NULL}, - {"reset_code", reset_code, METH_VARARGS, NULL}, - {"unsupported", unsupported, METH_VARARGS, NULL}, - {"skip_code", skip_code, METH_VARARGS, NULL}, - {"set_guard_fail_hook", set_guard_fail_hook, METH_VARARGS, NULL}, - {"set_guard_error_hook", set_guard_error_hook, METH_VARARGS, NULL}, - {NULL, NULL, 0, NULL}}; - -static struct PyModuleDef _module = { - PyModuleDef_HEAD_INIT, "_eval_frame", - "Module containing hooks to override eval_frame", -1, _methods}; - -PyMODINIT_FUNC PyInit__eval_frame(void) { - CHECK(sizeof(unsigned long) == sizeof(void *)); - extra_index = _PyEval_RequestCodeExtraIndex(ignored); - - int result = PyThread_tss_create(&eval_frame_callback_key); - CHECK(result == 0); - - Py_INCREF(Py_None); - eval_frame_callback_set(Py_None); - - noargs = PyTuple_New(0); - dotzerokey = PyUnicode_InternFromString(".0"); - return PyModule_Create(&_module); -} diff --git a/torchdynamo/_guards.cpp b/torchdynamo/_guards.cpp deleted file mode 100644 index 6b27314f26..0000000000 --- a/torchdynamo/_guards.cpp +++ /dev/null @@ -1,360 +0,0 @@ -#define PY_SSIZE_T_CLEAN -#include -#include -#include - -namespace { - -struct LocalState { - // TLS state that changes operators - c10::impl::LocalDispatchKeySet dispatch_modifier; - bool grad_mode_enabled; - - at::DispatchKeySet apply(at::DispatchKeySet ks) const { - return (ks | dispatch_modifier.included_) - dispatch_modifier.excluded_; - } - - LocalState() - : dispatch_modifier(c10::impl::tls_local_dispatch_key_set()), - grad_mode_enabled(at::GradMode::is_enabled()) {} -}; - -class TensorCheck { -public: - TensorCheck(const LocalState &state, PyTypeObject *pt, const at::Tensor &v, - bool dynamic_shapes) - : pytype(pt), dispatch_key_(state.apply(v.key_set()).raw_repr()), - dtype_(v.dtype().toScalarType()), - requires_grad_(state.grad_mode_enabled && v.requires_grad()), - dynamic_shapes_(dynamic_shapes) { - auto ndim = v.ndimension(); - const auto &sizes = v.sizes(); - const auto &strides = v.strides(); - sizes_.reserve(ndim); - strides_.reserve(ndim); - for (auto i : c10::irange(ndim)) { - sizes_.emplace_back(sizes[i]); - strides_.emplace_back(strides[i]); - } - } - - bool check(const LocalState &state, const at::Tensor &v) { - if (dispatch_key_ != state.apply(v.key_set()).raw_repr() || - dtype_ != v.dtype().toScalarType() || - requires_grad_ != (state.grad_mode_enabled && v.requires_grad())) { - return false; - } - auto ndim = static_cast(v.ndimension()); - if (ndim != sizes_.size()) { - return false; - } - if (!dynamic_shapes_) { - const auto &sizes = v.sizes(); - const auto &strides = v.strides(); - for (auto i : c10::irange(ndim)) { - if (sizes_[i] != sizes[i] || strides_[i] != strides[i]) { - return false; - } - } - } - return true; - } - - std::string check_verbose(const LocalState &state, const at::Tensor &v, - std::string tensor_name) { - std::stringstream fail_reason; - fail_reason << "tensor '" << tensor_name << "' "; - if (dispatch_key_ != state.apply(v.key_set()).raw_repr()) { - // return fmt::format("tensor dispatch key mismatch. expected {}, actual - // {}", dispatch_key_, state.apply(v.key_set()).raw_repr()); - fail_reason << "dispatch key set mismatch. expected " - << c10::DispatchKeySet(c10::DispatchKeySet::RAW, - dispatch_key_) - << ", actual " << state.apply(v.key_set()); - return fail_reason.str(); - } else if (dtype_ != v.dtype().toScalarType()) { - // return fmt::format("tensor dtype mismatch. expected {}, actual {}", - // dtype_, v.dtype().toScalarType()); - fail_reason << "dtype mismatch. expected " << dtype_ << ", actual " - << v.dtype().toScalarType(); - return fail_reason.str(); - } else if (requires_grad_ != - (state.grad_mode_enabled && v.requires_grad())) { - // return fmt::format("tensor requires_grad mismatch. expected {}", - // requires_grad_); - fail_reason << "requires_grad mismatch. expected requires_grad=" - << requires_grad_; - return fail_reason.str(); - } - size_t ndim = static_cast(v.ndimension()); - if (ndim != sizes_.size()) { - // return fmt::format("tensor rank mismatch. expected {}, actual {}", - // sizes_.size(), ndim); - fail_reason << "rank mismatch. expected " << sizes_.size() << ", actual " - << ndim; - return fail_reason.str(); - } - if (!dynamic_shapes_) { - const auto &sizes = v.sizes(); - const auto &strides = v.strides(); - for (auto i : c10::irange(ndim)) { - if (sizes_[i] != sizes[i]) { - // return fmt::format("tensor size mismatch at index {}. expected {}, - // actual {}", i, sizes_[i], sizes[i]); - fail_reason << "size mismatch at index " << i << ". expected " - << sizes_[i] << ", actual " << sizes[i]; - return fail_reason.str(); - } else if (strides_[i] != strides[i]) { - // return fmt::format("tensor strides mismatch at index {}. expected - // {}, actual {}", i, strides_[i]); - fail_reason << "strides mismatch at index " << i << ". expected " - << strides_[i] << ", actual " << strides[i]; - return fail_reason.str(); - } - } - } - return ""; - } - - PyTypeObject *pytype; - -private: - uint64_t dispatch_key_; // DispatchKeySet includes device/layout - at::ScalarType dtype_; - bool requires_grad_; - bool dynamic_shapes_; - std::vector sizes_; - std::vector strides_; -}; - -typedef std::vector ChecksList; - -typedef struct { - PyObject_HEAD; - ChecksList *checks; -} TensorGuards; - -static void TensorGuards_dealloc(TensorGuards *self) { - if (self->checks != NULL) { - delete self->checks; - self->checks = NULL; - } - Py_TYPE(self)->tp_free((PyObject *)self); -} - -static PyObject *TensorGuards_new(PyTypeObject *type, PyObject *args, - PyObject *kwds) { - TensorGuards *self = (TensorGuards *)type->tp_alloc(type, 0); - if (self != NULL) { - self->checks = new ChecksList(); - } - return (PyObject *)self; -} - -static int TensorGuards_init(TensorGuards *self, PyObject *args, - PyObject *kwds) { - if (!PyTuple_CheckExact(args)) { - PyErr_SetString(PyExc_TypeError, "expected tuple()"); - return -1; - } - PyObject *dynamic_shapes_py = PyDict_GetItemString(kwds, "dynamic_shapes"); - if (dynamic_shapes_py == NULL) { - PyErr_SetString(PyExc_TypeError, "missing dynamic_shapes=..."); - return -1; - } - bool dynamic_shapes = PyObject_IsTrue(dynamic_shapes_py); - - auto &checks = *self->checks; - auto len = PyTuple_GET_SIZE(args); - checks.reserve(len); - LocalState state; - for (auto i : c10::irange(len)) { - PyObject *item = PyTuple_GET_ITEM(args, i); - if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) { - PyErr_SetString(PyExc_TypeError, "expected Tensor()"); - return -1; - } - checks.emplace_back(TensorCheck(state, Py_TYPE(item), - THPVariable_Unpack(item), dynamic_shapes)); - } - return 0; -} - -PyObject *TensorGuards_check(TensorGuards *self, PyObject *args) { - if (!PyTuple_CheckExact(args)) { - PyErr_SetString(PyExc_TypeError, "expected tuple()"); - return NULL; - } - auto &checks = *self->checks; - auto len = PyTuple_GET_SIZE(args); - - if (static_cast(checks.size()) != len) { - PyErr_SetString(PyExc_TypeError, "wrong length"); - return NULL; - } - - LocalState state; - - for (auto i : c10::irange(len)) { - PyObject *item = PyTuple_GET_ITEM(args, i); - if (Py_TYPE(item) != checks[i].pytype) { - Py_RETURN_FALSE; - } - if (!checks[i].check(state, THPVariable_Unpack(item))) { - Py_RETURN_FALSE; - } - } - - Py_RETURN_TRUE; -} - -PyObject *TensorGuards_check_verbose(TensorGuards *self, PyObject *args, - PyObject *kwargs) { - if (!PyTuple_CheckExact(args)) { - PyErr_SetString(PyExc_TypeError, "expected tuple()"); - return NULL; - } - auto &checks = *self->checks; - auto len = PyTuple_GET_SIZE(args); - - if (static_cast(checks.size()) != len) { - PyErr_SetString(PyExc_TypeError, "wrong length"); - return NULL; - } - - PyObject *tensor_check_names_py = - PyDict_GetItemString(kwargs, "tensor_check_names"); - if (tensor_check_names_py == NULL) { - PyErr_SetString(PyExc_TypeError, "missing tensor_check_names kwarg"); - return NULL; - } - - if (!PyList_Check(tensor_check_names_py)) { - PyErr_SetString(PyExc_TypeError, "tensor_check_names kwarg must be a list"); - return NULL; - } - - auto names_size = PyList_Size(tensor_check_names_py); - if (names_size != static_cast(checks.size())) { - PyErr_SetString(PyExc_TypeError, - "tensor_check_names should be the same size as # tensors"); - return NULL; - } - - std::vector tensor_check_names; - tensor_check_names.reserve(names_size); - for (auto i : c10::irange(names_size)) { - PyObject *value = PyList_GetItem(tensor_check_names_py, i); - if (!PyUnicode_Check(value)) { - PyErr_SetString(PyExc_TypeError, - "tensor_check_names must only contain strings"); - return NULL; - } - tensor_check_names.emplace_back(PyUnicode_AsUTF8(value)); - } - - LocalState state; - for (auto i : c10::irange(len)) { - PyObject *item = PyTuple_GET_ITEM(args, i); - if (Py_TYPE(item) != checks[i].pytype) { - std::stringstream fail_reason; - PyObject *type_str = PyObject_Str(PyObject_Type(item)); - fail_reason << "expected type of '" << tensor_check_names[i] - << "' to be a tensor type, "; - if (!type_str) { - fail_reason << "but found a different type"; - } else { - fail_reason << "' but found " << PyUnicode_AsUTF8(type_str); - } - return Py_BuildValue("s", fail_reason.str().c_str()); - } - std::string fail_reason = checks[i].check_verbose( - state, THPVariable_Unpack(item), tensor_check_names[i]); - if (fail_reason.length() > 0) { - return Py_BuildValue("s", fail_reason.c_str()); - } - } - - Py_RETURN_TRUE; -} - -static PyMethodDef TensorGuards_methods[] = { - {"check", (PyCFunction)TensorGuards_check, METH_VARARGS, ""}, - {"check_verbose", (PyCFunction)TensorGuards_check_verbose, - METH_VARARGS | METH_KEYWORDS, "verbose fail reasons for failed checks"}, - {NULL} /* Sentinel */ -}; - -static PyTypeObject TensorGuardsType = { - // NOLINTNEXTLINE - PyVarObject_HEAD_INIT(NULL, 0)}; - -static PyObject *check_type_id(PyObject *dummy, PyObject *args) { - // faster `lambda obj, expected: id(type(obj)) == expected` - PyObject *obj; - unsigned long expected; - if (!PyArg_ParseTuple(args, "Ok", &obj, &expected)) { - return NULL; - } - if (Py_TYPE(obj) == (void *)expected) { - Py_RETURN_TRUE; - } else { - Py_RETURN_FALSE; - } -} - -static PyObject *check_obj_id(PyObject *dummy, PyObject *args) { - // faster `lambda obj, expected: id(obj) == expected` - PyObject *obj; - unsigned long expected; - if (!PyArg_ParseTuple(args, "Ok", &obj, &expected)) { - return NULL; - } - if (obj == (void *)expected) { - Py_RETURN_TRUE; - } else { - Py_RETURN_FALSE; - } -} - -static PyMethodDef _methods[] = { - {"check_type_id", check_type_id, METH_VARARGS, NULL}, - {"check_obj_id", check_obj_id, METH_VARARGS, NULL}, - {NULL, NULL, 0, NULL}}; - -static struct PyModuleDef _module = {PyModuleDef_HEAD_INIT, "_guards", - "Module containing checks on tensors", -1, - _methods}; - -} // namespace - -PyMODINIT_FUNC PyInit__guards(void) { - // initialize TensorGuardsType - TensorGuardsType.tp_name = "torchdynamo._guards.TensorGuards"; - TensorGuardsType.tp_basicsize = sizeof(TensorGuards); - TensorGuardsType.tp_itemsize = 0; - TensorGuardsType.tp_dealloc = (destructor)TensorGuards_dealloc; - TensorGuardsType.tp_flags = Py_TPFLAGS_DEFAULT; - TensorGuardsType.tp_doc = "Check properties of a torch.Tensor"; - TensorGuardsType.tp_methods = TensorGuards_methods; - TensorGuardsType.tp_init = (initproc)TensorGuards_init; - TensorGuardsType.tp_new = TensorGuards_new; - - PyObject *m; - if (PyType_Ready(&TensorGuardsType) < 0) - return NULL; - - m = PyModule_Create(&_module); - if (m == NULL) - return NULL; - - Py_INCREF(&TensorGuardsType); - if (PyModule_AddObject(m, "TensorGuards", (PyObject *)&TensorGuardsType) < - 0) { - Py_DECREF(&TensorGuardsType); - Py_DECREF(m); - return NULL; - } - - return m; -} diff --git a/torchdynamo/allowed_functions.py b/torchdynamo/allowed_functions.py index a742782724..e2bb330c7f 100644 --- a/torchdynamo/allowed_functions.py +++ b/torchdynamo/allowed_functions.py @@ -93,6 +93,7 @@ def _disallowed_function_ids(): torch.set_rng_state, torch.autograd.profiler.profile, warnings.warn, + torch._C.dynamo.eval_frame.unsupported, ] # extract all dtypes from torch dtypes = [ @@ -123,7 +124,12 @@ def _is_allowed_module_prefix(obj): # Tensor.set_ with a Storage, and Storages cannot be traced with # AOTAutograd; so we need to graph-break. To ensure this, we inline # these functions, rather than keep them opaque-ly in the graph. - disallowed_modules = ("torch.optim.", "torch.nn.modules.rnn.") + disallowed_modules = ( + "torch.optim.", + "torch.nn.modules.rnn.", + "torch.dynamo.", + "torch._C.dynamo.", + ) allowed_modules_dot = tuple([x + "." for x in allowed_modules]) module = inspect.getmodule(obj) if module is None: diff --git a/torchdynamo/eval_frame.py b/torchdynamo/eval_frame.py index 11b626f95f..c94c468c6e 100644 --- a/torchdynamo/eval_frame.py +++ b/torchdynamo/eval_frame.py @@ -30,16 +30,12 @@ log = logging.getLogger(__name__) -try: - from . import _eval_frame -except (ModuleNotFoundError, ImportError) as e: - raise RuntimeError("run `python setup.py develop` to compile C extensions") from e - try: from torch.fx.experimental import proxy_tensor except (ModuleNotFoundError, ImportError): proxy_tensor = None +_eval_frame = torch._C.dynamo.eval_frame set_eval_frame = _eval_frame.set_eval_frame reset_code = _eval_frame.reset_code unsupported = _eval_frame.unsupported diff --git a/torchdynamo/guards.py b/torchdynamo/guards.py index ede83a67f6..7a8f6ffc71 100644 --- a/torchdynamo/guards.py +++ b/torchdynamo/guards.py @@ -23,9 +23,6 @@ from . import config from . import convert_frame from . import mutation_guard -from ._guards import TensorGuards -from ._guards import check_obj_id -from ._guards import check_type_id from .eval_frame import set_guard_error_hook from .eval_frame import set_guard_fail_hook from .exc import unimplemented @@ -39,6 +36,9 @@ from .utils import tuple_iterator_len log = logging.getLogger(__name__) +TensorGuards = torch._C.dynamo.guards.TensorGuards +check_obj_id = torch._C.dynamo.guards.check_obj_id +check_type_id = torch._C.dynamo.guards.check_type_id CLOSURE_VARS = collections.OrderedDict( From 260066e8715a6a190214c308d87f4c5a71d7b303 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Fri, 7 Oct 2022 08:47:20 -0700 Subject: [PATCH 02/36] Fixing issues --- benchmarks/common.py | 4 +++- benchmarks/huggingface.py | 2 +- benchmarks/microbenchmarks/microbench.py | 2 +- benchmarks/microbenchmarks/operatorbench.py | 2 +- benchmarks/runner.py | 2 +- benchmarks/timm_models.py | 7 +++---- benchmarks/torchbench.py | 2 +- copy_to_core.sh | 17 +++++++++++++++++ setup.py | 2 -- test/inductor/opinfo_harness.py | 5 +++-- torchdynamo/allowed_functions.py | 2 +- torchdynamo/config.py | 2 +- torchdynamo/convert_frame.py | 2 +- torchdynamo/eval_frame.py | 2 +- torchdynamo/guards.py | 6 +++--- torchdynamo/skipfiles.py | 2 ++ torchdynamo/variables/misc.py | 20 +++++++++++--------- torchinductor/compile_fx.py | 3 +-- torchinductor/lowering.py | 2 +- 19 files changed, 53 insertions(+), 33 deletions(-) create mode 100755 copy_to_core.sh diff --git a/benchmarks/common.py b/benchmarks/common.py index 797319449e..f24b503eb5 100644 --- a/benchmarks/common.py +++ b/benchmarks/common.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 import argparse import collections import copy @@ -1613,6 +1613,8 @@ def main(runner, original_dir=None): return sys.exit(-1) if not args.devices: + import torch + if torch.cuda.is_available(): args.devices = ["cuda"] else: diff --git a/benchmarks/huggingface.py b/benchmarks/huggingface.py index 8ec08af5a5..77d5c7227e 100755 --- a/benchmarks/huggingface.py +++ b/benchmarks/huggingface.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 import importlib import logging import os diff --git a/benchmarks/microbenchmarks/microbench.py b/benchmarks/microbenchmarks/microbench.py index 887f9b5eeb..ccf7b9e5ae 100755 --- a/benchmarks/microbenchmarks/microbench.py +++ b/benchmarks/microbenchmarks/microbench.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 import argparse import inspect import sys diff --git a/benchmarks/microbenchmarks/operatorbench.py b/benchmarks/microbenchmarks/operatorbench.py index d7bdc6e133..3a3f954e36 100644 --- a/benchmarks/microbenchmarks/operatorbench.py +++ b/benchmarks/microbenchmarks/operatorbench.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 import click import numpy as np import torch diff --git a/benchmarks/runner.py b/benchmarks/runner.py index 46de1ee2e4..c61363a266 100755 --- a/benchmarks/runner.py +++ b/benchmarks/runner.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 """ A wrapper over the benchmark infrastructure to generate commonly used commands, diff --git a/benchmarks/timm_models.py b/benchmarks/timm_models.py index b0658069d6..ff2cdb1f3c 100755 --- a/benchmarks/timm_models.py +++ b/benchmarks/timm_models.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 import importlib import logging import os @@ -30,9 +30,8 @@ def pip_install(package): from timm.models import create_model TIMM_MODELS = dict() -filename = "timm_models_list.txt" -if os.path.exists("benchmarks"): - filename = "benchmarks/" + filename +filename = os.path.join(os.path.dirname(__file__), "timm_models_list.txt") + with open(filename, "r") as fh: lines = fh.readlines() lines = [line.rstrip() for line in lines] diff --git a/benchmarks/torchbench.py b/benchmarks/torchbench.py index 7a595f871b..63af798bb4 100755 --- a/benchmarks/torchbench.py +++ b/benchmarks/torchbench.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 import gc import importlib import logging diff --git a/copy_to_core.sh b/copy_to_core.sh new file mode 100755 index 0000000000..a00e9c4531 --- /dev/null +++ b/copy_to_core.sh @@ -0,0 +1,17 @@ +#!/bin/bash +set -ex + +rsync -ra ~/torchdynamo/torchdynamo/ ~/pytorch/torch/dynamo +rsync -ra ~/torchdynamo/torchinductor/ ~/pytorch/torch/inductor +rsync -ra ~/torchdynamo/test/{dynamo,inductor} ~/pytorch/test/ +rsync -ra ~/torchdynamo/benchmarks/ ~/pytorch/benchmarks/dynamo + +for DIR in ~/pytorch/test/{dynamo,inductor} ~/pytorch/benchmarks/dynamo +do + find $DIR -name '*.py' | xargs -n1 -- sed -i 's/torchdynamo/torch.dynamo/g' + find $DIR -name '*.py' | xargs -n1 -- sed -i 's/torchinductor/torch.inductor/g' + find $DIR -name '*.py' | xargs -n1 -- sed -i 's/_torch[.]inductor/_torchinductor/g' + find $DIR -name '*.py' | xargs -n1 -- sed -i 's@pytorch/torch[.]dynamo@pytorch/torchdynamo@g' +done + +lintrunner -a diff --git a/setup.py b/setup.py index 09b3848d7d..cdade4d712 100755 --- a/setup.py +++ b/setup.py @@ -1,10 +1,8 @@ #!/usr/bin/env python import sys -from setuptools import Extension from setuptools import find_packages from setuptools import setup -from torch.utils.cpp_extension import CppExtension long_description = """ TorchDynamo is a Python-level JIT compiler designed to make unmodified diff --git a/test/inductor/opinfo_harness.py b/test/inductor/opinfo_harness.py index 100d18100a..8607758213 100644 --- a/test/inductor/opinfo_harness.py +++ b/test/inductor/opinfo_harness.py @@ -11,7 +11,8 @@ os.environ["PYTORCH_TEST_RANGE_START"] = f"{start}" os.environ["PYTORCH_TEST_RANGE_END"] = f"{end}" popen = subprocess.Popen( - ["pytest", "test/test_torchinductor_opinfo.py"], stdout=subprocess.PIPE + ["pytest", "test/inductor/test_torchinductor_opinfo.py"], + stdout=subprocess.PIPE, ) for line in popen.stdout: print(line.decode(), end="") @@ -19,6 +20,6 @@ return_code = popen.wait() if return_code: raise subprocess.CalledProcessError( - return_code, ["pytest", "test/test_torchinductor_opinfo.py"] + return_code, ["pytest", "test/inductor/test_torchinductor_opinfo.py"] ) i = end + 1 diff --git a/torchdynamo/allowed_functions.py b/torchdynamo/allowed_functions.py index e2bb330c7f..b35164009a 100644 --- a/torchdynamo/allowed_functions.py +++ b/torchdynamo/allowed_functions.py @@ -93,7 +93,7 @@ def _disallowed_function_ids(): torch.set_rng_state, torch.autograd.profiler.profile, warnings.warn, - torch._C.dynamo.eval_frame.unsupported, + torch._C._dynamo.eval_frame.unsupported, ] # extract all dtypes from torch dtypes = [ diff --git a/torchdynamo/config.py b/torchdynamo/config.py index da0ac537a4..dd9ea8d9f0 100644 --- a/torchdynamo/config.py +++ b/torchdynamo/config.py @@ -55,7 +55,7 @@ } # root folder of the project -base_dir = dirname(dirname(abspath(__file__))) +base_dir = dirname(dirname(dirname(abspath(__file__)))) # don't specialize on shapes and strides and put shape ops in graph dynamic_shapes = os.environ.get("TORCHDYNAMO_DYNAMIC_SHAPES") == "1" diff --git a/torchdynamo/convert_frame.py b/torchdynamo/convert_frame.py index 375c923aa8..b114bc880a 100644 --- a/torchdynamo/convert_frame.py +++ b/torchdynamo/convert_frame.py @@ -311,7 +311,7 @@ def format_guard_failures(code): assert code in guard_failures, "TODO(whc) any other recompile reasons?" log.warning( - f"torchdynamo hit config.cache_size_limit ({config.cache_size_limit})\n" + f"torch.dynamo hit config.cache_size_limit ({config.cache_size_limit})\n" + f" function: {format_func_info(code)}\n" + f" reasons: {format_guard_failures(code)}\n" + f"to diagnose recompilation issues, see {troubleshooting_url}." diff --git a/torchdynamo/eval_frame.py b/torchdynamo/eval_frame.py index c94c468c6e..996f732963 100644 --- a/torchdynamo/eval_frame.py +++ b/torchdynamo/eval_frame.py @@ -35,7 +35,7 @@ except (ModuleNotFoundError, ImportError): proxy_tensor = None -_eval_frame = torch._C.dynamo.eval_frame +_eval_frame = torch._C._dynamo.eval_frame set_eval_frame = _eval_frame.set_eval_frame reset_code = _eval_frame.reset_code unsupported = _eval_frame.unsupported diff --git a/torchdynamo/guards.py b/torchdynamo/guards.py index 7a8f6ffc71..c7485afde1 100644 --- a/torchdynamo/guards.py +++ b/torchdynamo/guards.py @@ -36,9 +36,9 @@ from .utils import tuple_iterator_len log = logging.getLogger(__name__) -TensorGuards = torch._C.dynamo.guards.TensorGuards -check_obj_id = torch._C.dynamo.guards.check_obj_id -check_type_id = torch._C.dynamo.guards.check_type_id +TensorGuards = torch._C._dynamo.guards.TensorGuards +check_obj_id = torch._C._dynamo.guards.check_obj_id +check_type_id = torch._C._dynamo.guards.check_type_id CLOSURE_VARS = collections.OrderedDict( diff --git a/torchdynamo/skipfiles.py b/torchdynamo/skipfiles.py index 6df021a5e2..314606dbf5 100644 --- a/torchdynamo/skipfiles.py +++ b/torchdynamo/skipfiles.py @@ -194,4 +194,6 @@ def is_torch_inline_allowed(filename): def is_torch(filename): + if filename.startswith(_module_dir(torch.dynamo)): + return False return filename.startswith(_module_dir(torch)) diff --git a/torchdynamo/variables/misc.py b/torchdynamo/variables/misc.py index 5aa9b3dea2..ec22980412 100644 --- a/torchdynamo/variables/misc.py +++ b/torchdynamo/variables/misc.py @@ -332,19 +332,11 @@ def __init__(self, tx, target_values, initial_values=None, **kwargs): self.mode = None def exit(self, tx, *args): - def exit_functional_autocast(mode): - mode.__exit__(None, None, None) - tx.output.graph.create_node( "call_function", exit_functional_autocast, (self.mode,), {} ) def enter(self, tx): - def enter_functional_autocast(*vals): - mode = torch.amp.autocast(*vals) - mode.__enter__() - return mode - self.mode = tx.output.graph.create_node( "call_function", enter_functional_autocast, (*self.target_values,), {} ) @@ -356,6 +348,16 @@ def fn_name(self): return "torch.amp.autocast_mode.autocast" +def enter_functional_autocast(*vals): + mode = torch.amp.autocast(*vals) + mode.__enter__() + return mode + + +def exit_functional_autocast(mode): + mode.__exit__(None, None, None) + + class AutogradProfilerContextWrapperVariable(ContextWrappingVariable): def __init__(self, target_values=None, **kwargs): super(AutogradProfilerContextWrapperVariable, self).__init__( @@ -614,7 +616,7 @@ def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": if inspect.getattr_static(self.value, "_torchdynamo_disable", False): - unimplemented(f"call torchdynamo.disable() wrapped function {self.value}") + unimplemented(f"call torch.dynamo.disable() wrapped function {self.value}") else: try: path = inspect.getfile(self.value) diff --git a/torchinductor/compile_fx.py b/torchinductor/compile_fx.py index c54020de16..5fe5a1feb6 100644 --- a/torchinductor/compile_fx.py +++ b/torchinductor/compile_fx.py @@ -10,8 +10,6 @@ from torch._subclasses.fake_tensor import FakeTensor from torch.utils._mode_utils import no_dispatch -from torchdynamo.utils import count_calls - from . import config from . import overrides from .debug import DebugContext @@ -28,6 +26,7 @@ aot_autograd = dynamo_optimizations.backends.aot_autograd normalize_ir = dynamo_optimizations.normalize.normalize_ir is_aot_autograd_safe_to_run = dynamo_optimizations.training.is_aot_autograd_safe_to_run +count_calls = dynamo_utils.count_calls @dataclasses.dataclass diff --git a/torchinductor/lowering.py b/torchinductor/lowering.py index 7e93ab4af1..bb3cebba45 100644 --- a/torchinductor/lowering.py +++ b/torchinductor/lowering.py @@ -849,7 +849,7 @@ def make_fallback(kernel): assert ( kernel not in decompositions ), f"both a fallback and a decomp for same kernel: {kernel}" - if get_decompositions([kernel]): + if get_decompositions([kernel]) and kernel is not aten.cumsum: log.warning( f"make_fallback({kernel}): a decomposition exists, we should switch to it" ) From 1646d9dc7f02520f4df95fc4c51b0ba019c660bb Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Fri, 7 Oct 2022 10:24:36 -0700 Subject: [PATCH 03/36] lints --- test/dynamo/mock_modules/mock_module2.py | 2 +- test/dynamo/mock_modules/mock_module3.py | 2 +- test/dynamo/test_aot_autograd.py | 1 + test/dynamo/test_distributed.py | 1 + test/dynamo/test_dynamic_shapes.py | 1 + test/dynamo/test_export.py | 1 + test/dynamo/test_functions.py | 1 + test/dynamo/test_global.py | 1 + test/dynamo/test_minifier.py | 1 + test/dynamo/test_misc.py | 1 + test/dynamo/test_model_output.py | 1 + test/dynamo/test_modules.py | 1 + test/dynamo/test_no_fake_tensors.py | 1 + test/dynamo/test_nops.py | 1 + test/dynamo/test_optimizations.py | 1 + test/dynamo/test_optimizers.py | 1 + test/dynamo/test_python_autograd.py | 1 + test/dynamo/test_recompile_ux.py | 1 + test/dynamo/test_replay_record.py | 1 + test/dynamo/test_repros.py | 1 + test/dynamo/test_skip_non_tensor.py | 1 + test/dynamo/test_subgraphs.py | 1 + test/dynamo/test_unspec.py | 1 + test/dynamo/test_verify_correctness.py | 1 + test/inductor/test_torchinductor.py | 1 + test/inductor/test_torchinductor_opinfo.py | 1 + torchinductor/codegen/triton_conv_delta_x.j2 | 2 +- torchinductor/codegen/triton_conv_delta_x_hwc.j2 | 2 +- torchinductor/sizevars.py | 2 +- torchinductor/utils.py | 2 +- 30 files changed, 30 insertions(+), 6 deletions(-) diff --git a/test/dynamo/mock_modules/mock_module2.py b/test/dynamo/mock_modules/mock_module2.py index a4810f1cd2..7fe8979709 100644 --- a/test/dynamo/mock_modules/mock_module2.py +++ b/test/dynamo/mock_modules/mock_module2.py @@ -14,6 +14,6 @@ def method2(self, x): def method1(x, y): - z = torch.ones(1, 1) # noqa + torch.ones(1, 1) x.append(y) return x diff --git a/test/dynamo/mock_modules/mock_module3.py b/test/dynamo/mock_modules/mock_module3.py index c01d2b6066..8af77a237a 100644 --- a/test/dynamo/mock_modules/mock_module3.py +++ b/test/dynamo/mock_modules/mock_module3.py @@ -2,6 +2,6 @@ def method1(x, y): - z = torch.ones(1, 1) # noqa + torch.ones(1, 1) x.append(y) return x diff --git a/test/dynamo/test_aot_autograd.py b/test/dynamo/test_aot_autograd.py index e4aae96378..1f6323a82e 100644 --- a/test/dynamo/test_aot_autograd.py +++ b/test/dynamo/test_aot_autograd.py @@ -1,4 +1,5 @@ #!/usr/bin/env pytest +# Owner(s): ["module: dynamo"] import functools import torch diff --git a/test/dynamo/test_distributed.py b/test/dynamo/test_distributed.py index 36b1486390..ca60600534 100644 --- a/test/dynamo/test_distributed.py +++ b/test/dynamo/test_distributed.py @@ -1,4 +1,5 @@ #!/usr/bin/env pytest +# Owner(s): ["module: dynamo"] import os from unittest.mock import patch diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py index 838f8a2a8f..1be45e295e 100755 --- a/test/dynamo/test_dynamic_shapes.py +++ b/test/dynamo/test_dynamic_shapes.py @@ -1,4 +1,5 @@ #!/usr/bin/env pytest +# Owner(s): ["module: dynamo"] from torchdynamo.testing import make_test_cls_with_patches diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index 0663f87257..129c4f63af 100755 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -1,4 +1,5 @@ #!/usr/bin/env pytest +# Owner(s): ["module: dynamo"] from unittest.mock import patch import torch diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 30f15fb8cc..8b78de0bc8 100755 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -1,4 +1,5 @@ #!/usr/bin/env pytest +# Owner(s): ["module: dynamo"] import collections import functools import inspect diff --git a/test/dynamo/test_global.py b/test/dynamo/test_global.py index 06b5939a92..9b721d0609 100755 --- a/test/dynamo/test_global.py +++ b/test/dynamo/test_global.py @@ -1,4 +1,5 @@ #!/usr/bin/env pytest +# Owner(s): ["module: dynamo"] import torch import torchdynamo.testing diff --git a/test/dynamo/test_minifier.py b/test/dynamo/test_minifier.py index e714d007c9..1c7b2f3394 100644 --- a/test/dynamo/test_minifier.py +++ b/test/dynamo/test_minifier.py @@ -1,4 +1,5 @@ #!/usr/bin/env pytest +# Owner(s): ["module: dynamo"] import os import shutil from unittest.mock import patch diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 3f5add83ac..264551357d 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -1,4 +1,5 @@ #!/usr/bin/env pytest +# Owner(s): ["module: dynamo"] import collections import copy import dataclasses diff --git a/test/dynamo/test_model_output.py b/test/dynamo/test_model_output.py index a0fd4c3fc1..7ed2785721 100755 --- a/test/dynamo/test_model_output.py +++ b/test/dynamo/test_model_output.py @@ -1,4 +1,5 @@ #!/usr/bin/env pytest +# Owner(s): ["module: dynamo"] import dataclasses import unittest.mock diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index 455ca359e1..aa4490d28b 100755 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -1,4 +1,5 @@ #!/usr/bin/env pytest +# Owner(s): ["module: dynamo"] from copy import deepcopy from unittest.mock import patch diff --git a/test/dynamo/test_no_fake_tensors.py b/test/dynamo/test_no_fake_tensors.py index 4128c45429..9c4e788e5a 100755 --- a/test/dynamo/test_no_fake_tensors.py +++ b/test/dynamo/test_no_fake_tensors.py @@ -1,4 +1,5 @@ #!/usr/bin/env pytest +# Owner(s): ["module: dynamo"] from torchdynamo.testing import make_test_cls_with_patches from . import test_functions diff --git a/test/dynamo/test_nops.py b/test/dynamo/test_nops.py index c0906b55f4..b202862132 100755 --- a/test/dynamo/test_nops.py +++ b/test/dynamo/test_nops.py @@ -1,4 +1,5 @@ #!/usr/bin/env pytest +# Owner(s): ["module: dynamo"] import torch import torchdynamo.testing diff --git a/test/dynamo/test_optimizations.py b/test/dynamo/test_optimizations.py index 197fac71d9..46cd709006 100755 --- a/test/dynamo/test_optimizations.py +++ b/test/dynamo/test_optimizations.py @@ -1,4 +1,5 @@ #!/usr/bin/env pytest +# Owner(s): ["module: dynamo"] import importlib import json import os diff --git a/test/dynamo/test_optimizers.py b/test/dynamo/test_optimizers.py index 0a1884380e..60412307f1 100755 --- a/test/dynamo/test_optimizers.py +++ b/test/dynamo/test_optimizers.py @@ -1,4 +1,5 @@ #!/usr/bin/env pytest +# Owner(s): ["module: dynamo"] import inspect import unittest diff --git a/test/dynamo/test_python_autograd.py b/test/dynamo/test_python_autograd.py index 8bb064bfe8..008c34fb7b 100755 --- a/test/dynamo/test_python_autograd.py +++ b/test/dynamo/test_python_autograd.py @@ -1,4 +1,5 @@ #!/usr/bin/env pytest +# Owner(s): ["module: dynamo"] from typing import Callable from typing import Dict from typing import List diff --git a/test/dynamo/test_recompile_ux.py b/test/dynamo/test_recompile_ux.py index e25e04122c..14447e64ce 100755 --- a/test/dynamo/test_recompile_ux.py +++ b/test/dynamo/test_recompile_ux.py @@ -1,4 +1,5 @@ #!/usr/bin/env pytest +# Owner(s): ["module: dynamo"] import unittest import weakref diff --git a/test/dynamo/test_replay_record.py b/test/dynamo/test_replay_record.py index 613d5b9d2a..6b06e78264 100755 --- a/test/dynamo/test_replay_record.py +++ b/test/dynamo/test_replay_record.py @@ -1,4 +1,5 @@ #!/usr/bin/env pytest +# Owner(s): ["module: dynamo"] import logging import re import shutil diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 80fefdc6d6..c4ce316c4e 100755 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -1,4 +1,5 @@ #!/usr/bin/env pytest +# Owner(s): ["module: dynamo"] import collections import copy import inspect diff --git a/test/dynamo/test_skip_non_tensor.py b/test/dynamo/test_skip_non_tensor.py index d0122ddd37..2054d4544e 100755 --- a/test/dynamo/test_skip_non_tensor.py +++ b/test/dynamo/test_skip_non_tensor.py @@ -1,4 +1,5 @@ #!/usr/bin/env pytest +# Owner(s): ["module: dynamo"] from unittest.mock import patch import torch diff --git a/test/dynamo/test_subgraphs.py b/test/dynamo/test_subgraphs.py index 81c71e2c22..d1d4dcef37 100755 --- a/test/dynamo/test_subgraphs.py +++ b/test/dynamo/test_subgraphs.py @@ -1,4 +1,5 @@ #!/usr/bin/env pytest +# Owner(s): ["module: dynamo"] import unittest from unittest.mock import patch diff --git a/test/dynamo/test_unspec.py b/test/dynamo/test_unspec.py index 6140fba660..323d7fb895 100755 --- a/test/dynamo/test_unspec.py +++ b/test/dynamo/test_unspec.py @@ -1,4 +1,5 @@ #!/usr/bin/env pytest +# Owner(s): ["module: dynamo"] import functools import random import unittest diff --git a/test/dynamo/test_verify_correctness.py b/test/dynamo/test_verify_correctness.py index c6c21c8121..ea7c980806 100755 --- a/test/dynamo/test_verify_correctness.py +++ b/test/dynamo/test_verify_correctness.py @@ -1,4 +1,5 @@ #!/usr/bin/env pytest +# Owner(s): ["module: dynamo"] import importlib import operator import unittest diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index ecd7340621..d560ed9d9c 100755 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -1,4 +1,5 @@ #!/usr/bin/env pytest +# Owner(s): ["module: inductor"] import contextlib import dataclasses import functools diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index f96ef7aaba..072095c43c 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -1,3 +1,4 @@ +# Owner(s): ["module: inductor"] import atexit import os from collections import defaultdict diff --git a/torchinductor/codegen/triton_conv_delta_x.j2 b/torchinductor/codegen/triton_conv_delta_x.j2 index b3c28a9b37..a7bf8ac433 100644 --- a/torchinductor/codegen/triton_conv_delta_x.j2 +++ b/torchinductor/codegen/triton_conv_delta_x.j2 @@ -172,7 +172,7 @@ def {{kernel_name}}( {% if pointwise_code %} {{ pointwise_code | indent(4, true) }} - {# + {# z = tl.load(z_ptrs, mask=mask_z) acc += z #} diff --git a/torchinductor/codegen/triton_conv_delta_x_hwc.j2 b/torchinductor/codegen/triton_conv_delta_x_hwc.j2 index 6b953d06c8..34f2c38812 100644 --- a/torchinductor/codegen/triton_conv_delta_x_hwc.j2 +++ b/torchinductor/codegen/triton_conv_delta_x_hwc.j2 @@ -191,7 +191,7 @@ def {{kernel_name}}( {% if pointwise_code %} {{ pointwise_code | indent(4, true) }} - {# + {# z = tl.load(z_ptrs, mask=mask_z) acc += z #} diff --git a/torchinductor/sizevars.py b/torchinductor/sizevars.py index b774d9128e..48de8be3e7 100644 --- a/torchinductor/sizevars.py +++ b/torchinductor/sizevars.py @@ -426,7 +426,7 @@ def offset_var(self, index: Expr, vars: List[sympy.Symbol]) -> Expr: def stride_hints(self, index: Expr, vars: List[sympy.Symbol]) -> List[int]: for v in index.free_symbols: - if v.name.startswith("indirect"): # type: ignore + if v.name.startswith("indirect"): index = sympy_subs(index, {v: 0}) result = [] for s in self.stride_vars(index, vars): diff --git a/torchinductor/utils.py b/torchinductor/utils.py index 4716dd4daa..d45a805b4a 100644 --- a/torchinductor/utils.py +++ b/torchinductor/utils.py @@ -37,7 +37,7 @@ def has_triton(): @functools.lru_cache(None) def has_torchvision_roi_align(): try: - from torchvision.ops import roi_align # noqa + from torchvision.ops import roi_align # noqa: F401 return roi_align is not None and hasattr( getattr(torch.ops, "torchvision", None), "roi_align" From f59feb0e9cf096773824db39fbdff948fcde19f3 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Fri, 7 Oct 2022 11:04:57 -0700 Subject: [PATCH 04/36] bits --- benchmarks/training_loss.py | 0 copy_to_core.sh | 2 +- test/dynamo/test_aot_autograd.py | 1 - test/dynamo/test_aot_cudagraphs.py | 5 ----- test/dynamo/test_distributed.py | 1 - test/dynamo/test_dynamic_shapes.py | 1 - test/dynamo/test_export.py | 1 - test/dynamo/test_functions.py | 1 - test/dynamo/test_global.py | 1 - test/dynamo/test_global_declaration.py | 1 + test/dynamo/test_minifier.py | 1 - test/dynamo/test_misc.py | 1 - test/dynamo/test_model_output.py | 1 - test/dynamo/test_modules.py | 1 - test/dynamo/test_no_fake_tensors.py | 1 - test/dynamo/test_nops.py | 1 - test/dynamo/test_optimizations.py | 1 - test/dynamo/test_optimizers.py | 1 - test/dynamo/test_python_autograd.py | 1 - test/dynamo/test_recompile_ux.py | 1 - test/dynamo/test_replay_record.py | 1 - test/dynamo/test_repros.py | 1 - test/dynamo/test_skip_non_tensor.py | 1 - test/dynamo/test_subgraphs.py | 1 - test/dynamo/test_unspec.py | 1 - test/dynamo/test_verify_correctness.py | 1 - test/inductor/test_torchinductor.py | 1 - 27 files changed, 2 insertions(+), 29 deletions(-) mode change 100755 => 100644 benchmarks/training_loss.py mode change 100755 => 100644 test/dynamo/test_aot_cudagraphs.py mode change 100755 => 100644 test/dynamo/test_dynamic_shapes.py mode change 100755 => 100644 test/dynamo/test_export.py mode change 100755 => 100644 test/dynamo/test_functions.py mode change 100755 => 100644 test/dynamo/test_global.py mode change 100755 => 100644 test/dynamo/test_model_output.py mode change 100755 => 100644 test/dynamo/test_modules.py mode change 100755 => 100644 test/dynamo/test_no_fake_tensors.py mode change 100755 => 100644 test/dynamo/test_nops.py mode change 100755 => 100644 test/dynamo/test_optimizations.py mode change 100755 => 100644 test/dynamo/test_optimizers.py mode change 100755 => 100644 test/dynamo/test_python_autograd.py mode change 100755 => 100644 test/dynamo/test_recompile_ux.py mode change 100755 => 100644 test/dynamo/test_replay_record.py mode change 100755 => 100644 test/dynamo/test_repros.py mode change 100755 => 100644 test/dynamo/test_skip_non_tensor.py mode change 100755 => 100644 test/dynamo/test_subgraphs.py mode change 100755 => 100644 test/dynamo/test_unspec.py mode change 100755 => 100644 test/dynamo/test_verify_correctness.py mode change 100755 => 100644 test/inductor/test_torchinductor.py diff --git a/benchmarks/training_loss.py b/benchmarks/training_loss.py old mode 100755 new mode 100644 diff --git a/copy_to_core.sh b/copy_to_core.sh index a00e9c4531..9ddeacea61 100755 --- a/copy_to_core.sh +++ b/copy_to_core.sh @@ -14,4 +14,4 @@ do find $DIR -name '*.py' | xargs -n1 -- sed -i 's@pytorch/torch[.]dynamo@pytorch/torchdynamo@g' done -lintrunner -a +(cd ~/pytorch && (lintrunner -a || lintrunner -a)) diff --git a/test/dynamo/test_aot_autograd.py b/test/dynamo/test_aot_autograd.py index 1f6323a82e..1824c66bce 100644 --- a/test/dynamo/test_aot_autograd.py +++ b/test/dynamo/test_aot_autograd.py @@ -1,4 +1,3 @@ -#!/usr/bin/env pytest # Owner(s): ["module: dynamo"] import functools diff --git a/test/dynamo/test_aot_cudagraphs.py b/test/dynamo/test_aot_cudagraphs.py old mode 100755 new mode 100644 index 60b85428fd..e67afd2c83 --- a/test/dynamo/test_aot_cudagraphs.py +++ b/test/dynamo/test_aot_cudagraphs.py @@ -1,4 +1,3 @@ -#!/usr/bin/env pytest # Owner(s): ["module: cuda graphs"] import functools @@ -200,7 +199,3 @@ def fn(x): x = torch.empty(20, device="cuda:0") fn(x) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/dynamo/test_distributed.py b/test/dynamo/test_distributed.py index ca60600534..dab4f2fe72 100644 --- a/test/dynamo/test_distributed.py +++ b/test/dynamo/test_distributed.py @@ -1,4 +1,3 @@ -#!/usr/bin/env pytest # Owner(s): ["module: dynamo"] import os from unittest.mock import patch diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py old mode 100755 new mode 100644 index 1be45e295e..9012d4e35d --- a/test/dynamo/test_dynamic_shapes.py +++ b/test/dynamo/test_dynamic_shapes.py @@ -1,4 +1,3 @@ -#!/usr/bin/env pytest # Owner(s): ["module: dynamo"] from torchdynamo.testing import make_test_cls_with_patches diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py old mode 100755 new mode 100644 index 129c4f63af..cb18a1f170 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -1,4 +1,3 @@ -#!/usr/bin/env pytest # Owner(s): ["module: dynamo"] from unittest.mock import patch diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py old mode 100755 new mode 100644 index 8b78de0bc8..2d539a8b68 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -1,4 +1,3 @@ -#!/usr/bin/env pytest # Owner(s): ["module: dynamo"] import collections import functools diff --git a/test/dynamo/test_global.py b/test/dynamo/test_global.py old mode 100755 new mode 100644 index 9b721d0609..3e66c30cfd --- a/test/dynamo/test_global.py +++ b/test/dynamo/test_global.py @@ -1,4 +1,3 @@ -#!/usr/bin/env pytest # Owner(s): ["module: dynamo"] import torch diff --git a/test/dynamo/test_global_declaration.py b/test/dynamo/test_global_declaration.py index b0f2186bae..95995ca80a 100644 --- a/test/dynamo/test_global_declaration.py +++ b/test/dynamo/test_global_declaration.py @@ -1,3 +1,4 @@ +# Owner(s): ["module: dynamo"] import torch g_tensor_export = torch.ones(10) diff --git a/test/dynamo/test_minifier.py b/test/dynamo/test_minifier.py index 1c7b2f3394..7741cb2c57 100644 --- a/test/dynamo/test_minifier.py +++ b/test/dynamo/test_minifier.py @@ -1,4 +1,3 @@ -#!/usr/bin/env pytest # Owner(s): ["module: dynamo"] import os import shutil diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 264551357d..fe6e32c4bb 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -1,4 +1,3 @@ -#!/usr/bin/env pytest # Owner(s): ["module: dynamo"] import collections import copy diff --git a/test/dynamo/test_model_output.py b/test/dynamo/test_model_output.py old mode 100755 new mode 100644 index 7ed2785721..a78300c744 --- a/test/dynamo/test_model_output.py +++ b/test/dynamo/test_model_output.py @@ -1,4 +1,3 @@ -#!/usr/bin/env pytest # Owner(s): ["module: dynamo"] import dataclasses import unittest.mock diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py old mode 100755 new mode 100644 index aa4490d28b..34206b7793 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -1,4 +1,3 @@ -#!/usr/bin/env pytest # Owner(s): ["module: dynamo"] from copy import deepcopy diff --git a/test/dynamo/test_no_fake_tensors.py b/test/dynamo/test_no_fake_tensors.py old mode 100755 new mode 100644 index 9c4e788e5a..5700b3b33c --- a/test/dynamo/test_no_fake_tensors.py +++ b/test/dynamo/test_no_fake_tensors.py @@ -1,4 +1,3 @@ -#!/usr/bin/env pytest # Owner(s): ["module: dynamo"] from torchdynamo.testing import make_test_cls_with_patches diff --git a/test/dynamo/test_nops.py b/test/dynamo/test_nops.py old mode 100755 new mode 100644 index b202862132..7de329fb96 --- a/test/dynamo/test_nops.py +++ b/test/dynamo/test_nops.py @@ -1,4 +1,3 @@ -#!/usr/bin/env pytest # Owner(s): ["module: dynamo"] import torch diff --git a/test/dynamo/test_optimizations.py b/test/dynamo/test_optimizations.py old mode 100755 new mode 100644 index 46cd709006..7c83741903 --- a/test/dynamo/test_optimizations.py +++ b/test/dynamo/test_optimizations.py @@ -1,4 +1,3 @@ -#!/usr/bin/env pytest # Owner(s): ["module: dynamo"] import importlib import json diff --git a/test/dynamo/test_optimizers.py b/test/dynamo/test_optimizers.py old mode 100755 new mode 100644 index 60412307f1..5ab4756a50 --- a/test/dynamo/test_optimizers.py +++ b/test/dynamo/test_optimizers.py @@ -1,4 +1,3 @@ -#!/usr/bin/env pytest # Owner(s): ["module: dynamo"] import inspect diff --git a/test/dynamo/test_python_autograd.py b/test/dynamo/test_python_autograd.py old mode 100755 new mode 100644 index 008c34fb7b..6de772715c --- a/test/dynamo/test_python_autograd.py +++ b/test/dynamo/test_python_autograd.py @@ -1,4 +1,3 @@ -#!/usr/bin/env pytest # Owner(s): ["module: dynamo"] from typing import Callable from typing import Dict diff --git a/test/dynamo/test_recompile_ux.py b/test/dynamo/test_recompile_ux.py old mode 100755 new mode 100644 index 14447e64ce..6f12a62c38 --- a/test/dynamo/test_recompile_ux.py +++ b/test/dynamo/test_recompile_ux.py @@ -1,4 +1,3 @@ -#!/usr/bin/env pytest # Owner(s): ["module: dynamo"] import unittest import weakref diff --git a/test/dynamo/test_replay_record.py b/test/dynamo/test_replay_record.py old mode 100755 new mode 100644 index 6b06e78264..6a0a04fc63 --- a/test/dynamo/test_replay_record.py +++ b/test/dynamo/test_replay_record.py @@ -1,4 +1,3 @@ -#!/usr/bin/env pytest # Owner(s): ["module: dynamo"] import logging import re diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py old mode 100755 new mode 100644 index c4ce316c4e..f2fe9f383a --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -1,4 +1,3 @@ -#!/usr/bin/env pytest # Owner(s): ["module: dynamo"] import collections import copy diff --git a/test/dynamo/test_skip_non_tensor.py b/test/dynamo/test_skip_non_tensor.py old mode 100755 new mode 100644 index 2054d4544e..70b3b1f22c --- a/test/dynamo/test_skip_non_tensor.py +++ b/test/dynamo/test_skip_non_tensor.py @@ -1,4 +1,3 @@ -#!/usr/bin/env pytest # Owner(s): ["module: dynamo"] from unittest.mock import patch diff --git a/test/dynamo/test_subgraphs.py b/test/dynamo/test_subgraphs.py old mode 100755 new mode 100644 index d1d4dcef37..a3c321ee60 --- a/test/dynamo/test_subgraphs.py +++ b/test/dynamo/test_subgraphs.py @@ -1,4 +1,3 @@ -#!/usr/bin/env pytest # Owner(s): ["module: dynamo"] import unittest from unittest.mock import patch diff --git a/test/dynamo/test_unspec.py b/test/dynamo/test_unspec.py old mode 100755 new mode 100644 index 323d7fb895..7977ef7e60 --- a/test/dynamo/test_unspec.py +++ b/test/dynamo/test_unspec.py @@ -1,4 +1,3 @@ -#!/usr/bin/env pytest # Owner(s): ["module: dynamo"] import functools import random diff --git a/test/dynamo/test_verify_correctness.py b/test/dynamo/test_verify_correctness.py old mode 100755 new mode 100644 index ea7c980806..7275b3880a --- a/test/dynamo/test_verify_correctness.py +++ b/test/dynamo/test_verify_correctness.py @@ -1,4 +1,3 @@ -#!/usr/bin/env pytest # Owner(s): ["module: dynamo"] import importlib import operator diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py old mode 100755 new mode 100644 index 98e621c3eb..b813490a27 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -1,4 +1,3 @@ -#!/usr/bin/env pytest # Owner(s): ["module: inductor"] import contextlib import dataclasses From 67f50fd95779b71c7e9f32409b4b8b0c3b194e62 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Fri, 7 Oct 2022 13:58:09 -0700 Subject: [PATCH 05/36] Lints in core --- Makefile | 5 ----- .../microbenchmarks/bench_conv_fusion.py | 1 + benchmarks/microbenchmarks/bench_mm_fusion.py | 1 + .../microbenchmarks/benchmark_helper.py | 3 ++- .../microbenchmarks/operator_inp_utils.py | 2 +- test/dynamo/test_dynamic_shapes.py | 17 +++++++++++----- test/dynamo/test_functions.py | 1 + test/dynamo/test_global.py | 7 +++++-- test/dynamo/test_misc.py | 20 +++++++++---------- test/dynamo/test_modules.py | 7 +++++-- test/dynamo/test_no_fake_tensors.py | 17 +++++++++++----- test/dynamo/test_repros.py | 2 +- test/inductor/test_torchinductor.py | 12 ++++++----- test/inductor/test_torchinductor_opinfo.py | 8 ++++++-- torchdynamo/bytecode_analysis.py | 2 +- torchdynamo/convert_frame.py | 6 +----- torchdynamo/eval_frame.py | 14 ++++++------- torchdynamo/optimizations/inference.py | 2 +- torchdynamo/optimizations/normalize.py | 10 ++++------ torchdynamo/output_graph.py | 4 ++-- torchdynamo/side_effects.py | 2 +- torchdynamo/variables/functions.py | 5 +---- torchdynamo/variables/lists.py | 2 +- torchdynamo/variables/tensor.py | 9 +++++---- torchdynamo/variables/torch.py | 4 ++-- torchinductor/codegen/autotuner.py | 6 +++--- torchinductor/codegen/cpp.py | 2 +- torchinductor/codegen/triton.py | 8 ++++++-- torchinductor/graph.py | 4 ++-- torchinductor/ir.py | 17 +++++++++------- torchinductor/lowering.py | 6 +++--- torchinductor/scheduler.py | 6 +++--- torchinductor/utils.py | 4 ++-- 33 files changed, 120 insertions(+), 96 deletions(-) diff --git a/Makefile b/Makefile index f684eed00e..7954c56e08 100644 --- a/Makefile +++ b/Makefile @@ -33,16 +33,11 @@ overhead: develop format: isort $(PY_FILES) black $(PY_FILES) - ! which $(CLANG_FORMAT) >/dev/null 2>&1 || $(CLANG_FORMAT) -i $(C_FILES) lint: black --check --diff $(PY_FILES) isort --check --diff $(PY_FILES) flake8 $(PY_FILES) - mypy - ! which $(CLANG_TIDY) >/dev/null 2>&1 || $(CLANG_TIDY) $(C_FILES) -- \ - -I`python -c 'from distutils.sysconfig import get_python_inc as X; print(X())'` \ - `python -c 'from torch.utils.cpp_extension import include_paths; print(" ".join(map("-I{}".format, include_paths())))'` lint-deps: grep -E '(black|flake8|isort|click|torch|mypy)' requirements.txt | xargs $(PIP) install diff --git a/benchmarks/microbenchmarks/bench_conv_fusion.py b/benchmarks/microbenchmarks/bench_conv_fusion.py index 6124c2caaf..7310d050b7 100644 --- a/benchmarks/microbenchmarks/bench_conv_fusion.py +++ b/benchmarks/microbenchmarks/bench_conv_fusion.py @@ -1,3 +1,4 @@ +# flake8: noqa import model import torch import triton diff --git a/benchmarks/microbenchmarks/bench_mm_fusion.py b/benchmarks/microbenchmarks/bench_mm_fusion.py index b243eb4108..4de895d3fe 100644 --- a/benchmarks/microbenchmarks/bench_mm_fusion.py +++ b/benchmarks/microbenchmarks/bench_mm_fusion.py @@ -1,3 +1,4 @@ +# flake8: noqa import torch import triton from prettytable import PrettyTable diff --git a/benchmarks/microbenchmarks/benchmark_helper.py b/benchmarks/microbenchmarks/benchmark_helper.py index 06845fc20a..971d7c15c8 100644 --- a/benchmarks/microbenchmarks/benchmark_helper.py +++ b/benchmarks/microbenchmarks/benchmark_helper.py @@ -1,7 +1,8 @@ from torch.utils.benchmark import Timer -def time_with_torch_timer(fn, args, kwargs={}, iters=100): +def time_with_torch_timer(fn, args, kwargs=None, iters=100): + kwargs = kwargs or {} env = {"args": args, "kwargs": kwargs, "fn": fn} fn_call = "fn(*args, **kwargs)" diff --git a/benchmarks/microbenchmarks/operator_inp_utils.py b/benchmarks/microbenchmarks/operator_inp_utils.py index 23dc2c3d6b..5d84ea81cb 100644 --- a/benchmarks/microbenchmarks/operator_inp_utils.py +++ b/benchmarks/microbenchmarks/operator_inp_utils.py @@ -85,7 +85,7 @@ def serialize_sparse_tensor(e): def deserialize_sparse_tensor(size, dtype, layout, is_coalesced, nnz=None): - assert False, "NYI" + raise NotImplementedError() def deserialize_tensor(size, dtype, stride=None): diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py index 9012d4e35d..c74572360c 100644 --- a/test/dynamo/test_dynamic_shapes.py +++ b/test/dynamo/test_dynamic_shapes.py @@ -2,11 +2,18 @@ from torchdynamo.testing import make_test_cls_with_patches -from . import test_functions -from . import test_misc -from . import test_modules -from . import test_repros -from . import test_unspec +try: + from . import test_functions + from . import test_misc + from . import test_modules + from . import test_repros + from . import test_unspec +except ImportError: + import test_functions + import test_misc + import test_modules + import test_repros + import test_unspec def make_dynamic_cls(cls): diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 2d539a8b68..86b93e50ef 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -1,4 +1,5 @@ # Owner(s): ["module: dynamo"] +# flake8: noqa import collections import functools import inspect diff --git a/test/dynamo/test_global.py b/test/dynamo/test_global.py index 3e66c30cfd..b706c066f0 100644 --- a/test/dynamo/test_global.py +++ b/test/dynamo/test_global.py @@ -4,10 +4,13 @@ import torchdynamo.testing from torchdynamo.testing import same -from . import test_global_declaration +try: + from . import test_global_declaration +except ImportError: + import test_global_declaration -class Pair(object): +class Pair(object): # noqa: B903 def __init__(self, x, y): self.x = x self.y = y diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index fe6e32c4bb..a1ae8df314 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -747,7 +747,7 @@ def fn(seq): return type(seq)([a + 1, b + 2, a + b]) args1 = [torch.randn(10), torch.randn(10)] - args2 = tuple([torch.randn(10), torch.randn(10)]) + args2 = (torch.randn(10), torch.randn(10)) correct1 = fn(args1) correct2 = fn(args2) cnts = torchdynamo.testing.CompileCounter() @@ -760,7 +760,7 @@ def fn(seq): self.assertEqual(cnts.op_count, 6) def test_setattr_mutation1(self): - class MyObj: + class MyObj: # noqa: B903 def __init__(self, a, b): self.a = a self.b = b @@ -1371,21 +1371,21 @@ def test_const_dict_variable_python_type(self): def test_builtin_subclasses_as_method_on_class_type(self): class Foo: - def __init__(name): + def __init__(self, name): self.ame_ = name def get_name(self): return "Foo " + self.name_ class Bar(Foo): - def __init__(name): + def __init__(self, name): self.name_ = name def get_name(self): return "Bar " + self.name_ class Baz(Foo): - def __init__(name): + def __init__(self, name): # noqa: B903 self.name_ = name def get_name(self): @@ -1406,21 +1406,21 @@ def fn(): def test_builtin_subclasses_as_method_on_var(self): class Foo: - def __init__(name): + def __init__(self, name): self.name_ = name def get_name(self): return "Foo " + self.name_ class Bar(Foo): - def __init__(name): + def __init__(self, name): self.name_ = name def get_name(self): return "Bar " + self.name_ class Baz(Bar): - def __init__(name): + def __init__(self, name): self.name_ = name def get_name(self): @@ -1715,7 +1715,7 @@ def foo(mod, x): def test_update_locals_and_stack_uses_shared_cache(self): def fn(x): perm = [0, 3, 5] - perm = [i for i in range(min(perm))] + perm + perm = list(range(min(perm))) + perm perm.extend(i for i in range(x.dim()) if i not in perm) return perm @@ -2529,7 +2529,7 @@ def test_generate_tensor_from_list_of_numpy_primitive_type(self): # Test sth like torch.LongTensor(list(np.int64, np.int64, ...)) def fn(): x = np.array([1, 2, 3, 4, 5, 6], dtype=np.int64) - y = list((x[0], x[2], x[4])) + y = [x[0], x[2], x[4]] z = torch.LongTensor(y) return z diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index 34206b7793..3ce2ed3cbd 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -14,7 +14,10 @@ from torchdynamo.mutation_guard import GenerationTracker from torchdynamo.testing import same -from . import test_functions +try: + from . import test_functions +except ImportError: + import test_functions class BasicModule(torch.nn.Module): @@ -485,7 +488,7 @@ def custom_add(cls, x): class SuperChildCallsClassMethod(ComplicatedSuperParent): @classmethod - def child_func(self, x): + def child_func(cls, x): x = super().custom_add(x) return x diff --git a/test/dynamo/test_no_fake_tensors.py b/test/dynamo/test_no_fake_tensors.py index 5700b3b33c..bff8c0cd50 100644 --- a/test/dynamo/test_no_fake_tensors.py +++ b/test/dynamo/test_no_fake_tensors.py @@ -1,11 +1,18 @@ # Owner(s): ["module: dynamo"] from torchdynamo.testing import make_test_cls_with_patches -from . import test_functions -from . import test_misc -from . import test_modules -from . import test_repros -from . import test_unspec +try: + from . import test_functions + from . import test_misc + from . import test_modules + from . import test_repros + from . import test_unspec +except ImportError: + import test_functions + import test_misc + import test_modules + import test_repros + import test_unspec def make_no_fake_cls(cls): diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index f2fe9f383a..481b26901a 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -43,7 +43,7 @@ def has_detectron2(): from detectron2.layers.mask_ops import _paste_masks_tensor_shape return _paste_masks_tensor_shape is not None - except (ImportError, ModuleNotFoundError): + except ImportError: return False diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index b813490a27..6afb6cda13 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -41,7 +41,7 @@ assert get_decompositions([torch.ops.aten.trace]) # Requires functorch from torchinductor.compile_fx import compile_fx_inner -except (ImportError, ModuleNotFoundError, AssertionError) as e: +except (ImportError, AssertionError) as e: sys.stderr.write(f"{type(e)}: {e}\n") raise unittest.SkipTest("requires sympy/functorch") @@ -64,7 +64,7 @@ try: importlib.import_module("triton") HAS_CUDA = True - except (ImportError, ModuleNotFoundError): + except ImportError: pass requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda") @@ -141,7 +141,7 @@ def check_model( self: TestCase, model, example_inputs, - kwargs={}, + kwargs=None, *, atol=None, rtol=None, @@ -152,6 +152,7 @@ def check_model( reference_in_float=True, assert_equal=True, ): + kwargs = kwargs or {} torchdynamo.reset() ref_inputs = example_inputs @@ -250,7 +251,7 @@ def check_model_cuda( self: TestCase, model, example_inputs, - kwargs={}, + kwargs=None, *, atol=None, rtol=None, @@ -261,6 +262,7 @@ def check_model_cuda( reference_in_float=True, assert_equal=True, ): + kwargs = kwargs or {} if hasattr(model, "to"): model = model.to("cuda") @@ -483,7 +485,7 @@ def test_indexing_join(self): class CommonTemplate: @classmethod - def install(my_cls, other_cls, suffix): + def install(my_cls, other_cls, suffix): # noqa: B902 for name, value in my_cls.__dict__.items(): if name.startswith("test_"): setattr(other_cls, f"{name}_{suffix}", value) diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 7fed729e07..fb595b7051 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -20,8 +20,12 @@ import torchdynamo -from .test_torchinductor import check_model -from .test_torchinductor import check_model_cuda +try: + from .test_torchinductor import check_model + from .test_torchinductor import check_model_cuda +except ImportError: + from test_torchinductor import check_model + from test_torchinductor import check_model_cuda bf16 = torch.bfloat16 # not tested f64 = torch.float64 diff --git a/torchdynamo/bytecode_analysis.py b/torchdynamo/bytecode_analysis.py index 621cc2fc70..541336ba48 100644 --- a/torchdynamo/bytecode_analysis.py +++ b/torchdynamo/bytecode_analysis.py @@ -85,7 +85,7 @@ def walk(state, start): elif "STORE" in inst.opname: state.writes.add(inst.argval) else: - assert False, f"unhandled {inst.opname}" + raise NotImplementedError(f"unhandled {inst.opname}") if inst.opcode in JUMP_OPCODES: walk(may, indexof[id(inst.target)]) state = may diff --git a/torchdynamo/convert_frame.py b/torchdynamo/convert_frame.py index b114bc880a..6b7a5e8f05 100644 --- a/torchdynamo/convert_frame.py +++ b/torchdynamo/convert_frame.py @@ -162,11 +162,7 @@ def has_tensor(obj): elif is_namedtuple(obj): seen_ids[obj_id] = any([has_tensor(getattr(obj, v)) for v in obj._fields]) return seen_ids[obj_id] - elif ( - not is_allowed(obj) - and hasattr(obj, "__dict__") - and len(getattr(obj, "__dict__")) - ): + elif not is_allowed(obj) and hasattr(obj, "__dict__") and len(obj.__dict__): seen_ids[obj_id] = any([has_tensor(v) for v in obj.__dict__.values()]) return seen_ids[obj_id] else: diff --git a/torchdynamo/eval_frame.py b/torchdynamo/eval_frame.py index 996f732963..38999f1e71 100644 --- a/torchdynamo/eval_frame.py +++ b/torchdynamo/eval_frame.py @@ -32,7 +32,7 @@ try: from torch.fx.experimental import proxy_tensor -except (ModuleNotFoundError, ImportError): +except ImportError: proxy_tensor = None _eval_frame = torch._C._dynamo.eval_frame @@ -78,7 +78,7 @@ def innermost_fn(fn): """ unaltered_fn = fn while hasattr(unaltered_fn, "_torchdynamo_orig_callable"): - unaltered_fn = getattr(unaltered_fn, "_torchdynamo_orig_callable") + unaltered_fn = unaltered_fn._torchdynamo_orig_callable assert callable(unaltered_fn) return unaltered_fn @@ -464,9 +464,9 @@ def produce_matching(source_args, candidate_args): dict_of_source_args[id(arg.item())] ) else: - assert ( - False - ), "Dynamo input/output is not consistent with traced input/output" + raise AssertionError( + "Dynamo input/output is not consistent with traced input/output" + ) else: assert ( id(arg) in dict_of_source_args @@ -568,7 +568,7 @@ def graph_with_interpreter(*args): def assume_constant_result(fn): - setattr(fn, "_dynamo_marked_constant", True) + fn._dynamo_marked_constant = True assert ( not config.fake_tensor_propagation ), "Constant result capture is not supported with fake tensors." @@ -669,7 +669,7 @@ def patch(): opt.step = unwrapped_step # disable future hooking - setattr(opt.step, "hooked", True) + opt.step.hooked = True @staticmethod def suppress_torch_distributed_warnings(fn): diff --git a/torchdynamo/optimizations/inference.py b/torchdynamo/optimizations/inference.py index 64faee60cf..a07ca40de9 100644 --- a/torchdynamo/optimizations/inference.py +++ b/torchdynamo/optimizations/inference.py @@ -79,7 +79,7 @@ def record_graph_stats(gm): elif node.op in ("placeholder", "output", "get_attr"): pass else: - assert False, node.op + raise AssertionError(node.op) def check_requires_grad(gm, example_inputs): diff --git a/torchdynamo/optimizations/normalize.py b/torchdynamo/optimizations/normalize.py index 1acf7e0aa2..fa0aaee7fe 100644 --- a/torchdynamo/optimizations/normalize.py +++ b/torchdynamo/optimizations/normalize.py @@ -220,7 +220,7 @@ def expand_module_call(prefix, graph: torch.fx.Graph, module, args, kwargs): vars[node] = graph.get_attr(f"{prefix}{node.target}") else: vars[node] = graph.node_copy(node, vars.__getitem__) - assert False + raise AssertionError("unreachable") except Exception: print(f"Error while expanding {module.__class__.__name__}") raise @@ -244,7 +244,7 @@ def short_name(gm, node: torch.fx.Node): return node.target elif node.op == "output": return "output" - assert False, node.op + raise AssertionError(node.op) def long_name(gm, node: torch.fx.Node): @@ -263,7 +263,7 @@ def long_name(gm, node: torch.fx.Node): return name elif node.op == "output": return "output" - assert False + raise AssertionError("unreachable") class Inplacifier: @@ -380,9 +380,7 @@ def run_node(self, n: torch.fx.Node): # For inplace operators, the output dtype should be equal to the # dtype of tensor being inplace modified. if n.target in IOPERATOR_REPLACEMENTS: - result = getattr(self, "call_method")( - "to", (result, n.args[0].meta["dtype"]), {} - ) + result = self.call_method("to", (result, n.args[0].meta["dtype"]), {}) for patch in patches: assert isinstance( diff --git a/torchdynamo/output_graph.py b/torchdynamo/output_graph.py index dbd35430f9..36a37d7547 100644 --- a/torchdynamo/output_graph.py +++ b/torchdynamo/output_graph.py @@ -240,7 +240,7 @@ def wrap_name(module_key): return wrap_name(name) name = f"{base}_{i}" - assert False + raise AssertionError("unreachable") def compile_subgraph( self, tx, partial_convert=False, reason: Optional[GraphCompileReason] = None @@ -503,7 +503,7 @@ def create_proxy( # append stack trace to fx node tx = current_tx if current_tx else self.root_tx - nn_module_stack = getattr(tx, "nn_module_stack") + nn_module_stack = tx.nn_module_stack if nn_module_stack: rv.node.meta["nn_module_stack"] = nn_module_stack.copy() diff --git a/torchdynamo/side_effects.py b/torchdynamo/side_effects.py index 2a5a6b25ac..ce09d5f67d 100644 --- a/torchdynamo/side_effects.py +++ b/torchdynamo/side_effects.py @@ -328,7 +328,7 @@ def codegen_update_mutated(self, cg: PyCodegen): cg(var.mutable_local.source) suffixes.append([create_instruction("STORE_ATTR", name)]) else: - assert False, type(var) + raise AssertionError(type(var)) # do all the actual mutations at the very end to handle dependencies for suffix in reversed(suffixes): diff --git a/torchdynamo/variables/functions.py b/torchdynamo/variables/functions.py index f69b79dfbe..dbf33afde4 100644 --- a/torchdynamo/variables/functions.py +++ b/torchdynamo/variables/functions.py @@ -78,10 +78,7 @@ class UserFunctionVariable(BaseUserFunctionVariable): def __init__(self, fn, is_constant=False, **kwargs): super(UserFunctionVariable, self).__init__(**kwargs) - if ( - hasattr(fn, "_dynamo_marked_constant") - and getattr(fn, "_dynamo_marked_constant") is True - ): + if getattr(fn, "_dynamo_marked_constant", False): # This method should be treated as a constant for the purposes of compilation self.is_constant = True else: diff --git a/torchdynamo/variables/lists.py b/torchdynamo/variables/lists.py index 99f25a98bd..613006966d 100644 --- a/torchdynamo/variables/lists.py +++ b/torchdynamo/variables/lists.py @@ -361,7 +361,7 @@ def __init__(self, items, **kwargs): elif len(items) == 3: start, stop, step = items else: - assert False + raise AssertionError() # Avoids a .item() call in the tensor slice that would attempt to get a # value out fake tensors, and which would determine the output shape of diff --git a/torchdynamo/variables/tensor.py b/torchdynamo/variables/tensor.py index ba0afb52b3..4a671bbe92 100644 --- a/torchdynamo/variables/tensor.py +++ b/torchdynamo/variables/tensor.py @@ -77,7 +77,7 @@ def run_proxy(proxy, args, kwargs, nnmodule): elif op == "call_module": assert nnmodule is not None return nnmodule(*args, **kwargs) - assert False, op + raise AssertionError(op) @classmethod def create(cls, tx, proxy, example_value=None, nnmodule=None, **options): @@ -267,9 +267,10 @@ def context(): **options, ) else: - assert ( - False - ), f"torch.* op returned non-Tensor {typestr(example_value)} {proxy.node.op} {proxy.node.target}" + raise AssertionError( + "torch.* op returned non-Tensor " + + f"{typestr(example_value)} {proxy.node.op} {proxy.node.target}" + ) def __init__( self, diff --git a/torchdynamo/variables/torch.py b/torchdynamo/variables/torch.py index 5b4511869f..8986babee2 100644 --- a/torchdynamo/variables/torch.py +++ b/torchdynamo/variables/torch.py @@ -120,7 +120,7 @@ def __init__(self, value, **kwargs): # some _C functions have __self__ as a null capsule pass else: - assert False, f"{value} found with __self__ set" + raise AssertionError(f"{value} found with __self__ set") def __repr__(self): return f"TorchVariable({self.value})" @@ -206,7 +206,7 @@ def call_function( elif self.value is torch.is_complex: return ConstantVariable(args[0].dtype.is_complex, **options) else: - assert False + raise AssertionError() elif ( self.value is torch.numel and isinstance(args[0], TensorVariable) diff --git a/torchinductor/codegen/autotuner.py b/torchinductor/codegen/autotuner.py index c65838b6a9..f5da37b366 100644 --- a/torchinductor/codegen/autotuner.py +++ b/torchinductor/codegen/autotuner.py @@ -89,7 +89,7 @@ def tuned_conv( use_cuda = x.is_cuda # gen_key - key = tuple([arg for arg in id_args]) + key = tuple(id_args) key = ("conv",) + key # candidate kernels @@ -173,7 +173,7 @@ def tuned_mm( use_cuda = a.is_cuda # gen_key - key = tuple([arg for arg in id_args]) + key = tuple(id_args) key = ("mm",) + key # candidate kernels @@ -246,7 +246,7 @@ def tuned_conv_layout( ] # gen_key - key = tuple([arg for arg in id_args]) + key = tuple(id_args) key = ("conv_layout",) + key runnable_kernel = str2func(kernel) diff --git a/torchinductor/codegen/cpp.py b/torchinductor/codegen/cpp.py index af2d61c460..41c26a6d91 100644 --- a/torchinductor/codegen/cpp.py +++ b/torchinductor/codegen/cpp.py @@ -61,7 +61,7 @@ def reduction_init(reduction_type, dtype): if is_float_dtype(dtype) else f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::max()" ) - assert False, reduction_type + raise AssertionError(reduction_type) def reduction_combine(reduction_type, var, next_value): diff --git a/torchinductor/codegen/triton.py b/torchinductor/codegen/triton.py index 02ba7cf33b..492a026d8c 100644 --- a/torchinductor/codegen/triton.py +++ b/torchinductor/codegen/triton.py @@ -298,8 +298,10 @@ def __init__( prefix: str, index: int, kernel: "Kernel", - pid_cache={}, + pid_cache=None, ): + if pid_cache is None: + pid_cache = {} super(IterationRangesRoot, self).__init__( name=name, var_list=[], @@ -453,7 +455,9 @@ class TritonKernel(Kernel): overrides = TritonOverrides sexpr = texpr - def __init__(self, *groups, pid_cache={}, reduction_hint=ReductionHint.DEFAULT): + def __init__(self, *groups, pid_cache=None, reduction_hint=ReductionHint.DEFAULT): + if pid_cache is None: + pid_cache = {} super(TritonKernel, self).__init__() self.numels = [V.graph.sizevars.simplify(s) for s in groups] self.range_trees = [] diff --git a/torchinductor/graph.py b/torchinductor/graph.py index bd38c68d40..2168331513 100644 --- a/torchinductor/graph.py +++ b/torchinductor/graph.py @@ -272,10 +272,10 @@ def get_attr(self, target, args, kwargs): return self.add_tensor_constant(value) def call_module(self, target, args, kwargs): - assert False + raise AssertionError() def call_method(self, target, args, kwargs): - assert False + raise AssertionError() def output(self, target, args, kwargs): result = super().output(target, args, kwargs) diff --git a/torchinductor/ir.py b/torchinductor/ir.py index d2f65c45b4..a8b7a4c94d 100644 --- a/torchinductor/ir.py +++ b/torchinductor/ir.py @@ -1154,7 +1154,7 @@ def reindex(index: List[sympy.Expr]) -> List[sympy.Expr]: return new_size, reindex def __init__(self, data): - assert False, "use SqueezeView.create()" + raise AssertionError("use SqueezeView.create()") @dataclasses.dataclass @@ -1286,7 +1286,7 @@ def _dynamic_reshape_indexer(old_size, new_size): size_old = size_old * modulus V.graph.sizevars.guard_equals(size_new, size_old) else: - assert False + raise AssertionError() while stack_old: size_old = stack_old.pop() @@ -2021,13 +2021,16 @@ def simplify_and_reorder(x_vars, sizes, reordering_reindex=None): @staticmethod def _apply_loop_reordering( - index_vars, sizes, memory_addrs, reordering_reindex=None, priority_idx=[] + index_vars, sizes, memory_addrs, reordering_reindex=None, priority_idx=None ): """ Shuffle the order of loops around to hopefully improve performance. """ from .scheduler import pick_loop_order + if priority_idx is None: + priority_idx = [] + try: strides = numpy.array( [ @@ -2417,9 +2420,9 @@ def codegen(self, wrapper): args.append(f"out={self.codegen_reference()}") wrapper.writeline(f"{self.kernel}({', '.join(args)})") - def __init__(self, layout, inputs, constant_args=(), kwargs={}, output_view=None): + def __init__(self, layout, inputs, constant_args=(), kwargs=None, output_view=None): super().__init__( - None, layout, self.unwrap_storage(inputs), constant_args, kwargs + None, layout, self.unwrap_storage(inputs), constant_args, kwargs or {} ) self.output_view = output_view self.name = V.graph.register_buffer(self) @@ -2627,8 +2630,8 @@ def map_args(self): class MatrixMultiplyAdd(ExternKernelOut): - def __init__(self, layout, inputs, constant_args=(), kwargs={}, output_view=None): - super().__init__(layout, inputs, constant_args, kwargs, output_view) + def __init__(self, layout, inputs, constant_args=(), kwargs=None, output_view=None): + super().__init__(layout, inputs, constant_args, kwargs or {}, output_view) self.kernel = "aten.addmm.out" @classmethod diff --git a/torchinductor/lowering.py b/torchinductor/lowering.py index bb3cebba45..7487cfc9c6 100644 --- a/torchinductor/lowering.py +++ b/torchinductor/lowering.py @@ -887,7 +887,7 @@ def bernoulli_(x, *args): # This shouldn't be called in general @register_lowering(aten._foobar) def _foobar(_): - assert False + raise AssertionError() @functools.lru_cache(1) @@ -1737,7 +1737,7 @@ def scatter_(self, dim: int, index, src, *, reduce: str = None): if reduce == "add": reduce = "sum" elif reduce == "multiply": - assert False, "TODO: multiply not supported" + raise NotImplementedError("TODO: multiply not supported") reduce = "prod" else: assert reduce is None @@ -2572,7 +2572,7 @@ def fn(idx): def avg_pool2d( x, kernel_size, - stride=[], + stride=(), padding=0, ceil_mode=False, count_include_pad=True, diff --git a/torchinductor/scheduler.py b/torchinductor/scheduler.py index 8006278a37..5df96fb5a3 100644 --- a/torchinductor/scheduler.py +++ b/torchinductor/scheduler.py @@ -324,7 +324,7 @@ def allocate(self): return super().allocate() if config.inplace_buffers: - assert False, "https://github.com/pytorch/torchdynamo/issues/823" + raise AssertionError("https://github.com/pytorch/torchdynamo/issues/823") """ for read in self.read_writes.reads: input_node: BaseSchedulerNode = self.scheduler.name_to_node.get( @@ -500,7 +500,7 @@ def can_free(self): raise NotImplementedError -def pick_loop_order(stride_lengths, sizes, priority_idx=[]): +def pick_loop_order(stride_lengths, sizes, priority_idx=()): """ A heuristic to decide loop iteration orders. This has not been well tuned and may be something we should autotune. @@ -571,7 +571,7 @@ def __init__(self, nodes): elif isinstance(node, ir.ExternKernel): self.nodes.append(ExternKernelSchedulerNode(self, node)) else: - assert False, node + raise NotImplementedError(node) # some new constants could have been created above self.available_buffer_names.update(V.graph.constants.keys()) for node in self.nodes: diff --git a/torchinductor/utils.py b/torchinductor/utils.py index e961afaf86..9cde2af38f 100644 --- a/torchinductor/utils.py +++ b/torchinductor/utils.py @@ -30,7 +30,7 @@ def has_triton(): import triton return triton is not None - except (ImportError, ModuleNotFoundError): + except ImportError: return False @@ -42,7 +42,7 @@ def has_torchvision_roi_align(): return roi_align is not None and hasattr( getattr(torch.ops, "torchvision", None), "roi_align" ) - except (ImportError, ModuleNotFoundError): + except ImportError: return False From ba35c1972f3a698bd5329a0991b4dfeff391c9b4 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sat, 8 Oct 2022 07:45:09 -0700 Subject: [PATCH 06/36] fix --- test/inductor/test_torchinductor_opinfo.py | 4 +++- torchdynamo/convert_frame.py | 2 +- torchdynamo/variables/misc.py | 5 ++++- torchinductor/codegen/autotuner.py | 6 +++--- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index a9bd0553a7..cced90d1a5 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -19,6 +19,7 @@ from torch.testing._internal.common_utils import suppress_warnings import torchdynamo +from torchinductor.utils import has_triton try: from .test_torchinductor import check_model @@ -571,4 +572,5 @@ def fn(*args, **kwargs): if __name__ == "__main__": torchdynamo.config.raise_on_assertion_error = True - run_tests() + if has_triton(): + run_tests() diff --git a/torchdynamo/convert_frame.py b/torchdynamo/convert_frame.py index 7aceb77a0c..228db99182 100644 --- a/torchdynamo/convert_frame.py +++ b/torchdynamo/convert_frame.py @@ -308,7 +308,7 @@ def format_guard_failures(code): assert code in guard_failures, "TODO(whc) any other recompile reasons?" log.warning( - f"torch.dynamo hit config.cache_size_limit ({config.cache_size_limit})\n" + f"{config.dynamo_import} hit config.cache_size_limit ({config.cache_size_limit})\n" + f" function: {format_func_info(code)}\n" + f" reasons: {format_guard_failures(code)}\n" + f"to diagnose recompilation issues, see {troubleshooting_url}." diff --git a/torchdynamo/variables/misc.py b/torchdynamo/variables/misc.py index 5bc9368cf1..0dff0c7152 100644 --- a/torchdynamo/variables/misc.py +++ b/torchdynamo/variables/misc.py @@ -6,6 +6,7 @@ import torch._C +from .. import config from .. import variables from ..bytecode_transformation import create_instruction from ..exc import unimplemented @@ -616,7 +617,9 @@ def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": if inspect.getattr_static(self.value, "_torchdynamo_disable", False): - unimplemented(f"call torch.dynamo.disable() wrapped function {self.value}") + unimplemented( + f"call {config.dynamo_import}.disable() wrapped function {self.value}" + ) else: try: path = inspect.getfile(self.value) diff --git a/torchinductor/codegen/autotuner.py b/torchinductor/codegen/autotuner.py index f5da37b366..94b2318a02 100644 --- a/torchinductor/codegen/autotuner.py +++ b/torchinductor/codegen/autotuner.py @@ -1,7 +1,6 @@ import builtins import torch -import triton from .. import config from .. import triton_ops @@ -32,14 +31,15 @@ def str2func(str): class Autotuner: def __init__(self): - self.cache = dict() def _bench(self, kernel, *args, **kwargs): def kernel_call(): kernel(*args, **kwargs) - return triton.testing.do_bench(kernel_call) + from triton.testing import do_bench + + return do_bench(kernel_call) autotune = Autotuner() From 2993bd311f38d9aa5e60721047e2a570ed0b061d Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sat, 8 Oct 2022 11:39:35 -0700 Subject: [PATCH 07/36] enable tests --- copy_to_core.sh | 12 +- test/dynamo/test_aot_autograd.py | 10 +- test/dynamo/test_aot_cudagraphs.py | 6 + test/dynamo/test_distributed.py | 6 + test/dynamo/test_dynamic_shapes.py | 5 + test/dynamo/test_export.py | 12 +- test/dynamo/test_functions.py | 6 + test/dynamo/test_global.py | 6 + test/dynamo/test_minifier.py | 6 + test/dynamo/test_misc.py | 6 + test/dynamo/test_model_output.py | 6 + test/dynamo/test_modules.py | 6 + test/dynamo/test_no_fake_tensors.py | 5 + test/dynamo/test_nops.py | 6 + test/dynamo/test_optimizations.py | 6 + test/dynamo/test_optimizers.py | 5 + test/dynamo/test_python_autograd.py | 6 + test/dynamo/test_recompile_ux.py | 6 + test/dynamo/test_replay_record.py | 6 + test/dynamo/test_repros.py | 38 +- test/dynamo/test_skip_non_tensor.py | 6 + test/dynamo/test_subgraphs.py | 6 + test/dynamo/test_unspec.py | 6 + test/dynamo/test_verify_correctness.py | 6 + test/inductor/test_torchinductor.py | 13 +- torchdynamo/replay_record.py | 7 +- torchdynamo/skipfiles.py | 8 +- torchdynamo/testing.py | 26 +- torchinductor/triton_ops/autotune.py | 34 +- torchinductor/triton_ops/batched_matmul.py | 528 ++++---- torchinductor/triton_ops/conv.py | 1344 ++++++++++--------- torchinductor/triton_ops/conv1x1.py | 351 ++--- torchinductor/triton_ops/conv_perf_model.py | 11 +- torchinductor/triton_ops/matmul.py | 246 ++-- torchinductor/triton_ops/mm_perf_model.py | 9 +- 35 files changed, 1491 insertions(+), 1275 deletions(-) diff --git a/copy_to_core.sh b/copy_to_core.sh index 9ddeacea61..4745b8444d 100755 --- a/copy_to_core.sh +++ b/copy_to_core.sh @@ -1,17 +1,17 @@ #!/bin/bash set -ex -rsync -ra ~/torchdynamo/torchdynamo/ ~/pytorch/torch/dynamo -rsync -ra ~/torchdynamo/torchinductor/ ~/pytorch/torch/inductor +rsync -ra ~/torchdynamo/torchdynamo/ ~/pytorch/torch/_dynamo +rsync -ra ~/torchdynamo/torchinductor/ ~/pytorch/torch/_inductor rsync -ra ~/torchdynamo/test/{dynamo,inductor} ~/pytorch/test/ rsync -ra ~/torchdynamo/benchmarks/ ~/pytorch/benchmarks/dynamo for DIR in ~/pytorch/test/{dynamo,inductor} ~/pytorch/benchmarks/dynamo do - find $DIR -name '*.py' | xargs -n1 -- sed -i 's/torchdynamo/torch.dynamo/g' - find $DIR -name '*.py' | xargs -n1 -- sed -i 's/torchinductor/torch.inductor/g' - find $DIR -name '*.py' | xargs -n1 -- sed -i 's/_torch[.]inductor/_torchinductor/g' - find $DIR -name '*.py' | xargs -n1 -- sed -i 's@pytorch/torch[.]dynamo@pytorch/torchdynamo@g' + find $DIR -name '*.py' | xargs -n1 -- sed -i 's/torchdynamo/torch._dynamo/g' + find $DIR -name '*.py' | xargs -n1 -- sed -i 's/torchinductor/torch._inductor/g' + find $DIR -name '*.py' | xargs -n1 -- sed -i 's/_torch[.]_inductor/_torchinductor/g' + find $DIR -name '*.py' | xargs -n1 -- sed -i 's@pytorch/torch[.]_dynamo@pytorch/torchdynamo@g' done (cd ~/pytorch && (lintrunner -a || lintrunner -a)) diff --git a/test/dynamo/test_aot_autograd.py b/test/dynamo/test_aot_autograd.py index 1824c66bce..a9e6134972 100644 --- a/test/dynamo/test_aot_autograd.py +++ b/test/dynamo/test_aot_autograd.py @@ -20,7 +20,7 @@ class Repro(torch.nn.Module): def __init__(self): super().__init__() self.self_mod_model_lstm_lstm = torch.nn.LSTM( - 2048, 2048, num_layers=2, bidirectional=True + 64, 64, num_layers=2, bidirectional=True ) def forward(self, permute: torch.Tensor): @@ -32,7 +32,7 @@ def forward(self, permute: torch.Tensor): compiler_fn = functools.partial(compiler_safe_fn, is_safe=is_safe) aot_mod = torchdynamo.optimize(compiler_fn)(mod) - args = [((92, 4, 2048), (1, 188416, 92), torch.float32, "cpu", False)] + args = [((92, 4, 64), (1, 5888, 92), torch.float32, "cpu", False)] args = [ rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args @@ -71,3 +71,9 @@ def fn(x, y): aot_fn = torchdynamo.optimize(compiler_fn)(fn) aot_fn(x, y) self.assertTrue(is_safe[0]) + + +if __name__ == "__main__": + from torchdynamo.testing import run_tests + + run_tests() diff --git a/test/dynamo/test_aot_cudagraphs.py b/test/dynamo/test_aot_cudagraphs.py index e67afd2c83..57c05b83cd 100644 --- a/test/dynamo/test_aot_cudagraphs.py +++ b/test/dynamo/test_aot_cudagraphs.py @@ -199,3 +199,9 @@ def fn(x): x = torch.empty(20, device="cuda:0") fn(x) + + +if __name__ == "__main__": + from torchdynamo.testing import run_tests + + run_tests() diff --git a/test/dynamo/test_distributed.py b/test/dynamo/test_distributed.py index dab4f2fe72..0559bab117 100644 --- a/test/dynamo/test_distributed.py +++ b/test/dynamo/test_distributed.py @@ -195,3 +195,9 @@ def opt_fn(inputs): opt_outputs = opt_fn(inputs) opt_outputs.sum().backward() self.assertTrue(same(correct_outputs, opt_outputs)) + + +# TODO(jansel): debug issues running this in CI +# if __name__ == "__main__": +# from torchdynamo.testing import run_tests +# run_tests() diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py index c74572360c..0985ebbdfd 100644 --- a/test/dynamo/test_dynamic_shapes.py +++ b/test/dynamo/test_dynamic_shapes.py @@ -27,3 +27,8 @@ def make_dynamic_cls(cls): DynamicShapesReproTests = make_dynamic_cls(test_repros.ReproTests) DynamicShapesNNModuleTests = make_dynamic_cls(test_modules.NNModuleTests) DynamicShapesUnspecTests = make_dynamic_cls(test_unspec.UnspecTests) + +if __name__ == "__main__": + from torchdynamo.testing import run_tests + + run_tests() diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index cb18a1f170..cbc760f588 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -1255,7 +1255,7 @@ def forward(self, x): real_result = module(torch.tensor([2])) # X is positive, so .item() > 0, which means we return y * x - self.assertEqual(real_result, torch.tensor([1])) + self.assertEqual(real_result, torch.tensor([1.0])) graph, guards = torchdynamo.export(module, torch.tensor([2])) result = graph(torch.tensor([-0.5])) @@ -1311,7 +1311,7 @@ def forward(self, x): real_result = module(torch.tensor([2])) # X is positive, so .item() > 0, which means we return y * x - self.assertEqual(real_result, torch.tensor([1])) + self.assertEqual(real_result, torch.tensor([1.0])) graph, guards = torchdynamo.export(module, torch.tensor([2])) result = graph(torch.tensor([-0.5])) @@ -1339,7 +1339,7 @@ def forward(self, x): real_result = module(torch.tensor([2])) # X is positive, so .item() > 0, which means we return y * x - self.assertEqual(real_result, torch.tensor([1])) + self.assertEqual(real_result, torch.tensor([1.0])) graph, guards = torchdynamo.export(module, torch.tensor([2])) result = graph(torch.tensor([-0.5])) @@ -1420,3 +1420,9 @@ def nop(x): graph, _ = torchdynamo.export( f, (torch.randn(5)), aten_graph=False, tracing_mode="symbolic" ) + + +if __name__ == "__main__": + from torchdynamo.testing import run_tests + + run_tests() diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 86b93e50ef..975d2a1abf 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -667,3 +667,9 @@ def test_list_slice_assignment(x): # return x * param # case {"b": param}: # return x / param + + +if __name__ == "__main__": + from torchdynamo.testing import run_tests + + run_tests() diff --git a/test/dynamo/test_global.py b/test/dynamo/test_global.py index b706c066f0..3d11c8e479 100644 --- a/test/dynamo/test_global.py +++ b/test/dynamo/test_global.py @@ -224,3 +224,9 @@ def fn(a, b): v0, s0 = opt_fn(a, b) self.assertEqual(s0, "v0v1") reset_name() + + +if __name__ == "__main__": + from torchdynamo.testing import run_tests + + run_tests() diff --git a/test/dynamo/test_minifier.py b/test/dynamo/test_minifier.py index 7741cb2c57..4532c0c8f2 100644 --- a/test/dynamo/test_minifier.py +++ b/test/dynamo/test_minifier.py @@ -77,3 +77,9 @@ def inner(): inner() self.assertTrue(os.path.exists(repro_file)) + + +if __name__ == "__main__": + from torchdynamo.testing import run_tests + + run_tests() diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 5cf4990431..b2021529e7 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -2709,3 +2709,9 @@ def forward(self, x): fn() opt_fn = torchdynamo.optimize("eager")(fn) opt_fn() + + +if __name__ == "__main__": + from torchdynamo.testing import run_tests + + run_tests() diff --git a/test/dynamo/test_model_output.py b/test/dynamo/test_model_output.py index a78300c744..1c4b695d99 100644 --- a/test/dynamo/test_model_output.py +++ b/test/dynamo/test_model_output.py @@ -157,3 +157,9 @@ def fn(obj): self.assertTrue(same(opt_fn(obj2), correct1)) self.assertEqual(cnts.frame_count, 1) self.assertEqual(cnts.op_count, 2) + + +if __name__ == "__main__": + from torchdynamo.testing import run_tests + + run_tests() diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index 3ce2ed3cbd..576850d1bf 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -882,3 +882,9 @@ def test_torch_static(): "Module should be transformed to an instance of BatchNorm3d.", ) self.assertEqual(cnt.frame_count, 1, "No guards should have triggered.") + + +if __name__ == "__main__": + from torchdynamo.testing import run_tests + + run_tests() diff --git a/test/dynamo/test_no_fake_tensors.py b/test/dynamo/test_no_fake_tensors.py index bff8c0cd50..8762250c70 100644 --- a/test/dynamo/test_no_fake_tensors.py +++ b/test/dynamo/test_no_fake_tensors.py @@ -26,3 +26,8 @@ def make_no_fake_cls(cls): NoFakeTensorsReproTests = make_no_fake_cls(test_repros.ReproTests) NoFakeTensorsNNModuleTests = make_no_fake_cls(test_modules.NNModuleTests) NoFakeTensorsUnspecTests = make_no_fake_cls(test_unspec.UnspecTests) + +if __name__ == "__main__": + from torchdynamo.testing import run_tests + + run_tests() diff --git a/test/dynamo/test_nops.py b/test/dynamo/test_nops.py index 7de329fb96..efe160de24 100644 --- a/test/dynamo/test_nops.py +++ b/test/dynamo/test_nops.py @@ -63,3 +63,9 @@ def test_extended_args(self): b = torch.ones(1) fn = with_debug_nops(fn) self.assertEqual(fn(a, b).sum(), 513) + + +if __name__ == "__main__": + from torchdynamo.testing import run_tests + + run_tests() diff --git a/test/dynamo/test_optimizations.py b/test/dynamo/test_optimizations.py index 5aacdfa01d..49b6ce6a4a 100644 --- a/test/dynamo/test_optimizations.py +++ b/test/dynamo/test_optimizations.py @@ -200,3 +200,9 @@ def fn(a, b): optimized_fn = torchdynamo.optimize("aot_eager")(fn) res = optimized_fn(a, b) self.assertTrue(same(ref, res)) + + +if __name__ == "__main__": + from torchdynamo.testing import run_tests + + run_tests() diff --git a/test/dynamo/test_optimizers.py b/test/dynamo/test_optimizers.py index 5ab4756a50..79f96ee20e 100644 --- a/test/dynamo/test_optimizers.py +++ b/test/dynamo/test_optimizers.py @@ -95,3 +95,8 @@ def setUpClass(cls): for opt in optimizers: setattr(OptimizerTests, "test_" + opt.__name__.lower(), make_test(opt)) + +if __name__ == "__main__": + from torchdynamo.testing import run_tests + + run_tests() diff --git a/test/dynamo/test_python_autograd.py b/test/dynamo/test_python_autograd.py index 6de772715c..d0a348c937 100644 --- a/test/dynamo/test_python_autograd.py +++ b/test/dynamo/test_python_autograd.py @@ -286,3 +286,9 @@ def forward(a, b): self.assertTrue(same(grad1, grad2)) self.assertEqual(cnt.frame_count, 2) self.assertEqual(cnt.op_count, 8) + + +if __name__ == "__main__": + from torchdynamo.testing import run_tests + + run_tests() diff --git a/test/dynamo/test_recompile_ux.py b/test/dynamo/test_recompile_ux.py index f6f7e69b2f..ff96feb3c7 100644 --- a/test/dynamo/test_recompile_ux.py +++ b/test/dynamo/test_recompile_ux.py @@ -196,3 +196,9 @@ def func(a, b): self.assert_single_log_contains( logs, "expected type of 'b' to be a tensor type, ' but found " ) + + +# TODO(jansel): these pass with pytest, but not with pytorch CI +# if __name__ == "__main__": +# from torchdynamo.testing import run_tests +# run_tests() diff --git a/test/dynamo/test_replay_record.py b/test/dynamo/test_replay_record.py index 6a0a04fc63..a9751f872c 100644 --- a/test/dynamo/test_replay_record.py +++ b/test/dynamo/test_replay_record.py @@ -145,3 +145,9 @@ def test_fn(x, y): self.check_replay( test_fn, torch.ones(3, 3), torch.ones(2, 2), exp_exc_name="RuntimeError" ) + + +if __name__ == "__main__": + from torchdynamo.testing import run_tests + + run_tests(needs="dill") diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index b1c3e0667c..96e6ea3594 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -1583,10 +1583,20 @@ def fn(x): self.assertEqual(cnt.frame_count, 1) def test_relative_import(self): - def fn(x): - from .test_functions import tensor_for_import_testing + try: + from . import test_functions as _ # noqa: F401 + + def fn(x): + from .test_functions import tensor_for_import_testing + + return x * 2 * tensor_for_import_testing + + except ImportError: - return x * 2 * tensor_for_import_testing + def fn(x): + from test_functions import tensor_for_import_testing + + return x * 2 * tensor_for_import_testing x = torch.randn(10) fn(x) @@ -1596,10 +1606,20 @@ def fn(x): self.assertEqual(cnt.frame_count, 1) def test_relative_import_no_modulename(self): - def fn(x): - from . import test_functions + try: + from . import test_functions as _ # noqa: F401 + + def fn(x): + from . import test_functions - return x * 2 * test_functions.tensor_for_import_testing + return x * 2 * test_functions.tensor_for_import_testing + + except ImportError: + + def fn(x): + import test_functions + + return x * 2 * test_functions.tensor_for_import_testing x = torch.randn(10) fn(x) @@ -1688,3 +1708,9 @@ def forward(self, getitem_1, getitem_2, add): for (sh, st, dt, dev, rg) in args ] self.assertTrue(same_two_models(mod, opt_mod, args)) + + +if __name__ == "__main__": + from torchdynamo.testing import run_tests + + run_tests() diff --git a/test/dynamo/test_skip_non_tensor.py b/test/dynamo/test_skip_non_tensor.py index 70b3b1f22c..bbe2cb7b38 100644 --- a/test/dynamo/test_skip_non_tensor.py +++ b/test/dynamo/test_skip_non_tensor.py @@ -104,3 +104,9 @@ def __len__(self): fn(x) assert counter.op_count == 0 + + +if __name__ == "__main__": + from torchdynamo.testing import run_tests + + run_tests() diff --git a/test/dynamo/test_subgraphs.py b/test/dynamo/test_subgraphs.py index a3c321ee60..a78de22375 100644 --- a/test/dynamo/test_subgraphs.py +++ b/test/dynamo/test_subgraphs.py @@ -525,3 +525,9 @@ def fn(a, b): return b self._common(fn, 1, 2) + + +if __name__ == "__main__": + from torchdynamo.testing import run_tests + + run_tests() diff --git a/test/dynamo/test_unspec.py b/test/dynamo/test_unspec.py index 82751e6460..e404f77fe6 100644 --- a/test/dynamo/test_unspec.py +++ b/test/dynamo/test_unspec.py @@ -219,3 +219,9 @@ def fn(image, scale_factor): opt_fn = torchdynamo.optimize(cnts)(fn) res = opt_fn(x, scale_factor) self.assertTrue(same(ref, res)) + + +if __name__ == "__main__": + from torchdynamo.testing import run_tests + + run_tests() diff --git a/test/dynamo/test_verify_correctness.py b/test/dynamo/test_verify_correctness.py index 7275b3880a..191a46ebf9 100644 --- a/test/dynamo/test_verify_correctness.py +++ b/test/dynamo/test_verify_correctness.py @@ -180,3 +180,9 @@ def test_ipex_bf16(self): r2 = opt_model(input) self.assertTrue(same(r1, r2.float(), tol=0.1)) self.assertEqual(r2.dtype, torch.bfloat16) + + +if __name__ == "__main__": + from torchdynamo.testing import run_tests + + run_tests() diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 6afb6cda13..baf7acf521 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -19,7 +19,6 @@ from torchdynamo.debug_utils import same_two_models from torchdynamo.testing import rand_strided from torchdynamo.testing import same -from torchinductor.utils import timed try: import sympy @@ -36,6 +35,7 @@ from torchinductor.ir import ModularIndexing from torchinductor.sizevars import SizeVarAllocator from torchinductor.utils import has_torchvision_roi_align + from torchinductor.utils import timed # This will only pass on pytorch builds newer than roughly 5/15/2022 assert get_decompositions([torch.ops.aten.trace]) @@ -43,6 +43,8 @@ from torchinductor.compile_fx import compile_fx_inner except (ImportError, AssertionError) as e: sys.stderr.write(f"{type(e)}: {e}\n") + if __name__ == "__main__": + sys.exit(0) raise unittest.SkipTest("requires sympy/functorch") @@ -89,6 +91,7 @@ def maybe_test(*args, **kwargs): class TestCase(TorchTestCase): @classmethod def setUpClass(cls): + super().setUpClass() cls._stack = contextlib.ExitStack() cls._stack.enter_context(patch.object(config, "debug", True)) cls._stack.enter_context(patch.object(config.cpp, "min_chunk_size", 1)) @@ -96,6 +99,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): cls._stack.close() + super().tearDownClass() class ToTuple(torch.nn.Module): @@ -363,7 +367,7 @@ class SweepInputsCpuTest(SweepInputs2, TestCase): SweepInputsCpuTest.populate() -class TestIndexingSimplification(unittest.TestCase): +class TestIndexingSimplification(TorchTestCase): def test_indexing_simplification(self): sizevars = SizeVarAllocator() i0 = sympy.Symbol("i0") @@ -3803,3 +3807,8 @@ def forward(self, start_positions: torch.Tensor, x: torch.Tensor): ] with torch.cuda.amp.autocast(enabled=False): assert same_two_models(mod, opt_mod, args), "Dynamo failed" + + +# if __name__ == "__main__": +# from torchdynamo.testing import run_tests +# run_tests() diff --git a/torchdynamo/replay_record.py b/torchdynamo/replay_record.py index f0a6fd6d5b..1d80de5d5e 100644 --- a/torchdynamo/replay_record.py +++ b/torchdynamo/replay_record.py @@ -5,7 +5,10 @@ from typing import Any from typing import Dict -import dill +try: + import dill +except ImportError: + dill = None @dataclasses.dataclass @@ -28,10 +31,12 @@ class ExecutionRecord: code_options: Dict[str, Any] = field(default_factory=dict) def dump(self, f): + assert dill is not None, "replay_record requires `pip install dill`" dill.dump(self, f) @classmethod def load(cls, f): + assert dill is not None, "replay_record requires `pip install dill`" return dill.load(f) diff --git a/torchdynamo/skipfiles.py b/torchdynamo/skipfiles.py index d8582d67ff..f38b953fd3 100644 --- a/torchdynamo/skipfiles.py +++ b/torchdynamo/skipfiles.py @@ -30,7 +30,6 @@ import _collections_abc import _weakrefset import torch -import torch.dynamo try: import torch._prims @@ -194,7 +193,12 @@ def is_torch_inline_allowed(filename): ) +@functools.lru_cache(None) +def dynamo_dir(): + return _module_dir(importlib.import_module(config.dynamo_import)) + + def is_torch(filename): - if filename.startswith(_module_dir(torch.dynamo)): + if filename.startswith(dynamo_dir()): return False return filename.startswith(_module_dir(torch)) diff --git a/torchdynamo/testing.py b/torchdynamo/testing.py index f565c46b12..c40a944b6c 100644 --- a/torchdynamo/testing.py +++ b/torchdynamo/testing.py @@ -1,6 +1,7 @@ import contextlib import dis import functools +import importlib import logging import os.path import types @@ -8,6 +9,7 @@ from unittest.mock import patch import torch +import torch.testing._internal.common_utils from torch import fx from . import config @@ -29,6 +31,27 @@ log = logging.getLogger(__name__) +def run_tests(argv=None, needs=()): + from torch.testing._internal.common_utils import TEST_WITH_TORCHDYNAMO + from torch.testing._internal.common_utils import run_tests + + if TEST_WITH_TORCHDYNAMO: + return # cant dynamo dynamo + + if isinstance(needs, str): + needs = (needs,) + for need in needs: + if need == "cuda" and not torch.cuda.is_available(): + return + else: + try: + importlib.import_module(need) + except ImportError: + return + + run_tests(argv) + + def clone_me(x): if x is None: return None @@ -198,8 +221,7 @@ def standard_test(self, fn, nargs, expected_ops=None, expected_ops_dynamic=None) self.assertEqual(actual.op_count, expected_ops) -# class TestCase(torch.testing._internal.common_utils.TestCase): -class TestCase(unittest.TestCase): +class TestCase(torch.testing._internal.common_utils.TestCase): @classmethod def tearDownClass(cls): cls._exit_stack.close() diff --git a/torchinductor/triton_ops/autotune.py b/torchinductor/triton_ops/autotune.py index 1d2dbdc443..4e0bf3116e 100644 --- a/torchinductor/triton_ops/autotune.py +++ b/torchinductor/triton_ops/autotune.py @@ -9,16 +9,6 @@ from typing import List import torch -import triton -from triton import Config -from triton import cdiv -from triton import heuristics -from triton import next_power_of_2 -from triton.ops.matmul import get_configs_io_bound -from triton.ops.matmul_perf_model import early_config_prune as mm_early_config_prune -from triton.runtime.jit import KernelInterface -from triton.runtime.jit import get_cuda_stream -from triton.testing import do_bench from .. import config from ..codecache import AsyncCompile @@ -30,6 +20,21 @@ log = logging.getLogger(__name__) +try: + import triton + from triton import Config + from triton import cdiv + from triton import next_power_of_2 + from triton.runtime.jit import KernelInterface + from triton.runtime.jit import get_cuda_stream +except ImportError: + cdiv = None + Config = object + get_cuda_stream = None + KernelInterface = object + next_power_of_2 = None + triton = None + class CachingAutotuner(KernelInterface): """ @@ -55,7 +60,7 @@ def precompile(self): self.launchers = AsyncCompile.map(self._precompile_config, self.configs) self.configs = None - def _precompile_config(self, cfg: triton.runtime.autotuner.Config): + def _precompile_config(self, cfg: Config): """Ahead of time compile a given autotuner config.""" torch.cuda.set_device(torch.cuda.current_device()) compile_meta = copy.deepcopy(self.meta) @@ -129,6 +134,8 @@ def kernel_call(): stream=stream, ) + from triton.testing import do_bench + return do_bench(kernel_call) def autotune_to_one_config(self, *args, **kwargs): @@ -514,6 +521,8 @@ def conv_heuristics(): def mm_heuristics(): + from triton import heuristics + mm_heuristic = heuristics( { "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, @@ -523,6 +532,9 @@ def mm_heuristics(): def mm_autotune(get_io_bound_configs=False): + from triton.ops.matmul import get_configs_io_bound + from triton.ops.matmul_perf_model import early_config_prune as mm_early_config_prune + configs = [ # basic configs for compute-bound matmuls triton.Config( diff --git a/torchinductor/triton_ops/batched_matmul.py b/torchinductor/triton_ops/batched_matmul.py index 0d6ca47312..7e7a65596b 100644 --- a/torchinductor/triton_ops/batched_matmul.py +++ b/torchinductor/triton_ops/batched_matmul.py @@ -1,276 +1,274 @@ import torch -import triton -import triton.language as tl -# from triton.ops.matmul_perf_model import early_config_prune -# from triton.ops.matmul_perf_model import estimate_matmul_time +from ..utils import has_triton +if has_triton(): + import triton + import triton.language as tl -def init_to_zero(name): - return lambda nargs: nargs[name].zero_() + def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + @triton.heuristics( + { + "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, + } + ) + @triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=5, + num_warps=2, + ), + # additional configs + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=2, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=2, + num_warps=4, + ), + # additional configs for K = 64 + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=1, + num_warps=8, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=1, + num_warps=8, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=1, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=1, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=1, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=1, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=1, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=1, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=1, + num_warps=2, + ), + ], + # + get_configs_io_bound(), + key=["M", "N", "K"], + # + # key=["M", "N", "K"], + # prune_configs_by={ + # "early_config_prune": early_config_prune, + # "perf_model": estimate_matmul_time, + # "top_k": 18, + # }, + ) + @triton.jit + def _kernel( + A, + B, + C, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, + SPLIT_K: tl.constexpr, + EVEN_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + ): + # matrix multiplication + pid = tl.program_id(0) + pid_z = tl.program_id(1) + bid = tl.program_id(2) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + A += bid * M * K + B += bid * K * N + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(K, 0, -BLOCK_K * SPLIT_K): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k, other=0.0) + b = tl.load(B, mask=rk[:, None] < k, other=0.0) + acc += tl.dot(a, b) + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + acc = acc.to(C.dtype.element_ty) -@triton.heuristics( - { - "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, - } -) -@triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, - num_stages=3, - num_warps=8, - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, - num_stages=3, - num_warps=8, - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, - num_stages=5, - num_warps=2, - ), - # additional configs - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "SPLIT_K": 1}, - num_stages=3, - num_warps=8, - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, - num_stages=3, - num_warps=8, - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 64, "SPLIT_K": 1}, - num_stages=2, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, - num_stages=2, - num_warps=4, - ), - # additional configs for K = 64 - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "SPLIT_K": 1}, - num_stages=1, - num_warps=8, - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, - num_stages=1, - num_warps=8, - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, - num_stages=1, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 64, "SPLIT_K": 1}, - num_stages=1, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, - num_stages=1, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, - num_stages=5, - num_warps=2, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, - num_stages=1, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, - num_stages=1, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, - num_stages=1, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, - num_stages=1, - num_warps=2, - ), - ], - # + get_configs_io_bound(), - key=["M", "N", "K"], - # - # key=["M", "N", "K"], - # prune_configs_by={ - # "early_config_prune": early_config_prune, - # "perf_model": estimate_matmul_time, - # "top_k": 18, - # }, -) -@triton.jit -def _kernel( - A, - B, - C, - M, - N, - K, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr, - SPLIT_K: tl.constexpr, - EVEN_K: tl.constexpr, - ACC_TYPE: tl.constexpr, -): - # matrix multiplication - pid = tl.program_id(0) - pid_z = tl.program_id(1) - bid = tl.program_id(2) - grid_m = (M + BLOCK_M - 1) // BLOCK_M - grid_n = (N + BLOCK_N - 1) // BLOCK_N - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) - # do matrix multiplication - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) - # pointers - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - A += bid * M * K - B += bid * K * N - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for k in range(K, 0, -BLOCK_K * SPLIT_K): - if EVEN_K: - a = tl.load(A) - b = tl.load(B) + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + C += bid * M * N + mask = (rm < M)[:, None] & (rn < N)[None, :] + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) else: - a = tl.load(A, mask=rk[None, :] < k, other=0.0) - b = tl.load(B, mask=rk[:, None] < k, other=0.0) - acc += tl.dot(a, b) - A += BLOCK_K * SPLIT_K * stride_ak - B += BLOCK_K * SPLIT_K * stride_bk - acc = acc.to(C.dtype.element_ty) - - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) - C += bid * M * N - mask = (rm < M)[:, None] & (rn < N)[None, :] - # handles write-back with reduction-splitting - if SPLIT_K == 1: - tl.store(C, acc, mask=mask) - else: - tl.atomic_add(C, acc, mask=mask) - + tl.atomic_add(C, acc, mask=mask) -def bmm_out(a, b, out): - # handle non-contiguous inputs if necessary - if a.stride(0) > 1 and a.stride(1) > 1: - a = a.contiguous() - if b.stride(0) > 1 and b.stride(1) > 1: - b = b.contiguous() - # checks constraints - assert a.shape[2] == b.shape[1], "incompatible dimensions" - B, M, K = a.shape - _, _, N = b.shape - # allocates output - c = out - # accumulator types - ACC_TYPE = ( - tl.float32 - if a.dtype in [torch.float16, torch.bfloat16, torch.float32] - else tl.int32 - ) - - # launch kernel - def grid(META): - return ( - triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), - META["SPLIT_K"], - B, + def bmm_out(a, b, out): + # handle non-contiguous inputs if necessary + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert a.shape[2] == b.shape[1], "incompatible dimensions" + B, M, K = a.shape + _, _, N = b.shape + # allocates output + c = out + # accumulator types + ACC_TYPE = ( + tl.float32 + if a.dtype in [torch.float16, torch.bfloat16, torch.float32] + else tl.int32 ) - # grid = lambda META: ( - # triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), - # META["SPLIT_K"], - # B, - # ) + # launch kernel + def grid(META): + return ( + triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + META["SPLIT_K"], + B, + ) + + # grid = lambda META: ( + # triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + # META["SPLIT_K"], + # B, + # ) - # autotuner = _kernel[grid].kernel - _kernel[grid](a, b, c, M, N, K, K, 1, N, 1, N, 1, GROUP_M=8, ACC_TYPE=ACC_TYPE) - # print(autotuner.best_config) - # print(autotuner.configs_timings) + # autotuner = _kernel[grid].kernel + _kernel[grid](a, b, c, M, N, K, K, 1, N, 1, N, 1, GROUP_M=8, ACC_TYPE=ACC_TYPE) + # print(autotuner.best_config) + # print(autotuner.configs_timings) diff --git a/torchinductor/triton_ops/conv.py b/torchinductor/triton_ops/conv.py index f71ec3308e..62d7123174 100644 --- a/torchinductor/triton_ops/conv.py +++ b/torchinductor/triton_ops/conv.py @@ -1,142 +1,98 @@ import torch -import triton -import triton.language as tl - -from .autotune import conv_heuristics -from .utils import _unpack - - -@conv_heuristics() -@triton.jit -def _kernel_delta_x_hwc( - x, - w, - y, - # stride of tensor - stride_xn, - stride_xc, - stride_xh, - stride_xw, - stride_wn, - stride_wc, - stride_wh, - stride_ww, - stride_yn, - stride_yc, - stride_yh, - stride_yw, - stride_biasn, - # pointer inc for x - delta_xh_ptr, - delta_xw_ptr, - delta_xc_ptr, - # Tensor dimensions - BATCH, - IN_C, - IN_H, - IN_W, - KERNEL_N, - KERNEL_H, - KERNEL_W, - OUT_H, - OUT_W, - # parameters of conv - stride_h, - stride_w, - padding_h, - padding_w, - dilation_h, - dilation_w, - output_padding_h, - output_padding_w, - groups, - # Metaparameters - ACC_TYPE: tl.constexpr, - CONV1X1_NHWC: tl.constexpr, - # blocks in different dimension - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - # reduction tiling parameter for matmul - BLOCK_K: tl.constexpr, - # Super-blocking for better L2 peformance - GROUP_H: tl.constexpr, -): - """ - each program instance computes a [BLOCK_BATCH, BLOCK_N, BLOCK_H, BLOCK_W] block of y - """ - # ----------------------------------------------------------- - # Map program ids `pid` to the block of y it should compute. - pid_nhw = tl.program_id(0) - pid_k = tl.program_id(1) - - # offset for output y - off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N) - off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M) - off_y_n = off_y_nhw // (OUT_H * OUT_W) - off_y_hw = off_y_nhw % (OUT_H * OUT_W) - off_y_h = off_y_hw // OUT_W + output_padding_h - off_y_w = off_y_hw % OUT_W + output_padding_w - - # offset for the initial ptr for x - off_x_n = off_y_n - off_x_h = off_y_h * stride_h - padding_h - off_x_w = off_y_w * stride_w - padding_w - off_x_nhw = off_x_n * stride_xn + off_x_h * stride_xh + off_x_w * stride_xw - off_x_crs = tl.arange(0, BLOCK_K) - - CRS = IN_C * KERNEL_H * KERNEL_W - # load inc ptr of x, upade x_ptrs - if not CONV1X1_NHWC: - delta_xh_ptrs = delta_xh_ptr + off_x_crs - delta_xw_ptrs = delta_xw_ptr + off_x_crs - delta_xc_ptrs = delta_xc_ptr + off_x_crs - delta_xh = tl.load(delta_xh_ptrs, mask=off_x_crs < CRS, other=0) - delta_xw = tl.load(delta_xw_ptrs, mask=off_x_crs < CRS, other=0) - delta_xc = tl.load(delta_xc_ptrs, mask=off_x_crs < CRS, other=0) - off_x_crs_unpacked = ( - delta_xh * stride_xh + delta_xw * stride_xw + delta_xc * stride_xc - ) - x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :] - else: - x_ptrs = x + off_x_nhw[:, None] + off_x_crs[None, :] - delta_xh = 0 - delta_xw = 0 - - mask_x = ( - (off_x_n < BATCH)[:, None] - & (off_x_crs < CRS)[None, :] - & (off_x_h[:, None] + delta_xh[None, :] >= 0) - & (off_x_h[:, None] + delta_xh[None, :] < IN_H) - & (off_x_w[:, None] + delta_xw[None, :] >= 0) - & (off_x_w[:, None] + delta_xw[None, :] < IN_W) - ) - - # offset for the inital ptr for w - off_w_crs = tl.arange(0, BLOCK_K) - off_w_k = off_y_k - w_ptrs = w + off_w_crs[:, None] + off_w_k[None, :] * stride_wn - mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :] - - # ------ load x ------ - matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0) - # ------ load w ------ - matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0) - - # ----------------------------------------------------------- - # allocate accumulator - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for crs in range(0, CRS, BLOCK_K): - - # ------ matrix multiplication ------ - acc += tl.dot(matrix_x, matrix_w) - # ------ update ptrs ------ - w_ptrs += BLOCK_K + +from ..utils import has_triton + +if has_triton(): + import triton + import triton.language as tl + + from .autotune import conv_heuristics + from .utils import _unpack + + @conv_heuristics() + @triton.jit + def _kernel_delta_x_hwc( + x, + w, + y, + # stride of tensor + stride_xn, + stride_xc, + stride_xh, + stride_xw, + stride_wn, + stride_wc, + stride_wh, + stride_ww, + stride_yn, + stride_yc, + stride_yh, + stride_yw, + stride_biasn, + # pointer inc for x + delta_xh_ptr, + delta_xw_ptr, + delta_xc_ptr, + # Tensor dimensions + BATCH, + IN_C, + IN_H, + IN_W, + KERNEL_N, + KERNEL_H, + KERNEL_W, + OUT_H, + OUT_W, + # parameters of conv + stride_h, + stride_w, + padding_h, + padding_w, + dilation_h, + dilation_w, + output_padding_h, + output_padding_w, + groups, + # Metaparameters + ACC_TYPE: tl.constexpr, + CONV1X1_NHWC: tl.constexpr, + # blocks in different dimension + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + # reduction tiling parameter for matmul + BLOCK_K: tl.constexpr, + # Super-blocking for better L2 peformance + GROUP_H: tl.constexpr, + ): + """ + each program instance computes a [BLOCK_BATCH, BLOCK_N, BLOCK_H, BLOCK_W] block of y + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of y it should compute. + pid_nhw = tl.program_id(0) + pid_k = tl.program_id(1) + + # offset for output y + off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N) + off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M) + off_y_n = off_y_nhw // (OUT_H * OUT_W) + off_y_hw = off_y_nhw % (OUT_H * OUT_W) + off_y_h = off_y_hw // OUT_W + output_padding_h + off_y_w = off_y_hw % OUT_W + output_padding_w + + # offset for the initial ptr for x + off_x_n = off_y_n + off_x_h = off_y_h * stride_h - padding_h + off_x_w = off_y_w * stride_w - padding_w + off_x_nhw = off_x_n * stride_xn + off_x_h * stride_xh + off_x_w * stride_xw + off_x_crs = tl.arange(0, BLOCK_K) + + CRS = IN_C * KERNEL_H * KERNEL_W # load inc ptr of x, upade x_ptrs - off_x_crs = crs + BLOCK_K + tl.arange(0, BLOCK_K) if not CONV1X1_NHWC: - delta_xh_ptrs += BLOCK_K - delta_xw_ptrs += BLOCK_K - delta_xc_ptrs += BLOCK_K + delta_xh_ptrs = delta_xh_ptr + off_x_crs + delta_xw_ptrs = delta_xw_ptr + off_x_crs + delta_xc_ptrs = delta_xc_ptr + off_x_crs delta_xh = tl.load(delta_xh_ptrs, mask=off_x_crs < CRS, other=0) delta_xw = tl.load(delta_xw_ptrs, mask=off_x_crs < CRS, other=0) delta_xc = tl.load(delta_xc_ptrs, mask=off_x_crs < CRS, other=0) @@ -145,7 +101,9 @@ def _kernel_delta_x_hwc( ) x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :] else: - x_ptrs += BLOCK_K + x_ptrs = x + off_x_nhw[:, None] + off_x_crs[None, :] + delta_xh = 0 + delta_xw = 0 mask_x = ( (off_x_n < BATCH)[:, None] @@ -155,169 +113,175 @@ def _kernel_delta_x_hwc( & (off_x_w[:, None] + delta_xw[None, :] >= 0) & (off_x_w[:, None] + delta_xw[None, :] < IN_W) ) + + # offset for the inital ptr for w + off_w_crs = tl.arange(0, BLOCK_K) + off_w_k = off_y_k + w_ptrs = w + off_w_crs[:, None] + off_w_k[None, :] * stride_wn mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :] - # ------ prefetch ------ + # ------ load x ------ matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0) # ------ load w ------ matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0) - acc = acc.to(y.dtype.element_ty) - - # rematerialize -- this saves some registers - # offset for output y - off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N) - off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M) - off_y_n = off_y_nhw // (OUT_H * OUT_W) - off_y_hw = off_y_nhw % (OUT_H * OUT_W) - # consider output padding - off_y_h = off_y_hw // OUT_W + output_padding_h - off_y_w = off_y_hw % OUT_W + output_padding_w - - # y ptrs in the block of [BLOCK_M, BLOCK_N] - y_ptrs = ( - y - + off_y_n[:, None] * stride_yn - + off_y_h[:, None] * stride_yh - + off_y_w[:, None] * stride_yw - + off_y_k[None, :] * stride_yc - ) - - # out-of-bounds check - mask_y = ( - (off_y_n < BATCH)[:, None] - & (off_y_h < OUT_H + output_padding_h)[:, None] - & (off_y_w < OUT_W + output_padding_w)[:, None] - & (off_y_k < KERNEL_N)[None, :] - ) - - tl.store(y_ptrs, acc, mask=mask_y) - - return - - -@conv_heuristics() -@triton.jit -def _kernel_delta_x( - x, - w, - y, - # stride of tensor - stride_xn, - stride_xc, - stride_xh, - stride_xw, - stride_wn, - stride_wc, - stride_wh, - stride_ww, - stride_yn, - stride_yc, - stride_yh, - stride_yw, - stride_biasn, - # pointer inc for x - delta_x_ptr, - # Tensor dimensions - BATCH, - IN_C, - IN_H, - IN_W, - KERNEL_N, - KERNEL_H, - KERNEL_W, - OUT_H, - OUT_W, - # parameters of conv - stride_h, - stride_w, - padding_h, - padding_w, - dilation_h, - dilation_w, - output_padding_h, - output_padding_w, - groups, - # Metaparameters - ACC_TYPE: tl.constexpr, - CONV1X1_NHWC: tl.constexpr, - # blocks in different dimension - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - # reduction tiling parameter for matmul - BLOCK_K: tl.constexpr, - # Super-blocking for better L2 peformance - GROUP_H: tl.constexpr, -): - """ - each program instance computes a [BLOCK_BATCH, BLOCK_N, BLOCK_H, BLOCK_W] block of y - """ - # ----------------------------------------------------------- - # Map program ids `pid` to the block of y it should compute. - pid_nhw = tl.program_id(0) - pid_k = tl.program_id(1) - - # offset for output y - off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N) - off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M) - off_y_n = off_y_nhw // (OUT_H * OUT_W) - off_y_hw = off_y_nhw % (OUT_H * OUT_W) - off_y_h = off_y_hw // OUT_W + output_padding_h - off_y_w = off_y_hw % OUT_W + output_padding_w - - # offset for the initial ptr for x - off_x_n = off_y_n - off_x_h = off_y_h * stride_h - padding_h - off_x_w = off_y_w * stride_w - padding_w - off_x_nhw = off_x_n * stride_xn + off_x_h * stride_xh + off_x_w * stride_xw - off_x_crs = tl.arange(0, BLOCK_K) - - CRS = IN_C * KERNEL_H * KERNEL_W - # load inc ptr of x, upade x_ptrs - if not CONV1X1_NHWC: - delta_x_ptrs = delta_x_ptr + off_x_crs - off_x_crs_unpacked = tl.load(delta_x_ptrs, mask=off_x_crs < CRS) - x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :] - else: - x_ptrs = x + off_x_nhw[:, None] + off_x_crs[None, :] - - mask_x = ( - (off_x_n < BATCH) - & (off_x_h >= 0) - & (off_x_h < IN_H) - & (off_x_w >= 0) - & (off_x_w < IN_W) - )[:, None] & (off_x_crs < CRS)[None, :] - - # offset for the inital ptr for w - off_w_crs = tl.arange(0, BLOCK_K) - off_w_k = off_y_k - w_ptrs = w + off_w_crs[:, None] + off_w_k[None, :] * stride_wn - mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :] - - # ------ load x ------ - matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0) - # ------ load w ------ - matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0) - - # ----------------------------------------------------------- - # allocate accumulator - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for crs in range(0, CRS, BLOCK_K): - - # ------ matrix multiplication ------ - acc += tl.dot(matrix_x, matrix_w) - # ------ update ptrs ------ - w_ptrs += BLOCK_K + # ----------------------------------------------------------- + # allocate accumulator + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for crs in range(0, CRS, BLOCK_K): + + # ------ matrix multiplication ------ + acc += tl.dot(matrix_x, matrix_w) + # ------ update ptrs ------ + w_ptrs += BLOCK_K + # load inc ptr of x, upade x_ptrs + off_x_crs = crs + BLOCK_K + tl.arange(0, BLOCK_K) + if not CONV1X1_NHWC: + delta_xh_ptrs += BLOCK_K + delta_xw_ptrs += BLOCK_K + delta_xc_ptrs += BLOCK_K + delta_xh = tl.load(delta_xh_ptrs, mask=off_x_crs < CRS, other=0) + delta_xw = tl.load(delta_xw_ptrs, mask=off_x_crs < CRS, other=0) + delta_xc = tl.load(delta_xc_ptrs, mask=off_x_crs < CRS, other=0) + off_x_crs_unpacked = ( + delta_xh * stride_xh + delta_xw * stride_xw + delta_xc * stride_xc + ) + x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :] + else: + x_ptrs += BLOCK_K + + mask_x = ( + (off_x_n < BATCH)[:, None] + & (off_x_crs < CRS)[None, :] + & (off_x_h[:, None] + delta_xh[None, :] >= 0) + & (off_x_h[:, None] + delta_xh[None, :] < IN_H) + & (off_x_w[:, None] + delta_xw[None, :] >= 0) + & (off_x_w[:, None] + delta_xw[None, :] < IN_W) + ) + mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :] + # ------ prefetch ------ + # ------ load x ------ + matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0) + # ------ load w ------ + matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0) + + acc = acc.to(y.dtype.element_ty) + + # rematerialize -- this saves some registers + # offset for output y + off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N) + off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M) + off_y_n = off_y_nhw // (OUT_H * OUT_W) + off_y_hw = off_y_nhw % (OUT_H * OUT_W) + # consider output padding + off_y_h = off_y_hw // OUT_W + output_padding_h + off_y_w = off_y_hw % OUT_W + output_padding_w + + # y ptrs in the block of [BLOCK_M, BLOCK_N] + y_ptrs = ( + y + + off_y_n[:, None] * stride_yn + + off_y_h[:, None] * stride_yh + + off_y_w[:, None] * stride_yw + + off_y_k[None, :] * stride_yc + ) + + # out-of-bounds check + mask_y = ( + (off_y_n < BATCH)[:, None] + & (off_y_h < OUT_H + output_padding_h)[:, None] + & (off_y_w < OUT_W + output_padding_w)[:, None] + & (off_y_k < KERNEL_N)[None, :] + ) + + tl.store(y_ptrs, acc, mask=mask_y) + + return + + @conv_heuristics() + @triton.jit + def _kernel_delta_x( + x, + w, + y, + # stride of tensor + stride_xn, + stride_xc, + stride_xh, + stride_xw, + stride_wn, + stride_wc, + stride_wh, + stride_ww, + stride_yn, + stride_yc, + stride_yh, + stride_yw, + stride_biasn, + # pointer inc for x + delta_x_ptr, + # Tensor dimensions + BATCH, + IN_C, + IN_H, + IN_W, + KERNEL_N, + KERNEL_H, + KERNEL_W, + OUT_H, + OUT_W, + # parameters of conv + stride_h, + stride_w, + padding_h, + padding_w, + dilation_h, + dilation_w, + output_padding_h, + output_padding_w, + groups, + # Metaparameters + ACC_TYPE: tl.constexpr, + CONV1X1_NHWC: tl.constexpr, + # blocks in different dimension + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + # reduction tiling parameter for matmul + BLOCK_K: tl.constexpr, + # Super-blocking for better L2 peformance + GROUP_H: tl.constexpr, + ): + """ + each program instance computes a [BLOCK_BATCH, BLOCK_N, BLOCK_H, BLOCK_W] block of y + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of y it should compute. + pid_nhw = tl.program_id(0) + pid_k = tl.program_id(1) + + # offset for output y + off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N) + off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M) + off_y_n = off_y_nhw // (OUT_H * OUT_W) + off_y_hw = off_y_nhw % (OUT_H * OUT_W) + off_y_h = off_y_hw // OUT_W + output_padding_h + off_y_w = off_y_hw % OUT_W + output_padding_w + + # offset for the initial ptr for x + off_x_n = off_y_n + off_x_h = off_y_h * stride_h - padding_h + off_x_w = off_y_w * stride_w - padding_w + off_x_nhw = off_x_n * stride_xn + off_x_h * stride_xh + off_x_w * stride_xw + off_x_crs = tl.arange(0, BLOCK_K) + + CRS = IN_C * KERNEL_H * KERNEL_W # load inc ptr of x, upade x_ptrs if not CONV1X1_NHWC: - delta_x_ptrs += BLOCK_K - off_x_crs = crs + BLOCK_K + tl.arange(0, BLOCK_K) - off_x_crs_unpacked = tl.load(delta_x_ptrs, mask=off_x_crs < CRS, other=0) + delta_x_ptrs = delta_x_ptr + off_x_crs + off_x_crs_unpacked = tl.load(delta_x_ptrs, mask=off_x_crs < CRS) x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :] else: - off_x_crs = crs + BLOCK_K + tl.arange(0, BLOCK_K) - x_ptrs += BLOCK_K + x_ptrs = x + off_x_nhw[:, None] + off_x_crs[None, :] mask_x = ( (off_x_n < BATCH) @@ -326,417 +290,455 @@ def _kernel_delta_x( & (off_x_w >= 0) & (off_x_w < IN_W) )[:, None] & (off_x_crs < CRS)[None, :] + + # offset for the inital ptr for w + off_w_crs = tl.arange(0, BLOCK_K) + off_w_k = off_y_k + w_ptrs = w + off_w_crs[:, None] + off_w_k[None, :] * stride_wn mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :] - # ------ prefetch ------ + # ------ load x ------ matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0) # ------ load w ------ matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0) - acc = acc.to(y.dtype.element_ty) - - # rematerialize -- this saves some registers - # offset for output y - off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N) - off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M) - off_y_n = off_y_nhw // (OUT_H * OUT_W) - off_y_hw = off_y_nhw % (OUT_H * OUT_W) - # consider output padding - off_y_h = off_y_hw // OUT_W + output_padding_h - off_y_w = off_y_hw % OUT_W + output_padding_w - - # y ptrs in the block of [BLOCK_M, BLOCK_N] - y_ptrs = ( - y - + off_y_n[:, None] * stride_yn - + off_y_h[:, None] * stride_yh - + off_y_w[:, None] * stride_yw - + off_y_k[None, :] * stride_yc - ) - - # out-of-bounds check - mask_y = ( - (off_y_n < BATCH)[:, None] - & (off_y_h < OUT_H + output_padding_h)[:, None] - & (off_y_w < OUT_W + output_padding_w)[:, None] - & (off_y_k < KERNEL_N)[None, :] - ) - - tl.store(y_ptrs, acc, mask=mask_y) - - return - - -class _conv: - kernel = _kernel_delta_x_hwc - - # for the contigous order of w ptr, what"s the corresponding - # ptr changes for x in a sliding window - @staticmethod - def _delta_x_ptr_hwc( - IN_C, - KERNEL_H, - KERNEL_W, - dilation_h, - dilation_w, - stride_wc, - stride_wh, - stride_ww, - stride_xc, - stride_xh, - stride_xw, - device, - ): - # get the order of axes in w, innermost dimension outward - stride_w_3d = [stride_wc, stride_wh, stride_ww] - order = sorted(range(len(stride_w_3d)), key=stride_w_3d.__getitem__) - window_size = IN_C * KERNEL_H * KERNEL_W - - r_window = torch.arange(0, window_size, 1, device=device) - window_unpack = _unpack(r_window, order, [IN_C, KERNEL_H, KERNEL_W]) - window_unpack_c = window_unpack[order[0]] - window_unpack_h = window_unpack[order[1]] - window_unpack_w = window_unpack[order[2]] - r_dilation_h = dilation_h * window_unpack_h - r_dilation_w = dilation_w * window_unpack_w - r_inc = window_unpack_c - # delta_x = ( - # r_dilation_h * stride_xh + r_dilation_w * stride_xw + r_inc * stride_xc - # ) - # return delta_x - return ( - r_dilation_h, - r_dilation_w, - r_inc, + # ----------------------------------------------------------- + # allocate accumulator + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for crs in range(0, CRS, BLOCK_K): + + # ------ matrix multiplication ------ + acc += tl.dot(matrix_x, matrix_w) + # ------ update ptrs ------ + w_ptrs += BLOCK_K + # load inc ptr of x, upade x_ptrs + if not CONV1X1_NHWC: + delta_x_ptrs += BLOCK_K + off_x_crs = crs + BLOCK_K + tl.arange(0, BLOCK_K) + off_x_crs_unpacked = tl.load( + delta_x_ptrs, mask=off_x_crs < CRS, other=0 + ) + x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :] + else: + off_x_crs = crs + BLOCK_K + tl.arange(0, BLOCK_K) + x_ptrs += BLOCK_K + + mask_x = ( + (off_x_n < BATCH) + & (off_x_h >= 0) + & (off_x_h < IN_H) + & (off_x_w >= 0) + & (off_x_w < IN_W) + )[:, None] & (off_x_crs < CRS)[None, :] + mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :] + # ------ prefetch ------ + # ------ load x ------ + matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0) + # ------ load w ------ + matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0) + + acc = acc.to(y.dtype.element_ty) + + # rematerialize -- this saves some registers + # offset for output y + off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N) + off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M) + off_y_n = off_y_nhw // (OUT_H * OUT_W) + off_y_hw = off_y_nhw % (OUT_H * OUT_W) + # consider output padding + off_y_h = off_y_hw // OUT_W + output_padding_h + off_y_w = off_y_hw % OUT_W + output_padding_w + + # y ptrs in the block of [BLOCK_M, BLOCK_N] + y_ptrs = ( + y + + off_y_n[:, None] * stride_yn + + off_y_h[:, None] * stride_yh + + off_y_w[:, None] * stride_yw + + off_y_k[None, :] * stride_yc ) - @staticmethod - def _delta_x_ptr( - IN_C, - KERNEL_H, - KERNEL_W, - dilation_h, - dilation_w, - stride_wc, - stride_wh, - stride_ww, - stride_xc, - stride_xh, - stride_xw, - device, - ): - # get the order of axes in w, innermost dimension outward - stride_w_3d = [stride_wc, stride_wh, stride_ww] - order = sorted(range(len(stride_w_3d)), key=stride_w_3d.__getitem__) - window_size = IN_C * KERNEL_H * KERNEL_W - - r_window = torch.arange(0, window_size, 1, device=device) - window_unpack = _unpack(r_window, order, [IN_C, KERNEL_H, KERNEL_W]) - window_unpack_c = window_unpack[order[0]] - window_unpack_h = window_unpack[order[1]] - window_unpack_w = window_unpack[order[2]] - r_dilation_h = dilation_h * window_unpack_h - r_dilation_w = dilation_w * window_unpack_w - r_inc = window_unpack_c - delta_x = ( - r_dilation_h * stride_xh + r_dilation_w * stride_xw + r_inc * stride_xc + # out-of-bounds check + mask_y = ( + (off_y_n < BATCH)[:, None] + & (off_y_h < OUT_H + output_padding_h)[:, None] + & (off_y_w < OUT_W + output_padding_w)[:, None] + & (off_y_k < KERNEL_N)[None, :] ) - return delta_x - @staticmethod - def _call( - x, - w, - bias, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - ): - # Q: should we check x, w, bias dtypes? - device = x.device - # input shapes - shape_x = x.shape - shape_w = w.shape - shape_bias = bias.shape if bias is not None else None - - # indicies for the layeout - xn, xc, xh, xw = 0, 1, 2, 3 - yn, yc, yh, yw = 0, 1, 2, 3 - wn, wc, wh, ww = 0, 1, 2, 3 - - # out_channel, in_channel, kernel_height, kernel_width - kernel_size = [shape_w[wh], shape_w[ww]] - input_size = [shape_x[xh], shape_x[xw]] - assert ( - not shape_bias or shape_bias[0] == shape_w[wn] - ), f"bias shape did not match{shape_bias} != {shape_w[wn]}" - in_channel = shape_w[wc] * groups - - assert shape_x[xc] % groups == 0, "in_channels must be divisible by groups" - assert shape_w[wn] % groups == 0, "out_channels must be divisible by groups" - assert ( - shape_x[xc] == in_channel - ), f"in_channel did not match {shape_x[xc]} != {in_channel}" - - assert ( - len(stride) - == len(padding) - == len(dilation) - == len(output_padding) - == len(kernel_size) - == len(input_size) - ) + tl.store(y_ptrs, acc, mask=mask_y) + + return + + class _conv: + kernel = _kernel_delta_x_hwc + + # for the contigous order of w ptr, what"s the corresponding + # ptr changes for x in a sliding window + @staticmethod + def _delta_x_ptr_hwc( + IN_C, + KERNEL_H, + KERNEL_W, + dilation_h, + dilation_w, + stride_wc, + stride_wh, + stride_ww, + stride_xc, + stride_xh, + stride_xw, + device, + ): + # get the order of axes in w, innermost dimension outward + stride_w_3d = [stride_wc, stride_wh, stride_ww] + order = sorted(range(len(stride_w_3d)), key=stride_w_3d.__getitem__) + window_size = IN_C * KERNEL_H * KERNEL_W + + r_window = torch.arange(0, window_size, 1, device=device) + window_unpack = _unpack(r_window, order, [IN_C, KERNEL_H, KERNEL_W]) + window_unpack_c = window_unpack[order[0]] + window_unpack_h = window_unpack[order[1]] + window_unpack_w = window_unpack[order[2]] + r_dilation_h = dilation_h * window_unpack_h + r_dilation_w = dilation_w * window_unpack_w + r_inc = window_unpack_c + # delta_x = ( + # r_dilation_h * stride_xh + r_dilation_w * stride_xw + r_inc * stride_xc + # ) + # return delta_x + return ( + r_dilation_h, + r_dilation_w, + r_inc, + ) - # output shape - shape_y = [0] * 4 - shape_y[yn] = shape_x[xn] - shape_y[yc] = shape_w[wn] - shape_y[yh] = ( - input_size[0] - + 2 * padding[0] - - dilation[0] * (kernel_size[0] - 1) - - 1 - + stride[0] - ) // stride[0] + 2 * output_padding[0] - shape_y[yw] = ( - input_size[1] - + 2 * padding[1] - - dilation[1] * (kernel_size[1] - 1) - - 1 - + stride[1] - ) // stride[1] + 2 * output_padding[1] - - BATCH = shape_x[xn] - IN_C = shape_x[xc] - IN_H = shape_x[xh] - IN_W = shape_x[xw] - KERNEL_N = shape_w[wn] - KERNEL_H = shape_w[wh] - KERNEL_W = shape_w[ww] - OUT_H = shape_y[yh] - OUT_W = shape_y[yw] - - # allocate output - y = torch.empty(shape_y, device=device, dtype=x.dtype) - - # get strides for tensors - stride_x = x.stride() - stride_w = w.stride() - stride_bias = bias.stride() if shape_bias else None - stride_biasn = stride_bias[0] if stride_bias else None - - # output layout should be the same as x - if stride_x[xc] < stride_x[xh] and stride_x[xc] < stride_x[xw]: - y = y.to(memory_format=torch.channels_last) - stride_y = y.stride() - - # allocate tmp - # WINDOW_SIZE = KERNEL_H * KERNEL_W * IN_C - # tmp_x = torch.empty((BATCH * OUT_H * OUT_W, WINDOW_SIZE), device=device, dtype=x.dtype) - # tmp_w = torch.empty((WINDOW_SIZE, KERNEL_N), device=device, dtype=w.dtype) - # accumulator types - ACC_TYPE = ( - tl.float32 - if x.dtype in [torch.float16, torch.bfloat16, torch.float32] - else tl.int32 - ) - # if stride_x[xc] == 1 and stride_x > 1 and stride_y > 1: - CONV1X1_NHWC = False - if stride_x[xc] == 1 and KERNEL_H == 1 and KERNEL_W == 1: - CONV1X1_NHWC = True - # do we need delta x ptr for h, w, c dimension each or not - DELTA_X_PTR_HWC = ( - False - if ( - (padding[0] == 0 and padding[1] == 0) - or (KERNEL_H == 1 and KERNEL_W == 1) + @staticmethod + def _delta_x_ptr( + IN_C, + KERNEL_H, + KERNEL_W, + dilation_h, + dilation_w, + stride_wc, + stride_wh, + stride_ww, + stride_xc, + stride_xh, + stride_xw, + device, + ): + # get the order of axes in w, innermost dimension outward + stride_w_3d = [stride_wc, stride_wh, stride_ww] + order = sorted(range(len(stride_w_3d)), key=stride_w_3d.__getitem__) + window_size = IN_C * KERNEL_H * KERNEL_W + + r_window = torch.arange(0, window_size, 1, device=device) + window_unpack = _unpack(r_window, order, [IN_C, KERNEL_H, KERNEL_W]) + window_unpack_c = window_unpack[order[0]] + window_unpack_h = window_unpack[order[1]] + window_unpack_w = window_unpack[order[2]] + r_dilation_h = dilation_h * window_unpack_h + r_dilation_w = dilation_w * window_unpack_w + r_inc = window_unpack_c + delta_x = ( + r_dilation_h * stride_xh + r_dilation_w * stride_xw + r_inc * stride_xc ) - else True - ) - if not CONV1X1_NHWC: - if DELTA_X_PTR_HWC: - delta_xh, delta_xw, delta_xc = _conv._delta_x_ptr_hwc( + return delta_x + + @staticmethod + def _call( + x, + w, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ): + # Q: should we check x, w, bias dtypes? + device = x.device + # input shapes + shape_x = x.shape + shape_w = w.shape + shape_bias = bias.shape if bias is not None else None + + # indicies for the layeout + xn, xc, xh, xw = 0, 1, 2, 3 + yn, yc, yh, yw = 0, 1, 2, 3 + wn, wc, wh, ww = 0, 1, 2, 3 + + # out_channel, in_channel, kernel_height, kernel_width + kernel_size = [shape_w[wh], shape_w[ww]] + input_size = [shape_x[xh], shape_x[xw]] + assert ( + not shape_bias or shape_bias[0] == shape_w[wn] + ), f"bias shape did not match{shape_bias} != {shape_w[wn]}" + in_channel = shape_w[wc] * groups + + assert shape_x[xc] % groups == 0, "in_channels must be divisible by groups" + assert shape_w[wn] % groups == 0, "out_channels must be divisible by groups" + assert ( + shape_x[xc] == in_channel + ), f"in_channel did not match {shape_x[xc]} != {in_channel}" + + assert ( + len(stride) + == len(padding) + == len(dilation) + == len(output_padding) + == len(kernel_size) + == len(input_size) + ) + + # output shape + shape_y = [0] * 4 + shape_y[yn] = shape_x[xn] + shape_y[yc] = shape_w[wn] + shape_y[yh] = ( + input_size[0] + + 2 * padding[0] + - dilation[0] * (kernel_size[0] - 1) + - 1 + + stride[0] + ) // stride[0] + 2 * output_padding[0] + shape_y[yw] = ( + input_size[1] + + 2 * padding[1] + - dilation[1] * (kernel_size[1] - 1) + - 1 + + stride[1] + ) // stride[1] + 2 * output_padding[1] + + BATCH = shape_x[xn] + IN_C = shape_x[xc] + IN_H = shape_x[xh] + IN_W = shape_x[xw] + KERNEL_N = shape_w[wn] + KERNEL_H = shape_w[wh] + KERNEL_W = shape_w[ww] + OUT_H = shape_y[yh] + OUT_W = shape_y[yw] + + # allocate output + y = torch.empty(shape_y, device=device, dtype=x.dtype) + + # get strides for tensors + stride_x = x.stride() + stride_w = w.stride() + stride_bias = bias.stride() if shape_bias else None + stride_biasn = stride_bias[0] if stride_bias else None + + # output layout should be the same as x + if stride_x[xc] < stride_x[xh] and stride_x[xc] < stride_x[xw]: + y = y.to(memory_format=torch.channels_last) + stride_y = y.stride() + + # allocate tmp + # WINDOW_SIZE = KERNEL_H * KERNEL_W * IN_C + # tmp_x = torch.empty((BATCH * OUT_H * OUT_W, WINDOW_SIZE), device=device, dtype=x.dtype) + # tmp_w = torch.empty((WINDOW_SIZE, KERNEL_N), device=device, dtype=w.dtype) + # accumulator types + ACC_TYPE = ( + tl.float32 + if x.dtype in [torch.float16, torch.bfloat16, torch.float32] + else tl.int32 + ) + # if stride_x[xc] == 1 and stride_x > 1 and stride_y > 1: + CONV1X1_NHWC = False + if stride_x[xc] == 1 and KERNEL_H == 1 and KERNEL_W == 1: + CONV1X1_NHWC = True + # do we need delta x ptr for h, w, c dimension each or not + DELTA_X_PTR_HWC = ( + False + if ( + (padding[0] == 0 and padding[1] == 0) + or (KERNEL_H == 1 and KERNEL_W == 1) + ) + else True + ) + if not CONV1X1_NHWC: + if DELTA_X_PTR_HWC: + delta_xh, delta_xw, delta_xc = _conv._delta_x_ptr_hwc( + IN_C, + KERNEL_H, + KERNEL_W, + dilation[0], + dilation[1], + stride_w[wc], + stride_w[wh], + stride_w[ww], + stride_x[xc], + stride_x[xh], + stride_x[xw], + device, + ) + else: + delta_x = _conv._delta_x_ptr( + IN_C, + KERNEL_H, + KERNEL_W, + dilation[0], + dilation[1], + stride_w[wc], + stride_w[wh], + stride_w[ww], + stride_x[xc], + stride_x[xh], + stride_x[xw], + device, + ) + else: + delta_x = None + delta_xh, delta_xw, delta_xc = None, None, None + + # launch kernel, 2-dim, batch*h*w, kernel + def grid(META): + return ( + triton.cdiv(BATCH * OUT_H * OUT_W, META["BLOCK_M"]), + triton.cdiv(KERNEL_N, META["BLOCK_N"]), + ) + + # conv1x1 or padding==0 + if CONV1X1_NHWC or not DELTA_X_PTR_HWC: + _kernel_delta_x[grid]( + x, + w, + y, + # stride nchw for x,w,y tensor + stride_x[xn], + stride_x[xc], + stride_x[xh], + stride_x[xw], + stride_w[wn], + stride_w[wc], + stride_w[wh], + stride_w[ww], + stride_y[yn], + stride_y[yc], + stride_y[yh], + stride_y[yw], + stride_biasn, + # pointer inc for x + delta_x, + # Tensor dimensions + BATCH, IN_C, + IN_H, + IN_W, + KERNEL_N, KERNEL_H, KERNEL_W, + OUT_H, + OUT_W, + # conv parameters + stride[0], + stride[1], + padding[0], + padding[1], dilation[0], dilation[1], - stride_w[wc], - stride_w[wh], - stride_w[ww], + output_padding[0], + output_padding[1], + groups, + # Metaparameters + ACC_TYPE=ACC_TYPE, + CONV1X1_NHWC=CONV1X1_NHWC, + # BLOCK_M=128, + # BLOCK_N=32, + # BLOCK_K=32, + GROUP_H=1, + ) + # need to know ptr update for each dimension to check if + # the sliding window is out of bounds + else: + # kernel = _kernel_delta_x_hwc + _kernel_delta_x_hwc[grid]( + x, + w, + y, + # stride nchw for x,w,y tensor + stride_x[xn], stride_x[xc], stride_x[xh], stride_x[xw], - device, - ) - else: - delta_x = _conv._delta_x_ptr( + stride_w[wn], + stride_w[wc], + stride_w[wh], + stride_w[ww], + stride_y[yn], + stride_y[yc], + stride_y[yh], + stride_y[yw], + stride_biasn, + # pointer inc for x + delta_xh, + delta_xw, + delta_xc, + # Tensor dimensions + BATCH, IN_C, + IN_H, + IN_W, + KERNEL_N, KERNEL_H, KERNEL_W, + OUT_H, + OUT_W, + # conv parameters + stride[0], + stride[1], + padding[0], + padding[1], dilation[0], dilation[1], - stride_w[wc], - stride_w[wh], - stride_w[ww], - stride_x[xc], - stride_x[xh], - stride_x[xw], - device, + output_padding[0], + output_padding[1], + groups, + # Metaparameters + ACC_TYPE=ACC_TYPE, + CONV1X1_NHWC=CONV1X1_NHWC, + # BLOCK_M=128, + # BLOCK_N=32, + # BLOCK_K=32, + GROUP_H=1, ) - else: - delta_x = None - delta_xh, delta_xw, delta_xc = None, None, None - # launch kernel, 2-dim, batch*h*w, kernel - def grid(META): - return ( - triton.cdiv(BATCH * OUT_H * OUT_W, META["BLOCK_M"]), - triton.cdiv(KERNEL_N, META["BLOCK_N"]), - ) + if bias is not None: + if len(bias.shape) == 1: + bias = bias.reshape([1, bias.shape[0], 1, 1]) + y += bias + return y - # conv1x1 or padding==0 - if CONV1X1_NHWC or not DELTA_X_PTR_HWC: - _kernel_delta_x[grid]( - x, - w, - y, - # stride nchw for x,w,y tensor - stride_x[xn], - stride_x[xc], - stride_x[xh], - stride_x[xw], - stride_w[wn], - stride_w[wc], - stride_w[wh], - stride_w[ww], - stride_y[yn], - stride_y[yc], - stride_y[yh], - stride_y[yw], - stride_biasn, - # pointer inc for x - delta_x, - # Tensor dimensions - BATCH, - IN_C, - IN_H, - IN_W, - KERNEL_N, - KERNEL_H, - KERNEL_W, - OUT_H, - OUT_W, - # conv parameters - stride[0], - stride[1], - padding[0], - padding[1], - dilation[0], - dilation[1], - output_padding[0], - output_padding[1], - groups, - # Metaparameters - ACC_TYPE=ACC_TYPE, - CONV1X1_NHWC=CONV1X1_NHWC, - # BLOCK_M=128, - # BLOCK_N=32, - # BLOCK_K=32, - GROUP_H=1, - ) - # need to know ptr update for each dimension to check if - # the sliding window is out of bounds - else: - # kernel = _kernel_delta_x_hwc - _kernel_delta_x_hwc[grid]( + @staticmethod + def forward( + x, + w, + bias, + stride=(1, 1), + padding=(0, 0), + dilation=(1, 1), + transposed=False, + output_padding=(0, 0), + groups=1, + ): + if groups != 1: + print(f"Do not support groups = {groups}") + return + if transposed: + print("Do not support transposed") + return _conv._call( x, w, - y, - # stride nchw for x,w,y tensor - stride_x[xn], - stride_x[xc], - stride_x[xh], - stride_x[xw], - stride_w[wn], - stride_w[wc], - stride_w[wh], - stride_w[ww], - stride_y[yn], - stride_y[yc], - stride_y[yh], - stride_y[yw], - stride_biasn, - # pointer inc for x - delta_xh, - delta_xw, - delta_xc, - # Tensor dimensions - BATCH, - IN_C, - IN_H, - IN_W, - KERNEL_N, - KERNEL_H, - KERNEL_W, - OUT_H, - OUT_W, - # conv parameters - stride[0], - stride[1], - padding[0], - padding[1], - dilation[0], - dilation[1], - output_padding[0], - output_padding[1], + bias, + stride, + padding, + dilation, + transposed, + output_padding, groups, - # Metaparameters - ACC_TYPE=ACC_TYPE, - CONV1X1_NHWC=CONV1X1_NHWC, - # BLOCK_M=128, - # BLOCK_N=32, - # BLOCK_K=32, - GROUP_H=1, ) - if bias is not None: - if len(bias.shape) == 1: - bias = bias.reshape([1, bias.shape[0], 1, 1]) - y += bias - return y - - @staticmethod - def forward( - x, - w, - bias, - stride=(1, 1), - padding=(0, 0), - dilation=(1, 1), - transposed=False, - output_padding=(0, 0), - groups=1, - ): - if groups != 1: - print(f"Do not support groups = {groups}") - return - if transposed: - print("Do not support transposed") - return _conv._call( - x, - w, - bias, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - ) - - -conv = _conv.forward + conv = _conv.forward diff --git a/torchinductor/triton_ops/conv1x1.py b/torchinductor/triton_ops/conv1x1.py index b483c78c19..c7b79f004a 100644 --- a/torchinductor/triton_ops/conv1x1.py +++ b/torchinductor/triton_ops/conv1x1.py @@ -1,192 +1,195 @@ import torch -import triton +from ..utils import has_triton -class _conv1x1: - @staticmethod - def _call( - x, - w, - bias, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - ): - # Q: should we check x, w, bias dtypes? - device = x.device - # input shapes - shape_x = x.shape - shape_w = w.shape - shape_bias = bias.shape if bias is not None else None +if has_triton(): - # indicies for the layeout - xn, xc, xh, xw = 0, 1, 2, 3 - yn, yc, yh, yw = 0, 1, 2, 3 - wn, wc, wh, ww = 0, 1, 2, 3 + import triton - # out_channel, in_channel, kernel_height, kernel_width - kernel_size = [shape_w[wh], shape_w[ww]] - input_size = [shape_x[xh], shape_x[xw]] - assert ( - not shape_bias or shape_bias[0] == shape_w[wn] - ), f"bias shape did not match{shape_bias} != {shape_w[wn]}" - in_channel = shape_w[wc] * groups + class _conv1x1: + @staticmethod + def _call( + x, + w, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ): + # Q: should we check x, w, bias dtypes? + device = x.device + # input shapes + shape_x = x.shape + shape_w = w.shape + shape_bias = bias.shape if bias is not None else None - assert shape_x[xc] % groups == 0, "in_channels must be divisible by groups" - assert shape_w[wn] % groups == 0, "out_channels must be divisible by groups" - assert ( - shape_x[xc] == in_channel - ), f"in_channel did not match {shape_x[xc]} != {in_channel}" + # indicies for the layeout + xn, xc, xh, xw = 0, 1, 2, 3 + yn, yc, yh, yw = 0, 1, 2, 3 + wn, wc, wh, ww = 0, 1, 2, 3 - assert ( - len(stride) - == len(padding) - == len(dilation) - == len(output_padding) - == len(kernel_size) - == len(input_size) - ) + # out_channel, in_channel, kernel_height, kernel_width + kernel_size = [shape_w[wh], shape_w[ww]] + input_size = [shape_x[xh], shape_x[xw]] + assert ( + not shape_bias or shape_bias[0] == shape_w[wn] + ), f"bias shape did not match{shape_bias} != {shape_w[wn]}" + in_channel = shape_w[wc] * groups - # output shape - shape_y = [0] * 4 - shape_y[yn] = shape_x[xn] - shape_y[yc] = shape_w[wn] - shape_y[yh] = ( - input_size[0] - + 2 * padding[0] - - dilation[0] * (kernel_size[0] - 1) - - 1 - + stride[0] - ) // stride[0] + 2 * output_padding[0] - shape_y[yw] = ( - input_size[1] - + 2 * padding[1] - - dilation[1] * (kernel_size[1] - 1) - - 1 - + stride[1] - ) // stride[1] + 2 * output_padding[1] + assert shape_x[xc] % groups == 0, "in_channels must be divisible by groups" + assert shape_w[wn] % groups == 0, "out_channels must be divisible by groups" + assert ( + shape_x[xc] == in_channel + ), f"in_channel did not match {shape_x[xc]} != {in_channel}" - BATCH = shape_x[xn] - IN_C = shape_x[xc] - # IN_H = shape_x[xh] - # IN_W = shape_x[xw] - KERNEL_N = shape_w[wn] - KERNEL_H = shape_w[wh] - KERNEL_W = shape_w[ww] - OUT_H = shape_y[yh] - OUT_W = shape_y[yw] + assert ( + len(stride) + == len(padding) + == len(dilation) + == len(output_padding) + == len(kernel_size) + == len(input_size) + ) - assert KERNEL_H == 1 and KERNEL_W == 1, "only support 1x1 conv" - channels_last = x.stride()[1] == 1 + # output shape + shape_y = [0] * 4 + shape_y[yn] = shape_x[xn] + shape_y[yc] = shape_w[wn] + shape_y[yh] = ( + input_size[0] + + 2 * padding[0] + - dilation[0] * (kernel_size[0] - 1) + - 1 + + stride[0] + ) // stride[0] + 2 * output_padding[0] + shape_y[yw] = ( + input_size[1] + + 2 * padding[1] + - dilation[1] * (kernel_size[1] - 1) + - 1 + + stride[1] + ) // stride[1] + 2 * output_padding[1] - if padding == (0, 0): - # nchw -> nhwc - x = x.permute(0, 2, 3, 1) - # select every stride's element (for stride > 1) - x = x[:, :: stride[0], :: stride[1], :] - # 2d matrix - mat_x = x.reshape(-1, IN_C) - # 2d matrix - mat_w = w.view(KERNEL_N, IN_C) - mat_w = mat_w.permute(1, 0) - # 2d matrix y, (BATCH * OUT_H * OUT_W, KERNEL_N) - mat_y = triton.ops.matmul(mat_x, mat_w) - # mat_y = torch.empty((BATCH * OUT_H * OUT_W, KERNEL_N), device=device, dtype=x.dtype,) - y = mat_y.view(BATCH, OUT_H, OUT_W, KERNEL_N) - if bias is not None: - y += bias - # convert back to the original layout of y - # nhwc -> nchw - y = y.permute(0, 3, 1, 2) - if not channels_last: - y = y.to(memory_format=torch.contiguous_format) - return y + BATCH = shape_x[xn] + IN_C = shape_x[xc] + # IN_H = shape_x[xh] + # IN_W = shape_x[xw] + KERNEL_N = shape_w[wn] + KERNEL_H = shape_w[wh] + KERNEL_W = shape_w[ww] + OUT_H = shape_y[yh] + OUT_W = shape_y[yw] + + assert KERNEL_H == 1 and KERNEL_W == 1, "only support 1x1 conv" + channels_last = x.stride()[1] == 1 + + if padding == (0, 0): + # nchw -> nhwc + x = x.permute(0, 2, 3, 1) + # select every stride's element (for stride > 1) + x = x[:, :: stride[0], :: stride[1], :] + # 2d matrix + mat_x = x.reshape(-1, IN_C) + # 2d matrix + mat_w = w.view(KERNEL_N, IN_C) + mat_w = mat_w.permute(1, 0) + # 2d matrix y, (BATCH * OUT_H * OUT_W, KERNEL_N) + mat_y = triton.ops.matmul(mat_x, mat_w) + # mat_y = torch.empty((BATCH * OUT_H * OUT_W, KERNEL_N), device=device, dtype=x.dtype,) + y = mat_y.view(BATCH, OUT_H, OUT_W, KERNEL_N) + if bias is not None: + y += bias + # convert back to the original layout of y + # nhwc -> nchw + y = y.permute(0, 3, 1, 2) + if not channels_last: + y = y.to(memory_format=torch.contiguous_format) + return y - else: - y = torch.empty( - (shape_y[yn], shape_y[yh], shape_y[yw], shape_y[yc]), - device=device, - dtype=x.dtype, - ) - if channels_last: - y = y.to(memory_format=torch.channels_last) - # y = bias.repeat((shape_y[yn], shape_y[yh], shape_y[yw], 1)).to(device).type(x.dtype) - # convert x to channel-last layout; - # don't care w layout since kernel size is 1 - x = x.permute(0, 2, 3, 1) - # select every stride"s element (for stride > 1) - x = x[:, :: stride[0], :: stride[1], :] - # 2d matrix - mat_x = x.view(-1, IN_C) - # 2d matrix - mat_w = w.view(KERNEL_N, IN_C) - mat_w = mat_w.permute(1, 0) - # 2d matrix y, (BATCH * (OUT_H-2*padding) * (OUT_W-2*padding), KERNEL_N) - mat_y = triton.ops.matmul(mat_x, mat_w) - mat_y = mat_y.view( - BATCH, OUT_H - 2 * padding[0], OUT_W - 2 * padding[1], KERNEL_N - ) - # consider padding > 0 - if bias is not None: - y[ - :, - padding[0] : OUT_H - padding[0], - padding[1] : OUT_W - padding[1], - :, - ] = ( - mat_y + bias - ) - y[:, : padding[0], :, :] = bias - y[:, :, : padding[1], :] = bias - y[:, OUT_H - padding[0] :, :, :] = bias - y[:, :, OUT_W - padding[1] :, :] = bias else: - y[ - :, - padding[0] : OUT_H - padding[0], - padding[1] : OUT_W - padding[1], - :, - ] = mat_y - y[:, : padding[0], :, :] = 0 - y[:, :, : padding[1], :] = 0 - y[:, OUT_H - padding[0] :, :, :] = 0 - y[:, :, OUT_W - padding[1] :, :] = 0 - # convert back to the original layout of y - # nhwc -> nchw - y = y.permute(0, 3, 1, 2) - return y + y = torch.empty( + (shape_y[yn], shape_y[yh], shape_y[yw], shape_y[yc]), + device=device, + dtype=x.dtype, + ) + if channels_last: + y = y.to(memory_format=torch.channels_last) + # y = bias.repeat((shape_y[yn], shape_y[yh], shape_y[yw], 1)).to(device).type(x.dtype) + # convert x to channel-last layout; + # don't care w layout since kernel size is 1 + x = x.permute(0, 2, 3, 1) + # select every stride"s element (for stride > 1) + x = x[:, :: stride[0], :: stride[1], :] + # 2d matrix + mat_x = x.view(-1, IN_C) + # 2d matrix + mat_w = w.view(KERNEL_N, IN_C) + mat_w = mat_w.permute(1, 0) + # 2d matrix y, (BATCH * (OUT_H-2*padding) * (OUT_W-2*padding), KERNEL_N) + mat_y = triton.ops.matmul(mat_x, mat_w) + mat_y = mat_y.view( + BATCH, OUT_H - 2 * padding[0], OUT_W - 2 * padding[1], KERNEL_N + ) + # consider padding > 0 + if bias is not None: + y[ + :, + padding[0] : OUT_H - padding[0], + padding[1] : OUT_W - padding[1], + :, + ] = ( + mat_y + bias + ) + y[:, : padding[0], :, :] = bias + y[:, :, : padding[1], :] = bias + y[:, OUT_H - padding[0] :, :, :] = bias + y[:, :, OUT_W - padding[1] :, :] = bias + else: + y[ + :, + padding[0] : OUT_H - padding[0], + padding[1] : OUT_W - padding[1], + :, + ] = mat_y + y[:, : padding[0], :, :] = 0 + y[:, :, : padding[1], :] = 0 + y[:, OUT_H - padding[0] :, :, :] = 0 + y[:, :, OUT_W - padding[1] :, :] = 0 + # convert back to the original layout of y + # nhwc -> nchw + y = y.permute(0, 3, 1, 2) + return y - @staticmethod - def forward( - x, - w, - bias, - stride=(1, 1), - padding=(0, 0), - dilation=(1, 1), - transposed=False, - output_padding=(0, 0), - groups=1, - ): - if groups != 1: - print(f"Do not support groups = {groups}") - return - if transposed: - print("Do not support transposed") - return _conv1x1._call( + @staticmethod + def forward( x, w, bias, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - ) - + stride=(1, 1), + padding=(0, 0), + dilation=(1, 1), + transposed=False, + output_padding=(0, 0), + groups=1, + ): + if groups != 1: + print(f"Do not support groups = {groups}") + return + if transposed: + print("Do not support transposed") + return _conv1x1._call( + x, + w, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) -conv1x1 = _conv1x1.forward + conv1x1 = _conv1x1.forward diff --git a/torchinductor/triton_ops/conv_perf_model.py b/torchinductor/triton_ops/conv_perf_model.py index e7b5ef9496..3d6013cbfa 100644 --- a/torchinductor/triton_ops/conv_perf_model.py +++ b/torchinductor/triton_ops/conv_perf_model.py @@ -1,10 +1,6 @@ import heapq import torch -import triton -import triton._C.libtriton.triton as _triton -from triton.ops.matmul_perf_model import get_dram_gbps as get_dram_gbps -from triton.ops.matmul_perf_model import get_tflops as get_tflops def estimate_conv_time( @@ -29,6 +25,11 @@ def estimate_conv_time( ): """return estimated running time in ms = max(compute, loading) + store""" + import triton + import triton._C.libtriton.triton as _triton + from triton.ops.matmul_perf_model import get_dram_gbps as get_dram_gbps + from triton.ops.matmul_perf_model import get_tflops as get_tflops + backend = _triton.runtime.backend.CUDA device = torch.cuda.current_device() dtype = x.dtype @@ -90,6 +91,8 @@ def estimate_conv_time( def early_config_prune(configs, named_args): + import triton._C.libtriton.triton as _triton + backend = _triton.runtime.backend.CUDA device = torch.cuda.current_device() cc = _triton.runtime.cc(backend, device) diff --git a/torchinductor/triton_ops/matmul.py b/torchinductor/triton_ops/matmul.py index f8100b7371..4895bcef88 100644 --- a/torchinductor/triton_ops/matmul.py +++ b/torchinductor/triton_ops/matmul.py @@ -1,135 +1,137 @@ import torch -import triton -import triton.language as tl -from .autotune import mm_autotune -from .autotune import mm_heuristics +from ..utils import has_triton +if has_triton(): -@mm_heuristics() -@mm_autotune(get_io_bound_configs=True) -@triton.jit -def _kernel( - A, - B, - C, - M, - N, - K, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - allow_tf32: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr, - SPLIT_K: tl.constexpr, - EVEN_K: tl.constexpr, - ACC_TYPE: tl.constexpr, -): - # matrix multiplication - pid = tl.program_id(0) - pid_z = tl.program_id(1) - grid_m = (M + BLOCK_M - 1) // BLOCK_M - grid_n = (N + BLOCK_N - 1) // BLOCK_N - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) - # do matrix multiplication - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) - # pointers - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for k in range(K, 0, -BLOCK_K * SPLIT_K): - if EVEN_K: - a = tl.load(A) - b = tl.load(B) - else: - a = tl.load(A, mask=rk[None, :] < k, other=0.0) - b = tl.load(B, mask=rk[:, None] < k, other=0.0) - acc += tl.dot(a, b, allow_tf32=allow_tf32) - A += BLOCK_K * SPLIT_K * stride_ak - B += BLOCK_K * SPLIT_K * stride_bk - acc = acc.to(C.dtype.element_ty) - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) - mask = (rm < M)[:, None] & (rn < N)[None, :] - # handles write-back with reduction-splitting - if SPLIT_K == 1: - tl.store(C, acc, mask=mask) - else: - tl.atomic_add(C, acc, mask=mask) + import triton + import triton.language as tl + from .autotune import mm_autotune + from .autotune import mm_heuristics -class _matmul_out: - kernel = _kernel + @mm_heuristics() + @mm_autotune(get_io_bound_configs=True) + @triton.jit + def _kernel( + A, + B, + C, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + allow_tf32: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, + SPLIT_K: tl.constexpr, + EVEN_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + ): + # matrix multiplication + pid = tl.program_id(0) + pid_z = tl.program_id(1) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(K, 0, -BLOCK_K * SPLIT_K): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k, other=0.0) + b = tl.load(B, mask=rk[:, None] < k, other=0.0) + acc += tl.dot(a, b, allow_tf32=allow_tf32) + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + acc = acc.to(C.dtype.element_ty) + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) - @staticmethod - def _call(a, b, out, allow_tf32=True): - # handle non-contiguous inputs if necessary - if a.stride(0) > 1 and a.stride(1) > 1: - a = a.contiguous() - if b.stride(0) > 1 and b.stride(1) > 1: - b = b.contiguous() - # checks constraints - assert a.shape[1] == b.shape[0], "incompatible dimensions" - M, K = a.shape - _, N = b.shape - # allocates output - c = out - # accumulator types - ACC_TYPE = ( - tl.float32 - if a.dtype in [torch.float16, torch.bfloat16, torch.float32] - else tl.int32 - ) + class _matmul_out: + kernel = _kernel - # launch kernel (grid defined as using def instead of lambda to pass `make lint`) - def grid(META): - return ( - triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), - META["SPLIT_K"], + @staticmethod + def _call(a, b, out, allow_tf32=True): + # handle non-contiguous inputs if necessary + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + # allocates output + c = out + # accumulator types + ACC_TYPE = ( + tl.float32 + if a.dtype in [torch.float16, torch.bfloat16, torch.float32] + else tl.int32 ) - # grid = lambda META: ( - # triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), - # META["SPLIT_K"], - # ) - _kernel[grid]( - a, - b, - c, - M, - N, - K, - a.stride(0), - a.stride(1), - b.stride(0), - b.stride(1), - c.stride(0), - c.stride(1), - allow_tf32=allow_tf32, - GROUP_M=8, - ACC_TYPE=ACC_TYPE, - ) + # launch kernel (grid defined as using def instead of lambda to pass `make lint`) + def grid(META): + return ( + triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + META["SPLIT_K"], + ) - @staticmethod - def forward(a, b, out, allow_tf32=True): - return _matmul_out._call(a, b, out, allow_tf32) + # grid = lambda META: ( + # triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + # META["SPLIT_K"], + # ) + _kernel[grid]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + allow_tf32=allow_tf32, + GROUP_M=8, + ACC_TYPE=ACC_TYPE, + ) + @staticmethod + def forward(a, b, out, allow_tf32=True): + return _matmul_out._call(a, b, out, allow_tf32) -matmul_out = _matmul_out.forward + matmul_out = _matmul_out.forward diff --git a/torchinductor/triton_ops/mm_perf_model.py b/torchinductor/triton_ops/mm_perf_model.py index 72bd4d2526..a9f6753b71 100644 --- a/torchinductor/triton_ops/mm_perf_model.py +++ b/torchinductor/triton_ops/mm_perf_model.py @@ -1,8 +1,4 @@ import torch -import triton -import triton._C.libtriton.triton as _triton -from triton.ops.matmul_perf_model import get_dram_gbps as get_dram_gbps -from triton.ops.matmul_perf_model import get_tflops as get_tflops def estimate_matmul_time( @@ -23,6 +19,11 @@ def estimate_matmul_time( ): """return estimated running time in ms = max(compute, loading) + store""" + import triton + import triton._C.libtriton.triton as _triton + from triton.ops.matmul_perf_model import get_dram_gbps as get_dram_gbps + from triton.ops.matmul_perf_model import get_tflops as get_tflops + backend = _triton.runtime.backend.CUDA device = torch.cuda.current_device() dtype = A.dtype From ad061f598788b0e3e374e62b0e710491ddcb1d31 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sat, 8 Oct 2022 12:47:26 -0700 Subject: [PATCH 08/36] . --- test/inductor/test_torchinductor.py | 7 ++++--- torchinductor/triton_ops/autotune.py | 5 +++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index baf7acf521..4f20733216 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -3809,6 +3809,7 @@ def forward(self, start_positions: torch.Tensor, x: torch.Tensor): assert same_two_models(mod, opt_mod, args), "Dynamo failed" -# if __name__ == "__main__": -# from torchdynamo.testing import run_tests -# run_tests() +if __name__ == "__main__": + from torchdynamo.testing import run_tests + + run_tests() diff --git a/torchinductor/triton_ops/autotune.py b/torchinductor/triton_ops/autotune.py index 4e0bf3116e..d4a009199b 100644 --- a/torchinductor/triton_ops/autotune.py +++ b/torchinductor/triton_ops/autotune.py @@ -15,19 +15,20 @@ from ..ir import ReductionHint from ..triton_ops.mm_perf_model import estimate_matmul_time from ..utils import conditional_product +from ..utils import has_triton from .conv_perf_model import early_config_prune as conv_early_config_prune from .conv_perf_model import estimate_conv_time log = logging.getLogger(__name__) -try: +if has_triton(): import triton from triton import Config from triton import cdiv from triton import next_power_of_2 from triton.runtime.jit import KernelInterface from triton.runtime.jit import get_cuda_stream -except ImportError: +else: cdiv = None Config = object get_cuda_stream = None From 57015a366b821303dcf5c7cd188ee8516fb06e62 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sat, 8 Oct 2022 13:17:26 -0700 Subject: [PATCH 09/36] fix --- torchdynamo/testing.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchdynamo/testing.py b/torchdynamo/testing.py index c40a944b6c..3134981c78 100644 --- a/torchdynamo/testing.py +++ b/torchdynamo/testing.py @@ -4,6 +4,7 @@ import importlib import logging import os.path +import sys import types import unittest from unittest.mock import patch @@ -49,6 +50,9 @@ def run_tests(argv=None, needs=()): except ImportError: return + if argv is None: + argv = sys.argv + run_tests(argv) From 25d768e0775cf45449865edc236d97ed9df8540d Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sat, 8 Oct 2022 21:13:56 -0700 Subject: [PATCH 10/36] . --- test/dynamo/test_replay_record.py | 29 +++++++++++++++++++++++------ torchdynamo/testing.py | 9 ++------- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/test/dynamo/test_replay_record.py b/test/dynamo/test_replay_record.py index a9751f872c..f0037e7726 100644 --- a/test/dynamo/test_replay_record.py +++ b/test/dynamo/test_replay_record.py @@ -109,7 +109,10 @@ def test_fn(): def test_nonlocal_module_fn_call(self): # replay when we use a module # not defined in the replay env - from . import mock_modules + try: + from . import mock_modules + except ImportError: + import mock_modules def test_fn(): z = mock_modules.mock_module2.method1([], 2) @@ -119,7 +122,10 @@ def test_fn(): self.check_replay(test_fn, exp_exc_name="RuntimeError") def test_nonlocal_module_class(self): - from .mock_modules import mock_module2 + try: + from .mock_modules import mock_module2 + except ImportError: + from mock_modules import mock_module2 def test_fn(): z = mock_module2.Class1(1, 2) @@ -129,11 +135,22 @@ def test_fn(): self.check_replay(test_fn, exp_exc_name="TypeError") def test_local_module(self): - def test_fn(x): - from .mock_modules import mock_module3 + try: + from .mock_modules import mock_module3 as _ # noqa: F401 - z = mock_module3.method1([], torch.ones(5, 1)) - return torch.ones(2, 2) + x + z[0] + def test_fn(x): + from .mock_modules import mock_module3 + + z = mock_module3.method1([], torch.ones(5, 1)) + return torch.ones(2, 2) + x + z[0] + + except ImportError: + + def test_fn(x): + from mock_modules import mock_module3 + + z = mock_module3.method1([], torch.ones(5, 1)) + return torch.ones(2, 2) + x + z[0] self.check_replay(test_fn, torch.ones(1, 1), exp_exc_name="RuntimeError") diff --git a/torchdynamo/testing.py b/torchdynamo/testing.py index 3134981c78..bbb5794496 100644 --- a/torchdynamo/testing.py +++ b/torchdynamo/testing.py @@ -4,7 +4,6 @@ import importlib import logging import os.path -import sys import types import unittest from unittest.mock import patch @@ -32,7 +31,7 @@ log = logging.getLogger(__name__) -def run_tests(argv=None, needs=()): +def run_tests(needs=()): from torch.testing._internal.common_utils import TEST_WITH_TORCHDYNAMO from torch.testing._internal.common_utils import run_tests @@ -49,11 +48,7 @@ def run_tests(argv=None, needs=()): importlib.import_module(need) except ImportError: return - - if argv is None: - argv = sys.argv - - run_tests(argv) + run_tests() def clone_me(x): From 30cef25178048764302350907137aab54e4aeb16 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sun, 9 Oct 2022 08:50:57 -0700 Subject: [PATCH 11/36] . --- test/inductor/test_torchinductor.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 34769ea6aa..489c3cde76 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -3822,7 +3822,6 @@ def forward(self, start_positions: torch.Tensor, x: torch.Tensor): assert same_two_models(mod, opt_mod, args), "Dynamo failed" -if __name__ == "__main__": - from torchdynamo.testing import run_tests - - run_tests() +# if __name__ == "__main__": +# from torchdynamo.testing import run_tests +# run_tests() From 681820649aeff32998fa528836f24d80ee2c8d54 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sun, 9 Oct 2022 14:39:24 -0700 Subject: [PATCH 12/36] . --- test/inductor/test_torchinductor.py | 7 ++++--- torchdynamo/allowed_functions.py | 15 ++++++++++++--- torchdynamo/testing.py | 4 ++++ 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 489c3cde76..34769ea6aa 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -3822,6 +3822,7 @@ def forward(self, start_positions: torch.Tensor, x: torch.Tensor): assert same_two_models(mod, opt_mod, args), "Dynamo failed" -# if __name__ == "__main__": -# from torchdynamo.testing import run_tests -# run_tests() +if __name__ == "__main__": + from torchdynamo.testing import run_tests + + run_tests() diff --git a/torchdynamo/allowed_functions.py b/torchdynamo/allowed_functions.py index b35164009a..1b6e9f2318 100644 --- a/torchdynamo/allowed_functions.py +++ b/torchdynamo/allowed_functions.py @@ -14,6 +14,7 @@ import numpy import torch +from torch.fx._symbolic_trace import is_fx_tracing from . import config from .utils import is_safe_constant @@ -127,8 +128,11 @@ def _is_allowed_module_prefix(obj): disallowed_modules = ( "torch.optim.", "torch.nn.modules.rnn.", - "torch.dynamo.", - "torch._C.dynamo.", + "torch._dynamo.", + "torch._C._dynamo.", + "torch._inductor.", + "torch._C.inductor.", + "torch.fx.", ) allowed_modules_dot = tuple([x + "." for x in allowed_modules]) module = inspect.getmodule(obj) @@ -152,7 +156,9 @@ def _find_torch_objects(module): for name, obj in list(module.__dict__.items()): if id(obj) not in torch_object_ids: if isinstance(obj, types.ModuleType): - if obj.__name__.startswith("torch."): + if obj.__name__.startswith("torch.") and _is_allowed_module_prefix( + obj + ): torch_object_ids[id(obj)] = f"{module.__name__}.{name}" _find_torch_objects(obj) elif _is_allowed_module_prefix(obj): @@ -167,6 +173,9 @@ def _find_torch_objects(module): if idx in torch_object_ids: del torch_object_ids[idx] + for extra in (is_fx_tracing,): + torch_object_ids[id(extra)] = f"{extra.__module__}.{extra.__name__}" + return torch_object_ids diff --git a/torchdynamo/testing.py b/torchdynamo/testing.py index bbb5794496..7224e5c4d7 100644 --- a/torchdynamo/testing.py +++ b/torchdynamo/testing.py @@ -32,12 +32,16 @@ def run_tests(needs=()): + from torch.testing._internal.common_utils import TEST_WITH_CROSSREF from torch.testing._internal.common_utils import TEST_WITH_TORCHDYNAMO from torch.testing._internal.common_utils import run_tests if TEST_WITH_TORCHDYNAMO: return # cant dynamo dynamo + if TEST_WITH_CROSSREF: + return # needs __torch_function__ + if isinstance(needs, str): needs = (needs,) for need in needs: From d18aff32882b09e27c49018130938c5f20396382 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sun, 9 Oct 2022 17:02:48 -0700 Subject: [PATCH 13/36] asan --- test/inductor/test_torchinductor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 34769ea6aa..391d366a94 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -11,6 +11,7 @@ import torch from torch.fx.experimental.proxy_tensor import make_fx from torch.nn import functional as F +from torch.testing._internal.common_utils import TEST_WITH_ASAN from torch.testing._internal.common_utils import TestCase as TorchTestCase from torch.utils._pytree import tree_flatten from torch.utils._pytree import tree_unflatten @@ -3498,6 +3499,7 @@ def forward( inps = [torch.randn(shape, dtype=dtype) for (shape, dtype) in inps] self.common(forward, inps, atol=1e-05, rtol=2e-05) + @unittest.skipIf(TEST_WITH_ASAN, "TODO: debug this with asan") def test_tmp_not_defined_issue2(self): def forward(arg38_1, arg81_1, getitem_17, new_zeros_default_4): div_tensor_7 = torch.ops.aten.div.Tensor(getitem_17, arg81_1) From 7f7c4b0928d43c2f50517f82af564fefd16b0598 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Mon, 10 Oct 2022 16:12:02 -0700 Subject: [PATCH 14/36] . --- torchdynamo/eval_frame.py | 13 +++++++++++++ torchdynamo/testing.py | 8 +++----- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/torchdynamo/eval_frame.py b/torchdynamo/eval_frame.py index 0051df70c6..080fcbdfee 100644 --- a/torchdynamo/eval_frame.py +++ b/torchdynamo/eval_frame.py @@ -4,6 +4,7 @@ import inspect import logging import os +import sys import threading import traceback import types @@ -365,6 +366,18 @@ def toy_example(a, b): """ if disable or os.environ.get("TORCHDYNAMO_DISABLE", "") == "1": return _NullDecorator() + if sys.platform == "win32": + warnings.warn( + "Windows is not currently supported, " + + f"{config.dynamo_import}.optimize() will do nothing" + ) + return _NullDecorator() + if sys.version_info >= (3, 11): + warnings.warn( + "Python 3.11+ not yet supported, " + f"{config.dynamo_import}.optimize() will do nothing" + ) + return _NullDecorator() backend = get_compiler_fn(backend) diff --git a/torchdynamo/testing.py b/torchdynamo/testing.py index 7224e5c4d7..7af0053f26 100644 --- a/torchdynamo/testing.py +++ b/torchdynamo/testing.py @@ -32,15 +32,13 @@ def run_tests(needs=()): + from torch.testing._internal.common_utils import IS_WINDOWS from torch.testing._internal.common_utils import TEST_WITH_CROSSREF from torch.testing._internal.common_utils import TEST_WITH_TORCHDYNAMO from torch.testing._internal.common_utils import run_tests - if TEST_WITH_TORCHDYNAMO: - return # cant dynamo dynamo - - if TEST_WITH_CROSSREF: - return # needs __torch_function__ + if TEST_WITH_TORCHDYNAMO or IS_WINDOWS or TEST_WITH_CROSSREF: + return # skip testing if isinstance(needs, str): needs = (needs,) From 151fc2517c8e9d7ae9c257476190c7dadca87c69 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Mon, 10 Oct 2022 18:04:32 -0700 Subject: [PATCH 15/36] . --- torchdynamo/testing.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torchdynamo/testing.py b/torchdynamo/testing.py index 7af0053f26..e7857ab51f 100644 --- a/torchdynamo/testing.py +++ b/torchdynamo/testing.py @@ -4,6 +4,7 @@ import importlib import logging import os.path +import sys import types import unittest from unittest.mock import patch @@ -37,7 +38,12 @@ def run_tests(needs=()): from torch.testing._internal.common_utils import TEST_WITH_TORCHDYNAMO from torch.testing._internal.common_utils import run_tests - if TEST_WITH_TORCHDYNAMO or IS_WINDOWS or TEST_WITH_CROSSREF: + if ( + TEST_WITH_TORCHDYNAMO + or IS_WINDOWS + or TEST_WITH_CROSSREF + or sys.version_info >= (3, 11) + ): return # skip testing if isinstance(needs, str): From e6f3ffbd840727b89c4a5a886ccbefd36b1b0042 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Mon, 10 Oct 2022 20:00:57 -0700 Subject: [PATCH 16/36] Filelock fix --- test/inductor/test_torchinductor.py | 2 +- torchinductor/codecache.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 847f63ba97..733bfc2b89 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -3872,4 +3872,4 @@ def forward(self, start_positions: torch.Tensor, x: torch.Tensor): if __name__ == "__main__": from torchdynamo.testing import run_tests - run_tests() + run_tests(needs="filelock") diff --git a/torchinductor/codecache.py b/torchinductor/codecache.py index fbfeefc7bc..e9c69b1c43 100644 --- a/torchinductor/codecache.py +++ b/torchinductor/codecache.py @@ -17,7 +17,6 @@ from typing import Dict import torch -from filelock import FileLock from torch.utils import cpp_extension from . import config @@ -77,6 +76,8 @@ def cpp_compiler_search(search): for cxx in search: try: if cxx is None: + from filelock import FileLock + lock_dir = get_lock_dir() lock = FileLock( os.path.join(lock_dir, "g++.lock"), timeout=LOCK_TIMEOUT @@ -85,7 +86,7 @@ def cpp_compiler_search(search): cxx = install_gcc_via_conda() subprocess.check_output([cxx, "--version"]) return cxx - except (subprocess.SubprocessError, FileNotFoundError): + except (subprocess.SubprocessError, FileNotFoundError, ImportError): continue raise exc.InvalidCxxCompiler() @@ -156,6 +157,8 @@ class CppCodeCache: def load(cls, source_code): key, input_path = write(source_code, "cpp", extra=cpp_compile_command("i", "o")) if key not in cls.cache: + from filelock import FileLock + lock_dir = get_lock_dir() lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) with lock: From 09039dd960a39f179b52d801cc45c970af2bbbf8 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Mon, 10 Oct 2022 20:24:23 -0700 Subject: [PATCH 17/36] filelock_skip --- test/inductor/test_torchinductor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 733bfc2b89..81e80f506c 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -25,6 +25,7 @@ import sympy importlib.import_module("functorch") + importlib.import_module("filelock") from functorch.compile import config as functorch_config from torch._decomp import get_decompositions @@ -46,7 +47,7 @@ sys.stderr.write(f"{type(e)}: {e}\n") if __name__ == "__main__": sys.exit(0) - raise unittest.SkipTest("requires sympy/functorch") + raise unittest.SkipTest("requires sympy/functorch/filelock") HAS_CPU = False From ef1575230a53d2111c016dbca1a8b614f5847c29 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Tue, 11 Oct 2022 16:04:15 -0700 Subject: [PATCH 18/36] . --- test/inductor/test_torchinductor_opinfo.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 65851f99eb..dc24cd741c 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -1,6 +1,8 @@ # Owner(s): ["module: inductor"] import atexit import os +import sys +import unittest from collections import defaultdict from enum import Enum from functools import partial @@ -22,11 +24,16 @@ from torchinductor.utils import has_triton try: - from .test_torchinductor import check_model - from .test_torchinductor import check_model_cuda -except ImportError: - from test_torchinductor import check_model - from test_torchinductor import check_model_cuda + try: + from .test_torchinductor import check_model + from .test_torchinductor import check_model_cuda + except ImportError: + from test_torchinductor import check_model + from test_torchinductor import check_model_cuda +except unittest.SkipTest: + if __name__ == "__main__": + sys.exit(0) + raise bf16 = torch.bfloat16 # not tested f64 = torch.float64 From 75f2c9e939a2a2007498dc4bcc156c1188eaf5ee Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Tue, 11 Oct 2022 21:30:07 -0700 Subject: [PATCH 19/36] . --- test/inductor/test_torchinductor.py | 7 ++++++- test/inductor/test_torchinductor_opinfo.py | 6 ++++-- torchdynamo/testing.py | 2 ++ torchinductor/codecache.py | 1 + 4 files changed, 13 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 81e80f506c..5d85e992a6 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -58,7 +58,12 @@ CppCodeCache.load("") HAS_CPU = True -except (CalledProcessError, OSError, torchinductor.exc.InvalidCxxCompiler): +except ( + CalledProcessError, + OSError, + torchinductor.exc.InvalidCxxCompiler, + torchinductor.exc.CppCompileError, +): pass aten = torch.ops.aten diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index dc24cd741c..d42d2625b6 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -21,16 +21,18 @@ from torch.testing._internal.common_utils import suppress_warnings import torchdynamo -from torchinductor.utils import has_triton try: + from torchinductor.utils import has_triton + try: from .test_torchinductor import check_model from .test_torchinductor import check_model_cuda except ImportError: from test_torchinductor import check_model from test_torchinductor import check_model_cuda -except unittest.SkipTest: +except (unittest.SkipTest, ImportError) as e: + sys.stderr.write(f"{type(e)}: {e}\n") if __name__ == "__main__": sys.exit(0) raise diff --git a/torchdynamo/testing.py b/torchdynamo/testing.py index e7857ab51f..02a395030c 100644 --- a/torchdynamo/testing.py +++ b/torchdynamo/testing.py @@ -33,6 +33,8 @@ def run_tests(needs=()): + return # TEMPORARY: disable all tests + from torch.testing._internal.common_utils import IS_WINDOWS from torch.testing._internal.common_utils import TEST_WITH_CROSSREF from torch.testing._internal.common_utils import TEST_WITH_TORCHDYNAMO diff --git a/torchinductor/codecache.py b/torchinductor/codecache.py index e9c69b1c43..378989393f 100644 --- a/torchinductor/codecache.py +++ b/torchinductor/codecache.py @@ -7,6 +7,7 @@ import re import shutil import subprocess +import sys import sysconfig import tempfile import types From 8866ec9c801f2ee8eaaef151b43750edc7a9dede Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Wed, 12 Oct 2022 12:00:07 +0000 Subject: [PATCH 20/36] Update PyTorch pin --- Makefile | 2 +- README.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 4b3e733424..5610f8ec77 100644 --- a/Makefile +++ b/Makefile @@ -12,7 +12,7 @@ PIP ?= python -m pip # versions used in CI # Also update the "Install nightly binaries" section of the README when updating these -PYTORCH_VERSION ?= dev20221011 +PYTORCH_VERSION ?= dev20221012 TRITON_VERSION ?= af76c989eb4799b015f8b288ccd8421558772e56 diff --git a/README.md b/README.md index 1a6fe5bfeb..1f743966aa 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ that you have installed locally matches the PyTorch version you are running. For the command below, you will need CUDA 11.7. ```shell -pip install --pre torch==1.14.0.dev20221011+cu117 --extra-index-url https://download.pytorch.org/whl/nightly/cu117 +pip install --pre torch==1.14.0.dev20221012+cu117 --extra-index-url https://download.pytorch.org/whl/nightly/cu117 pip install -U "git+https://github.com/openai/triton@af76c989eb4799b015f8b288ccd8421558772e56#subdirectory=python" pip install -U "git+https://github.com/pytorch/torchdynamo" ``` From 088f5619e1bb8905d507df1e228f09602bbc7391 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Wed, 12 Oct 2022 09:40:18 -0700 Subject: [PATCH 21/36] Circular import --- torchinductor/codecache.py | 1 - torchinductor/scheduler.py | 11 ++++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/torchinductor/codecache.py b/torchinductor/codecache.py index 378989393f..e9c69b1c43 100644 --- a/torchinductor/codecache.py +++ b/torchinductor/codecache.py @@ -7,7 +7,6 @@ import re import shutil import subprocess -import sys import sysconfig import tempfile import types diff --git a/torchinductor/scheduler.py b/torchinductor/scheduler.py index 5df96fb5a3..685b326a2a 100644 --- a/torchinductor/scheduler.py +++ b/torchinductor/scheduler.py @@ -19,9 +19,6 @@ from . import config from . import dependencies from . import ir -from .codegen.triton_template import should_use_template -from .codegen.triton_template import template_can_fuse -from .codegen.triton_template import template_codegen from .dependencies import MemoryDep from .dependencies import StarDep from .sizevars import SimplifyIndexing @@ -174,6 +171,8 @@ def can_inplace(self, read_dep: dependencies.MemoryDep): return False def allocate(self): + from .codegen.triton_template import should_use_template + if self.node.should_allocate() or should_use_template(self.node): # if self.node should allocate or # if self.node is generated by TritonKernelTemplates @@ -548,6 +547,8 @@ def get_name(self): class Scheduler: @dynamo_utils.dynamo_timed def __init__(self, nodes): + from .codegen.triton_template import should_use_template + super(Scheduler, self).__init__() self.backends = {} @@ -904,6 +905,8 @@ def can_fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): if not self.can_fuse_vertical(node1, node2): return False if node1.is_template(): + from .codegen.triton_template import template_can_fuse + return template_can_fuse(node1, node2) return self.get_backend(device).can_fuse_vertical(node1, node2) else: # nodes don't depend on each other, but may have common reads @@ -1032,6 +1035,8 @@ def codegen_extern_call(self, scheduler_node: ExternKernelSchedulerNode): def codegen_template_call( self, scheduler_node: Union[FusedSchedulerNode, TemplateSchedulerNode] ): + from .codegen.triton_template import template_codegen + node, *epilogue = scheduler_node.get_nodes() node.allocate() template_codegen(self, node, epilogue) From 4a59e6ea946f5eac1c47a2ac409da1266be0af80 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Wed, 12 Oct 2022 10:35:20 -0700 Subject: [PATCH 22/36] Unbound local --- benchmarks/common.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/benchmarks/common.py b/benchmarks/common.py index 977b79da9a..046ad5f98d 100644 --- a/benchmarks/common.py +++ b/benchmarks/common.py @@ -1634,8 +1634,6 @@ def main(runner, original_dir=None): return sys.exit(-1) if not args.devices: - import torch - if torch.cuda.is_available(): args.devices = ["cuda"] else: From 4080e1b8058fdf2934a8c6151e1a89682f9f7d5e Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Wed, 12 Oct 2022 11:05:30 -0700 Subject: [PATCH 23/36] fsdp import error --- test/dynamo/test_distributed.py | 47 ++++++++++++++++++++------------- 1 file changed, 29 insertions(+), 18 deletions(-) diff --git a/test/dynamo/test_distributed.py b/test/dynamo/test_distributed.py index 0559bab117..c0478e8288 100644 --- a/test/dynamo/test_distributed.py +++ b/test/dynamo/test_distributed.py @@ -1,13 +1,12 @@ # Owner(s): ["module: dynamo"] import os +import unittest from unittest.mock import patch import pytest import torch import torch.distributed as dist from torch import nn -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.nn.parallel import DistributedDataParallel as DDP import torchdynamo from torchdynamo import config @@ -36,6 +35,13 @@ def compile_fn(self, gm, example_inputs): return gm +def skip_if_no_active_ddp(): + from torch.nn.parallel import DistributedDataParallel as DDP + + if not hasattr(DDP, "_get_active_ddp_module"): + raise unittest.SkipTest("requires pytorch landing in parallel") + + @pytest.mark.skip("Module hangs in PyTorch CI") class TestDistributed(torchdynamo.testing.TestCase): """ @@ -73,6 +79,8 @@ def get_model(self): @patch.object(config, "optimize_ddp", False) def test_ddp_baseline_aot_eager(self): + from torch.nn.parallel import DistributedDataParallel as DDP + m, inputs, correct_outputs = self.get_model() ddp_m = DDP(m, device_ids=self.device_ids) ddp_m = torchdynamo.optimize("aot_eager")(ddp_m) @@ -81,6 +89,8 @@ def test_ddp_baseline_aot_eager(self): @patch.object(config, "optimize_ddp", False) def test_ddp_baseline_inductor(self): + from torch.nn.parallel import DistributedDataParallel as DDP + m, inputs, correct_outputs = self.get_model() ddp_m = DDP(m, device_ids=self.device_ids) ddp_m = torchdynamo.optimize("inductor")(ddp_m) @@ -91,6 +101,8 @@ def test_ddp_baseline_inductor(self): @pytest.mark.xfail @patch.object(config, "optimize_ddp", False) def test_fsdp_baseline_aot_eager(self): + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + m, inputs, correct_outputs = self.get_model() fsdp_m = FSDP(m, device_id=self.device_ids[0] if self.device_ids else None) fsdp_m = torchdynamo.optimize("aot_eager")(fsdp_m) @@ -101,16 +113,14 @@ def test_fsdp_baseline_aot_eager(self): @pytest.mark.skip @patch.object(config, "optimize_ddp", False) def test_fsdp_baseline_inductor(self): + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + m, inputs, correct_outputs = self.get_model() fsdp_m = FSDP(m, device_id=self.device_ids[0] if self.device_ids else None) fsdp_m = torchdynamo.optimize("inductor")(fsdp_m) outputs = fsdp_m(inputs) self.assertTrue(same(correct_outputs, outputs)) - @pytest.mark.skipif( - not hasattr(DDP, "_get_active_ddp_module"), - reason="requires pytorch landing in parallel", - ) @patch.object(config, "optimize_ddp", True) def test_graph_split(self): """ @@ -119,6 +129,10 @@ def test_graph_split(self): the user-provided compiler is called by the DDPOptimizer which is doing the graph splitting """ + from torch.nn.parallel import DistributedDataParallel as DDP + + skip_if_no_active_ddp() + m, inputs, correct_outputs = self.get_model() ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25) @@ -132,16 +146,15 @@ def opt_fn(inputs): self.assertTrue(same(correct_outputs, opt_outputs)) self.assertEqual(check_splits_compiler.compiler_called, 3) - @pytest.mark.skipif( - not hasattr(DDP, "_get_active_ddp_module"), - reason="requires pytorch landing in parallel", - ) @patch.object(config, "optimize_ddp", True) def test_graph_split_inductor(self): """ Same as above, but using inductor backend. We observed issues with inductor/fx interface in the past. """ + from torch.nn.parallel import DistributedDataParallel as DDP + + skip_if_no_active_ddp() m, inputs, correct_outputs = self.get_model() ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25) @@ -152,16 +165,15 @@ def opt_fn(inputs): opt_outputs = opt_fn(inputs) self.assertTrue(same(correct_outputs, opt_outputs)) - @pytest.mark.skipif( - not hasattr(DDP, "_get_active_ddp_module"), - reason="requires pytorch landing in parallel", - ) @patch.object(config, "optimize_ddp", True) def test_no_split(self): """ Ensures the DDPOptimizer returns a correct, compiled module without introducing graph splits. (Based on model parmeters fitting in the bucket) """ + from torch.nn.parallel import DistributedDataParallel as DDP + + skip_if_no_active_ddp() m, inputs, correct_outputs = self.get_model() ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=250) @@ -175,16 +187,15 @@ def opt_fn(inputs): self.assertTrue(same(correct_outputs, opt_outputs)) self.assertEqual(check_splits_compiler.compiler_called, 1) - @pytest.mark.skipif( - not hasattr(DDP, "_get_active_ddp_module"), - reason="requires pytorch landing in parallel", - ) @patch.object(config, "optimize_ddp", True) def test_aot_autograd(self): """ Explicitly check AotAutograd family of compilers work, since they require example inputs propagated between graph splits. """ + from torch.nn.parallel import DistributedDataParallel as DDP + + skip_if_no_active_ddp() m, inputs, correct_outputs = self.get_model() ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25) From 0641c877c03079e64e9d7790d94878ce4dd4b91a Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Wed, 12 Oct 2022 12:06:29 -0700 Subject: [PATCH 24/36] jinja --- torchinductor/codegen/triton_template.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torchinductor/codegen/triton_template.py b/torchinductor/codegen/triton_template.py index 65f32a5542..eb4bf8bb45 100644 --- a/torchinductor/codegen/triton_template.py +++ b/torchinductor/codegen/triton_template.py @@ -2,9 +2,6 @@ import os import sympy -from jinja2 import Environment -from jinja2 import FileSystemLoader -from jinja2 import StrictUndefined from .. import config from .. import ir @@ -18,6 +15,10 @@ class TritonTemplateKernel(TritonKernel): def __init__(self, node: ir.ExternKernel, *groups): + from jinja2 import Environment + from jinja2 import FileSystemLoader + from jinja2 import StrictUndefined + self.node = node self.template_name = template_dict[type(node)] env = Environment( From a7d0756d3d1878a9eadac363c0e906234300770b Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Wed, 12 Oct 2022 15:04:54 -0700 Subject: [PATCH 25/36] . --- benchmarks/common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmarks/common.py b/benchmarks/common.py index 5f39c070af..0cea65a37a 100644 --- a/benchmarks/common.py +++ b/benchmarks/common.py @@ -1285,6 +1285,7 @@ def compare_branches( "--diff_main called on main branch, what are you diffing?" ) + @staticmethod def maybe_fresh_cache(fn): def inner(self, *args, **kwargs): cache_minder = NullContext() From 3409d31cb81b909bf11ddbd7d7cb11259d00c283 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Wed, 12 Oct 2022 15:48:19 -0700 Subject: [PATCH 26/36] . --- benchmarks/common.py | 54 ++++++++++++++++++++++---------------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/benchmarks/common.py b/benchmarks/common.py index 0cea65a37a..29bde358ac 100644 --- a/benchmarks/common.py +++ b/benchmarks/common.py @@ -891,6 +891,33 @@ def scale(self, loss): return loss +def maybe_fresh_cache(fn): + def inner(self, *args, **kwargs): + cache_minder = NullContext() + if self.args.cold_start_latency: + cache_entries = {} + cache_minder = fresh_triton_cache(cache_entries) + + try: + with cache_minder: + return fn(self, *args, **kwargs) + finally: + dump_cache = False + if dump_cache and self.args.cold_start_latency: + output_csv( + output_filename[:-4] + "_triton_cache.csv", + ["dev", "name", "batch_size", "triton_cache"], + [ + current_device, + current_name, + current_batch_size, + cache_entries, + ], + ) + + return inner + + class BenchmarkRunner: def __init__(self): self.model_iter_fn = None @@ -1285,33 +1312,6 @@ def compare_branches( "--diff_main called on main branch, what are you diffing?" ) - @staticmethod - def maybe_fresh_cache(fn): - def inner(self, *args, **kwargs): - cache_minder = NullContext() - if self.args.cold_start_latency: - cache_entries = {} - cache_minder = fresh_triton_cache(cache_entries) - - try: - with cache_minder: - return fn(self, *args, **kwargs) - finally: - dump_cache = False - if dump_cache and self.args.cold_start_latency: - output_csv( - output_filename[:-4] + "_triton_cache.csv", - ["dev", "name", "batch_size", "triton_cache"], - [ - current_device, - current_name, - current_batch_size, - cache_entries, - ], - ) - - return inner - @maybe_fresh_cache def run_one_model( self, From bc0ebf5e73fa3519258f8a274b7ab5cfa04195e4 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Wed, 12 Oct 2022 19:07:30 -0700 Subject: [PATCH 27/36] . --- test/dynamo/test_replay_record.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_replay_record.py b/test/dynamo/test_replay_record.py index f0037e7726..4c39508b03 100644 --- a/test/dynamo/test_replay_record.py +++ b/test/dynamo/test_replay_record.py @@ -8,6 +8,13 @@ import torchdynamo.testing +try: + import dill +except ImportError: + dill = None + +requires_dill = unittest.skipIf(dill is None, "requires dill") + class ReplayRecordTests(torchdynamo.testing.TestCase): @classmethod @@ -66,6 +73,7 @@ def get_error_name(log): "Error logs for recorded execution and replayed execution should match.", ) + @requires_dill def test_unsuccessful_inline(self): def level2(): z = torch.ones(2, 2) @@ -82,6 +90,7 @@ def level0(): self.check_replay(level0, exp_exc_name="AssertionError") + @requires_dill def test_successful_inline(self): def test_fn(): x = torch.ones(2, 2) @@ -95,6 +104,7 @@ def level1(a): self.check_replay(test_fn, exp_exc_name="RuntimeError") + @requires_dill def test_nonlocal_fn_call(self): def nonlocal_fn(x): return x + torch.ones(2, 2) @@ -106,6 +116,7 @@ def test_fn(): self.check_replay(test_fn, exp_exc_name="RuntimeError") + @requires_dill def test_nonlocal_module_fn_call(self): # replay when we use a module # not defined in the replay env @@ -121,6 +132,7 @@ def test_fn(): self.check_replay(test_fn, exp_exc_name="RuntimeError") + @requires_dill def test_nonlocal_module_class(self): try: from .mock_modules import mock_module2 @@ -134,6 +146,7 @@ def test_fn(): self.check_replay(test_fn, exp_exc_name="TypeError") + @requires_dill def test_local_module(self): try: from .mock_modules import mock_module3 as _ # noqa: F401 @@ -155,6 +168,7 @@ def test_fn(x): self.check_replay(test_fn, torch.ones(1, 1), exp_exc_name="RuntimeError") # Verfiy that we replay when we have tensor arguments to the frame being replayed + @requires_dill def test_fn_call_args(self): def test_fn(x, y): return x + y + torch.zeros(2, 2) @@ -167,4 +181,4 @@ def test_fn(x, y): if __name__ == "__main__": from torchdynamo.testing import run_tests - run_tests(needs="dill") + run_tests() From 9205e10526433ed782702a8d10c0dcff74c8c5f4 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 13 Oct 2022 09:02:07 -0700 Subject: [PATCH 28/36] . --- Makefile | 2 +- requirements.txt | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index 2d9249b30d..cab94e94f7 100644 --- a/Makefile +++ b/Makefile @@ -98,7 +98,7 @@ build-deps: clone-deps (cd ../torchvision && python setup.py clean && python setup.py develop) (cd ../torchtext && python setup.py clean && python setup.py develop) (cd ../detectron2 && python setup.py clean && python setup.py develop) - # (cd ../torchbenchmark && python install.py --continue_on_fail) + (cd ../torchbenchmark && python install.py --continue_on_fail) (cd ../triton/python && python setup.py clean && python setup.py develop) make setup_lint python setup.py develop diff --git a/requirements.txt b/requirements.txt index eb3d8c85b4..aa195555c6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,5 +10,4 @@ numpy pytest pyyaml sympy -torch -filelock +torch>=1.12.0 From dcc079af18b3fd59d9e7c99a8081a0cd9c7c2e92 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 13 Oct 2022 17:14:11 -0700 Subject: [PATCH 29/36] Enable tests --- test/inductor/test_torchinductor.py | 2 +- torchdynamo/testing.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index cd8aefcd59..a75f5519e5 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -3787,7 +3787,7 @@ def fn(x, y): matmul_seen = False class TestRefMode(TorchDispatchMode): - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): kwargs = kwargs if kwargs else {} nonlocal inps diff --git a/torchdynamo/testing.py b/torchdynamo/testing.py index 3e22f67b6a..82baac498b 100644 --- a/torchdynamo/testing.py +++ b/torchdynamo/testing.py @@ -33,8 +33,6 @@ def run_tests(needs=()): - return # TEMPORARY: disable all tests - from torch.testing._internal.common_utils import IS_WINDOWS from torch.testing._internal.common_utils import TEST_WITH_CROSSREF from torch.testing._internal.common_utils import TEST_WITH_TORCHDYNAMO From a384a7634d15be46c1055608a0b9dd9c4ddafe11 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 13 Oct 2022 17:18:53 -0700 Subject: [PATCH 30/36] . --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index cdade4d712..2c4183a0a3 100755 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ setup( name="torchdynamo", - version="1.13.0.dev0", + version="1.14.0.dev0", url="https://github.com/pytorch/torchdynamo", description="A Python-level JIT compiler designed to make unmodified PyTorch programs faster.", long_description=long_description, From 084562455274ac8b2de34d1c9dd64caad4c3e00b Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 13 Oct 2022 19:13:13 -0700 Subject: [PATCH 31/36] . --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index aa195555c6..f59601a477 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,3 +11,4 @@ pytest pyyaml sympy torch>=1.12.0 +filelock From 614514efd72f75ff6c15b80b7f4582188658dfaa Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 13 Oct 2022 20:07:14 -0700 Subject: [PATCH 32/36] Workaround --- benchmarks/torchbench.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/benchmarks/torchbench.py b/benchmarks/torchbench.py index 55f289221f..b22a949440 100755 --- a/benchmarks/torchbench.py +++ b/benchmarks/torchbench.py @@ -240,6 +240,8 @@ def load_model( if batch_size is None and is_training and model_name in USE_SMALL_BATCH_SIZE: batch_size = USE_SMALL_BATCH_SIZE[model_name] + # workaround "RuntimeError: not allowed to set torch.backends.cudnn flags" + torch.backends.__allow_nonbracketed_mutation_flag = True if is_training: benchmark = benchmark_cls( test="train", device=device, jit=False, batch_size=batch_size From 841b3bfca5d282c95ecfbdbb5d6b6376e6dda933 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 13 Oct 2022 20:49:22 -0700 Subject: [PATCH 33/36] Move testcase to workaround import side effects --- test/dynamo/test_aot_autograd.py | 5 +- test/dynamo/test_aot_cudagraphs.py | 5 +- test/dynamo/test_distributed.py | 3 +- test/dynamo/test_dynamic_shapes.py | 2 +- test/dynamo/test_export.py | 5 +- test/dynamo/test_functions.py | 5 +- test/dynamo/test_global.py | 5 +- test/dynamo/test_minifier.py | 5 +- test/dynamo/test_misc.py | 5 +- test/dynamo/test_model_output.py | 7 +-- test/dynamo/test_modules.py | 5 +- test/dynamo/test_no_fake_tensors.py | 2 +- test/dynamo/test_nops.py | 5 +- test/dynamo/test_optimizations.py | 7 +-- test/dynamo/test_optimizers.py | 5 +- test/dynamo/test_python_autograd.py | 4 +- test/dynamo/test_recompile_ux.py | 3 +- test/dynamo/test_replay_record.py | 5 +- test/dynamo/test_repros.py | 5 +- test/dynamo/test_skip_non_tensor.py | 5 +- test/dynamo/test_subgraphs.py | 5 +- test/dynamo/test_unspec.py | 5 +- test/dynamo/test_verify_correctness.py | 5 +- test/inductor/test_torchinductor.py | 2 +- torchdynamo/test_case.py | 68 ++++++++++++++++++++++++++ torchdynamo/testing.py | 61 ----------------------- 26 files changed, 132 insertions(+), 107 deletions(-) create mode 100644 torchdynamo/test_case.py diff --git a/test/dynamo/test_aot_autograd.py b/test/dynamo/test_aot_autograd.py index d9f4c38da1..9e02099862 100644 --- a/test/dynamo/test_aot_autograd.py +++ b/test/dynamo/test_aot_autograd.py @@ -4,6 +4,7 @@ import torch import torchdynamo +import torchdynamo.test_case from torchdynamo.optimizations.training import is_aot_autograd_safe_to_run from torchdynamo.testing import rand_strided @@ -13,7 +14,7 @@ def compiler_safe_fn(gm, example_inputs, is_safe): return gm.forward -class AotAutogradFallbackTests(torchdynamo.testing.TestCase): +class AotAutogradFallbackTests(torchdynamo.test_case.TestCase): def test_LSTM(self): # https://github.com/pytorch/torchdynamo/issues/1147 class Repro(torch.nn.Module): @@ -133,6 +134,6 @@ def fn(x, y): if __name__ == "__main__": - from torchdynamo.testing import run_tests + from torchdynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_aot_cudagraphs.py b/test/dynamo/test_aot_cudagraphs.py index c23437041f..21cdb3d0ee 100644 --- a/test/dynamo/test_aot_cudagraphs.py +++ b/test/dynamo/test_aot_cudagraphs.py @@ -7,6 +7,7 @@ import torch import torchdynamo +import torchdynamo.test_case import torchdynamo.testing from torchdynamo.testing import same @@ -52,7 +53,7 @@ def patch_all(ok=True): @unittest.skipIf(not torch.cuda.is_available(), "these tests require cuda") -class TestAotCudagraphs(torchdynamo.testing.TestCase): +class TestAotCudagraphs(torchdynamo.test_case.TestCase): @patch_all() def test_basic(self): def model(x, y): @@ -201,6 +202,6 @@ def fn(x): if __name__ == "__main__": - from torchdynamo.testing import run_tests + from torchdynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_distributed.py b/test/dynamo/test_distributed.py index 8f0bdf5dd7..2356443d85 100644 --- a/test/dynamo/test_distributed.py +++ b/test/dynamo/test_distributed.py @@ -9,6 +9,7 @@ from torch import nn import torchdynamo +import torchdynamo.test_case from torchdynamo import config from torchdynamo.testing import same @@ -43,7 +44,7 @@ def skip_if_no_active_ddp(): @pytest.mark.skip("Module hangs in PyTorch CI") -class TestDistributed(torchdynamo.testing.TestCase): +class TestDistributed(torchdynamo.test_case.TestCase): """ Test harness initializes dist process group """ diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py index 0985ebbdfd..b9159c466b 100644 --- a/test/dynamo/test_dynamic_shapes.py +++ b/test/dynamo/test_dynamic_shapes.py @@ -29,6 +29,6 @@ def make_dynamic_cls(cls): DynamicShapesUnspecTests = make_dynamic_cls(test_unspec.UnspecTests) if __name__ == "__main__": - from torchdynamo.testing import run_tests + from torchdynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index cbc760f588..6044f7dac7 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -5,10 +5,11 @@ import torch.utils._pytree as pytree from torch.fx.experimental.proxy_tensor import make_fx +import torchdynamo.test_case import torchdynamo.testing -class ExportTests(torchdynamo.testing.TestCase): +class ExportTests(torchdynamo.test_case.TestCase): # TODO(voz): Refactor to a shared test function. # The tests in this file are a little redundant, # They all take a func, run it with eager, then export it, then compare @@ -1423,6 +1424,6 @@ def nop(x): if __name__ == "__main__": - from torchdynamo.testing import run_tests + from torchdynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 975d2a1abf..9d80eb52d9 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -11,6 +11,7 @@ from torch import sub from torch.nn import functional as F +import torchdynamo.test_case import torchdynamo.testing from torchdynamo.testing import requires_static_shapes @@ -53,7 +54,7 @@ def inline_unused(x): return x + 5.6 -class FunctionTests(torchdynamo.testing.TestCase): +class FunctionTests(torchdynamo.test_case.TestCase): @make_test def test_inline_jit_annotations(x): x = inline_script_if_tracing(x) @@ -670,6 +671,6 @@ def test_list_slice_assignment(x): if __name__ == "__main__": - from torchdynamo.testing import run_tests + from torchdynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_global.py b/test/dynamo/test_global.py index 3d11c8e479..e9e571c90b 100644 --- a/test/dynamo/test_global.py +++ b/test/dynamo/test_global.py @@ -1,6 +1,7 @@ # Owner(s): ["module: dynamo"] import torch +import torchdynamo.test_case import torchdynamo.testing from torchdynamo.testing import same @@ -43,7 +44,7 @@ def reset_name(): _name = 0 -class TestGlobals(torchdynamo.testing.TestCase): +class TestGlobals(torchdynamo.test_case.TestCase): def test_store_global_1(self): def fn(x): global g_counter @@ -227,6 +228,6 @@ def fn(a, b): if __name__ == "__main__": - from torchdynamo.testing import run_tests + from torchdynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_minifier.py b/test/dynamo/test_minifier.py index 8ade910798..719517f1db 100644 --- a/test/dynamo/test_minifier.py +++ b/test/dynamo/test_minifier.py @@ -6,6 +6,7 @@ import torch import torchdynamo +import torchdynamo.test_case import torchdynamo.testing from torchdynamo.optimizations.backends import create_backend @@ -23,7 +24,7 @@ def forward(self, x): return x -class MinfierTests(torchdynamo.testing.TestCase): +class MinfierTests(torchdynamo.test_case.TestCase): def test_after_dynamo(self): @create_backend def bad_dynamo_backend(subgraph): @@ -92,6 +93,6 @@ def test_after_aot(self): if __name__ == "__main__": - from torchdynamo.testing import run_tests + from torchdynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 36889ab070..b718c985ec 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -18,6 +18,7 @@ import torch.onnx.operators from torch.testing._internal.jit_utils import JitTestCase +import torchdynamo.test_case import torchdynamo.testing from torchdynamo import bytecode_transformation from torchdynamo.testing import CompileCounter @@ -32,7 +33,7 @@ def my_custom_function(x): return x + 1 -class MiscTests(torchdynamo.testing.TestCase): +class MiscTests(torchdynamo.test_case.TestCase): def test_boolarg(self): def boolarg(aa, bb, flag): if flag: @@ -2713,6 +2714,6 @@ def forward(self, x): if __name__ == "__main__": - from torchdynamo.testing import run_tests + from torchdynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_model_output.py b/test/dynamo/test_model_output.py index 1c4b695d99..2653c09c72 100644 --- a/test/dynamo/test_model_output.py +++ b/test/dynamo/test_model_output.py @@ -4,6 +4,7 @@ import torch +import torchdynamo.test_case import torchdynamo.testing from torchdynamo.testing import same @@ -22,7 +23,7 @@ def maybe_skip(fn): return fn -class TestHFPretrained(torchdynamo.testing.TestCase): +class TestHFPretrained(torchdynamo.test_case.TestCase): @maybe_skip def test_pretrained(self): def fn(a, tmp): @@ -38,7 +39,7 @@ def fn(a, tmp): self.assertTrue(same(ref, res)) -class TestModelOutput(torchdynamo.testing.TestCase): +class TestModelOutput(torchdynamo.test_case.TestCase): @maybe_skip def test_mo_create(self): def fn(a, b): @@ -160,6 +161,6 @@ def fn(obj): if __name__ == "__main__": - from torchdynamo.testing import run_tests + from torchdynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index 576850d1bf..4055337b40 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -9,6 +9,7 @@ from torch.nn.parameter import Parameter from torch.nn.parameter import UninitializedParameter +import torchdynamo.test_case import torchdynamo.testing from torchdynamo.eval_frame import unsupported from torchdynamo.mutation_guard import GenerationTracker @@ -605,7 +606,7 @@ def test_fn(self): return test_fn -class NNModuleTests(torchdynamo.testing.TestCase): +class NNModuleTests(torchdynamo.test_case.TestCase): test_seq = make_test(Seq()) test_basicmodule1 = make_test(BasicModule()) test_basicmodule2 = make_test(BasicModule()) @@ -885,6 +886,6 @@ def test_torch_static(): if __name__ == "__main__": - from torchdynamo.testing import run_tests + from torchdynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_no_fake_tensors.py b/test/dynamo/test_no_fake_tensors.py index 8762250c70..6cc3573464 100644 --- a/test/dynamo/test_no_fake_tensors.py +++ b/test/dynamo/test_no_fake_tensors.py @@ -28,6 +28,6 @@ def make_no_fake_cls(cls): NoFakeTensorsUnspecTests = make_no_fake_cls(test_unspec.UnspecTests) if __name__ == "__main__": - from torchdynamo.testing import run_tests + from torchdynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_nops.py b/test/dynamo/test_nops.py index efe160de24..797b1cfd5b 100644 --- a/test/dynamo/test_nops.py +++ b/test/dynamo/test_nops.py @@ -1,6 +1,7 @@ # Owner(s): ["module: dynamo"] import torch +import torchdynamo.test_case import torchdynamo.testing from torchdynamo import eval_frame @@ -35,7 +36,7 @@ def fn3(): ) -class NopTests(torchdynamo.testing.TestCase): +class NopTests(torchdynamo.test_case.TestCase): @with_debug_nops def test1(self): self.assertEqual(fn1(1, 2), -7) @@ -66,6 +67,6 @@ def test_extended_args(self): if __name__ == "__main__": - from torchdynamo.testing import run_tests + from torchdynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_optimizations.py b/test/dynamo/test_optimizations.py index 4da4455ab3..da0978a12f 100644 --- a/test/dynamo/test_optimizations.py +++ b/test/dynamo/test_optimizations.py @@ -8,6 +8,7 @@ import torch import torchdynamo +import torchdynamo.test_case from torchdynamo.optimizations import backends from torchdynamo.optimizations.analysis import has_mutation from torchdynamo.optimizations.log_args import conv_args_analysis @@ -65,7 +66,7 @@ def forward(self, x): return self.relu(self.bn(self.conv(x))) -class TestOptimizations(torchdynamo.testing.TestCase): +class TestOptimizations(torchdynamo.test_case.TestCase): def test_inplacifier(self): gm = torch.fx.symbolic_trace(Seq()) normalize(gm) @@ -184,7 +185,7 @@ def test_ipex_bf16(self): self.assertEqual(r2.dtype, torch.bfloat16) -class NormalizeIRTests(torchdynamo.testing.TestCase): +class NormalizeIRTests(torchdynamo.test_case.TestCase): @unittest.skipIf(not has_functorch(), "requires functorch") def test_inplace_normalize(self): def fn(a, b): @@ -203,6 +204,6 @@ def fn(a, b): if __name__ == "__main__": - from torchdynamo.testing import run_tests + from torchdynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_optimizers.py b/test/dynamo/test_optimizers.py index 79f96ee20e..19a9655d74 100644 --- a/test/dynamo/test_optimizers.py +++ b/test/dynamo/test_optimizers.py @@ -6,6 +6,7 @@ import torch import torchdynamo +import torchdynamo.test_case import torchdynamo.testing input = torch.ones([10, 10]) @@ -37,7 +38,7 @@ def fn(): return test_fn -class OptimizerTests(torchdynamo.testing.TestCase): +class OptimizerTests(torchdynamo.test_case.TestCase): @classmethod def setUpClass(cls): super().setUpClass() @@ -97,6 +98,6 @@ def setUpClass(cls): setattr(OptimizerTests, "test_" + opt.__name__.lower(), make_test(opt)) if __name__ == "__main__": - from torchdynamo.testing import run_tests + from torchdynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_python_autograd.py b/test/dynamo/test_python_autograd.py index d0a348c937..7958a23350 100644 --- a/test/dynamo/test_python_autograd.py +++ b/test/dynamo/test_python_autograd.py @@ -9,8 +9,8 @@ import torchdynamo from torchdynamo.testing import CompileCounter -from torchdynamo.testing import TestCase from torchdynamo.testing import same +from torchdynamo.test_case import run_tests, TestCase """ This is an example of a pure-python version of autograd implemented by @@ -289,6 +289,4 @@ def forward(a, b): if __name__ == "__main__": - from torchdynamo.testing import run_tests - run_tests() diff --git a/test/dynamo/test_recompile_ux.py b/test/dynamo/test_recompile_ux.py index ff96feb3c7..891945a043 100644 --- a/test/dynamo/test_recompile_ux.py +++ b/test/dynamo/test_recompile_ux.py @@ -6,10 +6,11 @@ import torchdynamo import torchdynamo.config +import torchdynamo.test_case import torchdynamo.testing -class RecompileUxTests(torchdynamo.testing.TestCase): +class RecompileUxTests(torchdynamo.test_case.TestCase): # TODO(whc) dynamo actualy recompiles one more time than the cache limit cache_limit = 1 diff --git a/test/dynamo/test_replay_record.py b/test/dynamo/test_replay_record.py index 4c39508b03..0995990d13 100644 --- a/test/dynamo/test_replay_record.py +++ b/test/dynamo/test_replay_record.py @@ -6,6 +6,7 @@ import torch +import torchdynamo.test_case import torchdynamo.testing try: @@ -16,7 +17,7 @@ requires_dill = unittest.skipIf(dill is None, "requires dill") -class ReplayRecordTests(torchdynamo.testing.TestCase): +class ReplayRecordTests(torchdynamo.test_case.TestCase): @classmethod def setUpClass(cls): super().setUpClass() @@ -179,6 +180,6 @@ def test_fn(x, y): if __name__ == "__main__": - from torchdynamo.testing import run_tests + from torchdynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 97c00336cd..0ab457abb9 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -16,6 +16,7 @@ from torch import nn from torch.nn import functional as F +import torchdynamo.test_case import torchdynamo.testing import torchdynamo.utils from torchdynamo.debug_utils import same_two_models @@ -751,7 +752,7 @@ def fn(self, tensor): return self.inner_fn(tensor.shape, (1, 2, 3)) -class ReproTests(torchdynamo.testing.TestCase): +class ReproTests(torchdynamo.test_case.TestCase): def test_do_paste_mask(self): torchdynamo.utils.counters.clear() opt__do_paste_mask = torchdynamo.optimize(torchdynamo.testing.CompileCounter())( @@ -1712,6 +1713,6 @@ def forward(self, getitem_1, getitem_2, add): if __name__ == "__main__": - from torchdynamo.testing import run_tests + from torchdynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_skip_non_tensor.py b/test/dynamo/test_skip_non_tensor.py index bbe2cb7b38..23c109481b 100644 --- a/test/dynamo/test_skip_non_tensor.py +++ b/test/dynamo/test_skip_non_tensor.py @@ -4,10 +4,11 @@ import torch import torchdynamo +import torchdynamo.test_case from torchdynamo.testing import CompileCounter -class SkipNonTensorTests(torchdynamo.testing.TestCase): +class SkipNonTensorTests(torchdynamo.test_case.TestCase): def test_add_tensor1(self): def fn(a, b): return a + b @@ -107,6 +108,6 @@ def __len__(self): if __name__ == "__main__": - from torchdynamo.testing import run_tests + from torchdynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_subgraphs.py b/test/dynamo/test_subgraphs.py index a78de22375..8f96a90033 100644 --- a/test/dynamo/test_subgraphs.py +++ b/test/dynamo/test_subgraphs.py @@ -4,6 +4,7 @@ import torch +import torchdynamo.test_case import torchdynamo.testing from torchdynamo import config from torchdynamo.testing import unsupported @@ -17,7 +18,7 @@ def indirectly_unsupported(a, b): return unsupported(a, c) -class SubGraphTests(torchdynamo.testing.TestCase): +class SubGraphTests(torchdynamo.test_case.TestCase): def _common(self, fn, frame_count, op_count): torchdynamo.reset() v1 = torch.ones(10) @@ -528,6 +529,6 @@ def fn(a, b): if __name__ == "__main__": - from torchdynamo.testing import run_tests + from torchdynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_unspec.py b/test/dynamo/test_unspec.py index e404f77fe6..429cb53180 100644 --- a/test/dynamo/test_unspec.py +++ b/test/dynamo/test_unspec.py @@ -7,6 +7,7 @@ import numpy as np import torch +import torchdynamo.test_case import torchdynamo.testing from torchdynamo.testing import same @@ -52,7 +53,7 @@ class UnspecTest(cls): @patch.object(torchdynamo.config, "specialize_int_float", False) -class UnspecTests(torchdynamo.testing.TestCase): +class UnspecTests(torchdynamo.test_case.TestCase): def test_numpy_correctness(self): def fn(x, y, z): xy = [x + y, y, False] @@ -222,6 +223,6 @@ def fn(image, scale_factor): if __name__ == "__main__": - from torchdynamo.testing import run_tests + from torchdynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_verify_correctness.py b/test/dynamo/test_verify_correctness.py index 31f780f8cd..cdf41f9680 100644 --- a/test/dynamo/test_verify_correctness.py +++ b/test/dynamo/test_verify_correctness.py @@ -8,6 +8,7 @@ import torchdynamo import torchdynamo.config as config +import torchdynamo.test_case from torchdynamo.optimizations import backends from torchdynamo.testing import same @@ -77,7 +78,7 @@ def transform(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: return gm -class TestVerifyCorrectness(torchdynamo.testing.TestCase): +class TestVerifyCorrectness(torchdynamo.test_case.TestCase): @patch.object(config, "verify_correctness", True) def test_example_inputs(self): def fn(a, bc, d): @@ -169,6 +170,6 @@ def test_ipex_fp32(self): if __name__ == "__main__": - from torchdynamo.testing import run_tests + from torchdynamo.test_case import run_tests run_tests() diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index a75f5519e5..7d2053d353 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -4022,6 +4022,6 @@ def forward(self, start_positions: torch.Tensor, x: torch.Tensor): if __name__ == "__main__": - from torchdynamo.testing import run_tests + from torchdynamo.test_case import run_tests run_tests(needs="filelock") diff --git a/torchdynamo/test_case.py b/torchdynamo/test_case.py new file mode 100644 index 0000000000..6b199b410a --- /dev/null +++ b/torchdynamo/test_case.py @@ -0,0 +1,68 @@ +import contextlib +import importlib +import sys +from unittest.mock import patch + +import torch +import torch.testing + +from . import config, reset, utils + +from torch.testing._internal.common_utils import TestCase as TorchTestCase +from torch.testing._internal.common_utils import IS_WINDOWS +from torch.testing._internal.common_utils import TEST_WITH_CROSSREF +from torch.testing._internal.common_utils import TEST_WITH_TORCHDYNAMO +from torch.testing._internal.common_utils import run_tests + + +def run_tests(needs=()): + + if ( + TEST_WITH_TORCHDYNAMO + or IS_WINDOWS + or TEST_WITH_CROSSREF + or sys.version_info >= (3, 11) + ): + return # skip testing + + if isinstance(needs, str): + needs = (needs,) + for need in needs: + if need == "cuda" and not torch.cuda.is_available(): + return + else: + try: + importlib.import_module(need) + except ImportError: + return + run_tests() + + +class TestCase(TorchTestCase): + @classmethod + def tearDownClass(cls): + cls._exit_stack.close() + super().tearDownClass() + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._exit_stack = contextlib.ExitStack() + cls._exit_stack.enter_context( + patch.object(config, "raise_on_backend_error", True) + ) + cls._exit_stack.enter_context( + patch.object(config, "raise_on_ctx_manager_usage", True) + ) + + def setUp(self): + super().setUp() + reset() + utils.counters.clear() + + def tearDown(self): + for k, v in utils.counters.items(): + print(k, v.most_common()) + reset() + utils.counters.clear() + super().tearDown() diff --git a/torchdynamo/testing.py b/torchdynamo/testing.py index 82baac498b..5555d36f0c 100644 --- a/torchdynamo/testing.py +++ b/torchdynamo/testing.py @@ -1,23 +1,19 @@ import contextlib import dis import functools -import importlib import logging import os.path -import sys import types import unittest from unittest.mock import patch import torch -import torch.testing._internal.common_utils from torch import fx from . import config from . import eval_frame from . import optimize_assert from . import reset -from . import utils from .bytecode_transformation import create_instruction from .bytecode_transformation import debug_checks from .bytecode_transformation import is_generator @@ -32,33 +28,6 @@ log = logging.getLogger(__name__) -def run_tests(needs=()): - from torch.testing._internal.common_utils import IS_WINDOWS - from torch.testing._internal.common_utils import TEST_WITH_CROSSREF - from torch.testing._internal.common_utils import TEST_WITH_TORCHDYNAMO - from torch.testing._internal.common_utils import run_tests - - if ( - TEST_WITH_TORCHDYNAMO - or IS_WINDOWS - or TEST_WITH_CROSSREF - or sys.version_info >= (3, 11) - ): - return # skip testing - - if isinstance(needs, str): - needs = (needs,) - for need in needs: - if need == "cuda" and not torch.cuda.is_available(): - return - else: - try: - importlib.import_module(need) - except ImportError: - return - run_tests() - - def clone_me(x): if x is None: return None @@ -228,36 +197,6 @@ def standard_test(self, fn, nargs, expected_ops=None, expected_ops_dynamic=None) self.assertEqual(actual.op_count, expected_ops) -class TestCase(torch.testing._internal.common_utils.TestCase): - @classmethod - def tearDownClass(cls): - cls._exit_stack.close() - super().tearDownClass() - - @classmethod - def setUpClass(cls): - super().setUpClass() - cls._exit_stack = contextlib.ExitStack() - cls._exit_stack.enter_context( - patch.object(config, "raise_on_backend_error", True) - ) - cls._exit_stack.enter_context( - patch.object(config, "raise_on_ctx_manager_usage", True) - ) - - def setUp(self): - super().setUp() - reset() - utils.counters.clear() - - def tearDown(self): - for k, v in utils.counters.items(): - print(k, v.most_common()) - reset() - utils.counters.clear() - super().tearDown() - - def dummy_fx_compile(gm: fx.GraphModule, example_inputs): return gm.forward From 739c299aa015f0b25b2dd93c29998630ad7c4ba1 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 13 Oct 2022 20:59:47 -0700 Subject: [PATCH 34/36] . --- test/dynamo/test_python_autograd.py | 3 ++- torchdynamo/test_case.py | 11 ++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/test/dynamo/test_python_autograd.py b/test/dynamo/test_python_autograd.py index 7958a23350..a99a77a714 100644 --- a/test/dynamo/test_python_autograd.py +++ b/test/dynamo/test_python_autograd.py @@ -8,9 +8,10 @@ import torch import torchdynamo +from torchdynamo.test_case import TestCase +from torchdynamo.test_case import run_tests from torchdynamo.testing import CompileCounter from torchdynamo.testing import same -from torchdynamo.test_case import run_tests, TestCase """ This is an example of a pure-python version of autograd implemented by diff --git a/torchdynamo/test_case.py b/torchdynamo/test_case.py index 6b199b410a..01819a231f 100644 --- a/torchdynamo/test_case.py +++ b/torchdynamo/test_case.py @@ -5,17 +5,18 @@ import torch import torch.testing - -from . import config, reset, utils - -from torch.testing._internal.common_utils import TestCase as TorchTestCase from torch.testing._internal.common_utils import IS_WINDOWS from torch.testing._internal.common_utils import TEST_WITH_CROSSREF from torch.testing._internal.common_utils import TEST_WITH_TORCHDYNAMO -from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import TestCase as TorchTestCase + +from . import config +from . import reset +from . import utils def run_tests(needs=()): + from torch.testing._internal.common_utils import run_tests if ( TEST_WITH_TORCHDYNAMO From 96fd6b179b018097f34e3c375276211b1d3abc5f Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 13 Oct 2022 21:24:11 -0700 Subject: [PATCH 35/36] . --- test/dynamo/test_distributed.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/dynamo/test_distributed.py b/test/dynamo/test_distributed.py index ecd99c9434..754396950e 100644 --- a/test/dynamo/test_distributed.py +++ b/test/dynamo/test_distributed.py @@ -210,10 +210,6 @@ def opt_fn(inputs): opt_outputs.sum().backward() self.assertTrue(same(correct_outputs, opt_outputs)) - @pytest.mark.skipif( - not hasattr(DDP, "_get_active_ddp_module"), - reason="requires pytorch landing in parallel", - ) @patch.object(config, "optimize_ddp", True) def test_custom_layer(self): """ @@ -222,6 +218,9 @@ def test_custom_layer(self): the user-provided compiler is called by the DDPOptimizer which is doing the graph splitting """ + from torch.nn.parallel import DistributedDataParallel as DDP + + skip_if_no_active_ddp() class MyCustomLinear(torch.nn.Module): def __init__(self): From 38ef9764e0ccaf2b91766a0575647adffacd14ad Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 13 Oct 2022 21:52:03 -0700 Subject: [PATCH 36/36] Fix base_dir --- torchdynamo/config.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torchdynamo/config.py b/torchdynamo/config.py index 468406809d..5ca3620b12 100644 --- a/torchdynamo/config.py +++ b/torchdynamo/config.py @@ -55,8 +55,6 @@ torch.onnx.is_in_onnx_export: False, } -# root folder of the project -base_dir = dirname(dirname(dirname(abspath(__file__)))) # don't specialize on shapes and strides and put shape ops in graph dynamic_shapes = os.environ.get("TORCHDYNAMO_DYNAMIC_SHAPES") == "1" @@ -153,6 +151,12 @@ # How to import torchinductor, either torchinductor or torch.inductor inductor_import = dynamo_import.replace("dynamo", "inductor") +# root folder of the project +if "torch." in dynamo_import: + base_dir = dirname(dirname(dirname(abspath(__file__)))) +else: + base_dir = dirname(dirname(abspath(__file__))) + class _AccessLimitingConfig(ModuleType): def __setattr__(self, name, value):