# Linear Regression (California Housing Data (1990))

The standard linear regression minimization problem is:

**Objective**: minimize the sum of squared residuals

**Formula**:

$$
\min_{\boldsymbol{w}, b} \sum_{i=1}^{n} (y_i - (\boldsymbol{w}^\top \boldsymbol{x}_i + b))^2
$$

Or in matrix form (with bias absorbed):

$$
\min_{\boldsymbol{\theta}} \|\boldsymbol{X}\boldsymbol{\theta} - \boldsymbol{y}\|^2_2
$$

Where:

* $\boldsymbol{x}_i$ ∈ ℝᵈ is the input vector for sample $i$
* $y_i$ ∈ ℝ is the target
* $\boldsymbol{w}$ ∈ ℝᵈ is the weight vector
* $b$ ∈ ℝ is the bias
* $\boldsymbol{X}$ ∈ ℝⁿˣᵈ is the design matrix (rows are $\boldsymbol{x}_i^\top$)
* $\boldsymbol{\theta} = \begin{bmatrix} \boldsymbol{w} \\ b \end{bmatrix}$ if bias is included via augmented input

Solution (if XᵀX invertible):

$$
\boldsymbol{\theta}^* = (\boldsymbol{X}^\top \boldsymbol{X})^{-1} \boldsymbol{X}^\top \boldsymbol{y}
$$


In [383]:
# Import libs
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import pandas as pd

In [343]:
# Load data
df = pd.read_csv('data/housing.csv')
df

Unnamed: 0,longitude,latitude,housing_median_age,total_rooms,total_bedrooms,population,households,median_income,median_house_value,ocean_proximity
0,-122.23,37.88,41.0,880.0,129.0,322.0,126.0,8.3252,452600.0,NEAR BAY
1,-122.22,37.86,21.0,7099.0,1106.0,2401.0,1138.0,8.3014,358500.0,NEAR BAY
2,-122.24,37.85,52.0,1467.0,190.0,496.0,177.0,7.2574,352100.0,NEAR BAY
3,-122.25,37.85,52.0,1274.0,235.0,558.0,219.0,5.6431,341300.0,NEAR BAY
4,-122.25,37.85,52.0,1627.0,280.0,565.0,259.0,3.8462,342200.0,NEAR BAY
...,...,...,...,...,...,...,...,...,...,...
20635,-121.09,39.48,25.0,1665.0,374.0,845.0,330.0,1.5603,78100.0,INLAND
20636,-121.21,39.49,18.0,697.0,150.0,356.0,114.0,2.5568,77100.0,INLAND
20637,-121.22,39.43,17.0,2254.0,485.0,1007.0,433.0,1.7000,92300.0,INLAND
20638,-121.32,39.43,18.0,1860.0,409.0,741.0,349.0,1.8672,84700.0,INLAND


In [344]:
df.describe()

Unnamed: 0,longitude,latitude,housing_median_age,total_rooms,total_bedrooms,population,households,median_income,median_house_value
count,20640.0,20640.0,20640.0,20640.0,20433.0,20640.0,20640.0,20640.0,20640.0
mean,-119.569704,35.631861,28.639486,2635.763081,537.870553,1425.476744,499.53968,3.870671,206855.816909
std,2.003532,2.135952,12.585558,2181.615252,421.38507,1132.462122,382.329753,1.899822,115395.615874
min,-124.35,32.54,1.0,2.0,1.0,3.0,1.0,0.4999,14999.0
25%,-121.8,33.93,18.0,1447.75,296.0,787.0,280.0,2.5634,119600.0
50%,-118.49,34.26,29.0,2127.0,435.0,1166.0,409.0,3.5348,179700.0
75%,-118.01,37.71,37.0,3148.0,647.0,1725.0,605.0,4.74325,264725.0
max,-114.31,41.95,52.0,39320.0,6445.0,35682.0,6082.0,15.0001,500001.0


In [345]:
# Preprocess data
def preprocess_df(df, shift=None, scale=None, category_map=None):
    if shift is None:
        shift = {}
    if scale is None:
        scale = {}
    if category_map is None:
        category_map = {}
    dfs = []
    
    for column, s in df.to_dict('series').items():
        # one hot encode categorical, rescale numerics
        if pd.api.types.is_string_dtype(s):
            if column not in category_map:
                category_map[column] = {c: i for i, c in enumerate(s.unique())}
    
            cat2idx = category_map[column]
            idxs = jnp.array([cat2idx.get(val, -1) for val in s])  # -1 for unknowns
            one_hot = jax.nn.one_hot(idxs, num_classes=len(cat2idx))
    
            # wrap in DataFrame for concat, with column names like "column_val1"
            one_hot_df = pd.DataFrame(
                one_hot,
                columns=[f"{column}_{c}" for c in cat2idx],
                index=s.index
            )
            dfs.append(one_hot_df)
        elif pd.api.types.is_numeric_dtype(s):
            if column not in shift:
                shift[column] = s.mean()
            if column not in scale:
                scale[column] = s.std(ddof=0)
            df = pd.DataFrame((s - shift[column]) / scale[column]).fillna(0)
            dfs.append(df)
        else:
            raise TypeError(f"unhandled dtype: {s.dtype}")
    
    df = pd.concat(dfs, axis=1)
    return df, shift, scale, category_map

def extract_labels(df, label_column):
    y = df[label_column]
    X = df.drop(label_column, axis=1)
    return X, y

def train_test_split(data, frac):
    shuffled = data.sample(frac=1, random_state=42).reset_index(drop=True)
    idx = int(len(shuffled) * frac)
    return data[:idx], data[idx:]

train, test = train_test_split(df, 0.8)
train, shift, scale, category_map = preprocess_df(train)
X_train, y_train = extract_labels(train, 'median_house_value')

In [346]:
X_train, y_train = jnp.array(X_train.values), jnp.array(y_train.values)

In [347]:
X_train = jnp.hstack([X_train, jnp.ones((X_train.shape[0], 1))])

In [348]:
@jax.jit
def fwd(theta, X):
    return X @ theta

@jax.jit
def loss(theta):
    errvec = (fwd(theta, X_train) - y_train)
    loss = jnp.dot(errvec, errvec)
    return loss

In [349]:
grad_loss = jax.grad(loss, argnums=0)

In [355]:
key = jax.random.key(0)
key, subkey = jax.random.split(key)
theta = jax.random.normal(subkey, shape=(X_train.shape[1]))
del subkey

@jax.jit
def train_loop(theta, lr=0.00001):
    return theta - lr * grad_loss(theta)

for i in range(100000):
    theta = train_loop(theta)
    if i % 5000 == 0:
        print(loss(theta))

40561.14
6182.0547
6170.965
6167.6206
6166.5967
6166.3105
6166.215
6166.187
6166.173
6166.1694
6166.1655
6166.1733
6166.176
6166.1763
6166.175
6166.175
6166.175
6166.1753
6166.1743
6166.175


In [373]:
def predict(theta, X):
    X_test, _, _, _ = preprocess_df(X, shift, scale, category_map)
    X_test = jnp.array(X_test.values)
    X_test = jnp.hstack([X_test, jnp.ones((X_test.shape[0], 1))])
    print(X_test.shape)
    print(X_train.shape)
    return fwd(theta, X_test) * scale['median_house_value'] + shift['median_house_value']

X_test, y_test = extract_labels(test, 'median_house_value')
shift['median_house_value']

np.float64(202067.03131056202)

In [375]:
y_hat = predict(theta, X_test)

(4128, 14)
(16512, 14)


In [376]:
pd.DataFrame(dict(y_hat=y_hat, y_test=y_test))

Unnamed: 0,y_hat,y_test
16512,112027.117188,165600.0
16513,131039.703125,126100.0
16514,93311.359375,94400.0
16515,105716.671875,91900.0
16516,114775.171875,124300.0
...,...,...
20635,29481.109375,78100.0
20636,54390.031250,77100.0
20637,37562.687500,92300.0
20638,48222.812500,84700.0


In [381]:
rmse = jnp.sqrt(jnp.mean((y_hat - y_test.values) ** 2))

In [382]:
rmse

Array(66458.1, dtype=float32)