<a href="https://colab.research.google.com/github/sg879/IIBProject/blob/main/Test/Test_Data.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os, sys
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
%%capture
!pip install blackcellmagic

# Creating Test Data

## Import and Load Packages

In [3]:
%load_ext blackcellmagic

In [4]:
import seaborn as sns
import matplotlib as mpl

mpl.rcParams.update(mpl.rcParamsDefault)
mpl.use("pgf")
mpl.rcParams.update(
    {
        "pgf.texsystem": "pdflatex",
        "font.family": "serif",
        "text.usetex": True,
        "pgf.rcfonts": False,
        "pgf.preamble": "\n".join(
            [
                r"\usepackage{bm}",
                r"\usepackage{mathtools}",
                r"\DeclarePairedDelimiter\norm{\lVert}{\rVert}",
                r"\makeatletter",
                r"\let\oldnorm\norm",
                r"\def\norm{\@ifstar{\oldnorm}{\oldnorm*}}",
                r"\makeatother",
                r"\usepackage[dvipsnames]{xcolor}", 
                r"\definecolor{myred}{RGB}{205, 108, 46}",
            ]
        ),
        "font.serif": ["Computer Modern Roman"],
    }
)
mpl.rc(
    "text.latex",
    preamble="\n".join(
        [
            r"\usepackage{bm}",
            r"\usepackage{mathtools}",
            r"\DeclarePairedDelimiter\norm{\lVert}{\rVert}",
            r"\makeatletter",
            r"\let\oldnorm\norm",
            r"\def\norm{\@ifstar{\oldnorm}{\oldnorm*}}",
            r"\makeatother",
            r"\usepackage[dvipsnames]{xcolor}", 
            r"\definecolor{myred}{RGB}{205, 108, 46}",
        ]
    ),
)

import matplotlib.font_manager
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import Divider, Size

In [5]:
%matplotlib inline

In [6]:
import numpy as np
import jax.numpy as jnp
from jax import grad, jit, vmap, random
import jax.lax as lax

## Function Definitions

In [7]:
@jit
def squared_exp(i, j, tau, ell, sigma_f):
  return sigma_f**2.0*jnp.exp(-(tau[i]-tau[j])**2/(2*ell**2.0))

In [8]:
def cov_squared_exp(tau, ell, sigma_f):
  k = jnp.shape(tau)[0]
  covmat = jnp.empty((k,k)).astype(jnp.float32)
  for i in jnp.arange(k):
    for j in np.arange(i+1):
      covmat = covmat.at[i, j].set(squared_exp(i, j, tau, ell, sigma_f))
  return covmat

## Create Input and Filter

### Generate random number key

In [9]:
key = random.PRNGKey(0)

# Split key to get usable subkeys in case of future editing
key, *subkeys = random.split(key, 3)

### Make random input spike train

In [10]:
# Get subkey
subkeyx = subkeys[0]

# Set up time indexing
Input_mt = 0.1
tbin_l = 0.005
K = jnp.floor(Input_mt/tbin_l).astype(jnp.int32)
timex = jnp.linspace(0.0, Input_mt, K+1)

inpx = random.randint(subkeyx, (K,1), 0, 2)

### Generate filter shape

In [11]:
# Get subkey
subkeyf = subkeys[1]

# Set up time indexing
Filter_mt = 0.05
M = jnp.floor(Filter_mt/tbin_l).astype(jnp.int32)

# True mean and covariance
muf_true = jnp.ones((M,1))
covf_true = cov_squared_exp(timex, 2.0, 5.0)

f_true = random.randint(subkeyf, (M,1), 0, 2)