diff --git a/tableaudocumentapi/datasource.py b/tableaudocumentapi/datasource.py
index 924575d..0fdc3fb 100644
--- a/tableaudocumentapi/datasource.py
+++ b/tableaudocumentapi/datasource.py
@@ -12,17 +12,54 @@
from tableaudocumentapi import Field
from tableaudocumentapi.multilookup_dict import MultiLookupDict
+########
+# This is needed in order to determine if something is a string or not. It is necessary because
+# of differences between python2 (basestring) and python3 (str). If python2 support is every
+# dropped, remove this and change the basestring references below to str
+try:
+ basestring
+except NameError:
+ basestring = str
+########
-def _mapping_from_xml(root_xml, column_xml):
- retval = Field.from_xml(column_xml)
- local_name = retval.id
- if "'" in local_name:
- local_name = sax.escape(local_name, {"'": "'"})
- xpath = ".//metadata-record[@class='column'][local-name='{}']".format(local_name)
- metadata_record = root_xml.find(xpath)
+_ColumnObjectReturnTuple = collections.namedtuple('_ColumnObjectReturnTupleType', ['id', 'object'])
+
+
+def _get_metadata_xml_for_field(root_xml, field_name):
+ if "'" in field_name:
+ field_name = sax.escape(field_name, {"'": "'"})
+ xpath = ".//metadata-record[@class='column'][local-name='{}']".format(field_name)
+ return root_xml.find(xpath)
+
+
+def _is_used_by_worksheet(names, field):
+ return any((y for y in names if y in field.worksheets))
+
+
+class FieldDictionary(MultiLookupDict):
+ def used_by_sheet(self, name):
+ # If we pass in a string, no need to get complicated, just check to see if name is in
+ # the field's list of worksheets
+ if isinstance(name, basestring):
+ return [x for x in self.values() if name in x.worksheets]
+
+ # if we pass in a list, we need to check to see if any of the names in the list are in
+ # the field's list of worksheets
+ return [x for x in self.values() if _is_used_by_worksheet(name, x)]
+
+
+def _column_object_from_column_xml(root_xml, column_xml):
+ field_object = Field.from_column_xml(column_xml)
+ local_name = field_object.id
+ metadata_record = _get_metadata_xml_for_field(root_xml, local_name)
if metadata_record is not None:
- retval.apply_metadata(metadata_record)
- return retval.id, retval
+ field_object.apply_metadata(metadata_record)
+ return _ColumnObjectReturnTuple(field_object.id, field_object)
+
+
+def _column_object_from_metadata_xml(metadata_xml):
+ field_object = Field.from_metadata_xml(metadata_xml)
+ return _ColumnObjectReturnTuple(field_object.id, field_object)
class ConnectionParser(object):
@@ -73,7 +110,7 @@ def __init__(self, dsxml, filename=None):
@classmethod
def from_file(cls, filename):
- "Initialize datasource from file (.tds)"
+ """Initialize datasource from file (.tds)"""
if zipfile.is_zipfile(filename):
dsxml = xfile.get_xml_from_archive(filename).getroot()
@@ -141,6 +178,16 @@ def fields(self):
return self._fields
def _get_all_fields(self):
- column_objects = (_mapping_from_xml(self._datasourceTree, xml)
- for xml in self._datasourceTree.findall('.//column'))
- return MultiLookupDict({k: v for k, v in column_objects})
+ column_objects = [_column_object_from_column_xml(self._datasourceTree, xml)
+ for xml in self._datasourceTree.findall('.//column')]
+ existing_fields = [x.id for x in column_objects]
+ metadata_fields = (x.text
+ for x in self._datasourceTree.findall(".//metadata-record[@class='column']/local-name"))
+
+ missing_fields = (x for x in metadata_fields if x not in existing_fields)
+ column_objects.extend((
+ _column_object_from_metadata_xml(_get_metadata_xml_for_field(self._datasourceTree, field_name))
+ for field_name in missing_fields
+ ))
+
+ return FieldDictionary({k: v for k, v in column_objects})
diff --git a/tableaudocumentapi/field.py b/tableaudocumentapi/field.py
index 8162cdb..4af648f 100644
--- a/tableaudocumentapi/field.py
+++ b/tableaudocumentapi/field.py
@@ -14,6 +14,12 @@
'aggregation', # The type of aggregation on the field (e.g Sum, Avg)
]
+_METADATA_TO_FIELD_MAP = [
+ ('local-name', 'id'),
+ ('local-type', 'datatype'),
+ ('remote-alias', 'alias')
+]
+
def _find_metadata_record(record, attrib):
element = record.find('.//{}'.format(attrib))
@@ -25,25 +31,60 @@ def _find_metadata_record(record, attrib):
class Field(object):
""" Represents a field in a datasource """
- def __init__(self, xmldata):
- for attrib in _ATTRIBUTES:
- self._apply_attribute(xmldata, attrib, lambda x: xmldata.attrib.get(x, None))
+ def __init__(self, column_xml=None, metadata_xml=None):
- # All metadata attributes begin at None
+ # Initialize all the possible attributes
+ for attrib in _ATTRIBUTES:
+ setattr(self, '_{}'.format(attrib), None)
for attrib in _METADATA_ATTRIBUTES:
setattr(self, '_{}'.format(attrib), None)
+ self._worksheets = set()
+
+ if column_xml is not None:
+ self._initialize_from_column_xml(column_xml)
+ if metadata_xml is not None:
+ self.apply_metadata(metadata_xml)
+
+ elif metadata_xml is not None:
+ self._initialize_from_metadata_xml(metadata_xml)
+
+ else:
+ raise AttributeError('column_xml or metadata_xml needed to initialize field')
+
+ def _initialize_from_column_xml(self, xmldata):
+ for attrib in _ATTRIBUTES:
+ self._apply_attribute(xmldata, attrib, lambda x: xmldata.attrib.get(x, None))
+
+ def _initialize_from_metadata_xml(self, xmldata):
+ for metadata_name, field_name in _METADATA_TO_FIELD_MAP:
+ self._apply_attribute(xmldata, field_name, lambda x: xmldata.find('.//{}'.format(metadata_name)).text,
+ read_name=metadata_name)
+ self.apply_metadata(xmldata)
+ ########################################
+ # Special Case methods for construction fields from various sources
+ # not intended for client use
+ ########################################
def apply_metadata(self, metadata_record):
for attrib in _METADATA_ATTRIBUTES:
self._apply_attribute(metadata_record, attrib, functools.partial(_find_metadata_record, metadata_record))
+ def add_used_in(self, name):
+ self._worksheets.add(name)
+
@classmethod
- def from_xml(cls, xmldata):
- return cls(xmldata)
+ def from_column_xml(cls, xmldata):
+ return cls(column_xml=xmldata)
- def _apply_attribute(self, xmldata, attrib, default_func):
- if hasattr(self, '_read_{}'.format(attrib)):
- value = getattr(self, '_read_{}'.format(attrib))(xmldata)
+ @classmethod
+ def from_metadata_xml(cls, xmldata):
+ return cls(metadata_xml=xmldata)
+
+ def _apply_attribute(self, xmldata, attrib, default_func, read_name=None):
+ if read_name is None:
+ read_name = attrib
+ if hasattr(self, '_read_{}'.format(read_name)):
+ value = getattr(self, '_read_{}'.format(read_name))(xmldata)
else:
value = default_func(attrib)
@@ -121,6 +162,10 @@ def default_aggregation(self):
""" The default type of aggregation on the field (e.g Sum, Avg)"""
return self._aggregation
+ @property
+ def worksheets(self):
+ return list(self._worksheets)
+
######################################
# Special Case handling methods for reading the values from the XML
######################################
diff --git a/tableaudocumentapi/workbook.py b/tableaudocumentapi/workbook.py
index 9e29973..fd85b3c 100644
--- a/tableaudocumentapi/workbook.py
+++ b/tableaudocumentapi/workbook.py
@@ -5,17 +5,12 @@
###############################################################################
import os
import zipfile
+import weakref
import xml.etree.ElementTree as ET
from tableaudocumentapi import Datasource, xfile
-###########################################################################
-#
-# Utility Functions
-#
-###########################################################################
-
class Workbook(object):
"""
@@ -33,6 +28,7 @@ def __init__(self, filename):
Constructor.
"""
+
self._filename = filename
# Determine if this is a twb or twbx and get the xml root
@@ -47,6 +43,12 @@ def __init__(self, filename):
self._datasources = self._prepare_datasources(
self._workbookRoot) # self.workbookRoot.find('datasources')
+ self._datasource_index = self._prepare_datasource_index(self._datasources)
+
+ self._worksheets = self._prepare_worksheets(
+ self._workbookRoot, self._datasource_index
+ )
+
###########
# datasources
###########
@@ -54,6 +56,13 @@ def __init__(self, filename):
def datasources(self):
return self._datasources
+ ###########
+ # worksheets
+ ###########
+ @property
+ def worksheets(self):
+ return self._worksheets
+
###########
# filename
###########
@@ -95,12 +104,47 @@ def save_as(self, new_filename):
# Private API.
#
###########################################################################
- def _prepare_datasources(self, xmlRoot):
+ @staticmethod
+ def _prepare_datasource_index(datasources):
+ retval = weakref.WeakValueDictionary()
+ for datasource in datasources:
+ retval[datasource.name] = datasource
+
+ return retval
+
+ @staticmethod
+ def _prepare_datasources(xml_root):
datasources = []
# loop through our datasources and append
- for datasource in xmlRoot.find('datasources'):
+ datasource_elements = xml_root.find('datasources')
+ if datasource_elements is None:
+ return []
+
+ for datasource in datasource_elements:
ds = Datasource(datasource)
datasources.append(ds)
return datasources
+
+ @staticmethod
+ def _prepare_worksheets(xml_root, ds_index):
+ worksheets = []
+ worksheets_element = xml_root.find('.//worksheets')
+ if worksheets_element is None:
+ return worksheets
+
+ for worksheet_element in worksheets_element:
+ worksheet_name = worksheet_element.attrib['name']
+ worksheets.append(worksheet_name) # TODO: A real worksheet object, for now, only name
+
+ dependencies = worksheet_element.findall('.//datasource-dependencies')
+
+ for dependency in dependencies:
+ datasource_name = dependency.attrib['datasource']
+ datasource = ds_index[datasource_name]
+ for column in dependency.findall('.//column'):
+ column_name = column.attrib['name']
+ datasource.fields[column_name].add_used_in(worksheet_name)
+
+ return worksheets
diff --git a/test/assets/TABLEAU_10_TWB.twb b/test/assets/TABLEAU_10_TWB.twb
index c116bdf..aa0207f 100644
--- a/test/assets/TABLEAU_10_TWB.twb
+++ b/test/assets/TABLEAU_10_TWB.twb
@@ -1 +1,22 @@
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/test/assets/datasource_test.twb b/test/assets/datasource_test.twb
new file mode 100644
index 0000000..af87659
--- /dev/null
+++ b/test/assets/datasource_test.twb
@@ -0,0 +1,172 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ a
+ 130
+ [a]
+ [xy]
+ a
+ 1
+ string
+ Count
+ 255
+ true
+
+ "SQL_WVARCHAR"
+ "SQL_C_WCHAR"
+ "true"
+
+
+
+ x
+ 3
+ [x]
+ [xy]
+ x
+ 2
+ integer
+ Sum
+ 10
+ true
+
+ "SQL_INTEGER"
+ "SQL_C_SLONG"
+
+
+
+ y
+ 3
+ [y]
+ [xy]
+ y
+ 3
+ integer
+ Sum
+ 10
+ true
+
+ "SQL_INTEGER"
+ "SQL_C_SLONG"
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ [datasource_test].[none:a:nk]
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ [datasource_test].[none:a:nk]
+ [datasource_test].[sum:x:qk]
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/test/assets/empty_workbook.twb b/test/assets/empty_workbook.twb
new file mode 100644
index 0000000..4b22c74
--- /dev/null
+++ b/test/assets/empty_workbook.twb
@@ -0,0 +1,3 @@
+
+
+
\ No newline at end of file
diff --git a/test/bvt.py b/test/bvt.py
index 1dedd57..6a7cdf8 100644
--- a/test/bvt.py
+++ b/test/bvt.py
@@ -22,6 +22,8 @@
TABLEAU_10_TDSX = os.path.join(TEST_DIR, 'assets', 'TABLEAU_10_TDSX.tdsx')
+EMPTY_WORKBOOK = os.path.join(TEST_DIR, 'assets', 'empty_workbook.twb')
+
class ConnectionParserTests(unittest.TestCase):
@@ -278,5 +280,12 @@ def test_can_open_twbx_and_save_as_changes(self):
os.unlink(new_twbx_filename)
+
+class EmptyWorkbookWillLoad(unittest.TestCase):
+ def test_no_exceptions_thrown(self):
+ wb = Workbook(EMPTY_WORKBOOK)
+ self.assertIsNotNone(wb)
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/test/test_datasource.py b/test/test_datasource.py
index 0a2457e..bf51746 100644
--- a/test/test_datasource.py
+++ b/test/test_datasource.py
@@ -1,16 +1,24 @@
import unittest
import os.path
-from tableaudocumentapi import Datasource
+from tableaudocumentapi import Datasource, Workbook
-TEST_TDS_FILE = os.path.join(
+TEST_ASSET_DIR = os.path.join(
os.path.dirname(__file__),
- 'assets',
+ 'assets'
+)
+TEST_TDS_FILE = os.path.join(
+ TEST_ASSET_DIR,
'datasource_test.tds'
)
+TEST_TWB_FILE = os.path.join(
+ TEST_ASSET_DIR,
+ 'datasource_test.twb'
+)
+
-class DataSourceFields(unittest.TestCase):
+class DataSourceFieldsTDS(unittest.TestCase):
def setUp(self):
self.ds = Datasource.from_file(TEST_TDS_FILE)
@@ -42,3 +50,47 @@ def test_datasource_field_is_quantitative(self):
def test_datasource_field_is_ordinal(self):
self.assertTrue(self.ds.fields['[x]'].is_ordinal)
+
+
+class DataSourceFieldsTWB(unittest.TestCase):
+ def setUp(self):
+ self.wb = Workbook(TEST_TWB_FILE)
+ self.ds = self.wb.datasources[0] # Assume the first datasource in the file
+
+ def test_datasource_fields_loaded_in_workbook(self):
+ self.assertIsNotNone(self.ds.fields)
+ self.assertIsNotNone(self.ds.fields.get('[Number of Records]', None))
+
+
+class DataSourceFieldsFoundIn(unittest.TestCase):
+ def setUp(self):
+ self.wb = Workbook(TEST_TWB_FILE)
+ self.ds = self.wb.datasources[0] # Assume the first datasource in the file
+
+ def test_datasource_fields_found_in_returns_fields(self):
+ actual_values = self.ds.fields.used_by_sheet('Sheet 1')
+ self.assertIsNotNone(actual_values)
+ self.assertEqual(1, len(actual_values))
+ self.assertIn('A', (x.name for x in actual_values))
+
+ def test_datasource_fields_found_in_does_not_return_fields_not_used_in_worksheet(self):
+ actual_values = self.ds.fields.used_by_sheet('Sheet 1')
+ self.assertIsNotNone(actual_values)
+ self.assertEqual(1, len(actual_values))
+ self.assertNotIn('X', (x.name for x in actual_values))
+
+ def test_datasource_fields_found_in_returns_multiple_fields(self):
+ actual_values = self.ds.fields.used_by_sheet('Sheet 2')
+ self.assertIsNotNone(actual_values)
+ self.assertEqual(2, len(actual_values))
+ self.assertIn('A', (x.name for x in actual_values))
+ self.assertIn('X', (x.name for x in actual_values))
+ self.assertNotIn('Y', (x.name for x in actual_values))
+
+ def test_datasource_fields_found_in_accepts_lists(self):
+ actual_values = self.ds.fields.used_by_sheet(['Sheet 1', 'Sheet 2'])
+ self.assertIsNotNone(actual_values)
+ self.assertEqual(2, len(actual_values))
+ self.assertIn('A', (x.name for x in actual_values))
+ self.assertIn('X', (x.name for x in actual_values))
+ self.assertNotIn('Y', (x.name for x in actual_values))