Skip to content

Commit

Permalink
update test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
Cuizi7 committed Sep 7, 2018
1 parent 77d64ce commit e0f76d8
Showing 1 changed file with 52 additions and 1 deletion.
53 changes: 52 additions & 1 deletion tests/api/test_api_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
"end_date": "2016-12-31",
"frequency": "1d",
"accounts": {
"stock": 1000000
"stock": 1000000,
"future": 1000000,
}
},
"extra": {
Expand Down Expand Up @@ -293,3 +294,53 @@ def handle_bar(_, bar_dict):
field, getattr(snapshot, field), field, getattr(bar, field)
)
return handle_bar


@as_test_strategy()
def test_get_position():

def assert_position(pos, obid, dir, today_quantity, old_quantity, avg_price):
assert pos.order_book_id == obid
assert pos.direction == dir, "Direction of {} is expected to be {} instead of {}".format(
pos.order_book_id, dir, pos.direction
)
assert pos.today_quantity == today_quantity
assert pos.old_quantity == old_quantity
assert pos.quantity == (today_quantity + old_quantity)
assert pos.avg_price == avg_price

def init(context):
context.counter = 0
context.expected_avg_price = None

def handle_bar(context, bar_dict):
context.counter += 1

if context.counter == 1:
order_shares("000001.XSHE", 300)
context.expected_avg_price = bar_dict["000001.XSHE"].close
elif context.counter == 5:
order_shares("000001.XSHE", -100)
elif context.counter == 10:
sell_open("RB1701", 5)
context.expected_avg_price = bar_dict["RB1701"].close
elif context.counter == 15:
buy_close("RB1701", 2)

if 1 <= context.counter < 5:
pos = get_positions()[0]
assert_position(pos, "000001.XSHE", POSITION_DIRECTION.LONG, 300, 0, context.expected_avg_price)
elif 5 <= context.counter < 10:
pos = get_position("000001.XSHE", POSITION_DIRECTION.LONG)
assert_position(pos, "000001.XSHE", POSITION_DIRECTION.LONG, 200, 0, context.expected_avg_price)
elif context.counter == 10:
pos = get_position("RB1701", POSITION_DIRECTION.SHORT)
assert_position(pos, "RB1701", POSITION_DIRECTION.SHORT, 5, 0, context.expected_avg_price)
elif 10 < context.counter < 15:
pos = get_position("RB1701", POSITION_DIRECTION.SHORT)
assert_position(pos, "RB1701", POSITION_DIRECTION.SHORT, 0, 5, context.expected_avg_price)
elif context.counter >= 15:
pos = get_position("RB1701", POSITION_DIRECTION.SHORT)
assert_position(pos, "RB1701", POSITION_DIRECTION.SHORT, 0, 3, context.expected_avg_price)

return init, handle_bar

0 comments on commit e0f76d8

Please sign in to comment.