# AI心电图

**项目组成员：** 高冲、杨建民、杨大圣、王安、李方、章鱼刘哥、醉霓裳

# 1 数据准备
首先需要了解测试集合、训练集合的数据基本情况
## 1.1 数据导入

In [None]:
import numpy as np
import pandas as pd

import os
print(os.listdir("../input"))

mit_test_data = pd.read_csv("../input/mitbih_test.csv", header=None)
mit_train_data = pd.read_csv("../input/mitbih_train.csv", header=None)

print("MIT test dataset")
print(mit_test_data.info())
print("MIT train dataset")
print(mit_train_data.info())

## 1.2 了解训练集样本情况
我们可以确定是5分类问题，另外，可以观察到几个类别的样本分布是非常不平衡的
 

In [None]:
mit_train_data[187].value_counts()

了解训练样本数据的基本情况

In [None]:
mit_train_data.describe()

## 1.3 数据EDA

In [None]:
M = mit_train_data.values
X = M[:, :-1]
y = M[:, -1].astype(int)

In [None]:
C0 = np.argwhere(y == 0).flatten()
C1 = np.argwhere(y == 1).flatten()
C2 = np.argwhere(y == 2).flatten()
C3 = np.argwhere(y == 3).flatten()
C4 = np.argwhere(y == 4).flatten()

In [None]:
x = np.arange(0, 187)*8/1000

plt.figure(figsize=(14,10))
plt.plot(x, X[C0, :][0], label="Cat. N")
plt.plot(x, X[C1, :][0], label="Cat. S")
plt.plot(x, X[C2, :][0], label="Cat. V")
plt.plot(x, X[C3, :][0], label="Cat. F")
plt.plot(x, X[C4, :][0], label="Cat. Q")
plt.legend()
plt.title("1-beat ECG for every category", fontsize=30)
plt.ylabel("Amplitude", fontsize=15)
plt.xlabel("Time (ms)", fontsize=15)
plt.show()

In [None]:
sns.countplot(x=187, data=mit_train_data);

> 

### 1.4 解决数据不平衡问题
TODO：数据增强
伸展
放大


In [None]:
def stretch(x):
    l = int(187 * (1 + (random.random()-0.5)/3))
    y = resample(x, l)
    if l < 187:
        y_ = np.zeros(shape=(187, ))
        y_[:l] = y
    else:
        y_ = y[:187]
    return y_

def amplify(x):
    alpha = (random.random()-0.5)
    factor = -alpha*x + (1+alpha)
    return x*factor

def augment(x):
    result = np.zeros(shape= (4, 187))
    for i in range(3):
        if random.random() < 0.33:
            new_y = stretch(x)
        elif random.random() < 0.66:
            new_y = amplify(x)
        else:
            new_y = stretch(x)
            new_y = amplify(new_y)
        result[i, :] = new_y
    return result

result = np.apply_along_axis(augment, axis=1, arr=X[C3]).reshape(-1, 187)
classe = np.ones(shape=(result.shape[0],), dtype=int)*3
X = np.vstack([X, result])
y = np.hstack([y, classe])

## 2 基础模型设计 

In [None]:
from keras import backend as K
 
def f1(y_true, y_pred):
    def recall(y_true, y_pred):
        """Recall metric.
        Only computes a batch-wise average of recall.
        Computes the recall, a metric for multi-label classification of
        how many relevant items are selected.
        """
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
        
        recall = true_positives / (possible_positives + K.epsilon())
        return recall
 
    def precision(y_true, y_pred):
        """Precision metric.
        Only computes a batch-wise average of precision.
        Computes the precision, a metric for multi-label classification of
        how many selected items are relevant.
        """
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
        
        precision = true_positives / (predicted_positives + K.epsilon())
        return precision
    
    precision = precision(y_true, y_pred)
    recall = recall(y_true, y_pred)
    
    return 2*((precision*recall)/(precision+recall+K.epsilon()))

In [None]:
from keras.models import Sequential
from keras.layers import Dense, Activation,BatchNormalization,Dropout

model = Sequential()

model.add(Dense(50, input_dim=187, init='normal', activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(30, init='normal', activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(5, activation='softmax'))

model.summary()



## 2.1 模型训练

In [None]:
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['acc',f1])

history = model.fit(X, y, validation_split=0.2,epochs=100,shuffle=True,class_weight='auto')

##  2.2 模型BaseLine的确定

**BaseLine: 97.3**