Skip to content

Commit

Permalink
Merge pull request #17 from preinaj/save-model
Browse files Browse the repository at this point in the history
Add support for save model offline
  • Loading branch information
pedrolarben committed Dec 18, 2021
2 parents fbc663f + 192d919 commit e788e15
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions ADLStream/adlstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ def set_weights(self, w):
def get_weights(self):
return self.weights

def get_shape(self):
X = self.x_train
return np.asarray(X).shape

def add(self, x, y=None):
with self.data_lock:
if x is not None:
Expand Down Expand Up @@ -253,6 +257,9 @@ def __init__(
self.train_gpu_index = train_gpu_index
self.predict_gpu_index = predict_gpu_index
self.log_file = log_file
self.x_shape = None
self.output_size = None
self.weights = None

self.manager = ADLStreamManager()

Expand Down Expand Up @@ -437,4 +444,22 @@ def run(self):
process_predict.join()
process_evaluator.join()

self.x_shape = context.get_shape()
self.output_size = context.get_output_size()
self.weights = context.get_weights()

self.manager.shutdown()

def get_model(self):
from ADLStream.models import create_model

model = create_model(
self.model_architecture,
self.x_shape,
self.output_size,
self.model_loss,
self.model_optimizer,
**self.model_parameters
)
model.set_weights(self.weights)
return model

0 comments on commit e788e15

Please sign in to comment.