# Logistic Regression (Wisconsin Diagnostic Breast Cancer)
Logistic regression minimizes the **negative log-likelihood**, also called the **binary cross-entropy loss**. For binary classification:

$$
\min_{\theta} \ \mathcal{L}(\theta) = - \sum_{i=1}^{n} \left[ y_i \log \sigma(x_i^\top \theta) + (1 - y_i) \log (1 - \sigma(x_i^\top \theta)) \right]
$$

Where:

* $x_i \in \mathbb{R}^d$ is the input vector for example $i$
* $y_i \in \{0, 1\}$ is the true label
* $\theta \in \mathbb{R}^d$ are the weights
* $\sigma(z) = \frac{1}{1 + e^{-z}}$ is the sigmoid function

It is recommended to absorb the bias into theta, by appending 1 to each feature and an extra parameter for the bias.

In [23]:
# Import libraries
import pandas as pd
import jax
import jax.numpy as jnp

# Dataset overview:

Relevant to our purposes is the following information from `data/wdbc.names`:

```
5. Number of instances: 569 

6. Number of attributes: 32 (ID, diagnosis, 30 real-valued input features)

7. Attribute information

1) ID number
2) Diagnosis (M = malignant, B = benign)
3-32)

Ten real-valued features are computed for each cell nucleus:

	a) radius (mean of distances from center to points on the perimeter)
	b) texture (standard deviation of gray-scale values)
	c) perimeter
	d) area
	e) smoothness (local variation in radius lengths)
	f) compactness (perimeter^2 / area - 1.0)
	g) concavity (severity of concave portions of the contour)
	h) concave points (number of concave portions of the contour)
	i) symmetry 
	j) fractal dimension ("coastline approximation" - 1)

Several of the papers listed above contain detailed descriptions of
how these features are computed. 

The mean, standard error, and "worst" or largest (mean of the three
largest values) of these features were computed for each image,
resulting in 30 features.  For instance, field 3 is Mean Radius, field
13 is Radius SE, field 23 is Worst Radius.

All feature values are recoded with four significant digits.

8. Missing attribute values: none

9. Class distribution: 357 benign, 212 malignant
```

The following column list converts this information into a descriptive pandas dataframe.

In [24]:
import itertools

df = pd.read_csv('data/wdbc.data', header=None)
features = ['radius', 'texture', 'perimeter', 'area', 'smoothness', 'compactness', 'concavity', 'concave_points', 'symmetry', 'fractal_dimension']
measurement = ['mean', 'se', 'max']
df.columns = ['id_number', 'diagnosis'] + list(map(lambda t: '_'.join(t), itertools.product(features, measurement)))
df

Unnamed: 0,id_number,diagnosis,radius_mean,radius_se,radius_max,texture_mean,texture_se,texture_max,perimeter_mean,perimeter_se,...,concavity_max,concave_points_mean,concave_points_se,concave_points_max,symmetry_mean,symmetry_se,symmetry_max,fractal_dimension_mean,fractal_dimension_se,fractal_dimension_max
0,842302,M,17.99,10.38,122.80,1001.0,0.11840,0.27760,0.30010,0.14710,...,25.380,17.33,184.60,2019.0,0.16220,0.66560,0.7119,0.2654,0.4601,0.11890
1,842517,M,20.57,17.77,132.90,1326.0,0.08474,0.07864,0.08690,0.07017,...,24.990,23.41,158.80,1956.0,0.12380,0.18660,0.2416,0.1860,0.2750,0.08902
2,84300903,M,19.69,21.25,130.00,1203.0,0.10960,0.15990,0.19740,0.12790,...,23.570,25.53,152.50,1709.0,0.14440,0.42450,0.4504,0.2430,0.3613,0.08758
3,84348301,M,11.42,20.38,77.58,386.1,0.14250,0.28390,0.24140,0.10520,...,14.910,26.50,98.87,567.7,0.20980,0.86630,0.6869,0.2575,0.6638,0.17300
4,84358402,M,20.29,14.34,135.10,1297.0,0.10030,0.13280,0.19800,0.10430,...,22.540,16.67,152.20,1575.0,0.13740,0.20500,0.4000,0.1625,0.2364,0.07678
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
564,926424,M,21.56,22.39,142.00,1479.0,0.11100,0.11590,0.24390,0.13890,...,25.450,26.40,166.10,2027.0,0.14100,0.21130,0.4107,0.2216,0.2060,0.07115
565,926682,M,20.13,28.25,131.20,1261.0,0.09780,0.10340,0.14400,0.09791,...,23.690,38.25,155.00,1731.0,0.11660,0.19220,0.3215,0.1628,0.2572,0.06637
566,926954,M,16.60,28.08,108.30,858.1,0.08455,0.10230,0.09251,0.05302,...,18.980,34.12,126.70,1124.0,0.11390,0.30940,0.3403,0.1418,0.2218,0.07820
567,927241,M,20.60,29.33,140.10,1265.0,0.11780,0.27700,0.35140,0.15200,...,25.740,39.42,184.60,1821.0,0.16500,0.86810,0.9387,0.2650,0.4087,0.12400


In [74]:
def shuffle_df(df):
    return df.sample(frac=1, random_state=42).reset_index(drop=True)

def split_df(df, frac):
    idx = int(len(df) * frac)
    return df[:idx], df[idx:]

def preprocess_df(df):
    labels = (df['diagnosis'] == 'M')
    features = df.drop(['id_number', 'diagnosis'], axis=1)
    return features, labels

train, test = split_df(shuffle_df(df), 0.8)
feats, labels = preprocess_df(train)

In [75]:
feats, labels = preprocess_df(train)
X_train, y_train = jnp.array(feats.values), jnp.array(labels.values)

In [76]:
def create_params(X_train):
    key = jax.random.key(42)

    mu = jnp.mean(X_train, axis=0)
    sigma = jnp.std(X_train, axis=0)
    
    key, subkey = jax.random.split(key)
    theta = jax.random.normal(subkey, (X_train.shape[1]+1,))
    del subkey
    
    params = (mu, sigma, theta)
    
    return params
params = create_params(X_train)

In [100]:
def fwd(params, X):
    mu, sigma, theta = params
    X_norm = (X - mu) / sigma
    X_norm = jnp.hstack([X_norm, jnp.ones((X.shape[0], 1))])
    logits = X_norm @ theta
    return logits

# fwd(theta, X_train)

def predict(params, X):
    return jax.nn.sigmoid(fwd(params, X))

def loss(theta, mu, sigma, X, y):
    params = mu, sigma, theta
    y = 2*y - 1 # remap labels to -1, +1
    logits = fwd(params, X)
    loss = jnp.mean(jnp.logaddexp(0, -y * logits))  # y in {-1, 1}, @TODO: unpack; this line is pure chatgpt
    return loss + 1e-4 * jnp.sum(theta ** 2)

# loss(params, X_train, y_train)

In [101]:
grad_loss = jax.jit(jax.grad(loss))

@jax.jit
def train(params):
    mu, sigma, theta = params
    lr = 0.0001

    def body(i, theta):
        theta -= lr * grad_loss(theta, mu, sigma, X_train, y_train)

        def do_print(_):
            jax.debug.print("step {i}, loss: {l}", i=i, l=loss(theta, mu, sigma, X_train, y_train))
            return None

        _ = jax.lax.cond(i % 100000 == 0, do_print, lambda _: None, operand=None)
        return theta

    theta = jax.lax.fori_loop(0, 10000000, body, theta)
    return theta

theta = train(create_params(X_train))

step 0, loss: 1.7558537721633911
step 100000, loss: 0.09227238595485687
step 200000, loss: 0.07621347159147263
step 300000, loss: 0.0691460371017456
step 400000, loss: 0.06489294767379761
step 500000, loss: 0.061977189034223557
step 600000, loss: 0.0598461739718914
step 700000, loss: 0.058205634355545044
step 800000, loss: 0.056896310299634933
step 900000, loss: 0.055809326469898224
step 1000000, loss: 0.05489137023687363
step 1100000, loss: 0.05408978834748268
step 1200000, loss: 0.0533839613199234
step 1300000, loss: 0.052760522812604904
step 1400000, loss: 0.0522017665207386
step 1500000, loss: 0.05171068385243416
step 1600000, loss: 0.05127420276403427
step 1700000, loss: 0.05086661875247955
step 1800000, loss: 0.05048423260450363
step 1900000, loss: 0.05012185499072075
step 2000000, loss: 0.04978455230593681
step 2100000, loss: 0.04946058616042137
step 2200000, loss: 0.04915091022849083
step 2300000, loss: 0.04885657876729965
step 2400000, loss: 0.04857207089662552
step 2500000, l

In [108]:
loss(theta, mu, sigma, X_train, y_train)

Array(0.03968935, dtype=float32)

In [110]:
feats, labels = preprocess_df(test)
X_test, y_test = jnp.array(feats.values), jnp.array(labels.values)

In [111]:
params = [mu, sigma, theta]
y_pred = predict(params, X_test)
y_pred

Array([9.99999881e-01, 9.99999881e-01, 3.51411509e-05, 9.99995708e-01,
       1.00000000e+00, 5.97806029e-05, 3.86891770e-07, 3.05701105e-05,
       1.00000000e+00, 3.50525681e-07, 9.99999404e-01, 1.00000000e+00,
       9.99983668e-01, 8.49106610e-02, 5.14403544e-02, 4.29154234e-03,
       9.99707639e-01, 5.18027155e-05, 5.42503141e-04, 9.99883533e-01,
       1.00000000e+00, 9.92534518e-01, 9.99998212e-01, 4.54434864e-02,
       9.99103963e-01, 9.47632536e-04, 1.40104760e-07, 4.80581395e-04,
       1.49442954e-03, 8.56768456e-04, 1.37061816e-05, 1.12894350e-06,
       1.00000000e+00, 1.00000000e+00, 9.99176800e-01, 1.00000000e+00,
       1.22511283e-05, 2.98149139e-02, 2.36257416e-04, 2.01730698e-01,
       1.00000000e+00, 5.28299093e-01, 9.98454452e-01, 4.25819871e-06,
       9.21360606e-06, 1.01508795e-05, 3.20190354e-03, 1.80602800e-02,
       9.99873519e-01, 1.00000000e+00, 9.92593467e-01, 7.68189636e-07,
       2.57860538e-05, 2.64904816e-02, 1.36944860e-01, 3.06835791e-05,
      

In [112]:
loss(theta, mu, sigma, X_test, y_test)

Array(0.12718126, dtype=float32)

In [114]:
def accuracy(y_pred, y_act):
    return jnp.sum(1 - jnp.round(jnp.abs(y_pred - y_act))) / len(y_act)

accuracy(y_pred, y_test)

Array(0.94736844, dtype=float32)

In [115]:
theta = train(params)

step 0, loss: 0.03968934714794159
step 100000, loss: 0.039634015411138535
step 200000, loss: 0.03957916051149368
step 300000, loss: 0.039524853229522705
step 400000, loss: 0.03947202116250992
step 500000, loss: 0.0394204743206501
step 600000, loss: 0.03936953470110893
step 700000, loss: 0.0393189862370491
step 800000, loss: 0.03926884010434151
step 900000, loss: 0.03922383114695549
step 1000000, loss: 0.039185721427202225
step 1100000, loss: 0.039147742092609406
step 1200000, loss: 0.03911023586988449
step 1300000, loss: 0.03907322883605957
step 1400000, loss: 0.03903794661164284
step 1500000, loss: 0.0390034094452858
step 1600000, loss: 0.03896983340382576
step 1700000, loss: 0.03893661126494408
step 1800000, loss: 0.03890376165509224
step 1900000, loss: 0.03887614607810974
step 2000000, loss: 0.03885504975914955
step 2100000, loss: 0.03883611783385277
step 2200000, loss: 0.0388176366686821
step 2300000, loss: 0.0387994684278965
step 2400000, loss: 0.03878159448504448
step 2500000, lo

In [116]:
params = [mu, sigma, theta]
y_pred = predict(params, X_test)
accuracy(y_pred, y_test)

Array(0.94736844, dtype=float32)

In [117]:
X_test.shape

(114, 30)

In [121]:
109/114

0.956140350877193