# Mixture of Experts (MoE) for Logistic Regression

In this notebook, we will explore how to use the Mixture of Experts (MoE) framework to fit pre-trained 
generalized linear models to a simple 2 dimensional dataset. We will compare the MoE model to a simple logistic model that follows the `MultiPRS` framework (i.e. use all pre-trained models as covariates). 

To start, let's import the necessary libraries and modules that will be needed for this notebook.


In [8]:
import numpy as np
import pandas as pd
import sys
from sklearn.linear_model import LogisticRegression
sys.path.append("../model/")
from PRSDataset import PRSDataset
from moe import MoEPRS
from baseline_models import MultiPRS
import matplotlib.pyplot as plt
import seaborn as sns

## The Data

After importing the necessary libraries and modules, we will create a simple 2 dimensional dataset 
that displays heterogeneity in the relationship between the input variables `X` and the binary output variable `y`. To achieve this, we will assume that there are two underlying decision boundaries between "cases" and "controls" that can differ depending on the position of the input point `X`:

In [5]:
import numpy as np
from scipy.special import expit
import matplotlib.pyplot as plt

# Set the random seed for reproducibility
np.random.seed(0)

# Number of points
N = 2000

# Generate random points in polar coordinates
r = np.random.normal(5, 1, N)
theta = np.random.uniform(0, np.pi, N)

# Convert to Cartesian coordinates
x1 = r * np.cos(theta)
x2 = r * np.sin(theta)

# Add some noise
x1 += np.random.normal(0, 0.5, N)
x2 += np.random.normal(0, 0.5, N)

# Assign labels
labels = np.where(np.abs(x2) + np.abs(x1) > 4, 0, 1)
markers = ['o', '+']

# Plot the points
plt.figure(figsize=(8, 8))

plt.scatter(x1[labels == 0], x2[labels==0], marker=markers[0], label='Control')
plt.scatter(x1[labels == 1], x2[labels==1], marker=markers[1], label='Case')
plt.xlabel('x1')
plt.ylabel('x2')
plt.legend()
plt.show()


The plot above shows the generated data. As we can see, the relationship between `x` and `y` cannot be 
easily modeled with a linear decision boundary, however, it can be well approximated by generalized 
linear models locally (within the two halves of the domain that we defined before).

## Pre-trained Linear Models

To make this setup more equivalent to the polygenic score (PGS) setting, let's assume that we are given access to a number 
of pre-trained linear models. Let's further assume that, similar to GWAS/PGS settings, those models were pre-trained on 
different subsets or domains of the data. To do that, we will implement the following procedure:

1. Split the data into two (here we'll split based on the `x1` feature).
2. Train a logistic model on each half of the data.

In [11]:
# ----------------------------------------------------------
# Inference:

# 1) split the data heuristically at x1=0.5

train_x = np.vstack([x1, x2]).T

train_data_1 = train_x[x1 < 0.5, :]
train_label_1 = labels[x1 < 0.5]

train_data_2 = train_x[x1 >= 0.5, :]
train_label_2 = labels[x1 >= 0.5]

# 2) Infer two sets of parameters

lrm = LogisticRegression(fit_intercept=False, penalty=None)
b1_hat = lrm.fit(train_data_1, train_label_1).coef_.T
b2_hat = lrm.fit(train_data_2, train_label_2).coef_.T

print("Inferred b1:", b1_hat)
print("Inferred b2:", b2_hat)

# 3) Get expert predictions on entire dataset:

exp1_pred = train_x.dot(b1_hat)
exp2_pred = train_x.dot(b2_hat)


expert_predictions = np.concatenate([exp1_pred.reshape(-1, 1),
                        exp2_pred.reshape(-1, 1)], axis=1)

# 4) plot expert predictions:

plt.scatter(train_x[:, 0], train_x[:, 1], c=expit(exp1_pred).flatten(), vmin=0., vmax=1.)
plt.title("Expert 1 predictions (trained on x1 >= 0.5)")
plt.xlabel("x1")
plt.ylabel("x2")
plt.colorbar()
plt.show()
plt.scatter(train_x[:, 0], train_x[:, 1], c=expit(exp2_pred).flatten(), vmin=0., vmax=1.)
plt.title("Expert 2 predictions (trained on x1 < 0.5)")
plt.xlabel("x1")
plt.ylabel("x2")
plt.colorbar()
plt.show()

## Meta models

Now that we've generated the pre-trained models, we can use them to fit meta or ensemble model. A meta PRS model is a model that takes the predictions of simpler models and combines them in such a way that the combined model fits the data better than any of the simpler models individually. Before we explore the meta PRS models, 
let's aggregate our data in a `PRSDataset` object:

In [15]:
# Aggregate the data + model predictions in a dataframe:

df = pd.DataFrame({'x1': x1, 'x2': x2, 'Model_1': exp1_pred.flatten(), 'Model_2': exp2_pred.flatten(), 
                   'y': labels})

df.head()

In [17]:
# Create a PRSDataset object:

prs_data = PRSDataset(df, 
                      phenotype_col='y', # Specify the target column (equivalent to phenotype)
                      prs_cols=['Model_1', 'Model_2'], # Specify the column names containing pre-trained model predictions
                      covariates_cols=['x1', 'x2']) # Specify the column names containing the covariates (if any)


### MultiPRS

First, we will explore the MultiPRS model formulation, which takes the individual predictions of the pre-trained
models and combines them linearly to get the final prediction. This can be done with the help of the `MultiPRS` model class:

In [25]:
# Create the MultiPRS model:
multi_prs = MultiPRS(prs_data,
                     expert_cols=['Model_1', 'Model_2'], # Specify the column names containing pre-trained model predictions
                     covariates_cols=['x1', 'x2'])

# Fit the MultiPRS model:
multi_prs.fit()

# Generate predictions (probabilities)
multi_prs_pred = multi_prs.predict()

# Plot the predictions:
plt.scatter(train_x[:, 0], train_x[:, 1], c=multi_prs_pred, vmin=0., vmax=1.)
plt.title("MultiPRS predictions")
plt.xlabel("x1")
plt.ylabel("x2")
plt.colorbar()
plt.show()

As we can see, the MultiPRS model struggles to fit this heterogeneous data well. This is because it has to minimize the error with a single logistic model across the entire domain. This is where the strengths of the Mixture of Experts (MoE) model come to shine.

### Mixture of Experts (MoE)

The Mixture of Experts (MoE) model is a more flexible meta PRS model that can fit the data better than the MultiPRS model. The MoE model formulates the prediction as a weighted sum of the predictions of the pre-trained models, where the weights are determined by a gating network:

$$
y_{MoE}(i) = \sum_{k=1}^{K} g_k(x_i) \hat{y}_k(i)
$$

Here, $g_k(x_i)$ is the output of the gating model for the $i$-th observation and the $k$-th pre-trained model. In our case, the gating model is a linear model that takes the covariates as input and outputs the weights for the pre-trained models (i.e. it performs the equivalent of softmax regression). The weights are then used to combine the predictions of the pre-trained models to get the final prediction. The important thing to realize here is that the weights of the gating model depend on the input, which allows the MoE to select the best model for each input domain.

Here's how the MoE model can be fit to the data using the `MoEPRS` class from the `moe` module:

In [32]:
# Create the MultiPRS model:
moe_model = MoEPRS(prs_data,
                   expert_cols=['Model_1', 'Model_2'], # Specify the column names containing pre-trained model predictions
                   gate_input_cols=['x1', 'x2'], # Specify the input columns for the gating model
                   expert_add_intercept=False,
                   gate_add_intercept=True)

# Fit the MoE model:
moe_model.fit()

# Generate the predictions (probabilities)
moe_pred = moe_model.predict()

# Plot the predictions:
plt.scatter(train_x[:, 0], train_x[:, 1], c=moe_pred, vmin=0., vmax=1.)
plt.title("MoE predictions")
plt.xlabel("x1")
plt.ylabel("x2")
plt.colorbar()
plt.show()

This looks a bit better! The MoE model is able to fit the data by selectively using the pre-trained models 
in the regions where they fit the data well. To see what the gating model is doing, 
let's plot its weights as a function of the input variables `X`:

In [36]:
plt.scatter(train_x[:, 0], train_x[:, 1], c=moe_model.predict_proba()[:, 0].flatten(), vmin=0., vmax=1.)
plt.title("Gate weight for model 1")
plt.xlabel('x1')
plt.ylabel('x2')
plt.colorbar()
plt.show()