Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP

Loading…

ENH: Add class to write Stata binary dta files #526

Merged
merged 17 commits into from

1 participant

Skipper Seabold
Skipper Seabold
Owner

This needs a little polishing but works for now for the below simple case. Should round trip the same.

dta = sm.datasets.macrodata.load().data
dtype = dta.dtype
dta = dta.astype(np.dtype([('year', 'i8'),('quarter', 'i4')] + dtype.descr[2:]))
writer = StataWriter("./try_dta.dta", dta)
writer.write_file()
dta2 = genfromdta('./try_dta.dta')

TODO:

  • Docstrings
  • Test on Python 3
  • Test with different encoding for strings - not handled yet
  • Test with object arrays
  • Test with datetime types - will need to convert to Stata time integer
  • Add support for pandas objects
  • Handle missing values
Skipper Seabold
Owner

All the TODOs are addressed and tested except Python 3 and unicode. This is ready to merge if we're happy with the API and tests pass on Python 3. I'm not entirely happy with the API, but I only really see this being useful for pandas objects and structured arrays. Since I seriously doubt anyone is going from structured array to dta, I'll focus on getting the API right with a PR to pandas. I'll also fix the unicode support there too and stick it back in here.

Skipper Seabold
Owner

Would like to go ahead and merge this. It's not perfect and I will probably continue to work on it incrementally - testing in production..., but I could use it for work right now. If this is not a great reason to merge, then I could figure something else out.

Skipper Seabold jseabold merged commit 25c8b7c into from
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
This page is out of date. Refresh to see the latest.
599 statsmodels/iolib/foreign.py
View
@@ -10,11 +10,14 @@
numpy.lib.io
"""
-from struct import unpack, calcsize
+from struct import unpack, calcsize, pack
+import datetime
import sys
import numpy as np
from numpy.lib._iotools import _is_string_like, easy_dtype
from statsmodels.compatnp.py3k import asbytes
+import statsmodels.tools.data as data_util
+from pandas import isnull
def is_py3():
@@ -24,6 +27,126 @@ def is_py3():
return False
PY3 = is_py3()
+_date_formats = ["%tc", "%tC", "%td", "%tw", "%tm", "%tq", "%th", "%ty"]
+
+def _datetime_to_stata_elapsed(date, fmt):
+ """
+ Convert from datetime to SIF. http://www.stata.com/help.cgi?datetime
+
+ Parameters
+ ----------
+ date : datetime.datetime
+ The date to convert to the Stata Internal Format given by fmt
+ fmt : str
+ The format to convert to. Can be, tc, td, tw, tm, tq, th, ty
+ """
+ if not isinstance(date, datetime.datetime):
+ raise ValueError("date should be datetime.datetime format")
+ stata_epoch = datetime.datetime(1960, 1, 1)
+ if fmt in ["%tc", "tc"]:
+ delta = date - stata_epoch
+ return (delta.days * 86400000 + delta.seconds*1000 +
+ delta.microseconds/1000)
+ elif fmt in ["%tC", "tC"]:
+ from warnings import warn
+ warn("Stata Internal Format tC not supported.")
+ return date
+ elif fmt in ["%td", "td"]:
+ return (date- stata_epoch).days
+ elif fmt in ["%tw", "tw"]:
+ return (52*(date.year-stata_epoch.year) +
+ (date - datetime.datetime(date.year, 1, 1)).days / 7)
+ elif fmt in ["%tm", "tm"]:
+ return (12 * (date.year - stata_epoch.year) + date.month - 1)
+ elif fmt in ["%tq", "tq"]:
+ return 4*(date.year-stata_epoch.year) + int((date.month - 1)/3)
+ elif fmt in ["%th", "th"]:
+ return 2 * (date.year - stata_epoch.year) + int(date.month > 6)
+ elif fmt in ["%ty", "ty"]:
+ return date.year
+ else:
+ raise ValueError("fmt %s not understood" % fmt)
+
+def _stata_elapsed_date_to_datetime(date, fmt):
+ """
+ Convert from SIF to datetime. http://www.stata.com/help.cgi?datetime
+
+ Parameters
+ ----------
+ date : int
+ The Stata Internal Format date to convert to datetime according to fmt
+ fmt : str
+ The format to convert to. Can be, tc, td, tw, tm, tq, th, ty
+
+ Examples
+ --------
+ >>> _stata_elapsed_date_to_datetime(52, "%tw") datetime.datetime(1961, 1, 1, 0, 0)
+
+ Notes
+ -----
+ datetime/c - tc
+ milliseconds since 01jan1960 00:00:00.000, assuming 86,400 s/day
+ datetime/C - tC - NOT IMPLEMENTED
+ milliseconds since 01jan1960 00:00:00.000, adjusted for leap seconds
+ date - td
+ days since 01jan1960 (01jan1960 = 0)
+ weekly date - tw
+ weeks since 1960w1
+ This assumes 52 weeks in a year, then adds 7 * remainder of the weeks.
+ The datetime value is the start of the week in terms of days in the
+ year, not ISO calendar weeks.
+ monthly date - tm
+ months since 1960m1
+ quarterly date - tq
+ quarters since 1960q1
+ half-yearly date - th
+ half-years since 1960h1 yearly
+ date - ty
+ years since 0000
+
+ If you don't have pandas with datetime support, then you can't do
+ milliseconds accurately.
+ """
+ #NOTE: we could run into overflow / loss of precision situations here
+ # casting to int, but I'm not sure what to do. datetime won't deal with
+ # numpy types and numpy datetime isn't mature enough / we can't rely on
+ # pandas version > 0.7.1
+ #TODO: IIRC relative delta doesn't play well with np.datetime?
+ date = int(date)
+ stata_epoch = datetime.datetime(1960, 1, 1)
+ if fmt in ["%tc", "tc"]:
+ from dateutil.relativedelta import relativedelta
+ return stata_epoch + relativedelta(microseconds=date*1000)
+ elif fmt in ["%tC", "tC"]:
+ from warnings import warn
+ warn("Encountered %tC format. Leaving in Stata Internal Format.")
+ return date
+ elif fmt in ["%td", "td"]:
+ return stata_epoch + datetime.timedelta(int(date))
+ elif fmt in ["%tw", "tw"]: # does not count leap days - 7 days is a week
+ year = datetime.datetime(stata_epoch.year + date / 52, 1, 1)
+ day_delta = (date % 52 ) * 7
+ return year + datetime.timedelta(int(day_delta))
+ elif fmt in ["%tm", "tm"]:
+ year = stata_epoch.year + date / 12
+ month_delta = (date % 12 ) + 1
+ return datetime.datetime(year, month_delta, 1)
+ elif fmt in ["%tq", "tq"]:
+ year = stata_epoch.year + date / 4
+ month_delta = (date % 4) * 3 + 1
+ return datetime.datetime(year, month_delta, 1)
+ elif fmt in ["%th", "th"]:
+ year = stata_epoch.year + date / 2
+ month_delta = (date % 2) * 6 + 1
+ return datetime.datetime(year, month_delta, 1)
+ elif fmt in ["%ty", "ty"]:
+ if date > 0:
+ return datetime.datetime(date, 1, 1)
+ else: # don't do negative years bc can't mix dtypes in column
+ raise ValueError("Year 0 and before not implemented")
+ else:
+ raise ValueError("Date fmt %s not understood" % fmt)
+
### Helper classes for StataReader ###
class _StataMissingValue(object):
@@ -166,6 +289,9 @@ class StataReader(object):
[(251, np.int16),(252, np.int32),(253, int),
(254, np.float32), (255, np.float64)])
TYPE_MAP = range(251)+list('bhlfd')
+ #NOTE: technically, some of these are wrong. there are more numbers
+ # that can be represented. it's the 27 ABOVE and BELOW the max listed
+ # numeric data type in [U] 12.2.2 of the 11.2 manual
MISSING_VALUES = { 'b': (-127,100), 'h': (-32767, 32740), 'l':
(-2147483647, 2147483620), 'f': (-1.701e+38, +1.701e+38), 'd':
(-1.798e+308, +8.988e+307) }
@@ -359,11 +485,9 @@ def _parse_header(self, file_object):
encoding) for i in range(nvar)]
# ignore expansion fields
-# When reading, read five bytes; the last four bytes now tell you the size of
-# the next read, which you discard. You then continue like this until you
-# read 5 bytes of zeros.
-# TODO: The way I read this is that they both should be zero, but that's
-# not what we get.
+ # When reading, read five bytes; the last four bytes now tell you the
+ # size of the next read, which you discard. You then continue like
+ # this until you read 5 bytes of zeros.
while True:
data_type = unpack(byteorder+'b', self._file.read(1))[0]
@@ -420,8 +544,422 @@ def _next(self):
self._file.read(self._col_size(i))),
range(self._header['nvar']))
-def genfromdta(fname, missing_flt=-999., missing_str="", encoding=None,
- pandas=False):
+def _open_file_binary_write(fname, encoding):
+ if hasattr(fname, 'write'):
+ #if 'b' not in fname.mode:
+ return fname
+ if PY3:
+ return open(fname, "wb", encoding=encoding)
+ else:
+ return open(fname, "wb")
+
+def _set_endianness(endianness):
+ if endianness.lower() in ["<", "little"]:
+ return "<"
+ elif endianness.lower() in [">", "big"]:
+ return ">"
+ else: # pragma : no cover
+ raise ValueError("Endianness %s not understood" % endianness)
+
+def _dtype_to_stata_type(dtype):
+ """
+ Converts dtype types to stata types. Returns the byte of the given ordinal.
+ See TYPE_MAP and comments for an explanation. This is also explained in
+ the dta spec.
+ 1 - 244 are strings of this length
+ 251 - chr(251) - for int8 and int16, byte
+ 252 - chr(252) - for int32, int
+ 253 - chr(253) - for int64, long
+ 254 - chr(254) - for float32, float
+ 255 - chr(255) - double, double
+
+ If there are dates to convert, then dtype will already have the correct
+ type inserted.
+ """
+ #TODO: expand to handle datetime to integer conversion
+ if dtype.type == np.string_: # might have to coerce objects here
+ return chr(dtype.itemsize)
+ elif dtype == np.float64:
+ return chr(255)
+ elif dtype == np.float32:
+ return chr(254)
+ elif dtype == np.int64:
+ return chr(253)
+ elif dtype == np.int32:
+ return chr(252)
+ elif dtype == np.int8 or dtype == np.int16: # ok to assume bytes?
+ return chr(251)
+ else: # pragma : no cover
+ raise ValueError("Data type %s not currently understood. "
+ "Please report an error to the developers." % dtype)
+
+def _dtype_to_default_stata_fmt(dtype):
+ """
+ Maps numpy dtype to stata's default format for this type. Not terribly
+ important since users can change this in Stata. Semantics are
+
+ string -> "%sDD" where DD is the length of the string
+ float64 -> "%10.0g"
+ float32 -> "%9.0g"
+ int64 -> "%9.0g"
+ int32 -> "%9.0g"
+ int16 -> "%9.0g"
+ int8 -> "%8.0g"
+ """
+ #TODO: expand this to handle a default datetime format?
+ if dtype.type == np.string_: # might have to coerce objects here
+ return "%" + str(dtype.itemsize) + "s"
+ elif dtype == np.float64:
+ return "%10.0g"
+ elif dtype == np.float32:
+ return "%9.0g"
+ elif dtype == np.int64:
+ return "%9.0g"
+ elif dtype == np.int32:
+ return "%8.0g"
+ elif dtype == np.int8 or dtype == np.int16: # ok to assume bytes?
+ return "%8.0g"
+ else: # pragma : no cover
+ raise ValueError("Data type %s not currently understood. "
+ "Please report an error to the developers." % dtype)
+
+def _pad_bytes(name, length):
+ """
+ Takes a char string and pads it wih null bytes until it's length chars
+ """
+ return name + "\x00" * (length - len(name))
+
+def _default_names(nvar):
+ """
+ Returns default Stata names v1, v2, ... vnvar
+ """
+ return ["v%d" % i for i in range(1,nvar+1)]
+
+def _convert_datetime_to_stata_type(fmt):
+ """
+ Converts from one of the stata date formats to a type in TYPE_MAP
+ """
+ if fmt in ["tc", "%tc", "td", "%td", "tw", "%tw", "tm", "%tm", "tq",
+ "%tq", "th", "%th", "ty", "%ty"]:
+ return np.float64 # Stata expects doubles for SIFs
+ else:
+ raise ValueError("fmt %s not understood" % fmt)
+
+def _maybe_convert_to_int_keys(convert_dates, varlist):
+ new_dict = {}
+ for key in convert_dates:
+ if not convert_dates[key].startswith("%"): # make sure proper fmts
+ convert_dates[key] = "%" + convert_dates[key]
+ if key in varlist:
+ new_dict.update({varlist.index(key) : convert_dates[key]})
+ else:
+ if not isinstance(key, int):
+ raise ValueError("convery_dates key is not in varlist "
+ "and is not an int")
+ new_dict.update({key : convert_dates[key]})
+ return new_dict
+
+class StataWriter(object):
+ """
+ A class for writing Stata binary dta files from array-like objects
+
+ Parameters
+ ----------
+ fname : file path or buffer
+ Where to save the dta file.
+ data : array-like
+ Array-like input to save. Pandas objects are also accepted.
+ convert_dates : dict
+ Dictionary mapping column of datetime types to the stata internal
+ format that you want to use for the dates. Options are
+ 'tc', 'td', 'tm', 'tw', 'th', 'tq', 'ty'. Column can be either a
+ number or a name.
+ encoding : str
+ Default is latin-1. Note that Stata does not support unicode.
+ byteorder : str
+ Can be ">", "<", "little", or "big". The default is None which uses
+ `sys.byteorder`
+
+ Returns
+ -------
+ writer : StataWriter instance
+ The StataWriter instance has a write_file method, which will
+ write the file to the given `fname`.
+
+ Examples
+ --------
+ >>> writer = StataWriter('./data_file.dta', data)
+ >>> writer.write_file()
+
+ Or with dates
+
+ >>> writer = StataWriter('./date_data_file.dta', date, {2 : 'tw'})
+ >>> writer.write_file()
+ """
+ #type code
+ #--------------------
+ #str1 1 = 0x01
+ #str2 2 = 0x02
+ #...
+ #str244 244 = 0xf4
+ #byte 251 = 0xfb (sic)
+ #int 252 = 0xfc
+ #long 253 = 0xfd
+ #float 254 = 0xfe
+ #double 255 = 0xff
+ #--------------------
+ #NOTE: the byte type seems to be reserved for categorical variables
+ # with a label, but the underlying variable is -127 to 100
+ # we're going to drop the label and cast to int
+ DTYPE_MAP = dict(zip(range(1,245), ['a' + str(i) for i in range(1,245)]) + \
+ [(251, np.int16),(252, np.int32),(253, int),
+ (254, np.float32), (255, np.float64)])
+ TYPE_MAP = range(251)+list('bhlfd')
+ MISSING_VALUES = { 'b': 101,
+ 'h': 32741,
+ 'l' : 2147483621,
+ 'f': 1.7014118346046923e+38,
+ 'd': 8.98846567431158e+307}
+ def __init__(self, fname, data, convert_dates=None, encoding="latin-1",
+ byteorder=None):
+
+ self._convert_dates = convert_dates
+ # attach nobs, nvars, data, varlist, typlist
+ if data_util._is_using_pandas(data, None):
+ self._prepare_pandas(data)
+
+ elif data_util._is_array_like(data, None):
+ data = np.asarray(data)
+ if data_util._is_structured_ndarray(data):
+ self._prepare_structured_array(data)
+ else:
+ if convert_dates is not None:
+ raise ValueError("Not able to convert dates in a plain"
+ " ndarray.")
+ self._prepare_ndarray(data)
+
+ else: # pragma : no cover
+ raise ValueError("Type %s for data not understood" % type(data))
+
+
+ if byteorder is None:
+ byteorder = sys.byteorder
+ self._byteorder = _set_endianness(byteorder)
+ self._encoding = encoding
+ self._file = _open_file_binary_write(fname, encoding)
+
+
+ def _prepare_structured_array(self, data):
+ self.nobs = len(data)
+ self.nvar = len(data.dtype)
+ self.data = data
+ self.datarows = iter(data)
+ dtype = data.dtype
+ descr = dtype.descr
+ if dtype.names is None:
+ varlist = _default_names(nvar)
+ else:
+ varlist = dtype.names
+
+ # check for datetime and change the type
+ convert_dates = self._convert_dates
+ if convert_dates is not None:
+ convert_dates = _maybe_convert_to_int_keys(convert_dates,
+ varlist)
+ self._convert_dates = convert_dates
+ for key in convert_dates:
+ descr[key] = (
+ descr[key][0],
+ _convert_datetime_to_stata_type(convert_dates[key])
+ )
+ dtype = np.dtype(descr)
+
+ self.varlist = varlist
+ self.typlist = [_dtype_to_stata_type(dtype[i])
+ for i in range(self.nvar)]
+ self.fmtlist = [_dtype_to_default_stata_fmt(dtype[i])
+ for i in range(self.nvar)]
+ # set the given format for the datetime cols
+ if convert_dates is not None:
+ for key in convert_dates:
+ self.fmtlist[key] = convert_dates[key]
+
+
+ def _prepare_ndarray(self, data):
+ if data.ndim == 1:
+ data = data[:,None]
+ self.nobs, self.nvar = data.shape
+ self.data = data
+ self.datarows = iter(data)
+ #TODO: this should be user settable
+ dtype = data.dtype
+ self.varlist = _default_names(self.nvar)
+ self.typlist = [_dtype_to_stata_type(dtype) for i in range(self.nvar)]
+ self.fmtlist = [_dtype_to_default_stata_fmt(dtype)
+ for i in range(self.nvar)]
+
+ def _prepare_pandas(self, data):
+ #NOTE: we might need a different API / class for pandas objects so
+ # we can set different semantics - handle this with a PR to pandas.io
+ class DataFrameRowIter(object):
+ def __init__(self, data):
+ self.data = data
+
+ def __iter__(self):
+ for i, row in data.iterrows():
+ yield row
+
+ data = data.reset_index()
+ self.datarows = DataFrameRowIter(data)
+ self.nobs, self.nvar = data.shape
+ self.data = data
+ self.varlist = data.columns.tolist()
+ dtypes = data.dtypes
+ convert_dates = self._convert_dates
+ if convert_dates is not None:
+ convert_dates = _maybe_convert_to_int_keys(convert_dates,
+ self.varlist)
+ self._convert_dates = convert_dates
+ for key in convert_dates:
+ new_type = _convert_datetime_to_stata_type(convert_dates[key])
+ dtypes[key] = np.dtype(new_type)
+ self.typlist = [_dtype_to_stata_type(dt) for dt in dtypes]
+ self.fmtlist = [_dtype_to_default_stata_fmt(dt) for dt in dtypes]
+ # set the given format for the datetime cols
+ if convert_dates is not None:
+ for key in convert_dates:
+ self.fmtlist[key] = convert_dates[key]
+
+ def write_file(self):
+ self._write_header()
+ self._write_descriptors()
+ self._write_variable_labels()
+ # write 5 zeros for expansion fields
+ self._file.write(_pad_bytes("", 5))
+ if self._convert_dates is None:
+ self._write_data_nodates()
+ else:
+ self._write_data_dates()
+ #self._write_value_labels()
+
+ def _write_header(self, data_label=None, time_stamp=None):
+ byteorder = self._byteorder
+ # ds_format - just use 114
+ self._file.write(pack("b", 114))
+ # byteorder
+ self._file.write(byteorder == ">" and "\x01" or "\x02")
+ # filetype
+ self._file.write("\x01")
+ # unused
+ self._file.write("\x00")
+ # number of vars, 2 bytes
+ self._file.write(pack(byteorder+"h", self.nvar)[:2])
+ # number of obs, 4 bytes
+ self._file.write(pack(byteorder+"i", self.nobs)[:4])
+ # data label 81 bytes, char, null terminated
+ if data_label is None:
+ self._file.write(self._null_terminate(_pad_bytes("", 80),
+ self._encoding))
+ else:
+ self._file.write(self._null_terminate(_pad_bytes(data_label[:80],
+ 80), self._encoding))
+ # time stamp, 18 bytes, char, null terminated
+ # format dd Mon yyyy hh:mm
+ if time_stamp is None:
+ time_stamp = datetime.datetime.now()
+ elif not isinstance(time_stamp, datetime):
+ raise ValueError("time_stamp should be datetime type")
+ self._file.write(self._null_terminate(
+ time_stamp.strftime("%d %b %Y %H:%M"),
+ self._encoding))
+
+ def _write_descriptors(self, typlist=None, varlist=None, srtlist=None,
+ fmtlist=None, lbllist=None):
+ nvar = self.nvar
+ # typlist, length nvar, format byte array
+ for typ in self.typlist:
+ self._file.write(typ)
+
+ # varlist, length 33*nvar, char array, null terminated
+ for name in self.varlist:
+ name = self._null_terminate(name, self._encoding)
+ name = _pad_bytes(name[:32], 33)
+ self._file.write(name)
+
+ # srtlist, 2*(nvar+1), int array, encoded by byteorder
+ srtlist = _pad_bytes("", (2*(nvar+1)))
+ self._file.write(srtlist)
+
+ # fmtlist, 49*nvar, char array
+ for fmt in self.fmtlist:
+ self._file.write(_pad_bytes(fmt, 49))
+
+ # lbllist, 33*nvar, char array
+ #NOTE: this is where you could get fancy with pandas categorical type
+ for i in range(nvar):
+ self._file.write(_pad_bytes("", 33))
+
+ def _write_variable_labels(self, labels=None):
+ nvar = self.nvar
+ if labels is None:
+ for i in range(nvar):
+ self._file.write(_pad_bytes("", 81))
+
+ def _write_data_nodates(self):
+ data = self.datarows
+ byteorder = self._byteorder
+ TYPE_MAP = self.TYPE_MAP
+ typlist = self.typlist
+ for row in data:
+ #row = row.squeeze().tolist() # needed for structured arrays
+ for i,var in enumerate(row):
+ typ = ord(typlist[i])
+ if typ <= 244: # we've got a string
+ if len(var) < typ:
+ var = _pad_bytes(var, len(var) + 1)
+ self._file.write(var)
+ else:
+ self._file.write(pack(byteorder+TYPE_MAP[typ], var))
+
+ def _write_data_dates(self):
+ convert_dates = self._convert_dates
+ data = self.datarows
+ byteorder = self._byteorder
+ TYPE_MAP = self.TYPE_MAP
+ MISSING_VALUES = self.MISSING_VALUES
+ typlist = self.typlist
+ for row in data:
+ #row = row.squeeze().tolist() # needed for structured arrays
+ for i,var in enumerate(row):
+ typ = ord(typlist[i])
+ #NOTE: If anyone finds this terribly slow, there is
+ # a vectorized way to convert dates, see genfromdta for going
+ # from int to datetime and reverse it. will copy data though
+ if i in convert_dates:
+ var = _datetime_to_stata_elapsed(var, self.fmtlist[i])
+ if typ <= 244: # we've got a string
+ if isnull(var):
+ var = "" # missing string
+ if len(var) < typ:
+ var = _pad_bytes(var, len(var) + 1)
+ self._file.write(var)
+ else:
+ if isnull(var): # this only matters for floats
+ var = MISSING_VALUES[typ]
+ self._file.write(pack(byteorder+TYPE_MAP[typ], var))
+
+
+ def _null_terminate(self, s, encoding):
+ null_byte = asbytes('\x00')
+ if PY3:
+ s += null_byte
+ return s.encode(encoding)
+ else:
+ s += null_byte
+ return s
+
+def genfromdta(fname, missing_flt=-999., encoding=None, pandas=False,
+ convert_dates=True):
"""
Returns an ndarray or DataFrame from a Stata .dta file.
@@ -432,18 +970,19 @@ def genfromdta(fname, missing_flt=-999., missing_str="", encoding=None,
missing_flt : numeric
The numeric value to replace missing values with. Will be used for
any numeric value.
- missing_str : str
- The string to replace missing values with for string variables.
encoding : string, optional
Used for Python 3 only. Encoding to use when reading the .dta file.
Defaults to `locale.getpreferredencoding`
pandas : bool
Optionally return a DataFrame instead of an ndarray
+ convert_dates : bool
+ If convert_dates is True, then Stata formatted dates will be converted
+ to datetime types according to the variable's format.
Notes
------
- Date types will be returned as their numeric value in Stata. A date
- parser is not written yet.
+ The tC Stata Internal Format for dates is not handled. These values
+ will be returned in SIF even if convert_dates is True.
"""
if isinstance(fname, basestring):
fhd = StataReader(open(fname, 'rb'), missing_values=False,
@@ -463,33 +1002,51 @@ def genfromdta(fname, missing_flt=-999., missing_str="", encoding=None,
nobs = header['nobs']
numvars = header['nvar']
varnames = header['varlist']
+ fmtlist = header['fmtlist']
dataname = header['data_label']
labels = header['vlblist'] # labels are thrown away unless DataArray
# type is used
data = np.zeros((nobs,numvars))
stata_dta = fhd.dataset()
- # key is given by np.issctype
- convert_missing = {
- True : missing_flt,
- False : missing_str}
-
dt = np.dtype(zip(varnames, types))
data = np.zeros((nobs), dtype=dt) # init final array
for rownum,line in enumerate(stata_dta):
# doesn't handle missing value objects, just casts
# None will only work without missing value object.
- if None in line:# and not remove_comma:
+ if None in line:
for i,val in enumerate(line):
+ #NOTE: This will only be scalar types because missing strings
+ # are empty not None in Stata
if val is None:
- line[i] = convert_missing[np.issctype(types[i])]
+ line[i] = missing_flt
data[rownum] = tuple(line)
- #TODO: make it possible to return plain array if all 'f8' for example
if pandas:
from pandas import DataFrame
- return DataFrame.from_records(data)
+ data = DataFrame.from_records(data)
+ if convert_dates:
+ cols = np.where(map(lambda x : x in _date_formats, fmtlist))[0]
+ for col in cols:
+ i = col
+ col = data.columns[col]
+ data[col] = data[col].apply(_stata_elapsed_date_to_datetime,
+ args=(fmtlist[i],))
+ elif convert_dates:
+ #date_cols = np.where(map(lambda x : x in _date_formats,
+ # fmtlist))[0]
+ # make the dtype for the datetime types
+ cols = np.where(map(lambda x : x in _date_formats, fmtlist))[0]
+ dtype = data.dtype.descr
+ dtype = [(dt[0], object) if i in cols else dt for i,dt in
+ enumerate(dtype)]
+ data = data.astype(dtype) # have to copy
+ for col in cols:
+ def convert(x):
+ return _stata_elapsed_date_to_datetime(x, fmtlist[col])
+ data[data.dtype.names[col]] = map(convert,
+ data[data.dtype.names[col]])
return data
def savetxt(fname, X, names=None, fmt='%.18e', delimiter=' '):
BIN  statsmodels/iolib/tests/data_missing.dta
View
Binary file not shown
9 statsmodels/iolib/tests/gen_dates.do
View
@@ -0,0 +1,9 @@
+insheet using "/home/skipper/statsmodels/statsmodels-skipper/statsmodels/iolib/tests/stata_dates.csv"
+format datetime_c %tc
+format datetime_big_c %tC
+format date %td
+format weekly_date %tw
+format monthly_date %tm
+format quarterly_date %tq
+format half_yearly_date %th
+format yearly_date %ty
BIN  statsmodels/iolib/tests/results/time_series_examples.dta
View
Binary file not shown
3  statsmodels/iolib/tests/stata_dates.csv
View
@@ -0,0 +1,3 @@
+datetime_c,datetime_big_c,date,weekly_date,monthly_date,quarterly_date,half_yearly_date,yearly_date
+1479597200000,1479596223000,18282,2601,600,58,100,2010
+-14200000,-1479590,-2282,-601,-60,-18,-10,2
145 statsmodels/iolib/tests/test_foreign.py
View
@@ -2,10 +2,18 @@
Tests for iolib/foreign.py
"""
+from StringIO import StringIO
+from datetime import datetime
+
from numpy.testing import *
import numpy as np
import statsmodels.api as sm
import os
+from statsmodels.iolib.foreign import (StataWriter, genfromdta,
+ _datetime_to_stata_elapsed, _stata_elapsed_date_to_datetime)
+from statsmodels.datasets import macrodata
+from pandas import DataFrame, isnull
+import pandas.util.testing as ptesting
# Test precisions
DECIMAL_4 = 4
@@ -13,29 +21,150 @@
def test_genfromdta():
- """
- Test genfromdta vs. results/macrodta.npy created with genfromtxt.
- """
-#NOTE: Stata handles data very oddly. Round tripping from csv to dta
-# to ndarray 2710.349 (csv) -> 2510.2491 (stata) -> 2710.34912109375
-# (dta/ndarray)
+ #Test genfromdta vs. results/macrodta.npy created with genfromtxt.
+ #NOTE: Stata handles data very oddly. Round tripping from csv to dta
+ # to ndarray 2710.349 (csv) -> 2510.2491 (stata) -> 2710.34912109375
+ # (dta/ndarray)
curdir = os.path.dirname(os.path.abspath(__file__))
#res2 = np.load(curdir+'/results/macrodata.npy')
#res2 = res2.view((float,len(res2[0])))
from results.macrodata import macrodata_result as res2
- res1 = sm.iolib.genfromdta(curdir+'/../../datasets/macrodata/macrodata.dta')
+ res1 = genfromdta(curdir+'/../../datasets/macrodata/macrodata.dta')
#res1 = res1.view((float,len(res1[0])))
assert_array_equal(res1 == res2, True)
def test_genfromdta_pandas():
from pandas.util.testing import assert_frame_equal
- dta = sm.datasets.macrodata.load_pandas().data
+ dta = macrodata.load_pandas().data
curdir = os.path.dirname(os.path.abspath(__file__))
res1 = sm.iolib.genfromdta(curdir+'/../../datasets/macrodata/macrodata.dta',
pandas=True)
res1 = res1.astype(float)
assert_frame_equal(res1, dta)
+def test_stata_writer_structured():
+ buf = StringIO()
+ dta = macrodata.load().data
+ dtype = dta.dtype
+ dta = dta.astype(np.dtype([('year', 'i8'),
+ ('quarter', 'i4')] + dtype.descr[2:]))
+ writer = StataWriter(buf, dta)
+ writer.write_file()
+ buf.seek(0)
+ dta2 = genfromdta(buf)
+ assert_array_equal(dta, dta2)
+
+def test_stata_writer_array():
+ buf = StringIO()
+ dta = macrodata.load().data
+ dta = DataFrame.from_records(dta)
+ dta.columns = ["v%d" % i for i in range(1,15)]
+ writer = StataWriter(buf, dta.values)
+ writer.write_file()
+ buf.seek(0)
+ dta2 = genfromdta(buf)
+ dta = dta.to_records(index=False)
+ assert_array_equal(dta, dta2)
+
+def test_missing_roundtrip():
+ buf = StringIO()
+ dta = np.array([(np.nan, np.inf, "")],
+ dtype=[("double_miss", float), ("float_miss", np.float32),
+ ("string_miss", "a1")])
+ writer = StataWriter(buf, dta)
+ writer.write_file()
+ buf.seek(0)
+ dta = genfromdta(buf, missing_flt=np.nan)
+ assert_(isnull(dta[0][0]))
+ assert_(isnull(dta[0][1]))
+ assert_(dta[0][2] == "")
+
+ dta = genfromdta("./data_missing.dta", missing_flt=-999)
+ assert_(np.all([dta[0][i] == -999 for i in range(5)]))
+
+def test_stata_writer_pandas():
+ buf = StringIO()
+ dta = macrodata.load().data
+ dtype = dta.dtype
+ #as of 0.9.0 pandas only supports i8 and f8
+ dta = dta.astype(np.dtype([('year', 'i8'),
+ ('quarter', 'i8')] + dtype.descr[2:]))
+ dta = DataFrame.from_records(dta)
+ writer = StataWriter(buf, dta)
+ writer.write_file()
+ buf.seek(0)
+ dta2 = genfromdta(buf)
+ ptesting.assert_frame_equal(dta.reset_index(), DataFrame.from_records(dta2))
+
+def test_stata_writer_unicode():
+ # make sure to test with characters outside the latin-1 encoding
+ pass
+
+def test_genfromdta_datetime():
+ results = [(datetime(2006, 11, 19, 23, 13, 20), 1479596223000,
+ datetime(2010, 1, 20), datetime(2010, 1, 8), datetime(2010, 1, 1),
+ datetime(1974, 7, 1), datetime(2010, 1, 1), datetime(2010, 1, 1)),
+ (datetime(1959, 12, 31, 20, 3, 20), -1479590, datetime(1953, 10, 2),
+ datetime(1948, 6, 10), datetime(1955, 1, 1), datetime(1955, 7, 1),
+ datetime(1955, 1, 1), datetime(2, 1, 1))]
+ dta = genfromdta("results/time_series_examples.dta")
+ assert_array_equal(dta[0].tolist(), results[0])
+ assert_array_equal(dta[1].tolist(), results[1])
+
+ dta = genfromdta("results/time_series_examples.dta", pandas=True)
+ assert_array_equal(dta.irow(0).tolist(), results[0])
+ assert_array_equal(dta.irow(1).tolist(), results[1])
+
+def test_date_converters():
+ ms = [-1479597200000, -1e6, -1e5, -100, 1e5, 1e6, 1479597200000]
+ days = [-1e5, -1200, -800, -365, -50, 0, 50, 365, 800, 1200, 1e5]
+ weeks = [-1e4, -1e2, -53, -52, -51, 0, 51, 52, 53, 1e2, 1e4]
+ months = [-1e4, -1e3, -100, -13, -12, -11, 0, 11, 12, 13, 100, 1e3, 1e4]
+ quarter = [-100, -50, -5, -4, -3, 0, 3, 4, 5, 50, 100]
+ half = [-50, 40, 30, 10, 3, 2, 1, 0, 1, 2, 3, 10, 30, 40, 50]
+ year = [1, 50, 500, 1000, 1500, 1975, 2075]
+ for i in ms:
+ assert_equal(_datetime_to_stata_elapsed(
+ _stata_elapsed_date_to_datetime(i, "tc"), "tc"), i)
+ for i in days:
+ assert_equal(_datetime_to_stata_elapsed(
+ _stata_elapsed_date_to_datetime(i, "td"), "td"), i)
+ for i in weeks:
+ assert_equal(_datetime_to_stata_elapsed(
+ _stata_elapsed_date_to_datetime(i, "tw"), "tw"), i)
+ for i in months:
+ assert_equal(_datetime_to_stata_elapsed(
+ _stata_elapsed_date_to_datetime(i, "tm"), "tm"), i)
+ for i in quarter:
+ assert_equal(_datetime_to_stata_elapsed(
+ _stata_elapsed_date_to_datetime(i, "tq"), "tq"), i)
+ for i in half:
+ assert_equal(_datetime_to_stata_elapsed(
+ _stata_elapsed_date_to_datetime(i, "th"), "th"), i)
+ for i in year:
+ assert_equal(_datetime_to_stata_elapsed(
+ _stata_elapsed_date_to_datetime(i, "ty"), "ty"), i)
+
+def test_datetime_roundtrip():
+ dta = np.array([(1, datetime(2010, 1, 1), 2),
+ (2, datetime(2010, 2, 1), 3),
+ (4, datetime(2010, 3, 1), 5)],
+ dtype=[('var1', float), ('var2', object), ('var3', float)])
+ buf = StringIO()
+ writer = StataWriter(buf, dta, {"var2" : "tm"})
+ writer.write_file()
+ buf.seek(0)
+ dta2 = genfromdta(buf)
+ assert_equal(dta, dta2)
+
+ dta = DataFrame.from_records(dta)
+ buf = StringIO()
+ writer = StataWriter(buf, dta, {"var2" : "tm"})
+ writer.write_file()
+ buf.seek(0)
+ dta2 = genfromdta(buf, pandas=True)
+ ptesting.assert_frame_equal(dta, dta2.drop('index', axis=1))
+
if __name__ == "__main__":
import nose
Something went wrong with that request. Please try again.