In [None]:
import os
import datetime
from glob import glob

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import scipy.optimize
import scipy.interpolate

import pydicom

In [None]:
data_root = r'S:\Physics\Programming\data\MVISO'

In [None]:
data_record = glob(os.path.join(data_root, 'iView*.xlsx'))[0]
dicom_files = np.array(glob(os.path.join(data_root, '*.dcm')))

In [None]:
record = pd.read_excel(data_record, skiprows=4)
timestamps_initial = record['Datetime']
timestamps = timestamps_initial[timestamps_initial.notnull()].values
gantry = record['Gantry'][timestamps_initial.notnull()].values
colimator = record['Col'][timestamps_initial.notnull()].values
turntable = record['TT'][timestamps_initial.notnull()].values
beam = record['Energy'][timestamps_initial.notnull()].values

In [None]:
datasets = np.array([
    pydicom.read_file(dicom_file, force=True)
    for dicom_file in dicom_files
])

In [None]:
acquisition_datetimes = np.array([
    datetime.datetime.strptime(dataset.AcquisitionDate + dataset.AcquisitionTime, '%Y%m%d%H%M%S.%f')
    for dataset in datasets
], dtype=np.datetime64)

In [None]:
diff_map = np.abs(acquisition_datetimes[None,:] - timestamps[:, None]) < np.timedelta64(2, 's')
timestamp_index, acquisition_index = np.where(diff_map)

In [None]:
assert len(set(acquisition_index)) == len(acquisition_index)
assert len(acquisition_index) == len(acquisition_datetimes)

In [None]:
datasets = datasets[acquisition_index]
dicom_files = dicom_files[acquisition_index]
timestamps = timestamps[timestamp_index]
gantry = gantry[timestamp_index]
colimator = colimator[timestamp_index]
turntable = turntable[timestamp_index]
beam = beam[timestamp_index]

acquisition_datetimes = np.array([
    datetime.datetime.strptime(dataset.AcquisitionDate + dataset.AcquisitionTime, '%Y%m%d%H%M%S.%f')
    for dataset in datasets
], dtype=np.datetime64)

diff_map = np.abs(acquisition_datetimes[None,:] - timestamps[:, None]) < np.timedelta64(2, 's')
timestamp_index, acquisition_index = np.where(diff_map)

assert np.all(timestamp_index == acquisition_index)

In [None]:
pixel_arrays = np.array([
    dataset.pixel_array
    for dataset in datasets
], copy=True)

pixel_arrays = 1 - pixel_arrays/2**16

In [None]:
axis_distance = np.arange(-512, 512)/4

initial_mask_distance = 20  # mm

first = np.where(axis_distance >= -initial_mask_distance)[0][0]
last = np.where(axis_distance > initial_mask_distance)[0][0]

mask = slice(first, last)

axis_distance = axis_distance[mask]

masked_arrays = np.array([
    pixel_array[mask, mask]
    for pixel_array in pixel_arrays
])

In [None]:
axis_distance

In [None]:
interpolation = scipy.interpolate.RectBivariateSpline(axis_distance, axis_distance, masked_arrays[0], kx=1, ky=1)

In [None]:
square_field_side_length = 20  # mm

penumbra_width = 3  # mm
ball_bearing_diameter = 8 # mm

In [None]:
penumbra_range = np.linspace(-penumbra_width/2, penumbra_width/2, 11)
half_field_range = np.linspace(-square_field_side_length/4, square_field_side_length/4, 51)

def get_sum_of_square_penumbra_flip(centre_x, centre_y, interpolation_func):
    left_lookup = centre_x - square_field_side_length/2 + penumbra_range
    right_lookup = centre_x + square_field_side_length/2 + penumbra_range
    x_lookup = np.concatenate([left_lookup, right_lookup])

    y_lookup = centre_y + half_field_range

    xx_lookup, yy_lookup = np.meshgrid(x_lookup, y_lookup)
    xx_lookup_flat = np.ravel(xx_lookup)
    yy_lookup_flat = np.ravel(yy_lookup)

    penumbra_values_flat = interpolation_func(yy_lookup_flat, xx_lookup_flat)
    penumbra_values = np.reshape(penumbra_values_flat, np.shape(xx_lookup))
    return np.sum((penumbra_values - penumbra_values[:, ::-1])**2)
    
    
def get_sum_of_square_penumbra_flip_transpose(centre_x, centre_y, interpolation_func):
    def transposed_interp_func(y, x):
        return interpolation_func(x, y)
        
    return get_sum_of_square_penumbra_flip(centre_y, centre_x, transposed_interp_func)


def get_sum_of_square_both_penumbra_flips(centre_x, centre_y, interpolation):
    interpolation_func = interpolation.ev
    
    return (
        get_sum_of_square_penumbra_flip(centre_x, centre_y, interpolation_func) + 
        get_sum_of_square_penumbra_flip_transpose(centre_x, centre_y, interpolation_func)
    )


def create_penumbra_minimisation(interpolation):
    def to_minimise(centre):
        return get_sum_of_square_both_penumbra_flips(centre[1], centre[0], interpolation)
    
    return to_minimise

In [None]:
dx = 0.05
interpolated_distances = np.arange(-initial_mask_distance, initial_mask_distance+dx, dx)

xx, yy = np.meshgrid(interpolated_distances, interpolated_distances)
xx_flat = np.ravel(xx)
yy_flat = np.ravel(yy)

interpolated_image_flat = interpolation.ev(yy_flat, xx_flat)
interpolated_image = np.reshape(interpolated_image_flat, np.shape(xx))

In [None]:
def show_image(pixel_array):
    plt.pcolormesh(interpolated_distances, interpolated_distances, pixel_array, clim=[0, 1])
    plt.colorbar()
    plt.axis('equal')
    
show_image(interpolated_image)

In [None]:
def show_image_with_square(image, centre, edge_length):
    x = centre[1]
    y = centre[0]
    
    plt.plot(
        [x - edge_length/2, x - edge_length/2, x + edge_length/2,  x + edge_length/2, x - edge_length/2],
        [y - edge_length/2, y + edge_length/2, y + edge_length/2,  y - edge_length/2, y - edge_length/2],
        'k', lw=2
    )
    
    show_image(image)
    plt.show()

In [None]:
def create_print_func(image_to_search):
    def print_fun(centre, f, accepted):
        print(centre)
        print(f)
        print(accepted)
        
        show_image_with_square(image_to_search, centre, square_field_side_length)
        
    return print_fun


to_minimise = create_penumbra_minimisation(interpolation)
print_fun = create_print_func(interpolated_image)

In [None]:
centre = [0.86680572, -0.04818984]

show_image_with_square(interpolated_image, centre, square_field_side_length)
print(to_minimise(centre))

In [None]:
results = scipy.optimize.basinhopping(to_minimise, [2,0], T=1, niter=5, stepsize=1)
field_centre = results.x

print(field_centre)

plt.figure(figsize=(15,15))
show_image_with_square(interpolated_image, np.array(field_centre), 18)

In [None]:
results

In [None]:
plt.figure(figsize=(10,10))
show_image_with_square(interpolated_image, field_centre, square_field_side_length*0.8)

In [None]:
# ballbearing_find_mask = (
#     (xx < field_centre[1] - square_field_side_length*0.8 / 2) |
#     (xx > field_centre[1] + square_field_side_length*0.8 / 2) |
#     (yy < field_centre[0] - square_field_side_length*0.8 / 2) |
#     (yy > field_centre[0] + square_field_side_length*0.8 / 2)
# )

# interpolated_image[ballbearing_find_mask] = 1

# show_image_with_square(interpolated_image, initial_centre, square_field_side_length*0.8)

In [None]:
def plot_circle_at_bb(bb_centre):
    t = np.linspace(0, 2*np.pi)
    x = ball_bearing_diameter/2 * np.sin(t) + bb_centre[1]
    y = ball_bearing_diameter/2 * np.cos(t) + bb_centre[0]
    
    plt.plot(x, y, 'k', lw=2)
    

    
plot_circle_at_bb([1.5,0])
show_image_with_square(interpolated_image, field_centre, square_field_side_length*0.8)

In [None]:
def create_points_to_check():
    dtheta = 2*np.pi / 21
    t = np.arange(0, 2*np.pi, dtheta)
    diameters = ball_bearing_diameter * np.arange(0.1, 1, 0.1)

    def points_to_check(bb_centre):  
        x = []
        y = []
        weight = []
        for i, diameter in enumerate(diameters):        
            x.append(diameter/2 * np.sin(t + i*dtheta/5) + bb_centre[1])
            y.append(diameter/2 * np.cos(t + i*dtheta/5) + bb_centre[0])
            
            weight.append(np.repeat(np.cos(np.arcsin(diameter/ball_bearing_diameter)), 21))
            
        x = np.concatenate(x)
        y = np.concatenate(y)
        weight = np.concatenate(weight)
        
        return x, y, weight
    
    return points_to_check

    
points_to_check = create_points_to_check()
x, y, weight = points_to_check([0,0])

plt.plot(x, y, '.')

In [None]:
def check_points(bb_centre, field_centre, interpolation):
    x, y, weight = points_to_check(bb_centre)
    
    results = weight * interpolation.ev(y, x)
    
    point_outside_of_field_centre = (
        (x < field_centre[1] - square_field_side_length*0.8 / 2) |
        (x > field_centre[1] + square_field_side_length*0.8 / 2) |
        (y < field_centre[0] - square_field_side_length*0.8 / 2) |
        (y > field_centre[0] + square_field_side_length*0.8 / 2)
    )
    
    results[point_outside_of_field_centre] = 1
    
    return np.mean(results)

def create_circle_to_minimise(field_centre, interpolation):
    def circle_to_minimise(bb_centre):
        return check_points(bb_centre, field_centre, interpolation)
    
    return circle_to_minimise
    

# check_points([0.25060408, -1.80120831], field_centre, interpolation)

In [None]:
circle_to_minimise = create_circle_to_minimise(field_centre, interpolation)

bb_results = scipy.optimize.basinhopping(circle_to_minimise, [0,0], T=0.1, niter=5, stepsize=1)
bb_results

In [None]:
plt.figure(figsize=(10,10))

plot_circle_at_bb(bb_results.x)
show_image_with_square(interpolated_image, field_centre, square_field_side_length)