In [6]:
""" 
Author: p.wagner@bhvi.org / p.wagner@unsw.edu.au 
image registration of anterior segment projection from several scans 

Purpose: 
average rnfl layer thincness across scan sets [per participant, eye, scan_type] 
- apply image registration rigid or translational 
and save data as csv for further analyses 


"""
import pandas as pd
import numpy as np
import os.path
from pathlib import Path
from pystackreg import StackReg
from skimage import io
import sys
sys.path.append(r'C:\Users\p.wagner\Documents\Python Scripts\oct_data_analyses_helpers')
from oct_helpers_lib import OctDataAccess as get_px_meta
from oct_helpers_lib import TopconSegmentationData as TSD

import cv2
import glob
from PIL import Image
import plotly.graph_objects as go
import matplotlib.pyplot as plt

path_oct = r'E:\studyIII\OCT_data'
path_logbook = r'C:\Users\p.wagner\Documents\phd\stud_III\participants'
fn_logbook = 'participant_log_studyIII_V0.2.xlsx' 
fp_fn_logbook = os.path.join(path_logbook, fn_logbook)
thickness_csv = 'DEFAULT_3D_All_Thickness.csv'


# check if path_oct is available 
if not os.path.isdir(path_oct):
    print('OCT data path NOT available')

In [7]:
# scratchy 
scan_type = 'OCT_initial_OD'
px_id = 8
px_meta_data = get_px_meta(fp_fn_logbook, [px_id,], [scan_type,], path_oct)
px_meta_data.log_master.loc['im_reg_type', px_meta_data.log_master.loc['px_id']==px_id][0]

'rbt'

In [8]:
px_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 20, 23, 26]
px_ids = [16,]

for px_id in px_ids: 
    scan_types_all = ['OCT_initial_OD', 'OCT_post_OD', 'OCT_initial_OS', 'OCT_post_OS']
    
    px_meta_data_all = get_px_meta(fp_fn_logbook, [px_id, ], scan_types_all, path_oct)
    # read thickness data file    
    oct_data_fn_fp = os.path.join(px_meta_data_all.subject_rec_fp[0], thickness_csv)
    oct_thicness_data_all = pd.read_csv(oct_data_fn_fp, dtype=str, names=list(range(0,513, 1)), low_memory=False)
    
    
    # check data quality for each eye and asign best quality image as ref image
    # # OD
    scan_types_OD = ['OCT_initial_OD', 'OCT_post_OD', ]
    df = pd.DataFrame(columns=['quality', 'fp',])
    px_meta_data_OD = get_px_meta(fp_fn_logbook, [px_id, ], scan_types_OD, path_oct)
    for idx, fp in enumerate(px_meta_data_OD.oct_scans_fp):
        df = df.append(pd.DataFrame({'quality': [px_meta_data_OD.get_quality(oct_thicness_data_all, px_meta_data_OD.oct_scans_ids[idx])],
                                'fp': [fp]}))
        df = df.reset_index(drop=True)     
            # insert text on images here
        df_sorted = df.sort_values(by=['quality'], ascending=False).reset_index(drop=True)  
        #load reference image
        best_q_im_OD = io.imread(os.path.join(df_sorted.fp[0], 'images_Sequence', 'images_Sequence_Proj_Iowa.tif'))    
    # # OS 
    scan_types_OS = ['OCT_initial_OS', 'OCT_post_OS', ]
    df = pd.DataFrame(columns=['quality', 'fp',])
    px_meta_data_OS = get_px_meta(fp_fn_logbook, [px_id, ], scan_types_OS, path_oct)
    for idx, fp in enumerate(px_meta_data_OS.oct_scans_fp):
        df = df.append(pd.DataFrame({'quality': [px_meta_data_OS.get_quality(oct_thicness_data_all, px_meta_data_OS.oct_scans_ids[idx])],
                                'fp': [fp]}))
        df = df.reset_index(drop=True)     
        # insert text on images here
        df_sorted = df.sort_values(by=['quality'], ascending=False).reset_index(drop=True)  
        #load reference image
        best_q_im_OS = io.imread(os.path.join(df_sorted.fp[0], 'images_Sequence', 'images_Sequence_Proj_Iowa.tif')) 
    
    # collect some meta data 
    scans_meta_columns = ['px_id', 'scan_type', 'oct_scans_ids', 'oct_scan_quality', 'mean_std',
                      'std_std', 'max_maxmin', 'min_maxmin']
    scans_meta_data = pd.DataFrame(columns=scans_meta_columns) 
    
    
    # apply averageing to all images 
    
    for scan_type in scan_types_all:
        px_meta_data = get_px_meta(fp_fn_logbook, [px_id,], [scan_type,], path_oct)


        all_rnfl = []
        for scan_id, fp in zip(px_meta_data.oct_scans_ids, px_meta_data.oct_scans_fp):
            oct_scan_data = TSD(oct_thicness_data_all, scan_id)

#             im =plt.imread(os.path.join(fp, 'images_Sequence_Proj_Iowa_ONH_ident2.png'))
#             # find blue pixels (ONH annotation)
#             idxs = np.where(np.all(im == [0, 0 , 1], axis=-1))

            z_data = oct_scan_data.rnfl.values.astype(float)
#             # # choroid data == 0 is impossible -> disregard but account during averaging across scans  
#             z_data[idxs[0], idxs[1]] = 0
            
            print(scan_id)
            # # which eye need to be registered 
            if scan_type.split('_')[-1] == 'OD':
                ref_im = best_q_im_OD
            if scan_type.split('_')[-1] == 'OS':
                ref_im = best_q_im_OS
            
            # # calculate image regression 
            out_rot_im = np.zeros((256, 512, 3), dtype=float)
            # #load "moved" image
            mov_im = io.imread(os.path.join(fp, 'images_Sequence', 'images_Sequence_Proj_Iowa.tif'))
            
            if px_meta_data.log_master.loc['im_reg_type', px_meta_data.log_master.loc['px_id']==px_id][0] =='rbt':
                # # Rigid Body transformation
                sr = StackReg(StackReg.RIGID_BODY)
                out_rot = sr.register_transform(ref_im, mov_im)            
                print('rbt')
            if px_meta_data.log_master.loc['im_reg_type', px_meta_data.log_master.loc['px_id']==px_id][0] =='tran':
                #Translational transformation
                # applied to px_id [3, 13
                sr = StackReg(StackReg.TRANSLATION)
                out_rot = sr.register_transform(ref_im, mov_im)
                print('tran')
            
            # # apply image regression
            print(sr.get_matrix())
            

            z_data[np.isnan(z_data)] = 0
            print(z_data)
            z_data_reg = sr.transform(z_data)            
            all_rnfl.append(z_data_reg)
        
        all_rnfl = np.array(all_rnfl)

        all_rnfl_mean = np.nanmean(all_rnfl, axis=0)
        all_rnfl_std = np.nanstd(all_rnfl, axis=0)
        all_rnfl_max = np.nanmax(all_rnfl, axis=0)
        all_rnfl_min = np.nanmin(all_rnfl, axis=0)

        # save all_choroids_mean for further analyses
        rnfl_output_fn = os.path.join(px_meta_data.subject_rec_fp[0], 
                                          scan_type.split('_')[-1] + '_' + scan_type.split('_')[-2] + 
                                          '_rnfl_mean_reg.csv')
        df = pd.DataFrame(all_rnfl_mean)
        df.to_csv(rnfl_output_fn)
        
#         all_rnfl_mean.astype(int).shape
        quality_all = []
        for scan_id in px_meta_data.oct_scans_ids:
            quality_all.append(px_meta_data.get_quality(oct_thicness_data_all, scan_id))
        
        TSD.create_rnfl_avg_figures(all_rnfl_mean, px_id, quality_all, scan_type, 
                        px_meta_data.oct_scans_ids, 
                        px_meta_data.subject_rec_fp[0], path_oct)
        
        
        # gather scan meta data 
        scans_meta_data = scans_meta_data.append(pd.DataFrame({'px_id':px_meta_data.px_ids,
                                     'scan_type': px_meta_data.scan_types,
                                     'oct_scans_ids': str(px_meta_data.oct_scans_ids), 
                                     'oct_scan_quality': str(quality_all), 
                                     'mean_std': np.nanmean(all_rnfl_std),
                                     'std_std': np.nanstd(all_rnfl_std),
                                     'max_maxmin': np.nanmax(all_rnfl_max - all_rnfl_min),
                                     'min_maxmin': np.nanmin(all_rnfl_max - all_rnfl_min)
                                    }))
            
            
        del all_rnfl_mean, all_rnfl_std, all_rnfl_max, all_rnfl_min, z_data_reg, df
    
    # save scan meta data 
    scans_meta_output_fn = os.path.join(px_meta_data.subject_rec_fp[0], 
                                    'scans_summery_rnfl.csv')
    scans_meta_data.to_csv(scans_meta_output_fn)
    
    del oct_thicness_data_all
    
    
print('finished')

44340
tran
[[ 1.          0.         -3.017285  ]
 [ 0.          1.         -7.62781997]
 [ 0.          0.          1.        ]]
[[23.28 25.86 25.86 ... 31.04 31.04 31.04]
 [15.52 12.93 12.93 ... 38.8  36.21 38.8 ]
 [25.86 25.86 25.86 ... 36.21 36.21 36.21]
 ...
 [23.28 25.86 23.28 ... 38.8  41.38 41.38]
 [18.1  18.1  18.1  ... 38.8  43.97 43.97]
 [36.21 41.38 43.97 ... 46.56 43.97 43.97]]
44341
tran
[[ 1.          0.         -2.69175438]
 [ 0.          1.         -7.6228755 ]
 [ 0.          0.          1.        ]]
[[18.1  18.1  18.1  ... 41.38 41.38 41.38]
 [15.52 15.52 15.52 ... 43.97 43.97 41.38]
 [15.52 15.52 15.52 ... 43.97 43.97 43.97]
 ...
 [ 5.17  2.59  2.59 ... 33.62 33.62 33.62]
 [ 0.    0.    0.   ... 33.62 33.62 33.62]
 [ 5.17  5.17  5.17 ... 28.45 31.04 31.04]]
44343
tran
[[ 1.          0.         -3.07680889]
 [ 0.          1.         -6.82671141]
 [ 0.          0.          1.        ]]
[[20.69 20.69 20.69 ... 41.38 41.38 41.38]
 [25.86 25.86 25.86 ... 36.21 38.8  38.8 ]

In [9]:
z_data.shape

(256, 512)

In [10]:
all_rnfl.shape

(4, 256, 512)

In [2]:
# # im_output_fn = scan_types[0].split('_')[-1] + '_' + scan_types[0].split('_')[-2] +'_mean_choroid_thickness.png'


# # # display choroid values 
# fig = go.Figure(data=[go.Surface(z=abs(-np.array(z_data)))])
# title_name = ('PX_id: ' + str(px_id)  + ', Scan Nr: ' + str(scan_id) + 
#               ', Quality: '+  str(oct_scan_data.scan_quality)) 
# fig.update_layout(title=title_name, autosize=False,
#                   scene_camera_eye=dict(x=0, y=-3, z=10),
#                   width=900, height=900,
#                   margin=dict(l=50, r=50, b=50, t=50))
# fig.update_layout(scene = dict(xaxis = dict(nticks=10, range=[0, 150],),
#                                yaxis = dict(nticks=10, range=[0, 150],),
#                                zaxis = dict(nticks=10, range=[-1,150],),),
#                   scene_aspectmode='manual', 
#                   scene_aspectratio=dict(x=9, y=9, z=10)
#                   )

# fig.show()
# # fig.write_image(os.path.join(px_meta_data.subject_rec_fp[0], 'images', im_output_fn))

In [3]:
# def create_rnfl_avg_figures(mean, px_id, quality, scan_type, scan_ids, fp, fp_oct):
#     # creating figure choroid thickness figures
#     # direct view
#     im_output_fn_d = (str(px_id)  + scan_type.split('_')[-1] + '_' + scan_type.split('_')[-2] +
#                       '_rnfl_thickness_all.png')
#     # # display choroid values
#     fig1 = go.Figure(data=[go.Surface(z=abs(-all_rnfl_mean))])
#     title_name = ('PX_id: ' + str(px_id)  + ', Scan Nr: ' + str(scan_ids) + 
#                   ', Quality: '+  str(quality)) 
#     fig1.update_layout(title=title_name, autosize=False,
#                       scene_camera_eye=dict(x=0, y=-10, z=20),
#                       width=900, height=900,
#                       margin=dict(l=50, r=50, b=50, t=50))
#     fig1.update_layout(scene = dict(xaxis = dict(nticks=10, range=[0,512],),
#                                    yaxis = dict(nticks=10, range=[0,256],),
#                                    zaxis = dict(nticks=10, range=[0,300],),),
#                       scene_aspectmode='manual', 
#                       scene_aspectratio=dict(x=12, y=9, z=10)
#                       )

#     fig1.write_image(os.path.join(fp_oct, 'images\\rnfl', im_output_fn_d))
#     fig1.write_image(os.path.join(fp, 'images', im_output_fn_d))

#     im_output_fn_i = (str(px_id)  + scan_type.split('_')[-1] + '_' + scan_type.split('_')[-2] +
#                       '_rnfl_thickness_macula.png')

#     # indirect view
#     fig2 = go.Figure(data=[go.Surface(z=abs(-all_rnfl_mean[50:200, 175:325]))])
#     title_name = ('PX_id: ' + str(px_id)  + ', Scan Nr: ' + str(scan_ids) + 
#                   ', Quality: '+  str(quality)) 
#     fig2.update_layout(title=title_name, autosize=False,
#                       scene_camera_eye=dict(x=0, y=-3, z=10),
#                       width=900, height=900,
#                       margin=dict(l=50, r=50, b=50, t=50))
#     fig2.update_layout(scene = dict(xaxis = dict(nticks=10, range=[0, 150],),
#                                    yaxis = dict(nticks=10, range=[0, 150],),
#                                    zaxis = dict(nticks=10, range=[-1,150],),),
#                       scene_aspectmode='manual', 
#                       scene_aspectratio=dict(x=9, y=9, z=10)
#                       )

#     fig2.write_image(os.path.join(fp_oct, 'images\\rnfl', im_output_fn_i))
#     fig2.write_image(os.path.join(fp, 'images', im_output_fn_i))
    


In [4]:
# create_rnfl_avg_figures(all_rnfl_mean, px_id, quality_all, scan_type, 
#                         px_meta_data.oct_scans_ids, 
#                         px_meta_data.subject_rec_fp[0], path_oct)

In [76]:
px_meta_data.oct_scans_ids

['44117', '44118', '44119']

In [None]:
all_rnfl.shape
