In [None]:
import wandb
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import pandas as pd
import scipy.stats as st
import pickle

from egg.palettes import palettes

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression

sns.set_context('talk')

def plot(x, y, ax=None, c='#1E88E5', text_y=40):
    if ax is None:
        ax = plt.gca()

    huber = HuberRegressor(fit_intercept=False, epsilon=1.1)
    huber.fit(x.to_numpy()[..., None], y.to_numpy())
    coef = huber.coef_[0]

    ax.plot([0, 50], [0, 50], c='tab:gray', ls='--', lw=2)
    ax.plot([0, 50], [0, 50 * coef], c='k', lw=2)
    ax.scatter(x, y, s=30, c=c, zorder=10)
    ax.axis('equal')
    ax.text(5, text_y, f"$r(x) = {coef:.2f}x$")

def get_scores(run_id):
    api = wandb.Api()
    run = api.run(f'sinzlab/egg/{run_id}')
    history = run.history()

    return history


# Task-Driven ResNet + Gaussian Readout Model

## Get the data
Get the EGG MEIs (DIMEs) and GA MEIs

Get the MEIs and choose the seeds which perform best on train

In [None]:
history = get_scores('dxuyo5r1')
idx = history.groupby(['unit_idx'])['train'].idxmin()
dimes = history.loc[idx, ['seed', 'unit_idx', 'train', 'val', 'cross-val', 'image']]

mei_history = get_scores('h83eq1s8')
idx = mei_history.groupby(['unit_idx'])['train'].idxmin()
meis = mei_history.loc[idx, ['seed', 'unit_idx', 'train', 'val', 'cross-val', 'image']]

mei_dime = dimes.merge(meis, on='unit_idx', how='right', suffixes=['_d', '_m'])
mei_dime = mei_dime.merge(data_driven_corrs, left_on='unit_idx', right_on='unit_id')

## Compare activations
*note: `cross-val_d` is stored negative*

In [None]:
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plot(mei_dime['cross-val_m'], -mei_dime['cross-val_d'])
plt.xlim(0, 40)
plt.ylim(0, 40)
plt.title('Cross')
plt.xlabel('GA')
plt.ylabel('EGG')

plt.yticks([0, 10, 20, 30, 40, 50], [0, 10, 20, 30, 40, 50])
plt.xticks([0, 10, 20, 30, 40, 50], [0, 10, 20, 30, 40, 50])

plt.subplot(1, 2, 2)
plot(mei_dime['val_m'], -mei_dime['val_d'], c=palettes['candy']['blue'])
plt.xlim(0, 40)
plt.ylim(0, 40)
plt.title('Within')

plt.yticks([0, 10, 20, 30, 40, 50], [''] * 6)
plt.xticks([0, 10, 20, 30, 40, 50], [''] * 6)

sns.despine(trim=True)
plt.savefig('./activations.png', dpi=150, bbox_inches='tight')

Check the means

In [None]:
print("Within:", np.mean(-mei_dime['val_d']), np.mean(mei_dime['val_m']))
print("Cross:", np.mean(-mei_dime['cross-val_d']), np.mean(mei_dime['cross-val_m']))

Check if they are significant

In [None]:
print("Within:", st.wilcoxon(-mei_dime['val_d'], mei_dime['val_m']))
print("Cross:", st.wilcoxon(-mei_dime['cross-val_d'], mei_dime['cross-val_m']))

## Plot MEI examples

In [None]:
mei_dime['ratio'] = -mei_dime['cross-val_d'] / mei_dime['cross-val_m']
images = history.loc[idx, ['unit_idx', 'image', 'seed']]

np.random.seed(1)
images = mei_dime.sort_values('ratio', ascending=False).head(11).iloc[1:] #mei_dime.sample(10)

api = wandb.Api()
run_d = api.run(f'sinzlab/egg/dxuyo5r1')
run_m = api.run(f'sinzlab/egg/h83eq1s8')

imgs_m = []
imgs_d = []
for image in images.iterrows():
    filename_m = image[1].image_m['path']
    filename_d = image[1].image_d['path']
    file_m = run_m.file(filename_m)
    file_d = run_d.file(filename_d)
    filename_m = file_m.download(exist_ok=True).name
    filename_d = file_d.download(exist_ok=True).name
    _img_m = Image.open(filename_m)
    _img_d = Image.open(filename_d)
    imgs_m.append(_img_m)
    imgs_d.append(_img_d)

In [None]:
img_d, img_m = np.hstack(imgs_d), np.hstack(imgs_m)
plt.figure(figsize=(18, 3))
plt.subplot(2, 1, 1)
plt.imshow(img_d)
plt.axis(False)

plt.subplot(2, 1, 2)
plt.imshow(img_m)
plt.axis(False)
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
            hspace = 0, wspace = 0)
plt.margins(0,0)

plt.savefig('./diffmeis.png', dpi=150, bbox_inches='tight', pad_inches=0)
plt.show()

In [None]:
plt.imshow(img_m)
plt.axis(False)
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
            hspace = 0, wspace = 0)
plt.margins(0,0)
plt.savefig('./diffmei_examples.png', dpi=150, bbox_inches='tight', pad_inches=0)

## Check compute performance

In [None]:
api = wandb.Api()
run = api.run(f'sinzlab/egg/fszcg6wz')

time = run.history()['time']
mu_d, se_d = time.mean(), time.std() / np.sqrt(len(time))

api = wandb.Api()
run = api.run(f'sinzlab/egg/vjcc5k8r')

time = run.history()['time']
mu_m, se_m = time.mean(), time.std() / np.sqrt(len(time))

In [None]:
sns.set_context('talk')
plt.figure(figsize=(3, 3), dpi=150, facecolor='w')
plt.bar([0, 1], [mu_d, mu_m], yerr=[se_d, se_m], color=[palettes['candy']['yellow'], palettes['candy']['green']], edgecolor='k', linewidth=3)

plt.xlim(-0.5, 1.5)

plt.xticks([-0.5, 0, 1, 1.5], ['', 'EGG', 'GA', ''])
plt.yticks([0, 50, 100, 150, 200, 250])

plt.text(0, mu_d + 5, f"{mu_d:.0f}s $\pm$ {se_d:.2f}", horizontalalignment='center', fontsize=11)
plt.text(1, mu_m + 5, f"{mu_m:.0f}s $\pm$ {se_m:.2f}", horizontalalignment='center', fontsize=11)

plt.ylabel('Generation time (s)')

sns.despine(trim=True)
plt.savefig('./performance.pdf', dpi=150, bbox_inches='tight')

# Data-Driven CNN + Attention Readout Model

## Get Data

In [None]:
dime_history = get_scores('ccoztu9h')
idx = dime_history.groupby(['unit_idx'])['train'].idxmin()
dimes = dime_history.loc[idx, ['seed', 'unit_idx', 'train', 'val', 'cross-val', 'image']]

mei_history = get_scores('jk3fgqnn')
idx = mei_history.groupby(['unit_idx'])['train'].idxmin()
meis = mei_history.loc[idx, ['seed', 'unit_idx', 'train', 'val', 'cross-val', 'image']]

mei_dime = dimes.merge(meis, on='unit_idx', how='right', suffixes=['_d', '_m'])
mei_dime = mei_dime.merge(data_driven_corrs, left_on='unit_idx', right_on='unit_id')

## Compare activations

In [None]:
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plot(mei_dime['cross-val_m'], -mei_dime['cross-val_d'], text_y=40, c=palettes['candy']['pink'])
plt.xlim(0, 50)
plt.ylim(0, 50)
plt.title('Cross')
plt.xlabel('GA')
plt.ylabel('EGG')

plt.yticks([0, 10, 20, 30, 40, 50], [0, 10, 20, 30, 40, 50])
plt.xticks([0, 10, 20, 30, 40, 50], [0, 10, 20, 30, 40, 50])

plt.subplot(1, 2, 2)
plot(mei_dime['val_m'], -mei_dime['val_d'], c=palettes['candy']['pink'], text_y=40)
plt.xlim(0, 50)
plt.ylim(0, 50)
plt.title('Within')

plt.yticks([0, 10, 20, 30, 40, 50], [''] * 6)
plt.xticks([0, 10, 20, 30, 40, 50], [''] * 6)

sns.despine(trim=True)
plt.savefig('./activations_attn.png', dpi=150, bbox_inches='tight')

Check the means

In [None]:
print("Within:", np.mean(-mei_dime['val_d']), np.mean(mei_dime['val_m']))
print("Cross:", np.mean(-mei_dime['cross-val_d']), np.mean(mei_dime['cross-val_m']))

Check if they are significant

In [None]:
print("Within:", st.wilcoxon(-mei_dime['val_d'], mei_dime['val_m']))
print("Cross:", st.wilcoxon(-mei_dime['cross-val_d']), np.mean(mei_dime['cross-val_m']))

## Get examples

In [None]:
images = mei_dime
images_m = images['image_m']
images_d = images['image_d']

api = wandb.Api()
run_m = api.run(f'sinzlab/egg/jk3fgqnn')
run_d = api.run(f'sinzlab/egg/ccoztu9h')

imgs_m = []
for image in images_m:
    path = image['path']
    file = run_m.file(path).download(exist_ok=True).name
    _img_m = Image.open(file)
    imgs_m.append(np.array(_img_m))

imgs_d = []
for image in images_d:
    path = image['path']
    file = run_d.file(path).download(exist_ok=True).name
    _img_d = Image.open(file)
    imgs_d.append(np.array(_img_d))

imgs_d = np.array(imgs_d).reshape(9, 10, 480, 480, 4)
imgs_m = np.array(imgs_m).reshape(9, 10, 480, 480, 4)

In [None]:
selected_d = np.hstack([imgs_d[7][3], imgs_d[6][7], imgs_d[5][0], imgs_d[5][1], imgs_d[5][8], imgs_d[4][6], imgs_d[8][0], imgs_d[4][8], imgs_d[6][5], imgs_d[0][5]])
selected_m = np.hstack([imgs_m[7][3], imgs_m[6][7], imgs_m[5][0], imgs_m[5][1], imgs_m[5][8], imgs_m[4][6], imgs_m[8][0], imgs_m[4][8], imgs_m[6][5], imgs_m[0][5]])
selected = np.vstack([selected_d, selected_m])

In [None]:
plt.figure(dpi=150, figsize=(18, 3))
plt.imshow(selected)

plt.axis(False)
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
            hspace = 0, wspace = 0)
plt.margins(0,0)
plt.savefig('./diffmei_attn.png', dpi=150, bbox_inches='tight', pad_inches=0)