Skip to content

Moving the Influence test helper #1484

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions tests/influence/_core/test_arnoldi_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@
_top_eigen,
_unflatten_params_factory,
)
from parameterized import parameterized
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from tests.helpers.influence.common import (
from captum.testing.helpers.influence.common import (
_format_batch_into_tuple,
build_test_name_func,
DataInfluenceConstructor,
Expand All @@ -31,6 +28,9 @@
is_gpu,
UnpackDataset,
)
from parameterized import parameterized
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from torch import Tensor
from torch.utils.data import DataLoader

Expand Down
16 changes: 8 additions & 8 deletions tests/influence/_core/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
TracInCPFast,
TracInCPFastRandProj,
)
from parameterized import parameterized
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from tests.helpers.influence.common import (
from captum.testing.helpers.influence.common import (
_format_batch_into_tuple,
build_test_name_func,
DataInfluenceConstructor,
get_random_model_and_data,
)
from parameterized import parameterized
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from torch.utils.data import DataLoader


Expand All @@ -32,13 +32,13 @@ class TestTracInDataLoader(BaseTest):
# `comprehension((reduction, constr, unpack_inputs) for
# generators(generator(unpack_inputs in [False, True] if ),
# generators(generator((reduction, constr) in
# [("none", tests.helpers.influence.common.DataInfluenceConstructor
# [("none", captum.testing.helpers.influence.common.DataInfluenceConstructor
# (captum.influence._core.tracincp.TracInCP)),
# ("sum", tests.helpers.influence.common.DataInfluenceConstructor
# ("sum", captum.testing.helpers.influence.common.DataInfluenceConstructor
# (captum.influence._core.tracincp_fast_rand_proj.TracInCPFast)), ("sum",
# tests.helpers.influence.common.DataInfluenceConstructor(captum.influence._core.
# captum.testing.helpers.influence.common.DataInfluenceConstructor(captum.influence._core.
# tracincp_fast_rand_proj.TracInCPFastRandProj)), ("sum",
# tests.helpers.influence.common.DataInfluenceConstructor(
# captum.testing.helpers.influence.common.DataInfluenceConstructor(
# captum.influence._core.tracincp_fast_rand_proj.TracInCPFastRandProj,
# $parameter$name = "TracInCPFastRandProj_1DProj",
# $parameter$projection_dim = 1))] if ))))`
Expand Down
8 changes: 4 additions & 4 deletions tests/influence/_core/test_naive_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@
_functional_call,
_unflatten_params_factory,
)
from parameterized import parameterized
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual, assertTensorTuplesAlmostEqual
from tests.helpers.influence.common import (
from captum.testing.helpers.influence.common import (
_format_batch_into_tuple,
build_test_name_func,
DataInfluenceConstructor,
Expand All @@ -25,6 +22,9 @@
Linear,
UnpackDataset,
)
from parameterized import parameterized
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual, assertTensorTuplesAlmostEqual
from torch.utils.data import DataLoader

# TODO: for some unknow reason, this test does not work
Expand Down
6 changes: 3 additions & 3 deletions tests/influence/_core/test_tracin_aggregate_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@

import torch.nn as nn
from captum.influence._core.tracincp import TracInCP
from parameterized import parameterized
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.helpers.influence.common import (
from captum.testing.helpers.influence.common import (
build_test_name_func,
DataInfluenceConstructor,
get_random_model_and_data,
)
from parameterized import parameterized
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from torch.utils.data import DataLoader


Expand Down
8 changes: 4 additions & 4 deletions tests/influence/_core/test_tracin_intermediate_quantities.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
TracInCPFast,
TracInCPFastRandProj,
)
from parameterized import parameterized
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from tests.helpers.influence.common import (
from captum.testing.helpers.influence.common import (
_format_batch_into_tuple,
build_test_name_func,
DataInfluenceConstructor,
get_random_model_and_data,
)
from parameterized import parameterized
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from torch.utils.data import DataLoader


Expand Down
12 changes: 6 additions & 6 deletions tests/influence/_core/test_tracin_k_most_influential.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@
import torch
import torch.nn as nn
from captum.influence._core.tracincp import TracInCP

from parameterized import parameterized
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from tests.helpers.influence.common import (
from captum.testing.helpers.influence.common import (
_format_batch_into_tuple,
build_test_name_func,
DataInfluenceConstructor,
Expand All @@ -19,6 +15,10 @@
is_gpu,
)

from parameterized import parameterized
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual


class TestTracInGetKMostInfluential(BaseTest):
param_list: List[
Expand Down Expand Up @@ -76,7 +76,7 @@ class TestTracInGetKMostInfluential(BaseTest):
)

# pyre-fixme[56]: Pyre was not able to infer the type of argument
# `tests.helpers.influence.common.build_test_name_func()`
# `captum.testing.helpers.influence.common.build_test_name_func()`
# to decorator factory `parameterized.parameterized.expand`.
@parameterized.expand(
param_list,
Expand Down
16 changes: 8 additions & 8 deletions tests/influence/_core/test_tracin_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@
TracInCPFast,
TracInCPFastRandProj,
)
from parameterized import parameterized
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from tests.helpers.influence.common import (
from captum.testing.helpers.influence.common import (
_isSorted,
_wrap_model_in_dataparallel,
build_test_name_func,
Expand All @@ -25,6 +22,9 @@
IdentityDataset,
RangeDataset,
)
from parameterized import parameterized
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from torch import Tensor


Expand Down Expand Up @@ -142,7 +142,7 @@ def _test_tracin_regression_setup(
param_list.append((reduction, constructor, mode, dim, use_gpu))

# pyre-fixme[56]: Pyre was not able to infer the type of argument
# `tests.helpers.influence.common.build_test_name_func
# `captum.testing.helpers.influence.common.build_test_name_func
# ($parameter$args_to_skip = ["reduction"])` to decorator factory
# `parameterized.parameterized.expand`.
@parameterized.expand(
Expand Down Expand Up @@ -258,7 +258,7 @@ def test_tracin_regression(
)

# pyre-fixme[56]: Pyre was not able to infer the type of argument
# `tests.helpers.influence.common.build_test_name_func()`
# `captum.testing.helpers.influence.common.build_test_name_func()`
# to decorator factory `parameterized.parameterized.expand`.
@parameterized.expand(
[
Expand Down Expand Up @@ -350,7 +350,7 @@ def _test_tracin_identity_regression_setup(
return dataset, net

# pyre-fixme[56]: Pyre was not able to infer the type of argument
# `tests.helpers.influence.common.build_test_name_func()`
# `captum.testing.helpers.influence.common.build_test_name_func()`
# to decorator factory `parameterized.parameterized.expand`
@parameterized.expand(
[
Expand Down Expand Up @@ -465,7 +465,7 @@ def test_tracin_identity_regression(
)

# pyre-fixme[56]: Pyre was not able to infer the type of argument
# `tests.helpers.influence.common.build_test_name_func()`
# `captum.testing.helpers.influence.common.build_test_name_func()`
# to decorator factory `parameterized.parameterized.expand`.
@parameterized.expand(
[
Expand Down
8 changes: 4 additions & 4 deletions tests/influence/_core/test_tracin_self_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@
from captum.influence._core.influence_function import NaiveInfluenceFunction
from captum.influence._core.tracincp import TracInCP, TracInCPBase
from captum.influence._core.tracincp_fast_rand_proj import TracInCPFast
from parameterized import parameterized
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from tests.helpers.influence.common import (
from captum.testing.helpers.influence.common import (
_format_batch_into_tuple,
build_test_name_func,
DataInfluenceConstructor,
get_random_model_and_data,
GPU_SETTING_LIST,
is_gpu,
)
from parameterized import parameterized
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from torch.utils.data import DataLoader


Expand Down
10 changes: 5 additions & 5 deletions tests/influence/_core/test_tracin_show_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
import torch.nn as nn
from captum.influence._core.tracincp import TracInCP
from captum.influence._core.tracincp_fast_rand_proj import TracInCPFast
from parameterized import parameterized
from tests.helpers import BaseTest
from tests.helpers.influence.common import (
from captum.testing.helpers.influence.common import (
build_test_name_func,
DataInfluenceConstructor,
get_random_model_and_data,
)
from parameterized import parameterized
from tests.helpers import BaseTest
from torch.utils.data import DataLoader


Expand Down Expand Up @@ -69,9 +69,9 @@ def _check_error_msg_multiplicity(
# pyre-fixme[56]: Pyre was not able to infer the type of argument
# `comprehension((reduction, constr, mode) for
# generators(generator((reduction, constr) in
# [("none", tests.helpers.influence.common.DataInfluenceConstructor
# [("none", captum.testing.helpers.influence.common.DataInfluenceConstructor
# (captum.influence._core.tracincp.TracInCP)),
# ("sum", tests.helpers.influence.common.DataInfluenceConstructor
# ("sum", captum.testing.helpers.influence.common.DataInfluenceConstructor
# (captum.influence._core.tracincp_fast_rand_proj.TracInCPFast))] if ),
# generators(generator(mode in ["self influence by checkpoints",
# "self influence by batches", "influence", "k-most"] if ))))`
Expand Down
8 changes: 4 additions & 4 deletions tests/influence/_core/test_tracin_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
import torch.nn as nn
from captum.influence._core.tracincp import TracInCP
from captum.influence._core.tracincp_fast_rand_proj import TracInCPFast

from parameterized import parameterized
from tests.helpers import BaseTest
from tests.helpers.influence.common import (
from captum.testing.helpers.influence.common import (
build_test_name_func,
DataInfluenceConstructor,
get_random_model_and_data,
)

from parameterized import parameterized
from tests.helpers import BaseTest


class TestTracinValidator(BaseTest):

Expand Down
10 changes: 5 additions & 5 deletions tests/influence/_core/test_tracin_xor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@
import torch.nn as nn
import torch.nn.functional as F
from captum.influence._core.tracincp import TracInCP
from parameterized import parameterized
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual
from tests.helpers.influence.common import (
from captum.testing.helpers.influence.common import (
_wrap_model_in_dataparallel,
BasicLinearNet,
BinaryDataset,
build_test_name_func,
DataInfluenceConstructor,
)
from parameterized import parameterized
from tests.helpers import BaseTest
from tests.helpers.basic import assertTensorAlmostEqual


class TestTracInXOR(BaseTest):
Expand Down Expand Up @@ -225,7 +225,7 @@ def _test_tracin_xor_setup(
)

# pyre-fixme[56]: Pyre was not able to infer the type of argument
# `tests.helpers.influence.common.build_test_name_func($parameter$args_to_skip
# `captum.testing.helpers.influence.common.build_test_name_func($parameter$args_to_skip
# = ["reduction"])` to decorator factory `parameterized.parameterized.expand`.
@parameterized.expand(
parametrized_list,
Expand Down
Loading