Skip to content

Commit

Permalink
Require np ~1.18.3 to avoid deprecation warning, fix object array iss…
Browse files Browse the repository at this point in the history
…ues with konverted model
  • Loading branch information
sshane committed Jul 13, 2020
1 parent 318eaa8 commit a5c66e0
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 38 deletions.
5 changes: 3 additions & 2 deletions konverter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def build_konverted_model(self):
model_builder = {'imports': ['import numpy as np'],
'functions': [],
'load_weights': [],
'model': ['def predict(x):']}
'model': ['def predict(x):',
'x = np.array(x, dtype=np.float32)']} # convert input to float32

# add section to load model weights and biases
model_builder['load_weights'].append(f'wb = np.load(\'{self.output_file}_weights.npz\', allow_pickle=True)')
Expand Down Expand Up @@ -140,7 +141,7 @@ def save_model(self, model_builder):

wb = list(zip(*wb))
gbmse = list(zip(*gbmse))
kwargs = {'wb': np.array(wb, dtype=np.object)}
kwargs = {'wb': np.array(wb)}

if Layers.BatchNormalization.name in support.layer_names(self.layers):
kwargs['gbmse'] = np.array(gbmse, dtype=np.object)
Expand Down
2 changes: 1 addition & 1 deletion konverter/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import konverter
from konverter.utils.general import success, info, warning, error, COLORS, color_logo, blue_grad

KONVERTER_VERSION = "v0.2" # fixme: unify this
KONVERTER_VERSION = "v0.2.1" # fixme: unify this
KONVERTER_LOGO_COLORED = color_logo(KONVERTER_VERSION)

class KonverterCLI:
Expand Down
2 changes: 1 addition & 1 deletion konverter/utils/konverter_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def get_layer_info(self, layer):
layer_class.info.weights = np.array(wb[0])
layer_class.info.biases = np.array(wb[1])
elif len(wb) == 3 and layer_class.name in self.recurrent_layers:
layer_class.info.weights = np.array(wb[:2], dtype=np.object) # input and recurrent weights
layer_class.info.weights = np.array(wb[:2]) # input and recurrent weights
layer_class.info.biases = np.array(wb[-1])
layer_class.info.returns_sequences = layer.return_sequences
layer_class.info.is_recurrent = True
Expand Down
59 changes: 27 additions & 32 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "keras-konverter"
version = "0.2"
version = "0.2.1"
description = "A tool to convert simple Keras models to pure Python + NumPy"
readme = "README.md"
repository = "https://github.com/ShaneSmiskol/Konverter"
Expand All @@ -13,7 +13,7 @@ packages = [

[tool.poetry.dependencies]
python = "^3.6" # 3.6 to 3.7 is okay, or 3.8 if you're using a beta version of tf
numpy = "^1.18.3"
numpy = "~1.18.3"

[tool.poetry.dev-dependencies]
tensorflow = "^2.1.0"
Expand Down

0 comments on commit a5c66e0

Please sign in to comment.