diff --git a/requirements-dev.txt b/requirements-dev.txt index 2057be4..3a9d630 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,4 +3,5 @@ pytest pytest-cov flake8==3.8.3 autopep8==1.5.3 -freezegun==0.3.15 \ No newline at end of file +freezegun==0.3.15 +black==19.10b0 \ No newline at end of file diff --git a/setup.py b/setup.py index ab38d00..f0b16b1 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ from setuptools import setup, find_packages EXCLUDE_FROM_PACKAGES = ['test', 'test.*', 'test*'] -VERSION = '1.0.16' +VERSION = '1.0.17' with open("README.md", "r") as fh: long_description = fh.read() diff --git a/src/pystockdb/db/schema/stocks.py b/src/pystockdb/db/schema/stocks.py index bafc0ca..c030e27 100644 --- a/src/pystockdb/db/schema/stocks.py +++ b/src/pystockdb/db/schema/stocks.py @@ -106,6 +106,7 @@ class Tag(db.Entity): GOG = 'google' USD = 'USD' EUR = 'EUR' + RUB = 'RUB' IDX = 'index' ICA = 'incomeanalysis' ICF = 'incomefacts' diff --git a/src/pystockdb/tools/base.py b/src/pystockdb/tools/base.py index 16dd7c1..cb723ca 100644 --- a/src/pystockdb/tools/base.py +++ b/src/pystockdb/tools/base.py @@ -12,8 +12,16 @@ from pony.orm import commit, db_session, core from pytickersymbols import PyTickerSymbols -from pystockdb.db.schema.stocks import (Index, Item, PriceItem, Stock, Symbol, - Tag, Type, db) +from pystockdb.db.schema.stocks import ( + Index, + Item, + PriceItem, + Stock, + Symbol, + Tag, + Type, + db, +) from pystockdb.tools.data_crawler import DataCrawler @@ -70,64 +78,91 @@ def __add_stock_to_index(self, index, stock_info): stock_in_db = Stock.get(name=stock_info['name']) if stock_in_db: self.logger.info( - 'Add stock {}:{} to index.'.format(index.name, - stock_in_db.name) + 'Add stock {}:{} to index.'.format( + index.name, stock_in_db.name + ) ) index.stocks.add(stock_in_db) else: self.logger.info( - 'Add stock {}:{} to db'.format(index.name, - stock_info[Type.SYM]) + 'Add stock {}:{} to db'.format( + index.name, stock_info[Type.SYM] + ) ) # create stock - stock = Stock(name=stock_info['name'], - price_item=PriceItem(item=Item())) + stock = Stock( + name=stock_info['name'], price_item=PriceItem(item=Item()) + ) # add symbols yao = Tag.get(name=Tag.YAO) gog = Tag.get(name=Tag.GOG) usd = Tag.get(name=Tag.USD) eur = Tag.get(name=Tag.EUR) + rub = Tag.get(name=Tag.RUB) for symbol in stock_info['symbols']: if Tag.GOG in symbol and symbol[Tag.GOG] != '-': - self.__create_symbol(stock, Tag.GOG, gog, symbol, eur, usd) + self.__create_symbol( + stock, Tag.GOG, gog, symbol, eur, usd, rub + ) if Tag.YAO in symbol and symbol[Tag.YAO] != '-': - self.__create_symbol(stock, Tag.YAO, yao, symbol, eur, usd) + self.__create_symbol( + stock, Tag.YAO, yao, symbol, eur, usd, rub + ) index.stocks.add(stock) # connect stock with industry and country # country name = stock_info['country'] - country = Tag.select(lambda t: t.name == name and - t.type.name == Type.REG).first() + country = Tag.select( + lambda t: t.name == name and t.type.name == Type.REG + ).first() country.items.add(stock.price_item.item) # industry indus = stock_info['industries'] - industries = Tag.select(lambda t: t.name in indus and - t.type.name == Type.IND) + industries = Tag.select( + lambda t: t.name in indus and t.type.name == Type.IND + ) for industry in industries: industry.items.add(stock.price_item.item) @db_session - def __create_symbol(self, stock, my_tag, my_tag_item, symbol, eur, usd): + def __create_symbol( + self, stock, my_tag, my_tag_item, symbol, eur, usd, rub + ): if my_tag in symbol and symbol[my_tag] != '-': - cur = eur if symbol[my_tag].startswith('FRA') or \ - symbol[my_tag].endswith('.F') else usd - item = Item() - item.add_tags([my_tag_item, cur]) + cur = None + if symbol[my_tag].startswith('FRA:') or symbol[my_tag].endswith( + '.F' + ): + cur = eur + elif ( + symbol[my_tag].startswith('NYSE:') + or symbol[my_tag].startswith('OTCMKTS:') + or symbol[my_tag].startswith('NASDAQ:') + or ('.' not in symbol[my_tag] and ':' not in symbol[my_tag]) + ): + cur = usd + elif symbol[my_tag].startswith('MCX:') or symbol[my_tag].endswith( + '.ME' + ): + cur = rub + + if cur: + item = Item() + item.add_tags([my_tag_item, cur]) if Symbol.get(name=symbol[my_tag]): self.logger.warning( 'Symbol {} is related to more than one' ' stock.'.format(symbol[my_tag]) ) else: - stock.price_item.symbols.create(item=item, - name=symbol[my_tag]) + stock.price_item.symbols.create(item=item, name=symbol[my_tag]) @db_session def download_historicals(self, symbols, start, end): if not (start and end): return False crawler = DataCrawler() - chunks = [symbols[x:x + 50] for x in range(0, len(symbols), 50)] + chunks = [symbols[x : x + 50] for x in range(0, len(symbols), 50)] for chunk in chunks: ids = [symbol.name for symbol in chunk] if ids is None: @@ -136,8 +171,9 @@ def download_historicals(self, symbols, start, end): series = crawler.get_series_stack(ids, start=start, end=end) for symbol in chunk: self.logger.debug( - 'Add prices for {} from {} until {}.'.format(symbol.name, - start, end) + 'Add prices for {} from {} until {}.'.format( + symbol.name, start, end + ) ) for value in series[symbol.name]: symbol.prices.create(**value) @@ -151,9 +187,10 @@ def __insert_initial_data(self): industry_type = Type(name=Type.IND) Type(name=Type.MSC).add_tags([Tag.IDX]) Type(name=Type.SYM).add_tags([Tag.YAO, Tag.GOG]) - Type(name=Type.CUR).add_tags([Tag.USD, Tag.EUR]) - Type(name=Type.FDM).add_tags([Tag.ICA, Tag.ICF, Tag.REC, - Tag.ICO, Tag.BLE, Tag.CSH]) + Type(name=Type.CUR).add_tags([Tag.USD, Tag.EUR, Tag.RUB]) + Type(name=Type.FDM).add_tags( + [Tag.ICA, Tag.ICF, Tag.REC, Tag.ICO, Tag.BLE, Tag.CSH] + ) Type(name=Type.FIL) Type(name=Type.ICR) Type(name=Type.ARG) diff --git a/src/pystockdb/tools/create.py b/src/pystockdb/tools/create.py index 9de6d83..3a4aff4 100644 --- a/src/pystockdb/tools/create.py +++ b/src/pystockdb/tools/create.py @@ -36,7 +36,7 @@ def build(self): # add indices and stocks to db if not self.indices_list: return 0 - if not all(cur in [Tag.EUR, Tag.USD] for cur in self.currencies) \ + if not all(cur in [Tag.EUR, Tag.USD, Tag.RUB] for cur in self.currencies) \ and self.prices: self.logger.warning( 'Currency {} is not supported.'.format(self.currencies) diff --git a/src/pystockdb/tools/update.py b/src/pystockdb/tools/update.py index d01e8cb..293e9bb 100644 --- a/src/pystockdb/tools/update.py +++ b/src/pystockdb/tools/update.py @@ -15,8 +15,15 @@ from pony.orm import commit, db_session, select -from pystockdb.db.schema.stocks import (Data, DataItem, Item, Price, PriceItem, - Tag, Symbol) +from pystockdb.db.schema.stocks import ( + Data, + DataItem, + Item, + Price, + PriceItem, + Tag, + Symbol, +) from pystockdb.tools.base import DBBase from pystockdb.tools.fundamentals import Fundamentals from pystockdb.tools import ALL_SYMBOLS @@ -56,11 +63,13 @@ def update_prices(self): for symbol in set(self.symbols): if len(prices) == 0: # download initial data - price_filtered.append([ - datetime.datetime.now() - - timedelta(days=365*self.history), - Symbol.get(name=symbol)] - ) + price_filtered.append( + [ + datetime.datetime.now() + - timedelta(days=365 * self.history), + Symbol.get(name=symbol), + ] + ) else: for price in prices: if symbol == price[1].name: @@ -80,20 +89,27 @@ def update_prices(self): end = datetime.datetime.now() if start.date() >= end.date(): continue - self.download_historicals(update[key], - start=start.strftime('%Y-%m-%d'), - end=end.strftime('%Y-%m-%d')) + self.download_historicals( + update[key], + start=start.strftime('%Y-%m-%d'), + end=end.strftime('%Y-%m-%d'), + ) commit() @db_session def update_fundamentals(self): """Updates all fundamentals of stocks """ - # select stock if first google symbol - stocks = list(select((pit.stock, sym.name) for pit in PriceItem - for sym in pit.symbols - if (Tag.GOG in sym.item.tags.name) and - sym.id == min(pit.symbols.id))) + # At the moment Fundamental client supports only usd symbols + stocks = list( + select( + (pit.stock, sym.name) + for pit in PriceItem + for sym in pit.symbols + if (Tag.GOG in sym.item.tags.name) + and (Tag.USD in sym.item.tags.name) + ) + ) # filter specific stocks if not all if ALL_SYMBOLS not in self.symbols: stocks_filtered = [] @@ -115,7 +131,7 @@ def update_fundamentals(self): if len(stock) != 1: self.logger.warning( 'Can not download fundamentals for {}'.format(ticker) - ) + ) continue stock = stock[0] ica = fundamentals.get_income_analysis(tickers[ticker]) @@ -124,8 +140,14 @@ def update_fundamentals(self): ble = fundamentals.get_balance(tickers[ticker]) ico = fundamentals.get_income(tickers[ticker]) csh = fundamentals.get_cash_flow(tickers[ticker]) - for val in [(ica, Tag.ICA), (ifc, Tag.ICF), (rec, Tag.REC), - (ble, Tag.BLE), (ico, Tag.ICO), (csh, Tag.CSH)]: + for val in [ + (ica, Tag.ICA), + (ifc, Tag.ICF), + (rec, Tag.REC), + (ble, Tag.BLE), + (ico, Tag.ICO), + (csh, Tag.CSH), + ]: # hash stock name, tag and data m = hashlib.sha256() m.update(val[1].encode('UTF-8')) @@ -136,7 +158,10 @@ def update_fundamentals(self): # only add data if not exist if Data.get(hash=shahash) is None: tag = Tag.get(name=val[1]) - obj = Data(data=val[0], hash=shahash, - data_item=DataItem(item=Item(tags=[tag]))) + obj = Data( + data=val[0], + hash=shahash, + data_item=DataItem(item=Item(tags=[tag])), + ) stock.data_items.add(obj.data_item) commit() diff --git a/tests/__init__.py b/tests/__init__.py index e69de29..3b0b893 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,3 @@ +import sys + +sys.path.insert(0, 'src') \ No newline at end of file diff --git a/tests/test_crawler.py b/tests/test_crawler.py index 36ad6f2..9057011 100644 --- a/tests/test_crawler.py +++ b/tests/test_crawler.py @@ -15,8 +15,8 @@ class TestCrawler(unittest.TestCase): - SYMBOLS = ['FRA:ADS', 'FRA:BAYN', 'FRA:BMW', 'FRA:ADP', 'NASDAQ:BBBY'] - SYMBOLS_Y = ['ADS.F', 'BAYN.F', 'BMW.F', 'ADP.F'] + SYMBOLS = ['OTCMKTS:ADDDF', 'OTCMKTS:BAYZF', 'OTCMKTS:BMWYY', 'NASDAQ:ADP', 'NASDAQ:BBBY'] + SYMBOLS_Y = ['ADDDF', 'BAYZF', 'BMWYY', 'ADP'] def test_fundamentals(self): """ diff --git a/tests/test_db.py b/tests/test_db.py index 90c29bc..9ce65b8 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -46,7 +46,7 @@ def test_1_create(self): create = CreateAndFillDataBase(config, logger) self.assertEqual(create.build(), 0) config['indices'] = ['DAX'] - config['currencies'] = ['RUB'] + config['currencies'] = ['CAD'] create = CreateAndFillDataBase(config, logger) self.assertEqual(create.build(), -1) config['currencies'] = ['EUR', 'USD'] @@ -74,54 +74,9 @@ def test_1_create(self): self.assertEqual(stocks, 30) prices = list(select(max(p.date) for p in Price)) self.assertEqual(len(prices), 1) - self.assertEqual(prices[0].strftime('%Y-%m-%d'), '2019-01-14') + self.assertEqual(prices[0].strftime('%Y-%m-%d'), '2019-01-11') - @db_session - def test_2_dbbase(self): - config = { - 'db_args': { - 'provider': 'sqlite', - 'filename': 'database_create.sqlite', - 'create_db': False - } - } - logger = logging.getLogger('test') - dbbase = DBBase(config, logger) - ind = Index.get(name='test123') - if ind: - ind.delete() - sym = Symbol.get(name='test123') - if sym: - sym.delete() - self.assertRaises(NotImplementedError, dbbase.build) - self.assertFalse(dbbase.download_historicals(None, None, None)) - - # override pytickersymbols - def get_stocks_by_index(name): - stock = { - 'name': 'adidas AG', - 'symbol': 'ADS', - 'country': 'Germany', - 'indices': ['DAX', 'test123'], - 'industries': [], - 'symbols': [] - } - return [stock] - - def index_to_yahoo_symbol(name): - return 'test123' - - dbbase.ticker_symbols.get_stocks_by_index = get_stocks_by_index - dbbase.ticker_symbols.index_to_yahoo_symbol = index_to_yahoo_symbol - dbbase.add_indices_and_stocks(['test123']) - ads = Stock.select( - lambda s: 'test123' in s.indexs.name - ).first() - self.assertNotEqual(ads, None) - Index.get(name='test123').delete() - Symbol.get(name='test123').delete() - - def test_3_update(self): + def test_2_update(self): """ Test database client update :return: @@ -159,7 +114,7 @@ def test_3_update(self): update = UpdateDataBaseStocks(config, logger) update.build() - def test_4_sync(self): + def test_3_sync(self): """Tests sync tool """ logger = logging.getLogger('test') @@ -180,7 +135,7 @@ def test_4_sync(self): self.assertEqual(stocks, 70) @db_session - def test_5_query_data(self): + def test_4_query_data(self): ifx = Stock.select( lambda s: 'IFX.F' in s.price_item.symbols.name ).first() @@ -201,7 +156,7 @@ def test_5_query_data(self): self.assertIsInstance(rat, float) self.assertIsInstance(eps, float) - def test_6_create_flat(self): + def test_5_create_flat(self): """ Test flat create :return: @@ -250,5 +205,50 @@ def test_6_create_flat(self): data_ctx = select(d for d in Data).count() self.assertEqual(data_ctx, 12) + @db_session + def test_6_dbbase(self): + config = { + 'db_args': { + 'provider': 'sqlite', + 'filename': 'database_create.sqlite', + 'create_db': False + } + } + logger = logging.getLogger('test') + dbbase = DBBase(config, logger) + ind = Index.get(name='test123') + if ind: + ind.delete() + sym = Symbol.get(name='test123') + if sym: + sym.delete() + self.assertRaises(NotImplementedError, dbbase.build) + self.assertFalse(dbbase.download_historicals(None, None, None)) + + # override pytickersymbols + def get_stocks_by_index(name): + stock = { + 'name': 'adidas AG', + 'symbol': 'ADS', + 'country': 'Germany', + 'indices': ['DAX', 'test123'], + 'industries': [], + 'symbols': [] + } + return [stock] + + def index_to_yahoo_symbol(name): + return 'test123' + + dbbase.ticker_symbols.get_stocks_by_index = get_stocks_by_index + dbbase.ticker_symbols.index_to_yahoo_symbol = index_to_yahoo_symbol + dbbase.add_indices_and_stocks(['test123']) + ads = Stock.select( + lambda s: 'test123' in s.indexs.name + ).first() + self.assertNotEqual(ads, None) + Index.get(name='test123').delete() + Symbol.get(name='test123').delete() + if __name__ == "__main__": unittest.main()