#Jax Linear Regression

In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap, random, tree_multimap

In [2]:
#Returns a prediction value given inputs and parameters
def pred(x, params):
  return jnp.dot(params["weights"], x) + params["bias"]

#Vectorized version of prediction file for batches of inputs (multiple rows)
multiple_preds = vmap(pred, (0, None))

#Given parameters, batch of inputs, and batch of corresponding true outputs, returns mean squared error.
def mse(params, x_multiple, y_multiple):
  print(x_multiple.shape)
  prediction = multiple_preds(x_multiple, params)
  actual = y_multiple
  return jnp.mean(jnp.multiply(prediction - actual, prediction - actual))

#Given parameters, batch of inputs, and batch of corresponding true outputs, returns R^2 value.
def score(params, x_multiple, y_multiple):
  prediction = multiple_preds(x_multiple, params)
  actual = y_multiple
  return 1 - (jnp.dot(prediction - actual, prediction - actual) / jnp.dot(actual - jnp.mean(actual), actual - jnp.mean(actual)))

In [3]:
class LinearRegression:
  
  #Given inputs and correct output, trains a one layer Linear Regression model.
  def train(self, x_data, y_data, num_steps=1000, step_size=0.01, display_info_step=100):
    dimension = x_data.shape[1]
    key = random.PRNGKey(1509)

    #Initialize parameters
    w_key, b_key = random.split(key)
    current_params = {"weights": random.normal(w_key, (dimension,)), "bias": random.normal(b_key)}

    #At each step, updates the parameters with using the gradient of mse function
    def training_step(params, x_multiple, y_multiple, step_size):
      loss_gradients = grad(mse)(params, x_multiple, y_multiple)
      return tree_multimap(lambda param, gradient: param - gradient * step_size, params, loss_gradients)
    
    #Compile training_step function with jit
    jit_training_step = jit(training_step)
    
    #Now the actual training
    for i in range(num_steps):
      current_params = jit_training_step(current_params, x_data, y_data, step_size)
      if display_info_step > 0:
        if i % display_info_step == 0:
          print(f"Step {i} R-Squared: {score(current_params, x_data, y_data)}")

    #Sets the model's coefficients and intercept properties to the final parameters
    self.coefficients = current_params["weights"]
    self.intercept = current_params["bias"]

  #Given an input, returns a prediction using the stored parameters.
  def predict(self, x, multiple=False):
    params = {"weights": self.coefficients, "bias": self.intercept}
    if multiple:
      return multiple_preds(x, params)
    else:
      return pred(x, params)

In [4]:
import numpy as np
from sklearn import preprocessing
from sklearn.preprocessing import StandardScaler
import pandas as pd

In [5]:
house_data = pd.read_csv("house_train.csv")

In [6]:
def sorting(df, target):
  preparation_df = df.copy()
  preparation_df.fillna('none', inplace=True)
  dtype_list = preparation_df.dtypes.tolist()
  for idx, val in enumerate(dtype_list):
    dtype_list[idx] = str(val)
  categorical_variables = []
  for idx, val in enumerate(preparation_df.columns.tolist()):
    if dtype_list[idx] == 'object':
      categorical_variables.append(val)
  categorical_variables.append('MSSubClass')
  sort_columns = []
  for column in categorical_variables:
    sort_columns.append(preparation_df.groupby(column).mean().sort_values(by=target).index.tolist())        
  return sort_columns, categorical_variables

In [7]:
def preparation(df, sorted_columns_list, categorical_variables):
    preparation_df = df.copy()
    preparation_df.fillna('none', inplace=True)
    for column in categorical_variables:
        globals()['{}_list'.format(column)] = preparation_df[column].tolist()
        for idx, val in enumerate(sorted_columns_list[categorical_variables.index(column)]):
            for index, value in enumerate(globals()['{}_list'.format(column)]):
                if value == val:
                    globals()['{}_list'.format(column)][index]=idx
        preparation_df[column]=globals()['{}_list'.format(column)]
        preparation_df.replace('none', 0, inplace=True)
    return preparation_df

In [8]:
def correlation(df, target, corr_constant):
    preparation_df = df.copy()
    correlation_df = preparation_df.corr()
    features = correlation_df[abs(correlation_df[target])>corr_constant][[target]].index.tolist()
    return features[0:-1]

In [9]:
def variable_prep(prepped_df):
    X = preprocessing.StandardScaler().fit(prepped_df).transform(prepped_df.astype(float))
    return X

In [10]:
def train_test_prep(train_df, test_df, target, corr_constant):
    sorted_columns, categorical_variables = sorting(train_df, target)
    prepped_train_df = preparation(train_df, sorted_columns, categorical_variables)
    features = correlation(prepped_train_df, target, corr_constant)
    prepped_test_df = preparation(test_df, sorted_columns, categorical_variables)
    y_train = np.array(prepped_train_df[target])
    prepped_train_df.drop("SalePrice", 1)
    X_test = variable_prep(prepped_test_df[features])
    X_train = variable_prep(prepped_train_df[features])
    return X_train, y_train, X_test

In [11]:
x_data, y_data, _ = train_test_prep(house_data, house_data.drop("SalePrice", axis=1), "SalePrice", 0.06)

In [12]:
model = LinearRegression()
model.train(x_data, y_data)



(1460, 67)
Step 0 R-Squared: -4.616201400756836
Step 100 R-Squared: 0.7634928226470947
Step 200 R-Squared: 0.8541419506072998
Step 300 R-Squared: 0.8570235967636108
Step 400 R-Squared: 0.8576697707176208
Step 500 R-Squared: 0.858012318611145
Step 600 R-Squared: 0.8582216501235962
Step 700 R-Squared: 0.8583570718765259
Step 800 R-Squared: 0.8584475517272949
Step 900 R-Squared: 0.8585088849067688


In [13]:
print(model.coefficients)
print(model.intercept)

[-2304.588       42.10295   4200.547     1618.1536     635.2579
  2041.2012    1635.4426   12022.644     2205.7495   -1884.1558
  3727.7793    -456.05234  12091.685     6282.655     1024.4338
 -2006.2737    2310.322     4183.5156    1363.8425     175.31981
 -2696.882     8340.476     4746.1924    -141.59502   1073.7736
  4201.6626   -3491.3027    4275.466    -1118.3302    4286.069
   755.6142   -1419.5288    3192.6128    -133.51602   1013.94275
   445.72784  -1477.3622    8510.692     7479.5444   11286.721
  2873.273     1207.6199     976.81805  -3052.6152   -3019.74
  5863.918     6386.53      3142.1365    2424.504      375.67484
 -1841.8363     263.46204     55.584454  5586.49      1283.6678
   114.65431  -2226.829      -68.69399   2418.2124    -129.31169
  -336.16888   2187.6536   -8785.606     9677.463     -216.31047
  4385.724     2534.7065  ]
180920.81


In [14]:
score({"weights": model.coefficients, "bias": model.intercept}, x_data, y_data)

DeviceArray(0.85855097, dtype=float32)