In [1]:
import pandas as pd

In [2]:
from sklearn.feature_extraction import DictVectorizer
from sklearn.tree import DecisionTreeClassifier

In [3]:
# STEP 1. 读取并探索数据

# 读取数据
train_data = pd.read_csv('train.csv')
test_data = pd.read_csv('test.csv')

In [4]:
# 了解数据表的基本情况
print(train_data.info())

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
PassengerId    891 non-null int64
Survived       891 non-null int64
Pclass         891 non-null int64
Name           891 non-null object
Sex            891 non-null object
Age            714 non-null float64
SibSp          891 non-null int64
Parch          891 non-null int64
Ticket         891 non-null object
Fare           891 non-null float64
Cabin          204 non-null object
Embarked       889 non-null object
dtypes: float64(2), int64(5), object(5)
memory usage: 83.7+ KB
None


In [5]:
# 了解数据表的统计情况
print(train_data.describe())

       PassengerId    Survived      Pclass         Age       SibSp  \
count   891.000000  891.000000  891.000000  714.000000  891.000000   
mean    446.000000    0.383838    2.308642   29.699118    0.523008   
std     257.353842    0.486592    0.836071   14.526497    1.102743   
min       1.000000    0.000000    1.000000    0.420000    0.000000   
25%     223.500000    0.000000    2.000000   20.125000    0.000000   
50%     446.000000    0.000000    3.000000   28.000000    0.000000   
75%     668.500000    1.000000    3.000000   38.000000    1.000000   
max     891.000000    1.000000    3.000000   80.000000    8.000000   

            Parch        Fare  
count  891.000000  891.000000  
mean     0.381594   32.204208  
std      0.806057   49.693429  
min      0.000000    0.000000  
25%      0.000000    7.910400  
50%      0.000000   14.454200  
75%      0.000000   31.000000  
max      6.000000  512.329200  


In [6]:
# 计算离散型变量的统计特征
print(train_data.describe(include=['O']))

                          Name   Sex  Ticket Cabin Embarked
count                      891   891     891   204      889
unique                     891     2     681   147        3
top     Honkanen, Miss. Eliina  male  347082    G6        S
freq                         1   577       7     4      644


In [7]:
# 查看首5行
print(train_data.head())

   PassengerId  Survived  Pclass  \
0            1         0       3   
1            2         1       1   
2            3         1       3   
3            4         1       1   
4            5         0       3   

                                                Name     Sex   Age  SibSp  \
0                            Braund, Mr. Owen Harris    male  22.0      1   
1  Cumings, Mrs. John Bradley (Florence Briggs Th...  female  38.0      1   
2                             Heikkinen, Miss. Laina  female  26.0      0   
3       Futrelle, Mrs. Jacques Heath (Lily May Peel)  female  35.0      1   
4                           Allen, Mr. William Henry    male  35.0      0   

   Parch            Ticket     Fare Cabin Embarked  
0      0         A/5 21171   7.2500   NaN        S  
1      0          PC 17599  71.2833   C85        C  
2      0  STON/O2. 3101282   7.9250   NaN        S  
3      0            113803  53.1000  C123        S  
4      0            373450   8.0500   NaN        S  


In [8]:
# 查看末5行
print(train_data.tail())

     PassengerId  Survived  Pclass                                      Name  \
886          887         0       2                     Montvila, Rev. Juozas   
887          888         1       1              Graham, Miss. Margaret Edith   
888          889         0       3  Johnston, Miss. Catherine Helen "Carrie"   
889          890         1       1                     Behr, Mr. Karl Howell   
890          891         0       3                       Dooley, Mr. Patrick   

        Sex   Age  SibSp  Parch      Ticket   Fare Cabin Embarked  
886    male  27.0      0      0      211536  13.00   NaN        S  
887  female  19.0      0      0      112053  30.00   B42        S  
888  female   NaN      1      2  W./C. 6607  23.45   NaN        S  
889    male  26.0      0      0      111369  30.00  C148        C  
890    male  32.0      0      0      370376   7.75   NaN        Q  


In [9]:
# STEP 2. 清洗数据，处理缺失值

# 使用平均年龄来填充年龄中的 NaN 值
train_data['Age'].fillna(train_data['Age'].mean(), inplace=True)
test_data['Age'].fillna(test_data['Age'].mean(), inplace=True)

In [10]:
# 使用票价的均值填充票价中的 NaN 值
test_data['Fare'].fillna(test_data['Fare'].mean(), inplace=True)

In [11]:
# 对 train_data 的 Embarked 的不同取值计数，找出港口的众数
print(train_data['Embarked'].value_counts())

S    644
C    168
Q     77
Name: Embarked, dtype: int64


In [12]:
# 使用港口的众数(S)来填充登录港口的 NaN 值
train_data['Embarked'].fillna('S', inplace=True)
test_data['Embarked'].fillna('S', inplace=True)

In [13]:
# STEP 3. 选择特征，特征向量化

# 选择特征
features = ['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked']

In [14]:
# 抽取特征
x_train = train_data[features]
x_test = test_data[features]
y_train = train_data['Survived']

In [15]:
# 特征向量化：将分类数据转为数值数据(原数值数据不变)

# 初始化 DictVectorizer 特征抽取器(不产生稀疏矩阵)
dict_vec = DictVectorizer(sparse=False)

In [16]:
# 转换(先转为字典，再进行转换)
x_train = dict_vec.fit_transform(x_train.to_dict(orient='record'))
x_test = dict_vec.transform(x_test.to_dict(orient='record'))

In [17]:
# 输出各个维度的特征含义
print(dict_vec.feature_names_)

['Age', 'Embarked=C', 'Embarked=Q', 'Embarked=S', 'Fare', 'Parch', 'Pclass', 'Sex=female', 'Sex=male', 'SibSp']


In [18]:
# STEP 4. 建立决策树模型并训练

# 构造 ID3 决策树(entropy:ID3; gini:CART)
clf = DecisionTreeClassifier(criterion='entropy')

In [19]:
# 训练决策树
clf.fit(x_train, y_train)

DecisionTreeClassifier(class_weight=None, criterion='entropy', max_depth=None,
                       max_features=None, max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, presort=False,
                       random_state=None, splitter='best')

In [20]:
# STEP 5. 进行预测，写入文件

# 决策树预测
y_pred = clf.predict(x_test)

In [21]:
# 写入 gender_submission.csv
output = pd.DataFrame({'PassengerId': test_data['PassengerId'], 'Survived': y_pred})
output.to_csv('gender_submission.csv', index=False)