diff --git a/AFQ/definitions/utils.py b/AFQ/definitions/utils.py index 1629821fb..deda3dc80 100644 --- a/AFQ/definitions/utils.py +++ b/AFQ/definitions/utils.py @@ -72,8 +72,8 @@ def _arglist_to_string(args, get_attr=None): def name_from_path(path): file_name = op.basename(path) # get file name file_name = drop_extension(file_name) # remove extension - if "-" in file_name: - file_name = file_name.split("-")[-1] # get suffix if exists + if "_" in file_name: + file_name = file_name.split("_")[-1] # get suffix if exists return file_name diff --git a/AFQ/tasks/tractography.py b/AFQ/tasks/tractography.py index 2c7a04d68..804c28baa 100644 --- a/AFQ/tasks/tractography.py +++ b/AFQ/tasks/tractography.py @@ -12,6 +12,14 @@ from AFQ.tasks.utils import get_default_args from AFQ.definitions.image import ScalarImage +from dipy.reconst.csdeconv import ( + ConstrainedSphericalDeconvModel, + auto_response_ssst) +from dipy.direction import BootDirectionGetter +from dipy.reconst.shm import CsaOdfModel +from dipy.reconst.gqi import GeneralizedQSamplingModel +import dipy.data as dpd + try: from trx.trx_file_memmap import TrxFile has_trx = True @@ -111,17 +119,49 @@ def streamlines(data_imap, seed, stop, # get odf_model odf_model = this_tracking_params["odf_model"] - if odf_model == "DTI": - params_file = data_imap["dti_params"] - elif odf_model == "CSD": - params_file = data_imap["csd_params"] - elif odf_model == "DKI": - params_file = data_imap["dki_params"] - elif odf_model == "GQ": - params_file = data_imap["gq_params"] + if this_tracking_params["directions"] == "boot": + # in this case, the params img is just used as a reference + params_file = data_imap["b0"] + # Note: you will also likely get on the order of a few + # streamlines per second, ie, ~10k an hour when using + # this bootstrapping implementation + # TODO: we need someone to get the model params from the user + if odf_model == "CSD": # + response, _ = auto_response_ssst( + data_imap["gtab"], data_imap["data"]) + model = ConstrainedSphericalDeconvModel( + data_imap["gtab"], response) + elif odf_model == "GQ": + model = GeneralizedQSamplingModel(data_imap["gtab"]) + elif odf_model == "CSA": + model = CsaOdfModel(data_imap["gtab"], 6) + else: + raise NotImplementedError(( + "Bootstrap direction getter currently not implemented " + "for DTI, DKI")) + sphere = this_tracking_params["sphere"] + if sphere is None: + sphere = dpd.default_sphere + this_tracking_params["directions"] = BootDirectionGetter.from_data( + data_imap["data"], + model, + max_angle=this_tracking_params["max_angle"], + sphere=sphere, + sh_order=6) else: - raise TypeError(( - f"The ODF model you gave ({odf_model}) was not recognized")) + if odf_model == "DTI": + params_file = data_imap["dti_params"] + elif odf_model == "CSD": + params_file = data_imap["csd_params"] + elif odf_model == "DKI": + params_file = data_imap["dki_params"] + elif odf_model == "GQ": + params_file = data_imap["gq_params"] + elif odf_model == "CSA": + params_file = data_imap["csa_params"] + else: + raise TypeError(( + f"The ODF model you gave ({odf_model}) was not recognized")) # get masks this_tracking_params['seed_mask'] = nib.load(seed).get_fdata() @@ -132,21 +172,17 @@ def streamlines(data_imap, seed, stop, is_trx = this_tracking_params.get("trx", False) + start_time = time() + sft = aft.track( + params_file, + **this_tracking_params) if is_trx: - start_time = time() dtype_dict = {'positions': np.float16, 'offsets': np.uint32} - lazyt = aft.track(params_file, **this_tracking_params) sft = TrxFile.from_lazy_tractogram( - lazyt, + sft, seed, dtype_dict=dtype_dict) - n_streamlines = len(sft) - - else: - start_time = time() - sft = aft.track(params_file, **this_tracking_params) - sft.to_vox() - n_streamlines = len(sft.streamlines) + n_streamlines = len(sft) return sft, _meta_from_tracking_params( tracking_params, start_time, diff --git a/AFQ/tractography/tractography.py b/AFQ/tractography/tractography.py index 1bfedc5d7..cdb385a16 100644 --- a/AFQ/tractography/tractography.py +++ b/AFQ/tractography/tractography.py @@ -6,8 +6,11 @@ import dipy.data as dpd from dipy.align import resample -from dipy.direction import (DeterministicMaximumDirectionGetter, - ProbabilisticDirectionGetter) +from dipy.direction import ( + PTTDirectionGetter, + DeterministicMaximumDirectionGetter, + ProbabilisticDirectionGetter) +from dipy.tracking.direction_getter import DirectionGetter from dipy.io.stateful_tractogram import StatefulTractogram, Space from dipy.tracking.stopping_criterion import (ThresholdStoppingCriterion, CmcStoppingCriterion, @@ -32,11 +35,12 @@ def track(params_file, directions="prob", max_angle=30., sphere=None, Parameters ---------- params_file : str, nibabel img. - Full path to a nifti file containing CSD spherical harmonic - coefficients, or nibabel img with model params. + Nifti image or a full path to a nifti file containing + CSD spherical harmonic coefficients. directions : str How tracking directions are determined. - One of: {"det" | "prob"} + Must be a direction getter from dipy, or one of: + {"det" | "prob" | "ptt"} Default: "prob" max_angle : float, optional. The maximum turning angle in each step. Default: 30 @@ -92,7 +96,7 @@ def track(params_file, directions="prob", max_angle=30., sphere=None, maxlen: int, optional The miminal length (mm) in a streamline. Default: 250 odf_model : str, optional - One of {"DTI", "CSD", "DKI"}. Defaults to use "DTI" + One of {"DTI", "CSD", "DKI", "GQ", "CSA"}. Defaults to use "DTI" tracker : str, optional Which strategy to use in tracking. This can be the standard local tracking ("local") or Particle Filtering Tracking ([Girard2014]_). @@ -122,7 +126,6 @@ def track(params_file, directions="prob", max_angle=30., sphere=None, model_params = params_img.get_fdata() odf_model = odf_model.upper() - directions = directions.lower() # We need to calculate the size of a voxel, so we can transform # from mm to voxel units: @@ -137,22 +140,29 @@ def track(params_file, directions="prob", max_angle=30., sphere=None, if sphere is None: sphere = dpd.default_sphere - logger.info("Getting Directions...") - if directions == "det": - dg = DeterministicMaximumDirectionGetter - elif directions == "prob": - dg = ProbabilisticDirectionGetter + if isinstance(directions, DirectionGetter): + dg = directions + elif isinstance(directions, str): + directions = directions.lower() + if directions == "det": + dg = DeterministicMaximumDirectionGetter + elif directions == "prob": + dg = ProbabilisticDirectionGetter + elif directions == "ptt": + dg = PTTDirectionGetter + if odf_model == "DTI" or odf_model == "DKI": + evals = model_params[..., :3] + evecs = model_params[..., 3:12].reshape( + params_img.shape[:3] + (3, 3)) + odf = tensor_odf(evals, evecs, sphere) + dg = dg.from_pmf(odf, max_angle=max_angle, sphere=sphere) + elif odf_model == "CSD" or odf_model == "GQ" or odf_model == "CSA": + dg = dg.from_shcoeff( + model_params, max_angle=max_angle, + sphere=sphere) else: raise ValueError(f"Unrecognized direction '{directions}'.") - if odf_model == "DTI" or odf_model == "DKI": - evals = model_params[..., :3] - evecs = model_params[..., 3:12].reshape(params_img.shape[:3] + (3, 3)) - odf = tensor_odf(evals, evecs, sphere) - dg = dg.from_pmf(odf, max_angle=max_angle, sphere=sphere) - elif odf_model == "CSD" or odf_model == "GQ": - dg = dg.from_shcoeff(model_params, max_angle=max_angle, sphere=sphere) - if tracker == "local": if stop_mask is None: stop_mask = np.ones(params_img.shape[:3])