Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DPConvCNP #15

Draft
wants to merge 44 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
a420224
added dpsetconv and privacy accounting
stratisMarkou Oct 25, 2022
1633c2a
adding privacy accounting functionality
stratisMarkou Oct 28, 2022
051c123
changes for running experiments
stratisMarkou Nov 1, 2022
4e0008f
adding requirements
stratisMarkou Nov 1, 2022
7e6e543
changes for DP Setconv
stratisMarkou Jan 8, 2023
436605a
pushing changes for variable-lengthscale training
stratisMarkou Feb 2, 2023
e951686
mixture of distributions working
stratisMarkou Feb 8, 2023
04d1a7d
small changes to predefined and experiment util plots
stratisMarkou Feb 14, 2023
878cbbf
changing hypers for GP comparison
stratisMarkou Feb 14, 2023
2f55c04
amortising noise level
stratisMarkou Feb 20, 2023
1ae7886
added messy plotting code and MLP amortisation
stratisMarkou Feb 28, 2023
32cb2c5
add epsilon=3 to plotting
stratisMarkou Mar 1, 2023
84e1f0d
made MLP simpler
stratisMarkou Mar 1, 2023
a9ac945
added separate optimisers for encoder setconv and rest of model
stratisMarkou Mar 22, 2023
ea3bfaa
changes to train script
stratisMarkou Apr 4, 2023
e485beb
add fake scale logging
stratisMarkou Apr 4, 2023
3442fd3
add fake scale
stratisMarkou Apr 4, 2023
7542e72
debugging changes
stratisMarkou Apr 4, 2023
d9ffeb7
resolve conflicts
stratisMarkou Apr 4, 2023
7481357
fix forgotten conflict
stratisMarkou Apr 4, 2023
afbb424
minor changs
stratisMarkou May 31, 2023
88d700d
changes to train.py
stratisMarkou May 31, 2023
311a4b9
Merge branch 'main' into dpconvcnp
stratisMarkou May 31, 2023
49eee90
add tensorboard logging
stratisMarkou May 31, 2023
884a81b
changes to setconv and train script
stratisMarkou Jun 2, 2023
dcc3be7
Replacing stheno with gpytorch in DPSetConv
Jun 2, 2023
6291fa9
doing some cleanup
stratisMarkou Jun 4, 2023
cd1cd60
finish comments and docstrings
stratisMarkou Jun 4, 2023
e37dfc2
more cleanup and piping
stratisMarkou Jun 4, 2023
a828729
couple more comments
stratisMarkou Jun 4, 2023
532bfe0
arg piping works
stratisMarkou Jun 4, 2023
fec6d43
changes to train script
stratisMarkou Jun 4, 2023
794f383
fix bugs
stratisMarkou Jun 4, 2023
21badab
fix paramter args
stratisMarkou Jun 4, 2023
76f3a1d
add some more changes
stratisMarkou Jun 12, 2023
63e52fe
SGD to Adam and experiment name
stratisMarkou Jun 14, 2023
5bce65c
add command to train
stratisMarkou Jun 14, 2023
e17d013
change to comment
stratisMarkou Jun 14, 2023
6e771f7
update train.py comments
stratisMarkou Jun 15, 2023
dabea18
add clipping
stratisMarkou Jun 15, 2023
b3aef69
update comments
stratisMarkou Jun 15, 2023
617c28a
Update train.py
stratisMarkou Jun 15, 2023
13366b4
split batch into two
stratisMarkou Jun 29, 2023
1a591e4
Merge branch 'dpconvcnp-batch-size' of https://github.com/wesselb/neu…
stratisMarkou Jun 29, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,7 @@ cover
*.swp
.vscode/*
venv-np
_experiments
_experiments*
*.ipynb_checkpoints
*__pycache__
_*
60 changes: 49 additions & 11 deletions experiment/data/gp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from functools import partial
from itertools import product
import numpy as np

import torch

Expand All @@ -18,6 +20,7 @@ def setup(name, args, config, *, num_tasks_train, num_tasks_cv, num_tasks_eval,
config["unet_strides"] = (2,) * 6
config["conv_receptive_field"] = 4
config["margin"] = 0.1

if args.dim_x == 1:
config["points_per_unit"] = 64
elif args.dim_x == 2:
Expand All @@ -38,14 +41,18 @@ def setup(name, args, config, *, num_tasks_train, num_tasks_cv, num_tasks_eval,

gen_train = nps.construct_predefined_gens(
torch.float32,
seed=10,
seed=1, # 10
batch_size=args.batch_size,
num_tasks=num_tasks_train,
dim_x=args.dim_x,
dim_y=args.dim_y,
pred_logpdf=False,
pred_logpdf_diag=False,
device=device,
dp_epsilon_range=config["dp_epsilon_range"],
dp_log10_delta_range=config["dp_log10_delta_range"],
min_log10_scale=config["min_log10_scale"],
max_log10_scale=config["max_log10_scale"],
mean_diff=config["mean_diff"],
)[name]

Expand All @@ -59,13 +66,45 @@ def setup(name, args, config, *, num_tasks_train, num_tasks_cv, num_tasks_eval,
pred_logpdf=True,
pred_logpdf_diag=True,
device=device,
dp_epsilon_range=config["dp_epsilon_range"],
dp_log10_delta_range=config["dp_log10_delta_range"],
min_log10_scale=config["min_log10_scale"],
max_log10_scale=config["max_log10_scale"],
mean_diff=config["mean_diff"],
)[name]

#def gens_eval():
# return [
# (
# eval_name,
# nps.construct_predefined_gens(
# torch.float32,
# seed=30, # Use yet another seed!
# batch_size=args.batch_size,
# num_tasks=num_tasks_eval,
# dim_x=args.dim_x,
# dim_y=args.dim_y,
# pred_logpdf=True,
# pred_logpdf_diag=True,
# device=device,
# dp_epsilon_range=config["dp_epsilon_range"],
# dp_log10_delta_range=config["dp_log10_delta_range"],
# x_range_context=x_range_context,
# x_range_target=x_range_target,
# mean_diff=config["mean_diff"],
# )[args.data],
# )
# for eval_name, x_range_context, x_range_target in [
# ("interpolation in training range", (-2, 2), (-2, 2)),
# ("interpolation beyond training range", (2, 6), (2, 6)),
# ("extrapolation beyond training range", (-2, 2), (2, 6)),
# ]
# ]

def gens_eval():
return [
(
eval_name,
f"Using scale = {10**log_10_scale:.3f}, epsilon = {fixed_epsilon}",
nps.construct_predefined_gens(
torch.float32,
seed=30, # Use yet another seed!
Expand All @@ -76,16 +115,13 @@ def gens_eval():
pred_logpdf=True,
pred_logpdf_diag=True,
device=device,
x_range_context=x_range_context,
x_range_target=x_range_target,
mean_diff=config["mean_diff"],
)[args.data],
dp_epsilon_range=(fixed_epsilon, fixed_epsilon),
dp_log10_delta_range=config["dp_log10_delta_range"],
min_log10_scale=log_10_scale,
max_log10_scale=log_10_scale,
)["scale-mix-eq"], # [args.data],
)
for eval_name, x_range_context, x_range_target in [
("interpolation in training range", (-2, 2), (-2, 2)),
("interpolation beyond training range", (2, 6), (2, 6)),
("extrapolation beyond training range", (-2, 2), (2, 6)),
]
for fixed_epsilon, log_10_scale in product([9.], np.log10(np.array([0.10, 0.15, 0.20, 0.25, 0.35, 0.5, 1.0])))
]

return gen_train, gen_cv, gens_eval
Expand All @@ -96,6 +132,8 @@ def gens_eval():
"weakly-periodic",
"mix-eq",
"mix-matern",
"scale-mix-eq",
"scale-mix-matern",
"mix-weakly-periodic",
"sawtooth",
"mixture",
Expand Down
48 changes: 39 additions & 9 deletions experiment/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def visualise_1d(model, gen, *, path, config, predict):
nps.AggregateInput(
*((x[None, None, :], i) for i in range(config["dim_y"]))
),
epsilon=batch["epsilon"],
delta=batch["delta"],
)

plt.figure(figsize=(8, 6 * config["dim_y"]))
Expand All @@ -66,13 +68,13 @@ def visualise_1d(model, gen, *, path, config, predict):
s=20,
)

plt.scatter(
nps.batch_xt(batch, i)[0, 0],
nps.batch_yt(batch, i)[0],
label="Target",
style="test",
s=20,
)
# plt.scatter(
# nps.batch_xt(batch, i)[0, 0],
# nps.batch_yt(batch, i)[0],
# label="Target",
# style="test",
# s=20,
# )

# Plot prediction.
err = 1.96 * B.sqrt(var[i][0, 0])
Expand All @@ -97,8 +99,14 @@ def visualise_1d(model, gen, *, path, config, predict):
)

# Plot prediction by ground truth.
if hasattr(gen, "kernel") and config["dim_y"] == 1:
f = stheno.GP(gen.kernel)
if (hasattr(gen, "kernel") or hasattr(gen, "kernel_type")) and config["dim_y"] == 1:

if hasattr(gen, "kernel_type"):
f = stheno.GP(gen.kernel_type().stretch(batch["scale"]))

else:
f = stheno.GP(gen.kernel)

# Make sure that everything is of `float64`s and on the GPU.
noise = B.to_active_device(B.cast(torch.float64, gen.noise))
xc = B.cast(torch.float64, nps.batch_xc(batch, 0)[0, 0])
Expand All @@ -107,13 +115,33 @@ def visualise_1d(model, gen, *, path, config, predict):
# Compute posterior GP.
f_post = f | (f(xc, noise), yc)
mean, lower, upper = f_post(x).marginal_credible_bounds()

lower = mean - 1.96 * (((mean - lower)/1.96)**2. + noise)**0.5
upper = mean + 1.96 * (((upper - mean)/1.96)**2. + noise)**0.5

plt.plot(x, mean, label="Truth", style="pred2")
plt.plot(x, lower, style="pred2")
plt.plot(x, upper, style="pred2")

for x_axvline in plot_config["axvline"]:
plt.axvline(x_axvline, c="k", ls="--", lw=0.5)

nps.batch_yt(batch, i)[0],

N = nps.batch_yc(batch, i)[0].shape[0]
ell = 0.25 if 'scale' not in batch else batch['scale'].detach().cpu().numpy()[0]
epsilon = batch['epsilon'][i].numpy()[0]
delta = batch['delta'][i].numpy()[0]

plt.gca().set_title(
f"$N = {N:.0f}$ " + \
f"$\\ell = {ell:.3f}$ " + \
f"$\\epsilon = {epsilon:.2f}$ " + \
f"$N\\ell\\epsilon \\approx {N*ell*epsilon:.0f}$ " + \
f"$\\delta = {delta:.3f}$",
fontsize=24,
)

plt.xlim(B.min(x), B.max(x))
tweak()

Expand Down Expand Up @@ -149,6 +177,8 @@ def visualise_2d(model, gen, *, path, config, predict):
nps.AggregateInput(
*((x_list[None, :, :], i) for i in range(config["dim_y"]))
),
epsilon=batch["epsilon"],
delta=batch["delta"],
num_samples=2,
)

Expand Down
66 changes: 56 additions & 10 deletions neuralprocesses/architectures/convgnp.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,14 @@ def _convgnp_construct_encoder_setconvs(
dim_yc,
disc,
dtype=None,
learnable_scale=True,
use_dp=False,
dp_learn_params=True,
dp_amortise_params=True,
dp_use_noise_channels=False,
dp_y_bound=None,
dp_t=None,
init_factor=1,
encoder_scales_learnable=True,
):
# Initialise scale.
if encoder_scales is not None:
Expand All @@ -51,13 +57,33 @@ def _convgnp_construct_encoder_setconvs(
# Ensure that there is one for every context set.
if not isinstance(encoder_scales, (tuple, list)):
encoder_scales = (encoder_scales,) * len(dim_yc)
# Construct set convs.
return nps.Parallel(
*(
nps.SetConv(s, dtype=dtype, learnable=encoder_scales_learnable)
for s in encoder_scales

if use_dp:
# Construct DP set convs.
return nps.Parallel(
*(
nps.DPSetConv(
s,
y_bound=dp_y_bound,
t=dp_t,
dp_learn_params=dp_learn_params,
dp_amortise_params=dp_amortise_params,
learnable_scale=learnable_scale,
dp_use_noise_channels=dp_use_noise_channels,
dtype=dtype,
)
for s in encoder_scales
)
)

else:
# Construct set convs.
return nps.Parallel(
*(
nps.SetConv(s, dtype=dtype, learnable=learnable_scale)
for s in encoder_scales
)
)
)


def _convgnp_assert_form_contexts(nps, dim_yc):
Expand Down Expand Up @@ -116,7 +142,7 @@ def construct_convgnp(
dim_lv=0,
lv_likelihood="het",
encoder_scales=None,
encoder_scales_learnable=True,
encoder_scale_learnable=True,
decoder_scale=None,
decoder_scale_learnable=True,
aux_t_mlp_layers=(128,) * 3,
Expand All @@ -125,6 +151,12 @@ def construct_convgnp(
transform=None,
dtype=None,
nps=nps,
use_dp=False,
dp_learn_params=True,
dp_amortise_params=False,
dp_use_noise_channels=True,
dp_y_bound=None,
dp_t=None,
):
"""A Convolutional Gaussian Neural Process.

Expand Down Expand Up @@ -275,7 +307,7 @@ def construct_convgnp(
in_channels = dim_lv
out_channels = conv_out_channels # These must be equal!
else:
in_channels = conv_in_channels
in_channels = 2 * conv_in_channels if use_dp and dp_use_noise_channels else conv_in_channels
out_channels = conv_out_channels # These must be equal!
if "unet" in conv_arch:
if dim_lv > 0:
Expand Down Expand Up @@ -352,6 +384,14 @@ def construct_convgnp(
margin=margin,
dim=dim_x,
)

if use_dp and divide_by_density:

raise ValueError(
f"If use_dp=True, divide_by_density should be False, "
f"found {divide_by_density}."
)


# Construct model.
model = nps.Model(
Expand All @@ -366,7 +406,13 @@ def construct_convgnp(
dim_yc,
disc,
dtype,
encoder_scales_learnable=encoder_scales_learnable,
use_dp=use_dp,
learnable_scale=encoder_scale_learnable,
dp_learn_params=dp_learn_params,
dp_amortise_params=dp_amortise_params,
dp_use_noise_channels=dp_use_noise_channels,
dp_y_bound=dp_y_bound,
dp_t=dp_t,
),
_convgnp_optional_division_by_density(nps, divide_by_density, epsilon),
nps.Concatenate(),
Expand Down
33 changes: 33 additions & 0 deletions neuralprocesses/coders/setconv/privacy_accounting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import numpy as np
from scipy import special
import scipy.optimize as optim

def delta(epsilon, sens_per_sigma):
"""Compute delta for given epsilon and sensitivity per noise standard deviation for the Gaussian mechanism.

Args:
epsilon (float)
sens_per_sigma (float): Sensitivity per noise standard deviation.

Returns:
float: Delta
"""
if sens_per_sigma <= 0:
return 0
mu = sens_per_sigma**2 / 2
term1 = special.erfc((epsilon - mu) / np.sqrt(mu) / 2)
term2 = np.exp(epsilon) * special.erfc((epsilon + mu) / np.sqrt(mu) / 2)
return 0.5 * (term1 - term2)

def find_sens_per_sigma(epsilon, delta_bound, upper_bound=20):
"""Find the required sensitivity per noise standard deviation for (epsilon, delta)-DP with Gaussian mechanism.

Args:
epsilon (float)
delta_bound (float)
upper_bound (float, optional): Upper bound guess on sensitivity per sigma. Defaults to 20.

Returns:
float: The required sensitivity per noise standard deviation.
"""
return optim.brentq(lambda sens_per_sigma: delta(epsilon, sens_per_sigma) - delta_bound, 0, upper_bound)
Loading