# JAX Regression Example - Interactive API

Run this jupyter notebook on a virtual environment.

In [None]:
!pip install jax==0.3.13 jaxlib==0.3.10 -q

GPU version of JAX. Pick the jax version compatible with the CUDA and cuDNN pre-installed.

In [None]:
# !pip install --upgrade pip # Careful with the pip upgrade, it may cause a package dependency related problems during OpenFL workflow execution.

# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
# Note: wheels only available on linux.
# !pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

# Installs the wheel compatible with Cuda >= 11.4 and cudnn >= 8.2
# !pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

# Installs the wheel compatible with Cuda >= 11.1 and cudnn >= 8.0.5
# !pip install "jax[cuda11_cudnn805]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

In [None]:
# Without either of the below flags, JAX XLA raised CUDA_OUT_OF_MEMORY exception.
# JAX XLA pre-allocates 90% of the GPU at start

# Below flag to restrict max GPU allocation to 50%
%env XLA_PYTHON_CLIENT_MEM_FRACTION=.5

# OR

# set XLA_PYTHON_CLIENT_PREALLOCATE to false to incrementally allocate GPU memory as and when required. But can take entire GPU by the end.
# %env XLA_PYTHON_CLIENT_PREALLOCATE=false


In [None]:
# Mandatory imports for Federation

import jax
import jax.numpy as jnp
import time

In [None]:
# Both the MSE function are optimal and accurate in terms of correctness.

# Calculate MSE approach 1
def mse_loss_function1(W, X, y):
    y_pred = jnp.dot(X, W)
    mse_error = y_pred - y
    return jnp.mean(jnp.square(mse_error))

# Calculate MSE approach 2
def mse_loss_function2(W, X, Y):
    def squared_error(x, y):
        y_pred = jnp.dot(x, W)
        return jnp.inner(y-y_pred, y-y_pred)
    vectorized_square_error = jax.vmap(squared_error)
    return jnp.mean(vectorized_square_error(X, Y), axis=0)

# Weight update, JAX compiled function. Consequent executions are way faster!!!.
def update(W, x, y, lr):
    W = W - lr * jax.grad(mse_loss_function1)(W, x, y)
    return W

In [None]:
class LinearRegression:
    def __init__(self, n_feat: int) -> None:
        self.weights = jnp.ones(n_feat)
    
    def mse(self, X, y) -> float:
        return mse_loss_function1(self.weights, X, y)
 
    def predict(self, X):
        return jnp.dot(X, self.weights)
    
    def fit(self, X, Y, n_epochs : int, learning_rate : int, silent : bool) -> None:
        
        # Speed up weight updates with consecutive calls to jitted `update` function.
        update_weights = jax.jit(update)
        
        start_time = time.time()
        print('Training Loss at start: ', self.mse(X, Y))
        for i in range(n_epochs):
            self.weights = update_weights(self.weights, X, Y, learning_rate)
            if i % int(n_epochs/10) == 0 and not silent:
                print(str(i), 'Training Loss: ', self.mse(X, Y))

        print("--- %s seconds ---" % (time.time() - start_time))

    

# JAX Linear Regression with federation

## Connect to a Federation

In [None]:
# Create a federation
from openfl.interface.interactive_api.federation import Federation

# please use the same identificator that was used in signed certificate
client_id = 'frontend'
director_node_fqdn = 'localhost'
director_port = 50050

federation = Federation(
    client_id=client_id,
    director_node_fqdn=director_node_fqdn,
    director_port=director_port,
    tls=False
)

In [None]:
shard_registry = federation.get_shard_registry()
shard_registry

### Initialize Data Interface

In [None]:
from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment

class LinearRegressionDataSet(DataInterface):
    def __init__(self, **kwargs):
        """Initialize DataLoader."""
        self.kwargs = kwargs

    @property
    def shard_descriptor(self):
        """Return shard descriptor."""
        return self._shard_descriptor
    
    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        """
        Describe per-collaborator procedures or sharding.
        
        This method will be called during a collaborator initialization.
        Local shard_descriptor  will be set by Envoy.
        """
        self._shard_descriptor = shard_descriptor
        self.train_set = shard_descriptor.get_dataset('train')
        self.val_set = shard_descriptor.get_dataset('val')
        
    def get_train_loader(self, **kwargs):
        """Output of this method will be provided to tasks with optimizer in contract."""
        return self.train_set

    def get_valid_loader(self, **kwargs):
        """Output of this method will be provided to tasks without optimizer in contract."""
        return self.val_set

    def get_train_data_size(self):
        """Information for aggregation."""
        return len(self.train_set)

    def get_valid_data_size(self):
        """Information for aggregation."""
        return len(self.val_set)
    
lin_reg_dataset = LinearRegressionDataSet()

### Define Model Interface

In [None]:
framework_adapter = 'custom_adapter.CustomFrameworkAdapter'

# LinearRegression class accepts a parameter n_features. Should be same as `sample_shape` from `director_config.yaml`
fed_model = LinearRegression(1)
MI = ModelInterface(model=fed_model, optimizer=None, framework_plugin=framework_adapter)

# Save the initial model state
initial_model = LinearRegression(1)

### Register Tasks
We need to employ a trick reporting metrics. OpenFL decides which model is the best based on an *increasing* metric.

In [None]:
TI = TaskInterface()

@TI.add_kwargs(**{'lr': 0.01,
                   'epochs': 101})
@TI.register_fl_task(model='my_model', data_loader='train_data', \
                     device='device', optimizer='optimizer')     
def train(my_model, train_data, optimizer, device, lr, epochs):
    X, Y = train_data[:,:-1], train_data[:,-1]
    my_model.fit(X, Y, epochs, lr, silent=False)
    return {'train_MSE': my_model.mse(X, Y),}

@TI.register_fl_task(model='my_model', data_loader='val_data', device='device')
def validate(my_model, val_data, device):
    X, Y = val_data[:,:-1], val_data[:,-1] 
    return {'validation_MSE': my_model.mse(X, Y),}

### Run the federation

In [None]:
experiment_name = 'jax_linear_regression_experiment'
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [None]:
fl_experiment.start(model_provider=MI,
                    task_keeper=TI,
                    data_loader=lin_reg_dataset,
                    rounds_to_train=2)

In [None]:
fl_experiment.stream_metrics()

# JAX Linear Regression without federation (Optional Simulation)

In [None]:
!pip install matplotlib scikit-learn -q

In [None]:
# Imports for running JAX Linear Regression example without OpenFL.

import matplotlib.pyplot as plt

%matplotlib inline
from matplotlib.pylab import rcParams
rcParams['figure.figsize'] = 7, 5

from jax import make_jaxpr
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split

#### Simple Linear Regression
<img src="https://www.analyticsvidhya.com/wp-content/uploads/2016/01/eq5-1.png" width="500">



In [None]:
# create a dataset with n_features
X, y = make_regression(n_samples=100, n_features=1, noise=10, random_state=42)

# Train test split - Default 0.75/0.25
X, X_test, y, y_test = train_test_split(X, y, random_state=42)

Visualize data distribution

In [None]:
_ = plt.scatter(X, y)

In [None]:
_ = plt.scatter(X_test, y_test)

In [None]:
# JAX logical execution plan
print(jax.make_jaxpr(update)(jnp.ones(X.shape[1]), X, y, 0.01))

In [None]:
# X.shape -> (n_samples, n_features)

lr_model = LinearRegression(X.shape[1])
lr = 0.01
epochs = 101

print(f"Initial Test MSE: {lr_model.mse(X_test,y_test)}")

# silent: logging verbosity
lr_model.fit(X,y, epochs, lr, silent=False)

print(f"Final Test MSE: {lr_model.mse(X_test,y_test)}")

print(f"Final parameters: {lr_model.weights}")