In [1]:
#using afids/afids-auto/afids-auto-train/workflow/scripts/reg_qc.py script

In [1]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import base64
import os
import re
from glob import glob
from io import BytesIO, StringIO
from pathlib import Path
from tempfile import TemporaryDirectory
from uuid import uuid4

import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
from nilearn import plotting
from nilearn.datasets import load_mni152_template
from svgutils.compose import Unit
from svgutils.transform import GroupElement, SVGFigure, fromstring



In [3]:

def svg2str(display_object, dpi):
    """Serialize a nilearn display object to string."""

    image_buf = StringIO()
    display_object.frame_axes.figure.savefig(
        image_buf, dpi=dpi, format="svg", facecolor="k", edgecolor="k"
    )
    return image_buf.getvalue()


def extract_svg(display_object, dpi=250):
    """Remove the preamble of the svg files generated with nilearn."""
    image_svg = svg2str(display_object, dpi)

    image_svg = re.sub(' height="[0-9]+[a-z]*"', "", image_svg, count=1)
    image_svg = re.sub(' width="[0-9]+[a-z]*"', "", image_svg, count=1)
    image_svg = re.sub(
        " viewBox", ' preseveAspectRation="xMidYMid meet" viewBox', image_svg, count=1
    )
    start_tag = "<svg "
    start_idx = image_svg.find(start_tag)
    end_tag = "</svg>"
    end_idx = image_svg.rfind(end_tag)

    # rfind gives the start index of the substr. We want this substr
    # included in our return value so we add its length to the index.
    end_idx += len(end_tag)

    return image_svg[start_idx:end_idx]


def clean_svg(fg_svgs, bg_svgs, ref=0):
    # Find and replace the figure_1 id.
    svgs = bg_svgs + fg_svgs
    roots = [f.getroot() for f in svgs]

    sizes = []
    for f in svgs:
        viewbox = [float(v) for v in f.root.get("viewBox").split(" ")]
        width = int(viewbox[2])
        height = int(viewbox[3])
        sizes.append((width, height))
    nsvgs = len([bg_svgs])

    sizes = np.array(sizes)

    # Calculate the scale to fit all widths
    width = sizes[ref, 0]
    scales = width / sizes[:, 0]
    heights = sizes[:, 1] * scales

    # Compose the views panel: total size is the width of
    # any element (used the first here) and the sum of heights
    fig = SVGFigure(Unit(f"{width}px"), Unit(f"{heights[:nsvgs].sum()}px"))

    yoffset = 0
    for i, r in enumerate(roots):
        r.moveto(0, yoffset, scale_x=scales[i])
        if i == (nsvgs - 1):
            yoffset = 0
        else:
            yoffset += heights[i]

    # Group background and foreground panels in two groups
    if fg_svgs:
        newroots = [
            GroupElement(roots[:nsvgs], {"class": "background-svg"}),
            GroupElement(roots[nsvgs:], {"class": "foreground-svg"}),
        ]
    else:
        newroots = roots

    fig.append(newroots)
    fig.root.attrib.pop("width", None)
    fig.root.attrib.pop("height", None)
    fig.root.set("preserveAspectRatio", "xMidYMid meet")

    with TemporaryDirectory() as tmpdirname:
        out_file = Path(tmpdirname) / "tmp.svg"
        fig.save(str(out_file))
        # Post processing
        svg = out_file.read_text().splitlines()

    # Remove <?xml... line
    if svg[0].startswith("<?xml"):
        svg = svg[1:]

    # Add styles for the flicker animation
    if fg_svgs:
        svg.insert(
            2,
            """\
<style type="text/css">
@keyframes flickerAnimation%s { 0%% {opacity: 1;} 100%% { opacity:0; }}
.foreground-svg { animation: 1s ease-in-out 0s alternate none infinite running flickerAnimation%s;}
.foreground-svg:hover { animation-play-state: running;}
</style>"""
            % tuple([uuid4()] * 2),
        )

    return svg


def sorted_nicely(data, reverse=False):
    convert = lambda text: int(text) if text.isdigit() else text
    alphanum_key = lambda key: [convert(c) for c in re.split("([0-9]+)", key)]

    return sorted(data, key=alphanum_key, reverse=reverse)



In [5]:


    
def output_html(gad_img_list, gad_stripped_3, gad_stripped_4, output_html):
    html_list = []
    for index in range(len(gad_img_list)):#loop over number of files in gad list
        isub = os.path.basename(gad_img_list[index]).split("_")[0] #getting subject ie. 'sub-O005'
        ##################################displaying stripped gad with b= 3 as foreground #######################
        gad_strip_3= nib.load(gad_stripped_3[index])
        gad_strip_3 = nib.Nifti1Image(
            gad_strip_3.get_fdata().astype(np.float32),
            header= gad_strip_3.header,
            affine=gad_strip_3.affine,
        )
        plot_args_ref = {"dim": -0.5} #dim adjustss the brightness, with -2 being max brightness and 2 being max dimness
    
        display_x = plotting.plot_anat( #class that can extract vector graphics from image: plotting gad image
            gad_strip_3, #gad stripped image with b= 3 
            display_mode="x",
            draw_cross=False,
            cut_coords=(-60,-40,0,20,40,60), #taking slice close to centre, coronal, sagittal and frontal
            **plot_args_ref, # ** upacks the dict
        )
        fg_x_svgs = [fromstring(extract_svg(display_x, 300))] #rescaling 
        display_x.close()

        display_y = plotting.plot_anat( #class that can extract vector graphics from image: plotting gad image
            gad_strip_3, 
            display_mode="y",
            draw_cross=False,
            cut_coords=(-40,-20,0,20,40,60), #taking slice close to centre, coronal, sagittal and frontal
            **plot_args_ref, # ** upacks the dict
        )
        fg_y_svgs = [fromstring(extract_svg(display_y, 300))] #rescaling 
        display_y.close()

        display_z = plotting.plot_anat( #class that can extract vector graphics from image: plotting gad image
            gad_strip_3,
            display_mode="z",
            draw_cross=False,
            cut_coords=(-40,-20,0,20,40,60), #taking slice close to centre, coronal, sagittal and frontal
            **plot_args_ref, # ** upacks the dict
        )
        fg_z_svgs = [fromstring(extract_svg(display_z, 300))] #rescaling 
        display_z.close()


        #displaying full gad image as background 
        gad_img = nib.load(gad_img_list[index]) 
        
        gad_img= nib.Nifti1Image(
            gad_img.get_fdata().astype(np.float32),
            header=gad_img.header,
            affine=gad_img.affine,
        )

       #displaying 6 columns of gad images for coronal, sagittal and frontal view
        display_x = plotting.plot_anat(
            gad_img, #gad image
            display_mode="x",# displaying 6 cuts in each axis 
            draw_cross=False,
            cut_coords=(-60,-40,0,20,40,60),
            **plot_args_ref,
        )
        bg_x_svgs = [fromstring(extract_svg(display_x, 300))]#rescaling for gad (background)
        display_x.close()

        display_y = plotting.plot_anat(
            gad_img, #gad image
            display_mode="y",# displaying 6 cuts in each axis 
            draw_cross=False,
            cut_coords=(-40,-20,0,20,40,60),
            **plot_args_ref,
        )
        bg_y_svgs = [fromstring(extract_svg(display_y, 300))]#rescaling for gad (background)
        display_y.close()

        display_z = plotting.plot_anat(
            gad_img, #gad image
            display_mode="z",# displaying 6 cuts in each axis 
            draw_cross=False,
            cut_coords=(-40,-20,0,20,40,60),
            **plot_args_ref,
        )
        bg_z_svgs = [fromstring(extract_svg(display_z, 300))]#rescaling for gad (background)
        display_z.close()

       
        final_stripped_3_x= "\n".join(clean_svg(fg_x_svgs, bg_x_svgs))#plotting them overtop of each other, and brings them in and out
        final_stripped_3_y= "\n".join(clean_svg(fg_y_svgs, bg_y_svgs))#plotting them overtop of each other, and brings them in and out
        final_stripped_3_z= "\n".join(clean_svg(fg_z_svgs, bg_z_svgs))#plotting them overtop of each other, and brings them in and out

        anat_params = {
            "vmin": gad_strip_3.get_fdata(dtype="float32").min(),
            "vmax": gad_strip_3.get_fdata(dtype="float32").max(),
            "cmap": plt.cm.gray,
            "interpolation": "none",
            "draw_cross": False,
        }
        #to plot contours of gad image on top of nongad rigidly transformed scan
        display = plotting.plot_anat(gad_strip_3, **anat_params)
        display.add_contours(gad_img, colors="r", alpha=0.7, linewidths=0.8)
        
        tmpfile = BytesIO()
        display.savefig(tmpfile, dpi=300)
        display.close()
        tmpfile.seek(0)
        encoded_rigid = base64.b64encode(tmpfile.getvalue())

        ##################################displaying stripped gad with b=4 as foreground #######################
        gad_strip_4 = nib.load(gad_stripped_4[index])
        gad_strip_4  = nib.Nifti1Image(
            gad_strip_4 .get_fdata().astype(np.float32),
            header= gad_strip_4 .header,
            affine=gad_strip_4 .affine,
        )


        plot_args_ref = {"dim": -0.5} #specify dictionary, dim can adjust the brightness, ranges between -2 and 2
        
        display_x = plotting.plot_anat(
            gad_strip_4 , 
            display_mode="x",# displaying 6 cuts in each axis 
            draw_cross=False,
            cut_coords=(-60,-40,0,20,40,60),
            **plot_args_ref,
        )
        fg_x_svgs = [fromstring(extract_svg(display_x, 300))]#rescaling for gad (background)
        display_x.close()

        display_y = plotting.plot_anat(
            gad_strip_4 , 
            display_mode="y",# displaying 6 cuts in each axis 
            draw_cross=False,
            cut_coords=(-40,-20,0,20,40,60),
            **plot_args_ref,
        )
        fg_y_svgs = [fromstring(extract_svg(display_y, 300))]#rescaling for gad (background)
        display_y.close()

        display_z = plotting.plot_anat(
            gad_strip_4 ,
            display_mode="z",# displaying 6 cuts in each axis 
            draw_cross=False,
            cut_coords=(-40,-20,0,20,40,60),
            **plot_args_ref,
        )
        fg_z_svgs = [fromstring(extract_svg(display_z, 300))]#rescaling for gad (background)
        display_z.close()


          #displaying gad image as background 
        gad_img = nib.load(gad_img_list[index]) 
        
        gad_img= nib.Nifti1Image(
            gad_img.get_fdata().astype(np.float32),
            header=gad_img.header,
            affine=gad_img.affine,
        )

       #displaying 6 columns of gad images for coronal, sagittal and frontal view
        display_x = plotting.plot_anat(
            gad_img, #gad image
            display_mode="x",# displaying 6 cuts in each axis 
            draw_cross=False,
            cut_coords=(-60,-40,0,20,40,60),
            **plot_args_ref,
        )
        bg_x_svgs = [fromstring(extract_svg(display_x, 300))]#rescaling for gad (background)
        display_x.close()

        display_y = plotting.plot_anat(
            gad_img, #gad image
            display_mode="y",# displaying 6 cuts in each axis 
            draw_cross=False,
            cut_coords=(-40,-20,0,20,40,60),
            **plot_args_ref,
        )
        bg_y_svgs = [fromstring(extract_svg(display_y, 300))]#rescaling for gad (background)
        display_y.close()

        display_z = plotting.plot_anat(
            gad_img, #gad image
            display_mode="z",# displaying 6 cuts in each axis 
            draw_cross=False,
            cut_coords=(-40,-20,0,20,40,60),
            **plot_args_ref,
        )
        bg_z_svgs = [fromstring(extract_svg(display_z, 300))]#rescaling for gad (background)
        
        display_z.close()

        final_stripped_4_x= "\n".join(clean_svg(fg_x_svgs, bg_x_svgs))#plotting them overtop of each other, and brings them in and out
        final_stripped_4_y= "\n".join(clean_svg(fg_y_svgs, bg_y_svgs))#plotting them overtop of each other, and brings them in and out
        final_stripped_4_z= "\n".join(clean_svg(fg_z_svgs, bg_z_svgs))#plotting them overtop of each other, and brings them in and out


        anat_params = {
            "vmin": gad_strip_4.get_fdata(dtype="float32").min(),
            "vmax": gad_strip_4.get_fdata(dtype="float32").max(),
            "cmap": plt.cm.gray,
            "interpolation": "none",
            "draw_cross": False,
        }
        #to plot contours of gad image on top of nongad affine transformed scan
        display = plotting.plot_anat(gad_strip_4 , **anat_params)
        display.add_contours(gad_img, colors="r", alpha=0.7, linewidths=0.8)
        
        tmpfile = BytesIO()
        display.savefig(tmpfile, dpi=300)
        display.close()
        tmpfile.seek(0)
        encoded_affine = base64.b64encode(tmpfile.getvalue())

        html_list.append(f"""
                <center>
                    <h1 style="font-size:42px">{isub}</h1>
                    <h3 style="font-size:42px">Skull stripped Gad with b = 3 border parameter</h3>
                    <p>{final_stripped_3_x}</p>
                    <p>{final_stripped_3_y}</p>
                    <p>{final_stripped_3_z}</p>
                    <h1 style="font-size:42px">{isub}</h1>
                    <h3 style="font-size:42px">Skull stripped Gad with b = 4 border parameter</h3>
                    <p>{final_stripped_4_x}</p>
                    <p>{final_stripped_4_y}</p>
                    <p>{final_stripped_4_z}</p>
                    <hr style="height:4px;border-width:0;color:black;background-color:black;margin:30px;">
                </center>"""
        )
        print(f"done {isub}")

    html_string = "".join(html_list)
    message = f"""<html>
            <head></head>
            <body>{html_string}</body>
            </html>"""

    with open(output_html, "w") as fid:
        fid.write(message)

if __name__ == "__main__":
    #html file to QC skull stripped images with border parameters 3 and 4 (checking first if 3 is sufficient and if not 4 is selected)
    input_gad_stripped_3= sorted(glob(f"/home/ROBARTS/fogunsanya/graham/scratch/degad/derivatives/synthstrip/*/*gad_stripped_b3_T1w.nii.gz"))#list of stripped gad images with b= 3 border parameter
    input_gad_stripped_4= sorted(glob(f"/home/ROBARTS/fogunsanya/graham/scratch/degad/derivatives/synthstrip/*/*gad_stripped_b4_T1w.nii.gz"))#list of stripped gad images with b= 4 border parameter
   
    gad_rigid_t1= glob('/home/ROBARTS/fogunsanya/graham/scratch/degad/derivatives/passing_dataset/rigid/*/*_acq-gad_resampled_T1w.nii.gz') # gad images who's corresponding nongad images underwent a rigid transform
    gad_affine_t1= glob('/home/ROBARTS/fogunsanya/graham/scratch/degad/derivatives/passing_dataset/affine/*/*_acq-gad_resampled_T1w.nii.gz')# gad images who's corresponding nongad images underwent an affine transform
    input_gad = gad_rigid_t1 + gad_affine_t1#total list of QC'ed gad image paths
    input_gad.sort(key=lambda x: os.path.basename(x).split('_')[0][0:])#sorting gad image paths alphabetically by subject name

    output_html_file= "/home/ROBARTS/fogunsanya/graham/scratch/degad/derivatives/skull_strip_QC.html"
    output_html(input_gad,input_gad_stripped_3,input_gad_stripped_4,output_html_file)

   

done sub-P002
done sub-P005
done sub-P007
done sub-P009
done sub-P010
done sub-P014
done sub-P017
done sub-P018
done sub-P019
done sub-P030
done sub-P031
done sub-P032
done sub-P034
done sub-P035
done sub-P038
done sub-P040
done sub-P044
done sub-P045
done sub-P046
done sub-P051
done sub-P052
done sub-P053
done sub-P055
