diff --git a/.github/ISSUE_TEMPLATES/feature_request.md b/.github/ISSUE_TEMPLATES/feature_request.md index 2bc5d5f..cf9deff 100644 --- a/.github/ISSUE_TEMPLATES/feature_request.md +++ b/.github/ISSUE_TEMPLATES/feature_request.md @@ -1,9 +1,8 @@ --- -name: Feature request -about: Suggest an idea for this project -title: "" -labels: "" -assignees: "" + +## About: +Suggest an idea for this project + --- **Is your feature request related to a problem? Please describe.** diff --git a/pyproject.toml b/pyproject.toml index 0872a4f..cb9fb42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,7 +72,7 @@ skip-magic-trailing-comma = false line-ending = "auto" [tool.mypy] -exclude = ["examples"] +exclude = ["examples", "venv"] [build-system] requires = ["poetry-core"] diff --git a/pytorch_lattice/enums.py b/pytorch_lattice/enums.py index 3597064..06d67d1 100644 --- a/pytorch_lattice/enums.py +++ b/pytorch_lattice/enums.py @@ -33,11 +33,12 @@ class InputKeypointsType(_Enum): """The type of input keypoints to use. - FIXED: the input keypoints will be fixed during initialization. + - LEARNED: the interior keypoints will learn through training to best fit the + piecewise linear function. """ FIXED = "fixed" - # TODO: add learned interior functionality - # LEARNED = "learned_interior" + LEARNED = "learned" class NumericalCalibratorInit(_Enum): diff --git a/pytorch_lattice/layers/numerical_calibrator.py b/pytorch_lattice/layers/numerical_calibrator.py index bccc014..6a15a6b 100644 --- a/pytorch_lattice/layers/numerical_calibrator.py +++ b/pytorch_lattice/layers/numerical_calibrator.py @@ -11,7 +11,7 @@ import torch from ..constrained_module import ConstrainedModule -from ..enums import Monotonicity, NumericalCalibratorInit +from ..enums import InputKeypointsType, Monotonicity, NumericalCalibratorInit class NumericalCalibrator(ConstrainedModule): @@ -50,6 +50,7 @@ def __init__( monotonicity: Optional[Monotonicity] = None, kernel_init: NumericalCalibratorInit = NumericalCalibratorInit.EQUAL_HEIGHTS, projection_iterations: int = 8, + input_keypoints_type: InputKeypointsType = InputKeypointsType.FIXED, ) -> None: """Initializes an instance of `NumericalCalibrator`. @@ -67,6 +68,9 @@ def __init__( kernel_init: Initialization scheme to use for the kernel. projection_iterations: Number of times to run Dykstra's projection algorithm when applying constraints. + input_keypoints_type: `InputKeypointType` of either `FIXED` or `LEARNED`. If + `LEARNED`, keypoints other than the first or last will follow + `input_keypoints` for initialization but adapt during training. Raises: ValueError: If `kernel_init` is invalid. @@ -80,6 +84,7 @@ def __init__( self.monotonicity = monotonicity self.kernel_init = kernel_init self.projection_iterations = projection_iterations + self.input_keypoints_type = input_keypoints_type # Determine default output initialization values if bounds are not fully set. if output_min is not None and output_max is not None: @@ -94,6 +99,15 @@ def __init__( self._interpolation_keypoints = torch.from_numpy(input_keypoints[:-1]) self._lengths = torch.from_numpy(input_keypoints[1:] - input_keypoints[:-1]) + if self.input_keypoints_type == InputKeypointsType.LEARNED: + self._keypoint_min = input_keypoints[0] + self._keypoint_range = input_keypoints[-1] - input_keypoints[0] + initial_logits = torch.from_numpy( + np.log( + (input_keypoints[1:] - input_keypoints[:-1]) / self._keypoint_range + ) + ).double() + self._interpolation_logits = torch.nn.Parameter(initial_logits) # First row of the kernel represents the bias. The remaining rows represent # the y-value delta compared to the previous point i.e. the segment heights. @@ -136,6 +150,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor of shape `(batch_size, 1)` containing calibrated input values. """ + if self.input_keypoints_type == InputKeypointsType.LEARNED: + softmaxed_logits = torch.nn.functional.softmax( + self._interpolation_logits, dim=-1 + ) + self._lengths = softmaxed_logits * self._keypoint_range + interior_keypoints = ( + torch.cumsum(self._lengths, dim=-1) + self._keypoint_min + ) + self._interpolation_keypoints = torch.cat( + [torch.tensor([self._keypoint_min]), interior_keypoints[:-1]] + ) + interpolation_weights = (x - self._interpolation_keypoints) / self._lengths interpolation_weights = torch.minimum(interpolation_weights, torch.tensor(1.0)) interpolation_weights = torch.maximum(interpolation_weights, torch.tensor(0.0)) diff --git a/tests/layers/test_numerical_calibrator.py b/tests/layers/test_numerical_calibrator.py index 64d6549..592dbe0 100644 --- a/tests/layers/test_numerical_calibrator.py +++ b/tests/layers/test_numerical_calibrator.py @@ -4,6 +4,7 @@ import torch from pytorch_lattice import Monotonicity, NumericalCalibratorInit +from pytorch_lattice.enums import InputKeypointsType from pytorch_lattice.layers import NumericalCalibrator from ..testing_utils import train_calibrated_module @@ -75,6 +76,35 @@ def test_initialization( assert calibrator.projection_iterations == projection_iterations +@pytest.mark.parametrize( + "input_keypoints, expected_lengths, expected_logits", + [ + ( + np.linspace(1.0, 5.0, num=5), + torch.tensor([1.0, 1.0, 1.0, 1.0], dtype=torch.double), + torch.from_numpy(np.log([0.25, 0.25, 0.25, 0.25])).double(), + ), + ( + np.array([0.0, 1.5, 2.0, 2.4, 3.0]), + torch.tensor([1.5, 0.5, 0.4, 0.6], dtype=torch.double), + torch.from_numpy( + np.log([1.5 / 3.0, 0.5 / 3.0, 0.4 / 3.0, 0.6 / 3.0]) + ).double(), + ), + ], +) +def test_initialization_learned_input_keypoints( + input_keypoints, expected_lengths, expected_logits +): + """Tests logic specific to learned input keypoint initialization.""" + calibrator = NumericalCalibrator( + input_keypoints=input_keypoints, input_keypoints_type=InputKeypointsType.LEARNED + ) + + assert torch.allclose(calibrator._lengths, expected_lengths) + assert torch.allclose(calibrator._interpolation_logits, expected_logits) + + @pytest.mark.parametrize( "input_keypoints,kernel_init,kernel_data,inputs,expected_outputs", [ @@ -170,6 +200,34 @@ def test_forward(input_keypoints, kernel_init, kernel_data, inputs, expected_out assert torch.allclose(outputs, expected_outputs) +@pytest.mark.parametrize( + "input_keypoints", + [ + (np.linspace(1, 4, num=4)), + (np.array([0.0, 1.2, 2.0, 3.7, 5.0])), + (np.linspace(1, 20, num=45)), + (np.array([0.0, 0.02, 0.023, 3.7, 5.0, 7.9, 9.9, 10.3, 12.4, 15.6, 51.2])), + ], +) +@pytest.mark.parametrize( + "x", + [(torch.tensor([[1.1]])), (torch.tensor([[1.2], [1.3]]))], +) +def test_forward_learned_input_keypoints(input_keypoints, x): + """Tests that learned input keypoints are properly reconstructed in forward.""" + calibrator = NumericalCalibrator( + input_keypoints, input_keypoints_type=InputKeypointsType.LEARNED + ) + calibrator.forward(x) + assert ( + abs(torch.sum(calibrator._lengths).item() - calibrator._keypoint_range) < 1e-6 + ) + assert torch.allclose( + torch.from_numpy(input_keypoints[:-1]).double(), + calibrator._interpolation_keypoints, + ) + + @pytest.mark.parametrize( "kernel_data,monotonicity, expected_out", [ @@ -839,3 +897,56 @@ def test_training(): assert torch.allclose( torch.absolute(keypoints_inputs), keypoints_outputs, atol=2e-2 ) + + +def test_training_learned_interior_input_keypoints(): + """Tests that `NumericalCalibrator` successfully learns interior input keypoints. + The calibrator is given a piecewise linear step function that starts at (0,0), + linearly rises from (1/3, 0) to (2/3, 1), then stays at 1. The calibrator is + initialized with inaccurate interior keypoints [0.1, 0.9] and then tested to see if + it can learn the ideal interior keypoint positions of [0.33, 0.66]. + """ + num_examples = 1000 + output_min, output_max = 0.0, 1.0 + training_examples = torch.linspace(output_min, output_max, num_examples)[ + :, None + ].double() + training_labels = torch.where( + training_examples < 1 / 3, + torch.zeros_like(training_examples), + torch.where( + training_examples > 2 / 3, + torch.ones_like(training_examples), + 3 * training_examples - 1, + ), + ).double() + noise = torch.randn_like(training_labels) * 0.05 + training_labels += noise + + calibrator = NumericalCalibrator( + np.array([0.0, 0.1, 0.9, 1.0]), + output_min=output_min, + output_max=output_max, + monotonicity=None, + input_keypoints_type=InputKeypointsType.LEARNED, + ) + + loss_fn = torch.nn.MSELoss() + optimizer = torch.optim.Adam(calibrator.parameters(), lr=1e-3) + + train_calibrated_module( + calibrator, + training_examples, + training_labels, + loss_fn, + optimizer, + 300, # Number of epochs + num_examples // 10, # Batch size + ) + + # Test that the learned keypoints roughly match the expected ones + assert torch.allclose( + calibrator._interpolation_keypoints, + torch.tensor([0, 1 / 3, 2 / 3]).double(), + atol=0.02, + )