Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
55dfe72
breaking up pull request
uribgp Dec 16, 2020
df914eb
breaking up pull request
uribgp Dec 16, 2020
7c2c68b
Update docs/source/concepts.rst
vfdev-5 Dec 16, 2020
6e9d82c
breaking up commits
uribgp Dec 16, 2020
3c43743
Merge branch 'master' of https://github.com/pytorch/ignite into forma…
uribgp Dec 16, 2020
f4a5a2e
Merge branch 'format_small' of https://github.com/uribgp/ignite into …
uribgp Dec 16, 2020
63c36fd
Apply suggestions from code review
vfdev-5 Dec 16, 2020
3e10083
ran black, copied corrections from github, added 'f' to one line
uribgp Dec 17, 2020
a9ffbbf
Merge branch 'master' of https://github.com/pytorch/ignite into forma…
uribgp Dec 17, 2020
86d1cc7
Merge branch 'format_small' of https://github.com/uribgp/ignite into …
uribgp Dec 17, 2020
b0e8288
metrics output f-strings
uribgp Dec 17, 2020
4a6c7d7
breaking up large pull. needs to do more in trains_logger.py
uribgp Dec 17, 2020
3287512
metadata items f-string
uribgp Dec 17, 2020
06fb7ca
Merge branch 'master' of https://github.com/pytorch/ignite into forma…
uribgp Dec 17, 2020
f33852e
Merge branch 'master' of https://github.com/pytorch/ignite into forma…
uribgp Dec 18, 2020
7e584af
Merge branch 'master' into format_small
uribgp Dec 18, 2020
963467f
retrigger checks
uribgp Dec 18, 2020
5beb8ae
Merge branch 'format_small' of https://github.com/uribgp/ignite
uribgp Dec 18, 2020
b810b2d
retrigger checks
uribgp Dec 18, 2020
0e010bd
Merge branch 'format_small' of https://github.com/uribgp/ignite into …
uribgp Dec 18, 2020
8f8bdb8
retrigger checks
uribgp Dec 18, 2020
a5227c7
RemovableEventHandle
uribgp Dec 18, 2020
003956e
f-strings for rows
uribgp Dec 18, 2020
bbbaa10
Revert "f-strings for rows"
uribgp Dec 18, 2020
66f8576
combining strings
uribgp Dec 18, 2020
8a04530
Merge branch 'master' into format_small
vfdev-5 Dec 18, 2020
6da60da
schedulers
uribgp Dec 19, 2020
b83444f
Merge branch 'master' of https://github.com/uribgp/ignite into format…
uribgp Dec 19, 2020
409f53c
Merge branch 'master' into format_small
vfdev-5 Dec 19, 2020
c22856c
Merge branch 'master' of https://github.com/pytorch/ignite into forma…
uribgp Dec 19, 2020
33c22a9
Merge branch 'format_small' of https://github.com/uribgp/ignite into …
uribgp Dec 19, 2020
f8cbe58
Merge branch 'master' of https://github.com/pytorch/ignite into forma…
uribgp Dec 19, 2020
d295f87
breaking up large pull
uribgp Dec 19, 2020
d17b1c8
breaking up large pull
uribgp Dec 20, 2020
9c3b7b8
autopep8 fix
uribgp Dec 20, 2020
5c1e862
Merge branch 'master' into format_small
vfdev-5 Dec 20, 2020
d9d8bc1
Merge branch 'master' of https://github.com/pytorch/ignite into forma…
uribgp Dec 20, 2020
317a1b4
annotation
uribgp Dec 20, 2020
83415de
annotation
uribgp Dec 20, 2020
31d0946
Merge branch 'format_small' of https://github.com/uribgp/ignite into …
uribgp Dec 20, 2020
578ac33
autopep8 fix
uribgp Dec 20, 2020
e215dcd
type ignore
uribgp Dec 21, 2020
791e751
Merge branch 'format_small' of https://github.com/uribgp/ignite into …
uribgp Dec 21, 2020
fd14a59
autopep8 fix
uribgp Dec 21, 2020
a9bcb52
autopep8 work around
uribgp Dec 21, 2020
c150a6f
Merge branch 'format_small' of https://github.com/uribgp/ignite into …
uribgp Dec 21, 2020
0dd131a
autopep8 fix
uribgp Dec 21, 2020
f33f8fb
bypass autopep
uribgp Dec 22, 2020
99890fd
Merge branch 'master' of https://github.com/pytorch/ignite into forma…
uribgp Dec 22, 2020
b5539e0
Update ignite/metrics/epoch_metric.py
vfdev-5 Dec 22, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 16 additions & 28 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def log_training(engine):
e = engine.state.epoch
n = engine.state.max_epochs
i = engine.state.iteration
print("Epoch {}/{} : {} - batch loss: {}, lr: {}".format(e, n, i, batch_loss, lr))
print(f"Epoch {e}/{n} : {i} - batch loss: {batch_loss}, lr: {lr}")

trainer.run(data_loader, max_epochs=5)

Expand Down Expand Up @@ -218,13 +218,11 @@ class TBPTT_Events(EventEnum):
# engine.state contains an attribute time_iteration, which can be accessed using engine.state.time_iteration
"""
if not (event_to_attr is None or isinstance(event_to_attr, dict)):
raise ValueError("Expected event_to_attr to be dictionary. Got {}.".format(type(event_to_attr)))
raise ValueError(f"Expected event_to_attr to be dictionary. Got {type(event_to_attr)}.")

for index, e in enumerate(event_names):
if not isinstance(e, (str, EventEnum)):
raise TypeError(
"Value at {} of event_names should be a str or EventEnum, but given {}".format(index, e)
)
raise TypeError(f"Value at {index} of event_names should be a str or EventEnum, but given {e}")
self._allowed_events.append(e)
if event_to_attr and e in event_to_attr:
State.event_to_attr[e] = event_to_attr[e]
Expand Down Expand Up @@ -271,7 +269,7 @@ def add_event_handler(self, event_name: Any, handler: Callable, *args: Any, **kw
engine = Engine(process_function)

def print_epoch(engine):
print("Epoch: {}".format(engine.state.epoch))
print(f"Epoch: {engine.state.epoch}")

engine.add_event_handler(Events.EPOCH_COMPLETED, print_epoch)

Expand Down Expand Up @@ -301,7 +299,7 @@ def execute_something():

if event_name not in self._allowed_events:
self.logger.error("attempt to add event handler to an invalid event %s.", event_name)
raise ValueError("Event {} is not a valid event for this Engine.".format(event_name))
raise ValueError(f"Event {event_name} is not a valid event for this Engine.")

event_args = (Exception(),) if event_name == Events.EXCEPTION_RAISED else ()
try:
Expand Down Expand Up @@ -359,15 +357,15 @@ def remove_event_handler(self, handler: Callable, event_name: Any) -> None:

"""
if event_name not in self._event_handlers:
raise ValueError("Input event name '{}' does not exist".format(event_name))
raise ValueError(f"Input event name '{event_name}' does not exist")

new_event_handlers = [
(h, args, kwargs)
for h, args, kwargs in self._event_handlers[event_name]
if not self._compare_handlers(handler, h)
]
if len(new_event_handlers) == len(self._event_handlers[event_name]):
raise ValueError("Input handler '{}' is not found among registered event handlers".format(handler))
raise ValueError(f"Input handler '{handler}' is not found among registered event handlers")
self._event_handlers[event_name] = new_event_handlers

def on(self, event_name: Any, *args: Any, **kwargs: Any) -> Callable:
Expand All @@ -387,7 +385,7 @@ def on(self, event_name: Any, *args: Any, **kwargs: Any) -> Callable:

@engine.on(Events.EPOCH_COMPLETED)
def print_epoch():
print("Epoch: {}".format(engine.state.epoch))
print(f"Epoch: {engine.state.epoch}")

@engine.on(Events.EPOCH_COMPLETED | Events.COMPLETED)
def execute_something():
Expand Down Expand Up @@ -533,9 +531,7 @@ def load_state_dict(self, state_dict: Mapping) -> None:
for k in self._state_dict_user_keys:
if k not in state_dict:
raise ValueError(
"Required user state attribute '{}' is absent in provided state_dict '{}'".format(
k, state_dict.keys()
)
f"Required user state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'"
)
self.state.max_epochs = state_dict["max_epochs"]
self.state.epoch_length = state_dict["epoch_length"]
Expand All @@ -552,7 +548,7 @@ def load_state_dict(self, state_dict: Mapping) -> None:
if self.state.epoch_length is None:
raise ValueError(
"If epoch is provided in the state dict, epoch_length should not be None. "
"Input state_dict: {}".format(state_dict)
f"Input state_dict: {state_dict}"
)
self.state.iteration = self.state.epoch_length * self.state.epoch

Expand Down Expand Up @@ -670,18 +666,14 @@ def switch_batch(engine):
if max_epochs < self.state.epoch:
raise ValueError(
"Argument max_epochs should be larger than the start epoch "
"defined in the state: {} vs {}. Please, set engine.state.max_epochs = None "
"before calling engine.run() in order to restart the training from the beginning.".format(
max_epochs, self.state.epoch
)
f"defined in the state: {max_epochs} vs {self.state.epoch}. Please, set engine.state.max_epochs = None "
"before calling engine.run() in order to restart the training from the beginning."
)
self.state.max_epochs = max_epochs
if epoch_length is not None:
if epoch_length != self.state.epoch_length:
raise ValueError(
"Argument epoch_length should be same as in the state, given {} vs {}".format(
epoch_length, self.state.epoch_length
)
f"Argument epoch_length should be same as in the state, given {epoch_length} vs {self.state.epoch_length}"
)

if self.state.max_epochs is None or self._is_done(self.state):
Expand All @@ -708,12 +700,10 @@ def switch_batch(engine):
self.state.max_epochs = max_epochs
self.state.max_iters = max_iters
self.state.epoch_length = epoch_length
self.logger.info("Engine run starting with max_epochs={}.".format(max_epochs))
self.logger.info(f"Engine run starting with max_epochs={max_epochs}.")
else:
self.logger.info(
"Engine run resuming from iteration {}, epoch {} until {} epochs".format(
self.state.iteration, self.state.epoch, self.state.max_epochs
)
f"Engine run resuming from iteration {self.state.iteration}, epoch {self.state.epoch} until {self.state.max_epochs} epochs"
)

self.state.dataloader = data
Expand Down Expand Up @@ -843,9 +833,7 @@ def _run_once_on_dataset(self) -> float:
warnings.warn(
"Data iterator can not provide data anymore but required total number of "
"iterations to run is not reached. "
"Current iteration: {} vs Total iterations to run : {}".format(
self.state.iteration, total_iters,
)
f"Current iteration: {self.state.iteration} vs Total iterations to run : {total_iters}"
)
break

Expand Down
8 changes: 4 additions & 4 deletions ignite/engine/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def __init__(self) -> None:

def _append(self, event: Union[Events, CallableEventWithFilter]) -> None:
if not isinstance(event, (Events, CallableEventWithFilter)):
raise TypeError("Argument event should be Events or CallableEventWithFilter, got: {}".format(type(event)))
raise TypeError(f"Argument event should be Events or CallableEventWithFilter, got: {type(event)}")
self._events.append(event)

def __getitem__(self, item: int) -> Union[Events, CallableEventWithFilter]:
Expand Down Expand Up @@ -392,15 +392,15 @@ def _update_attrs(self) -> None:

def get_event_attrib_value(self, event_name: Union[str, Events, CallableEventWithFilter]) -> int:
if event_name not in State.event_to_attr:
raise RuntimeError("Unknown event name '{}'".format(event_name))
raise RuntimeError(f"Unknown event name '{event_name}'")
return getattr(self, State.event_to_attr[event_name])

def __repr__(self) -> str:
s = "State:\n"
for attr, value in self.__dict__.items():
if not isinstance(value, (numbers.Number, str)):
value = type(value)
s += "\t{}: {}\n".format(attr, value)
s += f"\t{attr}: {value}\n"
return s


Expand All @@ -424,7 +424,7 @@ class RemovableEventHandle:
engine = Engine()

def print_epoch(engine):
print("Epoch: {}".format(engine.state.epoch))
print(f"Epoch: {engine.state.epoch}")

with engine.add_event_handler(Events.EPOCH_COMPLETED, print_epoch):
# print_epoch handler registered for a single run
Expand Down
6 changes: 3 additions & 3 deletions ignite/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def _check_signature(fn: Callable, fn_description: str, *args: Any, **kwargs: An
exception_msg = str(exc)
passed_params = list(args) + list(kwargs)
raise ValueError(
"Error adding {} '{}': "
"takes parameters {} but will be called with {}"
"({}).".format(fn, fn_description, fn_params, passed_params, exception_msg)
f"Error adding {fn} '{fn_description}': "
f"takes parameters {fn_params} but will be called with {passed_params}"
f"({exception_msg})."
)
32 changes: 14 additions & 18 deletions ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def __init__(

if to_save is not None: # for compatibility with ModelCheckpoint
if not isinstance(to_save, collections.Mapping):
raise TypeError("Argument `to_save` should be a dictionary, but given {}".format(type(to_save)))
raise TypeError(f"Argument `to_save` should be a dictionary, but given {type(to_save)}")

if len(to_save) < 1:
raise ValueError("No objects to checkpoint.")
Expand All @@ -261,11 +261,11 @@ def __init__(
if include_self:
if not isinstance(to_save, collections.MutableMapping):
raise TypeError(
"If `include_self` is True, then `to_save` must be mutable, but given {}.".format(type(to_save))
f"If `include_self` is True, then `to_save` must be mutable, but given {type(to_save)}."
)

if "checkpointer" in to_save:
raise ValueError("Cannot have key 'checkpointer' if `include_self` is True: {}".format(to_save))
raise ValueError(f"Cannot have key 'checkpointer' if `include_self` is True: {to_save}")

if not (callable(save_handler) or isinstance(save_handler, BaseSaveHandler)):
raise TypeError("Argument `save_handler` should be callable or inherit from BaseSaveHandler")
Expand All @@ -274,9 +274,7 @@ def __init__(
raise ValueError("If `score_name` is provided, then `score_function` " "should be also provided.")

if global_step_transform is not None and not callable(global_step_transform):
raise TypeError(
"global_step_transform should be a function, got {} instead.".format(type(global_step_transform))
)
raise TypeError(f"global_step_transform should be a function, got {type(global_step_transform)} instead.")

self.to_save = to_save
self.filename_prefix = filename_prefix
Expand Down Expand Up @@ -318,9 +316,7 @@ def __call__(self, engine: Engine) -> None:

if self._check_lt_n_saved() or self._saved[0].priority < priority:

priority_str = (
"{}".format(priority) if isinstance(priority, numbers.Integral) else "{:.4f}".format(priority)
)
priority_str = f"{priority}" if isinstance(priority, numbers.Integral) else f"{priority:.4f}"

checkpoint = self._setup_checkpoint()

Expand Down Expand Up @@ -351,7 +347,7 @@ def __call__(self, engine: Engine) -> None:
filename = filename_pattern.format(**filename_dict)

metadata = {
"basename": "{}{}{}".format(self.filename_prefix, "_" * int(len(self.filename_prefix) > 0), name),
"basename": f"{self.filename_prefix}{'_' * int(len(self.filename_prefix) > 0)}{name}",
"score_name": self.score_name,
"priority": priority,
}
Expand Down Expand Up @@ -443,7 +439,7 @@ def setup_filename_pattern(
def _check_objects(objs: Mapping, attr: str) -> None:
for k, obj in objs.items():
if not hasattr(obj, attr):
raise TypeError("Object {} should have `{}` method".format(type(obj), attr))
raise TypeError(f"Object {type(obj)} should have `{attr}` method")

@staticmethod
def load_objects(to_load: Mapping, checkpoint: Mapping, **kwargs: Any) -> None:
Expand Down Expand Up @@ -488,7 +484,7 @@ def load_objects(to_load: Mapping, checkpoint: Mapping, **kwargs: Any) -> None:
"""
Checkpoint._check_objects(to_load, "load_state_dict")
if not isinstance(checkpoint, collections.Mapping):
raise TypeError("Argument checkpoint should be a dictionary, but given {}".format(type(checkpoint)))
raise TypeError(f"Argument checkpoint should be a dictionary, but given {type(checkpoint)}")

if len(kwargs) > 1 or any(k for k in kwargs.keys() if k not in ["strict"]):
warnings.warn("kwargs contains keys other than strict and these will be ignored")
Expand All @@ -506,7 +502,7 @@ def load_objects(to_load: Mapping, checkpoint: Mapping, **kwargs: Any) -> None:
# multiple objects to load
for k, obj in to_load.items():
if k not in checkpoint:
raise ValueError("Object labeled by '{}' from `to_load` is not found in the checkpoint".format(k))
raise ValueError(f"Object labeled by '{k}' from `to_load` is not found in the checkpoint")
if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
obj = obj.module
if isinstance(obj, torch.nn.Module):
Expand Down Expand Up @@ -552,16 +548,16 @@ def _check_and_setup(dirname: str, create_dir: bool, require_empty: bool) -> Non
os.makedirs(dirname)
# Ensure that dirname exists
if not os.path.exists(dirname):
raise ValueError("Directory path '{}' is not found".format(dirname))
raise ValueError(f"Directory path '{dirname}' is not found")

if require_empty:
matched = [fname for fname in os.listdir(dirname) if fname.endswith(".pt")]
if len(matched) > 0:
raise ValueError(
"Files {} with extension '.pt' are already present "
"in the directory {}. If you want to use this "
f"Files {matched} with extension '.pt' are already present "
f"in the directory {dirname}. If you want to use this "
"directory anyway, pass `require_empty=False`."
"".format(matched, dirname)
""
)

def __call__(self, checkpoint: Mapping, filename: str, metadata: Optional[Mapping] = None) -> None:
Expand Down Expand Up @@ -709,7 +705,7 @@ def last_checkpoint(self) -> Union[str, None]:

if not isinstance(self.save_handler, DiskSaver):
raise RuntimeError(
"Unable to save checkpoint, save_handler should be DiskSaver, got {}.".format(type(self.save_handler))
f"Unable to save checkpoint, save_handler should be DiskSaver, got {type(self.save_handler)}."
)

return os.path.join(self.save_handler.dirname, self._saved[-1].filename)
Expand Down
4 changes: 1 addition & 3 deletions ignite/handlers/terminate_on_nan.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,5 @@ def raise_error(x: Union[float, torch.Tensor]) -> None:
try:
apply_to_type(output, (numbers.Number, torch.Tensor), raise_error)
except RuntimeError:
self.logger.warning(
"{}: Output '{}' contains NaN or Inf. Stop training".format(self.__class__.__name__, output)
)
self.logger.warning(f"{self.__class__.__name__}: Output '{output}' contains NaN or Inf. Stop training")
engine.terminate()
8 changes: 4 additions & 4 deletions ignite/metrics/accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
device: Union[str, torch.device] = torch.device("cpu"),
):
if not callable(op):
raise TypeError("Argument op should be a callable, but given {}".format(type(op)))
raise TypeError(f"Argument op should be a callable, but given {type(op)}")

self._op = op

Expand All @@ -59,7 +59,7 @@ def reset(self) -> None:

def _check_output_type(self, output: Union[float, torch.Tensor]) -> None:
if not (isinstance(output, numbers.Number) or isinstance(output, torch.Tensor)):
raise TypeError("Output should be a number or torch.Tensor, but given {}".format(type(output)))
raise TypeError(f"Output should be a number or torch.Tensor, but given {type(output)}")

@reinit__is_reduced
def update(self, output: Union[float, torch.Tensor]) -> None:
Expand Down Expand Up @@ -135,7 +135,7 @@ def _mean_op(a: Union[float, torch.Tensor], x: Union[float, torch.Tensor]) -> Un
def compute(self) -> Union[float, torch.Tensor]:
if self.num_examples < 1:
raise NotComputableError(
"{} must have at least one example before it can be computed.".format(self.__class__.__name__)
f"{self.__class__.__name__} must have at least one example before it can be computed."
)

return self.accumulator / self.num_examples
Expand Down Expand Up @@ -186,7 +186,7 @@ def _geom_op(a: torch.Tensor, x: Union[float, torch.Tensor]) -> torch.Tensor:
def compute(self) -> Union[float, torch.Tensor]:
if self.num_examples < 1:
raise NotComputableError(
"{} must have at least one example before it can be computed.".format(self.__class__.__name__)
f"{self.__class__.__name__} must have at least one example before it can be computed."
)

tensor = torch.exp(self.accumulator / self.num_examples)
Expand Down
12 changes: 5 additions & 7 deletions ignite/metrics/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _check_shape(self, output: Sequence[torch.Tensor]) -> None:
raise ValueError(
"y must have shape of (batch_size, ...) and y_pred must have "
"shape of (batch_size, num_categories, ...) or (batch_size, ...), "
"but given {} vs {}.".format(y.shape, y_pred.shape)
f"but given {y.shape} vs {y_pred.shape}."
)

y_shape = y.shape
Expand Down Expand Up @@ -78,19 +78,17 @@ def _check_type(self, output: Sequence[torch.Tensor]) -> None:
num_classes = 1
else:
raise RuntimeError(
"Invalid shapes of y (shape={}) and y_pred (shape={}), check documentation."
" for expected shapes of y and y_pred.".format(y.shape, y_pred.shape)
f"Invalid shapes of y (shape={y.shape}) and y_pred (shape={y_pred.shape}), check documentation."
" for expected shapes of y and y_pred."
)
if self._type is None:
self._type = update_type
self._num_classes = num_classes
else:
if self._type != update_type:
raise RuntimeError("Input data type has changed from {} to {}.".format(self._type, update_type))
raise RuntimeError(f"Input data type has changed from {self._type} to {update_type}.")
if self._num_classes != num_classes:
raise ValueError(
"Input data number of classes has changed from {} to {}".format(self._num_classes, num_classes)
)
raise ValueError(f"Input data number of classes has changed from {self._num_classes} to {num_classes}")


class Accuracy(_BaseClassification):
Expand Down
Loading