Skip to content

Commit

Permalink
[dynamo][eval frame] Make CacheEntry a PyObject (#107405)
Browse files Browse the repository at this point in the history
This PR makes CacheEntry a PyObject. This is prep PR for cache size changes. As CacheEntry is a py object, we can now traverse the linked list in Python and write cache size policies. It was possible to do in C, but Python is just easier to iterate upon. We call convert_frame only when we (re)compile, so a small bump in latency going from C to Python is acceptable here.

Pull Request resolved: #107405
Approved by: https://github.com/ezyang
ghstack dependencies: #106917, #107117
  • Loading branch information
anijain2305 authored and pytorchmergebot committed Aug 21, 2023
1 parent 3b2c5d4 commit e201e3f
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 35 deletions.
12 changes: 11 additions & 1 deletion torch/_dynamo/convert_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,15 @@ def is_recompilation(cache_size):
return cache_size >= 1


def compute_cache_size(cache_entry):
# Walk the linked list to calculate the cache size
running_cache_size = 0
while cache_entry:
running_cache_size += 1
cache_entry = cache_entry.next
return running_cache_size


FRAME_COUNTER = 0
FRAME_COMPILE_COUNTER: typing.Counter[int] = collections.Counter()

Expand All @@ -238,12 +247,13 @@ def convert_frame_assert(
reset_graph_break_dup_checker()

def _convert_frame_assert(
frame: types.FrameType, cache_size: int, hooks: Hooks, frame_state
frame: types.FrameType, cache_entry, hooks: Hooks, frame_state
):
increment_frame()

code = frame.f_code

cache_size = compute_cache_size(cache_entry)
if is_recompilation(cache_size) and (
recompiles_log.isEnabledFor(logging.DEBUG) or config.error_on_recompile
):
Expand Down
6 changes: 3 additions & 3 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def first_real_inst_idx(code):

def catch_errors_wrapper(callback, hooks: Hooks):
@functools.wraps(callback)
def catch_errors(frame, cache_size, frame_state):
def catch_errors(frame, cache_entry, frame_state):
assert frame_state is not None

if (
Expand Down Expand Up @@ -487,10 +487,10 @@ def catch_errors(frame, cache_size, frame_state):
ddp_optimizer.compile_fn,
hooks=hooks,
)
return hijacked_callback(frame, cache_size, hooks, frame_state)
return hijacked_callback(frame, cache_entry, hooks, frame_state)

with compile_lock, _disable_current_modes():
return callback(frame, cache_size, hooks, frame_state)
return callback(frame, cache_entry, hooks, frame_state)

catch_errors._torchdynamo_orig_callable = callback # type: ignore[attr-defined]
return catch_errors
Expand Down
172 changes: 141 additions & 31 deletions torch/csrc/dynamo/eval_frame.c
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ These two are encapsulated into a ExtraState.
// Linked list of cache entries, where each cache entry stores
// the check_fn and the torch.compile optimized python bytecode.
typedef struct cache_entry {
PyObject_HEAD
// check the guards: lambda: <locals of user function>: bool
PyObject* check_fn;
// modified user bytecode (protected by check_fn's guards)
Expand All @@ -304,10 +305,101 @@ typedef struct cache_entry {
struct cache_entry* next;
} CacheEntry;

static void cache_entry_dealloc(CacheEntry* e);

#define DECLARE_CACHE_ENTRY_ATTR(name) \
static PyObject* CacheEntry_##name(CacheEntry* self, PyObject* _noargs) { \
PyObject* res = (PyObject*)self->name; \
Py_INCREF(res); \
return res; \
}

DECLARE_CACHE_ENTRY_ATTR(check_fn)
DECLARE_CACHE_ENTRY_ATTR(code)
DECLARE_CACHE_ENTRY_ATTR(next)

static struct PyGetSetDef CacheEntry_properties[] = {
{"check_fn", (getter)CacheEntry_check_fn, NULL, NULL, NULL},
{"code", (getter)CacheEntry_code, NULL, NULL, NULL},
{"next", (getter)CacheEntry_next, NULL, NULL, NULL},
{NULL}};


static PyObject* cache_entry_new(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
CacheEntry *self;
self = (CacheEntry*) type->tp_alloc(type, 0);
if (self != NULL) {
// The corresponding decrefs for Py_None are in cache_entry_init.
Py_INCREF(Py_None);
self->check_fn = Py_None;
Py_INCREF(Py_None);
self->code = (PyCodeObject*)Py_None;
Py_INCREF(Py_None);
self->next = (CacheEntry*)Py_None;
}
return (PyObject*)self;
}


static int cache_entry_init(CacheEntry* self, PyObject* args, PyObject* kwds) {
PyObject* check_fn = NULL;
PyCodeObject* code = NULL;
CacheEntry* next = NULL;

static char *kwlist[] = {"check_fn", "code", "next", NULL};

int ret = PyArg_ParseTupleAndKeywords(
args, kwds, "OOO", kwlist,
&check_fn, &code, &next);

if (!ret) return -1;

if (check_fn) {
PyObject* tmp = self->check_fn;
Py_INCREF(check_fn);
self->check_fn = check_fn;
Py_XDECREF(tmp);
}

if (code) {
PyCodeObject* tmp = self->code;
Py_INCREF(code);
self->code = code;
Py_XDECREF(tmp);
}

if (next) {
CacheEntry* tmp = self->next;
Py_INCREF(next);
self->next = next;
Py_XDECREF(tmp);
}
return 0;
}

static PyTypeObject CacheEntryType = {
PyVarObject_HEAD_INIT(NULL, 0)
.tp_name = "torch._C.dynamo.eval_frame.CacheEntryWrapper",
.tp_basicsize = sizeof(CacheEntry),
.tp_itemsize = 0,
.tp_flags = Py_TPFLAGS_DEFAULT,
.tp_new = cache_entry_new,
.tp_init = (initproc)cache_entry_init,
.tp_dealloc = (destructor)cache_entry_dealloc,
.tp_getset = CacheEntry_properties,
};

// ExtraState encasulates CacheEntry and FrameState. ExtraState is the highest
// level of abstraction of what is stored on the extra code object. Previously,
// we saved different parts on different extra indexes. We prefer this way
// because of cleaner abstraction and faster SetExtra access.

// TODO(anijain2305) - Consider making this a PyObject. Benefits are
// 1) Modular dealloc - destroy_extra_state just becomes Py_DECREF(extra)
// 2) We can directly send the extra object to convert_frame callback. One
// data structure - easier to understand code.
// There might be some perf impact of going through a PyObject on the critical
// path, but it should not be too bad.
typedef struct {
// Cache entry for the code object
CacheEntry* cache_entry;
Expand All @@ -327,24 +419,28 @@ static CacheEntry* create_cache_entry(
// - guarded_code: Borrowed
// return
// - CacheEntry*: new reference.
CacheEntry* e = (CacheEntry*)malloc(sizeof(CacheEntry));
DEBUG_NULL_CHECK(e);
e->check_fn = PyObject_GetAttrString(guarded_code, "check_fn");
NULL_CHECK(e->check_fn);
e->code = (PyCodeObject*)PyObject_GetAttrString(guarded_code, "code");
NULL_CHECK(e->code);
e->next = next;
PyObject* check_fn = PyObject_GetAttrString(guarded_code, "check_fn"); // new reference
PyCodeObject* code = (PyCodeObject*)PyObject_GetAttrString(guarded_code, "code"); // new reference

// equivalent to CacheEntry(check_fn, code, next) in Python
PyObject* args = Py_BuildValue("OOO", check_fn, code, next);
CacheEntry* e = (CacheEntry*)PyObject_CallObject((PyObject*)&CacheEntryType, args); // new reference
// CacheEntry e is the now the owner of old cachey entry next. This happens
// when we incref the next pointer in cache_entry_init.
Py_DECREF(next);
Py_DECREF(check_fn);
Py_DECREF(code);
Py_DECREF(args);
return e;
}

static void destroy_cache_entry(CacheEntry* e) {
if (e == NULL || e == SKIP_CODE) {
return;
}
static void cache_entry_dealloc(CacheEntry* e) {
Py_XDECREF(e->check_fn);
Py_XDECREF(e->code);
destroy_cache_entry(e->next);
free(e);
// This will recursively call cache_entry_dealloc for the next items in the
// linked list.
Py_XDECREF(e->next);
Py_TYPE(e)->tp_free((PyObject*)e);
}

/* CacheEntry helper functions ends */
Expand Down Expand Up @@ -409,10 +505,10 @@ inline static void destroy_extra_state(void* obj) {

ExtraState* extra = (ExtraState*)obj;
if (extra != NULL && extra != SKIP_CODE) {
CacheEntry* cache_entry = extra->cache_entry;
FrameState* frame_state = extra->frame_state;
destroy_cache_entry(cache_entry);
Py_XDECREF(frame_state);
// Cpython gc will call cache_entry_dealloc on its own when the ref count
// goes to 0.
Py_XDECREF(extra->cache_entry);
Py_XDECREF(extra->frame_state);
free(extra);
}
}
Expand Down Expand Up @@ -457,7 +553,10 @@ inline static ExtraState* init_and_set_extra_state(PyCodeObject* code) {
CHECK(get_extra_state(code) == NULL);
ExtraState* extra_state = (ExtraState*)malloc(sizeof(ExtraState));
DEBUG_NULL_CHECK(extra_state);
extra_state->cache_entry = NULL;
// We set the last node in the linked list to Py_None. We incref the Py_None
// here, the corresponding decref is in cache_entry_dealloc.
Py_INCREF(Py_None);
extra_state->cache_entry = (CacheEntry*)Py_None;
extra_state->frame_state = PyDict_New();
set_extra_state(code, extra_state);
return extra_state;
Expand All @@ -470,6 +569,8 @@ Debugger helper functions.
*/

PyObject* _debug_get_cache_entry_list(PyObject* self, PyObject* args) {
// TODO(anijain2305) - CacheEntry being the first class Python object might
// obviate the need of this function. Revisit.
PyObject* object;
if (!PyArg_ParseTuple(args, "O", &object)) {
return NULL;
Expand All @@ -487,7 +588,7 @@ PyObject* _debug_get_cache_entry_list(PyObject* self, PyObject* args) {
if (!outer_list) {
return NULL; // Return NULL if failed to create list
}
while (current_node != NULL && current_node != SKIP_CODE) {
while (current_node != NULL && current_node != (CacheEntry*)Py_None) {
// Creating a new Python tuple for the check_fn and code of current CacheEntry
PyObject* inner_list = PyTuple_Pack(2, current_node->check_fn, current_node->code);
int flag = PyList_Append(outer_list, inner_list); // Add the inner list to the outer list
Expand All @@ -507,7 +608,7 @@ PyObject* _debug_get_cache_entry_list(PyObject* self, PyObject* args) {
static inline PyObject* call_callback(
PyObject* callable,
THP_EVAL_API_FRAME_OBJECT* _frame,
long cache_len,
CacheEntry* cache_entry,
FrameState* frame_state) {

#if IS_PYTHON_3_11_PLUS
Expand All @@ -518,7 +619,13 @@ static inline PyObject* call_callback(
#else
PyObject* frame = Py_NewRef(_frame);
#endif
PyObject* res = PyObject_CallFunction(callable, "OlO", frame, cache_len, frame_state);

PyObject* res = PyObject_CallFunction(
callable,
"OOO",
frame,
cache_entry,
frame_state);
Py_DECREF(frame);
return res;
}
Expand All @@ -536,7 +643,7 @@ static PyObject* call_guard_fail_hook(
e->code,
f_locals,
(Py_ssize_t)index,
(e->next == NULL ? Py_True : Py_False));
(e->next == (CacheEntry*)Py_None ? Py_True : Py_False));
}

static PyObject* call_profiler_start_hook(PyObject* name_str) {
Expand All @@ -558,7 +665,7 @@ static void call_profiler_end_hook(PyObject* record) {
// Return value: borrowed reference
// Is either Py_None or a PyCodeObject
static PyObject* lookup(CacheEntry* e, THP_EVAL_API_FRAME_OBJECT *frame, CacheEntry* prev, size_t index) {
if (e == NULL) {
if (e == (CacheEntry*)Py_None) {
// NB: intentionally not using Py_RETURN_NONE, to return borrowed ref
return Py_None;
}
Expand Down Expand Up @@ -602,13 +709,6 @@ static PyObject* lookup(CacheEntry* e, THP_EVAL_API_FRAME_OBJECT *frame, CacheEn
return lookup(e->next, frame, e, index + 1);
}

static long cache_size(CacheEntry* e) {
if (e == NULL) {
return 0;
}
return 1 + cache_size(e->next);
}

inline static PyObject* eval_custom_code(
PyThreadState* tstate,
THP_EVAL_API_FRAME_OBJECT* frame,
Expand Down Expand Up @@ -880,7 +980,7 @@ static PyObject* _custom_eval_frame(
// TODO(alband): This is WRONG for python3.11+ we pass in a _PyInterpreterFrame
// that gets re-interpreted as a PyObject (which it is NOT!)
PyObject* result =
call_callback(callback, frame, cache_size(cache_entry), frame_state);
call_callback(callback, frame, cache_entry, frame_state);
if (result == NULL) {
// internal exception, returning here will leak the exception into user code
// this is useful for debugging -- but we dont want it to happen outside of
Expand Down Expand Up @@ -1102,5 +1202,15 @@ PyObject* torch_c_dynamo_eval_frame_init(void) {
}
#endif


if (PyType_Ready(&CacheEntryType) < 0) {
return NULL;
}
Py_INCREF(&CacheEntryType);
if (PyModule_AddObject(module, "_CacheEntry", (PyObject *) &CacheEntryType) < 0) {
Py_DECREF(&CacheEntryType);
return NULL;
}

return module;
}

0 comments on commit e201e3f

Please sign in to comment.