In [1]:
import cloudknot as ck

In [2]:
ck.set_region('us-east-1')

In [3]:
def afq_hcp(params):
    subject, hcp_ak, hcp_sk = params
    import numpy as np
    import pandas as pd
    import s3fs
    import json
    import logging
    import os.path as op
    import nibabel as nib
    import dipy.data as dpd
    import dipy.tracking.utils as dtu
    import dipy.tracking.streamline as dts
    from dipy.io.streamline import save_tractogram, load_tractogram
    from dipy.stats.analysis import afq_profile, gaussian_weights
    from dipy.io.stateful_tractogram import StatefulTractogram
    from dipy.io.stateful_tractogram import Space
    import dipy.core.gradients as dpg
    from dipy.segment.mask import median_otsu

    import AFQ.data as afd
    import AFQ.tractography as aft
    import AFQ.registration as reg
    import AFQ.dti as dti
    import AFQ.segmentation as seg
    from AFQ import api
    from AFQ import csd
    
    logging.basicConfig(level=logging.INFO)
    log = logging.getLogger(__name__)    
    
    log.info(f"Fetching HCP subject {subject}")
    afd.fetch_hcp([subject], 
                  profile_name=False,
                  aws_access_key_id=hcp_ak,
                  aws_secret_access_key=hcp_sk)    
        
    dwi_dir = op.join(afd.afq_home, 'HCP', 'derivatives',
                      'dmriprep', f'sub-{subject}', 'sess-01/dwi')

    anat_dir = op.join(afd.afq_home, 'HCP', 'derivatives',
                      'dmriprep', f'sub-{subject}', 'sess-01/anat')

    hardi_fdata = op.join(dwi_dir, f"sub-{subject}_dwi.nii.gz")
    hardi_fbval = op.join(dwi_dir, f"sub-{subject}_dwi.bval")
    hardi_fbvec = op.join(dwi_dir, f"sub-{subject}_dwi.bvec")

    log.info(f"Reading data from file {hardi_fdata}")
    img = nib.load(hardi_fdata)
    log.info(f"Creating gradient table from {hardi_fbval} and {hardi_fbvec}")
    gtab = dpg.gradient_table(hardi_fbval, hardi_fbvec)
    
    bucket_name = 'hcp.recobundles'
    fs = s3fs.S3FileSystem()
    
    wm_mask_fname = f'{bucket_name}/sub-{subject}/sub-{subject}_wm_mask.nii.gz'
    if fs.exists(wm_mask_fname):
        log.info(f"WM mask exists. Reading from {wm_mask_fname}")
        wm_img = afd.s3fs_nifti_read(wm_mask_fname)
        wm_mask = wm_img.get_data()
    else:
        log.info(f"Calculating WM segmentation")
        wm_labels=[250, 251, 252, 253, 254, 255, 41, 2, 16, 77]
        seg_img = nib.load(op.join(anat_dir, f"sub-{subject}_aparc+aseg.nii.gz"))
        seg_data_orig = seg_img.get_fdata()
        # For different sets of labels, extract all the voxels that
        # have any of these values:
        wm_mask = np.sum(np.concatenate([(seg_data_orig == l)[..., None]
                                        for l in wm_labels], -1), -1)

        # Resample to DWI data:
        dwi_data = img.get_fdata()
        wm_mask = np.round(reg.resample(wm_mask, 
                                        dwi_data[..., 0],
                                        seg_img.affine,
                                        img.affine)).astype(int)

        wm_img = nib.Nifti1Image(wm_mask.astype(int),
                                 img.affine)
        log.info(f"Saving to {wm_mask_fname}")
        afd.s3fs_nifti_write(wm_img, wm_mask_fname)
    
    fa_fname = f'{bucket_name}/sub-{subject}/sub-{subject}_dti_FA.nii.gz'
    dti_params_fname = f'{bucket_name}/sub-{subject}/sub-{subject}_dti.nii.gz'
    dti_meta_fname = f'{bucket_name}/sub-{subject}/sub-{subject}_dti.json'
    if fs.exists(fa_fname):
        log.info(f"DTI already exists. Reading FA from {fa_fname}")
        log.info(f"DTI already exists. Reading params from {dti_params_fname}")
        FA_img = afd.s3fs_nifti_read(fa_fname)
        dti_params = afd.s3fs_nifti_read(dti_params_fname)
    else:
        log.info("Calculating DTI")
        dti_params = dti.fit_dti(hardi_fdata, hardi_fbval, hardi_fbvec,
                                out_dir='.', b0_threshold=50,
                                mask=wm_mask)
        FA_img = nib.load('./dti_FA.nii.gz')
        log.info(f"Writing FA to {fa_fname}")
        afd.s3fs_nifti_write(FA_img, fa_fname)
        dti_params_img = nib.load('./dti_params.nii.gz')
        log.info(f"Writing DTI params to {dti_params_fname}")
        afd.s3fs_nifti_write(dti_params_img, dti_params_fname)
        dti_params_json = {"Model": "Diffusion Tensor",
                           "OrientationRepresentation": "param",
                            "ReferenceAxes": "xyz",
                            "Parameters": {
                                "FitMethod": "ols",
                                "OutlierRejection": False
                                }
                          }
        log.info(f"Writing DTI metadata to {dti_meta_fname}")
        afd.s3fs_json_write(dti_params_json, dti_meta_fname)

    log.info(f"Reading FA data from img")
    FA_data = FA_img.get_fdata()

    csd_fname = f'hcp.recobundles/sub-{subject}/sub-{subject}_csd.nii.gz'
    csd_meta_fname = f'hcp.recobundles/sub-{subject}/sub-{subject}_csd.json'

    if fs.exists(csd_fname):
        log.info(f"CSD already exists. Getting it from {csd_fname}")        
        csd_params = afd.s3fs_nifti_read(csd_fname)
    else:
        log.info(f"Calculating CSD")        
        csd_params = csd.fit_csd(hardi_fdata, hardi_fbval, hardi_fbvec,
                                 out_dir='.', b0_threshold=50,
                                 mask=wm_mask)
        afd.s3fs_nifti_write(nib.load(csd_params), csd_fname)

        
        csd_params_json = {
    "Model": "Constrained Spherical Deconvolution (CSD)",
    "ModelURL": "https://github.com/nipy/dipy/commit/abf31d15a0ee5dc0704ee03ebbba57358d540612",
    "Shells": [ 0, 1000, 2000, 3000 ],
    "Parameters": {
        "ResponseFunctionTensor" : "auto",
        "SphericalHarmonicBasis": "Descoteaux",
        "NonNegativityConstraint": "hard",
        "SphericalHarmonicDegree" : 8
                }
            }
        
        log.info(f"Writing CSD metadata to {csd_meta_fname}")
        afd.s3fs_json_write(csd_params_json, csd_meta_fname)


    csd_streamlines_fname = f'hcp.recobundles/sub-{subject}/sub-{subject}_model-csd_track-det.trk'
    csd_streamlines_meta_fname = f'hcp.recobundles/sub-{subject}/sub-{subject}_model-csd_track-det.json'
    if fs.exists(csd_streamlines_fname):
        log.info(f"Streamlines already exist. Loading from {csd_streamlines_fname}")        
        fs.download(csd_streamlines_fname, './csd_streamlines.trk')
        tg = load_tractogram('./csd_streamlines.trk', img)
        streamlines = tg.streamlines
    else:
        log.info(f"Generating streamlines")      
        seed_roi = np.zeros(img.shape[:-1])
        seed_roi[FA_data > 0.4] = 1
        seed_roi[wm_mask < 1] = 0
        streamlines = aft.track(csd_params, seed_mask=seed_roi,
                                directions='det', stop_mask=FA_data,
                                stop_threshold=0.1)
        log.info(f"After tracking, there are {len(streamlines)} streamlines")
        sft = StatefulTractogram(streamlines, img, Space.RASMM)
        save_tractogram(sft, './csd_streamlines.trk',
                        bbox_valid_check=False)
        log.info(f"Uploading streamlines to {csd_streamlines_fname}")
        fs.upload('./csd_streamlines.trk', csd_streamlines_fname)
        csd_streamlines_json = {
            "Algorithm" : "LocalTracking",
            "AlgorithmURL":"https://github.com/yeatmanlab/pyAFQ/commit/c04835cd4ca13d28c20bb449d6f088e656c55e57",
            "Parameters":{
            "SeedRoi": "dti_FA>0.4",
            "Directions": "det",
            "StopMask" : "dti_FA<0.1"}
            }
        log.info(f"Writing streamlines metadata to {csd_streamlines_meta_fname}")
        afd.s3fs_json_write(csd_streamlines_json, csd_streamlines_meta_fname)
    
    streamlines = dts.Streamlines(
            dtu.transform_tracking_output(streamlines,
                                  np.linalg.inv(img.affine)))
 
    log.info("Segmenting")
        
    # Use the default for waypoint ROI
    bundles = api.make_bundle_dict()

    segmentation = seg.Segmentation(b0_threshold=50,
                                    prob_threshold=10,
                                    return_idx=True)
    segmentation.segment(bundles, 
                         streamlines, 
                         fdata=hardi_fdata,
                         fbval=hardi_fbval,
                         fbvec=hardi_fbvec)

    fiber_groups = segmentation.fiber_groups

    sl_count = []
    for kk in fiber_groups:
        log.info(f"Cleaning {kk}")
        len_before = len(fiber_groups[kk]['sl'])
        log.info(f"Before cleaning there are {len_before} streamlines")
        new_fibers, idx_in_bundle = seg.clean_fiber_group(
                            fiber_groups[kk]['sl'],
                            return_idx=True, 
                            clean_threshold=3)

        log.info(f"After cleaning there are {len(new_fibers)} streamlines")
        idx_in_global = fiber_groups[kk]['idx'][idx_in_bundle]
        
        sl_count.append(len(new_fibers))
        log.info(f"There are {sl_count[-1]} streamlines in {kk}")
        sft = StatefulTractogram(
            dtu.transform_tracking_output(new_fibers, img.affine),
            img, Space.RASMM)

        local_tg_fname = './%s_reco.trk'%kk
        save_tractogram(sft, local_tg_fname,
                        bbox_valid_check=False)
        tg_fname = f'hcp.recobundles/sub-{subject}/sub-{subject}_model-csd_track-det_segment-afq_bundle-{kk}.trk'
        log.info(f"Uploading {local_tg_fname} to {tg_fname}")
        fs.upload('./%s_reco.trk'%kk, tg_fname)
        tg_meta_fname = f'hcp.recobundles/sub-{subject}/sub-{subject}_model-csd_track-det_segment-afq_bundle-{kk}.json'
        tg_meta_json = {
            "Algorithm" : "AFQ",
            "AlgorithmURL" : "https://github.com/yeatmanlab/pyAFQ/commit/f0f486d",
            "Parameters":
            {"clean_threshold":3,
             "prob_threshold": 10}
        }
        
        log.info(f"Uploading segmentation metadata to {tg_meta_fname}")
        afd.s3fs_json_write(tg_meta_json, tg_meta_fname)

        np.save('bundle_idx.npy', idx_in_global)
        idx_fname = f'hcp.recobundles/sub-{subject}/sub-{subject}_model-csd_track-det_segment-afq_bundle-{kk}_idx.npy'
        log.info(f"Uploading bundle indices to {idx_fname}")
        fs.upload('bundle_idx.npy', idx_fname)

    log.info("Saving streamline counts")
    sl_count = pd.DataFrame(data=sl_count, index=fiber_groups.keys(), columns=["streamlines"])
    sl_count.to_csv("./sl_count.csv")
    sl_count_fname = f'hcp.recobundles/sub-{subject}/sub-{subject}_model-csd_track-det_segment-afq_counts.csv'
    log.info(f"Uploading streamline counts to {sl_count_fname}")
    fs.upload("./sl_count.csv", sl_count_fname)

In [4]:
import configparser
import os.path as op

In [5]:
CP = configparser.ConfigParser()
CP.read_file(open(op.join(op.expanduser('~'), '.aws', 'credentials')))
CP.sections()
ak = CP.get('hcp', 'AWS_ACCESS_KEY_ID')
sk = CP.get('hcp', 'AWS_SECRET_ACCESS_KEY')

In [6]:
afq_knot = ck.Knot(name='afq_hcp-64gb-191101-27',
                  func=afq_hcp,
                  image_github_installs="https://github.com/arokem/pyAFQ.git@f0f486d",
                  pars_policies=('AmazonS3FullAccess',),
                  resource_type="SPOT",
                  bid_percentage=100,
                  memory=64000)

In [7]:
inputs = [(sub, ak, sk) for sub in [
            100408,
            100307,
            100610,
            101006,
            101107,
            101309,
            101410,
            101915,
            102008,
            102109,
            102311,
            102513,
            100206,
            970764,
            971160,
            972566,
            973770,
            978578,
            979984,
            983773,
            984472]]

In [8]:
ft = afq_knot.map(inputs)

In [9]:
afq_knot.view_jobs()

Job ID              Name                        Status   
---------------------------------------------------------
a3f9cf77-7504-4c22-a3ba-946a10400867        afq-hcp-64gb-191101-27-0        SUBMITTED


In [10]:
j0 = afq_knot.jobs[0]

In [13]:
j0.status

{'status': 'PENDING',
 'statusReason': None,
 'attempts': [],
 'arrayProperties': {'statusSummary': {'STARTING': 1,
   'FAILED': 0,
   'RUNNING': 18,
   'SUCCEEDED': 0,
   'RUNNABLE': 2,
   'SUBMITTED': 0,
   'PENDING': 0},
  'size': 21}}

In [12]:
# afq_knot.clobber()