In [None]:
%matplotlib notebook
%config InlineBackend.figure_format='retina'

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.colors as colors

from astropy import units
from matplotlib.patches import Circle

from hsr4hci.utils.masking import get_predictor_mask, get_sausage_mask
from hsr4hci.utils.units import set_units_for_instrument

In [None]:
set_units_for_instrument(pixscale=units.Quantity(0.0271, 'arcsec / pixel'),
                         lambda_over_d=units.Quantity(0.096, 'arcsec'))

In [None]:
def get_cmap(color):
    color_list = ['white', color]
    cmap = colors.ListedColormap(color_list)
    cmap2 = cmap(np.arange(cmap.N))
    cmap2[:,-1] = np.linspace(0, 1, cmap.N)
    cmap2 = colors.ListedColormap(cmap2)
    return cmap2

def get_subsampling_mask(mask_size):
    return np.indices(mask_size).sum(axis=0) % 2

def draw(position):
    
    plt.clf()
    
    mask_size = (81, 81)
    field_rotation = units.Quantity(90, 'degree')
    annulus_width = units.Quantity(1.0, 'lambda_over_d')
    radius_position = units.Quantity(4.00, 'lambda_over_d')
    radius_mirror_position = units.Quantity(2.00, 'lambda_over_d')
    minimum_distance = units.Quantity(1.00, 'lambda_over_d')

    # Get masks
    subsampling_mask = get_subsampling_mask(mask_size)
    
    # Get the mask that selects all potential predictor pixels
    predictor_mask = get_predictor_mask(mask_size=mask_size,
                                        position=position,
                                        annulus_width=annulus_width,
                                        radius_position=radius_position,
                                        radius_mirror_position=radius_mirror_position)
    # predictor_mask = np.logical_and(predictor_mask, subsampling_mask)

    # Get exclusion mask (i.e., pixels we must not use as predictors)
    exclusion_radius = minimum_distance.to('pixel').value
    opening_angle = 2 * field_rotation.to('degree').value
    exclusion_mask = get_sausage_mask(mask_size=mask_size,
                                      position=position,
                                      radius=exclusion_radius,
                                      opening_angle=opening_angle)
    
    # Compute number of predictor pixels
    selection_mask = np.logical_and(np.logical_not(exclusion_mask),
                                    predictor_mask)
    n_predictors = np.sum(selection_mask)
    
    # display_mask = np.stack([exclusion_mask, predictor_mask, np.zeros(mask_size)], axis=-1)
    plt.imshow(predictor_mask, origin='lower', cmap=get_cmap('green'))
    plt.imshow(exclusion_mask, origin='lower', cmap=get_cmap('red'))
    
    # plt.imshow(display_mask, origin='lower')
    plt.plot(position[1], position[0], 'x', ms=6, color='blue', zorder=99)
        
    # Plot a circle at the separation of the position
    sep = np.sqrt((position[0] - mask_size[0]/2)**2 + (position[1] - mask_size[1]/2)**2 )
    circle = plt.Circle((mask_size[0]/2, mask_size[1]/2), sep, ls='--', facecolor='none', edgecolor='Black', zorder=99)
    plt.gca().add_artist(circle)
    
    # Plot a circle at 0.7 arcsec
    circle = plt.Circle((mask_size[0]/2, mask_size[1]/2), 
                        units.Quantity(0.7, 'arcsec').to('pixel').value, 
                        facecolor='none', alpha=0.5, edgecolor='gray', zorder=99)
    plt.gca().add_artist(circle)

    # plt.gca().axis('off')
    plt.gca().set_position([0, 0, 1, 1], which='both')
    
    plt.gcf().suptitle(f'position: {position} | n_predictors: {n_predictors}')
    
    plt.show()

    
def on_press(event):
    position = (int(event.ydata), int(event.xdata))
    draw(position)

plt.gcf().set_size_inches(5, 5, forward=True)
plt.gcf().canvas.mpl_connect('button_press_event', on_press)

position = (50, 60)
draw(position)