# Week 5: Classes, objects, methods and all that

## Introduction

Remember the very first line of code we ran in the first class? It went something like...  
> int_number = 56

It looks quite simple. We tell the Python interpreter (just a fancy name given to programs that run code, like "python3" in your command line) to create the variable name int_number and assign to it 56. A variable is just a little chunk of the computer's memory reserved to store its name and a pointer to another piece of memory where its value is. But there's a catch: in Python, 56 is not simply an integer (or a chunk of memory containing the equivalent binary number 111000). It turns out that you've been using objects since the very beginning. 

In languages like C or Fortran, integers are simply small chunks of memory containing just the bare minimum information about it. In Python, everything is an object. An int is an object. When you type int_number = 56, Python instantiates an object of the class _int_ and refers to it by its name, int_number. So what are classes? And what does instantiating an object mean? Why is there a distinction between methods and functions? And what are user-defined classes good for?

First of all, are classes necessary for programming? In principle no. You can live your life without ever touching a class statement. In fact, not all programming language are object-oriented. So why learn it? Here's a (non-exhaustive) list of reasons to learn at least the basics:

- Conceptually, it can be easier to think in terms of classes. Programmers often think of objects as real-world objects, which helps to make a connection between the code and the real-world problem;
- Since everything is an object in Python... well, it helps to know where they come from;
- If you write code, you read code and use other people's code. And most libraries are built around classes. The reason for that is that classes can be handy to keep large projects organized. So to understand other people's codes you need to understand how classes are structured.

In this lecture, we will learn the basics of object-oriented programming (OOP) in Python. We will also apply what we learn to analyze data recorded using patch-clamping in brain slices. Keep in mind that we will use classes for this analysis for pedagogical reasons. In practice, this approach may not be the best suited for this particular task.

*** The data included in this notebook was obtained in the von Gersdorff Lab at the Vollum Institute and remains unpublish. Please do not distribute it. ***

If you'd like to learn more about the recordings shown here, refer to Zemel et. al., 2023. _Motor cortex analogue neurons in songbirds utilize Kv3 channels to generate ultranarrow spikes_. eLife 12:e81992.

## Interlude: Patch clamp recordings

The goal of this lecture is two-fold. We will learn about classes and we will implement a class that handles feature extraction from voltage clamp and current clamp recordings from excitatory neurons in brain slices.

We'll start by loading and taking a quick peak at the data. The first recording is from a patch-clamping experiment in cell-attached mode, where the voltage was fixed at zero. All recordings were done at a sampling rate of 71400Hz.

In [None]:
import numpy as np
import matplotlib.pyplot as plt

sampling_rate = 71400

cell_attached = np.load('cell_attached.npy')
t_attached = np.arange(len(cell_attached))/float(sampling_rate)

plt.plot(t_attached, cell_attached)
plt.xlabel('Time (s)')
plt.ylabel('Current (nA)')

The second recording was done after breaking into the cell's membrane, also known as whole-cell mode, while holding the current at zero.

In [None]:
whole_cell = np.load('whole_cell.npy')
t_whole = np.arange(len(whole_cell))/float(sampling_rate)

plt.plot(t_whole, whole_cell)
plt.xlabel('Time (s)')
plt.ylabel('Voltage (mV)')

Interestingly, these cells fire action potentials (APs) spontaneously at a particularly high rate. The APs themselves are also remarkably sharp and fast, a claim which we will verify by calculating AP features such as half-width, depolarization rate, and peak value. But first, classes 101.

## The class statement

The simplest class that you can create in Python is:

In [None]:
class MyFirstClass:
    pass

In the statement above, we defined a class called MyFirstClass. Note that the syntax is in some ways similar to how functions are defined, except that here we did not include any argument in parenthesis after MyFirstClass. More on that later. 

The "pass" statement simply tells Python to not do anything within the class scope.

We can now create objects with our new class. Note that the when objects are instantiated, the syntax requires parenthesis:

In [None]:
not_my_first_object = MyFirstClass()

What is contained in not_my_first_object? 

In [None]:
dir(not_my_first_object)

Alot of stuff. But what stands out is how every string above starts and ends with a double undescore, "__". These are called magic methods, and they are one of the wonders of Python. But to avoid being circular, first we need to know what a method is. 

## Methods

A method is a function defined within the scope of a class. In fact, the syntax for method definition is essentially the same as for functions. Let's create a slightly more interesting class, add a method to it:

In [None]:
class NeuronalRecording:
    
    def quantity(self):
        return 'Voltage (mV)'

    
N = NeuronalRecording()
N.quantity()

There are many details hiding in this simple example. First, the pass statement is gone since now we have something else (a method) inside the class definition. 

Second, the method _quantity_ requires an argument, which I called _self_. If you don't define at least one argument to a method, you'll get an error message. However, when we called _quantity_ using the object _N_, I did include the argument _self_. Why?

You could've used any name for this argument, but _self_ is a convention that virtually everyone uses. Behind the curtains, when you create a new object, the dot notation for class calls, e.g. N.quantity(), means the following:

In [None]:
NeuronalRecording.quantity(N)

We are actually calling the method from the class definition and passing the object we created to it! _self_ refers to the object. In practice we always use the dot notation. 

Even in Numpy, when we call, for example:

In [None]:
fibo = np.array([0,1,1,2,3,5])

fibo.sum()

we are actually creating a numpy array object and passing it to the method _sum_. But, as we just learned, we could've done

In [None]:
np.ndarray.sum(fibo)

with the slight difference that the array class is actually called ndarray (np.array is just a function that returns an object of the type np.ndarray).

What's inside the object N? 

In [None]:
dir(N)

As expected, all the magic methods that we saw before (the methods containing two underscores), in addition to the method we defined, _quantity_.

## The \__init__ method

There is one magic method that is of absolute importance, and that's the \__init__ method. Plainly speaking, \__init__ initializes an object. It's a method that gets automatically called when you create an object. Let's try it out by redefining our NeuronalRecording class:

In [None]:
class NeuronalRecording:

    def __init__(self, vm):
        self.vm = vm
        self.spikes = []
    
    def quantity(self):
        return 'Voltage (mV)'
    
N = NeuronalRecording(vm=whole_cell)
N.quantity()

In addition to the good old _self_ argument, we also added a new argument _vm_, a placeholder for the membrane potential, in the same manner we normally do when defining functions.

Now notice the next two lines. Here's when things start getting interesting. _self.vm_ is now a variable that will be assigned to the _self_ object when we instantiate the _NeuronalRecording_ class. 

All this jargon is complicated, but what it really means is that our _NeuronalRecording_ object, _N_, now has a variable called _vm_ and it's value is a numpy array.

In [None]:
print(N.vm)

And we also have a variable that was assgined as list by the class itself:

In [None]:
print(N.spikes)

The existence of both _vm_ and _spikes_ is conditioned on the creation of an object and subsequent call to the \__init__ method. So if you call these variables directly from the class, you will get an error:

In [None]:
NeuronalRecording.vm

Most importantly, these quantities are bound to object N's scope. This is important because now we can create many _NeuronalRecording_ instances at once, each with a different set of _vm_ and _spikes_ values. For example:

In [None]:
vm1 = whole_cell[:len(whole_cell)//2]
vm2 = whole_cell[len(whole_cell)//2:]

N1 = NeuronalRecording(vm1)
N2 = NeuronalRecording(vm2)

print(N1.vm)
print(N2.vm)

You can also change these variables after initialization. Let's flip them:

In [None]:
N1.vm = vm2
N2.vm = vm1

print(N1.vm)
print(N2.vm)

One natural question at this point is whether there's a way to define a variable that exists indepedently of object instantiation. The answer is yes, and these are called class variables:

In [None]:
class MyThirdClass:
    this_is_a_class_variable = 1

print(MyThirdClass.this_is_a_class_variable)

As you can see, it is still object-dependent:

In [None]:
O1 = MyThirdClass()
O2 = MyThirdClass()

print(O1.this_is_a_class_variable)
print(O2.this_is_a_class_variable)

O1.this_is_a_class_variable = 2

print(O1.this_is_a_class_variable)
print(O2.this_is_a_class_variable)

There are tricks to create variables that are really shared amongst objects, but we will not cover them in this lecture. 

## Exercise 1

To find individual spikes in whole-cell recordings, we need to locate where in the signal the membrane potential exceeds some threshold value. From eye inspection, 0mV seems like a reasonable guess. Since our recordings are in a numpy array, we can use the following syntax:

In [None]:
threshold = 0
spike_indicator = (whole_cell>threshold)

print(spike_indicator)

Using this idea, create a class called _NeuronalRecording_ and define a method _spike_indicator_ that takes the membrane potential _vm_ and a new variable called _threshold_ as arguments and implements the following operation:

_spike_indicator_ = True, if vm >= threshold

_spike_indicator_ = False, if vm < threshold

In [None]:
#Answer

class NeuronalRecording:
    
    def __init__(self, vm, threshold):
        self.vm = vm
        self.threshold = threshold
        self.spikes = []
        
    def spike_indicator(self, vm):        
        return (vm>=threshold)
        
    def quantity(self):
        return 'Voltage (mV)'

N = NeuronalRecording(whole_cell, threshold)
N.spike_indicator(whole_cell)

To check our work so far, let's plot a masked version of the recording:

In [None]:
masked = np.ma.array(whole_cell, mask=~N.spike_indicator(whole_cell))
plt.plot(t_whole, masked)

### Bonus question

How does the first line in the cell above work? You may need to read the documentation of np.ma.array:

In [None]:
np.ma.array?

## Inheritance

In our previous examples, the _NeuronalRecording_ class was designed for current clamp experiments. It even has the physical quantity _Voltage (mV)_ embedded in it. 

In voltage clamp, the quantity measured is current and the units are given in nanoampere (nA). But aside from that difference, we can still measure properties from action potentials much in the same way. We could define two classes: one for voltage clamp and another one for current clamp, both with the same set of methods but with slightly different "quantity" methods.

A better solution is to make the problem slightly more abstract. First we create a generic _NeuronalRecording_ class that doesn't specify the measurement type:

In [None]:
class NeuronalRecording:
    
    def __init__(self, trace, threshold):
        self.trace = trace
        self.threshold = threshold
        self.spikes = []
        
    def spike_indicator(self, trace):        
        return (trace>=threshold)
        
    def quantity(self):
        return ''

Note how the method _quantity_ is essentially a placeholder now.

We also renamed the variable _vm_ into _traces_ to indicate that it now represents a generic type of temporal series.

Now comes the trick: inheritance. Classes can be inherited from other classes. Here's a simple example of how this looks like in practice:

In [None]:
class A:
    
    def a_method(self):
        print('a')
        
class B(A):
    
    def b_method(self):
        print('b')
    
object_b = B()

object_b.a_method()
object_b.b_method()

Class B takes A as an argument and that gives B access to everything that class A can do! A is the base class of B. In fact, remember our very first class, _MyFirstClass_? It turns out that not passing any argument in the definition is equivalent to:

In [None]:
class MyFirstClass(object):
    pass

_object_ in Python is a special class. Implicitly, every class is a subclass of _object_. And classes are also objects of a metaclass... This gets a bit confusing, so we are going to stop this discussion here. The main message is that classes can be nested like russian dolls.

Another neat trick in Python is that subclasses can overwrite methods from the base class without affecting the base class whatsoever:

In [None]:
class C(A):
    
    def a_method(self):
        print('c')
        
        
object_c = C()
object_c.a_method()

object_a = A()
object_a.a_method()

## Exercise 2

Use our abstracted class _NeuronalRecording_ as the base to define two new classes: one for cell-attached voltage clamp recordings, and another for whole-cell current clamp recordings. Name these classes VoltageClampRecording and CurrentClampRecording, respectivelly.

In [None]:
#Answer

class VoltageClampRecording(NeuronalRecording):
    
    def quantity(self):
        return 'Current (nA)'
    
    
class CurrentClampRecording(NeuronalRecording):
    
    def quantity(self):
        return 'Voltage (mV)'

Now you might be getting a sense of why classes are so helpful to organize code. You don't need to define a methods that is used in two different classes twice. The classes can simply share the same methods!

In [None]:
whole_threshold = 0
CC = CurrentClampRecording(whole_cell, whole_threshold)

print(CC.quantity())

In [None]:
attached_threshold = 0.2
VC = VoltageClampRecording(cell_attached, attached_threshold)

print(VC.quantity())

The following functions can often be very handy to figure out where an object came from:

In [None]:
print(type(VC))
isinstance(VC, VoltageClampRecording)

Finally, if you ever need to interact with the parent class from a subclass, you can use the _super_ function. When no argument is passed, that is _super()_ is called, the immediate parent class is returned:

In [None]:
class A:
    def __init__(self, a):
        self.a = a

class B(A):
    def __init__(self):
        super().__init__('super().__init__ refers to class A\'s __init__ method, so variable self.a will contain this string.')
        
b = B()
b.a

## Segmenting spikes

Before extracting features from our action potentials we first need to segment them.

Luckily, APs can be very stereotypical, which simplifies our analysis a lot! Our first task is finding the AP peak times in both traces.

Recall that our _spike_indicator_ method returns True when the membrane potential is above the threshold we stablished and False otherwise:

In [None]:
spks = CC.spike_indicator(whole_cell)
print(spks)

This means that at the points where the threshold is crossed, the value of _spks_ changes either from True to False or from False to True. We can exploit that to find the segments where the action potential peaks are located:

In [None]:
flips = (spks[1:] != spks[:-1])
segment_indices = np.where(flips)[0]
segment_starts = segment_indices[::2]
segment_ends = segment_indices[1::2]

Here's the logic behind the code above:
- mark as True the points where the array _spks_ flips from True to False and vice versa;
- find the indices of these points using np.where;
- the starting points of the AP segments hop every two elements;
- the ending points are the indices following the starting points.

## Exercise 3

For each segment we now find the index where the membrane potential is the highest. These points should correspond to the AP peaks. This will involve a for loop. 

Calculate the spike peak indices in the whole_cell array and save them in an array called _peaks_.

In [None]:
#Answer 

peaks = []
for start, end in zip(segment_starts, segment_ends):
    peak_idx = whole_cell[start:end].argmax()+start
    peaks.append(peak_idx)

As a sanity check, let's plot the whole-cell trace again as well as the peaks:

In [None]:
plt.plot(whole_cell)
plt.plot(peaks, whole_cell[peaks], 'or')

Finally, let's fix a range of points centered at the peak times and segment those intervals out of the whole_cell array. Instead of writting a new routine, let's rewrite the previous for loop. The size of each segment will be 2048.

In [None]:
segment_size = 2048

aps = []
for start, end in zip(segment_starts, segment_ends):
    peak_idx = whole_cell[start:end].argmax()+start
    ap = whole_cell[peak_idx-segment_size//2:peak_idx+segment_size//2]
    aps.append(ap)

Averaging all the APs and finally plotting:

In [None]:
avg_ap = np.mean(aps, axis=0)
t_ap = np.arange(len(avg_ap))/float(sampling_rate)

for a in aps:
    plt.plot(a, 'r', alpha=0.1)
plt.plot(avg_ap, 'k')

Just like we previously claimed, these APs are highly stereotyped.

## Putting it all together

Phew, that was hardwork. But now we know we can package all this code in a class and reuse it for the cell attached recording. 

Take some time reading through the logic behind it. There are several modifications from the previous version.

In [None]:
class NeuronalRecording:

    segment_size = 2048
    sampling_rate = 71400
    
    def __init__(self, trace, threshold):
        self.trace = trace
        self.threshold = threshold
        self.spikes = self.segment(trace)

    def segment(self, trace):
        spks = self.spike_indicator(trace)
        
        flips = (spks[1:] != spks[:-1])
        segment_indices = np.where(flips)[0]
        segment_starts = segment_indices[::2]
        segment_ends = segment_indices[1::2]

        aps = []
        for start, end in zip(segment_starts, segment_ends):
            peak_idx = trace[start:end].argmax()+start
            ap = trace[peak_idx-self.segment_size//2:peak_idx+self.segment_size//2]
            aps.append(ap)
        return aps
        
    def spike_indicator(self, trace):        
        return (trace>=threshold)

    def independent_quantity(self):
        return 'Time (s)'
        
    def quantity(self):
        return ''

We can even go a step further and add a plotting method just to make our lives easier

In [None]:
class NeuronalRecording:

    segment_size = 2048
    sampling_rate = 71400
    
    def __init__(self, trace, threshold):
        self.trace = trace
        self.threshold = threshold
        self.spikes = self.segment(trace)

    def segment(self, trace):
        spks = self.spike_indicator(trace)
        
        flips = (spks[1:] != spks[:-1])
        segment_indices = np.where(flips)[0]
        segment_starts = segment_indices[::2]
        segment_ends = segment_indices[1::2]

        aps = []
        for start, end in zip(segment_starts, segment_ends):
            peak_idx = trace[start:end].argmax()+start
            ap = trace[peak_idx-self.segment_size//2:peak_idx+self.segment_size//2]
            aps.append(ap)
        return aps
        
    def spike_indicator(self, trace):        
        return (trace>=threshold)

    def independent_quantity(self):
        return 'Time (s)'
        
    def quantity(self):
        return ''

    def plot(self):
        t_cc = np.arange(self.segment_size)/float(self.sampling_rate)
        plt.plot(t_cc, np.mean(self.spikes, axis=0))
        plt.xlabel(self.independent_quantity())
        plt.ylabel(self.quantity())

Of course, we also need to define again our subclasses. 

In [None]:
class VoltageClampRecording(NeuronalRecording):

    def quantity(self):
        return 'Current (nA)'
    

class CurrentClampRecording(NeuronalRecording):
    
    def quantity(self):
        return 'Voltage (mV)'

From our plots of the voltage clamp recordings, 0.2 nA seems like a reasonable threshold.

In [None]:
cc_threshold = 0.2
VC = VoltageClampRecording(cell_attached, cc_threshold)

vc_threshold = 0.0
CC = CurrentClampRecording(whole_cell, vc_threshold)

In [None]:
CC.plot()

In [None]:
VC.plot()

## Measuring features

For the grand finale, we will calculate some AP properties using whole-cell recordings:
- AP peak;
- Maximum depolarization and repolarization rates;
- AP threshold;
- AP half-width.

These are important quantities because they summarize the AP waveform compactly and often tell us the neuron's identity. 

We will describe each property first and at the end put it all together in our classes.

In [None]:
spikes = np.array(CC.spikes)

### Peak values

The AP peak is simply the maximum voltage reached by the AP. That's simple, we can use numpy's max function:

In [None]:
peaks = np.max(spikes, axis=1)
print('AP peak = %f +- %f mV'%(peaks.mean(), peaks.std()))

### Depolarization and repolarization rates

Both quantities refer to the rate of change of voltage in time. We can calculate it using the slope between each pair of points. 

_Warning: In practice, never calculate the slope in this way before filtering your data!_

In [None]:
slope = (spikes[:, 1:]-spikes[:, :-1])*sampling_rate/1000.

max_depols = np.max(slope, axis=1)
max_repols = np.min(slope, axis=1)

print('Max. depol. rate = %f +- %f mV'%(max_depols.mean(), max_depols.std()))
print('Max. repol. rate = %f +- %f mV'%(max_repols.mean(), max_repols.std()))

### Bonus question

Compare the average slope of the action potential with the cell attached recording in the figures below. 

In [None]:
avg_slope = -np.mean(slope, axis=0)
t_slope = np.arange(len(avg_slope))/float(sampling_rate)

plt.plot(t_slope, avg_slope)
plt.ylabel('dV/dt (V/s)')
plt.xlabel('Time (s)')

In [None]:
VC.plot()

Why are the two so similar? Is this a coincidence?

### AP Threshold

Zemel et. al., 2023 defines AP threshold as the voltage value when the depolarization rate reaches 10 V/s. Instead of finding where the slope hits 10 V/s, we will find when it is the _closets_ to 10 V/s. The trick is to find where the absolute value of slope minus 10 V/s is the minimum!

In [None]:
idxs = np.argmin(np.abs(slope-10), axis=1)

ths = []
for ap, idx in zip(spikes, idxs):
    ths.append(ap[idx])
ths = np.array(ths)

print('Threshold = %f +- %f mV'%(ths.mean(), ths.std()))

### Half-width

The AP half-width is the width of the AP at the mid voltage between the threshold and the peak. First we need to find the points when this condition is met, which we can calculate by subtracting the peak from the threshold and halving the result:

In [None]:
mid_volt = (peaks+ths)*0.5
mid_volt = mid_volt[:, np.newaxis]

idxs_depol = np.argmin(np.abs(spikes[:, :segment_size//2]-mid_volt), axis=1)
idxs_repol = np.argmin(np.abs(spikes[:, segment_size//2:]-mid_volt), axis=1)+segment_size//2

hws = (idxs_repol-idxs_depol)/float(sampling_rate)

print('HW = %f +- %f s'%(hws.mean(), hws.std()))

## Final Exercise

Put everything together in the definition of the class VoltageClampRecording and write a method that returns a summary of all AP properties.


## Stuff we skipped over

If you feel like learning more, here's list of some of the topics we skipped over but which can be useful in practice:

- Some useful decorators, such as @property, @classmethod, and @staticmethod;
- Multiple inheritances;
- Other useful magic methods, like \__repr__, \__