Skip to content

Commit

Permalink
Merge pull request #12 from zurutech/doc-fix
Browse files Browse the repository at this point in the history
Improve documentation and increase code quality
  • Loading branch information
galeone committed Jul 19, 2019
2 parents d9294e5 + 77d73c6 commit 47eef53
Show file tree
Hide file tree
Showing 13 changed files with 773 additions and 372 deletions.
29 changes: 17 additions & 12 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -335,24 +335,29 @@ function-naming-style=snake_case
#function-rgx=

# Good variable names which should always be accepted, separated by a comma
good-names=i,
j,
k,
ex,
Run,
_,
G,
good-names=_,
D,
Dx,
E,
Gz,
ex,
Ex,
Dx,
f,
fn
fp,
z,
G,
g,
Gz,
h,
i,
j,
k,
o,
op,
Run,
s,
x,
y,
s,
op
z,

# Include a hint for the correct naming format with invalid-name
include-naming-hint=yes
Expand Down
3 changes: 2 additions & 1 deletion ashpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
from . import metrics
from . import models
from . import trainers
from . import types
from .modes import LogEvalMode

__version__ = "1.0.1"
__version__ = "1.0.2"
__url__ = "https://github.com/zurutech/ashpy"
__author__ = "Machine Learning Team @ Zuru Tech"
__email__ = "ml@zuru.tech"
77 changes: 53 additions & 24 deletions ashpy/contexts/base_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,34 +20,39 @@
handle information transfer.
"""

from typing import List

import tensorflow as tf

from ashpy.metrics import Metric
from ashpy.modes import LogEvalMode


class BaseContext:
r"""
:py:class:`ashpy.contexts.base_context.BaseContext` provide an interface for all contexts to inherit from.
"""
""":py:class:`ashpy.contexts.base_context.BaseContext` provide an interface for all contexts to inherit from."""

def __init__(
self,
metrics=None,
dataset=None,
log_eval_mode=LogEvalMode.TEST,
metrics: List[Metric] = None,
dataset: tf.data.Dataset = None,
log_eval_mode: LogEvalMode = LogEvalMode.TEST,
global_step=tf.Variable(0, name="global_step", trainable=False, dtype=tf.int64),
ckpt=None,
):
r"""
:py:class:`ashpy.contexts.base_context.BaseContext`
ckpt: tf.train.Checkpoint = None,
) -> None:
"""
Initialize :py:class:`ashpy.contexts.base_context.BaseContext`.
Args:
metrics ([:py:class:`ashpy.metrics.metric.Metric`]): list of :py:class:`ashpy.metrics.metric.Metric` objects.
metrics (:obj:`list` of [:py:class:`ashpy.metrics.metric.Metric`]): List of
:py:class:`ashpy.metrics.metric.Metric` objects.
dataset (:py:class:`tf.data.Dataset`): The dataset to use, that
contains everything needed to use the model in this context.
log_eval_mode: models' mode to use when evaluating and logging.
global_step (:py:class:`tf.Variable`): keeps track of the training steps.
ckpt (:py:class:`tf.train.Checkpoint`): checkpoint to use to keep track of models status.
log_eval_mode (:py:class:`ashpy.modes.LogEvalMode`): Models' mode to use when
evaluating and logging.
global_step (:py:class:`tf.Variable`): Keeps track of the training steps.
ckpt (:py:class:`tf.train.Checkpoint`): Checkpoint to use to keep track of
models status.
"""
self._distribute_strategy = tf.distribute.get_strategy()
self._metrics = metrics if metrics else []
Expand All @@ -65,32 +70,56 @@ def _validate_metrics(self):
"Metric " + str(metric) + " is not a ashpy.metrics.Metric"
)

def measure_metrics(self):
def measure_metrics(self) -> None:
"""Measure the metrics."""
for metric in self._metrics:
metric.update_state(self)

def model_selection(self):
def model_selection(self) -> None:
"""Use the metrics to perform model selection."""
for metric in self._metrics:
metric.model_selection(self._ckpt)

@property
def log_eval_mode(self):
"""Model(s) mode."""
def log_eval_mode(self) -> LogEvalMode:
"""
Model(s) mode.
Returns:
:py:class:`ashpy.modes.LogEvalMode`.
"""
return self._log_eval_mode

@property
def dataset(self):
"""Return dataset."""
def dataset(self) -> tf.data.Dataset:
"""
Return dataset.
Returns:
:py:class:`tf.data.Dataset`.
"""
return self._dataset

@property
def metrics(self):
"""Return the metrics."""
def metrics(self) -> List[Metric]:
"""
Return the metrics.
Returns:
:obj:`list` of [:py:class:`ashpy.metrics.metric.Metric`].
"""
return self._metrics

@property
def global_step(self):
"""Return the global_step."""
def global_step(self) -> tf.Variable:
"""
Return the global_step.
Returns:
:py:class:`tf.Variable`.
"""
return self._global_step
50 changes: 33 additions & 17 deletions ashpy/contexts/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,55 +13,71 @@
# limitations under the License.

"""Classifier Context."""

from __future__ import annotations

from typing import TYPE_CHECKING, List, Optional

import tensorflow as tf

from ashpy.contexts.base_context import BaseContext
from ashpy.metrics import Metric
from ashpy.modes import LogEvalMode

if TYPE_CHECKING:
from ashpy.losses.executor import Executor


class ClassifierContext(BaseContext):
r""":py:class:`ashpy.contexts.classifier.ClassifierContext` provide the standard functions to test a classifier."""

def __init__(
self,
classifier_model=None,
loss=None,
dataset=None,
metrics=None,
log_eval_mode=LogEvalMode.TEST,
global_step=tf.Variable(0, name="global_step", trainable=False, dtype=tf.int64),
ckpt=None,
):
classifier_model: tf.keras.Model = None,
loss: Executor = None,
dataset: tf.data.Dataset = None,
metrics: List[Metric] = None,
log_eval_mode: LogEvalMode = LogEvalMode.TEST,
global_step: tf.Variable = tf.Variable(
0, name="global_step", trainable=False, dtype=tf.int64
),
ckpt: tf.train.Checkpoint = None,
) -> None:
r"""
Instantiate the :py:class:`ashpy.contexts.classifier.ClassifierContext` context.
Args:
classifier_model (:py:class:`tf.keras.Model`): A :py:class:`tf.keras.Model`
model.
loss (callable): loss function, format f(y_true, y_pred)
loss (:py:class:`ashpy.losses.Executor`): Loss function, format f(y_true, y_pred).
dataset (:py:class:`tf.data.Dataset`): The test dataset.
metrics: List of python objects (of Metric class) with which to measure
training and validation data performances.
log_eval_mode: models' mode to use when evaluating and logging.
global_step: tf.Variable that keeps track of the training steps.
ckpt (:py:class:`tf.train.Checkpoint`): checkpoint to use to keep track of models status.
metrics (:obj:`list` of [:py:class:`ashpy.metrics.metric.Metric`]): List of
:py:class:`ashpy.metrics.metric.Metric` with which to measure training
and validation data performances.
log_eval_mode (:py:obj:`ashpy.modes.LogEvalMode`): Models' mode to use when
evaluating and logging.
global_step (:py:obj:`tf.Variable`): tf.Variable that keeps track of the
training steps.
ckpt (:py:class:`tf.train.Checkpoint`): checkpoint to use to keep track of
models status.
"""
super().__init__(metrics, dataset, log_eval_mode, global_step, ckpt)
self._classifier_model = classifier_model
self._loss = loss

@property
def loss(self):
def loss(self) -> Optional[Executor]:
"""Return the loss value."""
return self._loss

@property
def classifier_model(self):
def classifier_model(self) -> tf.keras.Model:
r"""
Return the Model Object.
Returns:
:py:class:`tf.keras.Model`
:py:class:`tf.keras.Model`.
"""
return self._classifier_model

0 comments on commit 47eef53

Please sign in to comment.