In [10]:
import numpy as np
import torch
import sys

sys.path.append("../../")
from vi_rnn.vae import VAE
from vi_rnn.train import train_VAE
from vi_rnn.datasets import SineWave, Oscillations_Poisson
from torch.utils.data import DataLoader
from py_rnn.model import RNN
from vi_rnn.utils import *
from py_rnn.train import train_rnn
from py_rnn.train import save_rnn, load_rnn
import matplotlib.pyplot as plt
from vi_rnn.saving import save_model
from py_rnn.default_params import get_default_params
from vi_rnn.datasets import Basic_dataset_with_trials
%matplotlib inline

In [11]:
train_teacher = False  # load already trained teacher model
data_dir = "../../data/"  # store inferred model
model_dir = "../../models/checker/"  # store teacher RNN
cuda = True  # toggle if GPU is available

In [12]:
# vars(task)

In [13]:
import scipy.io
dataset = scipy.io.loadmat(data_dir + 'dataset.mat')

dataset = dataset['dataset']
# Accessing the fields 
data_all = dataset[0, 0]['data']
inputs = dataset[0, 0]['input']

data_all = data_all.astype(np.float32)
inputs = inputs.astype(np.float32)

In [14]:
n_trials, dim_x, seq_len = data_all.shape

# split into train and eval
train_inds = np.full((n_trials,), False)
train_inds[np.random.choice(np.arange(n_trials), size=500, replace=False)] = True


data_train = data_all[train_inds]
data_eval = data_all[~train_inds]
stim_train = inputs[train_inds]
stim_eval = inputs[~train_inds]


In [15]:
stim_eval.shape

(262, 2, 1500)

In [20]:

# initialise a dataset class
task_params = {
    "name": "checker_spikes",
    "dur": 1500,  # we will sample pseudo trials of "dur" timesteps during training
    "n_trials": data_all.shape[0],  # every epoch consists of 256 psuedo trials
    "n_neurons": data_all.shape[1],
    "out": "currents",
    "non_lin": torch.nn.ReLU(),
    "obs_rectify": "softplus",    
    "w": 0.1,
    "R_z": 0.2,
    "Bias": -3,
    "B": 4  
}
# task = Basic_dataset_with_trials(
#     task_params=task_params,
#     data=data_train,
#     data_eval=data_eval,
#     stim=stim_train,  # you could additionally pass input / stimuli like this
#     stim_eval=stim_eval,
# )

task = Basic_dataset_with_trials(
    task_params=task_params,
    data=data_all,
    data_eval=data_all,
    stim=inputs,  # you could additionally pass input / stimuli like this
    stim_eval=inputs,
)

TypeError: Cannot interpret '0' as a data type

## Create a VAE RNN setup

In [9]:
# Initialise VI / student setup

dim_z = 2
dim_N = 40
dim_x = task_params["n_neurons"]
bs = 10
cuda = True
n_epochs = 1500
wandb = False
# initialise encoder
enc_params = {
    "init_kernel_sizes": [21, 11, 1],
    "nonlinearity": "gelu",
    "n_channels": [64, 64],
    "init_scale": 0.1,
    "constant_var": False,
    "padding_mode": "circular",
    "padding_location": "causal",
}


# initialise prior
rnn_params = {
    "transition": "low_rank",
    "observation": "one_to_one",
    "train_noise_z": True,
    "train_noise_z_t0": True,
    "init_noise_z": 0.1,
    "init_noise_z_t0": 1,
    "noise_z": "diag",
    "noise_z_t0": "diag",
    "identity_readout": True,
    "activation": "relu",
    "decay": 0.9,
    "readout_from": task_params["out"],
    "train_obs_bias": True,
    "train_obs_weights": True,
    "train_neuron_bias": True,
    "weight_dist": "uniform",
    "weight_scaler": 1,  # /dim_N,
    "initial_state": "trainable",
    "obs_nonlinearity": task_params["obs_rectify"],
    "obs_likelihood": "Poisson",
    "simulate_input": False,
}


training_params = {
    "lr": 1e-3,
    "lr_end": 1e-5,
    "grad_norm": 0,
    "n_epochs": n_epochs,
    "eval_epochs": 50,
    "batch_size": bs,
    "cuda": cuda,
    "smoothing": 20,
    "freq_cut_off": 10000,
    "k": 64,
    "loss_f": "smc",
    "resample": "systematic",  # , multinomial or none"
    "run_eval": True,
    "smooth_at_eval": False,
    "t_forward": 0,
    "init_state_eval": "posterior_sample",
}


VAE_params = {
    "dim_x": dim_x,
    "dim_z": dim_z,
    "dim_N": dim_N,
    "enc_architecture": "CNN",
    "enc_params": enc_params,
    "rnn_params": rnn_params,
}
vae = VAE(VAE_params)

using uniform init
using causal circular padding


In [10]:
# enc_params = {
#     "kernel_sizes": [21, 11, 1],  # kernel sizes of the CNN
#     "padding_mode": "constant",  # padding mode of the CNN (e.g., "circular", "constant", "reflect")
#     "nonlinearity": "gelu",  # "leaky_relu" or "gelu"
#     "n_channels": [
#         64,
#         64,
#     ],  # number of channels in the CNN (last one will be equal to dim_z)
#     "init_scale": 0.1,  # initial scale of the noise predicted by the encoder
#     "constant_var": False,  # whether or not to use a constant variance (as opposed to a data-dependent variance)
#     "padding_location": "acausal",
# }  # padding location of the CNN ("causal", "acausal", or "windowed")


# rnn_params = {
#     # transition and observation
#     "transition": "low_rank",  # "low_rank" or "full_rank" RNN
#     "observation": "one_to_one",  # "one_to_one" mapping between RNN and observed units or "affine" mapping from the latents
#     # observation settings
#     "readout_from": "currents",  # readout from the RNN activity before / after applying the non-linearty by setting this to "currents" / "rates" respectively.
#     "train_obs_bias": True,  # whether or not to train a bias term in the observation model
#     "train_obs_weights": True,  # whether or not train the weights of the observation model
#     "obs_nonlinearity": "softplus",  # can be used to rectify the output (e.g., when using Poisson observations, use "softplus")
#     "obs_likelihood": "Poisson",  # observation likelihood model ("Gauss" or "Poisson")
#     # transition settings
#     "activation": "relu",  # set the nonlinearity to "clipped_relu, "relu", "tanh" or "identity"
#     "decay": 0.9,  # initial decay constant, scalar between 0 and 1
#     "train_neuron_bias": True,  # train a bias term for every neuron
#     "weight_dist": "uniform",  # weight distribution ("uniform" or "gauss")
#     "initial_state": "trainable",  # initial state ("trainable", "zero", or "bias")
#     "simulate_input": False,  # set to True when using time-varying inputs
#     # noise covariances settings
#     "train_noise_z": True,  # whether or not to train the transition noise scale
#     "train_noise_z_t0": True,  # whether or not to train the initial state noise scale
#     "init_noise_z": 0.1,  # initial scale of the transition noise
#     "init_noise_z_t0": 0.1,  # initial scale of the initial state noise
#     "noise_z": "diag",  # transition noise covariance type ("full", "diag" or "scalar"), set to "full" when using the optimal proposal
#     "noise_z_t0": "diag",  # initial state noise covariance type ("full", "diag" or "scalar"), set to "full" when using the optimal proposal
# }


# VAE_params = {
#     "dim_x": data_all.shape[1],  # observation dimension (number of units in the data)
#     "dim_z": 2,  # latent dimension / rank of the RNN
#     "dim_N": data_all.shape[1],  # amount of units in the RNN (can generally be different then the observation dim)
#     "dim_u": inputs.shape[1],  # input stimulus dimension
#     "enc_architecture": "CNN",  # encoder architecture (not trained when using linear Gauss observations)
#     "enc_params": enc_params,  # encoder paramaters
#     "rnn_params": rnn_params,  # parameters of the RNN
# }

# # initialise the VAE
# vae = VAE(VAE_params)

In [10]:
training_params = {
    "lr": 1e-3,  # learning rate start
    "lr_end": 1e-5,  # learning rate end (with exponential decay)
    "n_epochs": 2,  # number of epochs to train
    "grad_norm": 0,  # gradient clipping above certain norm (if this is set to >0)
    "batch_size": 10,  # batch size
    "cuda": True,  # train on GPU
    "k": 64,  # number of particles to use
    "loss_f": "smc",  # use regular variational SMC ("smc"), or use the optimal ("opt_smc")
    "resample": "systematic",  # type of resampling "systematic", "multinomial" or "none"
    "run_eval": False,  # run an evaluation setup during training (requires additional parameters)
    "t_forward": 0,  # timesteps to predict without using the encoder
}

In [9]:
# n_epochs = 1500


# training_params = {
#     "lr": 1e-3,
#     "lr_end": 1e-5,
#     "grad_norm": 0,
#     "n_epochs": n_epochs,
#     "eval_epochs": 50,
#     "batch_size": 10,
#     "cuda": cuda,
#     "smoothing": 20,
#     "freq_cut_off": 10000,
#     "k": 64,
#     "loss_f": "smc",
#     "resample": "systematic",  # , multinomial or none"
#     "run_eval": True,
#     "smooth_at_eval": False,
#     "t_forward": 0,
#     "init_state_eval": "posterior_sample",
# }


In [9]:
vars(task)

{'task_params': {'name': 'checker_spikes',
  'dur': 1500,
  'n_trials': 762,
  'n_neurons': 105,
  'out': 'currents',
  'non_lin': ReLU(),
  'obs_rectify': 'softplus',
  'w': 0.1,
  'R_z': 0.2,
  'Bias': -3,
  'B': 4},
 'data': tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 1., 0.,  ..., 0., 0., 0.]],
 
         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 1., 0.,  ..., 0., 0., 0.]],
 
         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
        

## Train the VAE

In [8]:
# Train
wandb = False
train_VAE(vae, training_params, task, sync_wandb=wandb, out_dir=model_dir, fname=None)

Training on : cuda
Learning rate decay factor 0.10000000000000002


RuntimeError: einsum(): subscript u has size 2 for operand 1 which does not broadcast with previously seen size 0

In [None]:
save_model(vae, training_params, task_params, name=model_dir + "Sine_40_1000_new")

In [2]:
import numpy as np
import torch
import sys

sys.path.append("../../")
from torch.utils.data import DataLoader
from evaluation.calc_stats import *
from vi_rnn.datasets import *
from vi_rnn.saving import load_model
from py_rnn.train import load_rnn
from vi_rnn.generate import generate
import matplotlib.pyplot as plt
from vi_rnn.utils import *
import matplotlib as mpl
from fixed_points.find_fixed_points_analytic import find_fixed_points_analytic
from fixed_points.stability import PL_Jacobian
from matplotlib.markers import MarkerStyle
from matplotlib.colors import colorConverter as cc
import copy

%matplotlib inline

In [4]:
# load student model
name = "../../models/checker/checker_spikes_CNN_low_rank_one_to_one_dim_z_2_date_2025_05_05_T_15_33_14"
vae, training_params, task_params = load_model(
    name=name, load_encoder=False, backward_compat=True
)

vae = orthogonalise_network(vae)
rotation_matrix = torch.tensor([[0, -1], [1, 0]], dtype=torch.float32)
# vae = rotate_basis_vectors(vae,rotation_matrix)
tau, pV, pU, pB, pI = get_loadings(vae)

no out nonlinearity found, setting to identity
using uniform init
using acausal constant padding


In [13]:


# run the student RNN
with torch.no_grad():
    # run the student RNN
    Z, _, data_gen, rates_gen = generate(
        vae,
        u=task.stim,
        x=task.data,
        initial_state=torch.zeros(task.data.shape[0], 2, 1),
        k=1,
        dur=task_params["dur"],
    )
    Zn = Z.cpu().detach().numpy()[:, :, :, 0]
    data_gen = data_gen.cpu().detach().numpy()[:, :, :, 0]
    z0 = torch.zeros(task_params["n_trials"], 2, 1)

# # convert to numpy
# Z_gen = Z_gen[:,:,:,0].cpu().detach().numpy()
# data_gen = data_gen[:,:,:,0].cpu().detach().numpy()

ValueError: Expected parameter rate (Tensor of shape (500, 105, 1500, 1)) of distribution Poisson(rate: torch.Size([500, 105, 1500, 1])) to satisfy the constraint GreaterThanEq(lower_bound=0.0), but found invalid values:
tensor([[[[-7.8844e-03],
          [-6.9717e-03],
          [-2.1381e-02],
          ...,
          [ 3.6363e-02],
          [ 5.6391e-02],
          [ 4.7347e-03]],

         [[-7.9009e-03],
          [-5.3322e-03],
          [ 2.1492e-03],
          ...,
          [-1.4999e-01],
          [-1.8373e-01],
          [-1.1727e-01]],

         [[-7.8008e-03],
          [-1.1526e-02],
          [-2.3338e-02],
          ...,
          [-4.0403e-02],
          [ 1.0329e-02],
          [-9.0301e-02]],

         ...,

         [[-7.5644e-03],
          [ 2.4649e-02],
          [-5.4558e-02],
          ...,
          [-6.2531e-01],
          [-7.2382e-01],
          [-6.5345e-01]],

         [[-8.0107e-03],
          [-2.6297e-02],
          [ 3.2298e-02],
          ...,
          [-3.5448e-01],
          [-3.2412e-01],
          [-3.0396e-01]],

         [[-7.8454e-03],
          [-1.3057e-03],
          [-1.1619e-02],
          ...,
          [ 2.8337e-01],
          [ 2.5256e-01],
          [ 2.9228e-01]]],


        [[[-7.8844e-03],
          [ 1.8348e-03],
          [-8.5490e-03],
          ...,
          [-1.2472e-01],
          [-1.4848e-01],
          [-1.2777e-01]],

         [[-7.9009e-03],
          [-2.1310e-02],
          [ 1.7038e-03],
          ...,
          [ 9.4211e-02],
          [ 1.2785e-01],
          [ 1.1075e-01]],

         [[-7.8008e-03],
          [ 1.2462e-02],
          [-2.1960e-02],
          ...,
          [ 3.6356e-02],
          [-1.4442e-02],
          [ 1.1865e-02]],

         ...,

         [[-7.5644e-03],
          [-2.8889e-02],
          [ 7.1552e-02],
          ...,
          [-3.0008e-02],
          [ 2.9831e-02],
          [ 8.6983e-02]],

         [[-8.0107e-03],
          [-7.4731e-03],
          [-4.9666e-02],
          ...,
          [ 6.7771e-01],
          [ 6.7226e-01],
          [ 6.1842e-01]],

         [[-7.8454e-03],
          [-1.7068e-02],
          [ 9.5938e-03],
          ...,
          [-4.4050e-01],
          [-4.1629e-01],
          [-4.1374e-01]]],


        [[[-7.8844e-03],
          [ 1.4593e-02],
          [ 2.3425e-02],
          ...,
          [-9.4696e-02],
          [-1.1366e-01],
          [-1.1983e-01]],

         [[-7.9009e-03],
          [-3.3414e-02],
          [-3.8768e-02],
          ...,
          [ 7.1724e-02],
          [ 9.2907e-02],
          [ 1.0027e-01]],

         [[-7.8008e-03],
          [ 3.0978e-02],
          [ 3.9349e-02],
          ...,
          [ 9.1242e-02],
          [ 5.9028e-02],
          [ 4.7855e-02]],

         ...,

         [[-7.5644e-03],
          [-7.7383e-03],
          [ 3.3940e-02],
          ...,
          [-1.3807e-01],
          [-1.4097e-01],
          [-1.3771e-01]],

         [[-8.0107e-03],
          [-3.3105e-02],
          [-6.5338e-02],
          ...,
          [ 7.4172e-01],
          [ 7.6453e-01],
          [ 7.6970e-01]],

         [[-7.8454e-03],
          [-1.8538e-02],
          [-1.3704e-02],
          ...,
          [-4.6073e-01],
          [-4.5237e-01],
          [-4.4874e-01]]],


        ...,


        [[[-7.8844e-03],
          [-2.9675e-02],
          [-3.5647e-02],
          ...,
          [-2.0682e-01],
          [-2.0479e-01],
          [-1.9202e-01]],

         [[-7.9009e-03],
          [ 2.4574e-02],
          [ 2.0738e-02],
          ...,
          [ 1.7345e-03],
          [-1.2489e-02],
          [-2.4164e-02]],

         [[-7.8008e-03],
          [-5.6778e-02],
          [-5.1474e-02],
          ...,
          [ 6.6567e-03],
          [ 2.7683e-02],
          [ 4.5570e-02]],

         ...,

         [[-7.5644e-03],
          [ 6.1803e-02],
          [-3.3035e-02],
          ...,
          [ 2.7358e-03],
          [-1.0383e-01],
          [-7.8797e-02]],

         [[-8.0107e-03],
          [-2.0768e-02],
          [ 3.6750e-02],
          ...,
          [ 4.3413e-01],
          [ 4.8897e-01],
          [ 4.6125e-01]],

         [[-7.8454e-03],
          [ 1.7499e-02],
          [-1.9774e-04],
          ...,
          [-5.3369e-01],
          [-5.5772e-01],
          [-5.5835e-01]]],


        [[[-7.8844e-03],
          [ 8.2335e-03],
          [ 4.3732e-03],
          ...,
          [ 1.7150e-01],
          [ 1.5824e-01],
          [ 1.5769e-01]],

         [[-7.9009e-03],
          [-3.7114e-02],
          [-2.2520e-02],
          ...,
          [-1.5889e-01],
          [-1.3749e-01],
          [-1.2796e-01]],

         [[-7.8008e-03],
          [ 3.6060e-02],
          [ 1.4386e-02],
          ...,
          [ 1.7872e-01],
          [ 1.4651e-01],
          [ 1.3247e-01]],

         ...,

         [[-7.5644e-03],
          [-1.0529e-01],
          [-1.3976e-02],
          ...,
          [-1.4192e+00],
          [-1.3623e+00],
          [-1.2827e+00]],

         [[-8.0107e-03],
          [ 2.6301e-02],
          [-1.8310e-02],
          ...,
          [ 3.1139e-01],
          [ 2.9577e-01],
          [ 2.5370e-01]],

         [[-7.8454e-03],
          [-3.6638e-02],
          [-1.5044e-02],
          ...,
          [ 2.1899e-01],
          [ 2.3758e-01],
          [ 2.5508e-01]]],


        [[[-7.8844e-03],
          [ 9.0576e-03],
          [-2.5025e-02],
          ...,
          [ 1.2979e-01],
          [ 1.6275e-01],
          [ 1.5967e-01]],

         [[-7.9009e-03],
          [-2.7540e-02],
          [ 1.4040e-02],
          ...,
          [-1.2109e-01],
          [-1.5186e-01],
          [-1.6329e-01]],

         [[-7.8008e-03],
          [ 2.2029e-02],
          [-4.1026e-02],
          ...,
          [ 3.4524e-02],
          [ 8.1615e-02],
          [ 9.8247e-02]],

         ...,

         [[-7.5644e-03],
          [-1.1349e-02],
          [ 1.4777e-02],
          ...,
          [-5.9795e-01],
          [-5.3877e-01],
          [-6.7216e-01]],

         [[-8.0107e-03],
          [-2.4967e-02],
          [-7.7623e-04],
          ...,
          [-2.5235e-01],
          [-3.2100e-01],
          [-2.4606e-01]],

         [[-7.8454e-03],
          [-1.6696e-02],
          [ 5.1157e-03],
          ...,
          [ 3.4195e-01],
          [ 3.3913e-01],
          [ 3.1172e-01]]]])

In [None]:
# Get all fixed points of student RNN
D_list, D_inds, z_list, n_inverses = find_fixed_points_analytic(
    np.array([tau, tau]), pV, pU, 0, pB
)

In [None]:
# extract phase planes

u_in = np.zeros(2)
xlims = 3
ylims = 3
# X, Y, uGT, vGT, normGT = extract_phase_plane_rnn(rnn_reaching, xlims, ylims, inp=u_in)
X, Y, u, v, norm = extract_phase_plane_vae(vae, xlims, ylims, inp=None)

In [None]:
# Make panel c

# marker style
dot_s = 10
dot_z = 100
dot_ew = 0.4
dot_s_st = 20
dot_fill = "gainsboro"



T1 = 0



with mpl.rc_context(fname="matplotlibrc"):
    fig, ax = plt.subplots(2, 2, figsize=(2, 2), dpi=200)


    # Student
    ax[1, 0].imshow(
        norm,
        extent=[-xlims, xlims, -ylims, ylims],
        origin="lower",
        cmap="bone",
        vmax=np.max(norm),
        aspect="auto",
    )
    ax[1, 0].streamplot(
        X, Y, u, v, color="lavender", density=0.5, linewidth=0.5, arrowsize=0.5
    )
    ax[1, 0].set_box_aspect(1)
    ax[1, 0].spines[["right", "top"]].set_visible(False)
    ax[1, 0].set_xlim(-xlims, xlims)
    ax[1, 0].set_ylim(-ylims, ylims)
    ax[1, 0].set_xticks([-xlims, xlims])
    ax[1, 0].set_yticks([-ylims, ylims])
    ax[1, 0].set_xlabel(r"$z_1$")
    ax[1, 0].set_ylabel(r"$z_2$")


    ax[1, 1].set_box_aspect(1)
    ax[1, 1].axis("off")

    # for legend
    incl_stable = False
    incl_saddle = False
    incl_unstable = False
    """"""
    # Calculate stability of fixed points and plots
    for z in z_list:
        e, vec = np.linalg.eig(PL_Jacobian(pV, pU, -pB, np.diag(np.ones(2) * tau), z))
        if abs(e[0]) > 1 and abs(e[1]) > 1:
            if incl_unstable:
                ax[1, 0].scatter(
                    z[0],
                    z[1],
                    s=dot_s,
                    c=dot_fill,
                    edgecolor="black",
                    lw=dot_ew,
                    marker=MarkerStyle("o"),
                    zorder=dot_z,
                )
            else:
                incl_unstable = True
                ax[1, 0].scatter(
                    z[0],
                    z[1],
                    s=dot_s,
                    c=dot_fill,
                    edgecolor="black",
                    lw=dot_ew,
                    marker=MarkerStyle("o"),
                    zorder=dot_z,
                    label="unstable",
                )
        elif abs(e[0]) < 1 and abs(e[1]) < 1:
            if incl_stable:
                ax[1, 0].scatter(
                    z[0],
                    z[1],
                    s=dot_s,
                    c="black",
                    edgecolor="black",
                    lw=dot_ew,
                    marker=MarkerStyle("o"),
                    zorder=dot_z,
                )
            else:
                incl_stable = True
                ax[1, 0].scatter(
                    z[0],
                    z[1],
                    s=dot_s,
                    c="black",
                    edgecolor="black",
                    lw=dot_ew,
                    marker=MarkerStyle("o"),
                    zorder=dot_z,
                    label="stable",
                )
        else:  # saddle
            if incl_saddle:
                ax[1, 0].scatter(
                    z[0],
                    z[1],
                    s=dot_s,
                    c=dot_fill,
                    edgecolor="black",
                    lw=dot_ew,
                    marker=MarkerStyle("o", fillstyle="right"),
                    zorder=dot_z - 1,
                )
                ax[1, 0].scatter(
                    z[0],
                    z[1],
                    s=dot_s,
                    c="black",
                    edgecolor="black",
                    lw=dot_ew,
                    marker=MarkerStyle("o", fillstyle="left"),
                    zorder=dot_z - 1,
                )
            else:
                incl_saddle = True
                ax[1, 0].scatter(
                    z[0],
                    z[1],
                    s=dot_s,
                    c=dot_fill,
                    edgecolor="black",
                    lw=dot_ew,
                    marker=MarkerStyle("o", fillstyle="right"),
                    zorder=dot_z - 1,
                    label="saddle",
                )
                ax[1, 0].scatter(
                    z[0],
                    z[1],
                    s=dot_s,
                    c="black",
                    edgecolor="black",
                    lw=dot_ew,
                    marker=MarkerStyle("o", fillstyle="left"),
                    zorder=dot_z - 1,
                    label="saddle",
                )
    ax[1, 0].legend(title="fixed points", loc="upper right", bbox_to_anchor=(2, 1.1))
