In [178]:
import numpy as np
from sklearn.cross_decomposition import CCA
import matplotlib.pyplot as plt

In [179]:
num_bins = 20
bin_length = 0.05
num1 = 50
num2 = 30

In [180]:
class Neuron:
    def __init__(self, gen):
        self.firing_rate = max(0.5, gen.lognormal(0.2, 2.0))
    
    def sample_bins(self, num_bins, bin_length, gen):
        return gen.poisson(self.firing_rate * bin_length, size=num_bins)


In [181]:
def generate_neurons(num_neurons, gen):
    return [Neuron(gen) for _ in range(num_neurons)]

def generate_binned_trial_rates(neurons, num_bins, bin_length, gen):
    return np.array([neuron.sample_bins(num_bins, bin_length, gen) for neuron in neurons])

In [182]:
def fit_cca(X_, Y_, cca=None, fit=True):
    if cca is None:
        cca = CCA(n_components=1, scale=True)
    if fit:
        cca.fit(X_, Y_)
    X, Y = cca.transform(X_, Y_)
    X = X.squeeze()
    Y = Y.squeeze()
    return cca, X, Y


In [183]:
def find_correlation(X, Y):
    if len(X.shape) == 1:
        return np.corrcoef(X, Y)[0, 1]
    else:
        return np.array(
            [np.corrcoef(X[:, v], Y[:, v])[0, 1] for v in range(X.shape[1])]
        )

In [184]:
def transform(X_, Y_, cca):
    rotation_x = cca.x_rotations_
    rotation_y = cca.y_rotations_

    scale_x = X_.std(axis=0, ddof=1)
    scale_y = Y_.std(axis=0, ddof=1)
    scale_x[scale_x == 0.0] = 1
    scale_y[scale_y == 0.0] = 1
    scaled_X = (X_ - X_.mean(axis=0)) / scale_x
    scaled_Y = (Y_ - Y_.mean(axis=0)) / scale_y

    Xt = np.dot(scaled_X, rotation_x)
    Yt = np.dot(scaled_Y, rotation_y)

    return Xt.squeeze(), Yt.squeeze()

In [185]:
def find_mean_correlation(X, Y, num_trials, data_per_trial):
    correlations = []
    for i in range(num_trials):
        start = i * data_per_trial
        end = start + data_per_trial
        correlations.append(find_correlation(X[start:end], Y[start:end]))
    return np.mean(np.abs(correlations)), np.std(np.abs(correlations))
    

In [186]:
gen = np.random.default_rng(0)

neurons_r1 = generate_neurons(num1, gen)
neurons_r2 = generate_neurons(num2, gen)

binned_trial_rates_r1 = generate_binned_trial_rates(
    neurons_r1, num_bins, bin_length, gen
)
binned_trial_rates_r2 = generate_binned_trial_rates(
    neurons_r2, num_bins, bin_length, gen
)

binned_trial_r1 = generate_binned_trial_rates(neurons_r1, num_bins, bin_length, gen)
binned_trial_r2 = generate_binned_trial_rates(neurons_r2, num_bins, bin_length, gen)

only_frates_r1 = np.array(
    [generate_binned_trial_rates(neurons_r1, 1, 1, gen).squeeze() for _ in range(100)]
)
only_frates_r2 = np.array(
    [generate_binned_trial_rates(neurons_r2, 1, 1, gen).squeeze() for _ in range(100)]
)

only_frates_a = np.array(
    [generate_binned_trial_rates(neurons_r1, 1, 1, gen).squeeze() for _ in range(100)]
)
only_frates_b = np.array(
    [generate_binned_trial_rates(neurons_r2, 1, 1, gen).squeeze() for _ in range(100)]
)

per_trial_bins1 = np.concatenate(
    [generate_binned_trial_rates(neurons_r1, num_bins, bin_length, gen).T for _ in range(100)], axis=0
)
per_trial_bins2 = np.concatenate(
    [generate_binned_trial_rates(neurons_r2, num_bins, bin_length, gen).T for _ in range(100)], axis=0
)

In [187]:
# Option number 1
cca, X, Y = fit_cca(binned_trial_rates_r1.T, binned_trial_r2.T, cca=CCA(n_components=4))
print(f"Correlation: {find_correlation(X, Y)}")
cca, X, Y = fit_cca(binned_trial_r1.T, binned_trial_r2.T, cca=cca, fit=False)
print(f"Correlation: {find_correlation(X, Y)}")
    

Correlation: [1. 1. 1. 1.]
Correlation: [ 0.14734664 -0.06030661  0.2127017   0.00817044]


In [188]:
# Option number 2
cca, Xtrain, Ytrain = fit_cca(per_trial_bins1[:500], per_trial_bins2[:500], cca=CCA(n_components=1), fit=True)
cca, Xtest, Ytest = fit_cca(per_trial_bins1[500:], per_trial_bins2[500:], cca=cca, fit=False)
print(find_correlation(Xtrain, Ytrain))
for val in range(10, 60, 10):
    start = val - 10
    print(f"Train {val} : {find_correlation(Xtest[start:val], Ytest[start:val])}")
    print(f"Test {val} : {find_correlation(Xtrain[start:val], Ytrain[start:val])}")

print(find_mean_correlation(Xtrain, Ytrain, 10, 50))
print(find_mean_correlation(Xtest, Ytest, 10, 50))

0.5068078264120351
Train 10 : 0.26637604850332197
Test 10 : 0.22943212807308686
Train 20 : -0.02842341271787485
Test 20 : 0.6897167867233878
Train 30 : 0.5256872499907193
Test 30 : 0.6573712011846082
Train 40 : -0.14368334408726813
Test 40 : 0.20565943527770425
Train 50 : 0.2121171126164346
Test 50 : 0.8287783132376072
0.5048158666958559
0.12801401032163992


In [189]:
# Option number 3
cca, X, Y = fit_cca(only_frates_r1[:50], only_frates_r2[:50], cca=CCA(n_components=3))
print(find_correlation(X, Y))
cca, X, Y = fit_cca(only_frates_r1[50:], only_frates_r2[50:], cca=cca, fit=False)
print(find_correlation(X, Y))

[1. 1. 1.]
[-0.03407047  0.0858088  -0.17024104]
