In [100]:
import csv
import numpy as np

# Load the training and test datasets
def load_titanic_dataset(filename):
    with open(filename, 'r') as data_file:
        reader = csv.reader(data_file)
        data = [x for x in reader]
        # remove header row
        data = data[1:]
    return data

# PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
train_data = load_titanic_dataset('train.csv')

# Extract pclass, sex, age, fare features
X = np.array([
    [float((x[2])), x[4], int(float(x[5])), int(float(x[9]))]  # floats -> ints reduces data noise
    for x in train_data
    if x[5] != ''  # throwaway if missing age
])

# Encode sex as int
for d in X:
    if d[1] == 'male':
        d[1] = 0
    else:
        d[1] = 1


# Extract survival target
y = np.array([bool(int(x[1])) for x in train_data if x[5] != ''])

print X[1:5]
print y[1:5]

[['1.0' '1' '38' '71']
 ['3.0' '1' '26' '7']
 ['1.0' '1' '35' '53']
 ['3.0' '0' '35' '8']]
[ True  True  True False]


In [124]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import KFold, cross_val_score

# Build the classifier
kf = KFold(n_splits=10, shuffle=True)
model = RandomForestClassifier()
scores = [
    model.fit(X[train], y[train]).score(X[test], y[test])
    for train, test in kf.split(X)
]
    
print("Mean 10-fold cross validation accuracy: %s" % np.mean(scores))


Mean 10-fold cross validation accuracy: 0.809487480438


In [125]:
# Create model with all training data
classifier = model.fit(X, y)

# Predict values for the test set
test_data = load_titanic_dataset('test.csv')
test_data = np.array([
    [int(float(x[1])), x[3], int(float(x[4])), int(float(x[8]))]  # floats -> ints reduces data noise
    for x in test_data
    if x[4] != '' and x[8] != ''  # throwaway if missing age or sex
])

# Encode sex as int
for d in test_data:
    if d[1] == 'male':
        d[1] = 0
    else:
        d[1] = 1

print test_data[1:5]
print classifier.predict(test_data[1:5])


[['3' '1' '47' '7']
 ['2' '0' '62' '9']
 ['3' '0' '27' '8']
 ['3' '1' '22' '12']]
[False False False  True]
