Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convert forward return to tensor in FeatureAblation #1049

Closed
wants to merge 2 commits into from

Conversation

aobo-y
Copy link
Contributor

@aobo-y aobo-y commented Oct 21, 2022

FeatureAblation(Permutation) never clearly define the type of forward's return. According to the documentation, tensor is the only acceptable type

The forward function can either return a scalar per example or a tensor of a fixed sized tensor (or scalar value) for the full batch

However, returning a single int or float is a common use case we have already supported (ref #1047 (comment)). But our code did not explicitly raise error for unexpected types. Other types like list, tuple, or numpy.array may pass unexpectedly or fail in unexpected places with confusing error messages, such as we may use list as torch.dtype

dtype=attrib_type,

The PR explicitly assert the return type and convert everything into tensor. The assertion & conversion is done in a new private _run_forward wrapper instead of _run_forward from global utils, which is shared by many other classes. I will update others progressively and eventually push the logic to the shared _run_forward.

# our tests expect int -> torch.int64, float -> torch.float64
# but this may actually depend on the machine
# ref: https://docs.python.org/3.10/library/stdtypes.html#typesnumeric
return torch.tensor(forward_output, dtype=output_type)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I inherit our original logic that passing python types to torch dtype, like dtype=float. But this is not an officially documented operation. Existing tests assume it must equal to dtype=torch.float64 https://github.com/pytorch/captum/blob/5f878af6a7/tests/attr/test_feature_ablation.py#L429

But this may be machine dependent https://docs.python.org/3.10/library/stdtypes.html#typesnumeric .

Floating point numbers are usually implemented using double in C; information about the precision and internal representation of floating point numbers for the machine on which your program is running is available in sys.float_info

Two other alternatives are:

  • explicitly map python type to torch dtype: float -> torch.float64
  • do not set dtype, rely on torch's default dtype (float32)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interesting point! I looked into it a bit, this functionality seems to be added in this PR: pytorch/pytorch#21215
It looks like the type mapping is done explicitly on the C++ side using the PyObject type, so this shouldn't be affected by the internal representation. This is the logic for mapping:

  PyObject *obj = args[i];
  if (obj == (PyObject*)&PyFloat_Type) {
    return at::ScalarType::Double;
  }
  if (obj == (PyObject*)&PyBool_Type) {
    return at::ScalarType::Bool;
  }
  if (obj == (PyObject*)&PyLong_Type
#if PY_MAJOR_VERSION == 2
      || obj == (PyObject*)&PyInt_Type
#endif
  ) {
    return at::ScalarType::Long;
  }

So if float is set as dtype, this would be passed through the Python / C++ bindings as PyFloat_Type, which should always correspond to ScalarType::Double / torch.float64. The tests in the original PR also verify this mapping.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx for the deep dive @vivekmig !
Then I will just add the comment to refer the mapping, also as a caveat.
After all, it is not a documented torch usage. May have breaking changes someday.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, sounds good!

@facebook-github-bot
Copy link
Contributor

@aobo-y has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Contributor

@vivekmig vivekmig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks! Just one note on the dtype comments.

# our tests expect int -> torch.int64, float -> torch.float64
# but this may actually depend on the machine
# ref: https://docs.python.org/3.10/library/stdtypes.html#typesnumeric
return torch.tensor(forward_output, dtype=output_type)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interesting point! I looked into it a bit, this functionality seems to be added in this PR: pytorch/pytorch#21215
It looks like the type mapping is done explicitly on the C++ side using the PyObject type, so this shouldn't be affected by the internal representation. This is the logic for mapping:

  PyObject *obj = args[i];
  if (obj == (PyObject*)&PyFloat_Type) {
    return at::ScalarType::Double;
  }
  if (obj == (PyObject*)&PyBool_Type) {
    return at::ScalarType::Bool;
  }
  if (obj == (PyObject*)&PyLong_Type
#if PY_MAJOR_VERSION == 2
      || obj == (PyObject*)&PyInt_Type
#endif
  ) {
    return at::ScalarType::Long;
  }

So if float is set as dtype, this would be passed through the Python / C++ bindings as PyFloat_Type, which should always correspond to ScalarType::Double / torch.float64. The tests in the original PR also verify this mapping.

@@ -601,3 +593,20 @@ def _find_output_mode(
feature_mask is None
or all(len(sm.shape) == 0 or sm.shape[0] == 1 for sm in feature_mask)
)

def _run_forward(self, *args, **kwargs) -> Tensor:
forward_output = _run_forward(*args, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: It seems a bit confusing when seeing both the instance method and original method named as _run_forward, could consider renaming this one slightly, but either way is fine.

@facebook-github-bot
Copy link
Contributor

@aobo-y has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants