Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Learned Interiors functionality for numerical calibrators #9

Merged
merged 4 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 4 additions & 5 deletions .github/ISSUE_TEMPLATES/feature_request.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
---
name: Feature request
about: Suggest an idea for this project
title: ""
labels: ""
assignees: ""

## About:
Suggest an idea for this project

---

**Is your feature request related to a problem? Please describe.**
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ skip-magic-trailing-comma = false
line-ending = "auto"

[tool.mypy]
exclude = ["examples"]
exclude = ["examples", "venv"]

[build-system]
requires = ["poetry-core"]
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lattice/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ class InputKeypointsType(_Enum):
"""The type of input keypoints to use.

- FIXED: the input keypoints will be fixed during initialization.
- LEARNED: the interior keypoints will learn through training to best fit the
piecewise linear function.
"""

FIXED = "fixed"
# TODO: add learned interior functionality
# LEARNED = "learned_interior"
LEARNED = "learned"


class NumericalCalibratorInit(_Enum):
Expand Down
28 changes: 27 additions & 1 deletion pytorch_lattice/layers/numerical_calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch

from ..constrained_module import ConstrainedModule
from ..enums import Monotonicity, NumericalCalibratorInit
from ..enums import InputKeypointsType, Monotonicity, NumericalCalibratorInit


class NumericalCalibrator(ConstrainedModule):
Expand Down Expand Up @@ -50,6 +50,7 @@ def __init__(
monotonicity: Optional[Monotonicity] = None,
kernel_init: NumericalCalibratorInit = NumericalCalibratorInit.EQUAL_HEIGHTS,
projection_iterations: int = 8,
input_keypoints_type: InputKeypointsType = InputKeypointsType.FIXED,
) -> None:
"""Initializes an instance of `NumericalCalibrator`.

Expand All @@ -67,6 +68,9 @@ def __init__(
kernel_init: Initialization scheme to use for the kernel.
projection_iterations: Number of times to run Dykstra's projection
algorithm when applying constraints.
input_keypoints_type: `InputKeypointType` of either `FIXED` or `LEARNED`. If
`LEARNED`, keypoints other than the first or last will follow
`input_keypoints` for initialization but adapt during training.

Raises:
ValueError: If `kernel_init` is invalid.
Expand All @@ -80,6 +84,7 @@ def __init__(
self.monotonicity = monotonicity
self.kernel_init = kernel_init
self.projection_iterations = projection_iterations
self.input_keypoints_type = input_keypoints_type

# Determine default output initialization values if bounds are not fully set.
if output_min is not None and output_max is not None:
Expand All @@ -94,6 +99,15 @@ def __init__(

self._interpolation_keypoints = torch.from_numpy(input_keypoints[:-1])
self._lengths = torch.from_numpy(input_keypoints[1:] - input_keypoints[:-1])
if self.input_keypoints_type == InputKeypointsType.LEARNED:
self._keypoint_min = input_keypoints[0]
self._keypoint_range = input_keypoints[-1] - input_keypoints[0]
initial_logits = torch.from_numpy(
np.log(
(input_keypoints[1:] - input_keypoints[:-1]) / self._keypoint_range
)
).double()
self._interpolation_logits = torch.nn.Parameter(initial_logits)

# First row of the kernel represents the bias. The remaining rows represent
# the y-value delta compared to the previous point i.e. the segment heights.
Expand Down Expand Up @@ -136,6 +150,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor of shape `(batch_size, 1)` containing calibrated input values.
"""
if self.input_keypoints_type == InputKeypointsType.LEARNED:
softmaxed_logits = torch.nn.functional.softmax(
self._interpolation_logits, dim=-1
)
self._lengths = softmaxed_logits * self._keypoint_range
interior_keypoints = (
torch.cumsum(self._lengths, dim=-1) + self._keypoint_min
)
self._interpolation_keypoints = torch.cat(
[torch.tensor([self._keypoint_min]), interior_keypoints[:-1]]
)

interpolation_weights = (x - self._interpolation_keypoints) / self._lengths
interpolation_weights = torch.minimum(interpolation_weights, torch.tensor(1.0))
interpolation_weights = torch.maximum(interpolation_weights, torch.tensor(0.0))
Expand Down
111 changes: 111 additions & 0 deletions tests/layers/test_numerical_calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch

from pytorch_lattice import Monotonicity, NumericalCalibratorInit
from pytorch_lattice.enums import InputKeypointsType
from pytorch_lattice.layers import NumericalCalibrator

from ..testing_utils import train_calibrated_module
Expand Down Expand Up @@ -75,6 +76,35 @@ def test_initialization(
assert calibrator.projection_iterations == projection_iterations


@pytest.mark.parametrize(
"input_keypoints, expected_lengths, expected_logits",
[
(
np.linspace(1.0, 5.0, num=5),
torch.tensor([1.0, 1.0, 1.0, 1.0], dtype=torch.double),
torch.from_numpy(np.log([0.25, 0.25, 0.25, 0.25])).double(),
),
(
np.array([0.0, 1.5, 2.0, 2.4, 3.0]),
torch.tensor([1.5, 0.5, 0.4, 0.6], dtype=torch.double),
torch.from_numpy(
np.log([1.5 / 3.0, 0.5 / 3.0, 0.4 / 3.0, 0.6 / 3.0])
).double(),
),
],
)
def test_initialization_learned_input_keypoints(
input_keypoints, expected_lengths, expected_logits
):
"""Tests logic specific to learned input keypoint initialization."""
calibrator = NumericalCalibrator(
input_keypoints=input_keypoints, input_keypoints_type=InputKeypointsType.LEARNED
)

assert torch.allclose(calibrator._lengths, expected_lengths)
assert torch.allclose(calibrator._interpolation_logits, expected_logits)


@pytest.mark.parametrize(
"input_keypoints,kernel_init,kernel_data,inputs,expected_outputs",
[
Expand Down Expand Up @@ -170,6 +200,34 @@ def test_forward(input_keypoints, kernel_init, kernel_data, inputs, expected_out
assert torch.allclose(outputs, expected_outputs)


@pytest.mark.parametrize(
"input_keypoints",
[
(np.linspace(1, 4, num=4)),
(np.array([0.0, 1.2, 2.0, 3.7, 5.0])),
(np.linspace(1, 20, num=45)),
(np.array([0.0, 0.02, 0.023, 3.7, 5.0, 7.9, 9.9, 10.3, 12.4, 15.6, 51.2])),
],
)
@pytest.mark.parametrize(
"x",
[(torch.tensor([[1.1]])), (torch.tensor([[1.2], [1.3]]))],
)
def test_forward_learned_input_keypoints(input_keypoints, x):
"""Tests that learned input keypoints are properly reconstructed in forward."""
calibrator = NumericalCalibrator(
input_keypoints, input_keypoints_type=InputKeypointsType.LEARNED
)
calibrator.forward(x)
assert (
abs(torch.sum(calibrator._lengths).item() - calibrator._keypoint_range) < 1e-6
)
assert torch.allclose(
torch.from_numpy(input_keypoints[:-1]).double(),
calibrator._interpolation_keypoints,
)


@pytest.mark.parametrize(
"kernel_data,monotonicity, expected_out",
[
Expand Down Expand Up @@ -839,3 +897,56 @@ def test_training():
assert torch.allclose(
torch.absolute(keypoints_inputs), keypoints_outputs, atol=2e-2
)


def test_training_learned_interior_input_keypoints():
"""Tests that `NumericalCalibrator` successfully learns interior input keypoints.
The calibrator is given a piecewise linear step function that starts at (0,0),
linearly rises from (1/3, 0) to (2/3, 1), then stays at 1. The calibrator is
initialized with inaccurate interior keypoints [0.1, 0.9] and then tested to see if
it can learn the ideal interior keypoint positions of [0.33, 0.66].
"""
num_examples = 1000
output_min, output_max = 0.0, 1.0
training_examples = torch.linspace(output_min, output_max, num_examples)[
:, None
].double()
training_labels = torch.where(
training_examples < 1 / 3,
torch.zeros_like(training_examples),
torch.where(
training_examples > 2 / 3,
torch.ones_like(training_examples),
3 * training_examples - 1,
),
).double()
noise = torch.randn_like(training_labels) * 0.05
training_labels += noise

calibrator = NumericalCalibrator(
np.array([0.0, 0.1, 0.9, 1.0]),
output_min=output_min,
output_max=output_max,
monotonicity=None,
input_keypoints_type=InputKeypointsType.LEARNED,
)

loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(calibrator.parameters(), lr=1e-3)

train_calibrated_module(
calibrator,
training_examples,
training_labels,
loss_fn,
optimizer,
300, # Number of epochs
num_examples // 10, # Batch size
)

# Test that the learned keypoints roughly match the expected ones
assert torch.allclose(
calibrator._interpolation_keypoints,
torch.tensor([0, 1 / 3, 2 / 3]).double(),
atol=0.02,
)