Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix train_model.py bug, fix Analysis module import #39

Merged
merged 3 commits into from
Oct 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions paltas/Analysis/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import argparse, os, sys, glob, math
from importlib import import_module
import tensorflow as tf
from . import dataset_generation, loss_functions, conv_models
from paltas.Analysis import dataset_generation, loss_functions, conv_models
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint
from tensorflow.keras import optimizers
import pandas as pd
Expand Down Expand Up @@ -162,7 +162,7 @@ def main():
kwargs_detector=None,log_learning_params=log_learning_params)

# If some of the parameters need to be extracted as inputs, do that.
if params_as_inputs is not None:
if params_as_inputs:
tf_dataset_t = dataset_generation.generate_params_as_input_dataset(
tf_dataset_t,params_as_inputs,all_params+log_learning_params)
tf_dataset_v = dataset_generation.generate_params_as_input_dataset(
Expand All @@ -187,10 +187,10 @@ def main():
loss_function))

# Load the model
if model_type == 'xresnet34' and params_as_inputs is None:
if model_type == 'xresnet34' and not params_as_inputs:
model = conv_models.build_xresnet34(img_size,num_outputs,
train_only_head=train_only_head)
elif model_type == 'xresnet34' and params_as_inputs is not None:
elif model_type == 'xresnet34' and params_as_inputs:
model = conv_models.build_xresnet34_fc_inputs(img_size,num_outputs,
len(params_as_inputs),train_only_head=False)
else:
Expand Down
8 changes: 7 additions & 1 deletion paltas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
__version__ = '0.1.1'

# Analysis is not imported by default because it required tensorflow.

try:
import tensorflow as tf
del tf
except ImportError:
print("paltas.Analysis disabled since tensorflow is missing")
else:
from . import Analysis
from . import Configs
from . import Sampling
from . import Sources
Expand Down
2 changes: 1 addition & 1 deletion test/analysis_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,7 +1176,7 @@ def test_plot_coverage(self):

y_pred = np.random.normal(size=(batch_size,num_params))
y_true = np.random.normal(size=(batch_size,num_params))
std_pred = np.random.normal(size=(batch_size,num_params))
std_pred = np.abs(np.random.normal(size=(batch_size,num_params)))

Analysis.posterior_functions.plot_coverage(y_pred,y_true,std_pred,
parameter_names,block=False)
Expand Down