In [1]:
import numpy as onp
import pandas as pd

from jax import numpy as np, ops, random
from jax.scipy.special import expit

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive

In [2]:
train_df = pd.read_csv("train.csv")
train_df.info()
train_df.head()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
PassengerId    891 non-null int64
Survived       891 non-null int64
Pclass         891 non-null int64
Name           891 non-null object
Sex            891 non-null object
Age            714 non-null float64
SibSp          891 non-null int64
Parch          891 non-null int64
Ticket         891 non-null object
Fare           891 non-null float64
Cabin          204 non-null object
Embarked       889 non-null object
dtypes: float64(2), int64(5), object(5)
memory usage: 83.7+ KB


Unnamed: 0,PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
0,1,0,3,"Braund, Mr. Owen Harris",male,22.0,1,0,A/5 21171,7.25,,S
1,2,1,1,"Cumings, Mrs. John Bradley (Florence Briggs Th...",female,38.0,1,0,PC 17599,71.2833,C85,C
2,3,1,3,"Heikkinen, Miss. Laina",female,26.0,0,0,STON/O2. 3101282,7.925,,S
3,4,1,1,"Futrelle, Mrs. Jacques Heath (Lily May Peel)",female,35.0,1,0,113803,53.1,C123,S
4,5,0,3,"Allen, Mr. William Henry",male,35.0,0,0,373450,8.05,,S


In [3]:
for col in ["Pclass", "Sex", "SibSp", "Parch", "Embarked"]:
    print(train_df[col].value_counts(), end="\n\n")

3    491
1    216
2    184
Name: Pclass, dtype: int64

male      577
female    314
Name: Sex, dtype: int64

0    608
1    209
2     28
4     18
3     16
8      7
5      5
Name: SibSp, dtype: int64

0    678
1    118
2     80
5      5
3      5
4      4
6      1
Name: Parch, dtype: int64

S    644
C    168
Q     77
Name: Embarked, dtype: int64



### extract Title from name to impute Age

In [4]:
d = train_df.copy()
d["Title"] = d.Name.str.split(", ").str.get(1).str.split(" ").str.get(0)
print(d.Title.value_counts())
d.groupby("Title").Age.mean()

Mr.          517
Miss.        182
Mrs.         125
Master.       40
Dr.            7
Rev.           6
Major.         2
Mlle.          2
Col.           2
Sir.           1
the            1
Ms.            1
Don.           1
Capt.          1
Mme.           1
Jonkheer.      1
Lady.          1
Name: Title, dtype: int64


Title
Capt.        70.000000
Col.         58.000000
Don.         40.000000
Dr.          42.000000
Jonkheer.    38.000000
Lady.        48.000000
Major.       48.500000
Master.       4.574167
Miss.        21.773973
Mlle.        24.000000
Mme.         24.000000
Mr.          32.368090
Mrs.         35.898148
Ms.          28.000000
Rev.         43.166667
Sir.         49.000000
the          33.000000
Name: Age, dtype: float64

### prepare data

In [5]:
d = train_df.copy()
d.Embarked.fillna("S", inplace=True)
d["Title"] = d.Name.str.split(", ").str.get(1).str.split(" ").str.get(0).apply(
    lambda x: x if x in ["Mr.", "Miss.", "Mrs.", "Master."] else "Misc.")
title_cat = pd.CategoricalDtype(categories=["Mr.", "Miss.", "Mrs.", "Master.", "Misc."], ordered=True)
age_mean, age_std = d.Age.mean(), d.Age.std()
embarked_cat = pd.CategoricalDtype(categories=["S", "C", "Q"], ordered=True)
data = dict(age=d.Age.pipe(lambda x: (x - age_mean) / age_std).values,
            pclass=d.Pclass.values - 1,
            title=d.Title.astype(title_cat).cat.codes.values,
            sex=(d.Sex == "male").astype(int).values,
            sibsp=d.SibSp.clip(0, 1).values,
            parch=d.Parch.clip(0, 2).values,
            embarked=d.Embarked.astype(embarked_cat).cat.codes.values,
            survived=d.Survived.values)

### modelling: logistic regression of survived w.r.t. age

In [6]:
def model(age, pclass, title, sex, sibsp, parch, embarked, survived=None):
    b_pclass = numpyro.sample("b_Pclass", dist.Normal(0, 1), sample_shape=(3,))
    b_title = numpyro.sample("b_Title", dist.Normal(0, 1), sample_shape=(5,))
    b_sex = numpyro.sample("b_Sex", dist.Normal(0, 1), sample_shape=(2,))
    b_sibsp = numpyro.sample("b_SibSp", dist.Normal(0, 1), sample_shape=(2,))
    b_parch = numpyro.sample("b_Parch", dist.Normal(0, 1), sample_shape=(3,))
    b_embarked = numpyro.sample("b_Embarked", dist.Normal(0, 1), sample_shape=(3,))

    # impute age by Title
    age_mu = numpyro.sample("age_mu", dist.Normal(0, 1), sample_shape=(5,))
    age_mu = age_mu[title]
    age_sigma = numpyro.sample("age_sigma", dist.Normal(0, 1), sample_shape=(5,))
    age_sigma = age_sigma[title]
    age_isnan = onp.isnan(age)
    age_nanidx = onp.nonzero(age_isnan)[0]
    if survived is not None:
        age_impute = numpyro.param("age_impute", np.zeros(age_isnan.sum()))
    else:
        age_impute = numpyro.sample("age_impute", dist.Normal(age_mu[age_nanidx], age_sigma[age_nanidx]))
    age = ops.index_update(age, age_nanidx, age_impute)
    numpyro.sample("age", dist.Normal(age_mu, age_sigma), obs=age)

    a = numpyro.sample("a", dist.Normal(0, 1))
    b_age = numpyro.sample("b_Age", dist.Normal(0, 1))
    logits = a + b_age * age

    logits = logits + b_title[title] + b_pclass[pclass] + b_sex[sex] \
        + b_sibsp[sibsp] + b_parch[parch] + b_embarked[embarked]
    if survived is None:
        probs = expit(logits)
        numpyro.sample("probs", dist.Delta(probs))
    numpyro.sample("survived", dist.Bernoulli(logits=logits), obs=survived)

### sampling

In [7]:
mcmc = MCMC(NUTS(model), 1000, 1000)
mcmc.run(random.PRNGKey(0), **data)
mcmc.print_summary()

sample: 100%|██████████| 2000/2000 [00:26<00:00, 76.46it/s, 63 steps of size 7.21e-02. acc. prob=0.93]  



                     mean       std    median      5.0%     95.0%     n_eff     r_hat
              a      0.10      0.82      0.12     -1.29      1.37    969.65      1.00
  age_impute[0]      0.24      0.83      0.21     -1.03      1.62   2056.88      1.00
  age_impute[1]     -0.11      0.89     -0.12     -1.66      1.32   2457.11      1.00
  age_impute[2]      0.38      0.83      0.39     -0.94      1.69   2011.40      1.00
  age_impute[3]      0.25      0.83      0.25     -1.24      1.48   1854.72      1.00
  age_impute[4]     -0.69      0.87     -0.65     -2.04      0.76   3498.57      1.00
  age_impute[5]      0.23      0.91      0.22     -1.28      1.66   3290.19      1.00
  age_impute[6]      0.41      0.80      0.42     -0.83      1.75   2142.73      1.00
  age_impute[7]     -0.65      0.90     -0.65     -2.19      0.75   2303.03      1.00
  age_impute[8]     -0.13      0.90     -0.17     -1.58      1.34   3406.48      1.00
  age_impute[9]      0.23      0.87      0.24     -1.

### posterior predictive

In [8]:
posterior = mcmc.get_samples().copy()
survived = data.pop("survived")
survived_probs = Predictive(model, posterior).get_samples(random.PRNGKey(1), **data)["probs"]
((survived_probs.mean(axis=0) >= 0.5).astype(np.uint8) == survived).sum() / survived.shape[0]

DeviceArray(0.8260382, dtype=float32)

This is a pretty good result using a linear regression model.

### submission

In [9]:
test_df = pd.read_csv("test.csv")
d = test_df.copy()
d["Title"] = d.Name.str.split(", ").str.get(1).str.split(" ").str.get(0).apply(
    lambda x: x if x in ["Mr.", "Miss.", "Mrs.", "Master."] else "Misc.")
test_data = dict(age=d.Age.pipe(lambda x: (x - age_mean) / age_std).values,
                 pclass=d.Pclass.values - 1,
                 title=d.Title.astype(title_cat).cat.codes.values,
                 sex=(d.Sex == "male").astype(int).values,
                 sibsp=d.SibSp.clip(0, 1).values,
                 parch=d.Parch.clip(0, 2).values,
                 embarked=d.Embarked.astype(embarked_cat).cat.codes.values)

posterior.pop("age_impute", None)
survived_probs = Predictive(model, posterior).get_samples(random.PRNGKey(2), **test_data)["probs"]
d["Survived"] = (survived_probs.mean(axis=0) >= 0.5).astype(np.uint8)
d[["PassengerId", "Survived"]].to_csv("submission.csv", index=False)