In [4]:
#using afids/afids-auto/afids-auto-train/workflow/scripts/reg_qc.py script

In [1]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import base64
import os
import re
from glob import glob
from io import BytesIO, StringIO
from pathlib import Path
from tempfile import TemporaryDirectory
from uuid import uuid4

import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
from nilearn import plotting
from nilearn.datasets import load_mni152_template
from svgutils.compose import Unit
from svgutils.transform import GroupElement, SVGFigure, fromstring



In [2]:

def svg2str(display_object, dpi):
    """Serialize a nilearn display object to string."""

    image_buf = StringIO()
    display_object.frame_axes.figure.savefig(
        image_buf, dpi=dpi, format="svg", facecolor="k", edgecolor="k"
    )
    return image_buf.getvalue()


def extract_svg(display_object, dpi=250):
    """Remove the preamble of the svg files generated with nilearn."""
    image_svg = svg2str(display_object, dpi)

    image_svg = re.sub(' height="[0-9]+[a-z]*"', "", image_svg, count=1)
    image_svg = re.sub(' width="[0-9]+[a-z]*"', "", image_svg, count=1)
    image_svg = re.sub(
        " viewBox", ' preseveAspectRation="xMidYMid meet" viewBox', image_svg, count=1
    )
    start_tag = "<svg "
    start_idx = image_svg.find(start_tag)
    end_tag = "</svg>"
    end_idx = image_svg.rfind(end_tag)

    # rfind gives the start index of the substr. We want this substr
    # included in our return value so we add its length to the index.
    end_idx += len(end_tag)

    return image_svg[start_idx:end_idx]


def clean_svg(fg_svgs, bg_svgs, ref=0):
    # Find and replace the figure_1 id.
    svgs = bg_svgs + fg_svgs
    roots = [f.getroot() for f in svgs]

    sizes = []
    for f in svgs:
        viewbox = [float(v) for v in f.root.get("viewBox").split(" ")]
        width = int(viewbox[2])
        height = int(viewbox[3])
        sizes.append((width, height))
    nsvgs = len([bg_svgs])

    sizes = np.array(sizes)

    # Calculate the scale to fit all widths
    width = sizes[ref, 0]
    scales = width / sizes[:, 0]
    heights = sizes[:, 1] * scales

    # Compose the views panel: total size is the width of
    # any element (used the first here) and the sum of heights
    fig = SVGFigure(Unit(f"{width}px"), Unit(f"{heights[:nsvgs].sum()}px"))

    yoffset = 0
    for i, r in enumerate(roots):
        r.moveto(0, yoffset, scale_x=scales[i])
        if i == (nsvgs - 1):
            yoffset = 0
        else:
            yoffset += heights[i]

    # Group background and foreground panels in two groups
    if fg_svgs:
        newroots = [
            GroupElement(roots[:nsvgs], {"class": "background-svg"}),
            GroupElement(roots[nsvgs:], {"class": "foreground-svg"}),
        ]
    else:
        newroots = roots

    fig.append(newroots)
    fig.root.attrib.pop("width", None)
    fig.root.attrib.pop("height", None)
    fig.root.set("preserveAspectRatio", "xMidYMid meet")

    with TemporaryDirectory() as tmpdirname:
        out_file = Path(tmpdirname) / "tmp.svg"
        fig.save(str(out_file))
        # Post processing
        svg = out_file.read_text().splitlines()

    # Remove <?xml... line
    if svg[0].startswith("<?xml"):
        svg = svg[1:]

    # Add styles for the flicker animation
    if fg_svgs:
        svg.insert(
            2,
            """\
<style type="text/css">
@keyframes flickerAnimation%s { 0%% {opacity: 1;} 100%% { opacity:0; }}
.foreground-svg { animation: 1s ease-in-out 0s alternate none infinite running flickerAnimation%s;}
.foreground-svg:hover { animation-play-state: running;}
</style>"""
            % tuple([uuid4()] * 2),
        )

    return svg


def sorted_nicely(data, reverse=False):
    convert = lambda text: int(text) if text.isdigit() else text
    alphanum_key = lambda key: [convert(c) for c in re.split("([0-9]+)", key)]

    return sorted(data, key=alphanum_key, reverse=reverse)



In [52]:


    
def output_html(gad_img_list, nongad_rigid_img, nongad_affine_img, output_html):
    html_list = []
    for index in range(len(gad_img_list)):#loop over number of files in gad list
        isub = os.path.basename(gad_img_list[index]).split("_")[0]
        ##################################displaying nongad rigid image as foreground #######################
        nongad_rigid = nib.load(nongad_rigid_img[index])
        nongad_rigid = nib.Nifti1Image(
            nongad_rigid.get_fdata().astype(np.float32),
            header= nongad_rigid.header,
            affine=nongad_rigid.affine,
        )
        plot_args_ref = {"dim": -0.5} #dim adjustss the brifhtness, with -2 being max brightness and 2 being max dimness
    
        display_x = plotting.plot_anat( #class that can extract vector graphics from image: plotting gad image
            nongad_rigid, #nongad rigid image 
            display_mode="x",
            draw_cross=False,
            cut_coords=(-60,-40,0,20,40,60), #taking slice close to centre, coronal, sagittal and frontal
            **plot_args_ref, # ** upacks the dict
        )
        fg_x_svgs = [fromstring(extract_svg(display_x, 300))] #rescaling for nongad rigid 
        display_x.close()

        display_y = plotting.plot_anat( #class that can extract vector graphics from image: plotting gad image
            nongad_rigid, #nongad rigid image 
            display_mode="y",
            draw_cross=False,
            cut_coords=(-40,-20,0,20,40,60), #taking slice close to centre, coronal, sagittal and frontal
            **plot_args_ref, # ** upacks the dict
        )
        fg_y_svgs = [fromstring(extract_svg(display_y, 300))] #rescaling for nongad rigid 
        display_y.close()

        display_z = plotting.plot_anat( #class that can extract vector graphics from image: plotting gad image
            nongad_rigid, #nongad rigid image 
            display_mode="z",
            draw_cross=False,
            cut_coords=(-40,-20,0,20,40,60), #taking slice close to centre, coronal, sagittal and frontal
            **plot_args_ref, # ** upacks the dict
        )
        fg_z_svgs = [fromstring(extract_svg(display_z, 300))] #rescaling for nongad rigid 
        display_z.close()


        #displaying gad image as background 
        gad_img = nib.load(gad_img_list[index]) 
        
        gad_img= nib.Nifti1Image(
            gad_img.get_fdata().astype(np.float32),
            header=gad_img.header,
            affine=gad_img.affine,
        )

       #displaying 6 columns of gad images for coronal, sagittal and frontal view
        display_x = plotting.plot_anat(
            gad_img, #gad image
            display_mode="x",# displaying 6 cuts in each axis 
            draw_cross=False,
            cut_coords=(-60,-40,0,20,40,60),
            **plot_args_ref,
        )
        bg_x_svgs = [fromstring(extract_svg(display_x, 300))]#rescaling for gad (background)
        display_x.close()

        display_y = plotting.plot_anat(
            gad_img, #gad image
            display_mode="y",# displaying 6 cuts in each axis 
            draw_cross=False,
            cut_coords=(-40,-20,0,20,40,60),
            **plot_args_ref,
        )
        bg_y_svgs = [fromstring(extract_svg(display_y, 300))]#rescaling for gad (background)
        display_y.close()

        display_z = plotting.plot_anat(
            gad_img, #gad image
            display_mode="z",# displaying 6 cuts in each axis 
            draw_cross=False,
            cut_coords=(-40,-20,0,20,40,60),
            **plot_args_ref,
        )
        bg_z_svgs = [fromstring(extract_svg(display_z, 300))]#rescaling for gad (background)
        display_z.close()

       
        final_svg_rigid_x= "\n".join(clean_svg(fg_x_svgs, bg_x_svgs))#plotting them overtop of each other, and brings them in and out
        final_svg_rigid_y= "\n".join(clean_svg(fg_y_svgs, bg_y_svgs))#plotting them overtop of each other, and brings them in and out
        final_svg_rigid_z= "\n".join(clean_svg(fg_z_svgs, bg_z_svgs))#plotting them overtop of each other, and brings them in and out

        anat_params = {
            "vmin": nongad_rigid.get_fdata(dtype="float32").min(),
            "vmax": nongad_rigid.get_fdata(dtype="float32").max(),
            "cmap": plt.cm.gray,
            "interpolation": "none",
            "draw_cross": False,
        }
        #to plot contours of gad image on top of nongad rigidly transformed scan
        display = plotting.plot_anat(nongad_rigid, **anat_params)
        display.add_contours(gad_img, colors="r", alpha=0.7, linewidths=0.8)
        
        tmpfile = BytesIO()
        display.savefig(tmpfile, dpi=300)
        display.close()
        tmpfile.seek(0)
        encoded_rigid = base64.b64encode(tmpfile.getvalue())

        ##################################displaying nongad affine image as foreground #######################
        nongad_affine = nib.load(nongad_affine_img[index])
        nongad_affine = nib.Nifti1Image(
            nongad_affine.get_fdata().astype(np.float32),
            header= nongad_affine.header,
            affine=nongad_affine.affine,
        )


        plot_args_ref = {"dim": -0.5} #specify dictionary, dim can adjust the brightness, ranges between -2 and 2
        
        display_x = plotting.plot_anat(
            nongad_affine, 
            display_mode="x",# displaying 6 cuts in each axis 
            draw_cross=False,
            cut_coords=(-60,-40,0,20,40,60),
            **plot_args_ref,
        )
        fg_x_svgs = [fromstring(extract_svg(display_x, 300))]#rescaling for gad (background)
        display_x.close()

        display_y = plotting.plot_anat(
            nongad_affine, 
            display_mode="y",# displaying 6 cuts in each axis 
            draw_cross=False,
            cut_coords=(-40,-20,0,20,40,60),
            **plot_args_ref,
        )
        fg_y_svgs = [fromstring(extract_svg(display_y, 300))]#rescaling for gad (background)
        display_y.close()

        display_z = plotting.plot_anat(
            nongad_affine,
            display_mode="z",# displaying 6 cuts in each axis 
            draw_cross=False,
            cut_coords=(-40,-20,0,20,40,60),
            **plot_args_ref,
        )
        fg_z_svgs = [fromstring(extract_svg(display_z, 300))]#rescaling for gad (background)
        display_z.close()


          #displaying gad image as background 
        gad_img = nib.load(gad_img_list[index]) 
        
        gad_img= nib.Nifti1Image(
            gad_img.get_fdata().astype(np.float32),
            header=gad_img.header,
            affine=gad_img.affine,
        )

       #displaying 6 columns of gad images for coronal, sagittal and frontal view
        display_x = plotting.plot_anat(
            gad_img, #gad image
            display_mode="x",# displaying 6 cuts in each axis 
            draw_cross=False,
            cut_coords=(-60,-40,0,20,40,60),
            **plot_args_ref,
        )
        bg_x_svgs = [fromstring(extract_svg(display_x, 300))]#rescaling for gad (background)
        display_x.close()

        display_y = plotting.plot_anat(
            gad_img, #gad image
            display_mode="y",# displaying 6 cuts in each axis 
            draw_cross=False,
            cut_coords=(-40,-20,0,20,40,60),
            **plot_args_ref,
        )
        bg_y_svgs = [fromstring(extract_svg(display_y, 300))]#rescaling for gad (background)
        display_y.close()

        display_z = plotting.plot_anat(
            gad_img, #gad image
            display_mode="z",# displaying 6 cuts in each axis 
            draw_cross=False,
            cut_coords=(-40,-20,0,20,40,60),
            **plot_args_ref,
        )
        bg_z_svgs = [fromstring(extract_svg(display_z, 300))]#rescaling for gad (background)
        
        display_z.close()

        final_svg_affine_x= "\n".join(clean_svg(fg_x_svgs, bg_x_svgs))#plotting them overtop of each other, and brings them in and out
        final_svg_affine_y= "\n".join(clean_svg(fg_y_svgs, bg_y_svgs))#plotting them overtop of each other, and brings them in and out
        final_svg_affine_z= "\n".join(clean_svg(fg_z_svgs, bg_z_svgs))#plotting them overtop of each other, and brings them in and out


        anat_params = {
            "vmin": nongad_affine.get_fdata(dtype="float32").min(),
            "vmax": nongad_affine.get_fdata(dtype="float32").max(),
            "cmap": plt.cm.gray,
            "interpolation": "none",
            "draw_cross": False,
        }
        #to plot contours of gad image on top of nongad affine transformed scan
        display = plotting.plot_anat(nongad_affine, **anat_params)
        display.add_contours(gad_img, colors="r", alpha=0.7, linewidths=0.8)
        
        tmpfile = BytesIO()
        display.savefig(tmpfile, dpi=300)
        display.close()
        tmpfile.seek(0)
        encoded_affine = base64.b64encode(tmpfile.getvalue())

        html_list.append(f"""
                <center>
                    <h1 style="font-size:42px">{isub}</h1>
                    <h3 style="font-size:42px">Rigid transformation: Nongad to Gad space</h3>
                    <p>{final_svg_rigid_x}</p>
                    <p>{final_svg_rigid_y}</p>
                    <p>{final_svg_rigid_z}</p>
                    <h1 style="font-size:42px">{isub}</h1>
                    <h3 style="font-size:42px">Affine transformation: Nongad to Gad space</h3>
                    <p>{final_svg_affine_x}</p>
                    <p>{final_svg_affine_y}</p>
                    <p>{final_svg_affine_z}</p>
                    <hr style="height:4px;border-width:0;color:black;background-color:black;margin:30px;">
                </center>"""
        )
        print(f"Done {isub}")

    html_string = "".join(html_list)
    message = f"""<html>
            <head></head>
            <body>{html_string}</body>
            </html>"""

    with open(output_html, "w") as fid:
        fid.write(message)

if __name__ == "__main__":
    input_gad_dirs= sorted(glob(f"/home/ROBARTS/fogunsanya/graham/scratch/degad/derivatives/greedy/*/*_acq-gad_resampled_T1w.nii.gz"))#list of gad images paths
    input_nongad_rigid= sorted(glob(f"/home/ROBARTS/fogunsanya/graham/scratch/degad/derivatives/greedy/*/*rigid_resliced_T1w.nii.gz")) #list of rigid transformed nongad image paths
    input_nongad_affine=  sorted(glob(f"/home/ROBARTS/fogunsanya/graham/scratch/degad/derivatives/greedy/*/*affine_resliced_T1w.nii.gz")) #list of affine transformed nongad image paths
    output_html_file= "/home/ROBARTS/fogunsanya/graham/scratch/degad/derivatives/registration_QC.html"
    output_html(input_gad_dirs,input_nongad_rigid,input_nongad_affine,output_html_file)

   

Done sub-P001
Done sub-P002
Done sub-P003
Done sub-P004
Done sub-P005
Done sub-P006
Done sub-P007
Done sub-P008
Done sub-P009
Done sub-P010
Done sub-P011
Done sub-P012
Done sub-P013
Done sub-P014
Done sub-P015
Done sub-P016
Done sub-P017
Done sub-P018
Done sub-P019
Done sub-P020
Done sub-P021
Done sub-P022
Done sub-P023
Done sub-P024
Done sub-P025
Done sub-P026
Done sub-P027
Done sub-P028
Done sub-P029
Done sub-P030
Done sub-P031
Done sub-P032
Done sub-P033
Done sub-P034
Done sub-P035
Done sub-P036
Done sub-P037
Done sub-P038
Done sub-P039
Done sub-P040
Done sub-P041
Done sub-P042
Done sub-P043
Done sub-P044
Done sub-P045
Done sub-P046
Done sub-P047
Done sub-P048
Done sub-P049
Done sub-P050
Done sub-P051
Done sub-P052
Done sub-P053
Done sub-P054
Done sub-P055
