Skip to content

Commit

Permalink
Merge pull request #9 from ControlAI/numcal_input_keypoints
Browse files Browse the repository at this point in the history
Learned Interiors functionality for numerical calibrators
  • Loading branch information
willbakst committed Dec 6, 2023
2 parents ef6a4ef + b499b52 commit b34f5f2
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 9 deletions.
9 changes: 4 additions & 5 deletions .github/ISSUE_TEMPLATES/feature_request.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
---
name: Feature request
about: Suggest an idea for this project
title: ""
labels: ""
assignees: ""

## About:
Suggest an idea for this project

---

**Is your feature request related to a problem? Please describe.**
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ skip-magic-trailing-comma = false
line-ending = "auto"

[tool.mypy]
exclude = ["examples"]
exclude = ["examples", "venv"]

[build-system]
requires = ["poetry-core"]
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lattice/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ class InputKeypointsType(_Enum):
"""The type of input keypoints to use.
- FIXED: the input keypoints will be fixed during initialization.
- LEARNED: the interior keypoints will learn through training to best fit the
piecewise linear function.
"""

FIXED = "fixed"
# TODO: add learned interior functionality
# LEARNED = "learned_interior"
LEARNED = "learned"


class NumericalCalibratorInit(_Enum):
Expand Down
28 changes: 27 additions & 1 deletion pytorch_lattice/layers/numerical_calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch

from ..constrained_module import ConstrainedModule
from ..enums import Monotonicity, NumericalCalibratorInit
from ..enums import InputKeypointsType, Monotonicity, NumericalCalibratorInit


class NumericalCalibrator(ConstrainedModule):
Expand Down Expand Up @@ -50,6 +50,7 @@ def __init__(
monotonicity: Optional[Monotonicity] = None,
kernel_init: NumericalCalibratorInit = NumericalCalibratorInit.EQUAL_HEIGHTS,
projection_iterations: int = 8,
input_keypoints_type: InputKeypointsType = InputKeypointsType.FIXED,
) -> None:
"""Initializes an instance of `NumericalCalibrator`.
Expand All @@ -67,6 +68,9 @@ def __init__(
kernel_init: Initialization scheme to use for the kernel.
projection_iterations: Number of times to run Dykstra's projection
algorithm when applying constraints.
input_keypoints_type: `InputKeypointType` of either `FIXED` or `LEARNED`. If
`LEARNED`, keypoints other than the first or last will follow
`input_keypoints` for initialization but adapt during training.
Raises:
ValueError: If `kernel_init` is invalid.
Expand All @@ -80,6 +84,7 @@ def __init__(
self.monotonicity = monotonicity
self.kernel_init = kernel_init
self.projection_iterations = projection_iterations
self.input_keypoints_type = input_keypoints_type

# Determine default output initialization values if bounds are not fully set.
if output_min is not None and output_max is not None:
Expand All @@ -94,6 +99,15 @@ def __init__(

self._interpolation_keypoints = torch.from_numpy(input_keypoints[:-1])
self._lengths = torch.from_numpy(input_keypoints[1:] - input_keypoints[:-1])
if self.input_keypoints_type == InputKeypointsType.LEARNED:
self._keypoint_min = input_keypoints[0]
self._keypoint_range = input_keypoints[-1] - input_keypoints[0]
initial_logits = torch.from_numpy(
np.log(
(input_keypoints[1:] - input_keypoints[:-1]) / self._keypoint_range
)
).double()
self._interpolation_logits = torch.nn.Parameter(initial_logits)

# First row of the kernel represents the bias. The remaining rows represent
# the y-value delta compared to the previous point i.e. the segment heights.
Expand Down Expand Up @@ -136,6 +150,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor of shape `(batch_size, 1)` containing calibrated input values.
"""
if self.input_keypoints_type == InputKeypointsType.LEARNED:
softmaxed_logits = torch.nn.functional.softmax(
self._interpolation_logits, dim=-1
)
self._lengths = softmaxed_logits * self._keypoint_range
interior_keypoints = (
torch.cumsum(self._lengths, dim=-1) + self._keypoint_min
)
self._interpolation_keypoints = torch.cat(
[torch.tensor([self._keypoint_min]), interior_keypoints[:-1]]
)

interpolation_weights = (x - self._interpolation_keypoints) / self._lengths
interpolation_weights = torch.minimum(interpolation_weights, torch.tensor(1.0))
interpolation_weights = torch.maximum(interpolation_weights, torch.tensor(0.0))
Expand Down
111 changes: 111 additions & 0 deletions tests/layers/test_numerical_calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch

from pytorch_lattice import Monotonicity, NumericalCalibratorInit
from pytorch_lattice.enums import InputKeypointsType
from pytorch_lattice.layers import NumericalCalibrator

from ..testing_utils import train_calibrated_module
Expand Down Expand Up @@ -75,6 +76,35 @@ def test_initialization(
assert calibrator.projection_iterations == projection_iterations


@pytest.mark.parametrize(
"input_keypoints, expected_lengths, expected_logits",
[
(
np.linspace(1.0, 5.0, num=5),
torch.tensor([1.0, 1.0, 1.0, 1.0], dtype=torch.double),
torch.from_numpy(np.log([0.25, 0.25, 0.25, 0.25])).double(),
),
(
np.array([0.0, 1.5, 2.0, 2.4, 3.0]),
torch.tensor([1.5, 0.5, 0.4, 0.6], dtype=torch.double),
torch.from_numpy(
np.log([1.5 / 3.0, 0.5 / 3.0, 0.4 / 3.0, 0.6 / 3.0])
).double(),
),
],
)
def test_initialization_learned_input_keypoints(
input_keypoints, expected_lengths, expected_logits
):
"""Tests logic specific to learned input keypoint initialization."""
calibrator = NumericalCalibrator(
input_keypoints=input_keypoints, input_keypoints_type=InputKeypointsType.LEARNED
)

assert torch.allclose(calibrator._lengths, expected_lengths)
assert torch.allclose(calibrator._interpolation_logits, expected_logits)


@pytest.mark.parametrize(
"input_keypoints,kernel_init,kernel_data,inputs,expected_outputs",
[
Expand Down Expand Up @@ -170,6 +200,34 @@ def test_forward(input_keypoints, kernel_init, kernel_data, inputs, expected_out
assert torch.allclose(outputs, expected_outputs)


@pytest.mark.parametrize(
"input_keypoints",
[
(np.linspace(1, 4, num=4)),
(np.array([0.0, 1.2, 2.0, 3.7, 5.0])),
(np.linspace(1, 20, num=45)),
(np.array([0.0, 0.02, 0.023, 3.7, 5.0, 7.9, 9.9, 10.3, 12.4, 15.6, 51.2])),
],
)
@pytest.mark.parametrize(
"x",
[(torch.tensor([[1.1]])), (torch.tensor([[1.2], [1.3]]))],
)
def test_forward_learned_input_keypoints(input_keypoints, x):
"""Tests that learned input keypoints are properly reconstructed in forward."""
calibrator = NumericalCalibrator(
input_keypoints, input_keypoints_type=InputKeypointsType.LEARNED
)
calibrator.forward(x)
assert (
abs(torch.sum(calibrator._lengths).item() - calibrator._keypoint_range) < 1e-6
)
assert torch.allclose(
torch.from_numpy(input_keypoints[:-1]).double(),
calibrator._interpolation_keypoints,
)


@pytest.mark.parametrize(
"kernel_data,monotonicity, expected_out",
[
Expand Down Expand Up @@ -839,3 +897,56 @@ def test_training():
assert torch.allclose(
torch.absolute(keypoints_inputs), keypoints_outputs, atol=2e-2
)


def test_training_learned_interior_input_keypoints():
"""Tests that `NumericalCalibrator` successfully learns interior input keypoints.
The calibrator is given a piecewise linear step function that starts at (0,0),
linearly rises from (1/3, 0) to (2/3, 1), then stays at 1. The calibrator is
initialized with inaccurate interior keypoints [0.1, 0.9] and then tested to see if
it can learn the ideal interior keypoint positions of [0.33, 0.66].
"""
num_examples = 1000
output_min, output_max = 0.0, 1.0
training_examples = torch.linspace(output_min, output_max, num_examples)[
:, None
].double()
training_labels = torch.where(
training_examples < 1 / 3,
torch.zeros_like(training_examples),
torch.where(
training_examples > 2 / 3,
torch.ones_like(training_examples),
3 * training_examples - 1,
),
).double()
noise = torch.randn_like(training_labels) * 0.05
training_labels += noise

calibrator = NumericalCalibrator(
np.array([0.0, 0.1, 0.9, 1.0]),
output_min=output_min,
output_max=output_max,
monotonicity=None,
input_keypoints_type=InputKeypointsType.LEARNED,
)

loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(calibrator.parameters(), lr=1e-3)

train_calibrated_module(
calibrator,
training_examples,
training_labels,
loss_fn,
optimizer,
300, # Number of epochs
num_examples // 10, # Batch size
)

# Test that the learned keypoints roughly match the expected ones
assert torch.allclose(
calibrator._interpolation_keypoints,
torch.tensor([0, 1 / 3, 2 / 3]).double(),
atol=0.02,
)

0 comments on commit b34f5f2

Please sign in to comment.