Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow slicing with custom step size #59

Merged
merged 5 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/hdf5_hl.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ export declare type Dtype = string | {
};
export type { Metadata, Filter, CompoundMember, CompoundTypeMetadata, EnumTypeMetadata };
declare type TypedArray = Int8Array | Uint8Array | Uint8ClampedArray | Int16Array | Uint16Array | Int32Array | Uint32Array | BigInt64Array | BigUint64Array | Float32Array | Float64Array;
/**
* Describes an array slice.
* `[]` - all data
* `[i0]` - select all data starting from the index `i0`
* `[i0, i1]` - select all data in the range `i0` to `i1`
* `[i0, i1, s]` - select every `s` values in the range `i0` to `i1`
**/
declare type Slice = [] | [number | null] | [number | null, number | null] | [number | null, number | null, number | null];
export declare type GuessableDataTypes = TypedArray | number | number[] | string | string[];
declare enum OBJECT_TYPE {
DATASET = "Dataset",
Expand Down Expand Up @@ -116,8 +124,8 @@ export declare class Dataset extends HasAttrs {
get filters(): Filter[];
get value(): OutputData;
get json_value(): JSONCompatibleOutputData;
slice(ranges: Array<Array<number>>): OutputData;
write_slice(ranges: Array<Array<number>>, data: any): void;
slice(ranges: Slice[]): OutputData;
write_slice(ranges: Slice[], data: any): void;
to_array(): string | number | boolean | JSONCompatibleOutputData[];
resize(new_shape: number[]): number;
_value_getter(json_compatible?: boolean): OutputData;
Expand Down
52 changes: 33 additions & 19 deletions src/hdf5_hl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,15 @@ const TypedArray_to_dtype = new Map([
['Float64Array', '<d']
])

/**
* Describes an array slice.
* `[]` - all data
* `[i0]` - select all data starting from the index `i0`
* `[i0, i1]` - select all data in the range `i0` to `i1`
* `[i0, i1, s]` - select every `s` values in the range `i0` to `i1`
**/
type Slice = [] | [number|null] | [number|null,number|null] | [number|null, number|null, number|null];

export type GuessableDataTypes = TypedArray | number | number[] | string | string[];

function guess_dtype(data: GuessableDataTypes): string {
Expand Down Expand Up @@ -786,6 +795,17 @@ export class File extends Group {
}
}

const calculateHyperslabParams = (shape: number[],ranges: Slice[]) => {
const strides = shape.map((s, i) => BigInt(ranges?.[i]?.[2] ?? 1));
const count = shape.map((s, i) => {
const N = BigInt((Math.min(s, ranges?.[i]?.[1] ?? s) - Math.max(0, ranges?.[i]?.[0] ?? 0)));
const st = strides[i];
return N / st + ((N % st) + st - 1n)/st
});
const offset = shape.map((s, i) => BigInt(Math.min(s, Math.max(0, ranges?.[i]?.[0] ?? 0))));
return {strides, count, offset}
}

export class Dataset extends HasAttrs {
private _metadata?: Metadata;

Expand Down Expand Up @@ -831,20 +851,18 @@ export class Dataset extends HasAttrs {
return this._value_getter(true) as JSONCompatibleOutputData;
}

slice(ranges: Array<Array<number>>) {
slice(ranges: Slice[]) {
// interpret ranges as [start, stop], with one per dim.
let metadata = this.metadata;
const metadata = this.metadata;
// if auto_refresh is on, getting the metadata has triggered a refresh of the dataset_id;
const { shape } = metadata;
let ndims = shape.length;
let count = shape.map((s, i) => BigInt(Math.min(s, ranges?.[i]?.[1] ?? s) - Math.max(0, ranges?.[i]?.[0] ?? 0)));
let offset = shape.map((s, i) => BigInt(Math.min(s, Math.max(0, ranges?.[i]?.[0] ?? 0))));
let total_size = count.reduce((previous, current) => current * previous, 1n);
let nbytes = metadata.size * Number(total_size);
let data_ptr = Module._malloc(nbytes);
var processed;
const {strides, count, offset} = calculateHyperslabParams(shape, ranges);
const total_size = count.reduce((previous, current) => current * previous, 1n);
const nbytes = metadata.size * Number(total_size);
const data_ptr = Module._malloc(nbytes);
let processed;
try {
Module.get_dataset_data(this.file_id, this.path, count, offset, BigInt(data_ptr));
Module.get_dataset_data(this.file_id, this.path, count, offset, strides, BigInt(data_ptr));
let data = Module.HEAPU8.slice(data_ptr, data_ptr + nbytes);
processed = process_data(data, metadata, false);
} finally {
Expand All @@ -856,26 +874,22 @@ export class Dataset extends HasAttrs {
return processed;
}

write_slice(ranges: Array<Array<number>>, data: any) {
write_slice(ranges: Slice[], data: any) {
// interpret ranges as [start, stop], with one per dim.
let metadata = this.metadata;
if (metadata.vlen) {
throw new Error("writing to a slice of vlen dtype is not implemented");
}
// if auto_refresh is on, getting the metadata has triggered a refresh of the dataset_id;
const { shape } = metadata;
let ndims = shape.length;
let count = shape.map((s, i) => BigInt(Math.min(s, ranges?.[i]?.[1] ?? s) - Math.max(0, ranges?.[i]?.[0] ?? 0)));
let offset = shape.map((s, i) => BigInt(Math.min(s, Math.max(0, ranges?.[i]?.[0] ?? 0))));
let total_size = count.reduce((previous, current) => current * previous, 1n);
let nbytes = metadata.size * Number(total_size);
// if auto_refresh is on, getting the metadata has triggered a refresh of the dataset_id;
const {strides, count, offset} = calculateHyperslabParams(shape, ranges);

const { data: prepared_data, shape: guessed_shape } = prepare_data(data, metadata, count);
let data_ptr = Module._malloc((prepared_data as Uint8Array).byteLength);
Module.HEAPU8.set(prepared_data as Uint8Array, data_ptr);

try {
Module.set_dataset_data(this.file_id, this.path, count, offset, BigInt(data_ptr));
Module.set_dataset_data(this.file_id, this.path, count, offset, strides, BigInt(data_ptr));
}
finally {
Module._free(data_ptr);
Expand Down Expand Up @@ -906,7 +920,7 @@ export class Dataset extends HasAttrs {
let data_ptr = Module._malloc(nbytes);
let processed: OutputData;
try {
Module.get_dataset_data(this.file_id, this.path, null, null, BigInt(data_ptr));
Module.get_dataset_data(this.file_id, this.path, null, null, null, BigInt(data_ptr));
let data = Module.HEAPU8.slice(data_ptr, data_ptr + nbytes);
processed = process_data(data, metadata, json_compatible);
} finally {
Expand Down
14 changes: 8 additions & 6 deletions src/hdf5_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ val get_dataset_filters(hid_t loc_id, const std::string& dataset_name_string)
return filters;
}

int read_write_dataset_data(hid_t loc_id, const std::string& dataset_name_string, val count_out, val offset_out, uint64_t rwdata_uint64, bool write=false)
int read_write_dataset_data(hid_t loc_id, const std::string& dataset_name_string, val count_out, val offset_out, val stride_out, uint64_t rwdata_uint64, bool write=false)
{
hid_t ds_id;
hid_t dspace;
Expand Down Expand Up @@ -510,8 +510,10 @@ int read_write_dataset_data(hid_t loc_id, const std::string& dataset_name_string
{
std::vector<hsize_t> count = vecFromJSArray<hsize_t>(count_out);
std::vector<hsize_t> offset = vecFromJSArray<hsize_t>(offset_out);
std::vector<hsize_t> strides = vecFromJSArray<hsize_t>(stride_out);

memspace = H5Screate_simple(count.size(), &count[0], nullptr);
status = H5Sselect_hyperslab(dspace, H5S_SELECT_SET, &offset[0], NULL, &count[0], NULL);
status = H5Sselect_hyperslab(dspace, H5S_SELECT_SET, &offset[0], &strides[0], &count[0], NULL);
status = H5Sselect_all(memspace);
}
else
Expand All @@ -535,14 +537,14 @@ int read_write_dataset_data(hid_t loc_id, const std::string& dataset_name_string
return (int)status;
}

int get_dataset_data(hid_t loc_id, const std::string& dataset_name_string, val count_out, val offset_out, uint64_t rdata_uint64)
int get_dataset_data(hid_t loc_id, const std::string& dataset_name_string, val count_out, val offset_out, val stride_out, uint64_t rdata_uint64)
{
return read_write_dataset_data(loc_id, dataset_name_string, count_out, offset_out, rdata_uint64, false);
return read_write_dataset_data(loc_id, dataset_name_string, count_out, offset_out, stride_out, rdata_uint64, false);
}

int set_dataset_data(hid_t loc_id, const std::string& dataset_name_string, val count_out, val offset_out, uint64_t wdata_uint64)
int set_dataset_data(hid_t loc_id, const std::string& dataset_name_string, val count_out, val offset_out, val stride_out, uint64_t wdata_uint64)
{
return read_write_dataset_data(loc_id, dataset_name_string, count_out, offset_out, wdata_uint64, true);
return read_write_dataset_data(loc_id, dataset_name_string, count_out, offset_out, stride_out, wdata_uint64, true);
}

int reclaim_vlen_memory(hid_t loc_id, const std::string& object_name_string, const std::string& attribute_name_string, uint64_t rdata_uint64)
Expand Down
4 changes: 2 additions & 2 deletions src/hdf5_util_helpers.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ export interface H5Module extends EmscriptenModule {
H5Z_FILTER_MAX: 65535;
create_group(file_id: bigint, name: string): number;
create_vlen_str_dataset(file_id: bigint, dset_name: string, prepared_data: any, shape: bigint[], maxshape: (bigint | null)[], chunks: bigint[] | null, type: number, size: number, signed: boolean, vlen: boolean): number;
get_dataset_data(file_id: bigint, path: string, count: bigint[] | null, offset: bigint[] | null, rdata_ptr: bigint): number;
set_dataset_data(file_id: bigint, path: string, count: bigint[] | null, offset: bigint[] | null, wdata_ptr: bigint): number;
get_dataset_data(file_id: bigint, path: string, count: bigint[] | null, offset: bigint[] | null, strides: bigint[] | null, rdata_ptr: bigint): number;
set_dataset_data(file_id: bigint, path: string, count: bigint[] | null, offset: bigint[] | null, strides: bigint[] | null, wdata_ptr: bigint): number;
refresh_dataset(file_id: bigint, path: string): number;
resize_dataset(file_id: bigint, path: string, new_size: bigint[]): number;
get_dataset_metadata(file_id: bigint, path: string): Metadata;
Expand Down
68 changes: 68 additions & 0 deletions test/overwrite_dataset.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,78 @@ async function overwrite_datasets() {

}

async function overwrite_datasets_strides() {

await h5wasm.ready;
const PATH = join(".", "test", "tmp");
const FILEPATH = join(PATH, "overwrite_dataset_strides.h5");

if (!(existsSync(PATH))) {
mkdirSync(PATH);
}

let write_file = new h5wasm.File(FILEPATH, "w");

const dset1D = write_file.create_dataset({
name: "overwrite-1d",
data: [0,1,2,3,4,5,6,7,8,9],
shape: [10],
dtype: "<f4"
});

// read slices 1D
assert.deepEqual([...dset1D.slice([[null, null, 2]])], [0,2,4,6,8]);
assert.deepEqual([...dset1D.slice([[1, null, 2]])], [1,3,5,7,9]);
assert.deepEqual([...dset1D.slice([[3, 7, 2]])], [3,5]);
assert.deepEqual([...dset1D.slice([[null, null, 3]])], [0,3,6,9]);
assert.deepEqual([...dset1D.slice([[1, null, 3]])], [1,4,7]);
assert.deepEqual([...dset1D.slice([[3, 9, 3]])], [3,6]);
assert.deepEqual([...dset1D.slice([[null, null, 100]])], [0]);

// write slices 1D
dset1D.write_slice([[3, 9, 3]], [-1,-2])
assert.deepEqual([...dset1D.value].map(Number), [0,1,2,-1,4,5,-2,7,8,9]);
dset1D.write_slice([[null, 5, 2]], [-3,-4,-5])
assert.deepEqual([...dset1D.value].map(Number), [-3,1,-4,-1,-5,5,-2,7,8,9]);

const dset2D = write_file.create_dataset({
name: "overwrite-2d",
data: [1,2,3,4,5,6,7,8,9],
shape: [3,3],
dtype: "<f4"
});

// read slices 2D
assert.deepEqual([...dset2D.slice([[null, null, 2], [null, null, null]])], [1,2,3,7,8,9]);
assert.deepEqual([...dset2D.slice([[null, null, 2], [null, null, 2]])], [1,3,7,9]);
assert.deepEqual([...dset2D.slice([[null, null, 3], [null, null, 2]])], [1,3]);
assert.deepEqual([...dset2D.slice([[1, null, 2], [null, null, null]])], [4,5,6]);
assert.deepEqual([...dset2D.slice([[1, null, 2], [null, null, 2]])], [4,6]);
assert.deepEqual([...dset2D.slice([[1, null, 2], [1, null, 2]])], [5]);
assert.deepEqual([...dset2D.slice([[null, null, 100], [null, null, 100]])], [1]);

// write slices 2D
dset2D.write_slice([[1, null, 2], [1, null, 2]], [-1])
assert.deepEqual([...dset2D.value].map(Number), [1,2,3,4,-1,6,7,8,9]);
dset2D.write_slice([[null, null, 2], [null, null, 2]], [-2,-3,-4,-5])
assert.deepEqual([...dset2D.value].map(Number), [-2,2,-3,4,-1,6,-4,8,-5]);

write_file.flush();
write_file.close();

// cleanup file when finished:
unlinkSync(FILEPATH);

}

export const tests = [
{
description: "Overwrite slices of existing dataset",
test: overwrite_datasets
},
{
description: "Overwrite slices of existing using strides",
test: overwrite_datasets_strides
}
]
export default tests;
Loading