Skip to content

[ENH] MDNRegressor (Mixture Density Network)#796

Merged
fkiraly merged 13 commits into
sktime:mainfrom
joshdunnlime:mdn
Mar 10, 2026
Merged

[ENH] MDNRegressor (Mixture Density Network)#796
fkiraly merged 13 commits into
sktime:mainfrom
joshdunnlime:mdn

Conversation

@joshdunnlime
Copy link
Copy Markdown
Contributor

@joshdunnlime joshdunnlime commented Mar 5, 2026

Reference Issues/PRs

No issue opened.

What does this implement/fix? Explain your changes.

A new regressor implementation of Mixture Density Network (MDN) as per Bishop 1994 with noise regularisation as per Rothfuss 2019.
It also includes optional passing of pytorch activation functions and optimizers.

It implements a fully vectorised NormalMixture distribution where each rows weights are individually learnt and applied. It also implements a custom vectorised bisection for fast _ppf method calls.

Does your contribution introduce a new dependency? If yes, which one?

Yes, soft deps: Pytorch and pytorch-optimizer.

What should a reviewer concentrate their feedback on?

NormalMixture distribution. This adds slightly opinionated design choice to the API.

Did you add any tests for the change?

Yes. Standard library param tests for dist and est. Additional test for est coming.

Any other comments?

PR checklist

For all contributions
  • I've added myself to the list of contributors with any new badges I've earned :-)
    How to: add yourself to the all-contributors file in the skpro root directory (not the CONTRIBUTORS.md). Common badges: code - fixing a bug, or adding code logic. doc - writing or improving documentation or docstrings. bug - reporting or diagnosing a bug (get this plus code if you also fixed the bug in the PR).maintenance - CI, test framework, release.
    See here for full badge reference
  • The PR title starts with either [ENH], [MNT], [DOC], or [BUG]. [BUG] - bugfix, [MNT] - CI, test framework, [ENH] - adding or improving code, [DOC] - writing or improving documentation or docstrings.
For new estimators
  • I've added the estimator to the API reference - in docs/source/api_reference/taskname.rst, follow the pattern.
  • I've added one or more illustrative usage examples to the docstring, in a pydocstyle compliant Examples section.
  • If the estimator relies on a soft dependency, I've set the python_dependencies tag and ensured
    dependency isolation, see the estimator dependencies guide.

Copy link
Copy Markdown
Collaborator

@fkiraly fkiraly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Nice!

  • question: why implement a NormalMixture when Mixture is available and should be capable of representing mixture of normals by a composition of Normal and Mixture?
  • if we add NormalMixture, it should also be added in the API reference for distribution

Comment thread pyproject.toml Outdated
@joshdunnlime
Copy link
Copy Markdown
Contributor Author

My understanding of Mixture is that it only has global weights for each sub-dist. MDNR could use this, however, it would mean that for each prediction row, it would use a new instance of Mixture (computationally this would be crippling). MixtureNormal has performant, vectorised implementation of most of the class methods so is actually pretty quick!

You could argue that we could call it Normal(Row|Instance|HeteroWeighted)Mixture but I think it is documented clearly enough to avoid confusion.

Add NormalMixture to distributions list. Also split and rename to silverman and scott as is more pythonic convention.
Initial testing shows ISJ provides improved convergence and final NLL scores. Bandwidths added as a separate module as this will be very useful for other kernel based estimators.
Copy link
Copy Markdown
Collaborator

@fkiraly fkiraly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. Yes, that makes sense.

Do you want to open an issue that wishes for an abstract mixture that can have row-wise different mixture weights? That could be extended from Mixture, if weights are not just a list but a matrix.

@joshdunnlime
Copy link
Copy Markdown
Contributor Author

I did consider the following:

  1. Extending Mixture to include "instance" and "global" weights.
  2. Create a generic InstanceWeightedMixture class.
  3. A base class BaseInstanceWeightedMixture.

The downsides were:

  1. Far too much work. IMHO these are two quite different classes to the point I think this would just be a wrapped around a Mixture and InstanceWeightedMixture.

  2. InstanceWeightedMixture which accepts a mixture of different dists comes with a lot of complications. The primary one is that the NormalMixture has very well defined (and mostly exact) methods for all of the standard methods. My understanding of weighted mixtures of distributions stops a normals so I have no idea what is feasible here. The vectorisation is clearly very well defined for normal distributions in scipy, and assumptions on the support for root-finding in _ppf were fairly straightforward. Implemenation-wise, I believe it serves to keep NormalMixture even if we have 3) as the speed-ups and exactness are significant.

  3. Happy to open a Issue. My knowledge of anything beyond normal mixtures of dists is effectively zero so I doubt I'd be up to a PR.

@joshdunnlime
Copy link
Copy Markdown
Contributor Author

Just as a side note: The natural extension to MDNs is probably something Flow/Kernel based, as opposed to extending this with non-Normal dists.

Copy link
Copy Markdown
Collaborator

@fkiraly fkiraly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fine now, but one request:

the estimator MDNRegressor does not actually depend on pytorch-optimizer, except if the string "SOAP" is passed. But that only does aliasing, and you are doing soft dependency checking there is no intrinsic dependency. Hence I would remove pytorch-optimizer from the dependency set.

I would also suggest to replace the try/except for dependency checking with _check_soft_dependencies (severity="none")


XGBoostLSS

Neural conditional density estimation
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about "deep learning based regressors" instead?

@joshdunnlime
Copy link
Copy Markdown
Contributor Author

Thanks for updating. I'll do that optimizer package stuff.
I'm also adding some further improvements to the model which seem to help with training speed and stability.

@fkiraly
Copy link
Copy Markdown
Collaborator

fkiraly commented Mar 9, 2026

btw, other topic, did we want to have the discussion on the differentiable transformers sometime? Have not seen you in the meetups, and I really think we need a sync meet on that (at least once). If the meetup times are not convenient, feel free to start a scheduling discussion on the discord.

@fkiraly fkiraly added enhancement module:regression probabilistic regression module implementing algorithms Implementing algorithms, estimators, objects native to skpro labels Mar 9, 2026
@joshdunnlime
Copy link
Copy Markdown
Contributor Author

joshdunnlime commented Mar 10, 2026

I've added ngem loss: https://arxiv.org/html/2602.10602v1 to the bellow PR on my fork. It seems to show significant benefits at lower lr, particularly for CRPS. I am happy to merge into this mdn branch but further tidying is required. See comments in PR:
joshdunnlime#1

It would be best to hold off and merge this all together to avoid minor breaking changes but I am unclear on how to handle the ngem_lr defaults as it really requires lower lr to perform well. Comments very much welcome.

Yes, apologies there. I will reach out on Discord regarding the Diff Transform class.

Copy link
Copy Markdown
Collaborator

@fkiraly fkiraly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, makes sense - adding losses can be done in a separate PR.

@fkiraly fkiraly merged commit 6ae07b6 into sktime:main Mar 10, 2026
53 checks passed
fkiraly pushed a commit that referenced this pull request May 19, 2026
#### Reference Issues/PRs

No Issue but referenced
[here](#796 (comment)).


#### What does this implement/fix? Explain your changes.

Mixture density networks typically uses the negative log-likelihood
(nll) objective, which can suffer from slow convergence and mode
collapse. The natural gradient expectation maximization (ngem) objective
can achieve up to 10x faster convergence while adding almost zero
computational overhead, and scales well to high-dimensional data where
nll fails.

[Learning Mixture Density via Natural Gradient Expectation
Maximization](https://arxiv.org/html/2602.10602v1)

#### What should a reviewer concentrate their feedback on?

~General API structure for adding ngem specific learning rate. ngem
typically performs better with lower lr, so adding an automated lr
multiplier has been considered. Is this designed correctly?~

I have removed `ngem_lr_scaling` as it obscures true `lr` and sets a bad
standard if we wish to add more objectives that also benefit different
magnitudes of scaling.

#### Did you add any tests for the change?

Existing MDN tests to still pass. Add specific ngem tests also.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement implementing algorithms Implementing algorithms, estimators, objects native to skpro module:regression probabilistic regression module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants