New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Extract kernel to a separate object using delegation #303
base: master
Are you sure you want to change the base?
Conversation
pyqg/model.py
Outdated
self.kernel = kernel(nz, ny, nx, | ||
q_parameterization, | ||
uv_parameterization, | ||
ntd, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One issue -- ntd
(number of threads) is not going to be an attribute that every kernel needs (or accepts). We could get around this by passing an optional kernel_kwargs
argument but that seems questionable.
Normally, with dependency injection, we pass in a fully initialized object, but here we're passing in a class. That also seems questionable, but it does simplify the interface. Hmm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do you imagine the api will work once we have multiple kernels? How will the user specify which kernel they want to use at runtime?
I would prefer to be able to write
model = QGModel(kernel="jax", kernel_kwargs={"foo": "special jax argument"})
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I was thinking something like that -- it would support both
model = pyqg.QGModel(kernel="jax", kernel_kwargs={"foo": "special jax argument"})
or
model = pyqg.QGModel(kernel=pyqg.JaxKernel, kernel_kwargs={"foo": "special jax argument"})
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
However, I do think this syntax is a bit weird, since it's more normal to do dependency injection with initialized objects:
model = pyqg.QGModel(kernel=pyqg.CythonFFTWKernel(fftw_num_threads=2)
model = pyqg.QGModel(kernel=pyqg.JaxKernel(foo='bar'))
The problem with this is that we'd need to start passing nx
/ny
/nz
/other arguments to the kernel rather than the model, which is could lead to redundancies and definitely isn't backwards compatible.
Two options if we want to use delegate-property:
Edit: I see you have gone for option 1. 👍 |
Sorry for all the draft-no-draft conversions -- keep forgetting to re-compile before running tests :) Current issue is just figuring out how the Cython kernel inherits from a Python class. |
It looks like this is actually impossible! We might still eventually add inheritance, but just for Python kernels (e.g. numpy and jax). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking good. 👍
It is also turning up some weird and inconsistent patterns within the pyqg code. Can we use this as a chance to clean things up a bit?
A few comments...
pyqg/model.py
Outdated
kernel=kernels.CythonFFTWKernel, | ||
kernel_kwargs={}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These need to be documented in the docstring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've done the bare minimum to document this now (some description in the pyqg.Model
docstring), but I'm realizing that for scalability, we probably need to add kernels to the API and give them their own documentation (which users would reference to figure out kernel_kwargs
).
pyqg/model.py
Outdated
def _invert(self): | ||
self.kernel.invert() | ||
|
||
def _do_advection(self): | ||
self.kernel.do_advection() | ||
|
||
def _do_friction(self): | ||
self.kernel.do_friction() | ||
|
||
def _do_q_subgrid_parameterization(self): | ||
self.kernel.do_q_subgrid_parameterization() | ||
|
||
def _do_uv_subgrid_parameterization(self): | ||
self.kernel.do_uv_subgrid_parameterization() | ||
|
||
def _forward_timestep(self): | ||
self.kernel.forward_timestep() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This feels like a code smell. Can we avoid wrapping all these methods? Could we use delegate here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that's better. The issue is that, having removed underscores from the corresponding kernel
method names, then we end up needing to remove underscores from the pyqg.Model
method names as well, which means they're now part of the public API (and it also breaks some of my code which calls the private methods anyway 😬).
A terse and backwards-compatible alternative I spent a little while starting to implement was updating delegate
to support a prefix option:
@delegate(*private_kernel_attrs, to='kernel', prefix='_')
@delegate(*public_kernel_attrs, to='kernel')
class Model:
# ...
This is relatively simple in principle, though it requires significant changes to the delegate
decorator to allow calls to stack / prevent one call from overwriting the other. Because of that, I opted for the smelly solution. However, with your comment, I'm considering updating the decorator again...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, I updated it, and also made a pull request to the original library in case they want to allow this functionality (dscottboggs/python-delegate#1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From what I can see, I think this can be resolved by composing most of the kernel functions. I think Kernel._forward_timestep
is the only one that couldn't be, atm, but only because _calc_diagnostics
is in the way.
I'm really wary of solving the Kernel/Model inheritance problem with delegation, it takes away a lot of the guarantees I need to be able to implement Jax #241 in a reasonable way. Basically Kernel would become part of Model, which means I can't know if the Kernel arrays are being messed with outside of Kernel. I think if we compose all the kernel functions, that would alleviate this problem on my end, but I think passing Kernel to Model as an argument is probably a better design option.
self._U = U | ||
self._V = V |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain why you moved this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needed to get moved up because we can't set Ubg
before calling super
, since Ubg
is now a property of the kernel. However, super
calls _initialize_background
which then calls _initialize_stretching_matrix
which requires Ubg
to be defined. So, following QGModel
, I moved the actual setting of Ubg
to _initialize_background
(which makes sense). The issue is that we need some intermediate variable to store the array, so we end up with this 😐
Doesn't appear to be supported, and it's not too much code. Maybe this is a mistake, though.
This pull requests makes
pyqg.Model
a top-level class, invokingpyqg.PseudoSpectralKernel
by dependency injection rather than inheritance.This change will allow us to pass in alternative kernels (e.g. jax or numpy without Cython), addressing #240.
Note that as currently implemented, I'm using the delegate-property package, which adds a dependency. Because it's is a fairly short library, I could copy it over if we think it might cause problems, or at least lock the version inI ended up removing this requirement right now because the package doesn't appear to be included in conda, so CI was failing.requirements.txt
.