In [None]:
import os, sys
import torch
import numpy as np
from scipy import ndimage
from skimage.filters import butterworth
import skimage.filters as filters
from skimage.filters import threshold_otsu

import pingouin as pg
import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "Arial"

# Import custom modules
sys.path.append("../")
from models.network_hierarchical_recurrent   import NetworkHierarchicalRecurrent
from models.network_hierarchical_feedforward import NetworkHierarchicalFeedforward
from data.dataset                            import data_loader

from plotting_functions import *

indist_data_loader = data_loader(
    '', # dir path
    split='validation',
    batch_size=200,
    shuffle=False
)

outdist_data_loader = data_loader(
    '', # dir path
    split='train',
    batch_size=200,
    shuffle=False
)

# Plotting functions

In [None]:
def plot_broken_bar (x_label, mean, error, facecolor, offset=-1):
    x = np.arange(len(x_label))

    fig, (ax1, ax2) = plt.subplots(2, 1)

    ax1.bar(x, mean, yerr=error, facecolor=facecolor)
    ax2.bar(x, mean, yerr=error, facecolor=facecolor)
    ax2.set_xticks(x)
    ax2.set_xticklabels(x_label, rotation=45)

    format_plot(ax1, fontsize=20)
    format_plot(ax2, fontsize=20)
    
    ax1.set_ylim(min(mean[offset:])-max(error[offset:])-0.03, max(mean[offset:])+max(error[offset:])+0.03)
    ax2.set_ylim(min(mean)-0.05, max(mean[:offset])+0.05) 


    ax1.spines['bottom'].set_visible(False)
    ax2.spines['top'].set_visible(False)
    ax1.set_xticks([])
    ax1.tick_params(labeltop=False)
    ax2.xaxis.tick_bottom()


    #determine axes and their limits 
    ax_selec = [(ax, ax.get_ylim()) for ax in [ax1, ax2]]
    #find maximum y-limit spread
    max_delta = max([lmax-lmin for _, (lmin, lmax) in ax_selec]) 
    #expand limits of all subplots according to maximum spread
    for ax, (lmin, lmax) in ax_selec:
        d = max_delta
        ax.set_ylim(lmin-(max_delta-(lmax-lmin))/2, lmax+(max_delta-(lmax-lmin))/2)


    d = .015  
    kwargs = dict(transform=ax1.transAxes, color='k', clip_on=False)
    ax1.plot((-d, +d), (-d, +d), **kwargs)
    kwargs.update(transform=ax2.transAxes)
    ax2.plot((-d, +d), (1 - d, 1 + d), **kwargs) 

    fig.set_size_inches(4,4)
    plt.show()

# Model path data

In [None]:
# Paths to trained model checkpoints here

model_beta = {
    0   : {},
    0.05: {},
    0.1 : {},
    0.25: {},
    0.5 : {}
}
model_architecture = {
    'Hierarchical': {},
    'Vanilla RNN' : {},
    'Feedforward' : {}
}

# Next-frame prediction

In [None]:
def filter_mse (data):
    MSE = (np.diff(data[0], axis=1)**2).mean(axis=2).mean(axis=1)
    
    return data[0][MSE>1], data[1][MSE>1]

def low_pass (data, freq):
    return [
        torch.tensor([
            [
                butterworth(i.detach().cpu().numpy().reshape(20, 20), cutoff_frequency_ratio=freq, high_pass=False).reshape(-1)
                for i in c
            ]
            for c in d
        ])
        for d in data
    ]

def low_pass (data, freq):
    d_ = []
    for d in data:
        c_ = []
        for c in d:
            thresh = threshold_otsu(c)
            c_.append([f>thresh for f in c])
        d_.append(c_)
        
    print(torch.tensor(d_)).shape
        
    return torch.tensor(d_)

In [None]:
def get_mse (model, data_loader):
    mse_arr = []
    
    model.eval()
    for batch_n, data in enumerate(data_loader):        
        if batch_n%5==0:
            print('\tStarting batch', batch_n)
        
        if batch_n > 19:
            break
        
        y = data[0]
        
        with torch.no_grad():
            out = model(y)
            loss, loss_components = model.get_loss(out, data)
            
        mse_arr.append(loss_components['mse0'].tolist())
            
    return np.mean(mse_arr), np.std(mse_arr)/(len(mse_arr)**0.5), mse_arr

def get_mse_copy_last_frame (data_loader):
    mse_arr = []
    
    for batch_n, data in enumerate(data_loader):
        #data = filter_mse(data)
        
        if batch_n%5==0:
            print('\tStarting batch', batch_n)
        
        y = data[0]
                    
        mse_arr += ((y[:, 1:]-y[:, :-1, :400])**2).detach().cpu().numpy().mean(axis=1).mean(axis=1).tolist()
            
    return np.mean(mse_arr), np.std(mse_arr)/(len(mse_arr)**0.5), mse_arr

def get_mse_by_model (model_data, data_loader, plot_by_L1=False):
    root  = '' # checkpoints dir
    epoch = 2000

    model_nm_arr = []
    model_mn_arr  = []
    model_er_arr  = []
    model_rw_arr  = []

    for model_name, model_paths in model_data.items():
        L1_arr   = []
        mn_arr  = []
        er_arr  = []
        rw_arr  = []

        for L1, folder in model_paths.items():
            file = f'{root}{folder}/{epoch}-epochs_model.pt'
            print(file)
            
            try:
                if 'feedforward' in file:
                    raise Exception
                model, _, _ = NetworkHierarchicalRecurrent.load(
                    model_path=file, device='cpu', plot_loss_history=False
                )
            except:
                model, _, _ = NetworkHierarchicalFeedforward.load(
                    model_path=file, device='cpu', plot_loss_history=False
                )            

            mn, er, rw = get_mse(model, data_loader)
            L1_arr.append(L1)
            mn_arr.append(mn)
            er_arr.append(er)
            rw_arr.append(rw)
            
        if plot_by_L1:
            plt.errorbar(L1_arr, mn_arr, yerr=er_arr)
            plt.xlabel('L1')
            plt.ylabel('MSE')
            format_plot()
            plt.show()

        model_nm_arr.append(model_name)
        model_mn_arr.append(np.min(mn_arr))
        model_er_arr.append(er_arr[np.argmin(mn_arr)])
        model_rw_arr.append(rw_arr[np.argmin(mn_arr)])
        
    copy_mn, copy_er, copy_rw = get_mse_copy_last_frame (data_loader)
    model_nm_arr.append('Copy frame')
    model_mn_arr.append(copy_mn)
    model_er_arr.append(copy_er)
    model_rw_arr.append(copy_rw)
    
    return model_nm_arr, model_mn_arr, model_er_arr, model_rw_arr

In [None]:
FRAME_SIZE = 20

def get_COM (im, rotation):
    im = im.reshape(20, 20)**4
    
    com = (
        np.average(np.arange(0, im.shape[0]), weights=im.mean(axis=1)),
        np.average(np.arange(0, im.shape[1]), weights=im.mean(axis=0))
    )

    theta = np.deg2rad(rotation)
    com_rot = (
        np.sin(theta) * (com[1]-10) + np.cos(theta) * (com[0]-10) + 10,
        np.cos(theta) * (com[1]-10) - np.sin(theta) * (com[0]-10) + 10,
    )

    return com_rot[1]


def whiten_and_filter_image (im_to_filt):
    Nx, Ny = im_to_filt.shape
    imf = np.fft.fftshift(np.fft.fft2(im_to_filt))
    fx = np.arange(-Nx/2, Nx/2)
    fy = np.arange(-Ny/2, Ny/2)
    [fx, fy] = np.meshgrid(fx,fy)
    rho = np.sqrt(fx**2 + fy**2)
    filtf = rho*np.exp(-0.5*(rho/(0.7*max(Nx, Ny)/2))**2)
    imwf = filtf.T*imf
    imw = np.real(np.fft.ifft2(np.fft.fftshift(imwf)))
    return imw

def get_bar_stimuli (rotate=0, w=3):
    im_arr = []
    for i in range(FRAME_SIZE):
        im = np.zeros((FRAME_SIZE*2, FRAME_SIZE*2))
        im[:, FRAME_SIZE//2+i:FRAME_SIZE//2+i+w] = 1
        im = whiten_and_filter_image(im)
        im = ndimage.rotate(im, rotate, reshape=False)
        im_arr.append(im[FRAME_SIZE//2:-FRAME_SIZE//2, FRAME_SIZE//2:-FRAME_SIZE//2].reshape(-1))
        
    im_arr = np.array(im_arr)
    im_arr = (im_arr-np.mean(im_arr))/np.std(im_arr)
    im_arr = torch.from_numpy(im_arr)
        
    return im_arr.float()

def get_mse_bar_copy_last_frame (warmup=10, tsteps=9):
    mse_arr = []

    for rotation in np.arange(0, 360, 45):
        x = get_bar_stimuli(rotation)

        y     = x[warmup:warmup+tsteps]
        y_hat = x[warmup:warmup+1     ].repeat((tsteps, 1))
    
        mse_arr += ((y-y_hat)**2).detach().cpu().numpy().mean(axis=1).tolist()
            
    return np.mean(mse_arr), np.std(mse_arr)/(len(mse_arr)**0.5), mse_arr

def get_mse_bar_rnn (model, hierarchical=False, warmup=10, tsteps=9):
    model.eval()
    mse_arr = []
    com_arr = []

    for rotation in np.arange(0, 360, 22.5):
        x = get_bar_stimuli(rotation)
        o = []

        targ = []
        pred = []

        if hierarchical:
            h = [torch.zeros((1, FRAME_SIZE*FRAME_SIZE*2*3))]
        else:
            h = [torch.zeros((1, FRAME_SIZE*FRAME_SIZE*2))]

        for t, x_t in enumerate(x):
            if t+1>=warmup+tsteps:
                break

            if t+1>=warmup:
                h_t, _ = model.rnn(o[-1].unsqueeze(0), h[-1])
            else:
                h_t, _ = model.rnn(x_t.unsqueeze(0), h[-1])            
            o_t    = model.fc(h_t)[:, :FRAME_SIZE*FRAME_SIZE]

            h.append(h_t)
            o.append(o_t[0])
            
            if t+1>=warmup:
                targ.append(x[t+1].detach().numpy())
                pred.append(o_t[0].detach().numpy())
                        
        mse_arr += [np.mean((t-p)**2) for t, p in zip(targ, pred)]
        com_arr += [np.abs(get_COM(t, rotation)-get_COM(p, rotation)) for t, p in zip(targ, pred)]
        
        if False:
            fig, axs = plt.subplots(nrows=2, ncols=tsteps, dpi=150, figsize=[tsteps, 2])
            for i, (t, p) in enumerate(zip(targ, pred)):
                axs[0, i].imshow(t.reshape(FRAME_SIZE, FRAME_SIZE))
                axs[0, i].set_xticks([])
                axs[0, i].set_yticks([])
                axs[1, i].imshow(p.reshape(FRAME_SIZE, FRAME_SIZE))
                axs[1, i].set_xticks([])
                axs[1, i].set_yticks([])
                
                if i >= warmup:
                    axs[1, i].set_xlabel(f't+{i-warmup}')

            plt.tight_layout()
            plt.show  
            
    mse_mn = np.mean(mse_arr)
    mse_er = np.std(mse_arr)/(len(mse_arr)**0.5)
    
    com_mn = np.mean(com_arr)
    com_er = np.std(com_arr)/(len(com_arr)**0.5)
    
    return mse_mn, mse_er, mse_arr, com_mn, com_er, com_arr

def get_mse_bar_ff (model, warmup=10, tsteps=9):
    model.eval()
    model.hidden_units_groups = model.hidden_units_groups[:1]

    mse_arr = []
    com_arr = []

    for rotation in np.arange(0, 360, 45):
        x = get_bar_stimuli(rotation).unsqueeze(0)

        targ = []
        pred = []

        p_t = model(x[:, :warmup])[0][0]
            
        for t in range(tsteps):
            p_t1 = model(p_t)[0][0][:, -1:]
            p_t = torch.cat([p_t, p_t1], dim=1)
            
            pred.append(p_t[0, 0].detach().numpy())
            targ.append(x[:, warmup+t+1].detach().numpy())

        mse_arr += [np.mean((t-p)**2) for t, p in zip(targ, pred)]
        com_arr += [np.abs(get_COM(t, rotation)-get_COM(p, rotation)) for t, p in zip(targ, pred)]
        
        if False:
            fig, axs = plt.subplots(nrows=2, ncols=tsteps, dpi=150, figsize=[tsteps, 2])
            for i, (t, p) in enumerate(zip(targ, pred)):
                axs[0, i].imshow(t.reshape(FRAME_SIZE, FRAME_SIZE))
                axs[0, i].set_xticks([])
                axs[0, i].set_yticks([])
                axs[1, i].imshow(p.reshape(FRAME_SIZE, FRAME_SIZE))
                axs[1, i].set_xticks([])
                axs[1, i].set_yticks([])
                
                if i >= warmup:
                    axs[1, i].set_xlabel(f't+{i-warmup}')

            plt.tight_layout()
            plt.show  
            
    mse_mn = np.mean(mse_arr)
    mse_er = np.std(mse_arr)/(len(mse_arr)**0.5)
    
    com_mn = np.mean(com_arr)
    com_er = np.std(com_arr)/(len(com_arr)**0.5)
    
    return mse_mn, mse_er, mse_arr, com_mn, com_er, com_arr

def get_mse_by_model_bar (model_data, plot_by_L1=False, copy_frame=True):
    root  = '' # checkpoint dir
    epoch = 2000

    model_name_arr     = []
    model_mse_mn_arr   = []
    model_mse_er_arr   = []
    model_mse_raw_arr  = []
    model_com_mn_arr   = []
    model_com_er_arr   = []
    model_com_raw_arr  = []

    for model_name, model_paths in model_data.items():
        L1_arr      = []
        mse_mn_arr  = []
        mse_er_arr  = []
        mse_raw_arr = []
        com_mn_arr  = []
        com_er_arr  = []
        com_raw_arr = []
        
        for L1, folder in model_paths.items():
            file = f'{root}{folder}/{epoch}-epochs_model.pt'
            print(file)

            try:
                if 'feedforward' in file:
                    raise Exception
                model, _, _ = NetworkHierarchicalRecurrent.load(
                    model_path=file, device='cpu', plot_loss_history=False
                )
            except:
                model, _, _ = NetworkHierarchicalFeedforward.load(
                    model_path=file, device='cpu', plot_loss_history=False
                )            

            if hasattr(model, 'ih0'):
                mse_mn, mse_er, mse_raw, com_mn, com_er, com_raw = get_mse_bar_ff(model)
            else:
                is_hierarchical = model.rnn.weight_ih_l0.shape[0] == 20*20*2*3
                mse_mn, mse_er, mse_raw, com_mn, com_er, com_raw = get_mse_bar_rnn(model, hierarchical=is_hierarchical)
                
            L1_arr.append(L1)
            mse_mn_arr.append(mse_mn)
            mse_er_arr.append(mse_er)
            mse_raw_arr.append(mse_raw)
            com_mn_arr.append(com_mn)
            com_er_arr.append(com_er)
            com_raw_arr.append(com_raw)
        
        if plot_by_L1:
            plt.errorbar(L1_arr, mse_arr, yerr=err_arr)
            plt.xlabel('L1')
            plt.ylabel('MSE')
            format_plot()
            plt.show()

        model_name_arr.append(model_name)
        model_mse_mn_arr.append(np.min(mse_mn_arr))
        model_mse_er_arr.append(mse_er_arr[np.argmin(mse_mn_arr)])
        model_mse_raw_arr.append(mse_raw_arr[np.argmin(mse_mn_arr)])
        model_com_mn_arr.append(np.min(com_mn_arr))
        model_com_er_arr.append(mse_er_arr[np.argmin(com_mn_arr)])
        model_com_raw_arr.append(mse_raw_arr[np.argmin(com_mn_arr)])

        
    if copy_frame:
        copy_mn, copy_er, copy_rw = get_mse_bar_copy_last_frame (warmup=10, tsteps=9)
        model_name_arr.append('Copy frame')
        model_mse_mn_arr.append(copy_mn)
        model_mse_er_arr.append(copy_er)
        model_mse_raw_arr.append(copy_rw)
        
    return model_name_arr, model_mse_mn_arr, model_mse_er_arr, model_mse_raw_arr, model_com_mn_arr, model_com_er_arr, model_com_raw_arr

In [None]:
model_name_arr, indist_mn_arr, indist_er_arr, indist_raw_arr = \
    get_mse_by_model (model_architecture, indist_data_loader, plot_by_L1=False)

model_name_arr, outdist_mn_arr, outdist_er_arr, outdist_raw_arr = \
    get_mse_by_model (model_architecture, outdist_data_loader, plot_by_L1=False)

bar_model_name_arr, bar_mn_arr, bar_er_arr, bar_raw_arr, _, _, _ = \
    get_mse_by_model_bar (model_architecture)

In [None]:
def plot_broken_bar (x_label, mean, error, facecolor, offset=-1):
    x = np.arange(len(x_label))

    fig, (ax1, ax2) = plt.subplots(2, 1)

    ax1.bar(x, mean, yerr=error, facecolor=facecolor)
    ax2.bar(x, mean, yerr=error, facecolor=facecolor)
    ax2.set_xticks(x)
    ax2.set_xticklabels(x_label, rotation=45)

    format_plot(ax1, fontsize=20)
    format_plot(ax2, fontsize=20)
    
    ax1.set_ylim(min(mean[offset:])-max(error[offset:])-0.08, max(mean[offset:])+max(error[offset:])+0.0)
    ax2.set_ylim(min(mean)-0.075, max(mean[:offset])+0.025) 


    ax1.spines['bottom'].set_visible(False)
    ax2.spines['top'].set_visible(False)
    ax1.set_xticks([])
    ax1.tick_params(labeltop=False)
    ax2.xaxis.tick_bottom()


    #determine axes and their limits 
    ax_selec = [(ax, ax.get_ylim()) for ax in [ax1, ax2]]
    #find maximum y-limit spread
    max_delta = max([lmax-lmin for _, (lmin, lmax) in ax_selec]) 
    #expand limits of all subplots according to maximum spread
    for ax, (lmin, lmax) in ax_selec:
        d = max_delta
        ax.set_ylim(lmin-(max_delta-(lmax-lmin))/2, lmax+(max_delta-(lmax-lmin))/2)


    d = .015  
    kwargs = dict(transform=ax1.transAxes, color='k', clip_on=False)
    ax1.plot((-d, +d), (-d, +d), **kwargs)
    kwargs.update(transform=ax2.transAxes)
    ax2.plot((-d, +d), (1 - d, 1 + d), **kwargs) 

    fig.set_size_inches(3,4)
    
    return fig

fig = plot_broken_bar (model_name_arr, indist_mn_arr, indist_er_arr, 'tab:gray')
plt.show()

fig = plot_broken_bar (model_name_arr, outdist_mn_arr, outdist_er_arr, 'tab:gray')
plt.show()

fig = plot_broken_bar (bar_model_name_arr, bar_mn_arr, bar_er_arr, 'tab:gray', offset=-2)
plt.show()

In [None]:
bar_model_name_arr, bar_mse_mn, bar_mse_er, bar_mse_raw, bar_com_mn, bar_com_er, _ = \
    get_mse_by_model_bar (model_beta, copy_frame=False)

fig = plt.figure()
plt.errorbar(bar_model_name_arr, bar_mse_mn, yerr=bar_mse_er, c='black')
plt.plot([0, bar_model_name_arr[-1]], [bar_mse_mn[0], bar_mse_mn[0]], '--', c='black')
plt.xlabel('Beta')
plt.ylabel('Multi-frame MSE')
format_plot(fontsize=20)
fig.set_size_inches(3, 4)
plt.show()

fig = plt.figure()
plt.errorbar(bar_model_name_arr, bar_com_mn, yerr=bar_com_er, c='black')
plt.plot([0, bar_model_name_arr[-1]], [bar_com_mn[0], bar_com_mn[0]], '--', c='black')
plt.xlabel('Beta')
plt.ylabel('Centre-of-mass MAE')
format_plot(fontsize=20)
fig.set_size_inches(3, 4)
plt.show()

fig = plt.figure()
ax1 = plt.gca()
ax2 = ax1.twinx()

ax1.errorbar(bar_model_name_arr, bar_mse_mn, yerr=bar_mse_er, c='tab:blue')
ax2.errorbar(bar_model_name_arr, bar_com_mn, yerr=bar_com_er, c='tab:orange')
ax1.set_xlabel('Beta')
ax1.set_ylabel('Multi-frame MSE', color='tab:blue')
ax2.set_ylabel('Centre-of-mass MAE', color='tab:orange')

format_plot(ax1, fontsize=20)
format_plot(ax2, fontsize=20)
ax2.spines['right'].set_visible(True)
fig.set_size_inches(3, 4)
plt.show()

In [None]:
bar_model_name_arr, indist_mn_arr, indist_er_arr, indist_raw_arr = \
    get_mse_by_model (model_beta, indist_data_loader, plot_by_L1=False)

bar_model_name_arr, outdist_mn_arr, outdist_er_arr, outdist_raw_arr = \
    get_mse_by_model (model_beta, outdist_data_loader, plot_by_L1=False)

bar_model_name_arr = bar_model_name_arr[:-1]
indist_mn_arr = indist_mn_arr[:-1]
indist_er_arr = indist_er_arr[:-1]
indist_raw_arr = indist_raw_arr[:-1]
outdist_mn_arr = outdist_mn_arr[:-1]
outdist_er_arr = outdist_er_arr[:-1]
outdist_raw_arr = outdist_raw_arr[:-1]

In [None]:
fig = plt.figure()
plt.errorbar(bar_model_name_arr, indist_mn_arr, yerr=indist_er_arr, c='black')
plt.plot([0, bar_model_name_arr[-1]], [indist_mn_arr[0], indist_mn_arr[0]], '--', c='tab:gray')
plt.xlabel('Beta')
plt.ylabel('Nex-frame MSE')
format_plot(fontsize=20)
fig.set_size_inches(4, 4)
plt.show()

fig = plt.figure()
plt.errorbar(bar_model_name_arr, outdist_mn_arr, yerr=outdist_er_arr, c='black')
plt.plot([0, bar_model_name_arr[-1]], [outdist_mn_arr[0], outdist_mn_arr[0]], '--', c='tab:gray')
plt.xlabel('Beta')
plt.ylabel('Nex-frame MSE')
format_plot(fontsize=20)
fig.set_size_inches(4, 4)
plt.show()

