Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH, WIP] More tracking options #1128

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions AFQ/definitions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def _arglist_to_string(args, get_attr=None):
def name_from_path(path):
file_name = op.basename(path) # get file name
file_name = drop_extension(file_name) # remove extension
if "-" in file_name:
file_name = file_name.split("-")[-1] # get suffix if exists
if "_" in file_name:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we even using this function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ParticipantAFQ uses this function to generate names for scalars that the user inputs (such as a t1 image from another pipeline). I should make it so the user can also define a name, and we call this function only if a name is not provided.

file_name = file_name.split("_")[-1] # get suffix if exists
return file_name


Expand Down
76 changes: 56 additions & 20 deletions AFQ/tasks/tractography.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@
from AFQ.tasks.utils import get_default_args
from AFQ.definitions.image import ScalarImage

from dipy.reconst.csdeconv import (
ConstrainedSphericalDeconvModel,
auto_response_ssst)
from dipy.direction import BootDirectionGetter
from dipy.reconst.shm import CsaOdfModel
from dipy.reconst.gqi import GeneralizedQSamplingModel
import dipy.data as dpd

try:
from trx.trx_file_memmap import TrxFile
has_trx = True
Expand Down Expand Up @@ -111,17 +119,49 @@ def streamlines(data_imap, seed, stop,

# get odf_model
odf_model = this_tracking_params["odf_model"]
if odf_model == "DTI":
params_file = data_imap["dti_params"]
elif odf_model == "CSD":
params_file = data_imap["csd_params"]
elif odf_model == "DKI":
params_file = data_imap["dki_params"]
elif odf_model == "GQ":
params_file = data_imap["gq_params"]
if this_tracking_params["directions"] == "boot":
# in this case, the params img is just used as a reference
params_file = data_imap["b0"]
# Note: you will also likely get on the order of a few
# streamlines per second, ie, ~10k an hour when using
# this bootstrapping implementation
# TODO: we need someone to get the model params from the user
if odf_model == "CSD": #
response, _ = auto_response_ssst(
data_imap["gtab"], data_imap["data"])
model = ConstrainedSphericalDeconvModel(
data_imap["gtab"], response)
elif odf_model == "GQ":
model = GeneralizedQSamplingModel(data_imap["gtab"])
elif odf_model == "CSA":
model = CsaOdfModel(data_imap["gtab"], 6)
else:
raise NotImplementedError((
"Bootstrap direction getter currently not implemented "
"for DTI, DKI"))
sphere = this_tracking_params["sphere"]
if sphere is None:
sphere = dpd.default_sphere
this_tracking_params["directions"] = BootDirectionGetter.from_data(
data_imap["data"],
model,
max_angle=this_tracking_params["max_angle"],
sphere=sphere,
sh_order=6)
else:
raise TypeError((
f"The ODF model you gave ({odf_model}) was not recognized"))
if odf_model == "DTI":
params_file = data_imap["dti_params"]
elif odf_model == "CSD":
params_file = data_imap["csd_params"]
elif odf_model == "DKI":
params_file = data_imap["dki_params"]
elif odf_model == "GQ":
params_file = data_imap["gq_params"]
elif odf_model == "CSA":
params_file = data_imap["csa_params"]
else:
raise TypeError((
f"The ODF model you gave ({odf_model}) was not recognized"))

# get masks
this_tracking_params['seed_mask'] = nib.load(seed).get_fdata()
Expand All @@ -132,21 +172,17 @@ def streamlines(data_imap, seed, stop,

is_trx = this_tracking_params.get("trx", False)

start_time = time()
sft = aft.track(
params_file,
**this_tracking_params)
if is_trx:
start_time = time()
dtype_dict = {'positions': np.float16, 'offsets': np.uint32}
lazyt = aft.track(params_file, **this_tracking_params)
sft = TrxFile.from_lazy_tractogram(
lazyt,
sft,
seed,
dtype_dict=dtype_dict)
n_streamlines = len(sft)

else:
start_time = time()
sft = aft.track(params_file, **this_tracking_params)
sft.to_vox()
n_streamlines = len(sft.streamlines)
n_streamlines = len(sft)

return sft, _meta_from_tracking_params(
tracking_params, start_time,
Expand Down
50 changes: 30 additions & 20 deletions AFQ/tractography/tractography.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@

import dipy.data as dpd
from dipy.align import resample
from dipy.direction import (DeterministicMaximumDirectionGetter,
ProbabilisticDirectionGetter)
from dipy.direction import (
PTTDirectionGetter,
DeterministicMaximumDirectionGetter,
ProbabilisticDirectionGetter)
from dipy.tracking.direction_getter import DirectionGetter
from dipy.io.stateful_tractogram import StatefulTractogram, Space
from dipy.tracking.stopping_criterion import (ThresholdStoppingCriterion,
CmcStoppingCriterion,
Expand All @@ -32,11 +35,12 @@ def track(params_file, directions="prob", max_angle=30., sphere=None,
Parameters
----------
params_file : str, nibabel img.
Full path to a nifti file containing CSD spherical harmonic
coefficients, or nibabel img with model params.
Nifti image or a full path to a nifti file containing
CSD spherical harmonic coefficients.
directions : str
How tracking directions are determined.
One of: {"det" | "prob"}
Must be a direction getter from dipy, or one of:
{"det" | "prob" | "ptt"}
Default: "prob"
max_angle : float, optional.
The maximum turning angle in each step. Default: 30
Expand Down Expand Up @@ -92,7 +96,7 @@ def track(params_file, directions="prob", max_angle=30., sphere=None,
maxlen: int, optional
The miminal length (mm) in a streamline. Default: 250
odf_model : str, optional
One of {"DTI", "CSD", "DKI"}. Defaults to use "DTI"
One of {"DTI", "CSD", "DKI", "GQ", "CSA"}. Defaults to use "DTI"
tracker : str, optional
Which strategy to use in tracking. This can be the standard local
tracking ("local") or Particle Filtering Tracking ([Girard2014]_).
Expand Down Expand Up @@ -122,7 +126,6 @@ def track(params_file, directions="prob", max_angle=30., sphere=None,

model_params = params_img.get_fdata()
odf_model = odf_model.upper()
directions = directions.lower()

# We need to calculate the size of a voxel, so we can transform
# from mm to voxel units:
Expand All @@ -137,22 +140,29 @@ def track(params_file, directions="prob", max_angle=30., sphere=None,
if sphere is None:
sphere = dpd.default_sphere

logger.info("Getting Directions...")
if directions == "det":
dg = DeterministicMaximumDirectionGetter
elif directions == "prob":
dg = ProbabilisticDirectionGetter
if isinstance(directions, DirectionGetter):
dg = directions
elif isinstance(directions, str):
directions = directions.lower()
if directions == "det":
dg = DeterministicMaximumDirectionGetter
elif directions == "prob":
dg = ProbabilisticDirectionGetter
elif directions == "ptt":
dg = PTTDirectionGetter
if odf_model == "DTI" or odf_model == "DKI":
evals = model_params[..., :3]
evecs = model_params[..., 3:12].reshape(
params_img.shape[:3] + (3, 3))
odf = tensor_odf(evals, evecs, sphere)
dg = dg.from_pmf(odf, max_angle=max_angle, sphere=sphere)
elif odf_model == "CSD" or odf_model == "GQ" or odf_model == "CSA":
dg = dg.from_shcoeff(
model_params, max_angle=max_angle,
sphere=sphere)
else:
raise ValueError(f"Unrecognized direction '{directions}'.")

if odf_model == "DTI" or odf_model == "DKI":
evals = model_params[..., :3]
evecs = model_params[..., 3:12].reshape(params_img.shape[:3] + (3, 3))
odf = tensor_odf(evals, evecs, sphere)
dg = dg.from_pmf(odf, max_angle=max_angle, sphere=sphere)
elif odf_model == "CSD" or odf_model == "GQ":
dg = dg.from_shcoeff(model_params, max_angle=max_angle, sphere=sphere)

if tracker == "local":
if stop_mask is None:
stop_mask = np.ones(params_img.shape[:3])
Expand Down