Skip to content

Commit

Permalink
Add support for Solr's JSON API
Browse files Browse the repository at this point in the history
For large results sets, this can be parsed and interpreted much more
efficiently than the equivalent XML response.

TODO:
 * Support JSON for queries other than /select
 * Support "interestingTerms" in MLT queries
  • Loading branch information
evansd committed Nov 11, 2013
1 parent 4ad86c2 commit 808768c
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 6 deletions.
92 changes: 90 additions & 2 deletions sunburnt/schema.py
Expand Up @@ -8,6 +8,11 @@
from lxml.builder import E
import lxml.etree

try:
import simplejson as json
except ImportError:
import json

from .dates import datetime_from_w3_datestring
from .strings import RawString, SolrString, WildcardString

Expand Down Expand Up @@ -400,9 +405,10 @@ class SolrSchema(object):
'solr.GeoHashField':SolrPoint2Field,
}

def __init__(self, f):
def __init__(self, f, format='xml'):
"""initialize a schema object from a
filename or file-like object."""
self.format = format
self.fields, self.dynamic_fields, self.default_field_name, self.unique_key \
= self.schema_parse(f)
self.default_field = self.fields[self.default_field_name] \
Expand Down Expand Up @@ -518,7 +524,10 @@ def make_delete(self, docs, query):
return SolrDelete(self, docs, query)

def parse_response(self, msg):
return SolrResponse.from_xml(self, msg)
if self.format == 'json':
return SolrResponse.from_json(self, msg)
else:
return SolrResponse.from_xml(self, msg)

def parse_result_doc(self, doc, name=None):
if name is None:
Expand All @@ -535,6 +544,26 @@ def parse_result_doc(self, doc, name=None):
raise SolrError("unexpected field found in result (field name: %s)" % name)
return name, SolrFieldInstance.from_solr(field_class, doc.text or '').to_user_data()

def parse_result_doc_json(self, doc):
result = {}
for name, value in doc.viewitems():
field_class = self.match_field(name)
if field_class is None and name == "score":
field_class = SolrScoreField()
elif field_class is None:
raise SolrError("unexpected field found in result (field name: %s)" % name)
if isinstance(value, list):
parsed_value = [self.cast_value(field_class, v) for v in value]
else:
parsed_value = self.cast_value(field_class, value)
result[name] = parsed_value
return result

def cast_value(self, field_class, value):
if isinstance(field_class, SolrUnicodeField):
return value
return SolrFieldInstance.from_solr(field_class, value).to_user_data()


class SolrUpdate(object):
ADD = E.add
Expand Down Expand Up @@ -646,6 +675,25 @@ def from_response(cls, response):
facet_counts_dict = dict(response.get("facet_counts", {}))
return SolrFacetCounts(**facet_counts_dict)

@classmethod
def from_response_json(cls, response):
try:
facet_counts_dict = response['facet_counts']
except KeyError:
return SolrFacetCounts()
facet_fields = {}
for facet_field, facet_values in facet_counts_dict['facet_fields'].viewitems():
facets = []
# Change each facet list from [a, 1, b, 2, c, 3 ...] to
# [(a, 1), (b, 2), (c, 3) ...]
for n, value in enumerate(facet_values):
if n&1 == 0:
name = value
else:
facets.append((name, value))
facet_fields[facet_field] = facets
facet_counts_dict['facet_fields'] = facet_fields
return SolrFacetCounts(**facet_counts_dict)

class SolrResponse(object):
@classmethod
Expand Down Expand Up @@ -686,6 +734,36 @@ def from_xml(cls, schema, xmlmsg):
self.interesting_terms = value
return self

@classmethod
def from_json(cls, schema, jsonmsg):
self = cls()
self.schema = schema
self.original_json = jsonmsg
doc = json.loads(jsonmsg)
details = doc['responseHeader']
for attr in ["QTime", "params", "status"]:
setattr(self, attr, details.get(attr))
if self.status != 0:
raise ValueError("Response indicates an error")
self.result = SolrResult.from_json(schema, doc['response'])
self.facet_counts = SolrFacetCounts.from_response_json(doc)
self.highlighting = dict((k, dict(v))
for k, v in details.get("highlighting", ()))
self.more_like_these = dict((k, SolrResult.from_json(schema, v))
for (k, v) in doc['moreLikeThis'].viewitems())
if len(self.more_like_these) == 1:
self.more_like_this = self.more_like_these.values()[0]
else:
self.more_like_this = None
# can be computed by MoreLikeThisHandler
#termsNodes = doc.xpath("/response/*[@name='interestingTerms']")
#if len(termsNodes) == 1:
# _, value = value_from_node(termsNodes[0])
#else:
# value = None
self.interesting_terms = None
return self

def __str__(self):
return str(self.result)

Expand All @@ -707,6 +785,16 @@ def from_xml(cls, schema, node):
self.docs = [schema.parse_result_doc(n) for n in node.xpath("doc")]
return self

@classmethod
def from_json(cls, schema, node):
self = cls()
self.schema = schema
self.name = 'response'
self.numFound = int(node['numFound'])
self.start = int(node['start'])
self.docs = [schema.parse_result_doc_json(n) for n in node["docs"]]
return self

def __str__(self):
return "%(numFound)s results found, starting at #%(start)s\n\n" % self.__dict__ + str(self.docs)

Expand Down
17 changes: 13 additions & 4 deletions sunburnt/sunburnt.py
Expand Up @@ -13,7 +13,7 @@
# Jetty default is 4096; Tomcat default is 8192; picking 2048 to be conservative.

class SolrConnection(object):
def __init__(self, url, http_connection, retry_timeout, max_length_get_url):
def __init__(self, url, http_connection, retry_timeout, max_length_get_url, format):
if http_connection:
self.http_connection = http_connection
else:
Expand All @@ -25,6 +25,7 @@ def __init__(self, url, http_connection, retry_timeout, max_length_get_url):
self.mlt_url = self.url + "mlt/"
self.retry_timeout = retry_timeout
self.max_length_get_url = max_length_get_url
self.format = format

def request(self, *args, **kwargs):
try:
Expand Down Expand Up @@ -98,6 +99,8 @@ def url_for_update(self, commit=None, commitWithin=None, softCommit=None, optimi
return self.update_url

def select(self, params):
if self.format == 'json':
params.append(('wt', 'json'))
qs = urllib.urlencode(params)
url = "%s?%s" % (self.select_url, qs)
if len(url) > self.max_length_get_url:
Expand Down Expand Up @@ -141,13 +144,19 @@ class SolrInterface(object):
readable = True
writeable = True
remote_schema_file = "admin/file/?file=schema.xml"
def __init__(self, url, schemadoc=None, http_connection=None, mode='', retry_timeout=-1, max_length_get_url=MAX_LENGTH_GET_URL):
self.conn = SolrConnection(url, http_connection, retry_timeout, max_length_get_url)
def __init__(self, url, schemadoc=None, http_connection=None, mode='', retry_timeout=-1,
max_length_get_url=MAX_LENGTH_GET_URL, format='xml'):
self.conn = SolrConnection(url, http_connection, retry_timeout, max_length_get_url, format)
self.schemadoc = schemadoc
if mode == 'r':
self.writeable = False
elif mode == 'w':
self.readable = False
allowed_formats = ('xml', 'json')
if format not in allowed_formats:
raise ValueError("Unsupported format '%s': allowed are %s" %
(format, ','.join(allowed_formats)))
self.format = format
self.init_schema()

def init_schema(self):
Expand All @@ -159,7 +168,7 @@ def init_schema(self):
if r.status != 200:
raise EnvironmentError("Couldn't retrieve schema document from server - received status code %s\n%s" % (r.status, c))
schemadoc = StringIO.StringIO(c)
self.schema = SolrSchema(schemadoc)
self.schema = SolrSchema(schemadoc, format=self.format)

def add(self, docs, chunk=100, **kwargs):
if not self.writeable:
Expand Down

0 comments on commit 808768c

Please sign in to comment.