Skip to content

Commit

Permalink
DEV: Add tz kwarg to get_datetime.
Browse files Browse the repository at this point in the history
  • Loading branch information
ssanderson committed Oct 23, 2014
1 parent 820115f commit 026ec99
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 2 deletions.
58 changes: 57 additions & 1 deletion tests/test_algorithm.py
Expand Up @@ -15,7 +15,9 @@
import datetime
from datetime import timedelta
from mock import MagicMock
from nose_parameterized import parameterized
from six.moves import range
from textwrap import dedent
from unittest import TestCase

import numpy as np
Expand Down Expand Up @@ -67,7 +69,11 @@
record_variables,
)

from zipline.utils.test_utils import drain_zipline, assert_single_position
from zipline.utils.test_utils import (
assert_single_position,
drain_zipline,
to_utc,
)

from zipline.sources import (SpecificEquityTrades,
DataFrameSource,
Expand Down Expand Up @@ -729,6 +735,56 @@ def handle_data(context, data):
self.assertIsNot(output, None)


class TestGetDatetime(TestCase):

@parameterized.expand(
[
('default', None,),
('utc', 'UTC',),
('us_east', 'US/Eastern',),
]
)
def test_get_datetime(self, name, tz):

algo = dedent(
"""
import pandas as pd
from zipline.api import get_datetime
def initialize(context):
context.tz = {tz} or 'UTC'
context.first_bar = True
def handle_data(context, data):
if context.first_bar:
dt = get_datetime(context.tz)
if dt.tz.zone != context.tz:
raise ValueError("Mismatched Zone")
elif dt.tz_convert("US/Eastern").hour != 9:
raise ValueError("Mismatched Hour")
elif dt.tz_convert("US/Eastern").minute != 31:
raise ValueError("Mismatched Minute")
context.first_bar = False
""".format(tz=repr(tz))
)

start = to_utc('2014-01-02 9:31')
end = to_utc('2014-01-03 9:31')
source = RandomWalkSource(
start=start,
end=end,
)
sim_params = factory.create_simulation_parameters(
data_frequency='minute'
)
algo = TradingAlgorithm(
script=algo,
sim_params=sim_params,
)
algo.run(source)
self.assertFalse(algo.first_bar)


class TestTradingControls(TestCase):

def setUp(self):
Expand Down
4 changes: 3 additions & 1 deletion zipline/algorithm.py
Expand Up @@ -746,13 +746,15 @@ def on_dt_changed(self, dt):
self.blotter.set_date(dt)

@api_method
def get_datetime(self):
def get_datetime(self, tz=None):
"""
Returns a copy of the datetime.
"""
date_copy = copy(self.datetime)
assert date_copy.tzinfo == pytz.utc, \
"Algorithm should have a utc datetime"
if tz is not None:
date_copy = date_copy.tz_convert(tz)
return date_copy

def set_transact(self, transact):
Expand Down

0 comments on commit 026ec99

Please sign in to comment.