In [51]:
import pymongo
import pandas as pd
import pickle
import datetime
import time
import gzip
import lzma
import pytz

class DB(object):
    def __init__(self, uri, symbol_column='skey'):
        self.db_name = 'white_db'
        user, passwd, host = self.parse_uri(uri)
        auth_db = 'admin' if user in ('admin', 'root') else self.db_name
        self.uri = 'mongodb://%s:%s@%s/?authSource=%s' % (user, passwd, host, auth_db)

        self.client = pymongo.MongoClient(self.uri)
        self.db = self.client[self.db_name]
        self.chunk_size = 20000
        self.symbol_column = symbol_column
        self.date_column = 'date'

    def parse_uri(self, uri):
        # mongodb://user:password@example.com
        return uri.strip().replace('mongodb://', '').strip('/').replace(':', ' ').replace('@', ' ').split(' ')

    def drop_table(self, table_name):
        self.db.drop_collection(table_name)

    def rename_table(self, old_table, new_table):
        self.db[old_table].rename(new_table)

    def write(self, table_name, df):
        if len(df) == 0: return

        multi_date = False

        if self.date_column in df.columns:
            date = str(df.head(1)[self.date_column].iloc[0])
            multi_date = len(df[self.date_column].unique()) > 1
        else:
            raise Exception('DataFrame should contain date column')

        collection = self.db[table_name]
        collection.create_index([('date', pymongo.ASCENDING), ('symbol', pymongo.ASCENDING)], background=True)
        collection.create_index([('symbol', pymongo.ASCENDING), ('date', pymongo.ASCENDING)], background=True)

        if multi_date:
            for (date, symbol), sub_df in df.groupby([self.date_column, self.symbol_column]):
                date = str(date)
                symbol = int(symbol)
                collection.delete_many({'date': date, 'symbol': symbol})
                self.write_single(collection, date, symbol, sub_df)
        else:
            for symbol, sub_df in df.groupby([self.symbol_column]):
                collection.delete_many({'date': date, 'symbol': symbol})
                self.write_single(collection, date, symbol, sub_df)

    def write_single(self, collection, date, symbol, df):
        for start in range(0, len(df), self.chunk_size):
            end = min(start + self.chunk_size, len(df))
            df_seg = df[start:end]
            version = 1
            seg = {'ver': version, 'data': self.ser(df_seg, version), 'date': date, 'symbol': symbol, 'start': start}
            collection.insert_one(seg)

    def build_query(self, start_date=None, end_date=None, symbol=None):
        query = {}

        def parse_date(x):
            if type(x) == str:
                if len(x) != 8:
                    raise Exception("`date` must be YYYYMMDD format")
                return x
            elif type(x) == datetime.datetime or type(x) == datetime.date:
                return x.strftime("%Y%m%d")
            elif type(x) == int:
                return parse_date(str(x))
            else:
                raise Exception("invalid `date` type: " + str(type(x)))

        if start_date is not None or end_date is not None:
            query['date'] = {}
            if start_date is not None:
                query['date']['$gte'] = parse_date(start_date)
            if end_date is not None:
                query['date']['$lte'] = parse_date(end_date)

        def parse_symbol(x):
            if type(x) == int:
                return x
            else:
                return int(x)

        if symbol:
            if type(symbol) == list or type(symbol) == tuple:
                query['symbol'] = {'$in': [parse_symbol(x) for x in symbol]}
            else:
                query['symbol'] = parse_symbol(symbol)

        return query

    def delete(self, table_name, start_date=None, end_date=None, symbol=None):
        collection = self.db[table_name]

        query = self.build_query(start_date, end_date, symbol)
        if not query:
            print('cannot delete the whole table')
            return None

        collection.delete_many(query)

    def read(self, table_name, start_date=None, end_date=None, symbol=None):
        collection = self.db[table_name]

        query = self.build_query(start_date, end_date, symbol)
        if not query:
            print('cannot read the whole table')
            return None

        segs = []
        for x in collection.find(query):
            x['data'] = self.deser(x['data'], x['ver'])
            segs.append(x)
        segs.sort(key=lambda x: (x['symbol'], x['date'], x['start']))
        return pd.concat([x['data'] for x in segs], ignore_index=True) if segs else None

    def list_tables(self):
        return self.db.collection_names()

    def list_dates(self, table_name, start_date=None, end_date=None, symbol=None):
        collection = self.db[table_name]
        dates = set()
        if start_date is None:
            start_date = '00000000'
        if end_date is None:
            end_date = '99999999'
        for x in collection.find(self.build_query(start_date, end_date, symbol), {"date": 1, '_id': 0}):
            dates.add(x['date'])
        return sorted(list(dates))

    def ser(self, s, version):
        if version == 1:
            return gzip.compress(pickle.dumps(s), compresslevel=2)
        elif version == 2:
            return lzma.compress(pickle.dumps(s), preset=1)
        else:
            raise Exception('unknown version')

    def deser(self, s, version):
        def unpickle(s):
            return pickle.loads(s)

        if version == 1:
            return unpickle(gzip.decompress(s))
        elif version == 2:
            return unpickle(lzma.decompress(s))
        else:
            raise Exception('unknown version')


startDate = 20170901
endDate = 20170901
targetStockLs = [1600000]

db = DB("mongodb://user_rw:faa96dfc@192.168.10.223")

# pick all stocks from certain period
mdData = db.read('trade', start_date=startDate, end_date=endDate, symbol=targetStockLs)

In [52]:
mdData

Unnamed: 0,skey,date,time,clockAtArrival,datetime,ApplSeqNum,trade_type,trade_flag,trade_price,trade_qty,BidApplSeqNum,OfferApplSeqNum
0,1600000,20170901,92502030000,1504229102030000,2017-09-01 09:25:02.030,8933,1,0,12.68,200,85921,38698
1,1600000,20170901,92502030000,1504229102030000,2017-09-01 09:25:02.030,8934,1,0,12.68,1000,85921,132453
2,1600000,20170901,92502030000,1504229102030000,2017-09-01 09:25:02.030,8935,1,0,12.68,2400,102982,132453
3,1600000,20170901,92502030000,1504229102030000,2017-09-01 09:25:02.030,8936,1,0,12.68,800,9295,132453
4,1600000,20170901,92502030000,1504229102030000,2017-09-01 09:25:02.030,8937,1,0,12.68,500,112205,143172
...,...,...,...,...,...,...,...,...,...,...,...,...
15009,1600000,20170901,145959410000,1504249199410000,2017-09-01 14:59:59.410,2489077,1,1,12.78,1000,3913004,3554816
15010,1600000,20170901,145959530000,1504249199530000,2017-09-01 14:59:59.530,2489123,1,1,12.78,9300,3913042,3554816
15011,1600000,20170901,145959530000,1504249199530000,2017-09-01 14:59:59.530,2489124,1,1,12.78,500,3913042,3559846
15012,1600000,20170901,145959530000,1504249199530000,2017-09-01 14:59:59.530,2489125,1,1,12.78,4900,3913042,3572104


In [53]:
import pymongo
import pandas as pd
import pickle
import datetime
import time
import gzip
import lzma
import pytz


def DB(host, db_name, user, passwd):
    auth_db = db_name if user not in ('admin', 'root') else 'admin'
    uri = 'mongodb://%s:%s@%s/?authSource=%s' % (user, passwd, host, auth_db)
    return DBObj(uri, db_name=db_name)


class DBObj(object):
    def __init__(self, uri, symbol_column='skey', db_name='white_db'):
        self.db_name = db_name
        self.uri = uri
        self.client = pymongo.MongoClient(self.uri)
        self.db = self.client[self.db_name]
        self.chunk_size = 20000
        self.symbol_column = symbol_column
        self.date_column = 'date'

    def parse_uri(self, uri):
        # mongodb://user:password@example.com
        return uri.strip().replace('mongodb://', '').strip('/').replace(':', ' ').replace('@', ' ').split(' ')

    def drop_table(self, table_name):
        self.db.drop_collection(table_name)

    def rename_table(self, old_table, new_table):
        self.db[old_table].rename(new_table)

    def write(self, table_name, df):
        if len(df) == 0: return

        multi_date = False

        if self.date_column in df.columns:
            date = str(df.head(1)[self.date_column].iloc[0])
            multi_date = len(df[self.date_column].unique()) > 1
        else:
            raise Exception('DataFrame should contain date column')

        collection = self.db[table_name]
        collection.create_index([('date', pymongo.ASCENDING), ('symbol', pymongo.ASCENDING)], background=True)
        collection.create_index([('symbol', pymongo.ASCENDING), ('date', pymongo.ASCENDING)], background=True)

        if multi_date:
            for (date, symbol), sub_df in df.groupby([self.date_column, self.symbol_column]):
                date = str(date)
                symbol = int(symbol)
                collection.delete_many({'date': date, 'symbol': symbol})
                self.write_single(collection, date, symbol, sub_df)
        else:
            for symbol, sub_df in df.groupby([self.symbol_column]):
                collection.delete_many({'date': date, 'symbol': symbol})
                self.write_single(collection, date, symbol, sub_df)

    def write_single(self, collection, date, symbol, df):
        for start in range(0, len(df), self.chunk_size):
            end = min(start + self.chunk_size, len(df))
            df_seg = df[start:end]
            version = 1
            seg = {'ver': version, 'data': self.ser(df_seg, version), 'date': date, 'symbol': symbol, 'start': start}
            collection.insert_one(seg)

    def build_query(self, start_date=None, end_date=None, symbol=None):
        query = {}

        def parse_date(x):
            if type(x) == str:
                if len(x) != 8:
                    raise Exception("`date` must be YYYYMMDD format")
                return x
            elif type(x) == datetime.datetime or type(x) == datetime.date:
                return x.strftime("%Y%m%d")
            elif type(x) == int:
                return parse_date(str(x))
            else:
                raise Exception("invalid `date` type: " + str(type(x)))

        if start_date is not None or end_date is not None:
            query['date'] = {}
            if start_date is not None:
                query['date']['$gte'] = parse_date(start_date)
            if end_date is not None:
                query['date']['$lte'] = parse_date(end_date)

        def parse_symbol(x):
            if type(x) == int:
                return x
            else:
                return int(x)

        if symbol:
            if type(symbol) == list or type(symbol) == tuple:
                query['symbol'] = {'$in': [parse_symbol(x) for x in symbol]}
            else:
                query['symbol'] = parse_symbol(symbol)

        return query

    def delete(self, table_name, start_date=None, end_date=None, symbol=None):
        collection = self.db[table_name]

        query = self.build_query(start_date, end_date, symbol)
        if not query:
            print('cannot delete the whole table')
            return None

        collection.delete_many(query)

    def read(self, table_name, start_date=None, end_date=None, symbol=None):
        collection = self.db[table_name]

        query = self.build_query(start_date, end_date, symbol)
        if not query:
            print('cannot read the whole table')
            return None

        segs = []
        for x in collection.find(query):
            x['data'] = self.deser(x['data'], x['ver'])
            segs.append(x)
        segs.sort(key=lambda x: (x['symbol'], x['date'], x['start']))
        return pd.concat([x['data'] for x in segs], ignore_index=True) if segs else None

    def list_tables(self):
        return self.db.collection_names()

    def list_dates(self, table_name, start_date=None, end_date=None, symbol=None):
        collection = self.db[table_name]
        dates = set()
        if start_date is None:
            start_date = '00000000'
        if end_date is None:
            end_date = '99999999'
        for x in collection.find(self.build_query(start_date, end_date, symbol), {"date": 1, '_id': 0}):
            dates.add(x['date'])
        return sorted(list(dates))

    def ser(self, s, version):
        pickle_protocol = 4
        if version == 1:
            return gzip.compress(pickle.dumps(s, protocol=pickle_protocol), compresslevel=2)
        elif version == 2:
            return lzma.compress(pickle.dumps(s, protocol=pickle_protocol), preset=1)
        else:
            raise Exception('unknown version')

    def deser(self, s, version):
        def unpickle(s):
            return pickle.loads(s)

        if version == 1:
            return unpickle(gzip.decompress(s))
        elif version == 2:
            return unpickle(lzma.decompress(s))
        else:
            raise Exception('unknown version')


def patch_pandas_pickle():
    if pd.__version__ < '0.24':
        import sys
        from types import ModuleType
        from pandas.core.internals import BlockManager
        pkg_name = 'pandas.core.internals.managers'
        if pkg_name not in sys.modules:
            m = ModuleType(pkg_name)
            m.BlockManager = BlockManager
            sys.modules[pkg_name] = m
patch_pandas_pickle()


# database_name = 
# user = 
# password = 

# db = white_db.DB("192.168.10.223", database_name, user, password)
# data = db.read(table, ***)

In [36]:
db.delete('md_snapshot_l2', start_date=startDate, end_date=endDate)

In [58]:
startDate = 20170101
endDate = 20200102
data = db.read('md_trade', start_date=startDate, end_date=endDate)
data

Unnamed: 0,skey,date,time,clockAtArrival,datetime,ApplSeqNum,trade_type,trade_flag,trade_price,trade_qty,BidApplSeqNum,OfferApplSeqNum
0,1600000,20170901,92502030000,1504229102030000,2017-09-01 09:25:02.030,8933,1,0,12.68,200,85921,38698
1,1600000,20170901,92502030000,1504229102030000,2017-09-01 09:25:02.030,8934,1,0,12.68,1000,85921,132453
2,1600000,20170901,92502030000,1504229102030000,2017-09-01 09:25:02.030,8935,1,0,12.68,2400,102982,132453
3,1600000,20170901,92502030000,1504229102030000,2017-09-01 09:25:02.030,8936,1,0,12.68,800,9295,132453
4,1600000,20170901,92502030000,1504229102030000,2017-09-01 09:25:02.030,8937,1,0,12.68,500,112205,143172
...,...,...,...,...,...,...,...,...,...,...,...,...
102032,2000001,20200102,150000000000,1577948400000000,2020-01-02 15:00:00.000,17674007,1,0,16.87,600,17644436,17668613
102033,2000001,20200102,150000000000,1577948400000000,2020-01-02 15:00:00.000,17674008,1,0,16.87,1300,17645108,17668613
102034,2000001,20200102,150000000000,1577948400000000,2020-01-02 15:00:00.000,17674009,1,0,16.87,500,17646130,17668613
102035,2000001,20200102,150000000000,1577948400000000,2020-01-02 15:00:00.000,17674010,1,0,16.87,1984,17646716,17668613


In [54]:
database_name = 'com_md_eq_cn'
user = "zhenyuy"
password = "bnONBrzSMGoE"

db = DB("192.168.10.223", database_name, user, password)
db.write('md_trade', mdData)