Skip to content

Commit

Permalink
Merge pull request #985 from frheault/convert_trks_to_hdf5
Browse files Browse the repository at this point in the history
Generate HDF5 from TRKs
  • Loading branch information
arnaudbore committed Jun 21, 2024
2 parents 4e7047d + 05b09c7 commit cb9ef7c
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 11 deletions.
27 changes: 16 additions & 11 deletions scripts/scil_tractogram_convert_hdf5_to_trk.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
It can either save all connections (default), individual connections specified
with --edge_keys or connections from specific nodes specified with --node_keys.
With the option --save_empty, a label_lists, as a txt file, must be provided.
This option saves existing connections and empty connections.
If a labels_list is provided, it will save all possible connections between the
labels in the list. If no labels_list is provided, it will save all connections
in the hdf5 file. If no argument is provided, it will save only the
connections that are present in the hdf5 file.
The output is a directory containing the thousands of connections:
out_dir/
Expand Down Expand Up @@ -59,13 +61,10 @@ def _build_arg_parser():
'interest.\nEquivalent to adding any --edge_keys '
'node_LABEL2 or LABEL2_node.')

p.add_argument('--save_empty', metavar='labels_list', dest='labels_list',
help='Save empty connections. Then, the list of possible '
'connections is \nnot found from the hdf5 but '
'inferred from labels_list, a txt file \ncontaining '
'a list of nodes saved by the decomposition script.\n'
'*If used together with edge_keys or node_keys, the '
'provided nodes must \nexist in labels_list.')
p.add_argument('--save_empty', nargs='?', metavar='labels_list',
dest='labels_list', const=True,
help='Save empty connections.\nSee script description for '
'more information on labels_list usage.')

add_verbose_arg(p)
add_overwrite_arg(p, will_delete_dirs=True)
Expand All @@ -79,21 +78,27 @@ def main():
logging.getLogger().setLevel(logging.getLevelName(args.verbose))

# Verifications
assert_inputs_exist(parser, args.in_hdf5, args.labels_list)
check_labels = args.labels_list if isinstance(
args.labels_list, str) else None
assert_inputs_exist(parser, args.in_hdf5, check_labels)
assert_output_dirs_exist_and_empty(parser, args, args.out_dir,
create_dir=True)

# Processing
with h5py.File(args.in_hdf5, 'r') as hdf5_file:
all_hdf5_keys = list(hdf5_file.keys())

if args.labels_list:
if isinstance(args.labels_list, str):
all_labels = np.loadtxt(args.labels_list, dtype='str')
comb_list = list(itertools.combinations(all_labels, r=2))
comb_list.extend(zip(all_labels, all_labels))
all_keys = [i[0]+'_'+i[1] for i in comb_list]
keys_origin = "the labels_list file's labels combination"
allow_empty = True
elif args.labels_list:
all_keys = all_hdf5_keys
keys_origin = "the hdf5 stored keys"
allow_empty = True
else:
all_keys = all_hdf5_keys
keys_origin = "the hdf5 stored keys"
Expand Down
108 changes: 108 additions & 0 deletions scripts/scil_tractogram_convert_trk_to_hdf5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Save connections as TRK to HDF5.
This script is useful to convert a set of connections or bundles to a single
HDF5 file. The HDF5 file will contain a group for each input file, with the
streamlines stored in the specified space and origin (keep the default if you
are going to use the connectivity scripts in scilpy).
To make a file compatible with scil_tractogram_commit.py or
scil_connectivity_compute_matrices.py you will have to follow this nomenclature
for the input files:
in_dir/
|-- LABEL1_LABEL1.trk
|-- LABEL1_LABEL2.trk
|-- [...]
|-- LABEL90_LABEL90.trk
The value of first labels should be smaller or equal to the second labels.
Connectivity scripts in scilpy only consider the upper triangular part of the
connectivity matrix.
By default, ignores the empty connections. To save them, use --save_empty.
Note that data_per_point is never included.
"""

import argparse
import logging
import os

from dipy.io.stateful_tractogram import Space, Origin
from dipy.io.utils import is_header_compatible
import h5py

from scilpy.io.hdf5 import (construct_hdf5_header,
construct_hdf5_group_from_streamlines)
from scilpy.io.streamlines import load_tractogram_with_reference
from scilpy.io.utils import (add_overwrite_arg, add_verbose_arg,
assert_inputs_exist,
assert_outputs_exist)


def _build_arg_parser():
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawTextHelpFormatter)
p.add_argument('in_bundles', nargs='+',
help='Path of the input connection(s) or bundle(s).')
p.add_argument('out_hdf5',
help='Output HDF5 filename (.h5).')

p.add_argument('--stored_space', choices=['rasmm', 'voxmm', 'vox'],
default='vox',
help='Space convention in which the streamlines are stored '
'[%(default)s].')
p.add_argument('--stored_origin', choices=['nifti', 'trackvis'],
default='trackvis',
help='Voxel origin convention in which the streamlines are '
'stored [%(default)s].')

p.add_argument('--include_dps', action='store_true',
help='Include the data_per_streamline the metadata.')
p.add_argument('--save_empty', action='store_true',
help='Save empty connections.')

add_verbose_arg(p)
add_overwrite_arg(p)

return p


def main():
parser = _build_arg_parser()
args = parser.parse_args()
logging.getLogger().setLevel(logging.getLevelName(args.verbose))

assert_inputs_exist(parser, args.in_bundles)
assert_outputs_exist(parser, args, args.out_hdf5)

ref_sft = load_tractogram_with_reference(parser, args, args.in_bundles[0])

# Convert STR to the Space and Origin ENUMS
target_space = Space[args.stored_space.upper()]
target_origin = Origin[args.stored_origin.upper()]
with h5py.File(args.out_hdf5, 'w') as hdf5_file:
for i, in_bundle in enumerate(args.in_bundles):
in_basename = os.path.splitext(os.path.basename(in_bundle))[0]
curr_sft = load_tractogram_with_reference(parser, args, in_bundle)
if len(curr_sft) == 0 and not args.save_empty:
logging.warning(f"Skipping {in_bundle} because it is empty")
continue

if not is_header_compatible(ref_sft, curr_sft):
parser.error(f"Header of {in_bundle} is not compatible")

curr_sft.to_space(target_space)
curr_sft.to_origin(target_origin)

if i == 0:
construct_hdf5_header(hdf5_file, ref_sft)
group = hdf5_file.create_group(in_basename)
dps = curr_sft.data_per_streamline if args.include_dps else {}
construct_hdf5_group_from_streamlines(group, curr_sft.streamlines,
dps=dps)


if __name__ == "__main__":
main()
55 changes: 55 additions & 0 deletions scripts/tests/test_tractogram_convert_trk_to_hdf5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import glob
import os
import tempfile

from dipy.io.stateful_tractogram import Space, Origin
import h5py

from scilpy import SCILPY_HOME
from scilpy.io.hdf5 import reconstruct_sft_from_hdf5
from scilpy.io.fetcher import fetch_data, get_testing_files_dict

# If they already exist, this only takes 5 seconds (check md5sum)
fetch_data(get_testing_files_dict(), keys=['connectivity.zip'])
tmp_dir = tempfile.TemporaryDirectory()
in_h5 = os.path.join(SCILPY_HOME, 'connectivity', 'decompose.h5')


def test_help_option(script_runner):
ret = script_runner.run('scil_tractogram_convert_trk_to_hdf5.py', '--help')
assert ret.success


def test_execution_edge_keys(script_runner, monkeypatch):
monkeypatch.chdir(os.path.expanduser(tmp_dir.name))
ret = script_runner.run('scil_tractogram_convert_hdf5_to_trk.py',
in_h5, 'save_trk/', '--edge_keys', '1_10', '1_7')
assert ret.success

# Out directory should have 2 files
out_files = glob.glob('save_trk/*')
assert len(out_files) == 2

ret = script_runner.run('scil_tractogram_convert_trk_to_hdf5.py',
'save_trk/1_10.trk', 'save_trk/1_7.trk',
'two_edges.h5',
'--stored_space', 'voxmm',
'--stored_origin', 'nifti')
assert ret.success

with h5py.File('two_edges.h5', 'r') as hdf5_file:
all_hdf5_keys = list(hdf5_file.keys())
assert all_hdf5_keys == ['1_10', '1_7']

sfts, _ = reconstruct_sft_from_hdf5(hdf5_file, all_hdf5_keys,
space=Space.VOXMM,
origin=Origin.NIFTI)

assert len(sfts) == 2
sfts[0].remove_invalid_streamlines()
sfts[1].remove_invalid_streamlines()

assert len(sfts[0]) == 340
assert len(sfts[1]) == 732

0 comments on commit cb9ef7c

Please sign in to comment.