In [1]:
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=6'

import jax
import jax.numpy as jnp

from functions import *
from primitives import weighted_quantile
from primitives import effective_sample_size

import numpy as np

from matplotlib import pyplot as plt
    
import scipy

import pandas as pd
from tqdm import tqdm
import arviz

import pickle

apply_sel = lambda data, cc: [item.iloc[cc,] for item in data]

base_file_pres2007 = "../data/source/presidentielle_2007/"
base_file_pres2022 = "../data/source/presidentielle_2022/"
base_file_legi2024 = "../data/source/legislatives_2024/"
base_file_save = "../data/results/inference_"
device_parallel = True

# base_file_legi2024 = ""
#base_file_pres2022 = ""
# base_file_save =  "inference_"
#device_parallel = False


# Synthetic data

## Figure 1 – Global comparison

In [10]:
est1 = jax.jit(lambda *args: estimate_small(10, *args))
size_bur =100
est_func = est1
K = 100
true_sampling, I = make_probability_X(size_mat=3)
n = jnp.ones(K)*size_bur
true_data, ecological, ecological_full, A = construct_synthetic(true_sampling, I, n = n)
pars_func = lambda pars: jnp.repeat(softmax1(pars)[None,:], repeats=K, axis=0) # fixed pars func 

# Definition dens_exact
find_tilt = jax.jit(lambda pis: find_optimal_tilt(pis, ecological, n, A, Niter=5, lr=1.))
find_notilt = jax.jit(lambda pis: find_optimal_tilt(pis, ecological, n, A, Niter=5, lr=0.))
estimate_prob = jax.jit(lambda keys, tilt, pis: jax.vmap(tilted_estimate, (0, 0, 0, 0, 0, None, None))(keys, tilt, pis, ecological, n, A, est_func))

dens_exact = lambda pars, key: marginal_likelihood(key, pars, pars_func, find_tilt, estimate_prob)

In [16]:
def make_sim(size_bur, est_func):
    K = 100
    true_sampling, I = make_probability_X(size_mat=3, base=1.)
    n = jnp.ones(K)*size_bur
    true_data, ecological, ecological_full, A = construct_synthetic(true_sampling, I, n = n)
    pars_func = lambda pars: jnp.repeat(softmax1(pars)[None,:], repeats=K, axis=0) # fixed pars func 

    # Definition dens_exact
    find_tilt = jax.jit(lambda pis: find_optimal_tilt(pis, ecological, n, A, Niter=5, lr=1.))
    find_notilt = jax.jit(lambda pis: find_optimal_tilt(pis, ecological, n, A, Niter=5, lr=0.))
    estimate_prob = jax.jit(lambda keys, tilt, pis: jax.vmap(tilted_estimate, (0, 0, 0, 0, 0, None, None))(keys, tilt, pis, ecological, n, A, est_func))
    
    dens_exact = lambda pars, key: marginal_likelihood(key, pars, pars_func, find_tilt, estimate_prob)
    dens_exact_notilt = lambda pars, key: marginal_likelihood(key, pars, pars_func, find_notilt, estimate_prob)

    val_pars = jnp.repeat(inv_softmax1(true_sampling)[None,:], repeats=1000, axis=0)
    keys = jax.random.split(jax.random.key(1), val_pars.shape[0])

    Q_tilt = jax.vmap(dens_exact, (0, 0))(val_pars, keys)[2]
    Q_notilt = jax.vmap(dens_exact_notilt, (0, 0))(val_pars, keys)[2]
    
    return np.array([np.std(Q_notilt, 0).mean()/Q_tilt.mean(), np.std(Q_tilt, 0).mean()/Q_tilt.mean()])
    #return np.array([np.mean(Q_notilt<0), np.mean(Q_tilt<0)])

est1 = jax.jit(lambda *args: estimate_small(10, *args))
est2 = jax.jit(lambda *args: estimate_medium(10, *args))

res_small = np.array([make_sim(10, est1),
 make_sim(100, est1),
 make_sim(1000, est1)])

res_medium = np.array([make_sim(10, est2),
              make_sim(100, est2),
              make_sim(1000, est2)])

In [19]:
tt = np.hstack([res_small, res_medium])
res = pd.DataFrame(tt, 
             index=[10, 100, 1000], columns=["Uniform, no tilt", "Uniform, tilt", "Normal, no tilt", "Normal, tilt"]).map(lambda x: f"{x:.{2}e}")
#res.to_latex(buf="../docs/figures/comparison_standard_error.tex")
res.to_csv("../data/for_figures/comparison_standard_error.csv")

## Figure 2 – Comparison given the size of cells

En fonction de la taille des cellules.

In [None]:

true_sampling, I = make_probability_X(size_mat=3, base=1.)
n = jnp.arange(5, 51, step=5)
true_data, ecological, ecological_full, A = construct_synthetic(true_sampling, I, n = n)
Nsim = 40000

# Approche uniforme 

pars_func = jax.jit(lambda pars: jnp.repeat(softmax1(pars)[None,:], repeats=n.shape[0], axis=0))
find_tilt = jax.jit(lambda pis: find_optimal_tilt(pis, ecological, n, A, Niter=5, lr=1.))
est1 = jax.jit(lambda *args: estimate_small(Nsim, *args))
estimate_prob = jax.jit(lambda keys, tilt, pis: jax.vmap(tilted_estimate, (0, 0, 0, 0, 0, None, None))(keys, tilt, pis, ecological, n, A, est1))
dens_exact = lambda pars, key: marginal_likelihood(key, pars, pars_func, find_tilt, estimate_prob)

val_pars = jnp.repeat(inv_softmax1(true_sampling)[None,:], repeats=500, axis=0)
keys = jax.random.split(jax.random.key(1), val_pars.shape[0])
Q_small, s_small,_ = jax.vmap(dens_exact, (0, 0))(val_pars, keys)

x_small = n
y_small = np.array(Q_small.std(0))
y_small[s_small.mean(0)<1] = jnp.nan

# Approche moyenne

#n = jnp.array([1000])
true_data, ecological, ecological_full, A = construct_synthetic(true_sampling, I, n = n)

find_tilt = jax.jit(lambda pis: find_optimal_tilt(pis, ecological, n, A, Niter=5, lr=1.))
est2 = jax.jit(lambda *args: estimate_medium(Nsim, *args))
estimate_prob = jax.jit(lambda keys, tilt, pis: jax.vmap(tilted_estimate, (0, 0, 0, 0, 0, None, None))(keys, tilt, pis, ecological, n, A, est2))
pars_func = jax.jit(lambda pars: jnp.repeat(softmax1(pars)[None,:], repeats=n.shape[0], axis=0))
dens_exact = lambda pars, key: marginal_likelihood(key, pars, pars_func, find_tilt, estimate_prob)

val_pars = jnp.repeat(inv_softmax1(true_sampling)[None,:], repeats=500, axis=0)
keys = jax.random.split(jax.random.key(2), val_pars.shape[0])
Q_medium, s_medium,_ = jax.vmap(dens_exact, (0, 0))(val_pars, keys)

x_medium = n#[s_medium.mean(0)==1]
y_medium = np.array(Q_medium.std(0))#[:,s_medium.mean(0)==1].std(0)
y_medium[s_medium.mean(0)<1] = jnp.nan

In [None]:
fig, ax = plt.subplots(1,2, figsize=(12/2.54, 8/2.54))
ax[0].plot(x_small, y_small, c="black")
ax[0].set_xlabel("Parameter $n$ of $\mathcal{M}(n,p)$")
ax[0].set_ylabel("Standard error of log-likelihood")
ax[0].set_title("(1) Uniform (tilt)")
ax[0].set_xticks(np.array([5,10,20,30,40,50]))

s = s_small.mean(0)
ax[0].fill_between(range(len(s)), 0, jnp.max(y_small), where=(s < 1), alpha=0.5)

ax[1].set_xlabel("Parameter $n$ of $\mathcal{M}(n,p)$")
ax[1].plot(x_medium, y_medium,c="black")
ax[1].set_title("(2) Normal (tilt)")

s = s_medium.mean(0)
ax[1].fill_between(range(len(s)), 0, jnp.max(y_medium), where=(s < 1), alpha=0.5)
ax[1].set_xticks(np.array([5,10,20,30,40,50]))

plt.style.use("ggplot")
plt.tight_layout()
#plt.savefig("../docs/figures/uniform_vs_normal.pdf")

## Figure 3 – Relation to size of table, and to asymetry of p


In [None]:
def test_asymetry(base, make_prob_func, size_bur):
    K = 200
    true_sampling, I = make_prob_func(size_mat=3, base=base)
    n = jnp.ones(K) * size_bur
    true_data, ecological, ecological_full, A = construct_synthetic(true_sampling, I, n = n)

    Nsim = 1000

    define_functions(ecological_full, jnp.zeros_like(ecological))

    pars_func = jax.jit(lambda pars: jnp.repeat(softmax1(pars)[None,:], repeats=n.shape[0], axis=0))
    est_medium = jax.jit(lambda *args: estimate_medium(Nsim, *args))
    find_tilt = jax.jit(lambda pis: find_optimal_tilt(pis, ecological, n, A, Niter=5, lr=1.))
    estimate_prob = jax.jit(lambda keys, tilt, pis: jax.vmap(tilted_estimate, (0, 0, 0, 0, 0, None, None))(keys, tilt, pis, ecological, n, A, est_medium))

    dens_exact = lambda pars, key: marginal_likelihood(key, pars, pars_func, find_tilt, estimate_prob)

    Q = jax.vmap(dens_exact)(inv_softmax1(true_sampling)[None,:].repeat(axis=0,repeats=200), jax.random.split(jax.random.key(1), 200))[0]
    return Q.std(0).mean()

In [None]:
size_bur = 3000
resX = jnp.array([test_asymetry(1., make_probability_X, size_bur), 
 test_asymetry(2., make_probability_X, size_bur), 
 test_asymetry(5., make_probability_X, size_bur),
 test_asymetry(10., make_probability_X, size_bur)])

resL = jnp.array([test_asymetry(1., make_probability_I, size_bur), 
 test_asymetry(2., make_probability_I, size_bur), 
 test_asymetry(5., make_probability_I, size_bur),
 test_asymetry(10., make_probability_I, size_bur)])


In [None]:
dt = pd.DataFrame({"ratio biggest/smallest": [1, 4, 25, 100],
    "X":resX,"I":resL, "X_relative":resX/resX[0], "I_relative":resL/resL[0]})
dt["X"] = dt["X"].map(lambda x: f"{x:.{2}e}")
dt["I"] = dt["I"].map(lambda x: f"{x:.{2}e}")
dt["X_relative"] = dt["X_relative"].map(lambda x: str(round(x,1)))
dt["I_relative"] = dt["I_relative"].map(lambda x: str(round(x,1)))
dt.index = ["1", "2", "5", "10"]
dt.to_latex("../docs/figures/sparsity.tex")

In [None]:
dt = pd.DataFrame({"Type":np.repeat(np.array(["X", "I"]), 4),
                   "Asymetry":np.tile(np.array(["1", "2", "5", "10"]), 2),
                   "Sd":np.concatenate([resX, resL])})

dt.to_csv("../data/for_figures/sparsity.csv")

## Figure 3B

In [3]:
import copy
K = 200
size_bur = 100
true_sampling, I = make_probability_I(size_mat=3, base=2.)

n = jnp.ones(K) * size_bur
true_data, _, ecological_full_or, AB = construct_synthetic(true_sampling, I, n = n)
AB, A_full = get_constraint_matrix(I)
ecological_data = copy.deepcopy(ecological_full_or)

In [23]:
AB = A_full[[0,1,3,4],:]
A_bis = A_full[[2,0,1,3,4,5],:][[0,1,3,4],:]
ecological_data_bis = copy.deepcopy([ecological_data[0][:,np.array([2,0,1])], ecological_data[1]])

In [34]:
funcs1 = define_functions(ecological_data, None, Nsim=100,A=AB,pars_function=pars_func_fixed)
res1 = jax.vmap(funcs1["dens_exact"])(jnp.ones(funcs1["sizes"]["size_pars"])[None,].repeat(3000,0), jax.random.split(jax.random.key(2), 3000))

funcs2 = define_functions(ecological_data_bis, None, Nsim=100,A=A_bis,pars_function=pars_func_fixed)
res2 = jax.vmap(funcs2["dens_exact"])(jnp.ones(funcs2["sizes"]["size_pars"])[None,].repeat(3000,0), jax.random.split(jax.random.key(2), 3000))

8
8


In [77]:
A_ter = AB.copy()
base = .1
A_ter = A_ter.at[1,].set(AB[1,]+base*AB[0,])

ecological_data_ter = copy.deepcopy(ecological_data)
ecological_data_ter[0][1,] = ecological_data_ter[0][0,] + base*ecological_data_ter[0][1,]

In [78]:
def define_density(logprior, estimate_function, Nsim, I, ecological_full, ecological, context, n, A, pars_function):

    pars_func = lambda pars: pars_function(pars, ecological, ecological_full, context, n, I)
    find_tilt = lambda pis: find_optimal_tilt(pis, ecological, n, A, Niter=2000, lr=1e-8)
    est_func = lambda *args: estimate_function(Nsim, *args)
    estimate_prob = lambda keys, tilt, pis: jax.vmap(tilted_estimate, (0, 0, 0, 0, 0, None, None))(keys, tilt, pis, ecological, n, A, est_func)
    dens_exact = lambda pars, key: marginal_likelihood(key, pars, pars_func, find_tilt, estimate_prob)[0].sum() + logprior(pars)
    logpost_diag =  lambda pars, key: marginal_likelihood(key, pars, pars_func, find_tilt, estimate_prob)

    return dens_exact, logpost_diag

def find_optimal_tilt(pis, ecological, n, A, Niter, lr):
    tilt = jnp.zeros((ecological.shape[0], A.shape[0]))
    
    body_func = lambda tilt: jax.vmap(body_func_gd, in_axes=(0, 0, 0, 0, None, None))(tilt, pis, ecological, n, A, lr)

    def loop_body(iter_idx, tilt):
        return body_func(tilt)

    final_tilt = jax.lax.fori_loop(0, Niter, loop_body, tilt)
    return final_tilt

def body_func_gd_(tilt, pis, obs, n, A, lr):
    grad_ = (jax.grad(lambda t, p, n, A: cumulant_multinom(A.T @ t, p, n))(tilt, pis, n, A) - obs) + (obs < 0.1) * jax.grad(lambda tilt: 1e-3*(tilt**2).sum())(tilt)
    return tilt - lr @ (grad_)

body_func_gd = jax.jit(body_func_gd_)

In [79]:
funcs3 = define_functions(ecological_data_ter, None, Nsim=100,A=A_ter,pars_function=pars_func_fixed)
res3 = jax.vmap(funcs3["dens_exact"])(jnp.ones(funcs3["sizes"]["size_pars"])[None,].repeat(3000,0), jax.random.split(jax.random.key(2), 3000))
print(res1.std(), res2.std(), res3.std())

8
0.08788322 0.08788322 0.09816043


## Figure 4 – Tail behavior 

In [None]:
true_sampling, I = make_probability_X(size_mat=3, base=1.)
false_sampling, _ = make_probability_X(size_mat=3, base=3.)

K = 200
n = jnp.ones(K) * 1000
_, ecological_true, _, A = construct_synthetic(true_sampling, I, n = n)
_, ecological_false, _, A = construct_synthetic(false_sampling, I, n = n)
Nsim = 1000

pars_func = jax.jit(lambda pars: jnp.repeat(softmax1(pars)[None,:], repeats=n.shape[0], axis=0))
est_func = jax.jit(lambda *args: estimate_medium(Nsim, *args))
estimate_prob_true = jax.jit(lambda keys, tilt, pis: jax.vmap(tilted_estimate, (0, 0, 0, 0, 0, None, None))(keys, tilt, pis, ecological_true, n, A, est_func))
estimate_prob_false = jax.jit(lambda keys, tilt, pis: jax.vmap(tilted_estimate, (0, 0, 0, 0, 0, None, None))(keys, tilt, pis, ecological_false, n, A, est_func))

find_true_tilt = jax.jit(lambda pis: find_optimal_tilt(pis, ecological_true, n, A, Niter=5, lr=1.))
find_true_notilt = jax.jit(lambda pis: find_optimal_tilt(pis, ecological_true, n, A, Niter=5, lr=0.))
find_false_tilt = jax.jit(lambda pis: find_optimal_tilt(pis, ecological_false, n, A, Niter=5, lr=1.))
find_false_notilt = jax.jit(lambda pis: find_optimal_tilt(pis, ecological_false, n, A, Niter=5, lr=0.))

dens_true_tilt = lambda pars, key: marginal_likelihood(key, pars, pars_func, find_true_tilt, estimate_prob_true)
dens_true_notilt = lambda pars, key: marginal_likelihood(key, pars, pars_func, find_true_notilt, estimate_prob_true)
dens_false_tilt = lambda pars, key: marginal_likelihood(key, pars, pars_func, find_false_tilt, estimate_prob_false)
dens_false_notilt = lambda pars, key: marginal_likelihood(key, pars, pars_func, find_false_notilt, estimate_prob_false)

val_pars = jnp.repeat(inv_softmax1(true_sampling)[None,:], repeats=200, axis=0)

keys = jax.random.split(jax.random.key(3), val_pars.shape[0])

Q_true_tilt = jax.vmap(dens_true_tilt, (0, 0))(val_pars, keys)[0].std(0).mean()
Q_true_notilt = jax.vmap(dens_true_notilt, (0, 0))(val_pars, keys)[0].std(0).mean()
Q_false_tilt = jax.vmap(dens_false_tilt, (0, 0))(val_pars, keys)[0].std(0).mean()
Q_false_notilt = jax.vmap(dens_false_notilt, (0, 0))(val_pars, keys)[0].std(0).mean()


In [None]:
dt = pd.DataFrame(
    {"Value":[Q_true_tilt, Q_true_notilt, Q_false_tilt, Q_false_notilt], 
     "Tilting": ["Tilting", "No tilting", "Tilting", "No tilting"],
     "Observation": ["Close to Mode", "Close to Mode", "Tail", "Tail"]})
dt.to_csv("../data/for_figures/tail_behavior.csv")

## Figure A1 – Importance sampling ESS calculation

In [None]:

def figureA1(size_bur, base, Nsim, make_prob_func,size_mat=3):
    K = 200
    Nmax = 200
    true_sampling, I = make_prob_func(size_mat=size_mat, base = base)
    n = jnp.ones(K)*size_bur
    true_data, ecological, ecological_full, A = construct_synthetic(true_sampling, I, n = n)
    ""
    init_state = jnp.ones(A.shape[1]-1)
    func1 = lambda pars: dens_approx_constant(pars, ecological, n, A)
    chain = get_approximate(init_state, Nmax, ecological, n, A, doublings=10, dens_approx_func=func1)

    # Define functions
    est_medium = jax.jit(lambda *args: estimate_medium(Nsim, *args))
    find_tilt = jax.jit(lambda pis: find_optimal_tilt(pis, ecological, n, A, Niter=5, lr=1.))
    estimate_prob = jax.jit(lambda keys, tilt, pis: jax.vmap(tilted_estimate, (0, 0, 0, 0, 0, None, None))(keys, tilt, pis, ecological, n, A, est_medium))

    dens_exact = lambda pars, key: marginal_likelihood_fixed(key, pars, K, find_tilt, estimate_prob)

    P = jax.vmap(dens_approx_constant, (0, None, None, None))(chain[100:,], ecological, n, A)
    Q = jax.vmap(dens_exact)(chain[100:,], jax.random.split(jax.random.key(0), chain[100:,].shape[0]))[0].sum(1)
    w = jnp.exp(Q-P - jnp.max(Q-P))
    ESS = (w.sum()**2)/(w**2).sum()

    return ESS

In [None]:
figureA1(100, 5., 5000, make_probability_X, 4)

In [None]:
dd1a = jnp.array([figureA1(50, 2., 5000, make_probability_I),
                  figureA1(1000, 2., 5000, make_probability_I)])

dd1b = jnp.array([figureA1(50, 5., 5000, make_probability_I),
                  figureA1(1000, 5., 5000, make_probability_I)])

dd2a = jnp.array([figureA1(50, 2., 5000, make_probability_X),
                  figureA1(1000, 2., 5000, make_probability_X)])

dd2b = jnp.array([figureA1(50, 5., 5000, make_probability_X),
                  figureA1(1000, 5., 5000, make_probability_X)])

In [None]:
# 
res = pd.DataFrame({"value":jnp.concatenate([dd1a, dd1b, dd2a, dd2b]),
    "n":np.repeat(np.array([50, 1000])[None,:], 4, 0).flatten(),
"multiplier":np.repeat(np.repeat(np.array(["Sparsity multiplier 2", "Sparsity multiplier 5"]), 2, 0)[None,:], 2, 0).flatten(),
"Type":np.repeat(np.array(["I-shape", "X-shape"]), 4)})
res.to_csv("../data/for_figures/evaluate_ESS.csv")

## Figure A2 – Idea of conditional probability

In [None]:
K = 200
size_bur = 150
Nsim = 100
true_sampling, I = make_probability_I(size_mat=3, base = 3.)
n = jnp.ones(K)*size_bur
true_data, ecological, ecological_full, A = construct_synthetic(true_sampling, I, n = n)
true_cond_sampling = inv_softmax2(true_sampling.reshape(I) / true_sampling.reshape(I).sum(1)[:,None])

est_medium = jax.jit(lambda *args: estimate_medium(Nsim, *args))
find_tilt = jax.jit(lambda pis: find_optimal_tilt(pis, ecological, n, A, Niter=5, lr=1.))
estimate_prob = jax.jit(lambda keys, tilt, pis: jax.vmap(tilted_estimate, (0, 0, 0, 0, 0, None, None))(keys, tilt, pis, ecological, n, A, est_medium))

dens_exact_cond = lambda pars, key: marginal_likelihood_margleft(key, pars, n, ecological_full, I, find_tilt, estimate_prob)
dens_exact = lambda pars, key: marginal_likelihood_fixed(key, pars, K, find_tilt, estimate_prob)

res_norm = jax.vmap(dens_exact)(inv_softmax1(true_sampling)[None,:].repeat(100,0), jax.random.split(jax.random.key(1), 100))[0].sum(1)
res_cond = jax.vmap(dens_exact_cond)(true_cond_sampling[None,:].repeat(100,0), jax.random.split(jax.random.key(1),100))[0].sum(1)

dd = pd.DataFrame({"Method": ["Fixed probability", "Conditional Probability"], "Standard Error of the log-likelihood":[res_norm.std(), res_cond.std()]})
dd.iloc[:,1] = dd.iloc[:,1].map(lambda x: f"{x:.{2}e}")

dd.to_latex("../docs/figures/conditional_probability.tex")

# Example on true data

## Version prsidential 2007

### Running the analysis

In [2]:
## Data pre-processing ##

rdt1 = pd.read_csv(base_file_pres2007 + "p2007_t1.csv")
rdt2 = pd.read_csv(base_file_pres2007 + "p2007_t2.csv")

## Selecting on rows ## 
otherT1 = rdt1.loc[:,["besa", "vill", "buff", "voyn", "lagu", "bove", "niho", "schi", "blanc"]].sum(1)
otherT2 = rdt2.loc[:,["blanc","abstention"]].sum(1)

crit_full = (otherT1 + rdt1["abstention"] < rdt1["total"]) & (otherT2 < rdt2["total"])
cc_full = np.arange(rdt1.shape[0])[crit_full]
list_idf = [75, 92, 93, 94, 77, 78, 91, 95]
cc_idf = np.arange(rdt1.shape[0])[crit_full & (rdt1["code_dept"].isin(list_idf))]

## Selecting on columns ##

# Small version
rdt1b = rdt1.loc[:,["sark", "roya", "bayr", "lepe", "abstention"]].copy()
rdt1b["other"] = otherT1

rdt2b = rdt2.loc[:,["sark", "roya"]].copy()
rdt2b["other"] = otherT2

data_small = [rdt1b, rdt2b]

# Large version

rdt1b = rdt1.loc[:,["sark", "roya", "bayr", "lepe"]].copy()
rdt1b["other_left"] = rdt1.loc[:,["besa", "buff", "voyn", "lagu", "bove", "schi"]].sum(1)
rdt1b["other_right"] = rdt1.loc[:,["vill", "niho"]].sum(1)
rdt1b["other"] =  rdt1.loc[:,["abstention","blanc"]].sum(1)

rdt2b = rdt2.loc[:,["roya", "sark"]].copy()
rdt2b["other"] = otherT2


data_large = [rdt1b, rdt2b]

In [None]:
# Analysis of Gaussianity on IDF #

full_inference(data=apply_sel(data_small, cc_idf), save_as=base_file_save+"pres2007_idf_fixed",
                     type_model = "fixed", step3_Nmax=1000, step3_Nchains=5, step3_parallel = True,
                     step3_type="IS", try_gaussian=False)

full_inference(data=apply_sel(data_small, cc_idf), save_as=base_file_save+"pres2007_idf_fixed_gaussian",
                     type_model = "fixed", step3_Nmax=1000, step3_Nchains=5, step3_parallel = True,
                     step3_type="IS", try_gaussian=True)

# Small version #

print("Full")

full_inference(data=apply_sel(data_small, cc_full), save_as=base_file_save+"pres2007_full",
                     type_model = "margleft", step3_Nmax=1000, step3_Nchains=5, step3_parallel = True,
                     step3_type="IS", try_gaussian=False)


### Figure 1 - Gaussian approximation evaluation

In [3]:
with open("../data/results/inference_pres2007_idf_fixed", "rb") as file:
    res_true = pickle.load(file)

with open("../data/results/inference_pres2007_idf_fixed_gaussian", "rb") as file:
    res_gaus = pickle.load(file)

In [9]:
cols1 = ["Sarkozy (right)", "Royal (left)", "Bayrou (center)", "Abstention", "Le Pen (far-right)", "Other"]
cols2 = ["Royal (left)", 'Sarkozy (right)', "Other"]

dt_true = pd.melt(make_posterior_pretty(res_true["step3"]["quants"], cols1, cols2))
dt_true["type"] = "True model"
dt_true["vote_t2"] = np.tile(cols1, dt_true.shape[0]//len(cols1))
dt_true.columns = ["vote_t2", "confidence", "value", "model", "vote_t1"]

dt_gaus = pd.melt(make_posterior_pretty(res_gaus["step3b"]["quants"], cols1, cols2))
dt_gaus["type"] = "Gaussian model"
dt_gaus["vote_t2"] = np.tile(cols1, dt_gaus.shape[0]//len(cols1))
dt_gaus.columns = ["vote_t2", "confidence", "value", "model", "vote_t1"]


In [10]:
dt = pd.concat([dt_true, dt_gaus])
dt.columns = ["vote_t2", "confidence", "value", "model", "vote_t1"]
dt.to_csv("../data/for_figures/comparison_gaussian.csv")

### Figure 2  - All France, all candidates

In [6]:
with open("../data/results/inference_pres2007_full", "rb") as file:
    res_full = pickle.load(file)

rows = ["Sarkozy", "Royal", "Bayrou", "Le Pen", "Abstention", "Other"]
cols = ["Sarkozy", "Royal", "Other"]
df = make_posterior_pretty(res_full["step3"]["quants"], rows, cols)

df2 = pd.melt(df)
df2["vote_t2"] = np.tile(rows, df2.shape[0]//len(rows))
df2.columns.set_names([""], inplace=True)
df2.columns = ["vote_t1", "quantile", "value", "vote_t2"]
df2.to_csv("../data/for_figures/full_france_small.csv")

df3 =df.xs('.5', level=1, axis=1)
df3.columns.set_names([""], inplace=True)
df3.to_latex("../docs/figures/full_france.tex",float_format="%.2f")

## Version legislatives 2024 

### Running the analysis

In [2]:
rdt1 = pd.read_csv(base_file_legi2024+"l2024_T1.csv")
rdt2 = pd.read_csv(base_file_legi2024+"l2024_T2.csv")

codex1 = pd.read_csv(base_file_legi2024+"l2024_codex1.csv")
codex2 = pd.read_csv(base_file_legi2024+"l2024_codex2.csv")

situations = pd.read_csv(base_file_legi2024+"l2024_situation_circos.csv").rename(columns={"situation_code":"code"})
situations = situations[situations["code"] > 0].copy()

name_circs_tab = pd.read_csv(base_file_legi2024+"l2024_name_circos.csv")
name_circs_tab = name_circs_tab.set_index(name_circs_tab["code_circ"])
name_circs = dict(name_circs_tab["name_circ"])
name_circs_tab = name_circs_tab.drop(labels=("code_circ"), axis=1)
name_circs_tab = name_circs_tab.reset_index()

list_circo1 = list(situations[situations["code"] > 0]["code_circ"])
num_circo = 268

def get_data(num_circo, name="name_long"):

    # Cand 1
    rdt1b = rdt1[rdt1["code_circ"]==num_circo].reset_index()
    nt1 = rdt1b.loc[0,"n_cand"]
    cand1 = ["abstention", "blanc_nul"]
    cand1.extend(["V"+str(i) for i in range(1, nt1+1)])

    rdt1c = rdt1b.loc[:,cand1]
    to_keep1 = (rdt1c.sum(0) / rdt1c.sum().sum() > 0.05)

    rdt1d = rdt1c.loc[:,to_keep1].copy()
    rdt1d["other"] = rdt1c.loc[:,~to_keep1].sum(1)

    codex1b = codex1[codex1["code_circ"]==num_circo].copy()
    codex1b.index = rdt1c.columns[2:]
    codex1c = codex1b.loc[to_keep1]

    # Cand 2
    rdt2b = rdt2[rdt2["code_circ"]==num_circo].reset_index()
    nt2 = rdt2b.loc[0,"n_cand"]
    cand2 = ["abstention", "blanc_nul"]
    cand2.extend(["V"+str(i) for i in range(1, nt2+1)])
    rdt2c = rdt2b.loc[:,cand2]
    to_keep2 = (rdt2c.sum(0) / rdt2c.sum().sum() > 0.05)

    rdt2d = rdt2c.loc[:,to_keep2].copy()
    rdt2d["other"] = rdt2c.loc[:,~to_keep2].sum(1)

    codex2b = codex2[codex2["code_circ"]==num_circo].copy()
    codex2b.index = rdt2c.columns[2:]
    codex2c = codex2b.loc[to_keep2]

    data = [rdt1d.astype(int), rdt2d.astype(int)]
    codex = [codex1c, codex2c]

    cc1 = []
    for item in data[0].columns:
        if item[0]=="V": cc1.append(codex[0].loc[item,name])
        else: cc1.append(item)

    cc2 = []
    for item in data[1].columns:
        if item[0]=="V": cc2.append(codex[1].loc[item,name])
        else: cc2.append(item)

    data[0].columns = cc1
    data[1].columns = cc2

    return data, codex

def full_inference_remote(num_circo):
    return full_inference(get_data(num_circo)[0], "legis2024_circo_"+str(num_circo), type="margleft", scale_step3="medium", Nmax_step1 = 200, 
                   Nmax_step2=1000, Nmax_step2b = 1000, Nmax_step3b=5000, Nmax_step3c=500, 
                   base_file=base_file_save, use_qmc=False, save_full_posterior=False, type_machine = type_machine, correct = True)

list_circo = list(situations[situations["code"] > 0]["code_circ"])

In [None]:
current_circo = dict()
i = 0
for num_circo in list_circo1:
    current_circo[num_circo] = full_inference_remote(num_circo)

    i += 1
    if i % 20 == 0:
        with open(base_file_save+"legi2024_circos_"+str(num_circo), "wb") as file:
            pickle.dump(current_circo)

        current_circo = dict()

import time

### Figures 

In [3]:
# Circos 1 à 4: simple joint

with open(base_file_save+"legi2024_circos_all1", "rb") as file: res1 = pickle.load(file)
with open(base_file_save+"legi2024_circos_all2", "rb") as file: res2 = pickle.load(file)
with open(base_file_save+"legi2024_circos_all3", "rb") as file: res3 = pickle.load(file)
with open(base_file_save+"legi2024_circos_all4", "rb") as file: res4 = pickle.load(file)
with open(base_file_save+"legi2024_circos_all5", "rb") as file: res5 = pickle.load(file)
with open(base_file_save+"legi2024_circos_all6", "rb") as file: res6 = pickle.load(file)
with open(base_file_save+"legi2024_circos_all7", "rb") as file: res7 = pickle.load(file)

circos = (res1 | res2 | res3 | res4 | res5 | res6 | res7)
circos = {circo_id:circos[circo_id] for circo_id in circos.keys() if circo_id in name_circs.keys()}

sit1 = situations[situations["code"]==1]["code_circ"]
sit2 = situations[situations["code"]==2]["code_circ"]
sit3 = situations[situations["code"]==3]["code_circ"]
sit4 = situations[situations["code"]==4]["code_circ"]

codex1["name_long2"] = codex1["name"] + " " + codex1["orientation"] 
codex2["name_long2"] = codex2["name"] + " " + codex2["orientation"]

pretty_graphs = dict()
pretty_graphs_full = dict()
for num_circo in circos.keys():
    pretty_graphs[num_circo] = make_posterior_pretty(circos[num_circo]["quants"], get_data(num_circo, "bloc")[0][0].columns, get_data(num_circo, "bloc")[0][1].columns)

for num_circo in circos.keys():
    pretty_graphs_full[num_circo] = make_posterior_pretty(circos[num_circo]["quants"], get_data(num_circo, "name_long2")[0][0].columns, get_data(num_circo, "name_long2")[0][1].columns)


### Figure 8

In [12]:
num = 0
ll1 = []
ll2 = []
ll3 = []
ll4 = []
for num in sit1:
    item = pretty_graphs[num].loc["center","left"][".5"]
    if not np.isnan(item).any() and np.array(item).reshape(1,-1).shape[1] == 1:
        ll1.append((num, 1, item))
    
for num in sit2:
    item = pretty_graphs[num].loc["center","left"][".5"]
    if not np.isnan(item).any() and np.array(item).reshape(1,-1).shape[1] == 1:
        ll2.append((num, 2, item))

for num in sit3:
    item = pretty_graphs[num].loc["left","center"][".5"]
    if not np.isnan(item).any() and np.array(item).reshape(1,-1).shape[1] == 1:
        ll3.append((num, 3, item))

for num in sit4:
    item = pretty_graphs[num].loc["left","center"][".5"]
    if not np.isnan(item).any() and np.array(item).reshape(1,-1).shape[1] == 1:
        ll4.append((num, 4, item))

outcomes = pd.concat([pd.DataFrame(ll1), pd.DataFrame(ll2), pd.DataFrame(ll3), pd.DataFrame(ll4)])
outcomes.columns = ["code_circ", "situation", "prob"]
outcomes = pd.merge(outcomes, name_circs_tab, on = "code_circ")
outcomes = pd.merge(outcomes, situations.loc[:,["code_circ", "share_far_right", "share_strongest"]])
outcomes.to_csv("../data/for_figures/outcomes_all_constituencies.csv")
outcomes.to_csv("../data/for_figures/constituencies.csv")

### Table for individual candidates

In [6]:
dt1 = pretty_graphs_full[83]
dt1.columns.set_names(['', ''], inplace=True)
dt1.to_latex("../docs/figures/borne.tex", float_format=lambda x: f"{x:.2f}")

dt2 = pretty_graphs_full[171]
dt2.columns.set_names(['', ''], inplace=True)
dt2.to_latex("../docs/figures/ruffin.tex", float_format=lambda x: f"{x:.2f}")

dt3 = pretty_graphs_full[489]
dt3.columns.set_names(['', ''], inplace=True)
dt3.to_latex("../docs/figures/darmanin.tex", float_format=lambda x: f"{x:.2f}")

dt4 = pretty_graphs_full[274]
dt4.columns.set_names(['', ''], inplace=True)
dt4.to_latex("../docs/figures/leaument.tex", float_format=lambda x: f"{x:.2f}")
# remove _

### Figure 9

In [7]:
rr = pd.read_csv("../data/source/legislatives_2024/l2024_circos_FP.csv")
rr.columns = ["code_circ", "type_FP"]
rr["code_circ2"] = rr["code_circ"].str.slice(0,2).astype(int).astype(str) + "_" + \
rr["code_circ"].str.slice(3,5).astype(int).astype(str)

outcomes["code_circ2"] = outcomes["code_dept"].astype(int).astype(str) + "_" + \
outcomes["number_circ"].astype(int).astype(str)

outcomes2 = pd.merge(outcomes, rr.loc[:,["type_FP",	"code_circ2"]], on  = "code_circ2")
outcomes2.to_csv("../data/for_figures/FP_cases.csv")

## Version presidential 2022 

### Inference

In [8]:
rdt1 = pd.read_csv(base_file_pres2022+"p2022_t1_net.csv")
rdt2 = pd.read_csv(base_file_pres2022+"p2022_t2_net.csv")
context = pd.read_csv(base_file_pres2022+"p2022_context_net.csv")

rdt1b = rdt1.loc[:,["Macron", "Le Pen", "Melenchon"]]
rdt1b["other"] = rdt1["Hidalgo"] + rdt1["Roussel"] + rdt1["Poutou"] + rdt1["Arthaud"] + rdt1["Jadot"] + \
rdt1["Lassalle"] + rdt1["Dupont-Aignan"] + rdt1["Pecresse"] + rdt1["abstention"] + rdt1["blanc_nul"] + rdt1["Zemmour"]

rdt2b = rdt2.loc[:,["Macron", "Le Pen"]]
rdt2b["other"] = rdt2["abstention"] + rdt2["blanc_nul"]

contextb = context.loc[:,["density"]]
dens_mean = contextb["density"].mean()
dens_std = contextb["density"].std()
contextb["density"] = (contextb["density"] - dens_mean)/dens_std


otherT1 = rdt1["total"] -  rdt1.loc[:,["Macron", "Le Pen", "Melenchon"]].sum(1)
otherT2 = rdt2["total"] - rdt2.loc[:,["Macron", "Le Pen"]].sum(1)
crit_full = (otherT1 < rdt1["total"]) & (otherT2 < rdt2["total"])

data = [rdt1b, rdt2b, contextb]
crit_full = ~np.isnan(context["density"])
cc = np.arange(rdt1.shape[0])[crit_full]

data2 = apply_sel(data, cc)
type = "covariate1"

  rdt1 = pd.read_csv(base_file_pres2022+"p2022_t1_net.csv")


In [9]:
# Before launching, add the criterions like in 2007 (see l13 above)

#full_inference(data=data2, save_as=base_file_save+"pres2022_covariate", type_model = "covariate1", 
#               step2_lr=1e-3, step3_Nmax=1000, step3_Nchains=5, step3_parallel = True,
#                     step3_type="IS", try_gaussian=False)

full_inference(data=data2, save_as=base_file_save+"pres2022_full", type_model = "margleft", 
               step2_lr=1e-2, step3_Nmax=1000, step3_Nchains=5, step3_parallel = True,
                     step3_type="IS", try_gaussian=False)


Step 1


  0%|          | 0/3000 [00:02<?, ?it/s]


KeyboardInterrupt: 

In [None]:

list_idf = [75, 92, 93, 94, 77, 78, 91, 95]

for item in list_idf:
    print(item)
    cc_dep = np.arange(rdt1.shape[0])[crit_full & (rdt1["code_dept"] == item)]
    full_inference(data=apply_sel(data, cc_dep), study_name="pres2024_dep"+str(item), scale_step3="medium")


### Figure 1

In [None]:
list_idf = [75, 92, 93, 94, 77, 78, 91, 95]
df_list = []

for item in list_idf:
    with open("../data/results/inference_pres2022_dep"+str(item), "rb") as file:
        res_dep = pickle.load(file)
    rows = ["Macron (center)", "Le Pen (far-right)", "Melenchon (far-left)", "Other"]
    cols = ["Macron (center)", "Le Pen (far-right)", "Other"]
    df = make_posterior_pretty(res_dep["step3"]["quants"], rows, cols)

    df2 = pd.melt(df.loc[["Melenchon (far-left)"]])
    df2["start_candidate"] = np.tile(["Melenchon (far-left)",], 9)
    df2["department"] = item

    df_list.append(df2)

res = pd.concat(df_list)

res.columns = ["vote_t2", "quantile", "value", "vote_t1", "department"]
res.to_csv("../data/for_figures/idf_departments.csv")

### Figure 2

In [3]:
with open("../data/results/inference_pres2022_covariate", "rb") as file:
    res1 = pickle.load(file)

with open("../data/results/inference_pres2022_full", "rb") as file:
    res0 = pickle.load(file)

In [12]:
w1 = res1["step3"]["weights"][0] - res1["step3"]["weights"][1] 
B1 = jnp.log(jnp.exp(w1 - w1.max()).mean()) + w1.max()

w0 = res0["step3"]["weights"][0] - res0["step3"]["weights"][1] 
B0 = jnp.log(jnp.exp(w0 - w0.max()).mean()) + w0.max()

(B1-B0)/jnp.log(10)

Array(62058.402, dtype=float32)

In [None]:
quants = jnp.quantile(res1["step3"]["posterior"], q=jnp.array([.05, 0.5, 0.95]), axis=0).T
quants = quants[jnp.array([4,5,12,13]),]
df = pd.DataFrame(quants, columns=[0.05, 0.5, 0.95])
#df["Option"] = ["Le Pen (far-right)", "Abstention / other", "Le Pen (far-right)", "Abstention / other"]
#df["Coefficient"] = ["Intercept", "Intercept", "Density", "Density"]
#df.to_csv("../data/for_figures/density.csv")

df = pd.DataFrame(np.concatenate([quants[:2], quants[2:]], axis=1), columns=["i.05","i.5", "i.95","s.05","s.5", "s.95"])
df["Option"] = ["Le Pen (far right)", "Abstention / other"]
df.to_csv("../data/for_figures/density.csv")

In [None]:
nspace = 100
space = jnp.linspace(-2,2,nspace)

pred1 = (res1["step3"]["posterior"][:,4].reshape(-1,1) + res1["step3"]["posterior"][:,12].reshape(-1,1) @ space.reshape(1, -1)) # Melenchon Le Pen
pred2 = (res1["step3"]["posterior"][:,5].reshape(-1,1) + res1["step3"]["posterior"][:,13].reshape(-1,1) @ space.reshape(1, -1)) # Melenchon Abstention
ww = jnp.ones(pred1.shape[0])/pred1.shape[0] # Set true weight here

res = jnp.concat([pred1[:,:,None], pred2[:,:,None]], axis=2)
res2 = (softmax2(res.reshape((-1,2))).reshape((-1,nspace,3)))

res3 = jax.vmap(weighted_quantile, (1, None, None))(res2.reshape(-1,nspace*3), jnp.array([.05,.5,.95]), ww)
res4 = res3.reshape((nspace,3,3))

res5 = pd.DataFrame(res3, columns=[".05",".5",".95"])
res5["grille"] = jnp.exp(np.repeat(space,3)*dens_std+dens_mean)
res5["candidate"] = np.tile(["Macron (center)", "Le Pen (far-right)", "Other"], reps=nspace)
res5.to_csv("../data/for_figures/density.csv")


In [None]:
res2.shape

(3000, 100, 3)

### Figure for presentation