In [1]:
import tensorflow as tf
import numpy as np
import pymysql
from datetime import date, timedelta


DB_IP = '192.168.1.210'
DB_USER = 'root'
DB_PWD = '1234'
DB_SCH = 'data'
DB_ENC = 'utf8mb4'
LIMIT_FILTER = 0.70

INPUT_VEC_SIZE = LSTM_SIZE = 7
TIME_STEP_SIZE = 60
LABEL_SIZE = 3
EVALUATE_SIZE = 3
LSTM_DEPTH = 4

BATCH_SIZE = 15000
TRAIN_CNT = 100

def init_weights(shape):
    return tf.Variable(tf.random_normal(shape, stddev=0.01))

In [2]:
conn = pymysql.connect(host=DB_IP, user=DB_USER, password=DB_PWD, db=DB_SCH, charset=DB_ENC)

def get_codedates(code, limit):    
    query = "SELECT date FROM data.daily_stock WHERE code = %s AND date <= %s ORDER BY date ASC"
    cursor = conn.cursor()
    cursor.execute(query, (code, limit))
    code_dates = list()        
    dates = cursor.fetchall()
    conn.close()
    for date in dates:
        code_dates.append((code, date[0]))
    return code_dates

def get_items(code, date, limit):
    conn = pymysql.connect(host=DB_IP, user=DB_USER, password=DB_PWD, db=DB_SCH, charset=DB_ENC)
    query = "SELECT open, high, low, close, volume, hold_foreign, st_purchase_inst FROM data.daily_stock WHERE code = %s AND date >= %s ORDER BY date ASC LIMIT %s"
    cursor = conn.cursor()
    cursor.execute(query, (code, date, limit))
    items = cursor.fetchall()
    conn.close()
    return items
    
def get_codes():
    query = "SELECT DISTINCT code FROM data.daily_stock"
    try :
        cursor = conn.cursor()
        cursor.execute(query)
    except :
        set_connect()
        cursor = conn.cursor()
        cursor.execute(query)
    return cursor.fetchall()
    
    
def set_connect():
    print('reset connect')
    conn = pymysql.connect(host=DB_IP, user=DB_USER, password=DB_PWD, db=DB_SCH, charset=DB_ENC)
    
    

In [3]:
def model(code, X, W, B, lstm_size):
    XT = tf.transpose(X, [1, 0, 2]) 
    XR = tf.reshape(XT, [-1, lstm_size])
    X_split = tf.split(0, TIME_STEP_SIZE, XR)
    with tf.variable_scope(code, reuse=False):
        cell = tf.nn.rnn_cell.GRUCell(lstm_size)
        cell = tf.nn.rnn_cell.DropoutWrapper(cell = cell, output_keep_prob = 0.5)
        cell = tf.nn.rnn_cell.MultiRNNCell([cell] * LSTM_DEPTH, state_is_tuple = True)

    outputs, _states = tf.nn.rnn(cell, X_split, dtype=tf.float32)

    return tf.matmul(outputs[-1], W) + B, cell.state_size # State size to initialize the stat


In [4]:
def read_series_datas(code_dates):
    EXPECT = 3
    X = list()
    Y = list()
    for code_date in code_dates:
        items = get_items(code_date[0], code_date[1], TIME_STEP_SIZE + EVALUATE_SIZE)
  
        if len(items) < (EVALUATE_SIZE + TIME_STEP_SIZE):
            break
        X.append(np.array(items[:TIME_STEP_SIZE]))

        st_purchase_inst = items[-(EVALUATE_SIZE + 1)][EXPECT]
        if st_purchase_inst == 0:
            continue
        for i in range(EVALUATE_SIZE, len(items) - EVALUATE_SIZE):
            eval_inst = items[i][EXPECT]
            eval_bef = items[EVALUATE_SIZE-i][EXPECT]
            if eval_bef < eval_inst:
                eval_bef = eval_inst           
        
        if (eval_bef - st_purchase_inst) / st_purchase_inst < -0.02: #percent ? cnt ? 
            Y.append((0., 0., 1.))
        elif (eval_bef - st_purchase_inst) / st_purchase_inst > 0.03:
            Y.append((1., 0., 0.))
        else:
            Y.append((0., 1., 0.))


    arrX = np.array(X)    
    meanX = np.mean(arrX, axis = 0)
    stdX = np.std(arrX, axis = 0)
    norX = (arrX - meanX) / stdX
    norY = np.array(Y)
    return norX, norY


In [5]:
def read_datas(code_dates):    
    np.random.seed()
    np.random.shuffle(code_dates)

    trX = list()
    trY = list()
    trX, trY = read_series_datas(code_dates)
    teX, teY = read_series_datas(code_dates)

    return trX, trY, teX, teY

In [6]:
def analyze(code, limit):  
    code_dates = get_codedates(code, limit)
    tf.reset_default_graph()    
    last = code_dates[-1][1]
    trX, trY, teX, teY = read_datas(code_dates)
    if (len(trX) == 0):
        return None

    X = tf.placeholder(tf.float32, [None, TIME_STEP_SIZE, INPUT_VEC_SIZE])
    Y = tf.placeholder(tf.float32, [None, LABEL_SIZE])

    W = init_weights([LSTM_SIZE, LABEL_SIZE])
    B = init_weights([LABEL_SIZE])

    py_x, state_size = model(code, X, W, B, LSTM_SIZE)

    loss = tf.nn.softmax_cross_entropy_with_logits(py_x, Y)
    cost = tf.reduce_mean(loss)
    train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
    predict_op = tf.argmax(py_x, 1)

    # Launch the graph in a session
    analyzed = None
    with tf.Session() as sess:
        # you need to initialize all variables
        tf.global_variables_initializer().run()

        for loop in range(TRAIN_CNT):
            for start, end in zip(range(0, len(trX), BATCH_SIZE), range(BATCH_SIZE, len(trX)+1, BATCH_SIZE)):
                sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end]})

            test_indices = np.arange(len(teX))
            org = teY[test_indices]
            res = sess.run(predict_op, feed_dict={X: teX[test_indices], Y: teY[test_indices]})
            
            if loop == TRAIN_CNT-1 :
                result = np.mean(np.argmax(org, axis=1) == res)                
                analyzed = {"stock":code, "per":round(result, 2), "date":limit}
                print(analyzed)
    return analyzed
                    
# 이전데이터로 확인
# 1,2,3 
# 결과후 저장

In [7]:
limit = '2017-02-14'
codes = get_codes() #
filtered = list()
for code in codes : 
    analyzed = analyze(code[0], limit)
    if analyzed is not None and analyzed["per"] > LIMIT_FILTER:
        filtered.append(analyzed)
print(filtered)


{'per': 0.88, 'date': '2017-02-14', 'stock': 'A000030'}
{'per': 0.28999999999999998, 'date': '2017-02-14', 'stock': 'A000050'}
{'per': 0.20000000000000001, 'date': '2017-02-14', 'stock': 'A000070'}
{'per': 0.33000000000000002, 'date': '2017-02-14', 'stock': 'A000080'}
{'per': 0.34999999999999998, 'date': '2017-02-14', 'stock': 'A000100'}




{'per': 1.0, 'date': '2017-02-14', 'stock': 'A000120'}
{'per': 0.47999999999999998, 'date': '2017-02-14', 'stock': 'A000140'}
{'per': 0.45000000000000001, 'date': '2017-02-14', 'stock': 'A000150'}
{'per': 0.46999999999999997, 'date': '2017-02-14', 'stock': 'A000210'}
{'per': 1.0, 'date': '2017-02-14', 'stock': 'A000230'}
{'per': 0.46999999999999997, 'date': '2017-02-14', 'stock': 'A000240'}
{'per': 0.17999999999999999, 'date': '2017-02-14', 'stock': 'A000270'}
{'per': 0.33000000000000002, 'date': '2017-02-14', 'stock': 'A000640'}
{'per': 1.0, 'date': '2017-02-14', 'stock': 'A000660'}
{'per': 0.12, 'date': '2017-02-14', 'stock': 'A000670'}
{'per': 0.23000000000000001, 'date': '2017-02-14', 'stock': 'A000720'}
{'per': 0.17999999999999999, 'date': '2017-02-14', 'stock': 'A000810'}
{'per': 0.14999999999999999, 'date': '2017-02-14', 'stock': 'A000880'}
{'per': 0.17000000000000001, 'date': '2017-02-14', 'stock': 'A000990'}
{'per': 1.0, 'date': '2017-02-14', 'stock': 'A001040'}
{'per': 0.88, 

  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)
  keepdims=keepdims)
  arrmean, rcount, out=arrmean, casting='unsafe', subok=False)
  ret = ret.dtype.type(ret / rcount)


{'per': 0.46000000000000002, 'date': '2017-02-14', 'stock': 'A008060'}
{'per': 0.0, 'date': '2017-02-14', 'stock': 'A008490'}
{'per': 0.67000000000000004, 'date': '2017-02-14', 'stock': 'A008560'}
{'per': 0.40999999999999998, 'date': '2017-02-14', 'stock': 'A008770'}
{'per': 0.050000000000000003, 'date': '2017-02-14', 'stock': 'A008930'}
{'per': 0.14000000000000001, 'date': '2017-02-14', 'stock': 'A009150'}
{'per': 0.13, 'date': '2017-02-14', 'stock': 'A009240'}
{'per': 0.11, 'date': '2017-02-14', 'stock': 'A009290'}
{'per': 0.56000000000000005, 'date': '2017-02-14', 'stock': 'A009420'}
{'per': 0.46000000000000002, 'date': '2017-02-14', 'stock': 'A009540'}
{'per': 0.40999999999999998, 'date': '2017-02-14', 'stock': 'A009830'}
{'per': 0.25, 'date': '2017-02-14', 'stock': 'A010060'}
{'per': 0.23999999999999999, 'date': '2017-02-14', 'stock': 'A010120'}
{'per': 0.01, 'date': '2017-02-14', 'stock': 'A010130'}
{'per': 0.55000000000000004, 'date': '2017-02-14', 'stock': 'A010140'}
{'per': 0.

OperationalError: (2003, "Can't connect to MySQL server on '192.168.1.210' ([Errno 110] Connection timed out)")