-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Refactoring OPVI to support Normalizing Flows #2306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
faca8aa
beginning of refactoring, simple tests pass for MeanField
ferrine 12e6173
more tests pass
ferrine 1be7aeb
mv normal refactor
ferrine ae3ff31
mv normal refactor
ferrine 241b303
refactor tests for VI
ferrine a7467da
more tests
ferrine 02996f2
refactor rest code
ferrine 9e53c05
refactor tests
ferrine 03b6d21
refactor tests
ferrine 3ef423e
use yield fixtures
ferrine 289a559
refactor tests, make them faster
ferrine 027ef00
fix tests
ferrine 33fdca8
fix tests duration
ferrine a180066
fix deterministic for Empirical, more tests
ferrine 7d4562a
change mean shape for Empirical
ferrine 64ee80a
support multitrace in Empirical
ferrine 5737626
fix convergence tests
ferrine 020c428
do not duplicate tests
ferrine 009b095
some small changes in docs+code
ferrine 8708772
update LDA notebook
ferrine 19aeec9
scale_cost to minibatch refactor
ferrine 80e7e47
get feedback and refactor code
ferrine a4a04a4
stein refactor and typo fix
ferrine 5b983a7
test value for stein
ferrine File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,50 +1,37 @@ | ||
import theano | ||
import numpy as np | ||
import theano | ||
import pymc3 as pm | ||
import pytest | ||
|
||
|
||
class DataSampler(object): | ||
""" | ||
Not for users | ||
""" | ||
def __init__(self, data, batchsize=50, random_seed=42, dtype='floatX'): | ||
self.dtype = theano.config.floatX if dtype == 'floatX' else dtype | ||
self.rng = np.random.RandomState(random_seed) | ||
self.data = data | ||
self.n = batchsize | ||
|
||
def __iter__(self): | ||
return self | ||
|
||
def __next__(self): | ||
idx = (self.rng | ||
.uniform(size=self.n, | ||
low=0.0, | ||
high=self.data.shape[0] - 1e-16) | ||
.astype('int64')) | ||
return np.asarray(self.data[idx], self.dtype) | ||
|
||
next = __next__ | ||
|
||
|
||
@pytest.fixture(scope="session", autouse=True) | ||
@pytest.fixture(scope="function", autouse=True) | ||
def theano_config(): | ||
config = theano.configparser.change_flags(compute_test_value='raise') | ||
with config: | ||
yield | ||
|
||
|
||
@pytest.fixture(scope='function') | ||
def strict_float32(): | ||
@pytest.fixture(scope='function', autouse=True) | ||
def exception_verbosity(): | ||
config = theano.configparser.change_flags( | ||
warn_float64='raise', | ||
floatX='float32') | ||
exception_verbosity='high') | ||
with config: | ||
yield | ||
|
||
|
||
@pytest.fixture('session', params=[ | ||
np.random.uniform(size=(1000, 10)) | ||
]) | ||
def datagen(request): | ||
return DataSampler(request.param) | ||
@pytest.fixture(scope='function', autouse=False) | ||
def strict_float32(): | ||
if theano.config.floatX == 'float32': | ||
config = theano.configparser.change_flags( | ||
warn_float64='raise') | ||
with config: | ||
yield | ||
else: | ||
yield | ||
|
||
|
||
@pytest.fixture('function', autouse=False) | ||
def seeded_test(): | ||
# TODO: use this instead of SeededTest | ||
np.random.seed(42) | ||
pm.set_tt_rng(42) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't
if vars:
one less indent here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it should not, I use if there for the case of empty vars. or else I get exception from flatten