In [1]:
%matplotlib inline
import matplotlib
matplotlib.rcParams['text.usetex'] = True

import pink_utils as pu
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import astropy.units as u
from astropy.coordinates import SkyCoord
import pickle
from scipy.stats import percentileofscore
from mpl_toolkits.axes_grid1 import make_axes_locatable

In [2]:
FIRST_PIX = 1.8*u.arcsecond # Pixel size of FIRST survey. Square pixels
FIRST_FWHM = 5*u.arcsecond / FIRST_PIX
FIRST_SIG = FIRST_FWHM / 2.355

In [3]:
df        = pd.read_csv('../FIRST_F1W1_95_5_Sources.csv')
transform = pu.transform('../Small/FIRST_F1W1_95_5_Small_Transform.bin')
ed        = pu.heatmap('../Small/FIRST_F1W1_95_5_Small_Similarity.bin')
som       = pu.som('../Small/FIRST_F1W1_95_5_L1_SOM_Small_5.bin')
images    = pu.image_binary('../FIRST_F1W1_95_5_imgs.bin')

with open('../Small/FIRST_F1W1_95_5_L1_SOM_Small_5_Features-table.pkl','rb') as infile:
    annotations = pickle.load(infile)

100%|██████████| 178859/178859 [00:02<00:00, 75589.33it/s]


In [4]:
print(som.file_head)
print(images.file_head)
print(transform.file_head)
print(ed.file_head)
print(df.shape)

(2, 12, 12, 1, 118, 118)
(178859, 2, 167, 167)
(178859, 12, 12, 1)
(178859, 12, 12, 1)
(178859, 29)


In [5]:
def return_transform(index, pos_min):
    rot = transform.transform(index=index).reshape(transform.header_info[1:])
    flip, ro = rot[pos_min[1], pos_min[0]][0]    

    return flip, ro

def return_bmu(index, best=True):
    hmap = ed.ed(index=index)
    if best:
        pos_min = np.unravel_index(np.argmin(hmap), hmap.shape)
    else:
        pos_min = np.unravel_index(np.argmax(hmap), hmap.shape)
    
    
    return pos_min
    
def return_hmap_stats(index):
    hmap = ed.ed(index=index)
    stats = {'min': np.min(hmap),
             'max': np.max(hmap),
             'sum': np.sum(hmap),
             'hmap':hmap}

    return stats


In [6]:
plt.close('all')

max_index = 2000
cmap = 'Greys'
for count, (index, row) in enumerate(df.iterrows()):
    
    if count !=index:
        print('Mismatch', index, count)
    
    # Here for testing
    if index > max_index:
        break
    
    sky_pos = SkyCoord(ra=row['RA']*u.deg, dec=row['DEC']*u.deg)
    radio_img = np.arcsinh(images.get_image(index=index, channel=0))
    ir_img =np.arcsinh(images.get_image(index=index, channel=1))
    
    # Get rotation information
    # ---------------------------------------------------------
    bmu_pos    = return_bmu(count)
    hmap_stats = return_hmap_stats(count)
    trans_info  = return_transform(count, bmu_pos)
    # ---------------------------------------------------------
    
    # Annotate_map_features.py script recorded positions around the 
    # incorrect convention
    key = (bmu_pos[1], bmu_pos[0]) + (0,)
    bmu = annotations[key]
    
    radio_feat_trans = np.array(bmu.transform_clicks(trans_info, channel=0))
    ir_feat_trans    = np.array(bmu.transform_clicks(trans_info, channel=1))
    
    fig, ((ax1,ax2,ax3),(ax4,ax5,ax6)) = plt.subplots(2,3, figsize=(10,6))
    
    cen = np.array(radio_img.shape)/2
    ax1.imshow(radio_img, cmap=cmap)
    ax1.plot(radio_feat_trans[:,0]+cen[0],
             radio_feat_trans[:,1]+cen[1],
             'r*', ms=15)
    
    zoom_radio_img = pu.zoom(radio_img, 50,50)
    cen = np.array(zoom_radio_img.shape)/2

    ax2.imshow(zoom_radio_img, cmap=cmap)
    ax2.plot(radio_feat_trans[:,0]+cen[0],
            radio_feat_trans[:,1]+cen[1],
            'r*', ms=15, label='Radio Click')
    ax2.scatter(ir_feat_trans[:,0]+cen[0],
                ir_feat_trans[:,1]+cen[1],
                facecolors='none', edgecolors='blue', s=150, 
                linewidths=5, label='IR Click')
    ax2.legend()
    
    score = percentileofscore(ed.data[:,key[1],key[0]], ed.data[index, key[1], key[0]])
    ax3.hist(ed.data[:,key[1],key[0]], bins=50)
    ax3.axvline(ed.data[index, key[1], key[0]], c='black', label=f'This source - {score:.2f}\%')
    ax3.set(xlabel=f'ED for ({key[0],key[1]}) neuron')
    ax3.legend()
    
    cen = np.array(radio_img.shape)/2
    ax4.imshow(ir_img, cmap=cmap)
    ax4.plot(radio_feat_trans[:,0]+cen[0],
             radio_feat_trans[:,1]+cen[1],
             'r*', ms=15)
    ax4.scatter(ir_feat_trans[:,0]+cen[0],
                ir_feat_trans[:,1]+cen[1],
                facecolors='none', edgecolors='blue', s=150, 
                linewidths=5, label='IR Click')
    
    zoom_ir_img = pu.zoom(ir_img, 50,50)
    cen = np.array(zoom_ir_img.shape)/2

    ax5.imshow(zoom_ir_img, cmap=cmap)
    ax5.plot(radio_feat_trans[:,0]+cen[0],
            radio_feat_trans[:,1]+cen[1],
            'r*', ms=15, label='Radio Click')
    ax5.scatter(ir_feat_trans[:,0]+cen[0],
                ir_feat_trans[:,1]+cen[1],
                facecolors='none', edgecolors='blue', s=150, 
                linewidths=5, label='IR Click')
    ax5.legend()
    
    heat = ed.ed(index=index)
    pos = np.unravel_index(np.argmin(heat, axis=None), heat.shape)
    im = ax6.imshow(heat)
    ax6.plot(pos[1], pos[0], 'ro')
    
    divider = make_axes_locatable(ax6)
    cax6 = divider.append_axes('right',size='5%', pad=0.05)
    fig.colorbar(im, cax=cax6, label='Euclidean Distance')
    
    fig.suptitle(f"FIRST Position {sky_pos.to_string('hmsdms')}")
    fig.tight_layout(rect=[0,0,0.95,0.95])
    fig.savefig(f"Radio_Feature_Overlay/{index}.png")
    plt.close(fig)
#     fig.show()

print('Finished')

Finished
