In [None]:
import collections
import pickle
import torch
from tqdm import tqdm
import wandb

api = wandb.Api()

runs = api.runs("pmin/train_fmri_convex.py", 
                {"$and": 
                 [
                     {"config.dataset": {"$in": ["mst_norm_airsim"]}},
                     {"config.subset": "35"},
                     {"config.aggregator": "downsample"},
                     {"config.pca": 500},
                     {"state": "finished"},
                ]
                }
               )

archives = {}

print("Found %i" % len(runs))
for run in tqdm(runs):
    if run.config['features'] == 'SlowFast':
        if run.config['layer'] > 16:
            run.config['features'] = 'SlowFast_Fast'
            run.config['layer'] = run.config['layer'] - 17
        else:
            run.config['features'] = 'SlowFast_Slow'
    
    unique_name = f"{run.config['features']}_layer{int(run.config['layer']):02}_{run.config['dataset']}_{(run.config['subset'])}"
    if unique_name in archives:
        continue
        
    #print(unique_name)
    #for file in run.files():
    #    if file.name == 'results.pkl':
    #        file.download(replace=True)
            
    #with open('results.pkl', 'rb') as f:
    #    results = pickle.load(f)
    if 'corrs_report' in run.summary:
        for file in run.files():
            if file.name == 'results.pkl':
                file.download(replace=True)
            if file.name == 'optimal_weights.pkl':
                file.download(replace=True)
                
        with open('results.pkl', 'rb') as f:
            results = pickle.load(f)
            
        #with open('optimal_weights.pkl', 'rb') as f:
        #    optimal_weights = pickle.load(f)
            
        archives[unique_name] = {'corrs_report': run.summary['corrs_report'],
                                 'config': run.config,
                                 'results': results,
                                 #'optimal_weights': optimal_weights
                                }

In [None]:
from matplotlib.patches import Patch
import numpy as np
import tables

def draw_tuning_circles(ax, f, rg):
    assert f.shape[-1] == 8
    ri = .4
    ro = 1.0
    cmap=plt.cm.RdBu_r
    norm = plt.Normalize(rg[0], rg[1])
    
    spacing = 2.2
    
    for k in range(f.shape[1]):
        for j in range(f.shape[0]):
            for i in range(8):
                color=cmap(norm(f[k, j, i]))
                theta1 = (i - .5) * 2 * np.pi / 8
                theta2 = (i + .5) * 2 * np.pi / 8
                nsegments = 16
                points = [[spacing * j + ri * np.cos(theta2), -spacing * k + ri * np.sin(theta2)],
                          [spacing * j + ro * np.cos(theta2), -spacing * k + ro * np.sin(theta2)]]
                for n in range(1, nsegments + 1):
                    points.append(
                        [spacing * j + ro * np.cos(theta2 + (theta1 - theta2) * n / nsegments), 
                         -spacing * k + ro * np.sin(theta2 + (theta1 - theta2) * n / nsegments)]
                    )
                points.append([spacing * j + ri * np.cos(theta1), -spacing * k + ri * np.sin(theta1)])
                p = plt.Polygon(points, facecolor=color, edgecolor='lightgray')
                ax.add_patch(p)
                
def draw_supertune(y):
    plt.figure(figsize=(10, 3.3))
    assert y.shape == (3, 3, 3, 8)
    dx = y.max() - np.median(y)
    rg = [np.median(y) - dx, np.median(y) + dx]
    
    titles = ['translation', 'spirals', 'deformation']
    for i in range(3):
        ax = plt.subplot(131 + i)
        draw_tuning_circles(ax, y[:, :, i, :].transpose((1, 0, 2)), rg)
        plt.xlim([-5.6, 1.2])
        plt.ylim([-5.6, 1.2])
        plt.axis('tight')
        plt.axis('equal')
        plt.axis('off')
        plt.title(titles[i])

#draw_tuning_circles(ax, np.tile(np.array([-1, 1, 0, 0, 0, 0, 0, 0]).reshape((1, 1, 8)), [3, 3, 1]), [-1, 1])

table = tables.open_file('/mnt/e/data_derived/packlab-mst/ku259.h5')
y = table.get_node('/Y_report')[:]
#draw_tuning_circles()
#plt.plot(y.reshape((-1, 8)).T)
draw_supertune(y.reshape((3, 3, 3, 8)))
plt.suptitle('Real data')
plt.show()

for k in archives.keys():
    data = archives[k]['optimal_weights']['Y_preds'].cpu().numpy()#.reshape((3, 3, 3, 8))
    data = np.append(data, 0)

    draw_supertune(data.reshape(3, 3, 3, 8))
    plt.suptitle(f'Modeled data, {k}', y = 1.0)
    plt.show()

In [None]:
mat = tables.open_file('/mnt/e/data_derived/packlab-mst/ku259.mat')
Yall = mat.get_node('/Yall_st')[:]
signal_power = 1 / (Yall.shape[0] - 1) * (Yall.shape[0] * Yall.mean(0).var() - Yall.var(1).mean())
response_power = Yall.mean(0).var()

np.sqrt(response_power / signal_power)

In [None]:
import matplotlib.pyplot as plt
plt.plot(archives['airsim_04_layer06_mst_35']['optimal_weights']['Y_preds'].cpu())