Skip to content

Commit

Permalink
Merge branch 'livetestsfixes'
Browse files Browse the repository at this point in the history
  • Loading branch information
rienafairefr committed May 3, 2017
2 parents 93cce92 + f979e6f commit f6f51ea
Show file tree
Hide file tree
Showing 17 changed files with 228 additions and 184 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@ ynab.conf
*.log
*.env
nosetests*.xml
tests/*.json
testscripts/*.json
7 changes: 5 additions & 2 deletions pynYNAB/ClientFactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,12 @@ def __init__(self, client):


class nYnabClientFactory(object):
def __init__(self, engine_url='sqlite://'):
def __init__(self, engine_url='sqlite://', engine=None):
self.engine_url = engine_url
self.engine = create_engine(engine_url)
if engine is None:
self.engine = create_engine(engine_url)
else:
self.engine = engine

Base.metadata.create_all(self.engine)
Session = sessionmaker(bind=self.engine)
Expand Down
76 changes: 47 additions & 29 deletions pynYNAB/ObjClient.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import logging
from abc import abstractproperty,ABCMeta

from pynYNAB.schema import fromapi_conversion_functions_table

LOG = logging.getLogger(__name__)


Expand All @@ -20,14 +23,26 @@ def __init__(self, obj, client):
self.client = client
self.connection = client.connection
self.session = client.session
self.server_entities = {}
self.synced = False

def update_from_api_changed_entitydicts(self, changed_entitydicts):
def update_from_api_changed_entitydicts(self, changed_entitydicts, update_keys=None):
if update_keys is None:
update_keys = list(self.obj.listfields.keys())
else:
update_keys = [k for k in update_keys if k in self.obj.listfields]
modified_entitydicts = {}
for name in changed_entitydicts:
for listfield_name in update_keys:
newlist = []
for entitydict in changed_entitydicts[name]:
newlist.append(self.obj.listfields[name].from_apidict(entitydict))
modified_entitydicts[name] = newlist
if changed_entitydicts[listfield_name] is not None:
for entitydict in changed_entitydicts[listfield_name]:
newlist.append(self.obj.listfields[listfield_name].from_apidict(entitydict))
modified_entitydicts[listfield_name] = newlist
for scalarfield_name in self.obj.scalarfields:
if scalarfield_name in changed_entitydicts:
typ = self.obj.scalarfields[scalarfield_name]
conversion_function = fromapi_conversion_functions_table.get(typ, lambda t, x: x)
modified_entitydicts[scalarfield_name] = conversion_function(typ, changed_entitydicts[scalarfield_name])
self.update_from_changed_entities(modified_entitydicts)

def update_from_changed_entitydict(self, changed_entitiydicts):
Expand Down Expand Up @@ -79,18 +94,19 @@ def update_from_changed_entities(self, changed_entities):
self.session.commit()
pass

def update_from_sync_data(self, sync_data):
self.update_from_api_changed_entitydicts(sync_data['changed_entities'])
def update_from_sync_data(self, sync_data, update_keys=None):
self.update_from_api_changed_entitydicts(sync_data['changed_entities'],update_keys)

def sync(self):

def sync(self, update_keys=None):
if self.connection is None:
return
sync_data = self.get_sync_data_obj()

self.client.server_entities = sync_data['changed_entities']
self.server_entities[self.opname] = sync_data['changed_entities']
LOG.debug('server_knowledge_of_device ' + str(sync_data['server_knowledge_of_device']))
LOG.debug('current_server_knowledge ' + str(sync_data['current_server_knowledge']))
self.update_from_sync_data(sync_data)
self.update_from_sync_data(sync_data,update_keys)
self.session.commit()
self.obj.clear_changed_entities()

Expand All @@ -109,8 +125,9 @@ def sync(self):

LOG.debug('current_device_knowledge %s' % self.obj.knowledge.current_device_knowledge)
LOG.debug('device_knowledge_of_server %s' % self.obj.knowledge.device_knowledge_of_server)
self.synced = True

def push(self):
def push(self, update_from_sync_data=True, update_keys=None):
changed_entities = self.obj.get_changed_apidict()
request_data = dict(starting_device_knowledge=self.client.starting_device_knowledge,
ending_device_knowledge=self.client.ending_device_knowledge,
Expand All @@ -125,24 +142,25 @@ def validate():
sync_data = self.connection.dorequest(request_data, self.opname)
LOG.debug('server_knowledge_of_device ' + str(sync_data['server_knowledge_of_device']))
LOG.debug('current_server_knowledge ' + str(sync_data['current_server_knowledge']))
self.update_from_sync_data(sync_data)
validate()

server_knowledge_of_device = sync_data['server_knowledge_of_device']
current_server_knowledge = sync_data['current_server_knowledge']

change = current_server_knowledge - self.obj.knowledge.device_knowledge_of_server
if change > 0:
LOG.debug('Server knowledge has gone up by ' + str(
change) + '. We should be getting back some entities from the server')
if self.obj.knowledge.current_device_knowledge < server_knowledge_of_device:
if self.obj.knowledge.current_device_knowledge != 0:
LOG.error('The server knows more about this device than we know about ourselves')
self.obj.knowledge.current_device_knowledge = server_knowledge_of_device
self.obj.knowledge.device_knowledge_of_server = current_server_knowledge

LOG.debug('current_device_knowledge %s' % self.obj.knowledge.current_device_knowledge)
LOG.debug('device_knowledge_of_server %s' % self.obj.knowledge.device_knowledge_of_server)
if update_from_sync_data:
self.update_from_sync_data(sync_data, update_keys)
validate()

server_knowledge_of_device = sync_data['server_knowledge_of_device']
current_server_knowledge = sync_data['current_server_knowledge']

change = current_server_knowledge - self.obj.knowledge.device_knowledge_of_server
if change > 0:
LOG.debug('Server knowledge has gone up by ' + str(
change) + '. We should be getting back some entities from the server')
if self.obj.knowledge.current_device_knowledge < server_knowledge_of_device:
if self.obj.knowledge.current_device_knowledge != 0:
LOG.error('The server knows more about this device than we know about ourselves')
self.obj.knowledge.current_device_knowledge = server_knowledge_of_device
self.obj.knowledge.device_knowledge_of_server = current_server_knowledge

LOG.debug('current_device_knowledge %s' % self.obj.knowledge.current_device_knowledge)
LOG.debug('device_knowledge_of_server %s' % self.obj.knowledge.device_knowledge_of_server)
else:
validate()

Expand Down
9 changes: 2 additions & 7 deletions pynYNAB/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,6 @@
configfile = os.path.join(myAppdir, configfile)

LOG = logging.getLogger(__name__)
# create console handler with a higher log level
ch = logging.StreamHandler()
ch.setLevel(logging.WARNING)
# create formatter and add it to the handlers
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter)
LOG.addHandler(ch)

parser = configargparse.getArgumentParser('pynYNAB', default_config_files=[configfile],
add_env_var_help=True,
Expand All @@ -39,12 +32,14 @@
parser.add_argument('--budgetname', metavar='BudgetName', type=str, required=False,
help='The nYNAB budget to use')


class classproperty(object):
def __init__(self, f):
self.f = f
def __get__(self, obj, owner):
return self.f(owner)


class MainCommands(object):
def __init__(self):
parser = argparse.ArgumentParser(
Expand Down
6 changes: 3 additions & 3 deletions pynYNAB/schema/Client.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ def init_internal_db(self):
self.Session = sessionmaker(bind=self.engine)
self.session = self.Session()

def sync(self):
def sync(self, update_keys=None):
LOG.debug('Client.sync')

self.catalogClient.sync()
self.catalogClient.sync(update_keys)
self.select_budget(self.budget_name)
self.budgetClient.sync()
self.budgetClient.sync(update_keys)

if self.budget_version_id is None and self.budget_name is not None:
raise BudgetNotFound()
Expand Down
8 changes: 5 additions & 3 deletions pynYNAB/schema/Entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
import re

from sqlalchemy import ForeignKey
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import relationship
from sqlalchemy.sql.sqltypes import Enum as sqlaEnum, String
from sqlalchemy.sql.sqltypes import Enum as sqlaEnum, String, DateTime
from aenum import Enum
from sqlalchemy import Boolean
from sqlalchemy import Column
Expand Down Expand Up @@ -87,7 +86,7 @@ def listfields(self):
@property
def scalarfields(self):
scalarcolumns = self.__table__.columns
return {k: scalarcolumns[k].type.__class__.__name__ for k in scalarcolumns.keys() if k != 'parent_id'}
return {k: scalarcolumns[k].type.__class__ for k in scalarcolumns.keys() if k != 'parent_id' and k != 'knowledge_id'}



Expand Down Expand Up @@ -154,6 +153,7 @@ def init_scalar(target, value, dict_):
re_date = re.compile(r'\d{4}[\/ .-]\d{2}[\/.-]\d{2}')



def date_from_api(columntype, string):
result = re_date.search(string)
if result is not None:
Expand All @@ -162,12 +162,14 @@ def date_from_api(columntype, string):

fromapi_conversion_functions_table = {
Date: date_from_api,
DateTime: lambda t,x: datetime.strptime(x, '%Y-%m-%dT%H:%M:%S.%f'),
AmountType: lambda t, x: float(x) / 1000,
sqlaEnum: lambda t, x: t.enum_class[x]
}

toapi_conversion_functions_table = {
Date: lambda t, x: x.strftime('%Y-%m-%d'),
DateTime: lambda t, x: x.strftime('%Y-%m-%dT%H:%M:%S.%f'),
AmountType: lambda t, x: int(float(x) * 1000),
sqlaEnum: lambda t, x: x._name_
}
Expand Down
1 change: 1 addition & 0 deletions pynYNAB/schema/budget.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class Transaction(Base, BudgetEntity):
check_number = Column(String)
cleared = Column(String, default='Uncleared')
credit_amount = Column(AmountType)
credit_amount_adjusted = Column(Boolean)
date = Column(Date)
date_entered_from_schedule = Column(Date)
entities_account_id = Column(ForeignKey('account.id'))
Expand Down
2 changes: 2 additions & 0 deletions pynYNAB/schema/catalog.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from sqlalchemy import Boolean
from sqlalchemy import Column
from sqlalchemy import DateTime
from sqlalchemy import ForeignKey
from sqlalchemy import String
from sqlalchemy.ext.declarative import declared_attr
Expand All @@ -21,6 +22,7 @@ def parent(self):

class CatalogBudget(Base, CatalogEntity):
budget_name = Column(String)
created_at = Column(DateTime)


class User(Base, CatalogEntity):
Expand Down
34 changes: 0 additions & 34 deletions test_live/test_scaling.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,3 @@
# -*- coding: utf-8 -*-
import unittest
from pynYNAB.ClientFactory import clientfromargs
from pynYNAB.__main__ import parser
from pynYNAB.schema.Entity import toapi_conversion_functions_table, fromapi_conversion_functions_table
from pynYNAB.schema.types import AmountType

test_budget_name = 'Test Budget - Dont Remove'


# this test cases expect that
# a budget named "Test Budget" exists
# in it, there is an account named "Account"
# in it there is a transaction date 27/01/2017 that has inflow == 12.34 € with a memo "TEST TRANSACTION"
# We check the API dict we fetch when syncing budget
class LiveTestBudget(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(LiveTestBudget, self).__init__(*args, **kwargs)
self.transaction = None
self.client = None

def setUp(self):
args = parser.parse_known_args()[0]
args.budgetname = test_budget_name
self.client = clientfromargs(args, sync=False)
self.client.catalogClient.sync()
self.client.select_budget(test_budget_name)

def test_api_scaling_is_ok(self):
sync_data = self.client.budgetClient.get_sync_data_obj()
server_entities = sync_data['changed_entities']
transactions = server_entities['be_transactions']
amount = None
for transaction in transactions:
if transaction['memo'] == 'TEST TRANSACTION':
amount = fromapi_conversion_functions_table[AmountType](AmountType,transaction['amount'])
self.assertEqual(12.34,amount)
82 changes: 36 additions & 46 deletions test_live/testlive.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import random
import unittest
from datetime import datetime, timedelta

from pynYNAB import KeyGenerator
from pynYNAB.schema.Entity import AccountTypes
from pynYNAB.schema.budget import Transaction, Account, Subtransaction, Payee
from pynYNAB.ClientFactory import clientfromargs
from pynYNAB.__main__ import parser
from pynYNAB.schema import DictDiffer
from pynYNAB.schema.budget import Transaction
from test_live.common import CommonLive
from test_live.common import needs_account

Expand All @@ -27,54 +26,45 @@ def test_add_deletetransaction(self):
self.reload()
self.assertNotIn(transaction, self.client.budget.be_transactions)

def test_roundtrip(self):
# 1. syncs data from server
# 2. gets the pushed changed_entities that would be pushed as if all entities were modified
# 3 the pushed changed_entities should be strictly identical to the changed_entities that was received

def get_changed_entities_current(obj):
current_map = obj.getmaps()
return {k: list(v.values()) if isinstance(v,dict) else v for k,v in current_map.items()}

def clean_id_tombstoned(ce):
returnvalue = {}
for k, value in ce.items():
if k == 'is_tombstone' or k == 'id':
continue
class LiveTests2(unittest.TestCase):
def test_roundtrip(self):
args = parser.parse_known_args()[0]

if isinstance(value,list):
returnvalue[k] = list(set(v for v in value if not v.is_tombstone))
return returnvalue
# 1. gets sync data from server
# 2. tests that to_api(from_api(data)) is the same thing

server_catalog_changed_entities = clean_id_tombstoned(self.client.server_entities['syncCatalogData'])
server_budget_changed_entities = clean_id_tombstoned(self.client.server_entities['syncBudgetData'])

pushed_catalog_changed_entities = clean_id_tombstoned(get_changed_entities_current(self.client.catalog))
pushed_budget_changed_entities = clean_id_tombstoned(get_changed_entities_current(self.client.budget))
client = clientfromargs(args, sync=False)
sync_data = client.catalogClient.get_sync_data_obj()
budget_version_id = next(d['id'] for d in sync_data['changed_entities']['ce_budget_versions'] if
d['version_name'] == args.budgetname)
client.budget_version_id = budget_version_id

self.checkEqual(pushed_catalog_changed_entities.keys(), server_catalog_changed_entities.keys(),
'catalog changed entities roundtrip keys %s and %s not equal')
self.checkEqual(pushed_budget_changed_entities.keys(), server_budget_changed_entities.keys(),
'budget changed entities roundtrip keys not equal')
for objclient in (client.budgetClient, client.catalogClient):
sync_data = objclient.get_sync_data_obj()
server_changed_entities = sync_data['changed_entities']

for key in server_catalog_changed_entities:
self.checkEqual(pushed_catalog_changed_entities[key], server_catalog_changed_entities[key],
'catalog changed entities roundtrip value for key %s not equal' % key)
for key in server_budget_changed_entities:
self.checkEqual(pushed_budget_changed_entities[key], server_budget_changed_entities[key],
'budget changed entities roundtrip value for key %s not equal' % key)
for key in server_changed_entities:
if key in objclient.obj.listfields:
if len(server_changed_entities[key]) == 0:
continue
obj_dict = server_changed_entities[key][0]
typ = objclient.obj.listfields[key]
obj_dict2 = typ.from_apidict(obj_dict).get_apidict()

@staticmethod
def checkEqual(l1, l2, msg):
ll1 = list(l1)
ll2 = list(l2)
try:
if len(ll1) == len(ll2) and len(set(ll1) - set(ll2)) == 0:
return True
except TypeError as e:
pass
raise AssertionError(msg)
diff = DictDiffer(obj_dict2, obj_dict)
for k in diff.changed():
AssertionError('changed {}: {}->{}'.format(k, obj_dict[k], obj_dict2[k]))
for k in diff.removed():
AssertionError('removed {}: {}'.format(k, obj_dict[k]))
for k in diff.added():
AssertionError('added {}: {}'.format(k, obj_dict2[k]))
elif key in objclient.obj.scalarfields:
obj_dict2 = objclient.obj.from_apidict(server_changed_entities).get_apidict()
if server_changed_entities[key] != obj_dict2[key]:
AssertionError('changed {}: {}->{}'.format(key, server_changed_entities[key], obj_dict2[key]))


if __name__ == "__main__":
unittest.main()
unittest.main()
Loading

0 comments on commit f6f51ea

Please sign in to comment.