Skip to content

Commit

Permalink
Merge pull request #10 from vixen-project/fix-unicode-csv
Browse files Browse the repository at this point in the history
Fix exporting and importing CSV with unicode.
  • Loading branch information
prabhuramachandran committed Jan 16, 2018
2 parents ff60fc5 + 7a69358 commit d471f5a
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 48 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ mock
pytest
json_tricks>=3.0
whoosh
backports.csv
25 changes: 11 additions & 14 deletions vixen/project.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import csv
import datetime
import io
import json_tricks
import logging
import os
Expand All @@ -25,8 +25,10 @@
if sys.version_info[0] > 2:
unicode = str
string_types = (str,)
import csv
else:
string_types = (basestring,)
import backports.csv as csv
INT = fields.NUMERIC(numtype=int)
FLOAT = fields.NUMERIC(numtype=float)

Expand All @@ -38,7 +40,7 @@ def get_file_saved_time(path):

def _get_sample(fname):
sample = ''
with open(fname, 'r') as fp:
with io.open(fname, 'r', newline='', encoding='utf-8') as fp:
sample += fp.readline() + fp.readline()

return sample
Expand All @@ -49,7 +51,7 @@ def _get_csv_headers(fname):
sniffer = csv.Sniffer()
has_header = sniffer.has_header(sample)
dialect = sniffer.sniff(sample)
with open(fname, 'r') as fp:
with io.open(fname, 'r', newline='', encoding='utf-8') as fp:
reader = csv.reader(fp, dialect)
header = next(reader)
return has_header, header, dialect
Expand Down Expand Up @@ -345,24 +347,19 @@ def export_csv(self, fname, cols=None):

data_cols = set([x for x in cols if x in self._data])

def _format(elem):
if isinstance(elem, string_types):
return '"%s"' % elem
else:
return str(elem) if elem is not None else ""

with open_file(fname, 'w') as of:
with io.open(fname, 'w', newline='', encoding='utf-8') as of:
# Write the header.
of.write(','.join(cols) + '\n')
writer = csv.writer(of)
writer.writerow(cols)
for i in range(len(self._relpath2index)):
line = []
for col in cols:
if col in data_cols:
elem = self._data[col][i]
else:
elem = self._tag_data[col][i]
line.append(_format(elem))
of.write(','.join(line) + '\n')
line.append(elem)
writer.writerow(line)

def import_csv(self, fname):
"""Read tag information from given CSV filename.
Expand Down Expand Up @@ -397,7 +394,7 @@ def import_csv(self, fname):

count = 0
total = 0
with open(fname, 'r') as fp:
with io.open(fname, 'r', newline='', encoding='utf-8') as fp:
reader = csv.reader(fp, dialect)
next(reader) # Skip header
for record in reader:
Expand Down
75 changes: 41 additions & 34 deletions vixen/tests/test_project.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
import csv
# -*- coding: utf-8 -*-
import datetime
import io
import os
from os.path import basename, join, exists
import tempfile
from textwrap import dedent
import time
import shutil
import sys

import unittest

from vixen.tests.test_directory import make_data, create_dummy_file
from vixen.project import Project, TagInfo, get_non_existing_filename, INT

if sys.version_info >= (3, 0):
import csv
else:
import backports.csv as csv


class TestProjectBase(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -119,7 +126,7 @@ def test_load_should_restore_saved_state(self):
self.assertEqual(p.get('root.txt')._mtime,
p1.get('root.txt')._mtime)

def test_export_to_csv(self):
def test_export_to_csv_with_unicode(self):
# Given
tags = [TagInfo(name='completed', type='bool'),
TagInfo(name='comment', type='string')]
Expand All @@ -128,39 +135,40 @@ def test_export_to_csv(self):
p.scan()
m = p.get('root.txt')
m.tags['completed'] = True
m.tags['comment'] = u'hello, world; foo'
m.tags['comment'] = u'hello, world; न Kévin'
out_fname = tempfile.mktemp(dir=self.root, suffix='.csv')

# When
p.export_csv(out_fname)

# Then
reader = csv.reader(open(out_fname))
cols = next(reader)
expected = [
'comment', 'completed', 'ctime', 'file_name', 'mtime', 'path',
'relpath', 'size', 'type'
]
self.assertEqual(cols, expected)
expected = {'hello.py': 'False', 'root.txt': 'True'}
data = [next(reader), next(reader), next(reader), next(reader)]
data = sorted(data, key=lambda x: x[6])
row = data[0]
self.assertEqual(basename(row[5]), 'hello.py')
self.assertEqual(row[1], 'False')
self.assertEqual(row[0], '')
row = data[1]
self.assertEqual(basename(row[5]), 'root.txt')
self.assertEqual(row[1], 'True')
self.assertEqual(row[0], u'hello, world; foo')
row = data[2]
self.assertTrue(basename(row[5]).startswith('sub'))
self.assertEqual(row[1], 'False')
self.assertEqual(row[0], '')
row = data[3]
self.assertTrue(basename(row[5]).startswith('sub'))
self.assertEqual(row[1], 'False')
self.assertEqual(row[0], '')
with io.open(out_fname, newline='', encoding='utf-8') as fp:
reader = csv.reader(fp)
cols = next(reader)
expected = [
'comment', 'completed', 'ctime', 'file_name', 'mtime', 'path',
'relpath', 'size', 'type'
]
self.assertEqual(cols, expected)
expected = {'hello.py': 'False', 'root.txt': 'True'}
data = [next(reader), next(reader), next(reader), next(reader)]
data = sorted(data, key=lambda x: x[6])
row = data[0]
self.assertEqual(basename(row[5]), u'hello.py')
self.assertEqual(row[1], u'False')
self.assertEqual(row[0], u'')
row = data[1]
self.assertEqual(basename(row[5]), u'root.txt')
self.assertEqual(row[1], u'True')
self.assertEqual(row[0], u'hello, world; न Kévin')
row = data[2]
self.assertTrue(basename(row[5]).startswith(u'sub'))
self.assertEqual(row[1], u'False')
self.assertEqual(row[0], u'')
row = data[3]
self.assertTrue(basename(row[5]).startswith(u'sub'))
self.assertEqual(row[1], u'False')
self.assertEqual(row[0], u'')

def test_refresh_updates_new_media(self):
# Given
Expand Down Expand Up @@ -285,15 +293,15 @@ def test_setting_root_extensions_limits_files(self):

def _write_csv(self, data):
fname = join(self._temp, 'data.csv')
with open(join(self._temp, 'data.csv'), 'w') as fp:
with io.open(fname, 'w', encoding='utf-8') as fp:
fp.write(data)
return fname

def test_import_csv_fails_with_bad_csv_header(self):
# Given
p = Project(name='test', path=self.root)
p.scan()
data = dedent("""\
data = dedent(u"""\
/blah/blah,1
""")
csv = self._write_csv(data)
Expand All @@ -305,7 +313,7 @@ def test_import_csv_fails_with_bad_csv_header(self):
self.assertFalse(success)

# Given
data = dedent("""\
data = dedent(u"""\
relpath,fox
root.txt,1
""")
Expand All @@ -322,7 +330,7 @@ def test_import_csv_works(self):
p = Project(name='test', path=self.root)
p.add_tags([TagInfo(name='fox', type='int')])
p.scan()
data = dedent("""\
data = dedent(u"""\
path,fox,junk
%s,2,hello
%s,1,bye
Expand All @@ -341,7 +349,6 @@ def test_import_csv_works(self):
self.assertEqual(p.get('hello.py').tags['fox'], 1)



class TestSearchMedia(TestProjectBase):
def test_query_schema_is_setup_correctly(self):
# Given
Expand Down

0 comments on commit d471f5a

Please sign in to comment.