In [None]:
import cv2
import holoviews as hv
from holoviews import opts
import pydicom
import param
import panel.widgets as pnw
import panel as pn
from pathlib import Path
pn.extension()
hv.extension('bokeh')

In [None]:
data_dir = '/Users/neshdev/radpathfusion/data/'
pathology_src = f'{data_dir}/lung/pathology/LungFCP-01-0006__d5.tiff'
pathology_src = f'{data_dir}/pathology/aaa 0051 BSlice.tif'

MAX_WIDTH = 2000
MAX_HEIGHT = 2000

MIN_WIDTH = 128
MIN_HEIGHT = 128

DEFAULT_WIDTH = 256
DEFAULT_HEIGHT = 256

VISUAL_WIDTH = 600
VISUAL_HEIGHT = 600

pathology_img = cv2.imread(pathology_src)
pathology_img.shape

In [None]:
pathology_img_scaled = cv2.resize(pathology_img,(2000,1500),interpolation=cv2.INTER_CUBIC)
pathology_img_scaled.shape

In [None]:
class PathologyVisualizer:
    
    interpolation_map = {
            "INTER_LINEAR" : cv2.INTER_LINEAR,
            "INTER_CUBIC"  : cv2.INTER_CUBIC,
            "INTER_AREA"   : cv2.INTER_AREA
    }
    
    width_wig  = pnw.IntSlider(name='width', value=DEFAULT_WIDTH, start=MIN_WIDTH, end=MAX_WIDTH)
    height_wig  = pnw.IntSlider(name='height', value=DEFAULT_HEIGHT, start=MIN_HEIGHT, end=MAX_HEIGHT)
    
    def __init__(self, pathology_img_scaled):
        self.pathology_img_scaled = pathology_img_scaled.copy()

        self.rotate_left  = pnw.Button(name='rotate left', width=50)
        self.rotate_left.on_click(self.rotate_counter_clockwise_fn)
        self.rotate_right = pnw.Button(name='rotate right', width=50)
        self.rotate_right.on_click(self.rotate_clockwise_fn)
        self.interpolation = pnw.Select(name='interpolation', options=['INTER_AREA','INTER_CUBIC','INTER_LINEAR'], value='INTER_LINEAR')
    
    def rotate_clockwise_fn(self, event):
        self.pathology_img_scaled = cv2.rotate(self.pathology_img_scaled, cv2.ROTATE_90_CLOCKWISE)
        self.image_.pop(0)
        self.image_.append(pn.depends(self.width_wig,self.height_wig,self.interpolation)(self.load_pathology))

    def rotate_counter_clockwise_fn(self, event):
        self.pathology_img_scaled = cv2.rotate(self.pathology_img_scaled, cv2.ROTATE_90_COUNTERCLOCKWISE)
        self.image_.pop(0)
        self.image_.append(pn.depends(self.width_wig,self.height_wig,self.interpolation)(self.load_pathology))


    def load_pathology(self, width=500, height=500, interpolation=cv2.INTER_LINEAR):
        "Types of interpolation methods"
        cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LINEAR
        
        interpolation_ = PathologyVisualizer.interpolation_map.get(interpolation,'INTER_LINEAR')
        
        dst = cv2.resize(self.pathology_img_scaled,(width,height),interpolation=interpolation_)
        self._dst = dst
        r = pv._dst[:,:,0]
        x,y = r.shape
        return hv.RGB(dst,bounds=(0,0,width,height)).opts(width=VISUAL_WIDTH,height=VISUAL_HEIGHT)   
    
    @property
    def moving_image(self):
        interpolation_ = PathologyVisualizer.interpolation_map.get(self.interpolation.value,'INTER_LINEAR')
        
        dst = cv2.resize(self.pathology_img_scaled,(self.width_wig.value,self.height_wig.value),interpolation=interpolation_)
        return dst
    
    def view(self):
        scaling   = pn.Column(self.width_wig, self.height_wig)
        rotation = pn.Row(self.rotate_left,self.rotate_right)
        widgets = pn.Column(scaling,rotation, self.interpolation)
        image_ = pn.Column()
        self.image_ = image_
        image_.append(pn.depends(self.width_wig,self.height_wig,self.interpolation)(self.load_pathology))
        image = pn.Row(image_, widgets)
        return image

pv = PathologyVisualizer(pathology_img)
view = pv.view()
view.servable()

In [None]:
class MriVisualizer:
    
    interpolation_map = {
            "INTER_LINEAR" : cv2.INTER_LINEAR,
            "INTER_CUBIC"  : cv2.INTER_CUBIC,
            "INTER_AREA"   : cv2.INTER_AREA
    }
    
    width_wig  = pnw.IntSlider(name='width', value=DEFAULT_WIDTH, start=MIN_WIDTH, end=MAX_WIDTH)
    height_wig  = pnw.IntSlider(name='height', value=DEFAULT_HEIGHT, start=MIN_HEIGHT, end=MAX_HEIGHT)
    
    def __init__(self, mri_path):
        p = Path(mri_path)
        self.dcms = sorted([dcm for dcm in p.iterdir()])
        self.index_wig  = pnw.IntSlider(name='index', value=1, start=1, end=len(self.dcms))
        self.interpolation = pnw.Select(name='interpolation', options=['INTER_AREA','INTER_CUBIC','INTER_LINEAR'], value='INTER_LINEAR')


    def load_mri(self, i,width,height,interpolation):
        path = self.dcms[i]
        dataset = pydicom.dcmread(path)
        src = dataset.pixel_array
        
        interpolation_ = MriVisualizer.interpolation_map.get(interpolation,'INTER_LINEAR')
        
        dst = cv2.resize(src,(width,height),interpolation=interpolation_)
        return hv.Image(dst,bounds=(0,0,width,height)).opts(width=VISUAL_WIDTH,height=VISUAL_HEIGHT,cmap='gray') 
    
    @property
    def fixed_image(self):
        path = self.dcms[self.index_wig.value]
        dataset = pydicom.dcmread(path)
        src = dataset.pixel_array
        interpolation_ = MriVisualizer.interpolation_map.get(self.interpolation.value,'INTER_LINEAR')
        dst = cv2.resize(src,(self.width_wig.value,self.height_wig.value),interpolation=interpolation_)
        return dst

    def view(self):
        scaling   = pn.Column(self.width_wig, self.height_wig)
        widgets = pn.Column(scaling,self.index_wig, self.interpolation)
        image_ = pn.Column()
        self.image_ = image_
        image_.append(pn.depends(self.index_wig, self.width_wig,self.height_wig,self.interpolation)(self.load_mri))
        image = pn.Row(image_, widgets)
        return image

mri_folder = f'{data_dir}/lung/Lung-Fused-CT-Pathology/LungFCP-01-0006/1.3.6.1.4.1.14519.5.2.1.5826.1402.276617219738701618517421717712/1.3.6.1.4.1.14519.5.2.1.5826.1402.301955554831918779561627517251'
mri_folder = '/Users/neshdev/radpathfusion/data/Prostate Fused-MRI-Pathology/aaa0051/07-02-2000-PELVISPROSTATE-97855/4.000000-T2 AXIAL SM FOV-36207'
mv = MriVisualizer(mri_folder)
view = mv.view()
view

In [None]:
# d = pydicom.dcmread('/Users/neshdev/radpathfusion/data/lung/Lung-Fused-CT-Pathology/LungFCP-01-0006/1.3.6.1.4.1.14519.5.2.1.5826.1402.276617219738701618517421717712/1.3.6.1.4.1.14519.5.2.1.5826.1402.161058093796067066996687821432/1-015.dcm')
# hv.Image(d.pixel_array)

In [None]:
from pathlib import Path
path = Path('/Users/neshdev/radpathfusion/data/lung/Lung-Fused-CT-Pathology')
lung_dcm_dirs = list(set([p for p in path.rglob('*') if p.is_dir()]))
col = pn.Column()
for p in lung_dcm_dirs:
    if len(list(p.iterdir())) > 10:
        v = MriVisualizer(p)
        view = v.view()
        view.show()
        col.append(str(p))
        col.append(view)

In [None]:
mv = MriVisualizer(lung_dcm_dirs[0])
mview = mv.view()
mview

In [None]:
import holoviews as hv
from holoviews import opts
from PIL import Image
import numpy as np
from holoviews import streams
from holoviews.streams import Pipe, Buffer
from holoviews.plotting.links import DataLink
import pandas as pd
hv.extension('bokeh')

In [None]:
class LandmarkSelector:
    
    def __init__(self, fixed_img, moving_img):
        self.fixed_img = fixed_img
        self.moving_img = moving_img
        self.fixed_stream = None
        self.moving_stream = None
   
    def _landmark_moving_wig(self, arr):
        points_source = hv.Points([])
        
        points_source_stream = streams.PointDraw(data=points_source.columns(), 
                                                 num_objects=10, 
                                                 source=points_source,
                                                 empty_value='black',
                                                 styles={
                                                     'fill_color': ['green','blue',  'red'  ,'yellow','pink',
                                                                    'gray', 'orange','white','purple','brown' ]
                                                 })

        width, height, _ = arr.shape
        image = hv.RGB(arr, bounds=(0, 0, width, height))
        
        table = hv.Table(points_source, ['x', 'y'])
        DataLink(points_source, table)

        wig = ((image * points_source) + table).opts(
            opts.Points(active_tools=['point_draw'], show_grid=True,
                        marker='s', size=10, tools=['hover', 'crosshair', 'undo']),
            opts.Table(editable=True)
        )
        
        self.moving_stream = points_source_stream
        
#         color_cycle = hv.Cycle('Category20')
#         wig = wig.options({'Points': {'color': color_cycle}})
        
        return wig
    
    def _landmark_fixed_wig(self, arr):
        points_source = hv.Points([])
        points_source_stream = streams.PointDraw(data=points_source.columns(), 
                                                 num_objects=10, 
                                                 source=points_source, 
                                                 empty_value='black',
                                                 styles={
                                                     'fill_color': ['green','blue',  'red'  ,'yellow','pink',
                                                                    'gray', 'orange','white','purple','brown' ]
                                 })

        
        width, height = arr.shape
        image = hv.Image(arr, bounds=(0, 0, width, height)).opts(cmap='gray')

        table = hv.Table(points_source, ['x', 'y'])
        DataLink(points_source, table)
        
        
        self.fixed_stream = points_source_stream
        

        wig = ((image * points_source) + table).opts(
            opts.Points(active_tools=['point_draw'], show_grid=True,
                        marker='s', size=10, tools=['hover', 'crosshair', 'undo']),
            opts.Table(editable=True)
        )
        
#         color_cycle = hv.Cycle('Category20')
#         wig = wig.options({'Points': {'color': color_cycle}})
        
        return wig

    def associated_points(self):
        df = pd.DataFrame(
            {
                "moving_x" : map(int,landmarks.moving_stream.data['x']), 
                "moving_y" : map(int,landmarks.moving_stream.data['y']), 
                "fixed_x"  : map(int,landmarks.fixed_stream.data['x']), 
                "fixed_y"  : map(int,landmarks.fixed_stream.data['y'])
            }
        )
        return df
    
    def view(self):
        fixed = self._landmark_fixed_wig(self.fixed_img)
        moving = self._landmark_moving_wig(self.moving_img)
        layout = (fixed + moving).opts(merge_tools=False, shared_axes=False)
        layout.cols(2)
        return layout

In [None]:
landmarks = LandmarkSelector(mv.fixed_image, pv.moving_image)
landmarks.view()

In [None]:
landmarks.associated_points()

In [None]:
class Tps:
    def __init__(self, df):
        df = landmarks.associated_points()
        moving_df = df[['moving_x','moving_y']]
        fixed_df = df[['fixed_x','fixed_y']]
        moving = moving_df.to_numpy(dtype=np.float32)
        fixed = fixed_df.to_numpy(dtype=np.float32)
        self.moving = moving.reshape(1,-1,2)
        self.fixed = fixed.reshape(1,-1,2)
        self.matches = [cv2.DMatch(i,i,0) for i in range(df['moving_x'].count())]
        
    def warp(self, fixed_img, moving_img):
        """
        only 3 points to run this function cv2.getAffineTransform
        cv2.getAffineTransform(self.moving, self.fixed)
        """
        
        
        t, i = cv2.estimateAffine2D(self.moving, self.fixed)
    
    
        w,h = fixed_img.shape
        a_image = cv2.warpAffine(moving_img,t,fixed_img.shape)
        
        new_moving_landmarks = cv2.transform(self.moving, t)
#         print(r)
#         print(self.moving)
#         print(a_image)
        
        tpst = cv2.createThinPlateSplineShapeTransformer()
        tpst.estimateTransformation(self.fixed, new_moving_landmarks,self.matches)
        
        wi_ = np.ones_like(a_image)
        
        return tpst.warpImage(a_image, wi_)
        
        

In [None]:
tps = Tps(landmarks.associated_points())
m = tps.warp(mv.fixed_image, pv.moving_image)

In [None]:
class TransparencyViewer:
    
    fixed_alpha_wig  = pnw.FloatSlider(name='fixed_alpha', value=1, start=0, end=1)
    warp_alpha_wig  = pnw.FloatSlider(name='warp_alpha', value=.5, start=0, end=1)
    
    def __init__(self, fixed, warped):
        self.fixed = fixed
        self.warped = warped
        
    def visualize_images(self,fixed_alpha, warp_alpha):
        f = hv.Image(self.fixed).opts(cmap='gray', alpha=fixed_alpha)
        w = hv.RGB(self.warped).opts(alpha=warp_alpha)
        return (f * w).opts(width=800,height=600)
        
    def view(self):
        cols = pn.Column()
        wigs = pn.Row(self.fixed_alpha_wig,self.warp_alpha_wig)
        cols.append(pn.depends(self.fixed_alpha_wig, self.warp_alpha_wig)(self.visualize_images))
        cols.append(wigs)
        return cols
        
tv = TransparencyViewer(mv.fixed_image,m)
tv.view()