Skip to content

Commit

Permalink
fix pylint errors, unused imports
Browse files Browse the repository at this point in the history
  • Loading branch information
sharanry committed May 31, 2018
1 parent 10ea1aa commit 9086d1a
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Expand Up @@ -323,7 +323,7 @@ ignore-on-opaque-inference=yes
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local
ignored-classes=optparse.Values,thread._local,_thread._local, Context

# List of module names for which member attributes should not be checked
# (useful for modules/projects where namespaces are manipulated during runtime
Expand Down
4 changes: 2 additions & 2 deletions pymc4/model.py
@@ -1,4 +1,3 @@
import tensorflow as tf
import threading

__all__ = ["Model"]
Expand All @@ -19,7 +18,7 @@ def __exit__(self, typ, value, traceback):


class Model(Context):
def __new__(cls, *args, **kwargs):
def __new__(cls, **kwargs):
instance = super(Model, cls).__new__(cls)
if kwargs.get('model') is not None:
instance.parent = kwargs.get('model')
Expand All @@ -32,6 +31,7 @@ def __new__(cls, *args, **kwargs):
def __init__(self, name="", model=None, ):
self.name = name
self.named_vars = {}
self.parent = model

@property
def model(self):
Expand Down
20 changes: 10 additions & 10 deletions pymc4/random_variable.py
Expand Up @@ -7,19 +7,19 @@
class RandomVariable(ed.RandomVariable):

def __init__(
self,
distribution,
sample_shape=(),
value=None,
name="RV"
):
self,
name,
distribution,
sample_shape=(),
value=None,
):
self.model = Model.get_context()
self.name = name

super(RandomVariable, self).__init__(
distribution,
sample_shape,
value,
)
distribution,
sample_shape,
value,
)

self.model.add_random_variable(self)
17 changes: 8 additions & 9 deletions pymc4/sample.py
@@ -1,9 +1,8 @@
import tensorflow as tf
from pymc4 import Model
import numpy as np
import xarray as xr
from pymc4 import Model
import tqdm
from pymc4 import Model

__all__ = ["sample"]

Expand All @@ -14,17 +13,17 @@ def sample(draws=1000, tune=500, as_xarray=True):
model = Model.get_context()
array = []
with tf.Session() as sess:
for i in tqdm.trange(draws+tune):
for _ in tqdm.trange(draws+tune):

# Sampling methods are applied here.
# Directly using tensorflow's default sampling method for now
array.append([i.eval() for i in model.named_vars.values()])
if as_xarray:
return (xr.DataArray(
data=array[tune:],
dims=("Val", "RV"),
coords={"RV": list(model.named_vars.keys())}
)
)
data=array[tune:],
dims=("Val", "RV"),
coords={"RV": list(model.named_vars.keys())}
)
)
else:
return np.array(array[Tune])
return np.array(array[tune:])

0 comments on commit 9086d1a

Please sign in to comment.