### PyMC implementation

To implement the R2-D2 prior in PyMC, let's first simulate some data with the following characteristics:

* $p = 30$, of which 4 are different from zero.
* $n = 1000$
* $\sigma^2 = 3^2$
* $X_i \sim \text{Uniform}(-2, 2)$, but then scaled to have variance 1

In [None]:
rng = np.random.default_rng(121195)

TRUE_ALPHA = 2.5
TRUE_BETAS = pm.draw(
    pm.Normal.dist(mu=[0] * 26 + [-2, 2] * 2, sigma=[0.05] * 26 + [0.5] * 4),
    random_seed=rng
)
TRUE_SIGMA = 3

n = 1000
p = len(TRUE_BETAS)

X = pm.draw(pm.Uniform.dist(lower=-2, upper=2, shape=p), draws=n, random_seed=rng)
X_std = X / X.std(0)
y = pm.draw(pm.Normal.dist(mu=TRUE_ALPHA + X_std @ TRUE_BETAS, sigma=TRUE_SIGMA), random_seed=rng)

The classic coefficient of determination is computed via

$$
R^2 = \frac{\mathbb{V}(\boldsymbol{x}^T \boldsymbol{\beta})}{\mathbb{V}(\boldsymbol{x}^T \boldsymbol{\beta}) + \sigma^2}
$$

with

$$
\mathbb{V}(\boldsymbol{x}^T \boldsymbol{\beta}) = \boldsymbol{\beta}^T \text{Cov}(\boldsymbol{x}) \boldsymbol{\beta} = \boldsymbol{\beta}^T  \Sigma_{\boldsymbol{x}} \boldsymbol{\beta}
$$

In [None]:
Sigma = np.eye(p)
mu_var = (TRUE_BETAS.T @ Sigma @ TRUE_BETAS)
TRUE_R2 = mu_var / (mu_var + TRUE_SIGMA ** 2)
TRUE_R2

Now, let's implement the model with the R2-D2 prior in PyMC.

In [None]:
coords = {
    "predictor": np.arange(p),
    "__obs__": np.arange(n)
}

with pm.Model(coords=coords) as model_r2d2:
    R2 = pm.Beta("R2", alpha=2, beta=2)
    phi = pm.Dirichlet("phi", a=np.ones(p), dims="predictor")
    W = pm.Deterministic("W", R2 / (1 - R2))

    sigma_squared = pm.Gamma("sigma_squared", mu=9, sigma=3) # Informative
    alpha = pm.Normal("alpha")
    beta = pm.Normal("beta", mu=0, sigma=(phi * W * sigma_squared) ** 0.5, dims="predictor")
    mu = pm.Deterministic("mu", alpha + (X_std @ beta), dims="__obs__")

    pm.Normal("y", mu=mu, sigma=sigma_squared ** 0.5, observed=y, dims="__obs__")

model_r2d2.to_graphviz()

In [None]:
with model_r2d2:
    idata_r2d2 = pm.sample(random_seed=121195, target_accept=0.99)

In [None]:
ax = az.plot_forest(idata_r2d2, var_names="beta", combined=True)[0]
ax.scatter(TRUE_BETAS, ax.get_yticks()[::-1], color="black", s=15, zorder=10)
ax.set(xlabel="Value", ylabel="Coefficient");

In [None]:
def get_conditional_R2(mean, variance):
    # A helper to get the conditional R^2
    mu_var = mean.var("__obs__")
    return mu_var / (mu_var + variance)

In [None]:
ax = az.plot_dist(get_conditional_R2(idata_r2d2.posterior["mu"], idata_r2d2.posterior["sigma_squared"]), label="Conditional $R^2$")
ax.axvline(TRUE_R2, color="0.2", ls="--")
ax.set(xlabel="$R^2$", ylabel="Density");