This notebook contains sample code that loads in neural networks optimized on training samples in the TR360 optic flow dataset and computes the accuracy on the test set. This code accompanies the paper:

Layton, OW & Steinmetz, ST (2024). Accuracy optimized neural networks do not effectively model optic flow tuning in brain area MSTd. *Frontiers in Neuroscience*.

Note: This code assumes that the optic flow datasets and the CNN model available on Hugging Face have been downloaded.

In [None]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf

from losses import CircularLoss, MSE
from weight_initializers import GlorotUniformNonNegative
from acc import mae, mse

np.set_printoptions(suppress=True)

# Automatically reload external modules
%load_ext autoreload
%autoreload 2

Load in the TR360 dataset (in `.npz` format). The dataset comes in two parts, one for the back plane environment, the other for the ground plane environment.

Note: For the MLP, the optic flow features should be flattened before proceeding. For example, the shape of the TR360 test set should be `(3015, 450)`.

In [None]:
# For this example, assumes the .npz files are located in datasets/TR360
ds_subfolder = 'datasets'
ds_name = 'TR360'
ds_path = os.path.join(ds_subfolder, ds_name)

# Load in the back plane env
ds_back = np.load(os.path.join(ds_path, 'optic_flow_ds_backplane.npz'), allow_pickle=True)
ds_back = ds_back['data']
# Load in the ground plane env
ds_ground = np.load(os.path.join(ds_path, 'optic_flow_ds_groundplane.npz'), allow_pickle=True)
ds_ground = ds_ground['data']

# Combine the samples into one ndarray
ds = np.vstack([ds_back, ds_ground])
# There is only one frame/sample, so take it
ds = ds[:, 0]
N, Iy, Ix, n_chans = ds.shape
print(f'Dataset has shape {ds.shape}')

# Shuffle the dataset
np.random.seed(0)
inds = np.arange(N)
np.random.shuffle(inds)
ds = ds[inds]
# The test set is the last 25% of samples (last 3015 samples)
ds_test = ds[-3015:]
print(f'Test set has shape {ds_test.shape}')

Load in TR360 self-motion labels

In [None]:
# Headers we wish to extract from the labels CSV file
labels2load = ['obs_heading_x', 'obs_heading_y', 'obs_rot_x', 'obs_rot_y', 'obs_rot_z', 'obs_rot_r']

# Load in back plane labels
labels_back = pd.read_csv(os.path.join(ds_path, 'labels_backplane.csv'))
# Load in ground plane labels
labels_ground = pd.read_csv(os.path.join(ds_path, 'labels_groundplane.csv'))
# Join them into one labels Dataframe
labels = pd.concat([labels_back, labels_ground], axis=0)

# Rotation vector is stored in normalized format. Combine rotation rate with direction.
labels.obs_rot_x = labels.obs_rot_r * labels.obs_rot_x
labels.obs_rot_y = labels.obs_rot_r * labels.obs_rot_y
labels.obs_rot_z = labels.obs_rot_r * labels.obs_rot_z

# Extract the labels that the network will predict
output_labels = ['obs_heading_x', 'obs_heading_y', 'obs_rot_x', 'obs_rot_y', 'obs_rot_z']
labels = labels.loc[:, output_labels]
# Shuffle the labels in the same order as samples
labels = labels.iloc[inds]
# Record each label's mins/max across the training set. Net is trained on normalized labels
# so this is needed to recover original scale of net predictions
train_label_mins = labels[:-3015].min(axis=0).to_numpy()
train_label_maxs = labels[:-3015].max(axis=0).to_numpy()

# Get test set labels 
test_labels = labels.iloc[-3015:]
print(f'Test set labels have shape {test_labels.shape}')


Load in pretrained model in `.keras` format. "CNN" model is used for this example. 

In [None]:
# For this example, assumes the .keras file is located in models/cnn
model_subfolder = 'models'
model = 'cnn'
model_filename = 'cnn_TR_15x15_6k_analytic'
model_full_path = os.path.join(model_subfolder, model, model_filename)

net = tf.keras.models.load_model(filepath=model_full_path + '.keras',
                                 custom_objects={'CircularLoss': CircularLoss(exp=1),
                                                 'MSE': MSE(exp=2),
                                                 'GlorotUniformNonNegative': GlorotUniformNonNegative()})
print(net.summary())

Next, obtain the model predictions on the TR360 test set.

In [None]:
pred_labels_norm = np.hstack(net(ds_test))
# Predictions are on normalized scale. Undo and recover original scale of labels
pred_labels = (pred_labels_norm + 0.5)*(train_label_maxs-train_label_mins) + train_label_mins

Compute the test MAE and MSE

In [None]:
print('MAE:')
print(mae(pred_labels, test_labels, circ_correction=True))
print('\nMSE:')
print(mse(pred_labels, test_labels, circ_correction=True))