Skip to content

Commit

Permalink
Feature/singelton (#15)
Browse files Browse the repository at this point in the history
Small improvements

Uses generators as much as possible.
PyTickerSymbols is now a singleton because there is no need to hold more than one instance
  • Loading branch information
SlashGordon committed Oct 7, 2019
1 parent 223b843 commit 7833bf2
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 36 deletions.
8 changes: 8 additions & 0 deletions .theia/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"python.testing.pytestArgs": [
"tests"
],
"python.testing.unittestEnabled": false,
"python.testing.nosetestsEnabled": false,
"python.testing.pytestEnabled": true
}
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from setuptools import setup, find_packages

EXCLUDE_FROM_PACKAGES = ['test', 'test.*', 'test*']
VERSION = '1.1.8'

VERSION = '1.1.9'

with open("README.md", "r") as fh:
long_description = fh.read()
Expand Down
46 changes: 29 additions & 17 deletions src/pytickersymbols/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,20 @@
import os
import json

class Singleton(type):
def __init__(cls, name, bases, my_dict):
super(Singleton, cls).__init__(name, bases, my_dict)
cls.instance = None

def __call__(cls,*args,**kw):
if cls.instance is None:
cls.instance = super(Singleton, cls).__call__(*args, **kw)
return cls.instance

class PyTickerSymbols:

__metaclass__ = Singleton

def __init__(self):
self.__stocks = None
json_path = os.path.join(
Expand Down Expand Up @@ -38,8 +49,10 @@ def get_all_countries(self):
Returns all available countries
:return: list of country names
"""
items = [stock["country"] for stock in self.__stocks["companies"]]
return list(set(items))
countries = set(
map(lambda stock: stock['country'], self.__stocks["companies"])
)
return countries

def get_stocks_by_index(self, index):
"""
Expand Down Expand Up @@ -81,15 +94,10 @@ def get_stocks_by_country(self, country):
:param country: name of country
:return: list of stocks
"""

def __valid(cou, st_cou):
return isinstance(cou, str) and st_cou.lower() == cou.lower()

return [
stock
for stock in self.__stocks["companies"]
if __valid(country, stock["country"])
]
return filter(
lambda stock: isinstance(country, str) and stock["country"].lower() == country.lower(),
self.__stocks["companies"]
)

def index_to_yahoo_symbol(self, index_name):
"""
Expand All @@ -105,12 +113,16 @@ def index_to_yahoo_symbol(self, index_name):
return yahoo_symbol

def __get_items(self, key, val):
stocks = [
stock
for stock in self.__stocks["companies"]
for item in stock[key]
if isinstance(val, str) and val.lower() == item.lower()
]
stocks = filter(
lambda item: len(
list(
filter(
lambda sub_item: isinstance(val, str) and val.lower() == sub_item.lower(), item[key]
)
)
) > 0,
self.__stocks["companies"]
)
return stocks

def __get_sub_items(self, key):
Expand Down
43 changes: 25 additions & 18 deletions tests/test_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@

class TestLib(unittest.TestCase):

def test_singleton(self):
"""
Test singleton pattern
:return:
"""
self.assertTrue(id(PyTickerSymbols()) == id(PyTickerSymbols()))

def test_index(self):
"""
Test index getter
Expand All @@ -37,7 +44,7 @@ def test_encoding(self):
"""
stock_data = PyTickerSymbols()
self.assertIsNotNone(stock_data)
dax = stock_data.get_stocks_by_index('DAX')
dax = list(stock_data.get_stocks_by_index('DAX'))
self.assertEqual(dax[10]['name'], 'Deutsche Börse AG')


Expand All @@ -48,7 +55,7 @@ def test_country(self):
"""
stock_data = PyTickerSymbols()
self.assertIsNotNone(stock_data)
countries = stock_data.get_all_countries()
countries = list(stock_data.get_all_countries())
self.assertIsNotNone(countries)
self.assertIn("Germany", countries)
self.assertIn("Netherlands", countries)
Expand All @@ -65,7 +72,7 @@ def test_industry(self):
"""
stock_data = PyTickerSymbols()
self.assertIsNotNone(stock_data)
industries = stock_data.get_all_industries()
industries = list(stock_data.get_all_industries())
self.assertIsNotNone(industries)
self.assertIn("Computer Hardware", industries)
self.assertIn("Gold", industries)
Expand All @@ -82,16 +89,16 @@ def test_stocks_by_index(self):
"""
stock_data = PyTickerSymbols()
self.assertIsNotNone(stock_data)
stocks = stock_data.get_stocks_by_index(None)
stocks = list(stock_data.get_stocks_by_index(None))
self.assertEqual(len(stocks), 0)
stocks = stock_data.get_stocks_by_index(False)
stocks = list(stock_data.get_stocks_by_index(False))
self.assertEqual(len(stocks), 0)
stocks = stock_data.get_stocks_by_index(True)
stocks = list(stock_data.get_stocks_by_index(True))
self.assertEqual(len(stocks), 0)
stocks = stock_data.get_stocks_by_index(22)
stocks = list(stock_data.get_stocks_by_index(22))
self.assertEqual(len(stocks), 0)
for ind, ctx in [('DAX', 30), ('CAC 40', 40)]:
stocks = stock_data.get_stocks_by_index(ind)
stocks = list(stock_data.get_stocks_by_index(ind))
self.assertIsNotNone(stocks)
self.assertEqual(len(stocks), ctx)
for stock in stocks:
Expand All @@ -108,15 +115,15 @@ def test_stocks_by_country(self):
"""
stock_data = PyTickerSymbols()
self.assertIsNotNone(stock_data)
stocks = stock_data.get_stocks_by_country(None)
stocks = list(stock_data.get_stocks_by_country(None))
self.assertEqual(len(stocks), 0)
stocks = stock_data.get_stocks_by_country(False)
stocks = list(stock_data.get_stocks_by_country(False))
self.assertEqual(len(stocks), 0)
stocks = stock_data.get_stocks_by_country(True)
stocks = list(stock_data.get_stocks_by_country(True))
self.assertEqual(len(stocks), 0)
stocks = stock_data.get_stocks_by_country(22)
stocks = list(stock_data.get_stocks_by_country(22))
self.assertEqual(len(stocks), 0)
stocks = stock_data.get_stocks_by_country("Israel")
stocks = list(stock_data.get_stocks_by_country("Israel"))
self.assertIsNotNone(stocks)
self.assertTrue(len(stocks) >= 1)
for stock in stocks:
Expand All @@ -132,15 +139,15 @@ def test_stocks_by_industry(self):
"""
stock_data = PyTickerSymbols()
self.assertIsNotNone(stock_data)
stocks = stock_data.get_stocks_by_industry(None)
stocks = list(stock_data.get_stocks_by_industry(None))
self.assertEqual(len(stocks), 0)
stocks = stock_data.get_stocks_by_industry(False)
stocks = list(stock_data.get_stocks_by_industry(False))
self.assertEqual(len(stocks), 0)
stocks = stock_data.get_stocks_by_industry(True)
stocks = list(stock_data.get_stocks_by_industry(True))
self.assertEqual(len(stocks), 0)
stocks = stock_data.get_stocks_by_industry(22)
stocks = list(stock_data.get_stocks_by_industry(22))
self.assertEqual(len(stocks), 0)
stocks = stock_data.get_stocks_by_industry("Basic Materials")
stocks = list(stock_data.get_stocks_by_industry("Basic Materials"))
self.assertIsNotNone(stocks)
for stock in stocks:
is_in_basic = False
Expand Down

0 comments on commit 7833bf2

Please sign in to comment.