In [4]:
import pymongo
import pandas as pd
import pickle
import datetime
import time
import gzip
import lzma
import pytz


import pymongo
import io
import pandas as pd
import pickle
import datetime
import time
import gzip
import lzma
import pytz
import pyarrow as pa
import pyarrow.parquet as pq
import numpy as np


def DB(host, db_name, user, passwd, version=3):
    auth_db = db_name if user not in ('admin', 'root') else 'admin'
    uri = 'mongodb://%s:%s@%s/?authSource=%s' % (user, passwd, host, auth_db)
    return DBObj(uri, db_name=db_name, version=version)


class DBObj(object):
    def __init__(self, uri, symbol_column='skey', db_name='white_db', version=3):
        self.db_name = db_name
        self.uri = uri
        self.client = pymongo.MongoClient(self.uri)
        self.db = self.client[self.db_name]
        self.chunk_size = 20000
        self.symbol_column = symbol_column
        self.date_column = 'date'
        self.version = version

    def parse_uri(self, uri):
        # mongodb://user:password@example.com
        return uri.strip().replace('mongodb://', '').strip('/').replace(':', ' ').replace('@', ' ').split(' ')

    def drop_table(self, table_name):
        self.db.drop_collection(table_name)

    def rename_table(self, old_table, new_table):
        self.db[old_table].rename(new_table)

    def write(self, table_name, df):
        if len(df) == 0: return

        multi_date = False

        if self.date_column in df.columns:
            date = str(df.head(1)[self.date_column].iloc[0])
            multi_date = len(df[self.date_column].unique()) > 1
        else:
            raise Exception('DataFrame should contain date column')

        collection = self.db[table_name]
        collection.create_index([('date', pymongo.ASCENDING), ('symbol', pymongo.ASCENDING)], background=True)
        collection.create_index([('symbol', pymongo.ASCENDING), ('date', pymongo.ASCENDING)], background=True)

        if multi_date:
            for (date, symbol), sub_df in df.groupby([self.date_column, self.symbol_column]):
                date = str(date)
                symbol = int(symbol)
                collection.delete_many({'date': date, 'symbol': symbol})
                self.write_single(collection, date, symbol, sub_df)
        else:
            for symbol, sub_df in df.groupby([self.symbol_column]):
                collection.delete_many({'date': date, 'symbol': symbol})
                self.write_single(collection, date, symbol, sub_df)

    def write_single(self, collection, date, symbol, df):
        for start in range(0, len(df), self.chunk_size):
            end = min(start + self.chunk_size, len(df))
            df_seg = df[start:end]
            version = self.version
            ser_data = self.ser(df_seg, version)
            seg = {'ver': version, 'data': ser_data, 'date': date, 'symbol': symbol, 'start': start}
            collection.insert_one(seg)

    def build_query(self, start_date=None, end_date=None, symbol=None):
        query = {}

        def parse_date(x):
            if type(x) == str:
                if len(x) != 8:
                    raise Exception("`date` must be YYYYMMDD format")
                return x
            elif type(x) == datetime.datetime or type(x) == datetime.date:
                return x.strftime("%Y%m%d")
            elif type(x) == int:
                return parse_date(str(x))
            else:
                raise Exception("invalid `date` type: " + str(type(x)))

        if start_date is not None or end_date is not None:
            query['date'] = {}
            if start_date is not None:
                query['date']['$gte'] = parse_date(start_date)
            if end_date is not None:
                query['date']['$lte'] = parse_date(end_date)

        def parse_symbol(x):
            if type(x) == int:
                return x
            else:
                return int(x)

        if symbol:
            if type(symbol) == list or type(symbol) == tuple:
                query['symbol'] = {'$in': [parse_symbol(x) for x in symbol]}
            else:
                query['symbol'] = parse_symbol(symbol)

        return query

    def delete(self, table_name, start_date=None, end_date=None, symbol=None):
        collection = self.db[table_name]

        query = self.build_query(start_date, end_date, symbol)
        if not query:
            print('cannot delete the whole table')
            return None

        collection.delete_many(query)

    def read(self, table_name, start_date=None, end_date=None, symbol=None):
        collection = self.db[table_name]

        query = self.build_query(start_date, end_date, symbol)
        if not query:
            print('cannot read the whole table')
            return None

        segs = []
        for x in collection.find(query):
            x['data'] = self.deser(x['data'], x['ver'])
            segs.append(x)
        segs.sort(key=lambda x: (x['symbol'], x['date'], x['start']))
        return pd.concat([x['data'] for x in segs], ignore_index=True) if segs else None
    
    def read_raw(self, table_name, start_date=None, end_date=None, symbol=None):
        collection = self.db[table_name]

        query = self.build_query(start_date, end_date, symbol)
        if not query:
            print('cannot read the whole table')
            return None

        return collection.find(query)

    def list_tables(self):
        return self.db.collection_names()

    def list_dates(self, table_name, start_date=None, end_date=None, symbol=None):
        collection = self.db[table_name]
        dates = set()
        if start_date is None:
            start_date = '00000000'
        if end_date is None:
            end_date = '99999999'
        for x in collection.find(self.build_query(start_date, end_date, symbol), {"date": 1, '_id': 0}):
            dates.add(x['date'])
        return sorted(list(dates))

    def ser(self, s, version):
        pickle_protocol = 4
        if version == 1:
            return gzip.compress(pickle.dumps(s, protocol=pickle_protocol), compresslevel=2)
        elif version == 2:
            return lzma.compress(pickle.dumps(s, protocol=pickle_protocol), preset=1)
        elif version == 3:
            # 32-bit number needs more space than 64-bit for parquet
            for col_name in s.columns:
                col = s[col_name]
                if col.dtype == np.int32:
                    s[col_name] = col.astype(np.int64)
                elif col.dtype == np.uint32:
                    s[col_name] = col.astype(np.uint64)
            tbl = pa.Table.from_pandas(s)
            f = io.BytesIO()
            pq.write_table(tbl, f, use_dictionary=False, compression='ZSTD', compression_level=0)
            f.seek(0)
            data = f.read()
            return data
        else:
            raise Exception('unknown version')

    def deser(self, s, version):
        def unpickle(s):
            return pickle.loads(s)
        if version == 1:
            return unpickle(gzip.decompress(s))
        elif version == 2:
            return unpickle(lzma.decompress(s))
        elif version == 3:
            f = io.BytesIO()
            f.write(s)
            f.seek(0)
            return pq.read_table(f, use_threads=False).to_pandas()
        else:
            raise Exception('unknown version')


def patch_pandas_pickle():
    if pd.__version__ < '0.24':
        import sys
        from types import ModuleType
        from pandas.core.internals import BlockManager
        pkg_name = 'pandas.core.internals.managers'
        if pkg_name not in sys.modules:
            m = ModuleType(pkg_name)
            m.BlockManager = BlockManager
            sys.modules[pkg_name] = m
patch_pandas_pickle()

def dailyDB(host, db_name, user, passwd):
    auth_db = db_name if user not in ('admin', 'root') else 'admin'
    url = 'mongodb://%s:%s@%s/?authSource=%s' % (user, passwd, host, auth_db)
    client = pymongo.MongoClient(url, maxPoolSize=None)
    db = client[db_name]
    return db

def read_stock_daily(db, name, start_date=None, end_date=None, skey=None, index_name=None, interval=None, col=None, return_sdi=True):
    collection = db[name]
    # Build projection
    prj = {'_id': 0}
    if col is not None:
        if return_sdi:
            col = ['skey', 'date'] + col
        for col_name in col:
            prj[col_name] = 1

    # Build query
    query = {}
    if skey is not None:
        query['skey'] = {'$in': skey}
    if index_name is not None:
        query['index_name'] = {'$in': index_name}
    if start_date is not None:
        if end_date is not None:
            query['date'] = {'$gte': start_date, '$lte': end_date}
        else:
            query['date'] = {'$gte': start_date}
    elif end_date is not None:
        query['date'] = {'$lte': end_date}

    # Load data
    cur = collection.find(query, prj)
    df = pd.DataFrame.from_records(cur)
    if df.empty:
        df = pd.DataFrame()
    else:
        df = df.sort_values(by=['date', 'skey'])
    return df   

def read_memb_daily(db, name, start_date=None, end_date=None, skey=None, index_id=None, interval=None, col=None, return_sdi=True):
    collection = db[name]
    # Build projection
    prj = {'_id': 0}
    if col is not None:
        if return_sdi:
            col = ['skey', 'date', 'index_id'] + col
        for col_name in col:
            prj[col_name] = 1

    # Build query
    query = {}
    if skey is not None:
        query['skey'] = {'$in': skey}
    if index_id is not None:
        query['index_id'] = {'$in': index_id}
    if interval is not None:
        query['interval'] = {'$in': interval}
    if start_date is not None:
        if end_date is not None:
            query['date'] = {'$gte': start_date, '$lte': end_date}
        else:
            query['date'] = {'$gte': start_date}
    elif end_date is not None:
        query['date'] = {'$lte': end_date}

    # Load data
    cur = collection.find(query, prj)
    df = pd.DataFrame.from_records(cur)
    if df.empty:
        df = pd.DataFrame()
    else:
        df = df.sort_values(by=['date', 'index_id', 'skey'])
    return df 



import pandas as pd
import random
import numpy as np
import glob
import pickle
import os
import datetime
import time
pd.set_option("max_columns", 200)

year = "2020"
startDate = '20200723'
endDate = '20200731'
database_name = 'com_md_eq_cn'
user = "zhenyuy"
password = "bnONBrzSMGoE"

startTm = datetime.datetime.now()
db1 = DB("192.168.10.178", database_name, user, password)
db2 = dailyDB("192.168.10.178", database_name, user, password)
save = {}
save['date'] = []
save['secid'] = []
mdOrderLog = db1.read('md_order', start_date=startDate, end_date=endDate, symbol=[2000001])
datelist = mdOrderLog['date'].unique()
ss = pd.read_csv('/mnt/ShareWithServer/result/shangshi.csv')
ss['skey'] = np.where(ss['证券代码'].str[-2:] == 'SZ', ss['证券代码'].str[:6].astype(int) + 2000000, ss['证券代码'].str[:6].astype(int) + 1000000)
ss['date'] = (ss['上市日期'].str[:4] + ss['上市日期'].str[5:7] + ss['上市日期'].str[8:10]).astype(int)
print(datetime.datetime.now() - startTm)

startTm = datetime.datetime.now()
for d in datelist:
    print(d)
    sl1 = read_memb_daily(db2, 'index_memb', index_id=[1000852], start_date=20170901, end_date=20201203)['skey'].unique()
    sl1 = sl1[sl1 > 2000000]
    sl2 = db1.read('md_order', start_date=str(d), end_date=str(d))['skey'].unique()
    sl1 = list(set(sl2) - set(sl1))
    data1 = db1.read('md_snapshot_l2', start_date=str(d), end_date=str(d), symbol=list(sl1))
    sl1 = data1['skey'].unique()
    op = read_stock_daily(db2, 'mdbar1d_tr', start_date=int(d), end_date=int(d))
#     sl1 = data1[(data1['cum_volume'] > 0) & (data1['time'] <= 145655000000) & (data1['ApplSeqNum'] == -1)]['skey'].unique()
#     print(sl1)
    if len(sl1) != 0:
        for s in sl1:
            mbd = db1.read('md_snapshot_mbd', start_date=str(d), end_date=str(d), symbol=s)
            if mbd is None:
                if ss[ss['skey'] == s]['date'].iloc[0] == d:
                    continue
                else:
                    save['date'].append(d)
                    save['secid'].append(s)
                    print(s)
                    continue
            try:
                assert(mbd.shape[1] == 82)
            except:
                print('mdb data column unupdated')
                print(s)
            try:
                op1 = op[op['skey'] == s]['open'].iloc[0]
                assert(mbd[mbd['cum_volume'] > 0]['open'].iloc[0] == op1)
            except:
                print('%s have no information in mdbar1d_tr' % str(s))
            l2 = data1[data1['skey'] == s]
            cols = ['skey', 'date', 'cum_volume', 'prev_close', 'open', 'close', 'cum_trades_cnt', 'bid10p', 'bid9p',
                   'bid8p', 'bid7p', 'bid6p', 'bid5p', 'bid4p', 'bid3p', 'bid2p', 'bid1p', 'ask1p', 'ask2p',
                   'ask3p', 'ask4p', 'ask5p', 'ask6p', 'ask7p', 'ask8p', 'ask9p', 'ask10p', 'bid10q', 'bid9q', 
                   'bid8q', 'bid7q', 'bid6q', 'bid5q', 'bid4q', 'bid3q', 'bid2q', 'bid1q', 'ask1q', 'ask2q', 'ask3q', 
                   'ask4q', 'ask5q', 'ask6q','ask7q', 'ask8q', 'ask9q', 'ask10q', 'bid10n', 'bid9n', 'bid8n',
                   'bid7n', 'bid6n', 'bid5n', 'bid4n', 'bid3n', 'bid2n', 'bid1n', 'ask1n', 'ask2n', 'ask3n', 
                   'ask4n', 'ask5n', 'ask6n', 'ask7n', 'ask8n', 'ask9n', 'ask10n', 'total_bid_quantity', 'total_ask_quantity']
            mbd1 = mbd.drop_duplicates(cols, keep='first')
            mbd = mbd1[cols+['ApplSeqNum']]
            if 'ApplSeqNum' in l2.columns:
                l2 = l2[list(l2.columns[l2.columns != 'ApplSeqNum'])]
            rl2 = pd.merge(l2, mbd, on=cols, how='left')
            try:
                assert(rl2[(rl2['ApplSeqNum'].isnull()) & (rl2['cum_volume'] > 0) & (rl2['time'] <= 145655000000)].shape[0] == 0)
            except:
                print(rl2[(rl2['ApplSeqNum'].isnull()) & (rl2['cum_volume'] > 0) & (rl2['time'] <= 145655000000)][['skey', 'date', 'time', 'cum_volume', 'close', 'bid1p', 'bid2p','bid1q', 'bid2q', 'ask1p', 'ask2p', 'ask1q', 'ask2q']])
                print(mbd1.tail(1)[['skey', 'date', 'time', 'cum_volume', 'close', 'bid1p', 'bid2p','bid1q', 'bid2q', 'ask1p', 'ask2p', 'ask1q', 'ask2q']])
            rl2.loc[rl2['ApplSeqNum'].isnull(), 'ApplSeqNum'] = -1
            rl2['ApplSeqNum'] = rl2['ApplSeqNum'].astype('int32') 
            assert(rl2.shape[0] == l2.shape[0])
            db1.write('md_snapshot_l2', rl2)
        print(datetime.datetime.now() - startTm)
    else:
        continue

0:00:00.235335
20200731
0:02:29.272143
0:02:30.667752
0:02:30.912918
0:02:31.077298
0:02:31.298398
0:02:32.358810
0:02:33.333960
0:02:33.541797
0:02:33.839165
0:02:35.243848
0:02:35.569649
0:02:36.050669
0:02:36.676787
0:02:37.055009
0:02:37.814460
0:02:38.111622
0:02:38.664936
0:02:39.146421
0:02:39.568104
0:02:39.907206
0:02:40.136499
0:02:40.624781
0:02:41.141561
0:02:41.509048
0:02:41.931955
0:02:43.581386
0:02:45.165457
0:02:45.467338
0:02:45.948489
0:02:47.605692
0:02:48.260022
0:02:49.111580
0:02:49.427957
0:02:49.753639
0:02:51.768898
0:02:53.119835
0:02:53.565859
0:02:54.462210
0:02:55.122250
0:02:56.317728
0:02:57.861167
0:02:58.838972
0:02:59.760308
0:03:00.656348
0:03:00.867851
0:03:01.047019
0:03:01.206737
0:03:01.600097
0:03:02.038084
0:03:02.430631
0:03:02.725588
0:03:03.057653
0:03:03.761339
0:03:03.931251
0:03:04.862090
0:03:05.802473
0:03:06.640136
0:03:06.940377
0:03:07.217257
0:03:07.864444
0:03:09.004292
0:03:09.307689
0:03:09.577174
0:03:10.222854
0:03:10.471408
0

0:07:56.824804
0:07:57.042910
0:07:57.164994
0:07:57.609548
0:07:57.777648
0:07:58.094787
0:07:59.284302
0:07:59.807900
0:08:00.397357
0:08:00.854647
0:08:01.160860
0:08:01.527212
0:08:01.751054
0:08:02.066049
0:08:02.351516
0:08:02.993558
0:08:03.422798
0:08:03.969249
0:08:04.659233
0:08:05.381484
0:08:05.871167
0:08:06.821335
0:08:07.239805
0:08:07.837646
0:08:08.114703
0:08:08.540482
0:08:08.760748
0:08:08.998829
0:08:09.245780
0:08:10.326330
0:08:10.532160
0:08:10.963143
0:08:11.318352
0:08:11.449724
0:08:11.978237
0:08:12.175623
0:08:12.478310
0:08:12.747814
0:08:12.883344
0:08:13.285853
0:08:14.188658
0:08:14.889616
0:08:15.250842
0:08:15.999758
0:08:16.356050
0:08:17.740236
0:08:18.100865
0:08:18.309775
0:08:18.609359
0:08:18.819781
0:08:19.154303
0:08:19.485362
0:08:19.923935
0:08:20.243974
0:08:20.608136
0:08:20.978566
0:08:21.302588
0:08:21.923045
0:08:22.271337
0:08:22.908797
0:08:23.248303
0:08:24.196965
0:08:24.759580
0:08:25.188041
0:08:26.023591
0:08:26.311893
0:08:26.61

0:11:56.973904
0:11:57.267224
0:11:57.573889
0:11:58.118020
0:11:58.581196
0:11:59.758571
0:12:00.091319
0:12:00.604971
0:12:00.833822
0:12:01.200805
0:12:01.730424
0:12:02.252185
0:12:02.631548
0:12:03.104292
0:12:03.838787
0:12:04.082091
0:12:04.387151
0:12:04.722997
0:12:04.932949
0:12:05.169077
0:12:05.518202
0:12:05.726057
0:12:05.942866
0:12:06.255108
0:12:06.798511
0:12:07.410327
0:12:07.906603
0:12:08.247337
0:12:08.524797
0:12:08.899104
0:12:09.189013
0:12:09.495117
0:12:09.978904
0:12:10.287228
0:12:10.828735
0:12:11.045058
0:12:11.308534
0:12:11.574501
0:12:12.046745
0:12:12.332607
0:12:12.638402
0:12:13.199404
0:12:13.453547
0:12:13.907301
0:12:14.275413
0:12:14.621750
0:12:14.920311
0:12:15.207802
0:12:15.578259
0:12:16.364717
0:12:16.889737
0:12:17.465652
0:12:17.722903
0:12:18.025835
0:12:18.502503
0:12:18.927322
0:12:20.606650
0:12:20.840680
0:12:21.133673
0:12:21.364252
0:12:21.674730
0:12:22.360117
0:12:23.474068
0:12:24.248100
0:12:24.874590
0:12:25.320915
0:12:25.83