In [None]:
#%matplotlib inline
%matplotlib notebook
%config InlineBackend.figure_format='retina'

In [None]:
import os 

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.colors as colors

from astropy import units
from matplotlib.patches import Circle

from hsr4hci.utils.fits import read_fits
from hsr4hci.utils.masking import get_predictor_mask, get_sausage_mask
from hsr4hci.utils.units import set_units_for_instrument

In [None]:
set_units_for_instrument(pixscale=units.Quantity(0.0271, 'arcsec / pixel'),
                         lambda_over_d=units.Quantity(0.096, 'arcsec'))

In [None]:
results_dir = '/Users/timothy/Desktop/hsr4hci/experiments/playground/local/results/round_000/'
coefficients = read_fits(os.path.join(results_dir, 'coefficients.fits'))
print(coefficients.shape)

In [None]:
def draw(position, split_idx=0):
    
    plt.clf()
    mask_size = coefficients.shape[1:3]
    
    # Select the coefficients and show them
    coef = np.mean(coefficients[:, position[0], position[1], ...], axis=0)
    
    n_predictors = np.count_nonzero(~np.isnan(coef))
    limit = np.percentile(np.nan_to_num(coef), 99)
    plt.imshow(coef, origin='lower', cmap='RdBu_r', vmin=-1.5*limit, vmax=1.5*limit)
    
    # Add X for current position
    plt.plot(position[1], position[0], 'x', ms=6, color='blue', zorder=99)
        
    # Plot a circle at 0.7 arcsec
    circle = plt.Circle((mask_size[0]/2, mask_size[1]/2), 
                        units.Quantity(0.7, 'arcsec').to('pixel').value, 
                        facecolor='none', edgecolor='gray', zorder=99)
    plt.gca().add_artist(circle)
        
    plt.gca().set_position([0, 0, 1, 1], which='both')    
    plt.gcf().suptitle(f'position: {position} | '
                       f'n_predictors: {n_predictors} |'
                       f' scale: ({-1.5*limit:.2f}, {1.5*limit:.2f})',
                       fontsize=8)

    plt.show()

    
def on_press(event):
    position_ = (int(event.ydata), int(event.xdata))
    draw(position_)

plt.gcf().set_size_inches(5, 5, forward=True)
plt.gcf().canvas.mpl_connect('button_press_event', on_press)

draw(position=(30, 30), split_idx=4)