In [2]:
from scipy.io import loadmat
import numpy as np

In [225]:
def getDAtrend(DAtrend, t1, t2, data_type='smoothed'):
    """
    Python version of MATLAB getDAtrend.
    
    Parameters:
        DAtrend: loaded MATLAB struct (as a list or array of objects)
        t1, t2: trial window indices (integers)
        dataType: 'raw' or 'smooth' (default: 'raw')
    Returns:
        stats: double array of size (numAnimals, 1)
    """

    # Determine dataType field
    if 'smooth' in data_type:
        key = 'slopeMap_smoothed'
    else:
        key = 'slopeMap_raw'

    stats = np.full(len(DAtrend), np.nan)

    for a in range(len(DAtrend)):
        try:
            field_data = DAtrend[a][key]
            if t1 <= 0: t1_final = field_data.shape[0] + t1
            else: t1_final = t1
            if t2 <= 0: t2_final = field_data.shape[1] + t2
            else: t2_final = t2

            if (t1_final > field_data.shape[0] or t2_final > field_data.shape[1]
                or t1_final <= 0 or t2_final <= 0 or t1 > t2_final):
                stats[a] = np.nan
            else:
                stats[a] = field_data[t1_final-1, t2_final-1]  # MATLAB is 1-based, Python is 0-based
        except Exception as e:
            print(f"Error accessing data for animal {a},: {e}")
            stats[a] = np.nan

    return stats


def getDAvsEImap(DAvsEImap, key='smoothed', direction='reverse', nTrials=50):
    """ 
    Parameters:
        DAvsEImap: loaded MATLAB struct (as a list or array of objects)
        key: 'smoothed' or 'raw' (default: 'smoothed')
    Returns:
        stats: double array of size (numAnimals, 1)
    """

    map_data = DAvsEImap[key][0]
    nTrials = map_data.shape[0] // 2

    if direction == 'reverse':
        late_ticks = -np.flip(np.arange(1, nTrials + 1))  # flip(1:nTrials)
        late_idx = map_data.shape[0] + late_ticks

    sub_map = map_data[np.ix_(late_idx, late_idx)]
    return sub_map

def remove_nan_values(DA, EI):
    """
    Remove NaN values from x_vals and y_vals.
    Returns cleaned x_vals and y_vals.
    """
    # Convert to NumPy arrays
    DA = np.asarray(DA).flatten()
    EI = np.asarray(EI).flatten()

    # Remove NaNs
    valid = ~np.isnan(DA) & ~np.isnan(EI)
    DA_clean = DA[valid]
    EI_clean = EI[valid]

    return DA_clean, EI_clean

In [228]:
DAtrend = loadmat('/Users/shunli/Desktop/manim_projects/DAtrend_manim.mat')
DAtrend = DAtrend['DAtrend_manim'].flatten()

DAvsEImap = loadmat('/Users/shunli/Desktop/manim_projects/DAvsEImap_manim.mat')
DAvsEImap = DAvsEImap['DAvsEImap_manim'].flatten()
true_map = getDAvsEImap(DAvsEImap,key='smoothed',nTrials=50)

animalEI_mat = loadmat('/Users/shunli/Desktop/manim_projects/animalEIpeaks.mat')
animalEI = animalEI_mat['animalEIindex_peaks']

In [229]:
i, j = 22, 48

stats = getDAtrend(DAtrend, t1=i-50, t2=j-50, data_type='smooth')

DA_clean, EI_clean = remove_nan_values(stats, animalEI)
print(f"Cleaned DA: {DA_clean}")
print(f"Cleaned EI: {EI_clean}")

# Fit stats vs animalEI
slope, intercept = np.polyfit(DA_clean, EI_clean, 1)
print(f"Slope: {slope}, Intercept: {intercept}")

print(true_map[i,j])

Cleaned DA: [-0.0895962  -0.01293639  0.00716079 -0.00413119 -0.03489123  0.0210576
 -0.08637815 -0.08766801  0.06588019 -0.01109956 -0.02133382  0.01208358
 -0.01340857 -0.01819567 -0.01429446  0.02872162  0.01213604  0.02215556
 -0.00906644  0.00611441  0.06259907  0.02441325 -0.00559765  0.05251934
  0.00849867 -0.03202014 -0.06645395  0.03945756]
Cleaned EI: [ 0.09239674  0.03965995 -0.42226602 -0.08940776 -0.18445254  0.25400178
 -0.22965684 -0.56444827  0.78973906 -0.00993123 -0.27625927 -0.74943796
 -0.72454838 -0.20890144 -0.12799995 -0.0646866  -0.2379009  -0.37428537
 -0.01373143  0.05648903  0.46759495  0.16417774  0.03385522  0.25977921
  0.33743203  0.45291995 -0.02182647  0.31510599]
Slope: 3.5831176468657593, Intercept: -0.018558535905992112
-0.4933087306273851


In [233]:
true_map[i-50-1,j-50-1]

np.float64(3.5831176468657566)