Skip to content

Commit

Permalink
chore: add weights clipping
Browse files Browse the repository at this point in the history
  • Loading branch information
RomanBredehoft committed May 28, 2024
1 parent 29d1665 commit 332ee71
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
2 changes: 1 addition & 1 deletion docs/advanced_examples/LogisticRegressionTraining.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@
"\n",
"batch_size = model_concrete_partial.batch_size\n",
"\n",
"classes = np.unique(y2_train)\n",
"classes = np.unique(y_train)\n",
"\n",
"# Go through the training batches\n",
"accuracy_scores = []\n",
Expand Down
9 changes: 7 additions & 2 deletions src/concrete/ml/sklearn/_fhe_training_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Utility functions for FHE training."""

from typing import Tuple
from typing import List, Optional, Tuple

import numpy
import torch
Expand Down Expand Up @@ -29,7 +29,12 @@ class LogisticRegressionTraining(torch.nn.Module):
The forward function iterates the SGD over a given certain number of times.
"""

def __init__(self, iterations: int = 1, learning_rate: float = 1.0, fit_bias: bool = True):
def __init__(
self,
iterations: int = 1,
learning_rate: float = 1.0,
fit_bias: bool = True,
):
"""Instantiate the model.
Args:
Expand Down
4 changes: 2 additions & 2 deletions src/concrete/ml/sklearn/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def __init__(

if self.parameters_range is None:
raise ValueError(
"Setting 'parameter_range' is mandatory if FHE training is enabled "
"Setting 'parameters_range' is mandatory if FHE training is enabled "
f"({fit_encrypted=}). Got {parameters_range=}"
)

Expand Down Expand Up @@ -358,6 +358,7 @@ def _get_training_quantized_module(
# Enable the underlying FHE circuit to be composed with itself
# This feature is used in order to be able to iterate in the clear n times without having
# to encrypt/decrypt the weight/bias values between each loop
# configuration = Configuration(composable=True, detect_overflow_in_simulation=False)
configuration = Configuration(composable=True)

# Compile the model using the compile set
Expand Down Expand Up @@ -446,7 +447,6 @@ def _fit_encrypted(
f" was: {self.classes_}"
)


n_samples, n_features = X.shape
weight_shape = (1, n_features, 1)
bias_shape = (1, 1, 1)
Expand Down

0 comments on commit 332ee71

Please sign in to comment.