In [None]:
import coord_transform as tr
import scale_model as scales

import astropy.units as u
import astropy.coordinates as c
from astropy.wcs import WCS
from astropy.io import fits
from astropy.utils.data import get_pkg_data_filename
from astropy.table import Table

import numpy as np
import matplotlib.pyplot as plt
import scipy.ndimage as scimg
from skimage import restoration
import scipy.signal as signal
from scipy.stats import spearmanr

In [None]:
class Image_Signal():

    def __init__(self, ID):

        # ID is the galaxy name that will be displayed on plots

        self.name = ID

        maps = np.array(['7.7', 'CO', 'Ha', '21'])
        n_maps = maps.size
        empty = np.empty(n_maps, dtype=np.ndarray)

        # each column is a different map
        # rows are:           name, orig,  wcs,   smooth,detail,polar
        self.data = np.array([maps, empty, empty, empty, empty, empty])

        self.peak_prop()

        return
    

    def getscales(self, filepath):

        # calculates predicted scales from a mega table

        self.Scales = scales.Scales_Model()
        self.Scales.inst(self.name, filepath)

        return
    
    
    def readdata(self, filepath, maptype):

        # inputs a map from a fits file

        if maptype == '7.7':

            filename = get_pkg_data_filename(filepath)
            hdu = fits.open(filename)[0]
            map_i = 0
        elif maptype == 'CO':

            filename = get_pkg_data_filename(filepath)
            hdu = fits.open(filename)[0]
            map_i = 1
        elif maptype == 'Ha':

            hdulist = fits.open(filepath)
            hdu = hdulist['HA6562_FLUX']
            map_i = 2
        elif maptype == '21':

            filename = get_pkg_data_filename(filepath)
            hdu = fits.open(filename)[1]
            map_i = 3
        else:

            raise ValueError('invalid map type')
        
        data = np.array(hdu.data, dtype=np.float64)
        wcs = WCS(hdu.header)

        np.nan_to_num(data, copy=False)

        self.data[1, map_i] = data
        self.data[2, map_i] = wcs

        return
    

    def inputdata(self, image, wcs, maptype):

        # inputs a new map under a new name

        arr = self.data.copy()

        new = np.array([maptype, image, wcs, None, None, None])

        appended = np.hstack((arr, new))

        self.data = appended

        return
    

    def setgeometry(self, row_name, filepath):

        # sets the galaxy geometry for polar deprojection

        table = Table.read(filepath)

        RA = table[table['name']==row_name]['orient_ra'].value
        Dec = table[table['name']==row_name]['orient_dec'].value
        incl = table[table['name']==row_name]['orient_incl'].value
        PA = table[table['name']==row_name]['orient_posang'].value
        dist = table[table['name']==row_name]['dist'].value * 1000

        self.Galaxy = tr.Galaxy()
        self.Galaxy.setpar(RA, Dec, incl, PA, dist)

        return
    

    def picture_prop(self, Rmax, slices=90, sig_resol=100, smoothing=1000):

        # sets parameters

        # number of radial bins of polar image
        self.slices = slices
        
        # radius to which polar image is filled
        self.Rmax = Rmax
        self.bound_rad = np.linspace(0, self.Rmax, self.slices+1)

        # number of steps in the correlation
        self.N = sig_resol
        self.rollangle = (2*np.pi)/self.N*np.arange(self.N)

        self.Rshift = Rmax/slices*2
        self.roll_R = np.linspace(0, self.Rshift, self.N)

        self.smooth_strength = smoothing

        return


    def process(self, maptype, tlog=True):

        # applies image smoothing and log transform

        if maptype in self.data[0]:
            
            map_i = np.where(self.data[0]==maptype)[0][0]
            data = self.data[1, map_i]
        else:

            raise ValueError('invalid map type')
        
        processed = data.copy()

        if tlog:

            processed *= 1000/processed.max()
            processed.clip(1e-7, None, processed)
            processed = np.log(processed)

        processed.clip(0, None, processed)
            
        smoothimage = restoration.denoise_tv_bregman(processed, weight=1/self.smooth_strength, isotropic=False)
        processed -= smoothimage

        processed.clip(0, None, processed)

        self.data[3, map_i] = smoothimage
        
        self.data[4, map_i] = processed

        return
    

    def skyshow(self, maptype, smooth=None):

        # shows an image in sky coordinates

        if maptype in self.data[0]:

            map_i = np.where(self.data[0]==maptype)[0][0]
            wcs = self.data[2, map_i]
        else:

            raise ValueError('invalid map type')
        
        if smooth:
            
            # shows the smoothed image
            image = self.data[3, map_i]
        elif smooth is None:
            
            # shows the original image
            image = self.data[1, map_i]
        elif ~smooth:
            
            # shows the processed image
            image = self.data[4, map_i]

        v_max = image.mean()+3*image.std()

        plt.figure()
        plt.subplot(projection=wcs)
        plt.title(self.name)
        plt.imshow(image, vmin=0, vmax=v_max, cmap='inferno')
        plt.colorbar()
        plt.xlabel('RA')
        plt.ylabel('Dec')

        return


    def unwind(self, maptype, smooth=True):

        # deprojects an image from sky to galactic polar coordinates
        # has to run after all maps are inputted to keep the shape consistent
        # if 'smooth' is changed it will override the previous image

        if maptype in self.data[0]:
            
            map_i = np.where(self.data[0]==maptype)[0][0]
            wcs = self.data[2, map_i]
        else:

            raise ValueError('invalid map type')
        
        print('Map: '+maptype)

        if smooth:

            data = self.data[4, map_i].copy()
        else:

            data = self.data[1, map_i].copy()

        x_max = np.max([self.data[1, i].shape[1] for i in np.where(np.array([isinstance(j, np.ndarray) for j in self.data[1]]))[0]])
        y_max = np.max([self.data[1, i].shape[0] for i in np.where(np.array([isinstance(j, np.ndarray) for j in self.data[1]]))[0]])
        print('R shape =', y_max)
        print('phi shape =', x_max)

        self.R, self.phi = np.meshgrid(np.linspace(0, self.Rmax, y_max), np.linspace(0, (2*np.pi), x_max, endpoint=0))
        self.R_log = np.log10(self.R.clip(1, None))

        skymap = tr.cartesian(self.Galaxy, [self.R, (self.phi*180/np.pi)*u.deg])
        pixels = wcs.world_to_pixel(c.SkyCoord(skymap[0].value, skymap[1].value, unit='deg'))
        polar_image = scimg.map_coordinates(data.transpose(), pixels, order=3)
        print('Polar image shape =', polar_image.shape)

        self.data[5, map_i] = polar_image

        return
    

    def polarshow(self, maptype):

        # shows a deprojected image

        if maptype in self.data[0]:
            
            map_i = np.where(self.data[0]==maptype)[0][0]
            polar_image = self.data[5, map_i]
        else:

            raise ValueError('invalid map type')

        lines = np.log10(self.bound_rad[1:])[np.where(self.bound_rad >= 1)[0]-1]
        v_max = polar_image.mean()+3*polar_image.std()

        plt.figure()
        plt.title(self.name)
        plt.pcolormesh(self.phi, self.R_log, polar_image, vmin=0, vmax=v_max, cmap='inferno')
        plt.errorbar(np.ones(lines.size)*np.pi, lines, xerr=np.pi, yerr=None, fmt='.', markersize=0.1, c='w', lw=0.5, alpha=0.5)
        plt.xlim(xmax=2*np.pi)
        plt.colorbar()
        plt.xlabel('Polar angle (rad)')
        plt.ylabel('log Galactic radius (kpc)')

        return
    

    def take_slice(self, layer, maptype):

        # returns the slice of a polar image with index 'layer'
        # layer index starts from 1

        if maptype in self.data[0]:
            
            map_i = np.where(self.data[0]==maptype)[0][0]
            polar_image = self.data[5, map_i]
        else:

            raise ValueError('invalid map type')

        R_low = self.bound_rad[layer-1]
        R_high = self.bound_rad[layer]

        i_low = np.floor(R_low/self.Rmax*polar_image.shape[1]).astype(int)
        i_high = np.ceil(R_high/self.Rmax*polar_image.shape[1]).astype(int)
        image_slice = polar_image[:,i_low:i_high]

        return image_slice, [i_low, i_high]


    def correlate(self, layer, dR, map1, map2=None, method='linear'):

        # returns a correlation signal for the slice 'layer' between map1 and optional map2
        # with a radial offset dR

        if map2 is None:

            map2 = map1

        if (map1 in self.data[0]) & (map2 in self.data[0]):
            
            map_i1 = np.where(self.data[0]==map1)[0][0]
            map_i2 = np.where(self.data[0]==map2)[0][0]
            polar_image = self.data[5, map_i2]
        else:

            raise ValueError('invalid map type')
        
        image_slice, ind = self.take_slice(layer, map1)
        i_dR = np.floor(dR/self.Rmax*polar_image.shape[1]).astype(int)
        image_rshift = np.roll(polar_image, -i_dR, 1)[:,ind[0]:ind[1]]

        indices = np.ceil((np.arange(self.N)/self.N)*image_slice.shape[0]).astype(int)
        signl = np.zeros(self.N)
        if method == 'linear':

            for i in range(self.N):

                image_shift = np.roll(image_rshift, indices[i], 0)
                avg = np.multiply(image_slice, image_shift).mean()
                signl[i] = avg

            signl /= (image_slice**2).mean()
        elif method == 'spearman':

            for i in range(self.N):

                image_shift = np.roll(image_rshift, indices[i], 0)
                avg, p_ = spearmanr(image_slice, image_shift, axis=None)
                signl[i] = avg
        else:

            raise ValueError('invalid correlation method')
    
        return signl
    

    def rad_shift(self, layer, map1, map2=None, method='linear'):

        # returns a 2D correlation signal as a function of angular and radial offset
        
        signals = np.zeros((self.N,self.N))
        for i in range(self.N):

            signl = self.correlate(layer, self.roll_R[i], map1, map2, method)
            signals[i,:] = signl

        return signals
    

    def plot_fullsignal(self, layer, map1, map2=None, method='linear'):

        # plots the full correlation signal of a slice

        signals = self.rad_shift(layer, map1, map2, method)

        plt.figure()
        plt.pcolormesh(self.rollangle, self.roll_R, signals, cmap='plasma')
        plt.colorbar()
        plt.xlabel('angular offset (rad)')
        plt.ylabel('radial offset (kpc)')
        plt.title('Corrrelation signal (slice '+str(layer)+' out of '+str(self.slices)+')' + '\nR = ' + self.bound_rad[layer-1].round(3).astype(str) + ' to ' + self.bound_rad[layer].round(3).astype(str) + ' kpc')

        return


    def plot_signals_subplots(self, map1, map2=None, method='linear'):

        # plots the 2D correlation signal of every slice and puts them on one figure

        num_rows = int(np.ceil(np.sqrt(self.slices)))
        num_cols = int(np.ceil(self.slices / num_rows))
        
        fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 15))
        fig.suptitle(self.name + ' Signals Subplots')

        for layer, ax in enumerate(axes.flat):

            if layer >= self.slices:
                
                ax.axis('off')
                continue
            
            signals = self.rad_shift(layer+1, map1, map2, method)
            rollangle = self.rollangle
            roll_R = self.roll_R

            ax.pcolormesh(rollangle, roll_R, signals, cmap='plasma')
            ax.set_title('Slice {}'.format(layer+1), fontsize=10)
            ax.tick_params(axis='both', which='both', labelsize=5)
        
        plt.tight_layout()
        plt.show()

        return


    ###### peaks ######

    def peak_prop(self, dist=1, prom=0.05, height=0):

        # changes find_peak parameters

        self.dist_ = dist
        self.prom_ = prom
        self.height_ = height

        return
    

    def signal_peaks(self, signl):

        # finds peaks in a 1D signal

        peaks, _ = signal.find_peaks(signl, distance=self.dist_, prominence=self.prom_, height=self.height_)

        peak_values = signl[peaks]
        peak_phi = self.rollangle[peaks]

        return peak_values, peak_phi
    

    def plot_signalpeaks(self, layer, dR, map1, map2=None, method='linear'):

        # plots the signal and peaks of a radial slice with a radial offset

        signl = self.correlate(layer, dR, map1, map2, method)
        
        peak_values, peak_phi = self.signal_peaks(signl)

        slice_str = 'R = ' + str(round(self.bound_rad[layer-1], 2)) + ' to ' + str(round(self.bound_rad[layer], 2)) + ', radial offset = ' + str(dR)

        plt.figure()
        fig, ax = plt.subplots()
        plt.scatter(peak_phi, peak_values, c='r', s=15)
        plt.plot(self.rollangle, signl, c='b')
        if method == 'linear':

            ax.set_ylim(ymin=-0.05, ymax=1.05)
        elif method == 'spearman':
                
            ax.set_ylim(ymin=-1.05, ymax=1.05)
        plt.title(slice_str)
        plt.xlabel('angular offset (rad)')
        plt.ylabel('linear correlation')

        return

    
    def plot_peaks_subplots(self, dR, map1, map2=None, method='linear'):

        # plots all signals at 0 radial offset

        num_rows = int(np.ceil(np.sqrt(self.slices)))
        num_cols = int(np.ceil(self.slices / num_rows))
        
        fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 15))
        fig.suptitle(self.name + ' Signals Subplots')

        for layer, ax in enumerate(axes.flat):

            if layer >= self.slices:

                ax.axis('off')
                continue
            
            signl = self.correlate(layer, dR, map1, map2, method)
            
            rollangle = self.rollangle

            peak_values, peak_phi = self.signal_peaks(signl)

            ax.plot(rollangle, signl)
            ax.scatter(peak_phi, peak_values, c='r', s=10)
            ax.set_title('Slice {}'.format(layer+1), fontsize=10)# + '\nR = ' + np.round(self.bound_rad[layer], 3).astype(str) + ' to ' + np.round(self.bound_rad[layer+1], 3).astype(str) + ' kpc', fontsize=10)
            if method == 'linear':

                ax.set_ylim(ymin=-0.05, ymax=1.05)
            elif method == 'spearman':
                
                ax.set_ylim(ymin=-1.05, ymax=1.05)
            ax.tick_params(axis='both', which='both', labelsize=5)
        
        plt.tight_layout()
        plt.show()

        return


    def measure_scales(self, maptype, method='linear', cutoff=(np.pi/4), grouping=5):

        # measures scales from the correlation signal
        
        feature_length = []
        feature_strength = []
        radius = []
        for layer in range(self.slices):

            signl = self.correlate(layer, 0, maptype, None, method)

            peak_values, peak_phi = self.signal_peaks(signl)
            if peak_phi.shape[0] > 0:

                if peak_phi[0] <= cutoff:
                    
                    feature_length.append(peak_phi[0]*self.bound_rad[layer]*1000)
                    feature_strength.append(peak_values[0])
                    radius.append(self.bound_rad[layer+1])
        
        feature_length = np.array(feature_length)
        feature_strength = np.array(feature_strength)
        radius = np.array(radius)

        num_groups = len(feature_length) // grouping

        feature_reshaped = feature_length[:num_groups*grouping].reshape([num_groups, grouping])
        strength_reshaped = feature_strength[:num_groups*grouping].reshape([num_groups, grouping])
        radius_reshaped = radius[:num_groups*grouping].reshape([num_groups, grouping])

        measurement = np.average(feature_reshaped, axis=1, weights=strength_reshaped)
        uncertainty = feature_reshaped.std(axis=1) / np.sqrt(grouping)
        radii = radius_reshaped.mean(axis=1)

        return [measurement, radii], uncertainty


    def auto_scales(self, maptype, tlog=True, smooth=True, method='linear'):

        # plots scales measured with auto correlation

        self.process(maptype, tlog)
        self.unwind(maptype, smooth)

        results, uncertainty = self.measure_scales(maptype, method)
        self.Scales.plot_allscales(results, uncertainty)

        return

In [None]:
galaxy_ID = Image_Signal('Galaxy ID')
galaxy_ID.getscales('mega table')
galaxy_ID.readdata('fits file', '7.7')
galaxy_ID.readdata('fits file', 'CO')
galaxy_ID.readdata('fits file', 'Ha')
galaxy_ID.readdata('fits file', '21')
galaxy_ID.setgeometry('galaxy_ID', 'sample table')
galaxy_ID.picture_prop(7)
galaxy_ID.auto_scales('7.7')

#galaxy_ID.skyshow('7.7', None)
#galaxy_ID.skyshow('7.7', True)
#galaxy_ID.skyshow('7.7', False)
#galaxy_ID.polarshow('7.7')