# Calculate the RMS values for a given dataset and fit a gaussian on it


In [None]:
from tkp.db.model import Image
import tkp.db
import math
import numpy as np
from scipy.optimize import leastsq
from collections import defaultdict
import matplotlib
%matplotlib inline
from matplotlib import pyplot

### settings

In [None]:
host = 'localhost'
port = 5432
user = 'gijs'
password = 'gijs'
database = 'gijs'
dataset_id = 1
sigma = 4

### Helper functions

In [None]:
def norm2(x, mean, sd):
    """
    creates a normal distribution in a simple array for plotting
    """
    normdist = []
    for i in range(len(x)):
        normdist += [1.0/(sd*np.sqrt(2*np.pi))*np.exp(-(x[i] - mean)**2/(2*sd**2))]
    return np.array(normdist)


def guess_p(x):
    """
    estimate the mean and rms as initial inumpy.ts to the Gaussian fitting
    """
    if len(x) == 0:
        return [0, 0, 0]
    median = np.median(x)
    temp = [n*n-(median*median) for n in x]
    rms = math.sqrt((abs(sum(temp))/len(x)))
    return [median, rms, math.sqrt(len(x))]


def res(p, y, x):
    """
    calculate residuals between data and Gaussian model
    """
    m1, sd1, a = p
    y_fit = a*norm2(x, m1, sd1)
    err = y - y_fit
    return err


def rms_histogram(x, sigma=8, name='rms_plot'):
    """
    args:
        x: an array of RMS values

    returns:
        a matplotlib figure canvas
    """
    p = guess_p(x)
    hist_x = np.histogram(x, bins=50)              # histogram of data
    range_x = [hist_x[1][n]+(hist_x[1][n+1]-hist_x[1][n])/2. for n in range(len(hist_x[1])-1)]
    plsq = leastsq(res, p, args=(hist_x[0], range_x))  # fit Gaussian to data
    fit2 = plsq[0][2]*norm2(range_x, plsq[0][0], plsq[0][1])  # create Gaussian distribution for plotting on graph
    sigcut = plsq[0][0]+plsq[0][1]*sigma  # max threshold defined as (mean + RMS * sigma)
    sigcut2 = plsq[0][0]-plsq[0][1]*sigma  # min threshold defined as (mean - RMS * sigma)

    xvals = np.arange(int(min(range_x)), int(max(range_x)+1.5), 1)
    xlabs = [str(10.**a) for a in xvals]
    fig = pyplot.figure(figsize=(5, 5))
    pyplot.hist(x, bins=50, histtype='stepfilled')
    pyplot.plot(range_x, fit2, 'r-', linewidth=3)
    pyplot.axvline(x=sigcut, linewidth=2, color='k', linestyle='--')
    pyplot.axvline(x=sigcut2, linewidth=2, color='k', linestyle='--')
    pyplot.xticks(xvals, xlabs)
    pyplot.xlabel(name)
    pyplot.ylabel('Number of images')

### setup connections and get images

In [None]:
db = tkp.db.Database(host=host, engine='postgresql', port=port, user=user, password=password, database=database)
db.connect()
session = db.Session()
images = session.query(Image).filter(Image.dataset_id == dataset_id).all()
print("number of images: %s" % len(images))

### Prepare data and render plots

In [None]:
freqs = defaultdict(list)
all_rms = []
for i in images:
    all_rms += [i.rms_qc]
    freqs[i.band.freq_central].append(i.rms_qc)
    
rms_values = [i.rms_qc for i in images]
name = 'RMS plot for all frequencies'
rms_histogram(all_rms, sigma=sigma, name=name)

for freq, values in freqs.items():
    name = 'RMS plot for frequency %s' % freq
    rms_histogram(values, sigma=sigma, name=name)
    