diff --git a/include/secp256k1.h b/include/secp256k1.h index 91cdd3672f597..3c4a311a05681 100644 --- a/include/secp256k1.h +++ b/include/secp256k1.h @@ -260,12 +260,10 @@ SECP256K1_API void secp256k1_context_set_error_callback( * * Returns: a newly created scratch space. * Args: ctx: an existing context object (cannot be NULL) - * In: init_size: initial amount of memory to allocate - * max_size: maximum amount of memory to allocate + * In: max_size: maximum amount of memory to allocate */ SECP256K1_API SECP256K1_WARN_UNUSED_RESULT secp256k1_scratch_space* secp256k1_scratch_space_create( const secp256k1_context* ctx, - size_t init_size, size_t max_size ) SECP256K1_ARG_NONNULL(1); diff --git a/src/bench_ecmult.c b/src/bench_ecmult.c index 3a7bfe379c67d..52d0476a30ffb 100644 --- a/src/bench_ecmult.c +++ b/src/bench_ecmult.c @@ -154,7 +154,7 @@ int main(int argc, char **argv) { /* Allocate stuff */ data.ctx = secp256k1_context_create(SECP256K1_CONTEXT_SIGN | SECP256K1_CONTEXT_VERIFY); scratch_size = secp256k1_strauss_scratch_size(POINTS) + STRAUSS_SCRATCH_OBJECTS*16; - data.scratch = secp256k1_scratch_space_create(data.ctx, scratch_size, scratch_size); + data.scratch = secp256k1_scratch_space_create(data.ctx, scratch_size); data.scalars = malloc(sizeof(secp256k1_scalar) * POINTS); data.seckeys = malloc(sizeof(secp256k1_scalar) * POINTS); data.pubkeys = malloc(sizeof(secp256k1_ge) * POINTS); diff --git a/src/ecmult_impl.h b/src/ecmult_impl.h index 22aec462e0334..d5fb6c5b61dd2 100644 --- a/src/ecmult_impl.h +++ b/src/ecmult_impl.h @@ -525,10 +525,9 @@ static int secp256k1_ecmult_strauss_batch(const secp256k1_ecmult_context *ctx, s return 1; } - if (!secp256k1_scratch_resize(scratch, secp256k1_strauss_scratch_size(n_points), STRAUSS_SCRATCH_OBJECTS)) { + if (!secp256k1_scratch_allocate_frame(scratch, secp256k1_strauss_scratch_size(n_points), STRAUSS_SCRATCH_OBJECTS)) { return 0; } - secp256k1_scratch_reset(scratch); points = (secp256k1_gej*)secp256k1_scratch_alloc(scratch, n_points * sizeof(secp256k1_gej)); scalars = (secp256k1_scalar*)secp256k1_scratch_alloc(scratch, n_points * sizeof(secp256k1_scalar)); state.prej = (secp256k1_gej*)secp256k1_scratch_alloc(scratch, n_points * ECMULT_TABLE_SIZE(WINDOW_A) * sizeof(secp256k1_gej)); @@ -543,10 +542,14 @@ static int secp256k1_ecmult_strauss_batch(const secp256k1_ecmult_context *ctx, s for (i = 0; i < n_points; i++) { secp256k1_ge point; - if (!cb(&scalars[i], &point, i+cb_offset, cbdata)) return 0; + if (!cb(&scalars[i], &point, i+cb_offset, cbdata)) { + secp256k1_scratch_deallocate_frame(scratch); + return 0; + } secp256k1_gej_set_ge(&points[i], &point); } secp256k1_ecmult_strauss_wnaf(ctx, &state, r, n_points, points, scalars, inp_g_sc); + secp256k1_scratch_deallocate_frame(scratch); return 1; } @@ -873,10 +876,9 @@ static int secp256k1_ecmult_pippenger_batch(const secp256k1_ecmult_context *ctx, } bucket_window = secp256k1_pippenger_bucket_window(n_points); - if (!secp256k1_scratch_resize(scratch, secp256k1_pippenger_scratch_size(n_points, bucket_window), PIPPENGER_SCRATCH_OBJECTS)) { + if (!secp256k1_scratch_allocate_frame(scratch, secp256k1_pippenger_scratch_size(n_points, bucket_window), PIPPENGER_SCRATCH_OBJECTS)) { return 0; } - secp256k1_scratch_reset(scratch); points = (secp256k1_ge *) secp256k1_scratch_alloc(scratch, entries * sizeof(*points)); scalars = (secp256k1_scalar *) secp256k1_scratch_alloc(scratch, entries * sizeof(*scalars)); state_space = (struct secp256k1_pippenger_state *) secp256k1_scratch_alloc(scratch, sizeof(*state_space)); @@ -896,6 +898,7 @@ static int secp256k1_ecmult_pippenger_batch(const secp256k1_ecmult_context *ctx, while (point_idx < n_points) { if (!cb(&scalars[idx], &points[idx], point_idx + cb_offset, cbdata)) { + secp256k1_scratch_deallocate_frame(scratch); return 0; } idx++; @@ -919,6 +922,7 @@ static int secp256k1_ecmult_pippenger_batch(const secp256k1_ecmult_context *ctx, for(i = 0; i < 1<data = checked_malloc(error_callback, init_size); - if (ret->data == NULL) { - free (ret); - return NULL; - } - ret->offset = 0; - ret->init_size = init_size; + memset(ret, 0, sizeof(*ret)); ret->max_size = max_size; ret->error_callback = error_callback; } @@ -33,45 +27,60 @@ static secp256k1_scratch* secp256k1_scratch_create(const secp256k1_callback* err static void secp256k1_scratch_destroy(secp256k1_scratch* scratch) { if (scratch != NULL) { - free(scratch->data); + VERIFY_CHECK(scratch->frame == 0); free(scratch); } } static size_t secp256k1_scratch_max_allocation(const secp256k1_scratch* scratch, size_t objects) { - if (scratch->max_size <= objects * ALIGNMENT) { + size_t i = 0; + size_t allocated = 0; + for (i = 0; i < scratch->frame; i++) { + allocated += scratch->frame_size[i]; + } + if (scratch->max_size - allocated <= objects * ALIGNMENT) { return 0; } - return scratch->max_size - objects * ALIGNMENT; + return scratch->max_size - allocated - objects * ALIGNMENT; } -static int secp256k1_scratch_resize(secp256k1_scratch* scratch, size_t n, size_t objects) { - n += objects * ALIGNMENT; - if (n > scratch->init_size && n <= scratch->max_size) { - void *tmp = checked_realloc(scratch->error_callback, scratch->data, n); - if (tmp == NULL) { +static int secp256k1_scratch_allocate_frame(secp256k1_scratch* scratch, size_t n, size_t objects) { + VERIFY_CHECK(scratch->frame < SECP256K1_SCRATCH_MAX_FRAMES); + + if (n <= secp256k1_scratch_max_allocation(scratch, objects)) { + n += objects * ALIGNMENT; + scratch->data[scratch->frame] = checked_malloc(scratch->error_callback, n); + if (scratch->data[scratch->frame] == NULL) { return 0; } - scratch->init_size = n; - scratch->data = tmp; + scratch->frame_size[scratch->frame] = n; + scratch->offset[scratch->frame] = 0; + scratch->frame++; + return 1; + } else { + return 0; } - return n <= scratch->max_size; +} + +static void secp256k1_scratch_deallocate_frame(secp256k1_scratch* scratch) { + VERIFY_CHECK(scratch->frame > 0); + scratch->frame -= 1; + free(scratch->data[scratch->frame]); } static void *secp256k1_scratch_alloc(secp256k1_scratch* scratch, size_t size) { void *ret; + size_t frame = scratch->frame - 1; size = ((size + ALIGNMENT - 1) / ALIGNMENT) * ALIGNMENT; - if (size + scratch->offset > scratch->init_size) { + + if (scratch->frame == 0 || size + scratch->offset[frame] > scratch->frame_size[frame]) { return NULL; } - ret = (void *) ((unsigned char *) scratch->data + scratch->offset); + ret = (void *) ((unsigned char *) scratch->data[frame] + scratch->offset[frame]); memset(ret, 0, size); - scratch->offset += size; - return ret; -} + scratch->offset[frame] += size; -static void secp256k1_scratch_reset(secp256k1_scratch* scratch) { - scratch->offset = 0; + return ret; } #endif diff --git a/src/secp256k1.c b/src/secp256k1.c index 7e3ad572c8c3c..cd0972dfaf464 100644 --- a/src/secp256k1.c +++ b/src/secp256k1.c @@ -115,11 +115,9 @@ void secp256k1_context_set_error_callback(secp256k1_context* ctx, void (*fun)(co ctx->error_callback.data = data; } -secp256k1_scratch_space* secp256k1_scratch_space_create(const secp256k1_context* ctx, size_t init_size, size_t max_size) { +secp256k1_scratch_space* secp256k1_scratch_space_create(const secp256k1_context* ctx, size_t max_size) { VERIFY_CHECK(ctx != NULL); - ARG_CHECK(max_size >= init_size); - - return secp256k1_scratch_create(&ctx->error_callback, init_size, max_size); + return secp256k1_scratch_create(&ctx->error_callback, max_size); } void secp256k1_scratch_space_destroy(secp256k1_scratch_space* scratch) { diff --git a/src/tests.c b/src/tests.c index 213281cb3ce8c..e85c46058a692 100644 --- a/src/tests.c +++ b/src/tests.c @@ -258,28 +258,31 @@ void run_scratch_tests(void) { /* Test public API */ secp256k1_context_set_illegal_callback(none, counting_illegal_callback_fn, &ecount); - scratch = secp256k1_scratch_space_create(none, 100, 10); - CHECK(scratch == NULL); - CHECK(ecount == 1); - - scratch = secp256k1_scratch_space_create(none, 100, 100); - CHECK(scratch != NULL); - CHECK(ecount == 1); - secp256k1_scratch_space_destroy(scratch); - scratch = secp256k1_scratch_space_create(none, 100, 1000); + scratch = secp256k1_scratch_space_create(none, 1000); CHECK(scratch != NULL); - CHECK(ecount == 1); + CHECK(ecount == 0); /* Test internal API */ CHECK(secp256k1_scratch_max_allocation(scratch, 0) == 1000); CHECK(secp256k1_scratch_max_allocation(scratch, 1) < 1000); - CHECK(secp256k1_scratch_resize(scratch, 50, 1) == 1); /* no-op */ - CHECK(secp256k1_scratch_resize(scratch, 200, 1) == 1); - CHECK(secp256k1_scratch_resize(scratch, 950, 1) == 1); - CHECK(secp256k1_scratch_resize(scratch, 1000, 1) == 0); - CHECK(secp256k1_scratch_resize(scratch, 2000, 1) == 0); + + /* Allocating 500 bytes with no frame fails */ + CHECK(secp256k1_scratch_alloc(scratch, 500) == NULL); + CHECK(secp256k1_scratch_max_allocation(scratch, 0) == 1000); + + /* ...but pushing a new stack frame does affect the max allocation */ + CHECK(secp256k1_scratch_allocate_frame(scratch, 500, 1 == 1)); + CHECK(secp256k1_scratch_max_allocation(scratch, 1) < 500); /* 500 - ALIGNMENT */ + CHECK(secp256k1_scratch_alloc(scratch, 500) != NULL); + CHECK(secp256k1_scratch_alloc(scratch, 500) == NULL); + + CHECK(secp256k1_scratch_allocate_frame(scratch, 500, 1) == 0); + + /* ...and this effect is undone by popping the frame */ + secp256k1_scratch_deallocate_frame(scratch); CHECK(secp256k1_scratch_max_allocation(scratch, 0) == 1000); + CHECK(secp256k1_scratch_alloc(scratch, 500) == NULL); /* cleanup */ secp256k1_scratch_space_destroy(scratch); @@ -2558,7 +2561,6 @@ void test_ecmult_multi(secp256k1_scratch *scratch, secp256k1_ecmult_multi_func e data.sc = sc; data.pt = pt; secp256k1_scalar_set_int(&szero, 0); - secp256k1_scratch_reset(scratch); /* No points to multiply */ CHECK(ecmult_multi(&ctx->ecmult_ctx, scratch, &r, NULL, ecmult_multi_callback, &data, 0)); @@ -2590,7 +2592,7 @@ void test_ecmult_multi(secp256k1_scratch *scratch, secp256k1_ecmult_multi_func e CHECK(secp256k1_gej_is_infinity(&r)); /* Try to multiply 1 point, but scratch space is empty */ - scratch_empty = secp256k1_scratch_create(&ctx->error_callback, 0, 0); + scratch_empty = secp256k1_scratch_create(&ctx->error_callback, 0); CHECK(!ecmult_multi(&ctx->ecmult_ctx, scratch_empty, &r, &szero, ecmult_multi_callback, &data, 1)); secp256k1_scratch_destroy(scratch_empty); @@ -2816,7 +2818,7 @@ void test_ecmult_multi_pippenger_max_points(void) { int bucket_window = 0; for(; scratch_size < max_size; scratch_size+=256) { - scratch = secp256k1_scratch_create(&ctx->error_callback, 0, scratch_size); + scratch = secp256k1_scratch_create(&ctx->error_callback, scratch_size); CHECK(scratch != NULL); n_points_supported = secp256k1_pippenger_max_points(scratch); if (n_points_supported == 0) { @@ -2824,7 +2826,8 @@ void test_ecmult_multi_pippenger_max_points(void) { continue; } bucket_window = secp256k1_pippenger_bucket_window(n_points_supported); - CHECK(secp256k1_scratch_resize(scratch, secp256k1_pippenger_scratch_size(n_points_supported, bucket_window), PIPPENGER_SCRATCH_OBJECTS)); + CHECK(secp256k1_scratch_allocate_frame(scratch, secp256k1_pippenger_scratch_size(n_points_supported, bucket_window), PIPPENGER_SCRATCH_OBJECTS)); + secp256k1_scratch_deallocate_frame(scratch); secp256k1_scratch_destroy(scratch); } CHECK(bucket_window == PIPPENGER_MAX_BUCKET_WINDOW); @@ -2866,13 +2869,13 @@ void test_ecmult_multi_batching(void) { data.pt = pt; /* Test with empty scratch space */ - scratch = secp256k1_scratch_create(&ctx->error_callback, 0, 0); + scratch = secp256k1_scratch_create(&ctx->error_callback, 0); CHECK(!secp256k1_ecmult_multi_var(&ctx->ecmult_ctx, scratch, &r, &scG, ecmult_multi_callback, &data, 1)); secp256k1_scratch_destroy(scratch); /* Test with space for 1 point in pippenger. That's not enough because * ecmult_multi selects strauss which requires more memory. */ - scratch = secp256k1_scratch_create(&ctx->error_callback, 0, secp256k1_pippenger_scratch_size(1, 1) + PIPPENGER_SCRATCH_OBJECTS*ALIGNMENT); + scratch = secp256k1_scratch_create(&ctx->error_callback, secp256k1_pippenger_scratch_size(1, 1) + PIPPENGER_SCRATCH_OBJECTS*ALIGNMENT); CHECK(!secp256k1_ecmult_multi_var(&ctx->ecmult_ctx, scratch, &r, &scG, ecmult_multi_callback, &data, 1)); secp256k1_scratch_destroy(scratch); @@ -2881,10 +2884,10 @@ void test_ecmult_multi_batching(void) { if (i > ECMULT_PIPPENGER_THRESHOLD) { int bucket_window = secp256k1_pippenger_bucket_window(i); size_t scratch_size = secp256k1_pippenger_scratch_size(i, bucket_window); - scratch = secp256k1_scratch_create(&ctx->error_callback, 0, scratch_size + PIPPENGER_SCRATCH_OBJECTS*ALIGNMENT); + scratch = secp256k1_scratch_create(&ctx->error_callback, scratch_size + PIPPENGER_SCRATCH_OBJECTS*ALIGNMENT); } else { size_t scratch_size = secp256k1_strauss_scratch_size(i); - scratch = secp256k1_scratch_create(&ctx->error_callback, 0, scratch_size + STRAUSS_SCRATCH_OBJECTS*ALIGNMENT); + scratch = secp256k1_scratch_create(&ctx->error_callback, scratch_size + STRAUSS_SCRATCH_OBJECTS*ALIGNMENT); } CHECK(secp256k1_ecmult_multi_var(&ctx->ecmult_ctx, scratch, &r, &scG, ecmult_multi_callback, &data, n_points)); secp256k1_gej_add_var(&r, &r, &r2, NULL); @@ -2900,14 +2903,14 @@ void run_ecmult_multi_tests(void) { test_secp256k1_pippenger_bucket_window_inv(); test_ecmult_multi_pippenger_max_points(); - scratch = secp256k1_scratch_create(&ctx->error_callback, 0, 819200); + scratch = secp256k1_scratch_create(&ctx->error_callback, 819200); test_ecmult_multi(scratch, secp256k1_ecmult_multi_var); test_ecmult_multi(scratch, secp256k1_ecmult_pippenger_batch_single); test_ecmult_multi(scratch, secp256k1_ecmult_strauss_batch_single); secp256k1_scratch_destroy(scratch); /* Run test_ecmult_multi with space for exactly one point */ - scratch = secp256k1_scratch_create(&ctx->error_callback, 0, secp256k1_strauss_scratch_size(1) + STRAUSS_SCRATCH_OBJECTS*ALIGNMENT); + scratch = secp256k1_scratch_create(&ctx->error_callback, secp256k1_strauss_scratch_size(1) + STRAUSS_SCRATCH_OBJECTS*ALIGNMENT); test_ecmult_multi(scratch, secp256k1_ecmult_multi_var); secp256k1_scratch_destroy(scratch); diff --git a/src/tests_exhaustive.c b/src/tests_exhaustive.c index 70795795b6cb1..ab9779b02fc54 100644 --- a/src/tests_exhaustive.c +++ b/src/tests_exhaustive.c @@ -196,7 +196,7 @@ static int ecmult_multi_callback(secp256k1_scalar *sc, secp256k1_ge *pt, size_t void test_exhaustive_ecmult_multi(const secp256k1_context *ctx, const secp256k1_ge *group, int order) { int i, j, k, x, y; - secp256k1_scratch *scratch = secp256k1_scratch_create(&ctx->error_callback, 1024, 4096); + secp256k1_scratch *scratch = secp256k1_scratch_create(&ctx->error_callback, 4096); for (i = 0; i < order; i++) { for (j = 0; j < order; j++) { for (k = 0; k < order; k++) {