# Writing your own Classes in Python

OOP = "Object Oriented Programming"

In OOP, a **class** describes a particular object type, with its definitions and rules of operation. An **object** is an *instance* of a class. That is, when you invoke the class, you create an object of that class type.

For example: `str` is a Python class, whereas "abc" is a `str` type object. Below, I instantiate the string "abc" and check what type of object it is using the built-in Python function, `type`.

In [None]:
type('abc')

As was demonstrated in the **Python Fundamentals** notebook (Session 2), the `str` (string) class has specific definitions and rules for engaging with syntax.

In [None]:
my_string = "abc"
print(my_string * 2) # print a string that is the original string repeated 2x
print(my_string[0]) # print the first letter in the string

**Typical classes have:**

* Attributes -- stored static quantities
* Methods -- function definitions accessible only by instantiating an object of that class

In [None]:
# Modules I'll use in this notebook

import numpy as np
import matplotlib.pyplot as plt

from astropy.io import ascii

# 1. Write a generic light curve class

Many scientific disciplines work with time series data. Astronomical data typically comes in the form of light, so our typical time series data are usually "light curves" -- measurements of a star's brightness or photon count rates over time.

In this example, I'll write a generic LightCurve\* class to hold the light curve dataset all in one place. The light curve data itself will be the **attributes**. I will define a plotting **method** that plots the light curve data.

\*It is common practice to define Python classes with each word capitalized and no underscore notation. This makes it clear that you are accessing a class definition, not a function or variable (which are typically all lower-case and with underscores separating words).

All class definitions start with the `__init__` method. This the "constructor" method that initializes the object. Note that all methods in the class include `self` as the first variable in the function definition. This is required so that the object has a variable name for referencing its own attributes.

In [None]:
class LightCurve(object):
    """
    A generic light curve
    
    Attributes
    ----------
    
    time : numpy array
    
    value : numpy array, values correspeonding to each time data-point
    
    error : numpy array, error on the values (default: None)
    
    Example Usage
    -------------
    >>> x = np.arange(10)
    >>> y = np.random.normal(10)
    >>> LightCurve(x, y)
    """
    def __init__(self, time, value, error=None):
        self.time = time
        self.value = value
        self.error = error
    
    def plot(self, ax, **kwargs):
        """
        Plot the light curve. Only plots noticed values.
        
        Inputs
        ------

        ax : matplotlib axes object, where the plot will go
        
        **kwargs are passed to matplotlib plotting functions
        """
        if self.error is not None:
            ax.errorbar(self.time, self.value, 
                        yerr=self.error, **kwargs)
        else:
            ax.scatter(self.time, self.value, **kwargs)

In the example below, I create fake light curve data set using the numpy random numbers package. I use the fake data to instatiate a LightCurve object. Then, I plot that data using the `LightCurve.plot` method.

In [None]:
# 1000 uniformly spaced data points between time=0 and 100
my_x = np.linspace(0, 100, 1000) 

# Gaussian distribution of values with mean of 1 and std-dev of 0.3
my_y = np.random.normal(1.0, 0.3, size=len(my_x)) 

# Instantiate the LightCurve object
my_lc = LightCurve(my_x, my_y)

ax = plt.subplot(111) # initiliaze matplotlib axes object
my_lc.plot(ax) # plot my fake light curve data

# No plot is complete without axes labels
ax.set_xlabel('time')
ax.set_ylabel('phot / s')

In [None]:
# Think about showing this
dir(my_lc)

## 1.2 Write a method to filter time

Let's think of a generic task we typically might apply to time series data, such as, filtering it with time. In the case below, I  introduce a new attribute to my `LightCurve` class, called `notice`. This will be a boolean array that specifies which time series data I care about. 

Then I modify the `plot` method so that only noticed data is plotted.

In [None]:
class LightCurve(object):
    """
    A generic light curve
    
    Attributes
    ----------
    
    time : numpy array
    
    value : numpy array, values correspeonding to each time data-point
    
    error : numpy array, error on the values (default: None)
    
    notice : bool array, describes which time values to keep for analysis
    
    >>> x = np.arange(10)
    >>> y = np.random.normal(10)
    >>> LightCurve(x, y)
    """
    def __init__(self, time, value, error=None):
        self.time = time
        self.value = value
        self.error = error
        self.notice = np.ones_like(value, dtype='bool')
    
    def notice_time(self, tmin, tmax):
        """
        Modify the `notice` attribute to filter by time.
        
        Inputs
        ------
        
        tmin : float, lower end of chosen time interval
        
        tmax : float, upper end of chosen time interval
         """
        self.notice = (self.time >= tmin) & (self.time <= tmax)
    
    def reset_notice(self):
        """
        Remove any previous filters so that the full light curve 
        dataset is included in the `notice` boolean array.
        
        Inputs
        ------
        None
        """
        self.notice = np.ones_like(self.value, dtype='bool')
    
    def plot(self, ax, **kwargs):
        """
        Plot the light curve. Only plots noticed values.
        
        Inputs
        ------

        ax : matplotlib axes object, where the plot will go
        
        **kwargs are passed to matplotlib plotting functions
        """
        if self.error is not None:
            ax.errorbar(self.time[self.notice], self.value[self.notice], 
                        yerr=self.error[self.notice], **kwargs)
        else:
            ax.scatter(self.time[self.notice], self.value[self.notice], **kwargs)

Let's try it out!

In [None]:
# First, I have to remake the fake dataset so I can make a LightCurve object 
# that has the new functionalities defined above
my_x = np.linspace(0, 100, 1000) 
my_y = np.random.normal(1.0, 0.3, size=len(my_x)) 
my_lc = LightCurve(my_x, my_y)

ax = plt.subplot(111) 

# plot the original dataset
my_lc.plot(ax, color='0.5', label='original dataset')

# filter so that we only notice datapoints between t=20 and 60
my_lc.notice_time(20, 60)

# now plot the filtered data only
my_lc.plot(ax, color='r', label='20 <= t <= 60')

ax.legend(loc='upper right')

### Exercises

* How would you calculate the mean and standard deviation of the entire dataset? How would you calculate those values for the "noticed" time values only?

* Challenge: What are some ways you can make the `notice_time` method more flexible? For example, what if you wanted to filter out $t < 20$?

* What are other types of analyses one might perform with light curves? How would you add that to the LightCurve class?

## 1.3 Challenge: Write an outlier clipping algorithm

Sometimes there are errors in measurement (e.g., instrument anomalies or human error) that lead to a few data points to appear significantly different from the broad trends. A reasonably suspicious data point is called an "outlier". Outliers are sometimes flagged for removal from the final scientific analysis.

One method for investigating outliers is to look for data points that are significantly far from the average value in the data set. With a light curve, that means examining the values on the y-axis. We can use numpy to calculate the mean and standard deviation (sigma) of the light curve values, then search for data points that are > N-sigma away from the mean. Removing outliers via this method is referred to as "sigma-clipping".

Write a sigma-clipping function that removes data points that are > N sigma values away from the mean, where `sigma_threshold` represents the value for N in the inequality.

In [None]:
class LightCurve(object):
    """
    A generic light curve
    
    Attributes
    ----------
    
    time : numpy array
    
    value : numpy array, values correspeonding to each time data-point
    
    error : numpy array, error on the values (default: None)
    
    notice : bool array, describes which time values to keep for analysis
    
    Example Usage
    -------------
    >>> x = np.arange(10)
    >>> y = np.random.normal(10)
    >>> LightCurve(x, y)
    """
    def __init__(self, time, value, error=None):
        self.time = time
        self.value = value
        self.error = error
        self.notice = np.ones_like(value, dtype='bool')
    
    def notice_time(self, tmin, tmax):
        """
        Modify the `notice` attribute to filter by time.
        
        Inputs
        ------
        
        tmin : float, lower end of chosen time interval
        
        tmax : float, upper end of chosen time interval
         """
        self.notice = (self.time >= tmin) & (self.time <= tmax)
    
    def plot(self, ax, **kwargs):
        """
        Plot the light curve. Only plots noticed values.
        
        Inputs
        ------

        ax : matplotlib axes object, where the plot will go
        
        **kwargs are passed to matplotlib plotting functions
        """
        if self.error is not None:
            ax.errorbar(self.time[self.notice], self.value[self.notice], 
                        yerr=self.error[self.notice], **kwargs)
        else:
            ax.scatter(self.time[self.notice], self.value[self.notice], **kwargs)
    
    def clip_sigma(self, sigma_threshold):
        """
        Remove outliers by filtering out values that are far from the mean value of the light curve.
        
        Inputs
        ------
        
        sigma_threshold : float, data that is > `sigma_threshold` standard deviations away 
                          from the light curve mean will be filtered out of the light curve
        
        Returns
        -------
        The `notice` array is modified so that only values within `sigma_threshold` of the 
        mean will be included.
        """
        igood = np.ones_like(self.value, dtype='bool')
        test_good_value_length = len(self.value[igood])

        continue_clipping = True
        while continue_clipping:
            
            # How many standard deviations is the data point away from the mean?
            std_dev = np.std(self.value[igood])
            mean    = np.mean(self.value[igood])
            test_distance = np.abs(self.value - mean) / std_dev
            
            # Keep values that are within threshold number of standard deviations away from mean
            igood = (test_distance <= sigma_threshold) # boolean array
            
            # Decide whether or not we are done
            test_length = len(self.value[igood])
            print('Total number of values clipped: {}'.format(len(self.value) - test_length))
            if test_length == test_good_value_length: # if the length of the array has NOT changed
                print("Done clipping")
                continue_clipping = False # we can stop clipping
            else:
                test_good_value_length = test_length # reset the test length
                print("Continue clipping")
        
        self.notice = igood

### Test it out!

In a Gaussian distribution of data points, 0.3% of the data will be > 3-sigma away from the mean. So with 1000 Gaussian data points, I expect a few of them to be removed by my sigma-clipping algorithm if I set `sigma_threshold = 3`

In [None]:
# First, I have to remake the fake dataset so I can make a LightCurve object 
# that has the new functionalities defined above
my_x = np.linspace(0, 100, 1000) 
my_y = np.random.normal(1.0, 0.3, size=len(my_x)) 
my_lc = LightCurve(my_x, my_y)

ax = plt.subplot(111) 

# plot the original dataset
my_lc.plot(ax, color='r', label='original dataset')

# Remove data points that are > 3 sigma from the mean
my_lc.clip_sigma(3.0)

# now plot the filtered data only
my_lc.plot(ax, color='0.5', label='outliers removed')

ax.legend()

# 2. Write a class with inheritance

Now let's work with real data! In this case, we'll be using light curves provided by the [MAXI all sky survey](http://maxi.riken.jp/top/index.html), which monitors bright X-ray point sources across the entire sky every day. The [light curves](http://maxi.riken.jp/top/lc.html) for these X-ray sources can be downloaded from the MAXI website in ASCII text format.

In the case below, I write a subclass of the `LightCurve` object. It opens the text files I downloaded from the website and stores the 2-20 keV as a light curve. (See the MAXI light curve [README](http://maxi.riken.jp/top/lc_readme.txt) file.)

In [None]:
class MAXILightCurve(LightCurve):
    """
    A MAXI Light Curve loaded from the data files provided on the MAXI website.
    """
    def __init__(self, filename):
        data = ascii.read(filename)
        LightCurve.__init__(self, data['col1'], data['col2'], error=data['col3'])

## 2.1 Test out your sigma clipping algorithm on some real data

I have included the light curve for Cyg X-2, and X-ray binary in the constellation of Cygnus. The code below applies the sigma-clipping algorithm you wrote in order to remove flares from the dataset.

In [None]:
cygX2_lc = MAXILightCurve('J2144+383_g_lc_1day_all.dat')

ax = plt.subplot(111)
cygX2_lc.plot(ax, color='k', marker='.', lw=0.5, alpha=0.5, ls='')
ax.set_xlabel('MJD')
ax.set_ylabel('ph cm^-2 s^-1')

# Here, I am going to print and overplot the mean and 3 x the standard deviation
# To get a sense for which dat points will be clipped.
mean = np.mean(cygX2_lc.value)
stddev = np.std(cygX2_lc.value)
print("Mean flux: {:.2f} +/- {:.2f}".format(mean, stddev))

ax.axhline(mean, color='r')
ax.axhline(mean + 3 * stddev, color='r', ls='--')
ax.axhline(mean - 3 * stddev, color='r', ls='--')

In [None]:
## Now ;et's remove 3-sigma flares

cygX2_lc.clip_sigma(3.0)

print('')
print('Quiescent Mean Flux {:.2f} +/- {:.2f}'.format(np.mean(cygX2_lc.value[cygX2_lc.notice]), 
                                                     np.std(cygX2_lc.value[cygX2_lc.notice])))

# The new plot should show some data points removed
ax = plt.subplot(111)
cygX2_lc.plot(ax, color='k', marker='.', lw=0.5, alpha=0.5)
ax.set_xlabel('MJD')
ax.set_ylabel('ph cm^-2 s^-1')

**Thought Question:** How many data points got clipped? Does it look  reasonable to you?

## 2.2 Sigma clipping on a source with strong outbursts

The X-ray binary GX 339-4 undergoes large outbursts every few years. Let's see how the sigma-clipping algorithm does when we try to remove those outbursts from the dataset.

In [None]:
gx339_lc = MAXILightCurve('J1702-487_g_lc_1day_all.dat')

ax = plt.subplot(111)
gx339_lc.plot(ax, color='k', marker='.', lw=0.5, alpha=0.5)
ax.set_xlabel('MJD')
ax.set_ylabel('ph cm^-2 s^-1')

# Here, I am going to print and overplot the mean and 3 x the standard deviation
# To get a sense for which dat points will be clipped.
mean = np.mean(gx339_lc.value)
stddev = np.std(gx339_lc.value)
print("Mean flux: {:.2f} +/- {:.2f}".format(mean, stddev))

ax.axhline(mean, color='r')
ax.axhline(mean + 3. * stddev, color='r', ls='--')
ax.axhline(mean - 3. * stddev, color='r', ls='--')

In [None]:
## Now remove 3-sigma flares
gx339_lc.clip_sigma(3.0)

print('')
print('Quiescent Mean Flux {:.2f} +/- {:.2f}'.format(np.mean(gx339_lc.value[gx339_lc.notice]), 
                                                     np.std(gx339_lc.value[gx339_lc.notice])))

# The new plot should show some data points removed
ax = plt.subplot(111)
gx339_lc.plot(ax, color='k', marker='.', lw=0.5, ls='', alpha=0.5)
ax.set_xlabel('MJD')
ax.set_ylabel('ph cm^-2 s^-1')

**Thought Question:** Did you algorithm perform as planned? Is there something you would do to improve it?

**Bonus:** How would you modify the code to get the light curve data for the periods of outburst?

## 3. Super challenge

If you look at the MAXI light curve files, you'll see that there are nine columns in total. That's because the light curves are also separated into multiple bandpasses: 2-20 keV (broad), 2-4 keV (soft), 4-10 keV (medium) and 10-20 keV (hard) X-ray light ([as described in the light curve README file](http://maxi.riken.jp/top/lc_readme.txt)).

* How would you write your `MAXILightCurve` class to load all of the light curves available in the ASCII file?

* Write a method that plots the totsl 2-20 keV brightness versus spectral hardness (e.g., medium/soft). This is the famous "q-diagram" used to study accretion onto black holes and neutron stars. (e.g., [http://www.sternwarte.uni-erlangen.de/proaccretion/](http://www.sternwarte.uni-erlangen.de/proaccretion/))

* Where do the flares fall on this spectra hardness ratio diagram? Where are the quiescent data points?

## Further concepts

Ask me later if you are interested in hearing about some of these concepts!

**More advanced things you can build into classes:**

* Properties -- functions that require no input and are calculated on the fly. You can access them in the same way you access attributes.
* Private methods -- class function definitions that are usually not intended to be accessed by the user. Usually these function names start with an underscore (e.g., _my_hidden_func) or double underscore
* Magic methods -- special function definitions that tells Python what to do in special circumstances. For example, defining a `__getitem__` method for your class will tell Python what to do when you subcript that object with bracket notation (e.g., `x[0]`) 