Skip to content

Commit

Permalink
Merge branch 'xinclude' of https://github.com/aptivate/sunburnt into …
Browse files Browse the repository at this point in the history
…aptivate-xinclude

Conflicts:
	sunburnt/sunburnt.py
  • Loading branch information
tow committed Oct 28, 2015
2 parents 92e8dbb + 30befeb commit 5494d56
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 21 deletions.
24 changes: 14 additions & 10 deletions sunburnt/schema.py
Expand Up @@ -130,7 +130,7 @@ def __init__(self, name, indexed=None, stored=None, required=False, multiValued=
elif self.name.endswith("*"):
self.wildcard_at_start = False
else:
raise SolrError("Dynamic fields must have * at start or end of name (field %s)" %
raise SolrError("Dynamic fields must have * at start or end of name (field %s)" %
self.name)

def match(self, name):
Expand All @@ -142,7 +142,7 @@ def match(self, name):

def normalize(self, value):
""" Normalize the given value according to the field type.
This method does nothing by default, returning the given value
as is. Child classes may override this method as required.
"""
Expand Down Expand Up @@ -181,7 +181,7 @@ def from_solr(self, value):
try:
return unicode(value)
except UnicodeError:
raise SolrError("%s could not be coerced to unicode (field %s)" %
raise SolrError("%s could not be coerced to unicode (field %s)" %
(value, self.name))


Expand All @@ -196,7 +196,7 @@ def normalize(self, value):
elif value.lower() == "false":
return False
else:
raise ValueError("sorry, I only understand simple boolean strings (field %s)" %
raise ValueError("sorry, I only understand simple boolean strings (field %s)" %
self.name)
return bool(value)

Expand All @@ -206,7 +206,7 @@ def from_user_data(self, value):
try:
return str(value)
except (TypeError, ValueError):
raise SolrError("Could not convert data to binary string (field %s)" %
raise SolrError("Could not convert data to binary string (field %s)" %
self.name)

def to_solr(self, value):
Expand All @@ -221,7 +221,7 @@ def normalize(self, value):
try:
v = self.base_type(value)
except (OverflowError, TypeError, ValueError):
raise SolrError("%s is invalid value for %s (field %s)" %
raise SolrError("%s is invalid value for %s (field %s)" %
(value, self.__class__, self.name))
if v < self.min or v > self.max:
raise SolrError("%s out of range for a %s (field %s)" %
Expand Down Expand Up @@ -413,10 +413,14 @@ def Q(self, *args, **kwargs):
return q

def schema_parse(self, f):
try:
schemadoc = lxml.etree.parse(f)
except lxml.etree.XMLSyntaxError, e:
raise SolrError("Invalid XML in schema:\n%s" % e.args[0])
# hack as we might pass in an already parsed doc
if hasattr(f, 'getroot'):
schemadoc = f
else:
try:
schemadoc = lxml.etree.parse(f)
except lxml.etree.XMLSyntaxError, e:
raise SolrError("Invalid XML in schema:\n%s" % e.args[0])

field_type_classes = {}
for field_type_node in schemadoc.xpath("/schema/types/fieldType|/schema/types/fieldtype|/schema/fieldType|/schema/fieldtype"):
Expand Down
81 changes: 74 additions & 7 deletions sunburnt/sunburnt.py
@@ -1,8 +1,10 @@
from __future__ import absolute_import

from os import path
from lxml import etree
import cStringIO as StringIO
from itertools import islice
import time, urllib, urlparse
import shutil, tempfile, time, urllib, urlparse
import warnings

from .http import ConnectionError, wrap_http_connection
Expand Down Expand Up @@ -153,7 +155,7 @@ def mlt(self, params, content=None):


class SolrInterface(object):
remote_schema_file = "admin/file/?file=schema.xml"

def __init__(self, url, schemadoc=None, http_connection=None, mode='', retry_timeout=-1,
max_length_get_url=MAX_LENGTH_GET_URL, format='xml'):
self.conn = SolrConnection(url, http_connection, mode, retry_timeout, max_length_get_url, format)
Expand All @@ -163,17 +165,82 @@ def __init__(self, url, schemadoc=None, http_connection=None, mode='', retry_tim
raise ValueError("Unsupported format '%s': allowed are %s" %
(format, ','.join(allowed_formats)))
self.format = format
self.file_cache = {}
self.init_schema()

def make_file_url(self, filename):
return urlparse.urljoin(self.conn.url, 'admin/file/?file=') + filename

def get_file(self, filename):
# return remote file as StringIO and cache the contents
if filename not in self.file_cache:
response = self.conn.request('GET', self.make_file_url(filename))
if response.status_code == 200:
self.file_cache[filename] = response.content
elif response.status_code == 404:
return None
else:
raise EnvironmentError("Couldn't retrieve schema document from server - received status code %s\n%s" % (response.status_code, content))
return StringIO.StringIO(self.file_cache[filename])

def save_file_cache(self, dirname):
# take the file cache and save to a directory
for filename in self.file_cache:
open(path.join(dirname, filename), 'w').write(self.file_cache[filename])

def get_xinclude_list_for_file(self, filename):
# return a list of xinclude elements in this file
file_contents = self.get_file(filename)
if file_contents is None:
return None
else:
tree = etree.parse(self.get_file(filename))
return tree.getroot().findall('{http://www.w3.org/2001/XInclude}include')

def get_file_and_included_files(self, filename):
# return a list containing this file, and all files this file includes
# via xinclude. And do this recursively to ensure we have all we need.
file_list = [filename]
xinclude_list = self.get_xinclude_list_for_file(filename)
if xinclude_list is None:
# this means we didn't even find the top level file
return []
for xinclude_node in xinclude_list:
included_file = xinclude_node.get('href')
file_list += self.get_file_and_included_files(included_file)
return file_list

def get_parsed_schema_file_with_xincludes(self, filename):
# get the parsed schema file, and ensure we also get any files
# required for any xinclude. If an xinclude is required, we need
# to save the files to the local disk before we call xinclude()
try:
file_list = self.get_file_and_included_files(filename)
if len(file_list) == 0:
# this means we didn't even find the top level file
raise EnvironmentError("Couldn't retrieve schema document from server - received status code 404\n")
if len(file_list) == 1:
# there are no xincludes, we can do this the easy way
schemadoc = etree.parse(self.get_file(filename))
else:
# save all contents to files, then load from file and xinclude
dirname = tempfile.mkdtemp()
try:
self.save_file_cache(dirname)
schemadoc = etree.parse(path.join(dirname, filename))
schemadoc.xinclude()
finally:
# delete dirname
shutil.rmtree(dirname)
except etree.XMLSyntaxError, e:
raise SolrError("Invalid XML in schema:\n%s" % e.args[0])
return schemadoc

def init_schema(self):
if self.schemadoc:
schemadoc = self.schemadoc
else:
response = self.conn.request('GET',
urlparse.urljoin(self.conn.url, self.remote_schema_file))
if response.status_code != 200:
raise EnvironmentError("Couldn't retrieve schema document from server - received status code %s\n%s" % (response.status_code, response.content))
schemadoc = StringIO.StringIO(response.content)
schemadoc = self.get_parsed_schema_file_with_xincludes('schema.xml')
self.schema = SolrSchema(schemadoc, format=self.format)

def add(self, docs, chunk=100, **kwargs):
Expand Down
98 changes: 94 additions & 4 deletions sunburnt/test_sunburnt.py
Expand Up @@ -12,7 +12,7 @@

from .sunburnt import SolrInterface

from nose.tools import assert_equal
from nose.tools import assert_equal, assert_in

debug = False

Expand Down Expand Up @@ -97,6 +97,9 @@ def xml_response(self):


class MockConnection(object):

file_dict = {'schema.xml': schema_string}

class MockStatus(object):
def __init__(self, status):
self.status = status
Expand All @@ -117,8 +120,12 @@ def request(self, uri, method='GET', body=None, headers=None):
body=body or '',
headers=headers or {})

if method == 'GET' and u.path.endswith('/admin/file/') and params.get("file") == ["schema.xml"]:
return self.MockStatus(200), schema_string
if method == 'GET' and u.path.endswith('/admin/file/'):
filename = params.get("file")[0]
if filename in self.file_dict:
return self.MockStatus(200), self.file_dict[filename]
else:
return self.MockStatus(404), None

rc = self._handle_request(u, params, method, body, headers)
if rc is not None:
Expand Down Expand Up @@ -174,7 +181,7 @@ def _handle_request(self, uri_obj, params, method, body, headers):
slice(5, 6, None),
slice(0, 5, 2),
slice(3, 6, 2),
slice(5, None, -1),
slice(5, None, -1),
slice(None, 0, -1),
# out of range but ok
slice(0, 12, None),
Expand Down Expand Up @@ -284,3 +291,86 @@ def check_mlt_query(i, o, E):
def test_mlt_queries():
for i, o, E in mlt_query_tests:
yield check_mlt_query, i, o, E

schema_string_with_xinclude = \
"""<schema name="timetric" version="1.1">
<xi:include href="schema_extra_types.xml" xmlns:xi="http://www.w3.org/2001/XInclude">
<xi:fallback/>
</xi:include>
<!-- Following is a dynamic way to include other fields, added by other contrib modules -->
<xi:include href="schema_extra_fields.xml" xmlns:xi="http://www.w3.org/2001/XInclude">
<xi:fallback/>
</xi:include>
<xi:include href="schema_not_available.xml" xmlns:xi="http://www.w3.org/2001/XInclude">
<xi:fallback/>
</xi:include>
<defaultSearchField>text_field</defaultSearchField>
<uniqueKey>int_field</uniqueKey>
</schema>"""

schema_fieldtypes_to_be_included = \
"""<types>
<fieldType name="string" class="solr.StrField" sortMissingLast="true" omitNorms="true"/>
<fieldType name="text" class="solr.TextField" sortMissingLast="true" omitNorms="true"/>
<fieldType name="boolean" class="solr.BoolField" sortMissingLast="true" omitNorms="true"/>
<fieldType name="int" class="solr.IntField" sortMissingLast="true" omitNorms="true"/>
<fieldType name="sint" class="solr.SortableIntField" sortMissingLast="true" omitNorms="true"/>
<fieldType name="long" class="solr.LongField" sortMissingLast="true" omitNorms="true"/>
<fieldType name="slong" class="solr.SortableLongField" sortMissingLast="true" omitNorms="true"/>
<fieldType name="float" class="solr.FloatField" sortMissingLast="true" omitNorms="true"/>
<fieldType name="sfloat" class="solr.SortableFloatField" sortMissingLast="true" omitNorms="true"/>
<fieldType name="double" class="solr.DoubleField" sortMissingLast="true" omitNorms="true"/>
<fieldType name="sdouble" class="solr.SortableDoubleField" sortMissingLast="true" omitNorms="true"/>
<fieldType name="date" class="solr.DateField" sortMissingLast="true" omitNorms="true"/>
</types>"""

schema_fields_to_be_included = \
"""<fields>
<field name="string_field" required="true" type="string" multiValued="true"/>
<field name="text_field" required="true" type="text"/>
<field name="boolean_field" required="false" type="boolean"/>
<field name="int_field" required="true" type="int"/>
<field name="sint_field" type="sint"/>
<field name="long_field" type="long"/>
<field name="slong_field" type="slong"/>
<field name="float_field" type="float"/>
<field name="sfloat_field" type="sfloat"/>
<field name="double_field" type="double"/>
<field name="sdouble_field" type="sdouble"/>
<field name="date_field" type="date"/>
</fields>"""


class XincludeMockConnection(MockConnection):
file_dict = {
'schema.xml': schema_string_with_xinclude,
'schema_extra_types.xml': schema_fieldtypes_to_be_included,
'schema_extra_fields.xml': schema_fields_to_be_included,
}


def test_schema_with_xinclude_gets_assembled():
si = SolrInterface("http://test.example.com/", http_connection=XincludeMockConnection())
assert_equal(12, len(si.schema.fields))


def test_schema_file_cache_gets_filled():
si = SolrInterface("http://test.example.com/", http_connection=XincludeMockConnection())
assert_equal(schema_string_with_xinclude, si.file_cache['schema.xml'])
assert_equal(schema_fieldtypes_to_be_included, si.file_cache['schema_extra_types.xml'])
assert_equal(schema_fields_to_be_included, si.file_cache['schema_extra_fields.xml'])


def test_all_xincludes_found():
si = SolrInterface("http://test.example.com/", http_connection=XincludeMockConnection())
assert_equal(3, len(si.get_xinclude_list_for_file('schema.xml')))
assert_equal(0, len(si.get_xinclude_list_for_file('schema_extra_fields.xml')))


def test_get_file_and_included_files_list_includes_all_required_files():
si = SolrInterface("http://test.example.com/", http_connection=XincludeMockConnection())
file_list = si.get_file_and_included_files('schema.xml')
assert_equal(3, len(file_list))
assert_in('schema.xml', file_list)
assert_in('schema_extra_fields.xml', file_list)
assert_in('schema_extra_types.xml', file_list)

0 comments on commit 5494d56

Please sign in to comment.