diff --git a/doc/source/io.rst b/doc/source/io.rst index e71b4134f5b9c..c4865fddb099b 100644 --- a/doc/source/io.rst +++ b/doc/source/io.rst @@ -3821,22 +3821,41 @@ outside of this range, the variable is cast to ``int16``. Reading from Stata format ~~~~~~~~~~~~~~~~~~~~~~~~~ -The top-level function ``read_stata`` will read a dta files -and return a DataFrame. Alternatively, the class :class:`~pandas.io.stata.StataReader` -can be used if more granular access is required. :class:`~pandas.io.stata.StataReader` -reads the header of the dta file at initialization. The method -:func:`~pandas.io.stata.StataReader.data` reads and converts observations to a DataFrame. +The top-level function ``read_stata`` will read a dta file and return +either a DataFrame or a :class:`~pandas.io.stata.StataReader` that can +be used to read the file incrementally. .. ipython:: python pd.read_stata('stata.dta') +.. versionadded:: 0.16.0 + +Specifying a ``chunksize`` yields a +:class:`~pandas.io.stata.StataReader` instance that can be used to +read ``chunksize`` lines from the file at a time. The ``StataReader`` +object can be used as an iterator. + + reader = pd.read_stata('stata.dta', chunksize=1000) + for df in reader: + do_something(df) + +For more fine-grained control, use ``iterator=True`` and specify +``chunksize`` with each call to +:func:`~pandas.io.stata.StataReader.read`. + +.. ipython:: python + + reader = pd.read_stata('stata.dta', iterator=True) + chunk1 = reader.read(10) + chunk2 = reader.read(20) + Currently the ``index`` is retrieved as a column. The parameter ``convert_categoricals`` indicates whether value labels should be read and used to create a ``Categorical`` variable from them. Value labels can -also be retrieved by the function ``variable_labels``, which requires data to be -called before use (see ``pandas.io.stata.StataReader``). +also be retrieved by the function ``value_labels``, which requires :func:`~pandas.io.stata.StataReader.read` +to be called before use. The parameter ``convert_missing`` indicates whether missing value representations in Stata should be preserved. If ``False`` (the default), diff --git a/doc/source/release.rst b/doc/source/release.rst index 164e381499490..0912a11e28801 100644 --- a/doc/source/release.rst +++ b/doc/source/release.rst @@ -55,6 +55,8 @@ performance improvements along with a large number of bug fixes. Highlights include: +- Allow Stata files to be read incrementally, support for long strings in Stata files (issue:`9493`:) :ref:`here`. + See the :ref:`v0.16.0 Whatsnew ` overview or the issue tracker on GitHub for an extensive list of all API changes, enhancements and bugs that have been fixed in 0.16.0. diff --git a/pandas/io/stata.py b/pandas/io/stata.py index 0d6e554b8b474..7dd32fd00a4d2 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -22,51 +22,144 @@ from pandas import compat, to_timedelta, to_datetime, isnull, DatetimeIndex from pandas.compat import lrange, lmap, lzip, text_type, string_types, range, \ zip, BytesIO +from pandas.util.decorators import Appender import pandas.core.common as com from pandas.io.common import get_filepath_or_buffer from pandas.lib import max_len_string_array, infer_dtype from pandas.tslib import NaT, Timestamp +_statafile_processing_params1 = """\ +convert_dates : boolean, defaults to True + Convert date variables to DataFrame time values +convert_categoricals : boolean, defaults to True + Read value labels and convert columns to Categorical/Factor variables""" + +_encoding_params = """\ +encoding : string, None or encoding + Encoding used to parse the files. Note that Stata doesn't + support unicode. None defaults to iso-8859-1.""" + +_statafile_processing_params2 = """\ +index : identifier of index column + identifier of column that should be used as index of the DataFrame +convert_missing : boolean, defaults to False + Flag indicating whether to convert missing values to their Stata + representations. If False, missing values are replaced with nans. + If True, columns containing missing values are returned with + object data types and missing values are represented by + StataMissingValue objects. +preserve_dtypes : boolean, defaults to True + Preserve Stata datatypes. If False, numeric data are upcast to pandas + default types for foreign data (float64 or int64) +columns : list or None + Columns to retain. Columns will be returned in the given order. None + returns all columns +order_categoricals : boolean, defaults to True + Flag indicating whether converted categorical data are ordered.""" + +_chunksize_params = """\ +chunksize : int, default None + Return StataReader object for iterations, returns chunks with + given number of lines""" + +_iterator_params = """\ +iterator : boolean, default False + Return StataReader object""" + +_read_stata_doc = """Read Stata file into DataFrame + +Parameters +---------- +filepath_or_buffer : string or file-like object + Path to .dta file or object implementing a binary read() functions +%s +%s +%s +%s +%s + +Returns +------- +DataFrame or StataReader + +Examples +-------- +Read a Stata dta file: +>> df = pandas.read_stata('filename.dta') + +Read a Stata dta file in 10,000 line chunks: +>> itr = pandas.read_stata('filename.dta', chunksize=10000) +>> for chunk in itr: +>> do_something(chunk) +""" % (_statafile_processing_params1, _encoding_params, + _statafile_processing_params2, _chunksize_params, + _iterator_params) + +_data_method_doc = """Reads observations from Stata file, converting them into a dataframe + +This is a legacy method. Use `read` in new code. + +Parameters +---------- +%s +%s + +Returns +------- +DataFrame +""" % (_statafile_processing_params1, _statafile_processing_params2) + + +_read_method_doc = """\ +Reads observations from Stata file, converting them into a dataframe + +Parameters +---------- +nrows : int + Number of lines to read from data file, if None read whole file. +%s +%s + +Returns +------- +DataFrame +""" % (_statafile_processing_params1, _statafile_processing_params2) + + +_stata_reader_doc = """\ +Class for reading Stata dta files. + +Parameters +---------- +path_or_buf : string or file-like object + Path to .dta file or object implementing a binary read() functions +%s +%s +%s +%s +""" % (_statafile_processing_params1, _statafile_processing_params2, + _encoding_params, _chunksize_params) + + +@Appender(_read_stata_doc) def read_stata(filepath_or_buffer, convert_dates=True, convert_categoricals=True, encoding=None, index=None, convert_missing=False, preserve_dtypes=True, columns=None, - order_categoricals=True): - """ - Read Stata file into DataFrame + order_categoricals=True, chunksize=None, iterator=False): - Parameters - ---------- - filepath_or_buffer : string or file-like object - Path to .dta file or object implementing a binary read() functions - convert_dates : boolean, defaults to True - Convert date variables to DataFrame time values - convert_categoricals : boolean, defaults to True - Read value labels and convert columns to Categorical/Factor variables - encoding : string, None or encoding - Encoding used to parse the files. Note that Stata doesn't - support unicode. None defaults to iso-8859-1. - index : identifier of index column - identifier of column that should be used as index of the DataFrame - convert_missing : boolean, defaults to False - Flag indicating whether to convert missing values to their Stata - representations. If False, missing values are replaced with nans. - If True, columns containing missing values are returned with - object data types and missing values are represented by - StataMissingValue objects. - preserve_dtypes : boolean, defaults to True - Preserve Stata datatypes. If False, numeric data are upcast to pandas - default types for foreign data (float64 or int64) - columns : list or None - Columns to retain. Columns will be returned in the given order. None - returns all columns - order_categoricals : boolean, defaults to True - Flag indicating whether converted categorical data are ordered. - """ - reader = StataReader(filepath_or_buffer, encoding) + reader = StataReader(filepath_or_buffer, + convert_dates=convert_dates, + convert_categoricals=convert_categoricals, + index=index, convert_missing=convert_missing, + preserve_dtypes=preserve_dtypes, + columns=columns, + order_categoricals=order_categoricals, + chunksize=chunksize, encoding=encoding) + + if iterator or chunksize: + return reader - return reader.data(convert_dates, convert_categoricals, index, - convert_missing, preserve_dtypes, columns, - order_categoricals) + return reader.read() _date_formats = ["%tc", "%tC", "%td", "%d", "%tw", "%tm", "%tq", "%th", "%ty"] @@ -139,8 +232,10 @@ def convert_year_month_safe(year, month): if year.max() < MAX_YEAR and year.min() > MIN_YEAR: return to_datetime(100 * year + month, format='%Y%m') else: + index = getattr(year, 'index', None) return Series( - [datetime.datetime(y, m, 1) for y, m in zip(year, month)]) + [datetime.datetime(y, m, 1) for y, m in zip(year, month)], + index=index) def convert_year_days_safe(year, days): """ @@ -150,9 +245,10 @@ def convert_year_days_safe(year, days): if year.max() < (MAX_YEAR - 1) and year.min() > MIN_YEAR: return to_datetime(year, format='%Y') + to_timedelta(days, unit='d') else: + index = getattr(year, 'index', None) value = [datetime.datetime(y, 1, 1) + relativedelta(days=int(d)) for y, d in zip(year, days)] - return Series(value) + return Series(value, index=index) def convert_delta_safe(base, deltas, unit): """ @@ -160,18 +256,18 @@ def convert_delta_safe(base, deltas, unit): versions if the deltas satisfy restrictions required to be expressed as dates in pandas. """ + index = getattr(deltas, 'index', None) if unit == 'd': if deltas.max() > MAX_DAY_DELTA or deltas.min() < MIN_DAY_DELTA: values = [base + relativedelta(days=int(d)) for d in deltas] - return Series(values) + return Series(values, index=index) elif unit == 'ms': if deltas.max() > MAX_MS_DELTA or deltas.min() < MIN_MS_DELTA: values = [base + relativedelta(microseconds=(int(d) * 1000)) for d in deltas] - return Series(values) + return Series(values, index=index) else: raise ValueError('format not understood') - base = to_datetime(base) deltas = to_timedelta(deltas, unit=unit) return base + deltas @@ -226,6 +322,7 @@ def convert_delta_safe(base, deltas, unit): if has_bad_values: # Restore NaT for bad values conv_dates[bad_locs] = NaT + return conv_dates @@ -717,7 +814,7 @@ def __init__(self, encoding): self.DTYPE_MAP_XML = \ dict( [ - (32768, np.string_), + (32768, np.uint8), # Keys to GSO (65526, np.float64), (65527, np.float32), (65528, np.int32), @@ -729,6 +826,7 @@ def __init__(self, encoding): self.TYPE_MAP_XML = \ dict( [ + (32768, 'L'), (65526, 'd'), (65527, 'f'), (65528, 'l'), @@ -776,7 +874,8 @@ def __init__(self, encoding): 'h': 'i2', 'l': 'i4', 'f': 'f4', - 'd': 'f8' + 'd': 'f8', + 'L': 'u8' } # Reserved words cannot be used as variable names @@ -797,42 +896,39 @@ def _decode_bytes(self, str, errors=None): else: return str - class StataReader(StataParser): - """ - Class for working with a Stata dataset. There are two possibilities for - usage: - - * The from_dta() method on the DataFrame class. - This will return a DataFrame with the Stata dataset. Note that when - using the from_dta() method, you will not have access to - meta-information like variable labels or the data label. - - * Work with this object directly. Upon instantiation, the header of the - Stata data file is read, giving you access to attributes like - variable_labels(), data_label(), nobs(), ... A DataFrame with the data - is returned by the read() method; this will also fill up the - value_labels. Note that calling the value_labels() method will result in - an error if the read() method has not been called yet. This is because - the value labels are stored at the end of a Stata dataset, after the - data. - - Parameters - ---------- - path_or_buf : string or file-like object - Path to .dta file or object implementing a binary read() functions - encoding : string, None or encoding - Encoding used to parse the files. Note that Stata doesn't - support unicode. None defaults to iso-8859-1. - """ + __doc__ = _stata_reader_doc - def __init__(self, path_or_buf, encoding='iso-8859-1'): + def __init__(self, path_or_buf, convert_dates=True, + convert_categoricals=True, index=None, + convert_missing=False, preserve_dtypes=True, + columns=None, order_categoricals=True, + encoding='iso-8859-1', chunksize=None): super(StataReader, self).__init__(encoding) self.col_sizes = () + + # Arguments to the reader (can be temporarily overridden in + # calls to read). + self._convert_dates = convert_dates + self._convert_categoricals = convert_categoricals + self._index = index + self._convert_missing = convert_missing + self._preserve_dtypes = preserve_dtypes + self._columns = columns + self._order_categoricals = order_categoricals + self._encoding = encoding + self._chunksize = chunksize + + # State variables for the file self._has_string_data = False self._missing_values = False - self._data_read = False + self._can_read_value_labels = False + self._column_selector_set = False self._value_labels_read = False + self._data_read = False + self._dtype = None + self._lines_read = 0 + self._native_byteorder = _set_endianness(sys.byteorder) if isinstance(path_or_buf, str): path_or_buf, encoding = get_filepath_or_buffer( @@ -917,8 +1013,8 @@ def _read_header(self): for typ in typlist: if typ <= 2045: self.typlist[i] = typ - elif typ == 32768: - raise ValueError("Long strings are not supported") + #elif typ == 32768: + # raise ValueError("Long strings are not supported") else: self.typlist[i] = self.TYPE_MAP_XML[typ] i += 1 @@ -1060,9 +1156,13 @@ def _read_header(self): self.has_string_data = len([x for x in self.typlist if type(x) is int]) > 0 - """Calculate size of a data record.""" + # calculate size of a data record self.col_sizes = lmap(lambda x: self._calcsize(x), self.typlist) + # remove format details from %td + self.fmtlist = ["%td" if x.startswith("%td") else x for x in self.fmtlist] + + def _calcsize(self, fmt): return (type(fmt) is int and fmt or struct.calcsize(self.byteorder + fmt)) @@ -1070,11 +1170,7 @@ def _calcsize(self, fmt): def _null_terminate(self, s): if compat.PY3 or self._encoding is not None: # have bytes not strings, # so must decode - null_byte = b"\0" - try: - s = s[:s.index(null_byte)] - except: - pass + s = s.partition(b"\0")[0] return s.decode(self._encoding or self._default_encoding) else: null_byte = "\0" @@ -1084,30 +1180,30 @@ def _null_terminate(self, s): return s def _read_value_labels(self): + if self.format_version <= 108: + # Value labels are not supported in version 108 and earlier. + return + if self._value_labels_read: + # Don't read twice + return + if self.format_version >= 117: self.path_or_buf.seek(self.seek_value_labels) else: - if not self._data_read: - raise Exception("Data has not been read. Because of the " - "layout of Stata files, this is necessary " - "before reading value labels.") - if self._value_labels_read: - raise Exception("Value labels have already been read.") + offset = self.nobs * self._dtype.itemsize + self.path_or_buf.seek(self.data_location + offset) + self._value_labels_read = True self.value_label_dict = dict() - if self.format_version <= 108: - # Value labels are not supported in version 108 and earlier. - return - while True: if self.format_version >= 117: if self.path_or_buf.read(5) == b' - break # end o f variable lable table + break # end of variable label table slength = self.path_or_buf.read(4) if not slength: - break # end of variable lable table (format < 117) + break # end of variable label table (format < 117) labname = self._null_terminate(self.path_or_buf.read(33)) self.path_or_buf.read(3) # padding @@ -1141,72 +1237,126 @@ def _read_strls(self): if self.path_or_buf.read(3) != b'GSO': break - v_o = struct.unpack(self.byteorder + 'L', - self.path_or_buf.read(8))[0] - typ = self.path_or_buf.read(1) + v_o = struct.unpack(self.byteorder + 'Q', self.path_or_buf.read(8))[0] + typ = struct.unpack('B', self.path_or_buf.read(1))[0] length = struct.unpack(self.byteorder + 'I', self.path_or_buf.read(4))[0] - self.GSO[v_o] = self.path_or_buf.read(length-1) - self.path_or_buf.read(1) # zero-termination + va = self.path_or_buf.read(length) + if typ == 130: + va = va[0:-1].decode(self._encoding or self._default_encoding) + self.GSO[v_o] = va + + # legacy + @Appender('DEPRECATED: ' + _data_method_doc) + def data(self, **kwargs): + + import warnings + warnings.warn("'data' is deprecated, use 'read' instead") + + if self._data_read: + raise Exception("Data has already been read.") + self._data_read = True + + return self.read(None, **kwargs) + - def data(self, convert_dates=True, convert_categoricals=True, index=None, - convert_missing=False, preserve_dtypes=True, columns=None, - order_categoricals=True): + def __iter__(self): + try: + if self._chunksize: + while True: + yield self.read(self._chunksize) + else: + yield self.read() + except StopIteration: + pass + + + def get_chunk(self, size=None): """ - Reads observations from Stata file, converting them into a dataframe + Reads lines from Stata file and returns as dataframe Parameters ---------- - convert_dates : boolean, defaults to True - Convert date variables to DataFrame time values - convert_categoricals : boolean, defaults to True - Read value labels and convert columns to Categorical/Factor - variables - index : identifier of index column - identifier of column that should be used as index of the DataFrame - convert_missing : boolean, defaults to False - Flag indicating whether to convert missing values to their Stata - representation. If False, missing values are replaced with - nans. If True, columns containing missing values are returned with - object data types and missing values are represented by - StataMissingValue objects. - preserve_dtypes : boolean, defaults to True - Preserve Stata datatypes. If False, numeric data are upcast to - pandas default types for foreign data (float64 or int64) - columns : list or None - Columns to retain. Columns will be returned in the given order. - None returns all columns - order_categoricals : boolean, defaults to True - Flag indicating whether converted categorical data are ordered. + size : int, defaults to None + Number of lines to read. If None, reads whole file. Returns ------- - y : DataFrame instance + DataFrame """ - self._missing_values = convert_missing - if self._data_read: - raise Exception("Data has already been read.") - self._data_read = True - - if self.format_version >= 117: + if size is None: + size = self._chunksize + return self.read(nrows=size) + + + @Appender(_read_method_doc) + def read(self, nrows=None, convert_dates=None, + convert_categoricals=None, index=None, + convert_missing=None, preserve_dtypes=None, + columns=None, order_categoricals=None): + + # Handle empty file or chunk. If reading incrementally raise + # StopIteration. If reading the whole thing return an empty + # data frame. + if (self.nobs == 0) and (nrows is None): + self._can_read_value_labels = True + self._data_read = True + return DataFrame(columns=self.varlist) + + # Handle options + if convert_dates is None: + convert_dates = self._convert_dates + if convert_categoricals is None: + convert_categoricals = self._convert_categoricals + if convert_missing is None: + convert_missing = self._convert_missing + if preserve_dtypes is None: + preserve_dtypes = self._preserve_dtypes + if columns is None: + columns = self._columns + if order_categoricals is None: + order_categoricals = self._order_categoricals + + if nrows is None: + nrows = self.nobs + + if (self.format_version >= 117) and (self._dtype is None): + self._can_read_value_labels = True self._read_strls() + # Setup the dtype. + if self._dtype is None: + dtype = [] # Convert struct data types to numpy data type + for i, typ in enumerate(self.typlist): + if typ in self.NUMPY_TYPE_MAP: + dtype.append(('s' + str(i), self.byteorder + self.NUMPY_TYPE_MAP[typ])) + else: + dtype.append(('s' + str(i), 'S' + str(typ))) + dtype = np.dtype(dtype) + self._dtype = dtype + # Read data - count = self.nobs - dtype = [] # Convert struct data types to numpy data type - for i, typ in enumerate(self.typlist): - if typ in self.NUMPY_TYPE_MAP: - dtype.append(('s' + str(i), self.byteorder + self.NUMPY_TYPE_MAP[typ])) - else: - dtype.append(('s' + str(i), 'S' + str(typ))) - dtype = np.dtype(dtype) - read_len = count * dtype.itemsize - self.path_or_buf.seek(self.data_location) - data = np.frombuffer(self.path_or_buf.read(read_len),dtype=dtype,count=count) + dtype = self._dtype + max_read_len = (self.nobs - self._lines_read) * dtype.itemsize + read_len = nrows * dtype.itemsize + read_len = min(read_len, max_read_len) + if read_len <= 0: + # Iterator has finished, should never be here unless + # we are reading the file incrementally + self._read_value_labels() + raise StopIteration + offset = self._lines_read * dtype.itemsize + self.path_or_buf.seek(self.data_location + offset) + read_lines = min(nrows, self.nobs - self._lines_read) + data = np.frombuffer(self.path_or_buf.read(read_len), dtype=dtype, + count=read_lines) + self._lines_read += read_lines + if self._lines_read == self.nobs: + self._can_read_value_labels = True + self._data_read = True # if necessary, swap the byte order to native here if self.byteorder != self._native_byteorder: data = data.byteswap().newbyteorder() - self._data_read = True if convert_categoricals: self._read_value_labels() @@ -1217,39 +1367,22 @@ def data(self, convert_dates=True, convert_categoricals=True, index=None, data = DataFrame.from_records(data, index=index) data.columns = self.varlist - if columns is not None: - column_set = set(columns) - if len(column_set) != len(columns): - raise ValueError('columns contains duplicate entries') - unmatched = column_set.difference(data.columns) - if unmatched: - raise ValueError('The following columns were not found in the ' - 'Stata data set: ' + - ', '.join(list(unmatched))) - # Copy information for retained columns for later processing - dtyplist = [] - typlist = [] - fmtlist = [] - lbllist = [] - matched = set() - for i, col in enumerate(data.columns): - if col in column_set: - matched.update([col]) - dtyplist.append(self.dtyplist[i]) - typlist.append(self.typlist[i]) - fmtlist.append(self.fmtlist[i]) - lbllist.append(self.lbllist[i]) + # If index is not specified, use actual row number rather than + # restarting at 0 for each chunk. + if index is None: + ix = np.arange(self._lines_read - read_lines, self._lines_read) + data = data.set_index(ix) - data = data[columns] - self.dtyplist = dtyplist - self.typlist = typlist - self.fmtlist = fmtlist - self.lbllist = lbllist + if columns is not None: + data = self._do_select_columns(data, columns) + # Decode strings for col, typ in zip(data, self.typlist): if type(typ) is int: data[col] = data[col].apply(self._null_terminate, convert_dtype=True) + data = self._insert_strls(data) + cols_ = np.where(self.dtyplist)[0] # Convert columns (if needed) to match input type @@ -1269,7 +1402,39 @@ def data(self, convert_dates=True, convert_categoricals=True, index=None, data = DataFrame.from_items(data_formatted) del data_formatted + self._do_convert_missing(data, convert_missing) + + if convert_dates: + cols = np.where(lmap(lambda x: x in _date_formats, + self.fmtlist))[0] + for i in cols: + col = data.columns[i] + data[col] = _stata_elapsed_date_to_datetime_vec(data[col], self.fmtlist[i]) + + if convert_categoricals and self.value_label_dict: + data = self._do_convert_categoricals(data, self.value_label_dict, self.lbllist, + order_categoricals) + + if not preserve_dtypes: + retyped_data = [] + convert = False + for col in data: + dtype = data[col].dtype + if dtype in (np.float16, np.float32): + dtype = np.float64 + convert = True + elif dtype in (np.int8, np.int16, np.int32): + dtype = np.int64 + convert = True + retyped_data.append((col, data[col].astype(dtype))) + if convert: + data = DataFrame.from_items(retyped_data) + + return data + + def _do_convert_missing(self, data, convert_missing): # Check for missing values, and replace if found + for i, colname in enumerate(data): fmt = self.typlist[i] if fmt not in self.VALID_RANGE: @@ -1282,7 +1447,7 @@ def data(self, convert_dates=True, convert_categoricals=True, index=None, if not missing.any(): continue - if self._missing_values: # Replacement follows Stata notation + if convert_missing: # Replacement follows Stata notation missing_loc = np.argwhere(missing) umissing, umissing_loc = np.unique(series[missing], return_inverse=True) @@ -1301,48 +1466,72 @@ def data(self, convert_dates=True, convert_categoricals=True, index=None, data[colname] = replacement - if convert_dates: - cols = np.where(lmap(lambda x: x in _date_formats, - self.fmtlist))[0] - for i in cols: - col = data.columns[i] - data[col] = _stata_elapsed_date_to_datetime_vec(data[col], self.fmtlist[i]) + def _insert_strls(self, data): + if not hasattr(self, 'GSO') or len(self.GSO) == 0: + return data + for i, typ in enumerate(self.typlist): + if typ != 'L': + continue + data.iloc[:, i] = [self.GSO[k] for k in data.iloc[:, i]] + return data - if convert_categoricals and self.value_label_dict: - value_labels = list(compat.iterkeys(self.value_label_dict)) - cat_converted_data = [] - for col, label in zip(data, self.lbllist): - if label in value_labels: - # Explicit call with ordered=True - cat_data = Categorical(data[col], ordered=order_categoricals) - value_label_dict = self.value_label_dict[label] - categories = [] - for category in cat_data.categories: - if category in value_label_dict: - categories.append(value_label_dict[category]) - else: - categories.append(category) # Partially labeled - cat_data.categories = categories - cat_converted_data.append((col, cat_data)) - else: - cat_converted_data.append((col, data[col])) - data = DataFrame.from_items(cat_converted_data) + def _do_select_columns(self, data, columns): - if not preserve_dtypes: - retyped_data = [] - convert = False - for col in data: - dtype = data[col].dtype - if dtype in (np.float16, np.float32): - dtype = np.float64 - convert = True - elif dtype in (np.int8, np.int16, np.int32): - dtype = np.int64 - convert = True - retyped_data.append((col, data[col].astype(dtype))) - if convert: - data = DataFrame.from_items(retyped_data) + if not self._column_selector_set: + column_set = set(columns) + if len(column_set) != len(columns): + raise ValueError('columns contains duplicate entries') + unmatched = column_set.difference(data.columns) + if unmatched: + raise ValueError('The following columns were not found in the ' + 'Stata data set: ' + + ', '.join(list(unmatched))) + # Copy information for retained columns for later processing + dtyplist = [] + typlist = [] + fmtlist = [] + lbllist = [] + matched = set() + for i, col in enumerate(data.columns): + if col in column_set: + matched.update([col]) + dtyplist.append(self.dtyplist[i]) + typlist.append(self.typlist[i]) + fmtlist.append(self.fmtlist[i]) + lbllist.append(self.lbllist[i]) + self.dtyplist = dtyplist + self.typlist = typlist + self.fmtlist = fmtlist + self.lbllist = lbllist + self._column_selector_set = True + + return data[columns] + + + def _do_convert_categoricals(self, data, value_label_dict, lbllist, order_categoricals): + """ + Converts categorical columns to Categorical type. + """ + value_labels = list(compat.iterkeys(value_label_dict)) + cat_converted_data = [] + for col, label in zip(data, lbllist): + if label in value_labels: + # Explicit call with ordered=True + cat_data = Categorical(data[col], ordered=order_categoricals) + categories = [] + for category in cat_data.categories: + if category in value_label_dict[label]: + categories.append(value_label_dict[label][category]) + else: + categories.append(category) # Partially labeled + cat_data.categories = categories + # TODO: is the next line needed above in the data(...) method? + cat_data = Series(cat_data, index=data.index) + cat_converted_data.append((col, cat_data)) + else: + cat_converted_data.append((col, data[col])) + data = DataFrame.from_items(cat_converted_data) return data def data_label(self): diff --git a/pandas/io/tests/data/stata12_117.dta b/pandas/io/tests/data/stata12_117.dta new file mode 100644 index 0000000000000..7d1d6181f53bf Binary files /dev/null and b/pandas/io/tests/data/stata12_117.dta differ diff --git a/pandas/io/tests/test_stata.py b/pandas/io/tests/test_stata.py index f896b98fddf5b..8b44be61d5f66 100644 --- a/pandas/io/tests/test_stata.py +++ b/pandas/io/tests/test_stata.py @@ -75,6 +75,8 @@ def setUp(self): self.dta20_115 = os.path.join(self.dirpath, 'stata11_115.dta') self.dta20_117 = os.path.join(self.dirpath, 'stata11_117.dta') + self.dta21_117 = os.path.join(self.dirpath, 'stata12_117.dta') + def read_dta(self, file): # Legacy default reader configuration return read_stata(file, convert_dates=True) @@ -90,11 +92,21 @@ def test_read_empty_dta(self): empty_ds2 = read_stata(path) tm.assert_frame_equal(empty_ds, empty_ds2) + def test_data_method(self): + # Minimal testing of legacy data method + reader_114 = StataReader(self.dta1_114) + with warnings.catch_warnings(record=True) as w: + parsed_114_data = reader_114.data() + + reader_114 = StataReader(self.dta1_114) + parsed_114_read = reader_114.read() + tm.assert_frame_equal(parsed_114_data, parsed_114_read) + def test_read_dta1(self): reader_114 = StataReader(self.dta1_114) - parsed_114 = reader_114.data() + parsed_114 = reader_114.read() reader_117 = StataReader(self.dta1_117) - parsed_117 = reader_117.data() + parsed_117 = reader_117.read() # Pandas uses np.nan as missing value. # Thus, all columns will be of type float, regardless of their name. expected = DataFrame([(np.nan, np.nan, np.nan, np.nan, np.nan)], @@ -152,14 +164,18 @@ def test_read_dta2(self): expected['yearly_date'] = expected['yearly_date'].astype('O') with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") parsed_114 = self.read_dta(self.dta2_114) parsed_115 = self.read_dta(self.dta2_115) parsed_117 = self.read_dta(self.dta2_117) - # 113 is buggy due ot limits date format support in Stata + # 113 is buggy due to limits of date format support in Stata # parsed_113 = self.read_dta(self.dta2_113) - # should get a warning for that format. - tm.assert_equal(len(w), 1) + # Remove resource warnings + w = [x for x in w if x.category is UserWarning] + + # should get warning for each call to read_dta + tm.assert_equal(len(w), 3) # buggy test because of the NaT comparison on certain platforms # Format 113 test fails since it does not support tc and tC formats @@ -215,6 +231,19 @@ def test_read_dta4(self): tm.assert_frame_equal(parsed_115, expected) tm.assert_frame_equal(parsed_117, expected) + # File containing strls + def test_read_dta12(self): + parsed_117 = self.read_dta(self.dta21_117) + expected = DataFrame.from_records( + [ + [1, "abc", "abcdefghi"], + [3, "cba", "qwertywertyqwerty"], + [93, "", "strl"], + ], + columns=['x', 'y', 'z']) + + tm.assert_frame_equal(parsed_117, expected, check_dtype=False) + def test_read_write_dta5(self): original = DataFrame([(np.nan, np.nan, np.nan, np.nan, np.nan)], columns=['float_miss', 'double_miss', 'byte_miss', @@ -858,6 +887,118 @@ def test_categorical_ordering(self): tm.assert_equal(False, parsed_115_unordered[col].cat.ordered) tm.assert_equal(False, parsed_117_unordered[col].cat.ordered) + + def test_read_chunks_117(self): + files_117 = [self.dta1_117, self.dta2_117, self.dta3_117, + self.dta4_117, self.dta14_117, self.dta15_117, + self.dta16_117, self.dta17_117, self.dta18_117, + self.dta19_117, self.dta20_117] + + for fname in files_117: + for chunksize in 1,2: + for convert_categoricals in False, True: + for convert_dates in False, True: + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + parsed = read_stata(fname, convert_categoricals=convert_categoricals, + convert_dates=convert_dates) + itr = read_stata(fname, iterator=True) + + pos = 0 + for j in range(5): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + try: + chunk = itr.read(chunksize) + except StopIteration: + break + from_frame = parsed.iloc[pos:pos+chunksize, :] + try: + tm.assert_frame_equal(from_frame, chunk, check_dtype=False) + except AssertionError: + # datetime.datetime and pandas.tslib.Timestamp may hold + # equivalent values but fail assert_frame_equal + assert(all([x == y for x, y in zip(from_frame, chunk)])) + + pos += chunksize + + def test_iterator(self): + + fname = self.dta3_117 + + parsed = read_stata(fname) + + itr = read_stata(fname, iterator=True) + chunk = itr.read(5) + tm.assert_frame_equal(parsed.iloc[0:5, :], chunk) + + itr = read_stata(fname, chunksize=5) + chunk = list(itr) + tm.assert_frame_equal(parsed.iloc[0:5, :], chunk[0]) + + itr = read_stata(fname, iterator=True) + chunk = itr.get_chunk(5) + tm.assert_frame_equal(parsed.iloc[0:5, :], chunk) + + itr = read_stata(fname, chunksize=5) + chunk = itr.get_chunk() + tm.assert_frame_equal(parsed.iloc[0:5, :], chunk) + + + def test_read_chunks_115(self): + files_115 = [self.dta2_115, self.dta3_115, self.dta4_115, + self.dta14_115, self.dta15_115, self.dta16_115, + self.dta17_115, self.dta18_115, self.dta19_115, + self.dta20_115] + + for fname in files_115: + for chunksize in 1,2: + for convert_categoricals in False, True: + for convert_dates in False, True: + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + parsed = read_stata(fname, convert_categoricals=convert_categoricals, + convert_dates=convert_dates) + itr = read_stata(fname, iterator=True, + convert_categoricals=convert_categoricals) + + pos = 0 + for j in range(5): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + try: + chunk = itr.read(chunksize) + except StopIteration: + break + from_frame = parsed.iloc[pos:pos+chunksize, :] + try: + tm.assert_frame_equal(from_frame, chunk, check_dtype=False) + except AssertionError: + # datetime.datetime and pandas.tslib.Timestamp may hold + # equivalent values but fail assert_frame_equal + assert(all([x == y for x, y in zip(from_frame, chunk)])) + + pos += chunksize + + def test_read_chunks_columns(self): + fname = self.dta3_117 + columns = ['quarter', 'cpi', 'm1'] + chunksize = 2 + + parsed = read_stata(fname, columns=columns) + itr = read_stata(fname, iterator=True) + pos = 0 + for j in range(5): + chunk = itr.read(chunksize, columns=columns) + if chunk is None: + break + from_frame = parsed.iloc[pos:pos+chunksize, :] + tm.assert_frame_equal(from_frame, chunk, check_dtype=False) + pos += chunksize + + if __name__ == '__main__': nose.runmodule(argv=[__file__, '-vvs', '-x', '--pdb', '--pdb-failure'], exit=False)