# Lab 3: Calcium demixing and deconvolution

**STATS320: Machine Learning Methods for Neural Data Analysis**

_Stanford University. Winter, 2021._

---

**Team Name:** _Your team name here_

**Team Members:** _Names of everyone on your team here_

*Due: 11:59pm Thursday, Feb 4, 2021 via GradeScope (see below)*

---

In this lab you'll write your own code for demixing and deconvolving calcium imaging videos. Demixing refers to the problem of identifying potentially overlapping neurons in the video and separating their fluorescence traces. Deconvolving refers to taking those traces and finding the times of spiking activity, which produce exponentially decaying transients in fluorescence. We'll frame it as a constrained and (partially) non-negative matrix factorization problem, inspired by the CNMF model of Pnevmatikakis et al, 2016, which is implemented in [CaImAn](https://github.com/flatironinstitute/CaImAn) (Giovannucci et al, 2019). More details and further references are in the course notes. We'll use [CVXpy](https://www.cvxpy.org/) to solve the convex optimization problems at the hard of this approach.

**References**
- Pnevmatikakis, Eftychios A., Daniel Soudry, Yuanjun Gao, Timothy A. Machado, Josh Merel, David Pfau, Thomas Reardon, et al. 2016. “Simultaneous Denoising, Deconvolution, and Demixing of Calcium Imaging Data.” Neuron 89 (2): 285–99.
[link](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4881387/)
- Giovannucci, Andrea, Johannes Friedrich, Pat Gunn, Jérémie Kalfon, Brandon L. Brown, Sue Ann Koay, Jiannis Taxidis, et al. 2019. “CaImAn an Open Source Tool for Scalable Calcium Imaging Data Analysis.” eLife. [link](http://dx.doi.org/10.7554/eLife.38173)





# Environment Setup

In [None]:
import numpy as np
import scipy.sparse
from scipy.signal import butter, sosfilt
from scipy.stats import norm
from scipy.ndimage import gaussian_filter
from skimage.feature import peak_local_max

# we'll use CVXpy to solve convex optimization problems
import cvxpy as cvx

# plotting stuff
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
from matplotlib.patches import Circle
import seaborn as sns

# helpers
from tqdm.auto import trange
from copy import deepcopy
import warnings


## Download example data

This demo data was contributed by Sue Ann Koay and David Tank (Princeton University). 
It is also used in the CaImAn demo notebook. 
We used CaImAn and NoRMCorr to correct for motion artifacts.

In [None]:
%%capture
! wget -nc https://www.dropbox.com/s/3gl299gbw1mavcl/data.npy

# Load the data and permute it so that time is the last axis,
# as in the notes. 
data = np.load("data.npy").transpose(1, 2, 0)
height, width, num_frames = data.shape

# Set some constants 
FPS = 30                        # frames per second in the movie
NEURON_WIDTH = 10               # approximate width (in pixels) of a neuron
GCAMP_TIME_CONST_SEC = 0.300    # reasonable guess for calcium decay time const.

In [None]:
#@title Helper functions for movies and plotting { display-mode: "form" }
from matplotlib import animation
from IPython.display import HTML
from tempfile import NamedTemporaryFile
import base64

# Set some plotting defaults
sns.set_context("talk")

# initialize a color palette for plotting
palette = sns.xkcd_palette(["windows blue",
                            "red",
                            "medium green",
                            "dusty purple",
                            "orange",
                            "amber",
                            "clay",
                            "pink",
                            "greyish"])

_VIDEO_TAG = """<video controls>
 <source src="data:video/x-m4v;base64,{0}" type="video/mp4">
 Your browser does not support the video tag.
</video>"""

def _anim_to_html(anim, fps=20):
    # todo: todocument
    if not hasattr(anim, '_encoded_video'):
        with NamedTemporaryFile(suffix='.mp4') as f:
            anim.save(f.name, fps=fps, extra_args=['-vcodec', 'libx264'])
            video = open(f.name, "rb").read()
        anim._encoded_video = base64.b64encode(video)

    return _VIDEO_TAG.format(anim._encoded_video.decode('ascii'))

def _display_animation(anim, fps=30, start=0, stop=None):
    plt.close(anim._fig)
    return HTML(_anim_to_html(anim, fps=fps))

def play(movie, fps=FPS, speedup=1, fig_height=6):
    # First set up the figure, the axis, and the plot element we want to animate
    Py, Px, T = movie.shape
    fig, ax = plt.subplots(1, 1, figsize=(fig_height * Px/Py, fig_height))
    im = plt.imshow(movie[..., 0], interpolation='None', cmap=plt.cm.gray)
    tx = plt.text(0.75, 0.05, 't={:.3f}s'.format(0), 
                  color='white',
                  fontdict=dict(size=12),
                  horizontalalignment='left',
                  verticalalignment='center', 
                  transform=ax.transAxes)
    plt.axis('off')

    def animate(i):
        im.set_data(movie[..., i * speedup])
        tx.set_text("t={:.3f}s".format(i * speedup / fps))
        return im, 

    # call the animator.  blit=True means only re-draw the parts that have changed.
    anim = animation.FuncAnimation(fig, animate, 
                                   frames=T // speedup, 
                                   interval=1, 
                                   blit=True)
    plt.close(anim._fig)

    # return an HTML video snippet
    print("Preparing animation. This may take a minute...")
    return HTML(_anim_to_html(anim, fps=30))

def plot_problem_1d(local_correlations, filtered_correlations, peaks):
    def _plot_panel(ax, im, title):
        h = ax.imshow(im, cmap="Greys_r")
        ax.set_title(title)
        ax.set_xlim(0, width)
        ax.set_ylim(height, 0)
        ax.set_axis_off()

        # add a colorbar of the same height
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad="2%")
        plt.colorbar(h, cax=cax)

    fig, axs = plt.subplots(1, 3, figsize=(15, 6))
    _plot_panel(axs[0], local_correlations, "local correlations")
    _plot_panel(axs[1], filtered_correlations, "filtered correlations")
    _plot_panel(axs[2], local_correlations, "candidate neurons")

    # Draw circles around the peaks
    for n, yx in enumerate(peaks):
        y, x = yx
        axs[2].add_patch(Circle((x, y), 
                                radius=NEURON_WIDTH/2, 
                                facecolor='none', 
                                edgecolor='red', 
                                linewidth=1))
        
        axs[2].text(x, y, "{}".format(n),
                    horizontalalignment="center",
                    verticalalignment="center",
                    fontdict=dict(size=10, weight="bold"),
                    color='r')

def plot_problem_2(traces, denoised_traces, amplitudes):
    num_neurons, num_frames = traces.shape

    # Plot the traces and our denoised estimates
    scale = np.percentile(traces, 99.5, axis=1, keepdims=True)
    offset = -np.arange(num_neurons)

    # Plot points at the time frames where the (normalized) amplitudes are > 0.05
    sparse_amplitudes = amplitudes.copy() / scale
    sparse_amplitudes = np.isclose(sparse_amplitudes, 0, atol=0.05).astype(float)
    sparse_amplitudes[sparse_amplitudes == 1] = np.nan

    plt.figure(figsize=(12, 8))
    plt.plot((traces / scale).T + offset , color=palette[0], lw=1, alpha=0.5)
    plt.plot((denoised_traces / scale).T + offset, color=palette[0], lw=2)
    plt.plot((sparse_amplitudes).T + offset, color=palette[1], marker='o', markersize=2)
    plt.xlabel("time (frames)")
    plt.xlim(0, num_frames)
    plt.ylabel("neuron")
    plt.yticks(-np.arange(num_neurons, step=5), labels=np.arange(num_neurons, step=5))
    plt.ylim(-num_neurons, 2)
    plt.title("raw and denoised fluorescence traces")


def plot_problem_3(flat_data, params, latents, hypers, plot_bkgd=True, indices=None):
    U = params["footprints"].reshape(-1, height, width)
    u0 = params["bkgd_footprint"].reshape(height, width)
    C = latents["traces"]
    c0 = latents["bkgd_trace"]
    N, T = C.shape

    if indices is None: indices = np.arange(N)
    

    def _plot_factor(footprint, trace, title):
        fig, ax1 = plt.subplots(1, 1, figsize=(12, 6))
        vlim = abs(footprint).max()
        h = ax1.imshow(footprint, vmin=-vlim, vmax=vlim, cmap="RdBu")
        ax1.set_title(title)
        ax1.set_axis_off()

        # add a colorbar of the same height
        divider = make_axes_locatable(ax1)
        cax = divider.append_axes("right", size="5%", pad="2%")
        plt.colorbar(h, cax=cax)

        ax2 = divider.append_axes("right", size="150%", pad="75%")
        ts = np.arange(T) / FPS
        ax2.plot(ts, trace, color=palette[0], lw=2)
        ax2.set_xlabel("time (sec)")
        ax2.set_ylabel("fluorescence trace")
        
    if plot_bkgd:
        _plot_factor(u0, c0, "background")

    for n in indices:            
        _plot_factor(U[n], C[n], "neuron {}".format(n))
        

## Movie of the data

It takes a minute to render the animation...

In [None]:
# Play the motion corrected movie.
play(data, speedup=5)

# Part 1: Initialization

## Problem 1a: Estimate the noise at each pixel and standardize

We'll use a simple heuristic to estimate the noise. With slow calcium responses, most of the high frequency content (e.g. above 8Hz) should be noise.  Since Gaussian noise has a flat spectrum (we didn't prove this but it's a useful fact to know!), the standard deviation of the high frequency signal should tell us the noise at lower frequencies as well. 

In this problem, use `butter` and `sosfilt` to high-pass filter the data at 8Hz with a 10-th order Butterworth filter. (Recall Lab 1.) Then compute the standard deviation for each pixel using `np.std` and the `axis` keyword argument to get the standard deviation over time for each pixel.

Finally, standardize the data by dividing each pixel by its standard deviation. 

In [None]:
###
# High-pass filter the data at 8Hz using a Butterworth filter.
# That should filter out the calcium transients and give a 
# reasonable estimate of the noise. 
# 
# YOUR CODE BELOW
sos = butter(..., output='sos')
noise = sosfilt(...)
sigmas = ...
assert sigmas.shape == (height, width)
#
###

# Plot the noise standard deviation for each pixel
plt.imshow(sigmas, vmin=0)
plt.axis("off")
plt.title("Estimated noise per pixel")
plt.colorbar(label="noise std. deviation")

# Standardize the data by dividing each frame by the standard deviation
std_data = data / sigmas[:, :, None]

# Check that we got the same answer
assert np.allclose(sigmas.mean(), 23.4768, atol=1e-4)

## Problem 1b: Find peaks in the local correlation matrix 


**Step 1**
To find candidate neurons, look for places in the image where nearby pixels are highly correlated with one another. 

The correlation between pixels $(i,j)$ and $(k,\ell)$ is
\begin{align}
\rho_{ijk\ell} = \frac{1}{T} \sum_{t=1}^T z_{ijt} z_{k\ell t},
\end{align}
where 
\begin{align}
z_{ijt} = \frac{y_{ijt} - \bar{y}_{ij}}{\sigma_{ij}}
\end{align}
denotes the z-scored data. $y_{ijt}$ is the fluorescence at pixel $(ij)$ and frame $t$, $\bar{y}_{ij}$ is the average fluorescence at that pixel over time, and $\sigma_{ij}$ is the standard deviation of fluorescence in that pixel. You've already compute the noise level $\sigma_{ij}$ for each pixel and you computed $y_{ijt} / \sigma_{ij}$ in Problem 1b. To compute $z$, simply subtract the mean of the standardized data. 

Now define the local correlation at pixel $(i,j)$ to be the average correlation with its neighbors to the north, south, east, and west:
\begin{align}
\bar{\rho}_{ij} = \tfrac{1}{4} \left(\rho_{ij,i-1,j} +  \rho_{ij,i+1,j} + \rho_{ij,i,j+1} + \rho_{ij,i,j-1}\right).
\end{align}
If $(i,j)$ is a border cell, assume the correlation with out-of-bounds neighbors is zero.

**Step 2**
Use the `gaussian_filter` function with a standard deviation `sigma=NEURON_WIDTH/4` to smooth the local correlations.

**Step 3**
Find peaks in the smoothed local correlations using `peak_local_max`,
which we imported from the `skimage.feature` package. Set a `min_distance` 
of 2 and play with the `threshold_abs` to get 30 neurons, which we think is a reasonable estimate.

In [None]:
###
# Compute the zscored data by subtracting the mean of 
# the standardized data.
#
# YOUR CODE BELOW

# First z-score the data
zscored_data = ...

# Then compute the local correlation by summing correlations with 
# neighboring pixels
local_correlations = np.zeros((height, width))
local_correlations[1:,  :] += ... # N
local_correlations[:-1, :] += ... # S
local_correlations[:, :-1] += ... # E
local_correlations[:,  1:] += ... # W
local_correlations /= 4

# Smooth the local correlations with a Gaussian filter of width 1/4
# the width of a typical neuron. 
filtered_correlations = gaussian_filter(...)
        
# Finally, find peaks in the smoothed local correlations using 
# `peak_local_max`. Set a `min_distance` of 2 and play with the 
# `threshold_abs` to get 30 neurons, which we think is a reasonable estimate.
min_distance = 2
threshold_abs = ...
peaks = peak_local_max(...)

num_neurons = len(peaks)
print("Found", num_neurons, "candidate neurons")
#
###

plot_problem_1d(local_correlations, filtered_correlations, peaks)
assert num_neurons == 30

## Problem 1c [Short Answer]: Explain this heuristic

Why are peaks in the local correlations indicative of neurons? Why did you filter the correlations? What would happen if you didn't use the Gaussian filter, or you used a Gaussian filter of a larger width? 

_Answer below this line_

---


## Problem 1d: Initialize the footprints

Initialize the footprints to,
\begin{align}
u_{nij} \propto \mathcal{N}\left(\begin{bmatrix}i \\ j \end{bmatrix} \,\bigg|\, \begin{bmatrix} \mu_{n,i} \\ \mu_{n,j} \end{bmatrix}, \frac{w^2}{4^2} I \right)
\end{align}
where $\mu_{n} \in \mathbb{R}^2$ is the location of the peak for neuron $n$ and $w$ is the width of a typical neuron. 

There's a simple trick to compute the footprints: convolve a Gaussian filter with a matrix that is zeros everywhere except for a one at the location of the peak. The `gaussian_filter` function with `sigma` set to 1/4 the neuron width will do this for you.

Finally, normalize the footprints so that $\|u_n\|=1$.

In [None]:
###
# Initialize the spatial footprints for each neuron.
# For each neuron, apply a Gaussian filter to a one-hot 
# matrix with a one in the peak position for that neuron.
# 
# YOUR CODE BELOW
footprints = np.zeros((num_neurons, height, width))
for neuron in range(num_neurons):
    ...
    footprints[neuron] = gaussian_filter(...)
    
# Scale the footprints to be unit norm
footprints /= ...

#
###

# Check that they're unit norm
assert np.allclose(np.linalg.norm(footprints, axis=(1,2)), 1.0)

# Plot the superimposed footprints
plt.imshow(footprints.sum(axis=0), cmap="Greys_r")
plt.axis("off")
plt.title("superimposed footprints")
_ = plt.colorbar()

## Problem 1e: Initialize the background
Set the spatial background factor $u_0$ equal to the **median of the standardized data** and set the temporal background factor to $c_{0} = 1_T$. The median should be more robust to the large spikes than the mean is. Then normalize by dividing $u_0$ by its norm $\|u_0\|_2$ and multiplying $c_0$ by $\|u_0\|_2$. 

In [None]:
###
# Initialize the background footprint and trace
bkgd_trace = np.ones(num_frames)
bkgd_footprint = ...

# rescale so that the spatial background has norm 1
scale = ...
bkgd_footprint /= scale
bkgd_trace *= scale
#
###

# Plot the background factor
plt.imshow(bkgd_footprint)
plt.axis("off")
plt.title("background footprint $u_0$")
plt.colorbar()

assert np.isclose(bkgd_footprint.mean(), 0.0056, atol=1e-4)

## Initialize the traces

We'll initialize the traces for Part 2 by computing the residual, projecting it onto each footprint in order, and updating the residual by subtracting off each neuron's contribution.

If we've done a good job initializing, the traces should show clear spikes and the noise should be roughly in the range $[-3, +3]$ since the data is standardized to have standard deviation 1.

In [None]:
residual = std_data - np.einsum('ij,t->ijt', bkgd_footprint, bkgd_trace)
traces = np.zeros((num_neurons, num_frames))
for n in trange(num_neurons):
    traces[n] = np.einsum('ij,ijt->t', footprints[n], residual)
    residual -= np.einsum('ij,t->ijt', footprints[n], traces[n])

# Plot trace for a single neuron
n = 16
plt.plot(traces[n], label="trace")
plt.hlines([-3, 3], 0, num_frames, 
        colors='r', linestyles=':', zorder=10, 
        label="noise level")
plt.legend(loc="upper left")
plt.xlim(0, num_frames)
plt.xlabel("time (frames)")
plt.ylabel("fluorescence")
plt.title("neuron {}".format(n))

# check that we got the same answer using the parameters from parts 1a-1e.
assert np.isclose(traces[16].mean(), 2.3430)

# Part 2: Deconvolving spikes from calcium traces

In this part you'll use [CVXpy](https://www.cvxpy.org/) to deconvolve the calcium traces by solving a convex optimization problem. CVX is a "Python-embedded modeling language for convex optimization problems," as the website says.  It provides an easy-to-use interface for translating convex optimization problems into code and easy access to a variety of underlying solvers. The key objects are:
- `cvx.Variable` objects, which specify the variables you wish to optimize with respect to,
- `cvx.Minimize` objects, which let you specify the objective you wish to minimize,
- `cvx.Problem` objects, which combine an objective and a set of constraints. 

CVX also has lots of helper functions like
- `cvx.sum_squares`, which computes the sum of squares of an array, and
- `cx.norm`, which computes norms of the specified order.

The following example is modified from the CVXpy homepage, linked above. It solves a least-squares problem with box constraints and compares the constrained and unconstrained solutions.

In [None]:
# A simple CVX example...

# Problem data.
np.random.seed(1)
A = np.random.randn(30, 20)
b = np.random.randn(30)

# Construct the problem.
x = cvx.Variable(20)
objective = cvx.Minimize(cvx.sum_squares(A @ x - b))
constraints = [0 <= x, x <= 1]
prob = cvx.Problem(objective, constraints)

# The optimal objective value is returned by `prob.solve()`.
# The optimal value for x is stored in `x.value`.
result = prob.solve(verbose=False)

# Plot the constrained optimum vs the unconstrained.
plt.fill_between([0, 19], 0, 1, color='k', alpha=0.1, hatch='x', label="constraint set")
plt.plot(x.value, '-o', label="$0 \leq x \leq 1$")
plt.plot(np.linalg.lstsq(A, b, rcond=None)[0], '-', marker='.', label="unconstrained")
plt.xlim(0, 19)
plt.ylim(-1, 1.0)
plt.xlabel("$n$")
plt.ylabel("$x_n^\star$")
plt.legend(loc="lower right", fontsize=10)

## Problem 2a: Solve the convex optimization problem in dual form with CVX

In this part of the lab you'll use CVXpy to maximize the log joint probability in its **dual form:**

\begin{align}
    \hat{c}_n, \hat{b}_n = \text{arg min}_{c_n, b_n} \; \|G c_n\|_1 
    \quad \text{subject to } \quad 
    \|\mu_n - c_n - b_n\|_2^2 &\leq \theta^2, \; G c_n \geq 0,
\end{align}

where $\mu_n = u_n^\top R_n \in \mathbb{R}^T$ is the target for neuron $n$ and 

\begin{align}
    G &= 
    \begin{bmatrix}
    1             &               &        &        \\
    -e^{-1/\tau} & 1             &        &        \\
    0             & -e^{-1/\tau} & 1      &        \\
                  & 0             & \ddots & \ddots \\
    \end{bmatrix}
\end{align}
is the first order difference matrix. The spike amplitudes (i.e. jumps in the fluorescence) are given by $a_n = G c_n$, so you can think about the optimization problem as minimizing the $L^1$ norm of the jumps subject to a non-negativity constraint and an upper bound on the $L^2$ norm of the difference between the target $\mu_n$ and the trace $c_n$. 

**Note** that this is a slight modification of the problem presented in class:
1. Here we've added a bias term $b_n$, which will be helpful in cases where the target has a nonzero baseline.  Accounting for this possibility will lead to more robust estimates of the calcium traces.
2. In class we presented the constraint $\|\mu_n - c_n - b_n\|_2 \leq \theta$. CVXpy does a much better job at solving these "second order cone programs," so in practice that's what you should do! For this problem, however, you'll square both sides, as written in the objective above. Squaring doesn't change the constraint set, but it will make it easier to compare to the "primal" form you'll solve in Problem 2c and 2d. 
3. There was a slight typo in my notes and slides that had the constraint $c_n \geq 0$ rather than $G c_n \geq 0$.  The former allows for positive and negative jumps (it just penalizes their absolute value), whereas the latter only allows positive jumps.

We argued that a reasonable guess for the norm threshold is $\theta = (1+\epsilon) \sigma \sqrt{T}$.  For large $T$ and good estimates of the target, we should be able to set $\epsilon$ pretty small.  Here, we'll use a fairly liberal upper bound and set $\epsilon = 1$ since we're working with a short dataset and a poor initial guess.

One of the great things about CVXpy is that it **works with SciPy's sparse matrices.** For example, you can use `scipy.sparse.diags` to construct the $G$ matrix. Under the hood, the solver will leverage the sparsity to run in linear time.

In [None]:
def deconvolve(trace, 
               noise_std=1.0, 
               epsilon=1.0,
               tau=GCAMP_TIME_CONST_SEC * FPS,
               full_output=False,
               verbose=False):
    """Deconvolve a noisy calcium trace (aka "target") by solving a 
    the convex optimization problem described above.

    Parameters
    ----------
    trace: a T numpy array containing the noisy trace (aka target).
    noise_std: scalar noise standard deviation $\sigma$
    epsilon: extra slack for the norm constraint. 
        (Typically > 0 and certainly > -1)
    tau: the time constant of the calcium indicator decay.
    full_output: if True, return a dictionary with the deconvolved 
        trace and a bunch of extra info, otherwise just return the trace.
    verbose: flag to pass to the CVX solver to print more info.
    """
    assert trace.ndim == 1
    T = len(trace)

    ###
    # YOUR CODE BELOW

    # Initialize the variable we're optimizing over
    c = cvx.Variable(...)
    b = cvx.Variable(...)

    # Create the sparse matrix G with 1 on the diagonal and -e^{-1/\tau} on the first lower diagonal
    G = ...

    # set the threshold to (1+\epsilon) \sigma \sqrt{T}
    theta = ...

    # Define the objective function
    objective = cvx.Minimize(...)
    
    # Set the constraints. 
    # PUT THE NORM CONSTRAINT FIRST, THEN THE NON-NEGATIVITY CONSTRAINT!
    constraints = [..., ...]

    # Construct the problem
    prob = cvx.Problem(..., ...)
    #
    ###

    # Solve the optimization problem. 
    try:
        # First try the default solver then revert to SCS if it fails.
        result = prob.solve(verbose=verbose)
    except Exception as e:
        print("Default solver failed with exception:")
        print(e)
        print("Trying 'solver=SCS' instead.")
        # if this still fails we give up!
        result = prob.solve(verbose=verbose, solver="SCS")

    # Make sure the result is finite (i.e. it found a feasible solution)
    if np.isinf(result): 
        raise Exception("solver failed to find a feasible solution!")

    # Package complete results into a dict
    all_results = dict(
        trace=c.value,
        baseline=b.value,
        result=result,
        amplitudes=G @ c.value,
        lagrange_multiplier=constraints[0].dual_value[0]
    )
    assert np.size(constraints[0].dual_value) == 1, \
        "Make sure your first constraint is on the norm of the residual."

    return all_results if full_output else c.value

# Solve the deconvolution problem for one neuron
n = 16              # this neuron has particularly high SNR
noise_std = 1.0     # \sigma is 1 since we standardized the data
epsilon = 1.0       # start with a generous tolerance of 2 \sigma \sqrt{T} (i.e. \epsilon = 1)
dual_results = deconvolve(traces[n], 
                          noise_std=noise_std, 
                          epsilon=epsilon,
                          full_output=True, 
                          verbose=True)

# Plot 
plt.plot(traces[n], color=palette[0], lw=1, alpha=0.5, label="raw")
plt.plot(dual_results["trace"] + dual_results["baseline"], 
         color=palette[0], lw=2, label="deconvolved")
plt.legend(loc="upper left")
plt.xlim(0, num_frames)
plt.xlabel("time (frames)")
plt.ylabel("fluorescence")
_ = plt.title("neuron {}".format(n))

# Check your answer
assert np.isclose(dual_results["result"], 563.4, 1e-1)

## Plot solutions as a function of $\epsilon$ (and hence of $\theta$)

Compute and plot the solutions (in separate figures) for a range of $\epsilon$ values. 

In [None]:
epsilons = [0, 0.25, 0.5, 0.75, 1, 2, 5, 10]
for epsilon in epsilons:
    # deconvolve with this epsilon
    dual_results = deconvolve(traces[n], 
                              noise_std=noise_std, 
                              epsilon=epsilon,
                              full_output=True, 
                              verbose=False)
    
    # Plot 
    plt.figure()
    plt.plot(traces[n], color=palette[0], lw=1, alpha=0.5, label="raw")
    plt.plot(dual_results["trace"] + dual_results["baseline"], 
            color=palette[0], lw=2, label="".format(epsilon))
    plt.legend(loc="upper left", fontsize=10)
    plt.xlim(0, num_frames)
    plt.xlabel("time (frames)")
    plt.ylabel("fluorescence")
    _ = plt.title("neuron {} $\epsilon$ = {:.2f}".format(n, epsilon))


## Problem 2b [Short Answer]: Explain these results

How does the solution change as you increase $\epsilon$ and thereby increase $\theta$? Why?

_Answer below this line_

---

## Problem 2c [Math]: Relate the dual form to the primal

Replacing the upper bound on the squared norm in Problem 2a with its Lagrangian, we obtain the following "primal" form of the problem:

\begin{align}
    \hat{c}_n, \hat{b}_n = \text{arg min}_{c_n, b_n} \; \eta (\|\mu_n - c_n - b_n\|_2^2 - \theta^2) + \|G c_n\|_1 
    \quad \text{subject to } \quad  G c_n \geq 0,
\end{align}

where $\eta$ is the Lagrange multiplier.

**Show** that this is equivalent to maximizing the log joint (with a baseline $b_n$) 

\begin{align}
\hat{c}_n, \hat{b}_n = \text{arg max}_{c_n, b_n} \mathcal{L}(c_n, b_n) &= -\frac{1}{2\sigma^2} \|\mu_n - c_n - b_n\|_2^2 - \lambda_n\|G c_n\|_1 \quad \text{subject to } \quad  G c_n \geq 0
\end{align}

by **solving for the value of $\lambda_n$** (in terms of $\eta$ and $\sigma$) that makes these problems equivalent. 

_Answer below this line_

---


## Problem 2d: Solve the problem in primal form with $\lambda_n$ set to match the dual

Solve the primal problem with CVX using the amplitude rate hyperparameter $\lambda_n$ that you solved for in Problem 2d and the optimal Lagrange multiplier $\eta$ output in Problem 2a.
```
dual_results["lagrange_multiplier"]   # this is \eta
```

In [None]:
def deconvolve_primal(trace, 
                      amplitude_rate,
                      noise_std=1.0, 
                      tau=GCAMP_TIME_CONST_SEC * FPS,
                      verbose=True,
                      full_output=False):
    """Deconvolve a noisy calcium trace (aka "target") by solving a 
    the convex optimization problem in the primal form.

    Parameters
    ----------
    trace: a T numpy array containing the noisy trace.
    amplitude_rate: non-negative rate (inverse scale) parameter $\lambda$
    noise_std: scalar noise standard deviation $\sigma$
    tau: the time constant of the calcium indicator decay.
    full_output: if True, return a dictionary with the deconvolved 
        trace and a bunch of extra info, otherwise just return the trace.
    verbose: flag to pass to the CVX solver to print more info.
    """
    assert trace.ndim == 1
    T = len(trace)

    ###
    # YOUR CODE BELOW

    # Initialize the variable we're optimizing over
    c = cvx.Variable(...)
    b = cvx.Variable(...)

    # Create the sparse matrix G with 1 on the diagonal and -e^{-1/\tau} on the first lower diagonal
    G = ...

    # Define the objective function
    objective = cvx.Minimize(...)
    constraints = [...]
    prob = cvx.Problem(...)
    #
    ###
    
    # Solve the optimization problem
    result = prob.solve(verbose=verbose)
    if np.isinf(result): 
        raise Exception("solver failed to find a feasible solution!")

    all_results = dict(
        trace=c.value,
        baseline=b.value,
        result=result,
        amplitudes=G @ c.value
    )
    return all_results if full_output else c.value


# Solve the deconvolution problem in the dual form
n = 16              # this neuron has particularly high SNR
noise_std = 1.0     # \sigma is 1 since we standardized the data
epsilon = 1.0       # start with a generous tolerance of 2 \sigma \sqrt{T} (i.e. \epsilon = 1)
dual_results = deconvolve(traces[n], 
                          noise_std=noise_std, 
                          epsilon=epsilon,
                          full_output=True, 
                          verbose=True)


###
# Convert the optimal Lagrange multiplier returned in Problem 2a
# to a hyperparameter $\lambda_n$ that sets the rate (inverse scale)
# of the exponential prior on spike amplitudes. The multiplier `eta` is in 
# `dual_results['lagrange_multiplier']` and \sigma is set by `noise_std`.
#
# YOUR CODE BELOW
amplitude_rate = ...
###

# Solve the problem in primal form
primal_results = deconvolve_primal(traces[n], 
                                   amplitude_rate=amplitude_rate, 
                                   verbose=True, 
                                   full_output=True)

# Plot raw, primal, and dual optimal trace for neuron n
plt.plot(traces[n], color=palette[0], lw=1, alpha=0.5, label="raw")
plt.plot(dual_results["trace"] + dual_results["baseline"],
         color=palette[0], ls='-', lw=2, label="dual")
plt.plot(primal_results["trace"] + primal_results["baseline"], 
         color=palette[1], ls='-', lw=1, label="primal")
plt.legend(loc="upper left")
plt.xlim(0, num_frames)
plt.xlabel("time (frames)")
plt.ylabel("fluorescence")
plt.title("neuron {}".format(n))

# Make sure the traces are the same!
primal_diff = abs(dual_results["trace"] - primal_results["trace"]).max()
print("primal and dual solutions match to absolute value: {:.4f}".format(primal_diff))
assert np.allclose(dual_results["trace"], primal_results["trace"], atol=1e-1)


## Compute all deconvolved traces and plot them

In [None]:
# Deconvolve each trace and concatenate the results
deconvolved_traces = np.zeros_like(traces)
amplitudes = np.zeros_like(traces)
for neuron in trange(num_neurons):
    try:
        all_results = deconvolve(traces[neuron], epsilon=1.0, full_output=True)
    except Exception as e:
        print("Failed to extract trace for neuron {}".format(neuron))
        all_results = deconvolve(traces[neuron], epsilon=1.0, verbose=True, full_output=True)
        raise(e)
    deconvolved_traces[neuron] = all_results["trace"]
    amplitudes[neuron] = all_results["amplitudes"]

plot_problem_2(traces, deconvolved_traces, amplitudes)

# Part 3: Demix and deconvolve the calcium imaging video

In this part you'll write the updates for MAP estimation in the constrained non-negative matrix factorization model. 

As in the notes and slides, we will operate on the **flattened** data and residuals by raveling the frames into 1d vectors.

**Note** that unlike CNMF (Pnevmatikakis et al, 2016), we're not going to constrain the footprints to be non-negative. Instead, we'll just assume they are normalized, since that's a bit easier to and it makes a clearer connection to the spike sorting algorithms from Lab 2. It would be a simple extension to enforce non-negativity, and the course notes describe how.

## Flatten the pixel dimensions and package the parameters

In [None]:
flat_data = std_data.reshape(-1, num_frames)
flat_footprints = footprints.reshape(num_neurons, -1)
flat_bkgd_footprint = bkgd_footprint.reshape(-1)

# The latent variables are the traces (they grow with time).
init_latents = dict(
    traces=np.zeros((num_neurons, num_frames)),
    bkgd_trace=bkgd_trace,
)

# The parameters are the spatial components 
# (they don't grow with the length of the data).
init_params = dict(
    footprints=flat_footprints,
    bkgd_footprint=flat_bkgd_footprint
)

# The hyperparameters specify the number of neurons,
# the noise standard deviation ($\sigma = 1$ since we standardized the data),
# the prior variance of the background trace (something really large),
# and the tolerance for our norm constrain ($\epsilon$).
hypers = dict(
    num_neurons=num_neurons,
    noise_std=1.0,
    bkgd_trace_var=1e6,
    epsilon=1.0,
)

## Problem 3a: Write a function to compute the log likelihood given the residual

The log likelihood is 
\begin{align}
\log p(Y \mid U, C, u_0, c_0, \sigma^2) &= 
\sum_{p=1}^P \sum_{t=1}^T \log \mathcal{N}(y_{pt} \mid \sum_{n=1}^N u_{np} c_{nt} + u_{0p} c_{0t}, \sigma^2) \\
&= -\frac{PT}{2} \log (2\pi \sigma^2) -\frac{1}{2\sigma^2} \left\|Y - U^\top C - u_0 c_0^\top \right\|_F^2
\end{align}
Write a function to compute the log likelihood given the precomputed residual $R = Y - U^\top C - u_0 c_0^\top $.


In [None]:
def log_likelihood_residual(residual, noise_std):
    """ Evaluate the log joint probability of the data 
    given the precomputed residual $Y - U^T C - u_0 c_0^T$

    Parameters
    ----------
    residual: a PxT numpy array containing the residual noise
        after subtracting the neuron and background contributions.

    noise_std: scalar per-pixel standard deviation $\sigma$
    """
    ### 
    # YOUR CODE BELOW
    ll = ...
    ###
    return ll / residual.size
    
# check it on the flat data (as if C and c_0 were zero)
assert np.isclose(log_likelihood_residual(flat_data, hypers["noise_std"]), -4.6867, atol=1e-4)

## Problem 3b: Optimize a trace

Optimize a single neuron's trace using the `deconvolve` function you wrote in Problem 2a. The target is $\mu_n = u_n^\top R_n$ where $R_n$ is the residual for this neuron. The residual is given as input to this function.

**Note:** In your final version, make sure you have `verbose=False` so that the final code doesn't print a bunch of unnecessary stuff.

In [None]:
def _update_trace(neuron, residual, params, latents, hypers):
    """Update a single neuron's trace by calling your `deconvolve` function.

    Parameters
    ----------
    neuron: integer index of which neuron to update
    residual: a PxT numpy array containing the residual for this neuron.
    params: dictionary with keys ['footprints', 'bkgd_footprint']
    latents: dictionary with keys ['traces', 'bkgd_trace']
    hypers: dictionary with keys ['num_neurons', 'epsilon', 'noise_std', 'bkgd_trace_var']
    """
    ###
    # YOUR CODE BELOW
    ...
    trace = ...
    #
    ###
    assert np.all(np.isfinite(trace))
    return trace


## Problem 3c: Optimize a footprint

Optimize a single neuron's footprint as follows,
\begin{align}
u_n = \frac{R_n c_n}{\|R_n c_n\|}
\end{align}
where $R_n$ is the given residual and $c_n$ is the neuron's trace.

In [None]:
def _update_footprint(neuron, residual, params, latents, hypers):
    """Update a single neuron's footprint.

    Parameters
    ----------
    neuron: integer index of which neuron to update
    residual: a PxT numpy array containing the residual for this neuron.
    params: dictionary with keys ['footprints', 'bkgd_footprint']
    latents: dictionary with keys ['traces', 'bkgd_trace']
    hypers: dictionary with keys ['num_neurons', 'epsilon', 'noise_std', 'bkgd_trace_var']
    """
    ###
    # YOUR CODE BELOW
    ...
    footprint = ...
    #
    ###
    assert np.all(np.isfinite(footprint))
    return footprint


## Problem 3d: Optimize the background

Optimize the background trace by projecting the residual onto the background footprint and shrinking the result slightly,
\begin{align}
c_0 = \left(\frac{\varsigma_0^2}{\sigma^2 + \varsigma_0^2}\right) u_0^\top R_0 
\end{align}
where $R_0 = Y - U^\top C$ is the background residual and $\varsigma_0^2$ is the prior variance on the background trace. (See the course notes for a derivation.)

Update the background footprint by setting it to,
\begin{align}
u_0 = \frac{R_0 c_0}{\|R_0 c_0\|}
\end{align}

In [None]:
def _update_bkgd_trace(residual, params, latents, hypers):
    """Update the background trace $c_0$.
    
    Parameters
    ----------
    residual: a PxT numpy array containing the residual for the background.
    params: dictionary with keys ['footprints', 'bkgd_footprint']
    latents: dictionary with keys ['traces', 'bkgd_trace']
    hypers: dictionary with keys ['num_neurons', 'epsilon', 'noise_std', 'bkgd_trace_var']
    """
    ###
    # YOUR CODE BELOW
    ...
    shrink_factor = ...
    target = ...
    #
    ###
    return shrink_factor * target

def _update_bkgd_footprint(residual, params, latents, hypers):
    """Update the background footprint $u_0$.

    Parameters
    ----------
    residual: a PxT numpy array containing the residual for the background.
    params: dictionary with keys ['footprints', 'bkgd_footprint']
    latents: dictionary with keys ['traces', 'bkgd_trace']
    hypers: dictionary with keys ['num_neurons', 'epsilon', 'noise_std', 'bkgd_trace_var']
    """
    ###
    # YOUR CODE BELOW
    ...
    bkgd_footprint = ...
    #
    ###
    return bkgd_footprint

## Putting it all together

Now we'll put these steps together into the MAP estimation algorithm. It's very similar to what you implemented in Lab 2. It amounts to:
- Initialize the residual $R = Y - U^\top C - u_0 c_0^\top$
- Repeat until convergence:
    - For each neuron $n=1,\ldots,N$:
        - Update the residual to $R = R + u_n c_n^\top$
        - Update the trace $c_n$ by applying your `deconvolve` function from Part 2a to the target $\mu_n = u_n^\top R$
        - Update the footprint to $u_n = \frac{R c_n}{\|R c_n\|}$
        - Downdate the residual to $R = R - u_n c_n^\top$ using the new footprint and trace
    - Update the background:
        - Update the residual to $R = R + u_0 c_0^\top$
        - Set the background trace to $c_0 = \frac{\varsigma_0^2}{\sigma^2 + \varsigma_0^2} u_0^\top R$ where $\varsigma_0^2$ is the prior variance of the background trace. (We will set it to be very large so that we barely shrink the background trace.)
        - Set the background footprint to $u_0 = \frac{R c_0}{\|R c_0\|}$
        - Downdate the residual to $R = R - u_0 c_0^\top$ using the new background footprint and trace.
    - Compute the log likelihood using the residual $R$


In [None]:
def map_estimate(flat_data,
                 init_latents,
                 init_params,
                 hypers,
                 num_iters=10,
                 tol=2e-4):
    """Fit the CNMF model via coordinate ascent.
    """
    
    # make a fancy reusable progress bar for the inner loops over neurons.
    outer_pbar = trange(num_iters)
    inner_pbar = trange(hypers["num_neurons"])
    inner_pbar.set_description("updating neurons")

    # make a copy of the model rather than overwriting the inputs
    latents = deepcopy(init_latents)
    params = deepcopy(init_params)

    # initialize the residual
    residual = flat_data.copy()
    residual -= params["footprints"].T @ latents["traces"]
    residual -= np.outer(params["bkgd_footprint"], latents["bkgd_trace"])

    # track log likelihoods over iterations
    lls = [log_likelihood_residual(residual, hypers["noise_std"])]
    outer_pbar.set_description("LL: {:.4f}".format(lls[-1]))

    # coordinate ascent
    for itr in outer_pbar:
        
        # update neurons one at a time
        inner_pbar.reset()
        for n in range(hypers["num_neurons"]):
            # update the residual (add $u_n c_n^\top$)
            residual += np.outer(params["footprints"][n], latents["traces"][n])
    
            # update the trace and footprint with the residual
            latents["traces"][n] = _update_trace(n, residual, params, latents, hypers)
            params["footprints"][n] = _update_footprint(n, residual, params, latents, hypers)
            
            # downdate the residual (subtract $u_n c_n^\top$)
            residual -= np.outer(params["footprints"][n], latents["traces"][n])

            # step the progress bar
            inner_pbar.update()

        # update the background
        residual += np.outer(params["bkgd_footprint"], latents["bkgd_trace"])
        latents["bkgd_trace"] = _update_bkgd_trace(residual, params, latents, hypers)
        params["bkgd_footprint"] = _update_bkgd_footprint(residual, params, latents, hypers)
        residual -= np.outer(params["bkgd_footprint"], latents["bkgd_trace"])
        
        # compute the log likelihood 
        lls.append(log_likelihood_residual(residual, hypers["noise_std"]))
        outer_pbar.set_description("LL: {:.4f}".format(lls[-1]))
        
        # check for convergence
        if abs(lls[-1] - lls[-2]) < tol:
            print("Convergence detected!")
            break
    
    return np.array(lls), latents, params

## Fit it!

This should take about 5 minutes.

In [None]:
# Fit it!
lls, latents, params = map_estimate(flat_data, init_latents, init_params, hypers)

# Plot the log likelihoods
plt.plot(lls, '-o',)
plt.xlabel("Iteration")
plt.xlim(-.1, len(lls) - .9)
plt.ylabel("Log Likelihood")
plt.grid(True)

## Plot the inferred footprints and traces

In [None]:
warnings.filterwarnings("ignore") # Yes, we know we're creating a lot of figures...
plot_problem_3(flat_data, params, latents, hypers)

## Problem 3e: Make a movie of the data, reconstruction, and residual

Show a movie with the data, reconstruction, and residual side by side. If all goes well, the data should show a nice, clean movie of spiking neurons and the residual should mostly look like white noise. In practice, you'll probably still see some evidence of neurons in the residual, suggesting that the model still isn't perfect. 

In [None]:
###
# Reconstruct the data and compute the residual.
# Then make a movie of the data, reconstruction, and residual 
# side-by-side.
#
# YOUR CODE BELOW
flat_recon = ...
flat_residual = ...

# Reshape into (height x width x frames) image stacks and concatenate along axis=1.
movie = np.concatenate([...], axis=1)
#
###

# Play the movie
play(movie, speedup=5)

# Part 4: Discussion

Hopefully you were successful in separating the neurons from the background and noise! Let's take a minute to reflect on the model and results. Write a couple paragraphs in response to the following prompts.

- We mentioned a few times that actual CNMF implementations also constrain the footprints to be non-negative. Without this constraint, you probably found in the plots above (before Problem 3e) that some of these footprints contain negative values. Why is this unrealistic and what are the consequences of omitting this constraint?
- What were the bottlenecks in the MAP estimation algorithm? Do you think implementing this model on a GPU, like we did in Lab 2, could help? Why or why not?
- You probably noticed that the background has lots of rings in it, like little Cheerios. Why do you think that is?
- We assumed that all neurons share the same time constant $\tau$. Is that reasonable? How could you learn per-neuron time constants?
- Do you think we can infer the number of underlying action potentials from the amplitude of the jumps in the calcium traces? 

_Answer below this line_

---

# Submission instructions

Download your notebook in .ipynb format and use the following command to convert it to PDF
```
jupyter nbconvert --to pdf lab3_teamname.ipynb
```
If you're using Anaconda for package management, you can install `nbconvert` with
```
conda install -c anaconda nbconvert
```
Upload your .ipynb and .pdf files to Gradescope. 

**Only one submission per team!**