## Layered Searchlight Simulation

This notebook uses the layered searchlight model to analyze a simulated spatial dataset, specified in `examples/bin/simulation_dataloader.py`. The dataloader for the full MTurk1 project data is a drop in replacement for this simulation dataloader.

In [None]:
from neurotools import decoding, embed, geometry, util
from bin.simulation_dataloader import SimulationDataloader
import torch
import numpy as np
from matplotlib import pyplot as plt
if torch.cuda.is_available():
    dev = "cuda"
else:
    dev = "cpu"


We define a data loader that will embed information allowing us to distiguish between 12 classes at 3 locations in a 16 x 16 x 16 space. This 16x16x16 space will be called the "full space" and is analogous to the voxel space in MRI data. The 12 classes are layed out in a circle on a low dimension manifold in this space, plus some degree of random noise. We plot an example of what this low dimension embedding of the classes might look like in a 3D feature space, and then the locations of class representaions in a 2D slice of the full space. Locations where there are only representations of set A are shown in teal, and locations where there are joint representaions of set A and B are shown in yellow.

We expect to be able to decode A in all three of the above locations. A classifier trained on A should be able to also decode examples drawn from B in the bottom right yellow colored region.

In order to operate in as close a setting to the full fMRI decoding task as possible, we also impose the constraint that only some classes should be directly compared to each other. This is generally because only data for some classes was collected under comparable conditions. We define allowed comparisons via a pairwise matrix, with ones at the index of allowed comparisons, and zeroes elsewhere.

In [None]:
# create pairwise:
main_light = [1, 3, 5]
other_light = [2, 4, 6]
main_dark = [7, 9, 11]
other_dark = [8, 10, 12]
# we only want comparisons within a set and within luminance levels
set_wieghts = []
pair_weights = torch.empty((12, 12, 12))
for item_set in [main_light, other_light, main_dark, other_dark]:
    rows = torch.zeros((12, 12))
    cols = torch.zeros((12, 12))
    ind = torch.tensor(item_set) - 1
    cols[:, ind] = 1
    rows[ind, :] = 1
    weights = torch.logical_and(cols, rows).float()
    for t in item_set:
        pair_weights[t - 1] = weights
pm = pair_weights

Another feature of the searchlight decoder is that it allows us to pool the results of the searchlight at different subregions of the full input. THis is useful for generating quantative measures of stats like accuracy in specific regions. We can efine regions via an atlas which is constructed below. We plot a slice through the atlas, with purple background, blue roi_1, cyan roi_2, green roi_3, and yellow roi_4. We expect cross decoding of B from A only in roi_2.

In [None]:
atlas = np.zeros((16, 16, 16))
atlas[:8, :12, :] = 1 # Blue
atlas[8:, 4:, :] = 2 # cyan
lookup = {1: "roi_1", 2:"roi_2"}
plt.imshow(atlas[:, :, 8])

In [None]:
n_classes = 12
vdl = SimulationDataloader(difficulty=0, seed=8, num_examples=100, batch_size=70)
vdl.plot_circle_embedding()
vdl.plot_templates()

Now we need to define some variables to initialize the layered searchlight model!
- BASE_KERNEL_SIZE is the dimensions of the cube of data sampled by each filter in each layer.
- N_LAYERS is the total number of layers.
Note that the amount of input data availble to each searchlight spot is a cube of size is fully determined by theses two parameters. In this case, each searchlight spot "sees" a 5x5x5 cube of input data.

In [None]:
BASE_KERNEL_SIZE = 2
N_LAYERS = 4
xdecoder = decoding.ROISearchlightDecoder(atlas, lookup, set_names=("a", "b"), in_channels=1, n_classes=12, spatial=(16, 16, 16), nonlinear=True, device=dev, base_kernel_size=BASE_KERNEL_SIZE, n_layers=N_LAYERS, dropout_prob=0.5, seed=8, share_conv=False, pairwise_comp=pm)
print("Total model parameters:", xdecoder.get_model_size())

Now we'll initialize our simulation dataloader for training the layered searchlight. We'll set the difficulty parameter to 2, which adds  noise and affine jitter to the classes. This is a very low level of noise, so classification will be fairly easy. The data loader is set so calling `vdl[0]` will give an iterator over batches of examples from set A, and calling `vdl[1]` will give an iterator on set B. We will train on A then later test cross decoding on B

In [None]:
# scaling factor on amount of noise and degree of affine jitter
difficulty = 2
vdl= SimulationDataloader(difficulty=difficulty, seed=8, num_examples=120, batch_size=70)

The Searchlight decoder's `fit` method expects an iterator the returns a tuple of numpy arrays `(stimulus: <batch, channels, width, height, depth>, target_class: <batch>)`, this is the form of the iterators returned from the SimulationDataloader object (called `vdl` above).

The xdecoder train searchligh and train predictor state variable can be set to deterine the training routine. Training the searchlight means attempting to predict class identity as well as possible at every location in the full space. Training the predictor sets the weights on each searchlight spot based on the cofidence of it's predictions. These weights will be important for determining accuracy and getting representaional geometries over whole rois. Generally, the searchlight should be trained first, though they could also be set in tandem.

In [None]:
xdecoder.train_searchlight("a", True)
xdecoder.train_predictors("a", False)
s_lh = xdecoder.fit(vdl.batch_iterator("A", 1500), lr=.02)
#
xdecoder.train_searchlight("a", False)
xdecoder.train_predictors("a", True)
p_lh = xdecoder.fit(vdl.batch_iterator("A", 1000), lr=.02)

fig, ax = plt.subplots(2)
ax[0].plot(s_lh)
ax[1].plot(p_lh)

The model prints the sum of the difficulty reweighted cross entropy loss at every batch, only with the value of the regularization and the average accuracy over all searchlight spots.

Now we have a trained model. First we look at the confidence (i.e. the weights) over the full space. Note that they align with where there is signal from set A.

In [None]:
sal_map = xdecoder.get_saliancy()
plt.imshow(sal_map[:, :, 8], vmin=0, vmax=.07)

By running the searchlight's predict method, we will get measures of maps of searchlight accuracy over the whole space on a validation set, as well as quantifications of accuracy in each ROI and combined over the full space. We can initialize a simulation dataloader with a new seed to get new examples for the validation set.

In [None]:
vdl = SimulationDataloader(difficulty=difficulty, seed=16, num_examples=100, batch_size=70)
roi_accs, acc_map, gradient_map = xdecoder.predict(vdl.batch_iterator("A", 50))
print(roi_accs)

We can now plot the accuracy map over the space from the searchlight

In [None]:
plt.imshow(acc_map[:, :, 6:10].mean(axis=2), vmin=.33, vmax=.55)
print("Min Acc:", np.min(acc_map).squeeze())
print("Max Acc:", np.max(acc_map).squeeze())

So were pretty great at predicting the exemplars of set A on a validation set. But what about cross decoding? Let's try predicting on exemplars of set B by accessing `vdl[1]`

In [None]:
roi_accs, acc_map, gradient_map = xdecoder.predict(vdl.batch_iterator("B", 50))
print(roi_accs)

Performance is a bit worse, and we can see is only greater than chance in roi 2, the region that overlaps the bottom right hand side. Let's look at the map of cross decoding, and see that it looks as expected

In [None]:
plt.imshow(acc_map[:, :, 6:10].mean(axis=2), vmin=1/3, vmax=.46)
print("Min Acc:", np.min(acc_map).squeeze())
print("Max Acc:", np.max(acc_map).squeeze())

So now we know that we can train a model on Set A and accurately classify in three regions that have significant Set A signal. We can then attempt to cross decode set B with the same trained model, and show that exemplars of Set B can be cross decoded from Set A in one of these regions.

Our final question looks at wether we can recover the intial geometry of the class embeddings from the models latent space. Calling the searchlight's `get_latent` method will give us a cosine distance matrix between the classes for each searchlight spot, and also give the distance matrices over each ROI, weighted by the confidence at each spot. We can do this for both Set A and Set B

In [None]:
rdm, roi_rdm = xdecoder.get_latent(vdl.batch_iterator("B", 50), metric="pearson", voxelwise=True)

The rdm is returned as the upper triangle of the distance matrix between the 12 classes. However, we know a ground truth geometry (a circle) only for to sets of 6 classes, so we need to break up the rdm with some utility functions. For now, we will focus on the RDM for roi 2, where we saw significant cross decoding.

In [None]:
r_rdm = roi_rdm["roi_1"]
square_rdm = util.triu_to_square(torch.from_numpy(r_rdm), 12).numpy()
square_rdm = square_rdm[:, 6:, 6:].squeeze()
triu_ind = np.triu_indices(6, 1)
sub_r_rdm = square_rdm[triu_ind]
print(sub_r_rdm)

We can now initial an MDS algorithm to determine the optimal geometry given the RDM.

_Note: depending on random initialization, MDS may rarely plateau in some local minima, if loss stabilizes at a value greater than about 10000, you should rerun_

In [None]:
mds = embed.MDScale(n=6, embed_dims=2, initialization="pca")
embedding = mds.fit_transform(sub_r_rdm)

After the MDS has finished we can plot the embedding, with some color coding indicating expected order.

In [None]:
colors = ["#eb8ba7", "#caa65a", "#8dba7b", "#56bccd", "#9da3fe", "#e385f0"]
plt.scatter(embedding[:, 0], embedding[:, 1], c=colors)

This is the expected order! We can also quantify this by taking the pearson correlation of the rank order of the distances between the points in an ideal circle (accounting for ties) with the rank order of the distances from the model. (i.e. the tie corrected spearman correlation of the distance matrices)

In [None]:
if type(sub_r_rdm) is np.ndarray:
    sub_r_rdm = torch.from_numpy(sub_r_rdm)
rho  = geometry.circle_corr(sub_r_rdm.unsqueeze(0), 6, metric="rho")

This correlation is high, indication that the rdm from the model strongly agrees with the expected circular ordering.

The error bars of this rho tell us our confidence in the value given the model. However, we need to consider the possibility that a dissimilarity structure that matches color space could occur at random. The simulation below establish the upper 95% quantile of the distribution of random rankings with the color space ranking for both the light and dark samples.

In [None]:
print(rho)
samp = []
for i in range(1000):
    n = 2
    s1 = [geometry.circle_corr(torch.rand_like(sub_r_rdm).unsqueeze(0), 6, metric="rho").detach().numpy() for _ in range(n)]
    samp.append(np.stack(s1).sum(axis=0) / n)
samp = np.stack(samp)
print(samp.mean())
print(np.quantile(samp, .95))