In [50]:
from data_tools.api import *
from utilscht.Data import *
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from volat_calcu import *
import pymysql
%config InlineBackend.figure_format ='retina'
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"


DB_INFO = dict(host='192.168.1.234',
               user='winduser',
               password='1qaz@WSX',
               db='wind')

conn = pymysql.connect(**DB_INFO, charset='utf8mb4', cursorclass=pymysql.cursors.DictCursor)

In [2]:
from data_tools.api import trade_days

trade_dates_all = trade_days.copy()
def get_prev_n_trade_date(trade_date, n):
    trade_date=str(trade_date)[0:10].replace('-','')
    pos = np.searchsorted(trade_dates_all, trade_date)
    assert pos >= n
    return str(trade_dates_all[pos - n])

def get_next_n_trade_date(trade_date, n=1):
    trade_date=str(trade_date)[0:10].replace('-','')
    pos = np.searchsorted(trade_dates_all, trade_date, side='right')
    if pos + n - 1 < len(trade_dates_all):
        return str(trade_dates_all[pos + n - 1])
    else:
        return str(trade_dates_all[-1])

In [3]:
def get_volat(stock_code,start_date,end_date):
    min_bar=get_stk_bar(stock_code,start_date,end_date,freq='1m',fields=["close"])
    date_p_1d=get_prev_n_trade_date(start_date,1)
    min_bar_pre_1d=get_stk_bar(stock_code,date_p_1d,date_p_1d,freq='1m',fields=["close"])
    
    if(len(min_bar)==0 or len(min_bar_pre_1d)==0):
        return 0
    
    #get the fq ratio
    fq_ratio = get_stk_bar(stock_code,date_p_1d,end_date,fields=["adj_factor"]).reset_index().rename(columns={"index":"date"})
    min_bar["date"]=min_bar.index
    min_bar["date"]=min_bar["date"].apply(lambda x:pd.to_datetime(str(x)[0:10]))
    min_bar=pd.merge(min_bar,fq_ratio,on="date",how="left")
    min_bar["close"]=min_bar.close*min_bar.adj_factor/fq_ratio.adj_factor.values[1]
    min_bar_pre_1d["close"]=min_bar_pre_1d.close*fq_ratio.adj_factor.values[0]/fq_ratio.adj_factor.values[1]
    
    return get_cross_count(min_bar_pre_1d["close"].values,min_bar["close"].values)
    
    

In [27]:
def get_stock_volat_interval(stock_code,start_date,end_date):
    
    '''
    params:
    
    stock_code: 股票代码
    start_date: 测试区间起始时间点，str类型，如'20190101'
    end_date： 测试区间结束时间点，str类型，如'20191201'
    
    output：
    画图部分：每个满足条件的时间点前后的股价图，标注含义：pre_low：之前40天低点，pre_high：之前40天低点后的高点，now：当前时间点，
    vola_begin:震荡开始时间点（当前时间点往后推两天），vola_end: 震荡结束时间点（股价突破上下界的时间）
    波动率计算部分： 返回一个DataFrame，index为： date：当前的日期，interval_end: 计算波动率所用区间的结束日期，
    columns为：vola_pre：从t_low（或t_high-15d）到t_high的波动率， vola_post: 从当前时间点向后推两天开始，到interval_end，求波动率
    '''
    df_close_price=get_stk_bar(stock_code,start=start_date,end=end_date,fields=["adj_close","adj_factor"])
    adj_close_price=df_close_price.adj_close
    adj_close_price=adj_close_price/df_close_price.adj_factor.values[-1]

    trade_dates=get_trade_dates(start_date,end_date)
    trade_dates=[str(i) for i in trade_dates]

    result_df=pd.DataFrame()
    last_pre_high=0
    last_pre_low=0

    test_volatility=PdfPages('stk_price_volat_plotting_v2//volatlity_interval_{}.pdf'.format(stock_code))
    df_vola_summary = pd.DataFrame()
    for date in trade_dates[:-10]:
        
        #判断当前价格是否满足要求
        price_now=adj_close_price.loc[date]
        if price_now<30:
            continue
            
        #判断过去和当前涨跌幅是否满足要求
        date_start=get_prev_n_trade_date(date,40)
        price_pre_low=np.min(adj_close_price.loc[date_start:date])
        date_pre_low=np.argmin(adj_close_price.loc[date_start:date])
        price_pre_high=np.max(adj_close_price.loc[date_pre_low:date])
        date_pre_high= np.argmax(adj_close_price.loc[date_pre_low:date])

        date_pre_begin=max(str(date_pre_low)[0:10].replace('-',''),get_prev_n_trade_date(date_pre_high,15))
        vola_pre= get_volat(stock_code,date_pre_begin,str(date_pre_high)[0:10].replace('-',''))\
                    /len(get_trade_dates(date_pre_begin,str(date_pre_high)[0:10].replace('-','')))

        if  \
        (price_pre_high/price_pre_low-1<=0.08 or
        (price_pre_high-price_now)/(price_pre_high-price_pre_low)<=0.2 or
        (price_pre_high-price_now)/(price_pre_high-price_pre_low)>=0.35):
            continue

        #判断未来 3-7 天股价是否在区间内震荡
        date_judge_begin = get_next_n_trade_date(date, 3)
        date_judge_end= get_next_n_trade_date(date,7)
        upper_bound=price_pre_high
        lower_bound=price_pre_high - (price_pre_high-price_pre_low) * 0.35

        flag = 0
        for dt in get_trade_dates(date_judge_begin,date_judge_end):
            price_date=adj_close_price.loc[str(dt)]
            if price_date>upper_bound or price_date < lower_bound:
                flag=1
                break
        if flag==1:
            continue
        
        #判断前高和前低是否和上一个重合
        if price_pre_high==last_pre_high and price_pre_low== last_pre_low:
            continue
        else:
            last_pre_high = price_pre_high
            last_pre_low = price_pre_low

        ###如果满足上诉几个筛选条件，则计算股票在震荡区间内的波动率
        #寻找震荡区间结束点（最大为判断区间往后推10天）
        date_vola_end = get_next_n_trade_date(date_judge_end,10) 
        for dt in get_trade_dates(date_judge_end,date_vola_end)[1:]:
            price_date=adj_close_price.loc[str(dt)]
            if price_date > upper_bound or price_date < lower_bound:
                date_vola_end = str(get_previous_trade_date(dt))
                break

        #计算平均波动率   
        vola_post_ls=[]
        for dt in get_trade_dates(date_judge_end,date_vola_end):
            volatility=get_volat(stock_code,date_judge_begin,dt)/len(get_trade_dates(date_judge_begin,dt))
            vola_post_ls.append(volatility)

        #plotting the volatility in each day of the past one year to date_vola_end
        begin_plot_date='20190101'
        date_range=adj_close_price.loc['20190101':get_next_n_trade_date(date_vola_end,20)].index
        date_range=[str(i)[0:10].replace('-','') for i in date_range]
        volatility_series=pd.Series(index=date_range)
        for dt in date_range:
            volatility=get_volat(stock_code,dt,dt)
            volatility_series.loc[dt]=volatility
        volatility_series.index=pd.to_datetime(volatility_series.index)
        volatility_ma5=volatility_series.rolling(5).mean().dropna()#计算波动率五天移动平均值
        fig=plt.figure(figsize=(10,7.5))
        plt.xticks(rotation=90)
        ax1=fig.add_subplot(111)
        ax1.plot(adj_close_price.loc[pd.to_datetime(date_range)],color='blue',label='close_price')
        plt.legend()
        ax2=ax1.twinx()
        ax2.plot(volatility_series,color='green',label='volat_daily')
        ax2.plot(volatility_ma5,color='red',label='volat_ma5')
        plt.legend()
        ax2.set_ylabel("volatility")
        ax1.set_ylabel("close price")
        ax2.set_ylim(0,np.max(volatility_series)*2.5)
        plt.title('close price & volatility for {}'.format(stock_code))
        plt.grid()
        test_volatility.savefig(fig)
        plt.show()


        #plotting if requirements are satisfied
        plotting_price=adj_close_price.loc[get_prev_n_trade_date(date,60):get_next_n_trade_date(date_vola_end,20)]
        date_range=adj_close_price.loc[str(date_pre_low).replace('-',''):get_next_n_trade_date(date_vola_end,20)].index
        fig=plt.figure(figsize=(10,7.5))
        plt.xticks(rotation=90)
        ax=fig.add_subplot(111)
        ax.plot(plotting_price,color='blue',label='close_price')
        plt.legend()
        arrow_style=dict(facecolor='blue', arrowstyle='->',connectionstyle='arc3')
        ax.annotate("pre_low:\n"+str(date_pre_low)[0:10],xy=(date_pre_low,price_pre_low),
                     xytext=(date_pre_low,price_pre_low*0.96),arrowprops=arrow_style)
        ax.annotate("pre_high:\n"+str(date_pre_high)[0:10],xy=(date_pre_high,price_pre_high),
                     xytext=(date_pre_high-pd.Timedelta('30d'),price_pre_high*0.98),arrowprops=arrow_style)
        ax.annotate("now:\n"+str(date),xy=(date,price_now),
                     xytext=(date,price_now*0.92),arrowprops=arrow_style)
        ax.annotate("vola_begin:\n"+str(date_judge_begin),xy=(date_judge_begin,adj_close_price.loc[date_judge_begin]),
                     xytext=(date_judge_begin,adj_close_price.loc[date_judge_begin]*1.05),arrowprops=arrow_style)
        ax.annotate("vola_end:\n"+str(date_vola_end),xy=(date_vola_end,adj_close_price.loc[date_vola_end]),
                     xytext=(date_vola_end,adj_close_price.loc[date_vola_end]*0.91),arrowprops=arrow_style)
        ax2=ax.twinx()
        ax2.plot(volatility_series.loc[pd.to_datetime(date_range)],color='green',label='volat_daily')
        ax2.plot(volatility_ma5.loc[pd.to_datetime(date_range)],color='red',label='volat_ma5')
        ax2.set_ylabel("volatility")
        ax1.set_ylabel("close price")
        ax2.set_ylim(0,np.max(volatility_series.loc[pd.to_datetime(date_range)])*2.5)
        plt.legend()
        plt.title("close price for {}".format(stock_code))
        plt.grid()
        test_volatility.savefig(fig)
        plt.show()

        ###calculate the volatility in the interval
        df_vola=pd.DataFrame(vola_post_ls,columns=['vola_interval_mean'],index=get_trade_dates(date_judge_end,date_vola_end))
        df_vola["vola_pre"]=vola_pre
        df_vola['date']=date
        df_vola["interval_begin"]=date_judge_begin
        df_vola['date_pre_high']=str(date_pre_high)[0:10]
        df_vola['date_pre_low']=str(date_pre_low)[0:10]
        df_vola=df_vola.reset_index().rename(columns={"index":"interval_end"})
        df_vola=df_vola.set_index(["date","interval_begin","interval_end"]).reset_index()
        df_vola_summary=pd.concat([df_vola_summary,df_vola])
        
        fig = plt.figure(figsize=(8,5))
        ax = plt.subplot(111)
        ax.axis('off')
        ax.table(cellText=df_vola.round(4).values, colLabels=df_vola.columns, bbox=[0,0,1,1])
        plt.title("interval_volat for {}".format(stock_code))
        test_volatility.savefig(fig)
        plt.show()
    
    #获得汇总表格
    if len(df_vola_summary)>0:
        df_vola_summary = df_vola_summary.groupby("date",as_index=False).apply(lambda x:x.iloc[-1])
    else:
        test_volatility.close()
        return
    
    #plot the summary picture
    date_range=adj_close_price.loc[start_date:end_date].index
    date_range=[str(i)[0:10].replace('-','') for i in date_range]
    volatility_series=pd.Series(index=date_range)
    for dt in date_range:
        volatility=get_volat(stock_code,dt,dt)
        volatility_series.loc[dt]=volatility
    volatility_series.index=pd.to_datetime(volatility_series.index)
    volatility_ma5=volatility_series.rolling(5).mean().dropna()#计算波动率五天移动平均值
    fig=plt.figure(figsize=(10,7.5))
    plt.xticks(rotation=90)
    ax1=fig.add_subplot(111)
    ax1.plot(adj_close_price.loc[pd.to_datetime(date_range)],color='blue',label='close_price')
    for date in df_vola_summary.date:
        ax1.scatter(date,adj_close_price.loc[date],color='red',label="date_now")
    plt.legend()
    ax2=ax1.twinx()
    ax2.plot(volatility_series,color='green',label='volat_daily')
    ax2.plot(volatility_ma5,color='red',label='volat_ma5')
    plt.legend()
    ax2.set_ylabel("volatility")
    ax1.set_ylabel("close price")
    ax2.set_ylim(0,np.max(volatility_series)*2.5)
    plt.title('close price volatility summary for {}'.format(stock_code))
    plt.grid()
    test_volatility.savefig(fig)
    plt.show()
    
    #plot the summary table
    fig = plt.figure(figsize=(8,4))
    ax = plt.subplot(111)
    ax.axis('off')
    ax.table(cellText=df_vola_summary.round(4).values, colLabels=df_vola_summary.columns, bbox=[0,0,1,1])
    plt.title("interval_volat_summary for {}".format(stock_code))
    test_volatility.savefig(fig)
    plt.show()
    
    test_volatility.close()
    return (stock_code,"success") 

In [None]:
sql="""select S_INFO_WINDCODE,S_DQ_MV from AShareEODDerivativeIndicator where TRADE_DT={}""".format("20190531").upper()
mktv_df=pd.read_sql_query(sql,conn)
stock_pool_init=mktv_df[mktv_df['S_DQ_MV']>3000000]["S_INFO_WINDCODE"].to_list()

code_ls_bank= \
['000001.SZ','002142.SZ','002807.SZ',
 '002839.SZ','002936.SZ','002948.SZ',
 '002958.SZ','002966.SZ','600000.SH',
 '600015.SH','600016.SH','600036.SH',
 '600908.SH','600919.SH','600926.SH',
 '600928.SH','601009.SH','601077.SH',
 '601128.SH','601166.SH','601169.SH',
 '601229.SH','601288.SH','601328.SH',
 '601398.SH','601577.SH','601658.SH',
 '601818.SH','601838.SH','601860.SH',
 '601916.SH','601939.SH','601988.SH',
 '601997.SH','601998.SH','603323.SH']

stock_pool_init=set(stock_pool_init)-set(code_ls_bank)

In [28]:
results=Parallel(n_jobs=8, verbose=5, backend='loky', batch_size='auto') \
        (delayed(get_stock_volat_interval)(stock_code,'20190101','20191210') for stock_code in stock_pool_init)

[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   2 tasks      | elapsed:    1.8s
[Parallel(n_jobs=8)]: Done  56 tasks      | elapsed:  6.3min
[Parallel(n_jobs=8)]: Done 146 tasks      | elapsed: 18.6min
[Parallel(n_jobs=8)]: Done 186 out of 186 | elapsed: 29.5min finished
