In [None]:
import random

import dill
import pandas as pd
import torch
from torch import nn
import numpy as np
from scipy.stats import lognorm
from fairnessdatasets import SouthGerman
import seaborn as sns
import matplotlib.pyplot as plt

from probspecs import TabularInputSpace
from probspecs.distributions import *

from experiments.input_spaces import south_german_input_space

torch.manual_seed(865470788748102)
np.random.seed(9931567)
random.seed(333161767645837)

In [None]:
dataset = SouthGerman(root="../../.datasets", download=True)
dataset_raw = SouthGerman(root="../../.datasets", raw=True, download=True)
dataset_df = pd.DataFrame(dataset_raw.data, columns=dataset_raw.columns)
dataset_df["dataset"] = "training"  # for later
dataset_df

# Visualize the Dataset


In [None]:
%%capture --no-stdout --no-display
fig, axes = plt.subplot_mosaic(
    [
        ["status", "credit_history", "purpose", "savings"],
        [
            "employment_duration",
            "installment_rate",
            "other_debtors",
            "present_residence",
        ],
        ["property", "other_installment_plans", "housing", "number_credits"],
        ["job", "people_liable", "telephone", "foreign_worker"],
        ["personal_status_sex", "personal_status_sex", "duration", "duration"],
        ["age", "age", "amount", "amount"],
    ],
    figsize=(15, 20),
)
for var in SouthGerman.variables:
    if var not in ("age", "duration", "amount"):
        g = sns.histplot(
            dataset_df,
            x=var,
            discrete=True,
            shrink=0.8,
            multiple="dodge",
            stat="percent",
            common_norm=False,
            legend=False,
            ax=axes[var],
        )
    else:
        g = sns.histplot(
            dataset_df,
            x=var,
            stat="percent",
            common_norm=False,
            ax=axes[var],
        )
    g.set(title=var, xlabel=None)

In [None]:
%%capture --no-stdout --no-display
fig = plt.figure(figsize=(7, 5))

data_corrcoef = np.corrcoef(dataset_raw.data.T.numpy())
_ = sns.heatmap(
    data_corrcoef,
    vmin=-1.0,
    vmax=1.0,
    square=True,
    cmap="RdBu",
    xticklabels=SouthGerman.variables,
    yticklabels=SouthGerman.variables,
)

# Base Population Model
We fit a Bayesian network to the dataset to obtain a base population model.
We postulate the following structure of the Bayesian network:
- We pose `age` and `foreign_worker` as root variables.
- We introduce a categorical latent variable `gender`. We use historical data to model the distribution of `gender` explicitly. 
  However, since no data on non-binary individuals is available to us, we only model two genders.
- We introduce a categorical latent variable `marital_status`.
- We introduce a categorical latent variable `bg` with 7 values intended for 
  modelling personality factors.
- `personal_status_sex` is influenced by `marital_status`, `gender`.
- `people_liable` is influenced by `marital_status` and `gender`.
- `job` is influenced by `bg`, `gender` and `foreign_worker`.
- `employment_duration` is influenced by `bg`, `gender`, `age`, and `foreign_worker`.
- A second categorical latent variable `income` with 5 values is intended to model
  income classes and is influenced by `job`, `employment_duration`, `foreign_worker`.
  Due to technical limitations, we can not model direct gender-based discrimination.
- `status` is influenced by `bg` and `income` (assumption: payment morale only indirectly influenced by gender).
- `savings` is influenced by `bg`, `income`.
- `credit_history` is influenced by `bg` and `income`.
- `purpose` is influenced by `bg`, `income`.
- `amount` is influenced by `purpose`.
- `duration` is influenced by `amount`.
- `installment_rate` is influenced by `amount` and `income`.
- `other_debtors` is influenced by `amount` and `bg`.
- `present_residence` is influenced by `age` and `bg`.
- `property` is influenced by `income`.
- `other_installment_plans` is influenced by `income` and `bg`.
- `housing` is influenced by `income`.
- `number_credits` is influenced by `income` and `bg`.
- `telephone` is influenced by `income` and `bg`.

## Gender and Marital Status
The `SouthGerman` dataset doesn't contain a variable for gender or marital status.
Therefore, we introduce both as latent variables that influence `personal_status_sex`.

Interesting background on the gender-related economic situation in West Germany in the 1970s
is contained in the report of the European Commission's Expert Group on Gender and Employment (EGGE) 
on the unadjusted and adjusted gender pay gap for Germany by Friederike Maier.
https://documents.manchester.ac.uk/display.aspx?DocID=50202
Concretely, we use the reported 1977 distributions for West Germany.

The EGGE reports the fraction of full-time employed women as 24.63% in 1977.
This is in contrast to the surplus of women in the general population in West Germany.
Women are reported to make up 52.6% of the overall population in West Germany in 1975 
according to the United Nations (https://population.un.org/wpp/).

For marital status, we did not find data split by gender.
Therefore, we model marital status independent of gender and use the
1970/1971 distribution from Statista for West Germany.
https://de.statista.com/statistik/daten/studie/1059366/umfrage/zahl-der-einwohner-nach-familienstand-in-deutschland/
The distribution is: 24 million people single, 
30.3 million married people (only age group 15-64), 5.2 million widowed, 
1.1 million separated.
According to the United Nations (https://population.un.org/wpp/), the overall
population of West Germany in 1970 was 78.2 million, including East Germany.
With the total population reported by Statista for West and East Germany
being 77.7 million, we assume the missing 1.1 million individuals are
married individuals outside the 15-64 age group.
The population of West Germany makes up 78.0% of the total german population
in 1970 according to Statista.
Therefore, we attribute 78% of the 1.1 million additional married individuals
to West Germany.
Overall, we obtain 24 mil (39.0%) single, 31.2 mil (50.7%) married, 
5.2 mil (8.5%) widowed and 1.1 mil (1.8%) separated.

The `SouthGerman` dataset does not account for separated or widowed women
in it's `personal_status_sex` variable.
We assume that separated and widowed women are split among the other
`personal_status_sex` groups containing women, but more likely to report to be
married.



In [None]:
def empirical_frequencies(var, num_vals=None):
    frequencies = dataset_df[var].value_counts(normalize=True)
    frequencies.sort_index(ascending=True, inplace=True)
    if num_vals is not None:
        for i in range(num_vals):
            if i not in frequencies.index:
                frequencies[i] = 0.0
    return frequencies


def random_weights(size):
    weights = torch.rand(size)
    return weights / weights.sum()

In [None]:
bayes_net_factory = BayesianNetwork.Factory()

## `age` and `foreign_worker`
The root variables `age` and `foreign_worker` are fit directly to the
frequencies in the dataset.


In [None]:
age_frequencies = empirical_frequencies("age")
age_frequencies

In [None]:
age_node = bayes_net_factory.new_node("age", replace=True)
age_node.discrete_event_space(*age_frequencies.index.tolist())
age_node.set_conditional_probability(
    {}, Categorical(age_frequencies.tolist(), values=age_frequencies.index.tolist())
)

age_groups = [(19.0, 29.0), (30.0, 59.0), (60.0, 75.0)]

In [None]:
foreign_worker_node = bayes_net_factory.new_node("foreign_worker", replace=True)
foreign_worker_node.one_hot_event_space(2)
frequencies = empirical_frequencies("foreign_worker")
foreign_worker_node.set_conditional_probability(
    {}, CategoricalOneHot(frequencies.tolist())
)

## `gender` and `marital_status`
We initialize the distributions with the historical values, but allow adapting these values to accommodate differences in the general population and the population applying for credits at the bank.


In [None]:
# no data on other genders available
# values: Female, Male
gender_size = 2
gender_node = bayes_net_factory.new_node("gender", replace=True)
gender_node.hidden = True
gender_node.discrete_event_space(0.0, 1.0)
gender_node.set_conditional_probability({}, Categorical([0.526, 0.474]))

In [None]:
marital_status_node = bayes_net_factory.new_node("marital_status", replace=True)
marital_status_node.hidden = True
# values: single, married, widowed, separated
marital_status_node.discrete_event_space(0.0, 1.0, 2.0, 3.0)
marital_status_node.set_conditional_probability(
    {}, Categorical([0.390, 0.507, 0.085, 0.018])
)

## `personal_status_sex`
The values of `personal_status_sex` are:
- `male: divorced/separated`
- `female: non-single or male: single`
- `male: married/widowed`
- `female: single`

Here, we freeze the distributions to keep the latent nodes `gender` and `marital_status` aligned with the values we'd like them to represent.


In [None]:
pers_status_sex_node = bayes_net_factory.new_node("personal_status_sex", replace=True)
pers_status_sex_node.set_parents(gender_node, marital_status_node)
pers_status_sex_node.one_hot_event_space(4)
# we do not model gender disagreeing with sex, due to unavailability of data
# and missing information on the historical circumstances.
pers_status_sex_node.set_conditional_probability(  # female single
    {gender_node: 0.0, marital_status_node: 0.0},
    CategoricalOneHot([0.0, 0.0, 0.0, 1.0], frozen=True),
)
pers_status_sex_node.set_conditional_probability(  # female married
    {gender_node: 0.0, marital_status_node: 1.0},
    CategoricalOneHot([0.0, 1.0, 0.0, 0.0], frozen=True),
)
pers_status_sex_node.set_conditional_probability(  # female widowed/separated
    {gender_node: 0.0, marital_status_node: ([2.0], [3.0])},
    CategoricalOneHot([0.0, 0.9, 0.0, 0.1], frozen=True),
)
pers_status_sex_node.set_conditional_probability(  # male single
    {gender_node: 1.0, marital_status_node: 0.0},
    CategoricalOneHot([0.0, 1.0, 0.0, 0.0], frozen=True),
)
pers_status_sex_node.set_conditional_probability(  # male married/widowed
    {gender_node: 1.0, marital_status_node: ([1.0], [2.0])},
    CategoricalOneHot([0.0, 0.0, 1.0, 0.0], frozen=True),
)
pers_status_sex_node.set_conditional_probability(  # male separated
    {gender_node: 1.0, marital_status_node: 3.0},
    CategoricalOneHot([1.0, 0.0, 0.0, 0.0], frozen=True),
)

## `bg`
A categorical variable with probabilities that are estimated by fitting.
We make `bg` dependent on `marital_status`, as we'd like to model correlations between these variables.


In [None]:
bg_size = 7
bg_node = bayes_net_factory.new_node("bg", replace=True)
bg_node.hidden = True
bg_node.discrete_event_space(*list(range(bg_size)))
bg_node.set_conditional_probability({}, Categorical(random_weights(bg_size)))


## `people_liable`
Depends on `marital_status` and `gender`.


In [None]:
node = bayes_net_factory.new_node("people_liable", replace=True)
node.set_parents(gender_node, marital_status_node)
node.one_hot_event_space(2)
for i in range(gender_size):
    for j in range(4):
        node.set_conditional_probability(
            {gender_node: float(i), marital_status_node: float(j)},
            CategoricalOneHot(random_weights(2)),
        )

## `job`


In [None]:
job_node = bayes_net_factory.new_node("job", replace=True)
job_node.set_parents(bg_node, gender_node, foreign_worker_node)
job_node.one_hot_event_space(4)
for i in range(bg_size):  # bg
    for j in range(2):  # foreign worker
        for k in range(gender_size):  # gender
            job_node.set_conditional_probability(
                {
                    bg_node: float(i),
                    gender_node: float(k),
                    foreign_worker_node: [float(l == j) for l in range(2)],
                },
                CategoricalOneHot(random_weights(4)),
            )

## `employment_duration`


In [None]:
empl_dur_node = bayes_net_factory.new_node("employment_duration", replace=True)
empl_dur_node.set_parents(bg_node, gender_node, age_node, foreign_worker_node)
empl_dur_node.one_hot_event_space(5)
for i in range(bg_size):
    for age_group in age_groups:
        for j in range(2):  # foreign worker
            for k in range(gender_size):
                empl_dur_node.set_conditional_probability(
                    {
                        bg_node: [float(i)],
                        gender_node: [float(k)],
                        age_node: age_group,
                        foreign_worker_node: [float(k == j) for k in range(2)],
                    },
                    CategoricalOneHot(random_weights(5)),
                )

## `income`
Another latent variable, dependent on `job`, `employment_duration`,
and `foreign_worker`.


In [None]:
income_size = 5
income_node = bayes_net_factory.new_node("income", replace=True)
income_node.hidden = True
income_node.set_parents(job_node, empl_dur_node, foreign_worker_node)
income_node.discrete_event_space(*list(range(income_size)))
for i in range(4):
    for j in range(5):
        for k in range(2):
            income_node.set_conditional_probability(
                {
                    job_node: [float(l == i) for l in range(4)],
                    empl_dur_node: [float(l == j) for l in range(5)],
                    foreign_worker_node: [float(l == k) for l in range(2)],
                },
                Categorical(random_weights(income_size)),
            )

## `status`, `savings`, `credit_history`, `purpose`
All depend on `bg` and `income`.


In [None]:
nodes = bayes_net_factory.new_nodes(
    "status", "savings", "credit_history", "purpose", replace=True
)
sizes = [4, 5, 5, 11]
for node, size in zip(nodes, sizes):
    node.set_parents(job_node, empl_dur_node)
    node.one_hot_event_space(size)
    for i in range(4):
        for j in range(5):
            node.set_conditional_probability(
                {
                    job_node: [float(k == i) for k in range(4)],
                    empl_dur_node: [float(k == j) for k in range(5)],
                },
                CategoricalOneHot(random_weights(size)),
            )
purpose_node = nodes[-1]

## `amount`
depends on `purpose`


In [None]:
amount_data = dataset_df["amount"]
amount_lower = 250.0
amount_upper = 20_000.0

In [None]:
amount_node = bayes_net_factory.new_node("amount", replace=True)
amount_node.set_parents(purpose_node)
amount_node.continuous_event_space(amount_lower, amount_upper)

for i in range(11):  # purpose
    # We fit this here, because truncnorm mixtures are hard to fit otherwise.
    distribution = MixtureModel.fit_truncnorm_mixture(
        amount_data,
        (amount_lower, amount_upper),
        n_components=3,
        n_init=1,
        seed=202404190311 + i,
    )
    distribution = AsInteger.wrap(distribution)
    amount_node.set_conditional_probability(
        {purpose_node: [float(k == i) for k in range(11)]}, distribution
    )

amount_classes = [(amount_lower, 1250.0), (1250.0, 5000.0), (5000.0, amount_upper)]

## `duration`


In [None]:
duration_frequencies = empirical_frequencies("duration")
duration_frequencies

In [None]:
dur_node = bayes_net_factory.new_node("duration", replace=True)
dur_node.set_parents(amount_node)
dur_node.discrete_event_space(*duration_frequencies.index.tolist())
for amount_group in amount_classes:
    distribution = Categorical(
        random_weights(len(duration_frequencies)),
        values=duration_frequencies.index.tolist(),
    )
    dur_node.set_conditional_probability({amount_node: amount_group}, distribution)

## `installment_rate`


In [None]:
install_rate_node = bayes_net_factory.new_node("installment_rate", replace=True)
install_rate_node.set_parents(amount_node, income_node)
install_rate_node.one_hot_event_space(4)
for amount_group in amount_classes:
    for i in range(income_size):
        install_rate_node.set_conditional_probability(
            {amount_node: amount_group, income_node: float(i)},
            CategoricalOneHot(random_weights(4)),
        )

## `other_debtors`


In [None]:
other_debtors = bayes_net_factory.new_node("other_debtors", replace=True)
other_debtors.set_parents(amount_node, bg_node)
other_debtors.one_hot_event_space(3)
for amount_group in amount_classes:
    for i in range(bg_size):
        other_debtors.set_conditional_probability(
            {amount_node: amount_group, bg_node: float(i)},
            CategoricalOneHot(random_weights(3)),
        )

## `present_residence`


In [None]:
residence_node = bayes_net_factory.new_node("present_residence", replace=True)
residence_node.set_parents(age_node, bg_node)
residence_node.one_hot_event_space(4)
for age_group in age_groups:
    for i in range(bg_size):
        residence_node.set_conditional_probability(
            {age_node: age_group, bg_node: float(i)},
            CategoricalOneHot(random_weights(4)),
        )

## `property`, `housing`
both depend on `income`.


In [None]:
nodes = bayes_net_factory.new_nodes("property", "housing", replace=True)
sizes = (4, 3)
for node, size in zip(nodes, sizes):
    node.set_parents(income_node)
    node.one_hot_event_space(size)
    for i in range(income_size):
        node.set_conditional_probability(
            {income_node: float(i)}, CategoricalOneHot(random_weights(size))
        )

## `other_installment_plans`, `number_credits`, `telephone`
all depend on `bg` and `income`.


In [None]:
nodes = bayes_net_factory.new_nodes(
    "other_installment_plans", "number_credits", "telephone", replace=True
)
sizes = (3, 4, 2)
for node, size in zip(nodes, sizes):
    node.set_parents(income_node, bg_node)
    node.one_hot_event_space(size)
    freqs = empirical_frequencies(node.name, size)
    for i in range(bg_size):
        for j in range(income_size):
            node.set_conditional_probability(
                {bg_node: float(i), income_node: float(j)},
                CategoricalOneHot(random_weights(size)),
            )

## Build and fit the population model.

In [None]:
base_variables = tuple(dataset_raw.columns) + (
    "gender",
    "marital_status",
    "bg",
    "income",
)
bayes_net_factory.reorder_nodes(base_variables)
base_bayes_net = bayes_net_factory.create()

In [None]:
var_types = {
    var: south_german_input_space.attribute_type(var) for var in SouthGerman.variables
}
full_base_var_types = var_types | {
    "gender": TabularInputSpace.AttributeType.INTEGER,
    "marital_status": TabularInputSpace.AttributeType.INTEGER,
    "bg": TabularInputSpace.AttributeType.INTEGER,
    "income": TabularInputSpace.AttributeType.INTEGER,
}
integer_ranges = {
    var: south_german_input_space.attribute_bounds(var)
    for var in SouthGerman.variables
    if SouthGerman.variables[var] is None
}
full_integer_ranges = integer_ranges | {
    "gender": (0, gender_size - 1),
    "marital_status": (0, 3),
    "bg": (0, bg_size - 1),
    "income": (0, income_size - 1),
}
categorical_values = {
    var: south_german_input_space.attribute_values(var)
    for var in SouthGerman.variables
    if SouthGerman.variables[var] is not None
}
full_base_bayes_net_input_space = TabularInputSpace(
    base_variables,
    data_types=full_base_var_types,
    continuous_ranges={},
    integer_ranges=full_integer_ranges,
    categorical_values=categorical_values,
)
dataset_input_space = south_german_input_space

### Simulate some data before fitting to see the improvement.


In [None]:
n = 10000


def sample_bayes(bayes_net, input_space, seed):
    generated_data = bayes_net.sample(n, seed=seed)
    generated_raw = {}
    subspace_layout = input_space.encoding_layout
    for var in input_space.attribute_names:
        cols = subspace_layout[var]
        if isinstance(cols, int):
            generated_raw[var] = generated_data[:, cols]
        else:
            values_one_hot = generated_data[:, list(cols.values())]
            values = np.argmax(values_one_hot, axis=1)
            generated_raw[var] = values

    generated_df = pd.DataFrame(generated_raw)
    generated_df["dataset"] = "generated"
    return pd.concat([generated_df, dataset_df])

In [None]:
%%capture --no-stdout --no-display
base_bayes_net.include_hidden = True
df = sample_bayes(base_bayes_net, full_base_bayes_net_input_space, seed=202404252023)
base_bayes_net.include_hidden = False
fig, axes = plt.subplot_mosaic(
    [
        ["status", "credit_history", "purpose", "savings"],
        [
            "employment_duration",
            "installment_rate",
            "other_debtors",
            "present_residence",
        ],
        ["property", "other_installment_plans", "housing", "number_credits"],
        ["job", "people_liable", "telephone", "foreign_worker"],
        ["personal_status_sex", "personal_status_sex", "duration", "duration"],
        ["age", "age", "amount", "amount"],
        ["gender", "marital_status", "bg", "income"],
    ],
    figsize=(15, 24),
)
for var in list(SouthGerman.variables) + ["gender", "marital_status", "bg", "income"]:
    if var not in ("age", "duration", "amount"):
        g = sns.histplot(
            df,
            x=var,
            hue="dataset",
            discrete=True,
            shrink=0.8,
            multiple="dodge",
            stat="percent",
            common_norm=False,
            legend=False,
            ax=axes[var],
        )
    else:
        g = sns.histplot(
            df,
            x=var,
            hue="dataset",
            stat="percent",
            common_norm=False,
            ax=axes[var],
        )
    g.set(title=var, xlabel=None)

In [None]:
%%capture --no-stdout --no-display
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

generated_df_ = df[df["dataset"] == "generated"][[var for var in SouthGerman.variables]]
pop_model_corrcoef = np.corrcoef(generated_df_.to_numpy().T)
# data_corrcoef = np.corrcoef(dataset_raw.data.T.numpy())
diff = pop_model_corrcoef - data_corrcoef
for corrcoef, ax in zip((pop_model_corrcoef, data_corrcoef, diff), axes):
    _ = sns.heatmap(
        corrcoef,
        vmin=-1.0,
        vmax=1.0,
        square=True,
        cmap="RdBu",
        xticklabels=SouthGerman.variables,
        yticklabels=SouthGerman.variables,
        ax=ax,
    )
_ = axes[0].set_title("Population Model")
_ = axes[1].set_title("Training Data")
_ = axes[2].set_title("Difference")

The distributions all look good, since they are all fitted to the data, but the correlations are missing.

### Fitting
Expect fitting to run for up to two days.


In [None]:
pd.Series(base_bayes_net.parameters.detach())

In [None]:
do_fit = True  # remember, fitting may take days

In [None]:
if do_fit:
    # Use the unnormalized values for `duration`, `age`, and `amount` from dataset_raw, but the one-hot encoded
    # values of all other variables from dataset.
    data = dataset.data
    for i, col in enumerate(dataset.columns):
        if col in ("age", "duration", "amount"):
            raw_i = [
                j for j, col_raw in enumerate(dataset_raw.columns) if col_raw == col
            ][0]
            data[:, i] = dataset_raw.data[:, raw_i]

    def callback(res):
        # for some reason scipy also calls the callback with numpy arrays
        if hasattr(res, "fun"):
            print(f"Current Likelihood: {res.fun:.4f}")

    base_bayes_net.fit(
        data,
        method="SLSQP",
        options={"eps": 1e-4, "maxiter": 250, "iprint": 2},  # > 24h
        callback=callback,
    )
    # base_bayes_net.fit(
    #     data,
    #     method="L-BFGS-B",
    #     options={"eps": 1e-4, "iprint": 99},  # > 2h
    # )
else:
    base_bayes_net = torch.load(
        "../../resources/south_german/base_population_model.pyt"
    )
    bayes_net_factory = torch.load(
        "../../resources/south_german/base_bayes_net_factory.pyt"
    )

In [None]:
pd.Series(base_bayes_net.parameters.detach())

In [None]:
%%capture --no-stdout --no-display
base_bayes_net.include_hidden = True
df = sample_bayes(base_bayes_net, full_base_bayes_net_input_space, seed=202404252023)
base_bayes_net.include_hidden = False
fig, axes = plt.subplot_mosaic(
    [
        ["status", "credit_history", "purpose", "savings"],
        [
            "employment_duration",
            "installment_rate",
            "other_debtors",
            "present_residence",
        ],
        ["property", "other_installment_plans", "housing", "number_credits"],
        ["job", "people_liable", "telephone", "foreign_worker"],
        ["personal_status_sex", "personal_status_sex", "duration", "duration"],
        ["age", "age", "amount", "amount"],
        ["gender", "marital_status", "bg", "income"],
    ],
    figsize=(15, 24),
)
for var in list(SouthGerman.variables) + ["gender", "marital_status", "bg", "income"]:
    if var not in ("age", "duration", "amount"):
        g = sns.histplot(
            df,
            x=var,
            hue="dataset",
            discrete=True,
            shrink=0.8,
            multiple="dodge",
            stat="percent",
            common_norm=False,
            legend=False,
            ax=axes[var],
        )
    else:
        g = sns.histplot(
            df,
            x=var,
            hue="dataset",
            stat="percent",
            common_norm=False,
            ax=axes[var],
        )
    g.set(title=var, xlabel=None)

In [None]:
%%capture --no-stdout --no-display
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

generated_df_ = df[df["dataset"] == "generated"][[var for var in SouthGerman.variables]]
pop_model_corrcoef = np.corrcoef(generated_df_.to_numpy().T)
# data_corrcoef = np.corrcoef(dataset_raw.data.T.numpy())
diff = pop_model_corrcoef - data_corrcoef
for corrcoef, ax in zip((pop_model_corrcoef, data_corrcoef, diff), axes):
    _ = sns.heatmap(
        corrcoef,
        vmin=-1.0,
        vmax=1.0,
        square=True,
        cmap="RdBu",
        xticklabels=SouthGerman.variables,
        yticklabels=SouthGerman.variables,
        ax=ax,
    )
_ = axes[0].set_title("Population Model")
_ = axes[1].set_title("Training Data")
_ = axes[2].set_title("Difference")

## Export the Population Model
To feed the samples of the population model to a neural network, we still need to apply
z-score normalization.
We do this using a linear layer that is applied to the samples of the population model.

Also, we make `gender` visible for fairness verification, but discard the gender value for the input of the neural network.


In [None]:
gender_node.hidden = False
# This maintains the parameters of the distributions.
bayes_net = bayes_net_factory.create()

bayes_net_visible_variables = tuple(dataset_raw.columns) + ("gender",)
bayes_net_var_types = var_types | {"gender": TabularInputSpace.AttributeType.INTEGER}
bayes_net_integer_ranges = integer_ranges | {"gender": (0.0, gender_size - 1)}
bayes_net_input_space = TabularInputSpace(
    bayes_net_visible_variables,
    data_types=bayes_net_var_types,
    continuous_ranges={},
    integer_ranges=bayes_net_integer_ranges,
    categorical_values=categorical_values,
)

In [None]:
# first transformation: drop gender
# create an identity matrix with several extra all-zero columns at the end
weight = torch.zeros(
    dataset_input_space.input_shape + bayes_net_input_space.input_shape
)
i = torch.arange(weight.size(0))
weight[i, i] = 1.0
drop_extra_vars = nn.Linear(weight.size(1), weight.size(0), bias=False)
with torch.no_grad():
    drop_extra_vars.weight = nn.Parameter(weight, requires_grad=False)

In [None]:
mean = dataset_raw.data.mean(dim=0)
std = dataset_raw.data.std(dim=0)
weight = torch.zeros(
    south_german_input_space.input_shape + dataset_input_space.input_shape
)
bias = torch.zeros(south_german_input_space.input_shape)
w_i = 0
for i, (var, vals) in enumerate(SouthGerman.variables.items()):
    if vals is not None:  # categorical
        for _ in range(len(vals)):
            weight[w_i, w_i] = 1.0
            w_i += 1
    else:  # continuous/integer
        # we calculate: (x - mean) / std = x/std - mean/std
        weight[w_i, w_i] = 1 / std[i]
        bias[i] = -mean[i] / std[i]
        w_i += 1
normalize = nn.Linear(weight.size(1), weight.size(0), bias=True)
with torch.no_grad():
    normalize.weight = nn.Parameter(weight, requires_grad=False)
    normalize.bias = nn.Parameter(bias, requires_grad=False)

bayes_net_transform = nn.Sequential(drop_extra_vars, normalize)

In [None]:
torch.save(
    (bayes_net, bayes_net_input_space, bayes_net_transform),
    "../../resources/south_german/bayes_net_population_model.pyt",
    pickle_module=dill,
)

Also safe the factory to facilitate experimenting with the Bayesian network without re-running the fitting.


In [None]:
torch.save(
    bayes_net_factory,
    "../../resources/south_german/base_bayes_net_factory.pyt",
    pickle_module=dill,
)