In [87]:
import os
import numpy as np
from pathlib import Path
from matplotlib import pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score

In [126]:
def viz_single_session(eid, target, model):
    path = res_dir/eid/target/model/f'comp_-1.npy'
    gt = np.load(path, allow_pickle=True).item()['target']
    pred = np.load(path, allow_pickle=True).item()['pred']
    r2 = r2_score(gt.flatten(), pred.flatten())
    
    y = gt - gt.mean(0)
    y_pred = pred - pred.mean(0)
    y_resid = y - y_pred
    
    vmin_perc, vmax_perc = 10, 90 
    vmax = np.percentile(y_pred, vmax_perc)
    vmin = np.percentile(y_pred, vmin_perc)
    
    toshow = [y, y_pred, y_resid]
    resid_vmax = np.percentile(toshow, vmax_perc)
    resid_vmin = np.percentile(toshow, vmin_perc)
    
    N = len(y)
    y_labels = [f'obs. {target}', f'pred. {target}', f'resid. {target}']

    fig, axes = plt.subplots(3, 1, figsize=(8, 6))
    im1 = axes[0].imshow(y, aspect='auto', cmap='bwr', vmin=vmin, vmax=vmax)
    cbar = plt.colorbar(im1, pad=0.02, shrink=.6)
    cbar.ax.tick_params(rotation=90)
    axes[0].set_title(f'{model} ({eid[:8]}) R2: {r2:.3f}')
    im2 = axes[1].imshow(y_pred, aspect='auto', cmap='bwr', vmin=vmin, vmax=vmax)
    cbar = plt.colorbar(im2, pad=0.02, shrink=.6)
    cbar.ax.tick_params(rotation=90)
    im3 = axes[2].imshow(y_resid, aspect='auto', cmap='bwr', vmin=resid_vmin, vmax=resid_vmax)
    cbar = plt.colorbar(im3, pad=0.02, shrink=.6)
    cbar.ax.tick_params(rotation=90)
    
    for i, ax in enumerate(axes):
        ax.set_ylabel(f"{y_labels[i]}"+f"\n(#trials={N})")
        ax.yaxis.set_ticks([])
        ax.yaxis.set_ticklabels([])
        ax.xaxis.set_ticks([])
        ax.xaxis.set_ticklabels([])
        ax.spines[['left','bottom', 'right', 'top']].set_visible(False)
    
    plt.savefig(plot_dir/f'{target.split("_")[0]}_{eid[:5]}_{model}_r2_{r2:.3}.png')

In [127]:
res_dir = Path('/home/yizi/shared_decoding/results/pc_results')
plot_dir = Path('/home/yizi/shared_decoding/results/plots/residual_plots/')

eids = [
    '034e726f-b35f-41e0-8d6c-a22cc32391fb',
    '09b2c4d1-058d-4c84-9fd4-97530f85baf6',
    '0a018f12-ee06-4b11-97aa-bbbff5448e9f',
    '3537d970-f515-4786-853f-23de525e110f',
    '56b57c38-2699-4091-90a8-aba35103155e'
]

In [None]:
for eid in eids:
    for target in ['wheel_speed', 'motion_energy', 'pupil_diameter']:
        for model in ['ridge', 'reduced-rank', 'multi-sess-reduced-rank', 'mlp', 'lstm']:
            viz_single_session(eid, target, model)