Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

Already on GitHub? Sign in to your account

ENH: improve DataFrame read_csv / to_csv for Index/MultiIndex #151

Closed
wants to merge 1 commit into
from
Jump to file or symbol
Failed to load files and symbols.
+113 −27
Split
View
@@ -431,8 +431,9 @@ def from_csv(cls, path, header=0, delimiter=',', index_col=0):
header : int, default 0
Row to use at header (skip prior rows)
delimiter : string, default ','
- index_col : int, default 0
- Column to use for index
+ index_col : int or sequence, default 0
+ Column to use for index. If a sequence is given, a MultiIndex
+ is used.
Notes
-----
@@ -482,8 +483,10 @@ def to_csv(self, path, nanRep='', cols=None, header=True,
Write out column names
index : boolean, default True
Write row names (index)
- index_label : string, default None
- Column label for index column if desired
+ index_label : string or sequence, default None
+ Column label for index column(s) if desired. If None is given, and
+ `header` and `index` are True, then the index names are used. A
+ sequence should be given if the DataFrame uses MultiIndex.
mode : Python write mode, default 'wb'
"""
f = open(path, mode)
@@ -494,15 +497,25 @@ def to_csv(self, path, nanRep='', cols=None, header=True,
series = self._series
if header:
joined_cols = ','.join([str(c) for c in cols])
- if index and index_label:
- f.write('%s,%s' % (index_label, joined_cols))
+ if index:
+ # should write something for index label
+ if index_label is None:
+ index_label = getattr(self.index, 'names', ['index'])
+ elif not isinstance(index_label, (list, tuple, np.ndarray)):
+ # given a string for a DF with Index
+ index_label = [index_label]
+ f.write('%s,%s' % (",".join(index_label), joined_cols))
else:
f.write(joined_cols)
f.write('\n')
+ nlevels = getattr(self.index, 'nlevels', 1)
for idx in self.index:
if index:
- f.write(str(idx))
+ if nlevels == 1:
+ f.write(str(idx))
+ else: # handle MultiIndex
+ f.write(",".join([str(i) for i in idx]))
for i, col in enumerate(cols):
val = series[col].get(idx)
if isnull(val):
View
@@ -9,7 +9,7 @@
import numpy as np
-from pandas.core.index import Index
+from pandas.core.index import Index, MultiIndex
from pandas.core.frame import DataFrame
def read_csv(filepath_or_buffer, sep=None, header=0, skiprows=None, index_col=0,
@@ -27,9 +27,9 @@ def read_csv(filepath_or_buffer, sep=None, header=0, skiprows=None, index_col=0,
Row to use for the column labels of the parsed DataFrame
skiprows : list-like
Row numbers to skip (0-indexed)
- index_col : int, default 0
+ index_col : int or sequence., default 0
Column to use as the row labels of the DataFrame. Pass None if there is
- no such column
+ no such column. If a sequence is given, a MultiIndex is used.
na_values : list-like, default None
List of additional strings to recognize as NA/NaN
date_parser : function
@@ -65,7 +65,7 @@ def read_csv(filepath_or_buffer, sep=None, header=0, skiprows=None, index_col=0,
sniffed = csv.Sniffer().sniff(sample)
dia.delimiter = sniffed.delimiter
f.seek(0)
-
+
reader = csv.reader(f, dialect=dia)
if skiprows is not None:
@@ -92,9 +92,9 @@ def read_table(filepath_or_buffer, sep='\t', header=0, skiprows=None,
Row to use for the column labels of the parsed DataFrame
skiprows : list-like
Row numbers to skip (0-indexed)
- index_col : int, default 0
+ index_col : int or sequence, default 0
Column to use as the row labels of the DataFrame. Pass None if there is
- no such column
+ no such column. If a sequence is given, a MultiIndex is used.
na_values : list-like, default None
List of additional strings to recognize as NA/NaN
date_parser : function
@@ -107,7 +107,7 @@ def read_table(filepath_or_buffer, sep='\t', header=0, skiprows=None,
-------
parsed : DataFrame
"""
- return read_csv(filepath_or_buffer, sep, header, skiprows,
+ return read_csv(filepath_or_buffer, sep, header, skiprows,
index_col, na_values, date_parser, names)
def _simple_parser(lines, colNames=None, header=0, indexCol=0,
@@ -149,27 +149,43 @@ def _simple_parser(lines, colNames=None, header=0, indexCol=0,
# no index column specified, so infer that's what is wanted
if indexCol is not None:
- if indexCol == 0 and len(content[0]) == len(columns) + 1:
- index = zipped_content[0]
- zipped_content = zipped_content[1:]
+ if np.isscalar(indexCol):
+ if indexCol == 0 and len(content[0]) == len(columns) + 1:
+ index = zipped_content[0]
+ zipped_content = zipped_content[1:]
+ else:
+ index = zipped_content.pop(indexCol)
+ columns.pop(indexCol)
+ else: # given a list of index
+ idx_names = []
+ index = []
+ for idx in indexCol:
+ idx_names.append(columns[idx])
+ index.append(zipped_content[idx])
+ #remove index items from content and columns, don't pop in loop
+ for i in range(len(indexCol)):
+ columns.remove(idx_names[i])
+ zipped_content.remove(index[i])
+
+
+ if np.isscalar(indexCol):
+ if parse_dates:
+ index = _try_parse_dates(index, parser=date_parser)
+ index = Index(_maybe_convert_int(np.array(index, dtype=object)))
else:
- index = zipped_content.pop(indexCol)
- columns.pop(indexCol)
-
- if parse_dates:
- index = _try_parse_dates(index, parser=date_parser)
-
- index = _maybe_convert_int(np.array(index, dtype=object))
+ index = MultiIndex.from_arrays(_maybe_convert_int_mindex(index,
+ parse_dates, date_parser),
+ names=idx_names)
else:
- index = np.arange(len(content))
+ index = Index(np.arange(len(content)))
if len(columns) != len(zipped_content):
raise Exception('wrong number of columns')
data = dict(izip(columns, zipped_content))
data = _floatify(data, na_values=na_values)
data = _convert_to_ndarrays(data)
- return DataFrame(data=data, columns=columns, index=Index(index))
+ return DataFrame(data=data, columns=columns, index=index)
def _floatify(data_dict, na_values=None):
"""
@@ -218,6 +234,20 @@ def _maybe_convert_int(arr):
return arr
+def _maybe_convert_int_mindex(index, parse_dates, date_parser):
+ if len(index) == 0:
+ return index
+
+ for i in range(len(index)):
+ try:
+ int(index[i][0])
+ index[i] = map(int, index[i])
+ except ValueError:
+ if parse_dates:
+ index[i] = _try_parse_dates(index[i], date_parser)
+
+ return index
+
def _convert_to_ndarrays(dct):
result = {}
for c, values in dct.iteritems():
View
@@ -13,7 +13,8 @@
import pandas.core.datetools as datetools
from pandas.core.index import NULL_INDEX
-from pandas.core.api import (DataFrame, Index, Series, notnull, isnull)
+from pandas.core.api import (DataFrame, Index, Series, notnull, isnull,
+ MultiIndex)
from pandas.util.testing import (assert_almost_equal,
assert_series_equal,
@@ -1462,6 +1463,48 @@ def test_to_csv_from_csv(self):
os.remove(path)
+ def test_to_csv_multiindex(self):
+ path = '__tmp__'
+
+ frame = self.frame
+ old_index = frame.index
+ new_index = MultiIndex.from_arrays(np.arange(len(old_index)*2).reshape(2,-1))
+ frame.index = new_index
+ frame.to_csv(path, header=False)
+ frame.to_csv(path, cols=['A', 'B'])
+
+
+ # round trip
+ frame.to_csv(path)
+
+ df = DataFrame.from_csv(path, index_col=[0,1])
+
+ assert_frame_equal(frame, df)
+ self.frame.index = old_index # needed if setUP becomes a classmethod
+
+ # try multiindex with dates
+ tsframe = self.tsframe
+ old_index = tsframe.index
+ new_index = [old_index, np.arange(len(old_index))]
+ tsframe.index = MultiIndex.from_arrays(new_index)
+
+ tsframe.to_csv(path, index_label = ['time','foo'])
+ recons = DataFrame.from_csv(path, index_col=[0,1])
+ assert_frame_equal(tsframe, recons)
+
+ # do not load index
+ tsframe.to_csv(path)
+ recons = DataFrame.from_csv(path, index_col=None)
+ np.testing.assert_equal(len(recons.columns), len(tsframe.columns) + 2)
+
+ # no index
+ tsframe.to_csv(path, index=False)
+ recons = DataFrame.from_csv(path, index_col=None)
+ assert_almost_equal(recons.values, self.tsframe.values)
+ self.tsframe.index = old_index # needed if setUP becomes classmethod
+
+ os.remove(path)
+
def test_info(self):
io = StringIO()
self.frame.info(buf=io)