Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions ignite/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@
from ignite.engine.engine import Engine, State, Events
from ignite.utils import convert_tensor

__all__ = [
'create_supervised_trainer',
'create_supervised_evaluator',
'Engine',
'Events'
]


def _prepare_batch(batch, device=None, non_blocking=False):
"""Prepare batch for training: pass to a device with options.
Expand Down
6 changes: 6 additions & 0 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@

from ignite._utils import _to_hours_mins_secs

__all__ = [
'Engine',
'Events',
'State'
]


class EventWithFilter:

Expand Down
5 changes: 5 additions & 0 deletions ignite/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
__all__ = [
'NotComputableError'
]


class NotComputableError(RuntimeError):
"""
Exception class to raise if Metric cannot be computed.
Expand Down
10 changes: 10 additions & 0 deletions ignite/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@
from ignite.handlers.early_stopping import EarlyStopping
from ignite.handlers.terminate_on_nan import TerminateOnNan

__all__ = [
'ModelCheckpoint',
'Checkpoint',
'DiskSaver',
'Timer',
'EarlyStopping',
'TerminateOnNan',
'global_step_from_engine'
]


def global_step_from_engine(engine):
"""Helper method to setup `global_step_transform` function using another engine.
Expand Down
5 changes: 5 additions & 0 deletions ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@

from ignite.engine import Events

__all__ = [
'Checkpoint',
'ModelCheckpoint'
]


class Checkpoint:
"""Checkpoint handler can be used to periodically save and load objects which have attribute
Expand Down
4 changes: 4 additions & 0 deletions ignite/handlers/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

from ignite.engine import Engine

__all__ = [
'EarlyStopping'
]


class EarlyStopping:
"""EarlyStopping handler can be used to stop the training if no improvement after a given number of events.
Expand Down
4 changes: 4 additions & 0 deletions ignite/handlers/terminate_on_nan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@

from ignite.utils import apply_to_type

__all__ = [
'TerminateOnNan'
]


class TerminateOnNan:
"""TerminateOnNan handler can be used to stop the training if the `process_function`'s output
Expand Down
4 changes: 4 additions & 0 deletions ignite/handlers/timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

from ignite.engine import Events

__all__ = [
'Timer'
]


class Timer:
""" Timer object can be used to measure (average) time between events.
Expand Down
24 changes: 24 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,27 @@
from ignite.metrics.confusion_matrix import ConfusionMatrix, IoU, mIoU, DiceCoefficient
from ignite.metrics.accumulation import VariableAccumulation, Average, GeometricAverage
from ignite.metrics.fbeta import Fbeta

__all__ = [
'Metric',
'Accuracy',
'Loss',
'MetricsLambda',
'MeanAbsoluteError',
'MeanPairwiseDistance',
'MeanSquaredError',
'ConfusionMatrix',
'TopKCategoricalAccuracy',
'Average',
'DiceCoefficient',
'EpochMetric',
'Fbeta',
'GeometricAverage',
'IoU',
'mIoU',
'Precision',
'Recall',
'RootMeanSquaredError',
'RunningAverage',
'VariableAccumulation'
]
6 changes: 6 additions & 0 deletions ignite/metrics/accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@

import torch

__all__ = [
'VariableAccumulation',
'GeometricAverage',
'Average'
]


class VariableAccumulation(Metric):
"""Single variable accumulator helper to compute (arithmetic, geometric, harmonic) average of a single variable.
Expand Down
4 changes: 4 additions & 0 deletions ignite/metrics/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

import torch

__all__ = [
'Accuracy'
]


class _BaseClassification(Metric):

Expand Down
10 changes: 10 additions & 0 deletions ignite/metrics/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@
from ignite.exceptions import NotComputableError
from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced

__all__ = [
'ConfusionMatrix',
'mIoU',
'IoU',
'DiceCoefficient',
'cmAccuracy',
'cmPrecision',
'cmRecall'
]


class ConfusionMatrix(Metric):
"""Calculates confusion matrix for multi-class data.
Expand Down
4 changes: 4 additions & 0 deletions ignite/metrics/epoch_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

from ignite.metrics.metric import Metric

__all__ = [
'EpochMetric'
]


class EpochMetric(Metric):
"""Class for metrics that should be computed on the entire output history of a model.
Expand Down
4 changes: 4 additions & 0 deletions ignite/metrics/fbeta.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from ignite.metrics import Precision, Recall

__all__ = [
'Fbeta'
]


def Fbeta(beta, average=True, precision=None, recall=None, output_transform=None, device=None):
"""Calculates F-beta score
Expand Down
4 changes: 4 additions & 0 deletions ignite/metrics/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
from ignite.metrics import Metric
from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced

__all__ = [
'Loss'
]


class Loss(Metric):
"""
Expand Down
4 changes: 4 additions & 0 deletions ignite/metrics/mean_absolute_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from ignite.metrics.metric import Metric
from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced

__all__ = [
'MeanAbsoluteError'
]


class MeanAbsoluteError(Metric):
"""
Expand Down
4 changes: 4 additions & 0 deletions ignite/metrics/mean_pairwise_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from ignite.metrics.metric import Metric
from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced

__all__ = [
'MeanPairwiseDistance'
]


class MeanPairwiseDistance(Metric):
"""
Expand Down
4 changes: 4 additions & 0 deletions ignite/metrics/mean_squared_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from ignite.metrics.metric import Metric
from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced

__all__ = [
'MeanSquaredError'
]


class MeanSquaredError(Metric):
"""
Expand Down
4 changes: 4 additions & 0 deletions ignite/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@

from ignite.engine import Events

__all__ = [
'Metric'
]


class Metric(metaclass=ABCMeta):
"""
Expand Down
4 changes: 4 additions & 0 deletions ignite/metrics/metrics_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
from ignite.metrics.metric import Metric, reinit__is_reduced
from ignite.engine import Events

__all__ = [
'MetricsLambda'
]


class MetricsLambda(Metric):
"""
Expand Down
4 changes: 4 additions & 0 deletions ignite/metrics/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from ignite.utils import to_onehot
from ignite.metrics.metric import reinit__is_reduced

__all__ = [
'Precision'
]


class _BasePrecisionRecall(_BaseClassification):

Expand Down
4 changes: 4 additions & 0 deletions ignite/metrics/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from ignite.utils import to_onehot
from ignite.metrics.metric import reinit__is_reduced

__all__ = [
'Recall'
]


class Recall(_BasePrecisionRecall):
"""
Expand Down
4 changes: 4 additions & 0 deletions ignite/metrics/root_mean_squared_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

from ignite.metrics.mean_squared_error import MeanSquaredError

__all__ = [
'RootMeanSquaredError'
]


class RootMeanSquaredError(MeanSquaredError):
"""
Expand Down
4 changes: 4 additions & 0 deletions ignite/metrics/running_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
from ignite.metrics import Metric
from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce

__all__ = [
'RunningAverage'
]


class RunningAverage(Metric):
"""Compute running average of a metric or the output of process function.
Expand Down
4 changes: 4 additions & 0 deletions ignite/metrics/top_k_categorical_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from ignite.exceptions import NotComputableError
from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced

__all__ = [
'TopKCategoricalAccuracy'
]


class TopKCategoricalAccuracy(Metric):
"""
Expand Down
8 changes: 8 additions & 0 deletions ignite/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@

import torch

__all__ = [
'convert_tensor',
'apply_to_tensor',
'apply_to_type',
'to_onehot',
'setup_logger'
]


def convert_tensor(input_, device=None, non_blocking=False):
"""Move tensors to relevant device."""
Expand Down