In [1]:
import cloudknot as ck

In [2]:
ck.set_region('us-east-1')

In [3]:
import configparser
import os.path as op

In [4]:
CP = configparser.ConfigParser()
CP.read_file(open(op.join(op.expanduser('~'), '.aws', 'credentials')))
CP.sections()
ak = CP.get('hcp', 'AWS_ACCESS_KEY_ID')
sk = CP.get('hcp', 'AWS_SECRET_ACCESS_KEY')

In [5]:
def recobundles_hcp(params):
    subject, hcp_ak, hcp_sk = params
    import pandas as pd
    import s3fs
    import logging
    import boto3
    import os.path as op
    import numpy as np
    import nibabel as nib
    import dipy.data as dpd
    import dipy.tracking.utils as dtu
    import dipy.tracking.streamline as dts
    from dipy.io.streamline import save_tractogram, load_tractogram
    from dipy.stats.analysis import afq_profile, gaussian_weights
    from dipy.io.stateful_tractogram import StatefulTractogram
    from dipy.io.stateful_tractogram import Space
    import dipy.core.gradients as dpg
    from dipy.segment.mask import median_otsu

    import AFQ.data as afd
    import AFQ.tractography as aft
    import AFQ.registration as reg
    import AFQ.dti as dti
    import AFQ.segmentation as seg
    from AFQ import api
    from AFQ import csd
    

    log = logging.Logger(__name__)    
    
    log.info(f"Fetching HCP subject {subject}")
    afd.fetch_hcp([subject], 
                  aws_access_key_id=hcp_ak,
                  aws_secret_access_key=hcp_sk)    
        
    dwi_dir = op.join(afd.afq_home, 'HCP', 'derivatives',
                      'dmriprep', f'sub-{subject}', 'sess-01/dwi')

    hardi_fdata = op.join(dwi_dir, f"sub-{subject}_dwi.nii.gz")
    hardi_fbval = op.join(dwi_dir, f"sub-{subject}_dwi.bval")
    hardi_fbvec = op.join(dwi_dir, f"sub-{subject}_dwi.bvec")

    log.info(f"Reading data from file {hardi_fdata}")
    img = nib.load(hardi_fdata)
    log.info(f"Creating gradient table from {hardi_fbval} and {hardi_fbvec}")
    gtab = dpg.gradient_table(hardi_fbval, hardi_fbvec)

    client = boto3.resource('s3')
    bucket_name = 'hcp.recobundles'
    b = client.Bucket(bucket_name)
    
    fs = s3fs.S3FileSystem()
    
    brain_mask_fname = f'hcp.recobundles/sub-{subject}/sub-{subject}_brain_mask.nii.gz'
    if fs.exists(brain_mask_fname):
        log.info(f"Brain-mask exists. Reading from {brain_mask_fname}")
        be_img = afd.s3fs_nifti_read(brain_mask_fname)
        brain_mask = be_img.get_data()
    else:
        log.info("Calculating brain-mask")
        mean_b0 = np.mean(img.get_data()[..., gtab.b0s_mask], -1)
        _, brain_mask = median_otsu(mean_b0, median_radius=4,
                                    numpass=1, autocrop=False,
                                    vol_idx=None, dilate=10)
        be_img = nib.Nifti1Image(brain_mask.astype(int),
                                img.affine)
        log.info(f"Saving to {brain_mask_fname}")
        afd.s3fs_nifti_write(be_img, brain_mask_fname)
        
    
    fa_fname = f'hcp.recobundles/sub-{subject}/sub-{subject}_dti_FA.nii.gz'
    dti_params_fname = f'hcp.recobundles/sub-{subject}/sub-{subject}_dti.nii.gz'
    dti_meta_fname = f'hcp.recobundles/sub-{subject}/sub-{subject}_dti.json'
    if fs.exists(fa_fname):
        log.info(f"DTI already exists. Reading FA from {fa_fname}")
        log.info(f"DTI already exists. Reading params from {dti_params_fname}")
        FA_img = afd.s3fs_nifti_read(fa_fname)
        dti_params = afd.s3fs_nifti_read(dti_params_fname)
    else:
        log.info("Calculating DTI")
        dti_params = dti.fit_dti(hardi_fdata, hardi_fbval, hardi_fbvec,
                                out_dir='.', b0_threshold=50,
                                mask=brain_mask)
        FA_img = nib.load('./dti_FA.nii.gz')
        log.info(f"Writing FA to {fa_fname}")
        afd.s3fs_nifti_write(FA_img, fa_fname)
        dti_params_img = nib.load('./dti_params.nii.gz')
        log.info(f"Writing DTI params to {dti_params_fname}")
        afd.s3fs_nifti_write(dti_params_img, dti_params_fname)
        dti_params_json = {"Model": "Diffusion Tensor",
                           "OrientationRepresentation": "param",
                            "ReferenceAxes": "xyz",
                            "Parameters": {
                                "FitMethod": "ols",
                                "OutlierRejection": False
                                }
                          }
        log.info(f"Writing DTI metadata to {dti_meta_fname}")
        afd.s3fs_write_json(dti_params_json, dti_meta_fname)

    log.info(f"Reading FA data from img")
    FA_data = FA_img.get_fdata()

    log.info("Getting the MNI template")
    
    MNI_T2_img = afd.s3fs_nifti_read('hcp.recobundles/mni_icbm152_t2_tal_nlin_asym_09a.nii')
    mapping_fname = f'hcp.recobundles/sub-{subject}/sub-{subject}_mapping.nii.gz'
    if fs.exists('mapping.nii.gz'):
        log.info(f"Mapping already exists. Getting it from {mapping_fname}")
        fs.download(mapping_fname, './mapping.nii.gz')
        log.info(f"Reading mapping from './mapping.nii.gz'")
        mapping = reg.read_mapping('./mapping.nii.gz', img, MNI_T2_img)
    else:
        log.info(f"Creating mapping.")
        gtab = dpg.gradient_table(hardi_fbval, hardi_fbvec)
        log.info(f"Calculating SyN registration.")
        warped_hardi, mapping = reg.syn_register_dwi(hardi_fdata, gtab,
                                                    template=MNI_T2_img)
        log.info(f"Writing to './mapping.nii.gz'")
        reg.write_mapping(mapping, './mapping.nii.gz')
        log.info(f"Uploading to {mapping_fname}")
        fs.upload(mapping.nii.gz, mapping_fname)
        
    bundle_names = ['CST',
                    'C',
                    'F',
                    'UF',
                    'MCP',
                    'AF',
                    'CCMid',
                    'AF',
                    'CC_ForcepsMajor',
                    'CC_ForcepsMinor',
                    'IFOF']

    bundles = api.make_bundle_dict(bundle_names=bundle_names, seg_algo="reco")


    csd_fname = f'hcp.recobundles/sub-{subject}/sub-{subject}_csd.nii.gz'
    csd_meta_fname = f'hcp.recobundles/sub-{subject}/sub-{subject}_csd.json'

    if fs.exists(csd_fname):
        log.info(f"CSD already exists. Getting it from {csd_fname}")        
        csd_params = afd.s3fs_nifti_read(csd_fname)
    else:
        log.info(f"Calculating CSD")        
        csd_params = csd.fit_csd(hardi_fdata, hardi_fbval, hardi_fbvec,
                                 out_dir='.', b0_threshold=50,
                                 mask=brain_mask)
        
        csd_params_json = {
    "Model": "Constrained Spherical Deconvolution (CSD)",
    "ModelURL": "https://github.com/nipy/dipy/commit/abf31d15a0ee5dc0704ee03ebbba57358d540612",
    "Shells": [ 0, 1000, 2000, 3000 ],
    "Parameters": {
        "ResponseFunctionTensor" : "auto",
        "SphericalHarmonicBasis": "Descoteaux",
        "NonNegativityConstraint": "hard",
        "SphericalHarmonicDegree" : 8
                }
            }
        
        log.info(f"Writing CSD metadata to {csd_meta_fname}")
        afd.s3fs_write_json(csd_params_json, csd_meta_fname)


    csd_streamlines_fname = f'hcp.recobundles/sub-{subject}/sub-{subject}_model-csd_track-det.trk'
    csd_streamlines_meta_fname = f'hcp.recobundles/sub-{subject}/sub-{subject}_model-csd_track-det.json'
    if fs.exists(csd_streamlines_fname):
        log.info(f"Streamlines already exist. Loading from {csd_streamlines_fname}")        
        fs.download(csd_streamlines_fname, '/csd_streamlines.trk')
        tg = load_tractogram('./csd_streamlines.trk', img)
        streamlines = tg.streamlines
    else:
        log.info(f"Generating streamlines")      
        seed_roi = np.zeros(img.shape[:-1])
        seed_roi[FA_data > 0.4] = 1
        streamlines = aft.track(csd_params, seed_mask=seed_roi,
                                directions='det', stop_mask=FA_data,
                                stop_threshold=0.1)
        log.info(f"After tracking, there are {len(streamlines)} streamlines")
        sft = StatefulTractogram(streamlines, img, Space.RASMM)
        save_tractogram(sft, './csd_streamlines_reco.trk',
                        bbox_valid_check=False)
        log.info(f"Uploading streamlines to {csd_streamlines_fname}")
        fs.upload('./csd_streamlines.trk')
        csd_streamlines_json = {
            "Algorithm" : "LocalTracking",
            "AlgorithmURL":"https://github.com/yeatmanlab/pyAFQ/commit/c04835cd4ca13d28c20bb449d6f088e656c55e57",
            "Parameters":{
            "SeedRoi": "dti_FA>0.4",
            "Directions": "det",
            "StopMask" : "dti_FA<0.1"}
            }
        log.info(f"Writing streamlines metadata to {csd_streamlines_meta_fname}")
        afd.s3fs_write_json(csd_streamlines_json, csd_streamlines_meta_fname)
        
    log.info("Segmenting")
    segmentation = seg.Segmentation(algo='reco',
                                    model_clust_thr=20,
                                    reduction_thr=20)
    segmentation.segment(bundles, streamlines)
    fiber_groups = segmentation.fiber_groups

    sl_count = []
    for kk in fiber_groups:
        sl_count.append(len(fiber_groups[kk]))
        log.info(f"There are {len(fiber_groups[kk])} streamlines in {kk}")
        sft = StatefulTractogram(fiber_groups[kk], img, Space.RASMM)
        local_tg_fname = './%s_reco.trk'%kk
        save_tractogram(sft, local_tg_fname,
                        bbox_valid_check=False)
        tg_fname = f'hcp.recobundles/sub-{subject}/sub-{subject}_model-csd_track-det_segment-recobundles_bundle-{kk}.trk'
        log.info(f"Uploading {local_tg_fname} to {tg_fname}")
        fs.upload('./%s_reco.trk'%kk, tg_fname)
        tg_meta_fname = f'hcp.recobundles/sub-{subject}/sub-{subject}_model-csd_track-det_segment-recobundles_bundle-{kk}.json'
        tg_meta_json = {
            "Algorithm" : "RecoBundles",
            "AlgorithmURL" : "https://github.com/yeatmanlab/pyAFQ/commit/871c7d567e83fae5041d67802fc8ec03791a877e",
            "Parameters":
            {"model_clust_thr":20,
             "reduction_thr":20}
        }
        log.info(f"Uploading segmentation metadata to {tg_meta_fname}")
        afd.s3fs_write_json(tg_meta_json, tg_meta_fname)
    
    log.info("Saving streamline counts")
    sl_count = pd.DataFrame(data=lengths, index=fiber_groups.keys(), columns=["streamlines"])
    sl_count.to_csv("./sl_count.csv")
    sl_count_fname = f'hcp.recobundles/sub-{subject}/sub-{subject}_model-csd_track-det_segment-recobundles_counts.csv'
    log.info(f"Uploading streamline counts to {sl_count_fname}")
    fs.upload("./sl_count.csv", sl_count_fname)

In [7]:
image = ck.DockerImage(func=recobundles_hcp, github_installs="https://github.com/arokem/pyAFQ.git@recobundles_hcp")

In [9]:
rb_knot = ck.Knot(name='recobundles_hcp',
                  docker_image=image, 
                  pars_policies=('AmazonS3FullAccess',),
                  resource_type="SPOT",
                  bid_percentage=100)

CloudknotInputError: The requested bucket name already exists and you do not have permission to put or get objects in it.