Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def load_objects(to_load: Mapping, checkpoint: Mapping, **kwargs: Any) -> None:
if not isinstance(checkpoint, collections.Mapping):
raise TypeError(f"Argument checkpoint should be a dictionary, but given {type(checkpoint)}")

if len(kwargs) > 1 or any(k for k in kwargs.keys() if k not in ["strict"]):
if len(kwargs) > 1 or any(k for k in kwargs if k not in ["strict"]):
warnings.warn("kwargs contains keys other than strict and these will be ignored")

is_state_dict_strict = kwargs.get("strict", True)
Expand Down
2 changes: 1 addition & 1 deletion ignite/metrics/accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def reset(self) -> None:
self.num_examples = 0

def _check_output_type(self, output: Union[float, torch.Tensor]) -> None:
if not (isinstance(output, numbers.Number) or isinstance(output, torch.Tensor)):
if not isinstance(output, (numbers.Number, torch.Tensor)):
raise TypeError(f"Output should be a number or torch.Tensor, but given {type(output)}")

@reinit__is_reduced
Expand Down
2 changes: 1 addition & 1 deletion ignite/metrics/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
self,
loss_fn: Callable,
output_transform: Callable = lambda x: x,
batch_size: Callable = lambda x: len(x),
batch_size: Callable = len,
device: Union[str, torch.device] = torch.device("cpu"),
):
super(Loss, self).__init__(output_transform, device=device)
Expand Down
2 changes: 1 addition & 1 deletion ignite/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def wrapper(*args: Any, **kwargs: Dict[str, Any]) -> Callable:
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
return func(*args, **kwargs)

appended_doc = f".. deprecated:: {deprecated_in}" + ("\n\n\t" if len(reasons) else "")
appended_doc = f".. deprecated:: {deprecated_in}" + ("\n\n\t" if len(reasons) > 0 else "")

for reason in reasons:
appended_doc += "\n\t- " + reason
Expand Down
1 change: 0 additions & 1 deletion tests/ignite/contrib/handlers/test_mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,6 @@ def test_mlflow_bad_metric_name_handling(dirname):

@pytest.fixture
def no_site_packages():
import sys

mlflow_client_modules = {}
for k in sys.modules:
Expand Down
2 changes: 0 additions & 2 deletions tests/ignite/contrib/handlers/test_visdom_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,8 +936,6 @@ def update_fn(engine, batch):

@pytest.fixture
def no_site_packages():
import sys

import visdom # noqa: F401

visdom_module = sys.modules["visdom"]
Expand Down
2 changes: 0 additions & 2 deletions tests/ignite/distributed/test_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,6 @@ def test_auto_methods_xla():


def test_dist_proxy_sampler():
import torch
from torch.utils.data import WeightedRandomSampler

weights = torch.ones(100)
weights[:50] += 1
Expand Down
20 changes: 8 additions & 12 deletions tests/ignite/engine/test_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import pytest
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import SGD
from torch.utils.data import BatchSampler, DataLoader, RandomSampler

import ignite.distributed as idist
from ignite.engine import Events
Expand Down Expand Up @@ -92,8 +93,6 @@ def test_reproducible_batch_sampler_wrong_input():


def test_reproducible_batch_sampler():
import torch
from torch.utils.data import DataLoader

data = list(range(100))
dataloader = DataLoader(data, batch_size=12, num_workers=0, shuffle=True, drop_last=True)
Expand Down Expand Up @@ -599,11 +598,12 @@ def test_multinode_distrib_gpu(distributed_context_multi_node_nccl):

def test_concepts_snippet_resume():

import torch
from torch.utils.data import DataLoader
# Commented imports required in the snippet
# import torch
# from torch.utils.data import DataLoader

from ignite.engine import DeterministicEngine
from ignite.utils import manual_seed
# from ignite.engine import DeterministicEngine
# from ignite.utils import manual_seed

seen_batches = []
manual_seed(seed=15)
Expand Down Expand Up @@ -663,10 +663,7 @@ def _test_gradients_on_resume(
dirname, device, with_dropout=True, with_dataaugs=True, data_size=24, batch_size=4, save_iter=None, save_epoch=None
):

debug = True

from torch.optim import SGD
from torch.utils.data import DataLoader
debug = False

def random_train_data_loader(size):
d = AugmentedData(torch.rand(size, 3, 32, 32), enabled=with_dataaugs)
Expand Down Expand Up @@ -820,7 +817,6 @@ def test_gradients_on_resume_on_cuda(dirname):

def test_engine_with_dataloader_no_auto_batching():
# tests https://github.com/pytorch/ignite/issues/941
from torch.utils.data import BatchSampler, DataLoader, RandomSampler

data = torch.rand(64, 4, 10)
data_loader = DataLoader(
Expand Down
22 changes: 8 additions & 14 deletions tests/ignite/engine/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,8 +702,7 @@ def restart_iter():

def test_faq_inf_iterator_with_epoch_length():
# Code snippet from FAQ

import torch
# import torch

torch.manual_seed(12)

Expand All @@ -727,8 +726,7 @@ def train_step(trainer, batch):

def test_faq_inf_iterator_no_epoch_length():
# Code snippet from FAQ

import torch
# import torch

torch.manual_seed(12)

Expand Down Expand Up @@ -756,8 +754,7 @@ def stop_training():

def test_faq_fin_iterator_unknw_size():
# Code snippet from FAQ

import torch
# import torch

torch.manual_seed(12)

Expand All @@ -782,9 +779,8 @@ def restart_iter():
assert trainer.state.epoch == 5
assert trainer.state.iteration == 5 * 11

# # # # #

import torch
# Code snippet from FAQ
# import torch

torch.manual_seed(12)

Expand All @@ -808,8 +804,7 @@ def val_step(evaluator, batch):

def test_faq_fin_iterator():
# Code snippet from FAQ

import torch
# import torch

torch.manual_seed(12)

Expand All @@ -836,9 +831,8 @@ def restart_iter():
assert trainer.state.epoch == 5
assert trainer.state.iteration == 5 * size

# # # # #

import torch
# Code snippet from FAQ
# import torch

torch.manual_seed(12)

Expand Down