From 4d633913f7e5de824d37c6d1f2c06e7abde1aa24 Mon Sep 17 00:00:00 2001 From: Brian Park Date: Wed, 29 Nov 2023 17:23:23 -0500 Subject: [PATCH 1/3] feat: learned interiors for numerical calibrator --- pyproject.toml | 2 +- .../layers/numerical_calibrator.py | 32 +++++- tests/layers/test_numerical_calibrator.py | 103 ++++++++++++++++++ 3 files changed, 135 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0872a4f..cb9fb42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,7 +72,7 @@ skip-magic-trailing-comma = false line-ending = "auto" [tool.mypy] -exclude = ["examples"] +exclude = ["examples", "venv"] [build-system] requires = ["poetry-core"] diff --git a/pytorch_lattice/layers/numerical_calibrator.py b/pytorch_lattice/layers/numerical_calibrator.py index bccc014..c447374 100644 --- a/pytorch_lattice/layers/numerical_calibrator.py +++ b/pytorch_lattice/layers/numerical_calibrator.py @@ -23,6 +23,8 @@ class NumericalCalibrator(ConstrainedModule): Attributes: All: `__init__` arguments. + interpolation_logits: `torch.nn.Parameter` that stores the logits representing + keypoint values when `input_keypoints_type == "learned_interior"`. kernel: `torch.nn.Parameter` that stores the piece-wise linear function weights. missing_output: `torch.nn.Parameter` that stores the output learned for any missing inputs. Only available if `missing_input_value` is provided. @@ -50,6 +52,7 @@ def __init__( monotonicity: Optional[Monotonicity] = None, kernel_init: NumericalCalibratorInit = NumericalCalibratorInit.EQUAL_HEIGHTS, projection_iterations: int = 8, + input_keypoints_type="fixed", ) -> None: """Initializes an instance of `NumericalCalibrator`. @@ -67,9 +70,14 @@ def __init__( kernel_init: Initialization scheme to use for the kernel. projection_iterations: Number of times to run Dykstra's projection algorithm when applying constraints. + input_keypoints_type: Either "fixed" or "learned_interior". If + "learned_interior", keypoints follow "input_keypoints" for + initialization but vary during training, except the first and last + keypoints. Raises: - ValueError: If `kernel_init` is invalid. + ValueError: If `kernel_init` is invalid, or if `input_keypoints_type` is + invalid. """ super().__init__() @@ -80,6 +88,7 @@ def __init__( self.monotonicity = monotonicity self.kernel_init = kernel_init self.projection_iterations = projection_iterations + self.input_keypoints_type = input_keypoints_type # Determine default output initialization values if bounds are not fully set. if output_min is not None and output_max is not None: @@ -94,6 +103,15 @@ def __init__( self._interpolation_keypoints = torch.from_numpy(input_keypoints[:-1]) self._lengths = torch.from_numpy(input_keypoints[1:] - input_keypoints[:-1]) + if self.input_keypoints_type == "learned_interior": + self._keypoint_min = input_keypoints[0] + self._keypoint_range = input_keypoints[-1] - input_keypoints[0] + initial_logits = torch.from_numpy( + np.log( + (input_keypoints[1:] - input_keypoints[:-1]) / self._keypoint_range + ) + ).double() + self._interpolation_logits = torch.nn.Parameter(initial_logits) # First row of the kernel represents the bias. The remaining rows represent # the y-value delta compared to the previous point i.e. the segment heights. @@ -136,6 +154,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor of shape `(batch_size, 1)` containing calibrated input values. """ + if self.input_keypoints_type == "learned_interior": + softmaxed_logits = torch.nn.functional.softmax( + self._interpolation_logits, dim=-1 + ) + self._lengths = softmaxed_logits * self._keypoint_range + interior_keypoints = ( + torch.cumsum(self._lengths, dim=-1) + self._keypoint_min + ) + self._interpolation_keypoints = torch.cat( + [torch.tensor([self._keypoint_min]), interior_keypoints[:-1]] + ) + interpolation_weights = (x - self._interpolation_keypoints) / self._lengths interpolation_weights = torch.minimum(interpolation_weights, torch.tensor(1.0)) interpolation_weights = torch.maximum(interpolation_weights, torch.tensor(0.0)) diff --git a/tests/layers/test_numerical_calibrator.py b/tests/layers/test_numerical_calibrator.py index 64d6549..4830cc7 100644 --- a/tests/layers/test_numerical_calibrator.py +++ b/tests/layers/test_numerical_calibrator.py @@ -75,6 +75,35 @@ def test_initialization( assert calibrator.projection_iterations == projection_iterations +@pytest.mark.parametrize( + "input_keypoints, expected_lengths, expected_logits", + [ + ( + np.linspace(1.0, 5.0, num=5), + torch.tensor([1.0, 1.0, 1.0, 1.0], dtype=torch.double), + torch.from_numpy(np.log([0.25, 0.25, 0.25, 0.25])).double(), + ), + ( + np.array([0.0, 1.5, 2.0, 2.4, 3.0]), + torch.tensor([1.5, 0.5, 0.4, 0.6], dtype=torch.double), + torch.from_numpy( + np.log([1.5 / 3.0, 0.5 / 3.0, 0.4 / 3.0, 0.6 / 3.0]) + ).double(), + ), + ], +) +def test_initialization_learned_inputs( + input_keypoints, expected_lengths, expected_logits +): + """Tests logic specific to learned interior initialization.""" + calibrator = NumericalCalibrator( + input_keypoints=input_keypoints, input_keypoints_type="learned_interior" + ) + + assert torch.allclose(calibrator._lengths, expected_lengths) + assert torch.allclose(calibrator._interpolation_logits, expected_logits) + + @pytest.mark.parametrize( "input_keypoints,kernel_init,kernel_data,inputs,expected_outputs", [ @@ -170,6 +199,32 @@ def test_forward(input_keypoints, kernel_init, kernel_data, inputs, expected_out assert torch.allclose(outputs, expected_outputs) +@pytest.mark.parametrize( + "input_keypoints", + [ + (np.linspace(1, 4, num=4)), + (np.array([0.0, 1.2, 2.0, 3.7, 5.0])), + (np.linspace(1, 20, num=45)), + (np.array([0.0, 0.02, 0.023, 3.7, 5.0, 7.9, 9.9, 10.3, 12.4, 15.6, 51.2])), + ], +) +@pytest.mark.parametrize( + "x", + [(torch.tensor([[1.1]])), (torch.tensor([[1.2], [1.3]]))], +) +def test_forward_learned_inputs(input_keypoints, x): + """Tests that interior keypoints are properly reconstructed for learned inputs.""" + calibrator = NumericalCalibrator( + input_keypoints, input_keypoints_type="learned_interior" + ) + calibrator.forward(x) + assert torch.sum(calibrator._lengths) == calibrator._keypoint_range + assert torch.allclose( + torch.from_numpy(input_keypoints[:-1]).double(), + calibrator._interpolation_keypoints, + ) + + @pytest.mark.parametrize( "kernel_data,monotonicity, expected_out", [ @@ -839,3 +894,51 @@ def test_training(): assert torch.allclose( torch.absolute(keypoints_inputs), keypoints_outputs, atol=2e-2 ) + + +def test_training_learned_interior(): + """Tests that the `NumericalCalibrator` can learn interior keypoints on a PWL.""" + num_examples = 1000 + output_min, output_max = 0.0, 1.0 + training_examples = torch.linspace(output_min, output_max, num_examples)[ + :, None + ].double() + training_labels = torch.where( + training_examples < 1 / 3, + torch.zeros_like(training_examples), + torch.where( + training_examples > 2 / 3, + torch.ones_like(training_examples), + 3 * training_examples - 1, + ), + ).double() + noise = torch.randn_like(training_labels) * 0.05 + training_labels += noise + + calibrator = NumericalCalibrator( + np.array([0.0, 0.1, 0.9, 1.0]), + output_min=output_min, + output_max=output_max, + monotonicity=None, + input_keypoints_type="learned_interior", + ) + + loss_fn = torch.nn.MSELoss() + optimizer = torch.optim.Adam(calibrator.parameters(), lr=1e-3) + + train_calibrated_module( + calibrator, + training_examples, + training_labels, + loss_fn, + optimizer, + 300, # Number of epochs + num_examples // 10, # Batch size + ) + + # Test that the learned keypoints roughly match the expected ones + assert torch.allclose( + calibrator._interpolation_keypoints, + torch.tensor([0, 1 / 3, 2 / 3]).double(), + atol=0.05, + ) From d5edaa8e7848983adee7175ce5115572437defcb Mon Sep 17 00:00:00 2001 From: Brian Park Date: Wed, 29 Nov 2023 17:28:07 -0500 Subject: [PATCH 2/3] chore: updated feature request template --- .github/ISSUE_TEMPLATES/feature_request.md | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/.github/ISSUE_TEMPLATES/feature_request.md b/.github/ISSUE_TEMPLATES/feature_request.md index 2bc5d5f..cf9deff 100644 --- a/.github/ISSUE_TEMPLATES/feature_request.md +++ b/.github/ISSUE_TEMPLATES/feature_request.md @@ -1,9 +1,8 @@ --- -name: Feature request -about: Suggest an idea for this project -title: "" -labels: "" -assignees: "" + +## About: +Suggest an idea for this project + --- **Is your feature request related to a problem? Please describe.** From b499b52d3cc5ec73ed120cae06c05ae7e14d2cf9 Mon Sep 17 00:00:00 2001 From: Brian Park Date: Thu, 30 Nov 2023 12:50:06 -0500 Subject: [PATCH 3/3] chore: docstring, function name, and InputKeypointType enum edits --- pytorch_lattice/enums.py | 5 ++-- .../layers/numerical_calibrator.py | 20 +++++-------- tests/layers/test_numerical_calibrator.py | 30 ++++++++++++------- 3 files changed, 30 insertions(+), 25 deletions(-) diff --git a/pytorch_lattice/enums.py b/pytorch_lattice/enums.py index 3597064..06d67d1 100644 --- a/pytorch_lattice/enums.py +++ b/pytorch_lattice/enums.py @@ -33,11 +33,12 @@ class InputKeypointsType(_Enum): """The type of input keypoints to use. - FIXED: the input keypoints will be fixed during initialization. + - LEARNED: the interior keypoints will learn through training to best fit the + piecewise linear function. """ FIXED = "fixed" - # TODO: add learned interior functionality - # LEARNED = "learned_interior" + LEARNED = "learned" class NumericalCalibratorInit(_Enum): diff --git a/pytorch_lattice/layers/numerical_calibrator.py b/pytorch_lattice/layers/numerical_calibrator.py index c447374..6a15a6b 100644 --- a/pytorch_lattice/layers/numerical_calibrator.py +++ b/pytorch_lattice/layers/numerical_calibrator.py @@ -11,7 +11,7 @@ import torch from ..constrained_module import ConstrainedModule -from ..enums import Monotonicity, NumericalCalibratorInit +from ..enums import InputKeypointsType, Monotonicity, NumericalCalibratorInit class NumericalCalibrator(ConstrainedModule): @@ -23,8 +23,6 @@ class NumericalCalibrator(ConstrainedModule): Attributes: All: `__init__` arguments. - interpolation_logits: `torch.nn.Parameter` that stores the logits representing - keypoint values when `input_keypoints_type == "learned_interior"`. kernel: `torch.nn.Parameter` that stores the piece-wise linear function weights. missing_output: `torch.nn.Parameter` that stores the output learned for any missing inputs. Only available if `missing_input_value` is provided. @@ -52,7 +50,7 @@ def __init__( monotonicity: Optional[Monotonicity] = None, kernel_init: NumericalCalibratorInit = NumericalCalibratorInit.EQUAL_HEIGHTS, projection_iterations: int = 8, - input_keypoints_type="fixed", + input_keypoints_type: InputKeypointsType = InputKeypointsType.FIXED, ) -> None: """Initializes an instance of `NumericalCalibrator`. @@ -70,14 +68,12 @@ def __init__( kernel_init: Initialization scheme to use for the kernel. projection_iterations: Number of times to run Dykstra's projection algorithm when applying constraints. - input_keypoints_type: Either "fixed" or "learned_interior". If - "learned_interior", keypoints follow "input_keypoints" for - initialization but vary during training, except the first and last - keypoints. + input_keypoints_type: `InputKeypointType` of either `FIXED` or `LEARNED`. If + `LEARNED`, keypoints other than the first or last will follow + `input_keypoints` for initialization but adapt during training. Raises: - ValueError: If `kernel_init` is invalid, or if `input_keypoints_type` is - invalid. + ValueError: If `kernel_init` is invalid. """ super().__init__() @@ -103,7 +99,7 @@ def __init__( self._interpolation_keypoints = torch.from_numpy(input_keypoints[:-1]) self._lengths = torch.from_numpy(input_keypoints[1:] - input_keypoints[:-1]) - if self.input_keypoints_type == "learned_interior": + if self.input_keypoints_type == InputKeypointsType.LEARNED: self._keypoint_min = input_keypoints[0] self._keypoint_range = input_keypoints[-1] - input_keypoints[0] initial_logits = torch.from_numpy( @@ -154,7 +150,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor of shape `(batch_size, 1)` containing calibrated input values. """ - if self.input_keypoints_type == "learned_interior": + if self.input_keypoints_type == InputKeypointsType.LEARNED: softmaxed_logits = torch.nn.functional.softmax( self._interpolation_logits, dim=-1 ) diff --git a/tests/layers/test_numerical_calibrator.py b/tests/layers/test_numerical_calibrator.py index 4830cc7..592dbe0 100644 --- a/tests/layers/test_numerical_calibrator.py +++ b/tests/layers/test_numerical_calibrator.py @@ -4,6 +4,7 @@ import torch from pytorch_lattice import Monotonicity, NumericalCalibratorInit +from pytorch_lattice.enums import InputKeypointsType from pytorch_lattice.layers import NumericalCalibrator from ..testing_utils import train_calibrated_module @@ -92,12 +93,12 @@ def test_initialization( ), ], ) -def test_initialization_learned_inputs( +def test_initialization_learned_input_keypoints( input_keypoints, expected_lengths, expected_logits ): - """Tests logic specific to learned interior initialization.""" + """Tests logic specific to learned input keypoint initialization.""" calibrator = NumericalCalibrator( - input_keypoints=input_keypoints, input_keypoints_type="learned_interior" + input_keypoints=input_keypoints, input_keypoints_type=InputKeypointsType.LEARNED ) assert torch.allclose(calibrator._lengths, expected_lengths) @@ -212,13 +213,15 @@ def test_forward(input_keypoints, kernel_init, kernel_data, inputs, expected_out "x", [(torch.tensor([[1.1]])), (torch.tensor([[1.2], [1.3]]))], ) -def test_forward_learned_inputs(input_keypoints, x): - """Tests that interior keypoints are properly reconstructed for learned inputs.""" +def test_forward_learned_input_keypoints(input_keypoints, x): + """Tests that learned input keypoints are properly reconstructed in forward.""" calibrator = NumericalCalibrator( - input_keypoints, input_keypoints_type="learned_interior" + input_keypoints, input_keypoints_type=InputKeypointsType.LEARNED ) calibrator.forward(x) - assert torch.sum(calibrator._lengths) == calibrator._keypoint_range + assert ( + abs(torch.sum(calibrator._lengths).item() - calibrator._keypoint_range) < 1e-6 + ) assert torch.allclose( torch.from_numpy(input_keypoints[:-1]).double(), calibrator._interpolation_keypoints, @@ -896,8 +899,13 @@ def test_training(): ) -def test_training_learned_interior(): - """Tests that the `NumericalCalibrator` can learn interior keypoints on a PWL.""" +def test_training_learned_interior_input_keypoints(): + """Tests that `NumericalCalibrator` successfully learns interior input keypoints. + The calibrator is given a piecewise linear step function that starts at (0,0), + linearly rises from (1/3, 0) to (2/3, 1), then stays at 1. The calibrator is + initialized with inaccurate interior keypoints [0.1, 0.9] and then tested to see if + it can learn the ideal interior keypoint positions of [0.33, 0.66]. + """ num_examples = 1000 output_min, output_max = 0.0, 1.0 training_examples = torch.linspace(output_min, output_max, num_examples)[ @@ -920,7 +928,7 @@ def test_training_learned_interior(): output_min=output_min, output_max=output_max, monotonicity=None, - input_keypoints_type="learned_interior", + input_keypoints_type=InputKeypointsType.LEARNED, ) loss_fn = torch.nn.MSELoss() @@ -940,5 +948,5 @@ def test_training_learned_interior(): assert torch.allclose( calibrator._interpolation_keypoints, torch.tensor([0, 1 / 3, 2 / 3]).double(), - atol=0.05, + atol=0.02, )