Skip to content

Commit

Permalink
Allow generic fields in switches
Browse files Browse the repository at this point in the history
  • Loading branch information
realVinayak committed May 16, 2023
1 parent de94ffd commit db7f014
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 64 deletions.
2 changes: 1 addition & 1 deletion specifyweb/stored_queries/execution.py
Expand Up @@ -621,5 +621,5 @@ def build_query(session, collection, user, tableid, field_specs,
where = reduce(sql.and_, (p for ps in predicates_by_field.values() for p in ps))
query = query.filter(where)

logger.debug("query: %s", query.query)
logger.warning("query: %s", query.query)
return query.query, order_by_exprs
163 changes: 100 additions & 63 deletions specifyweb/stored_queries/format.py
Expand Up @@ -18,7 +18,8 @@

from specifyweb.context.app_resource import get_app_resource
from specifyweb.context.remote_prefs import get_remote_prefs
from specifyweb.specify.models import datamodel, Spappresourcedata, Splocalecontainer, Splocalecontaineritem
from specifyweb.specify.models import datamodel, Spappresourcedata, \
Splocalecontainer, Splocalecontaineritem
from specifyweb.specify.datamodel import Field, Relationship, Table
from specifyweb.stored_queries.queryfield import QueryField

Expand All @@ -34,19 +35,23 @@
Agent_model = datamodel.get_table('Agent')
Spauditlog_model = datamodel.get_table('SpAuditLog')


class ObjectFormatter(object):
def __init__(self, collection, user, replace_nulls):
formattersXML, _ = get_app_resource(collection, user, 'DataObjFormatters')
formattersXML, _ = get_app_resource(collection, user,
'DataObjFormatters')
self.formattersDom = ElementTree.fromstring(formattersXML)
self.date_format = get_date_format()
self.date_format_year = MYSQL_TO_YEAR.get(self.date_format)
self.date_format_month = MYSQL_TO_MONTH.get(self.date_format)
self.collection = collection
self.replace_nulls = replace_nulls

def getFormatterDef(self, specify_model: Table, formatter_name) -> Optional[Element]:
def getFormatterDef(self, specify_model: Table, formatter_name) -> Optional[
Element]:
def lookup(attr: str, val: str) -> Optional[Element]:
return self.formattersDom.find('format[@%s=%s]' % (attr, quoteattr(val)))
return self.formattersDom.find(
'format[@%s=%s]' % (attr, quoteattr(val)))

def getFormatterFromSchema() -> Element:
try:
Expand All @@ -61,14 +66,17 @@ def getFormatterFromSchema() -> Element:
return formatter_name and lookup('name', formatter_name)

return (formatter_name and lookup('name', formatter_name)) \
or getFormatterFromSchema() \
or lookup('class', specify_model.classname)
or getFormatterFromSchema() \
or lookup('class', specify_model.classname)

def getAggregatorDef(self, specify_model: Table, aggregator_name) -> Optional[Element]:
def getAggregatorDef(self, specify_model: Table, aggregator_name) -> \
Optional[Element]:
def lookup(attr: str, val: str) -> Optional[Element]:
return self.formattersDom.find('aggregators/aggregator[@%s=%s]' % (attr, quoteattr(val)))
return self.formattersDom.find(
'aggregators/aggregator[@%s=%s]' % (attr, quoteattr(val)))

return (aggregator_name and lookup('name', aggregator_name)) \
or lookup('class', specify_model.classname)
or lookup('class', specify_model.classname)

def catalog_number_is_numeric(self):
return self.collection.catalognumformatname == 'CatalogNumberNumeric'
Expand All @@ -93,106 +101,131 @@ def pseudo_sprintf(self, format, expr):
else:
return format

def objformat(self, query: QueryConstruct, orm_table: SQLTable, formatter_name) -> Tuple[QueryConstruct, blank_nulls]:
def objformat(self, query: QueryConstruct, orm_table: SQLTable,
formatter_name) -> Tuple[QueryConstruct, blank_nulls]:
logger.info('formatting %s using %s', orm_table, formatter_name)
specify_model = datamodel.get_table(inspect(orm_table).class_.__name__, strict=True)
specify_model = datamodel.get_table(inspect(orm_table).class_.__name__,
strict=True)
formatterNode = self.getFormatterDef(specify_model, formatter_name)
if formatterNode is None:
logger.warn("no dataobjformatter for %s", specify_model)
return query, literal(_("<Formatter not defined.>"))
logger.debug("using dataobjformatter: %s", ElementTree.tostring(formatterNode))
logger.debug("using dataobjformatter: %s",
ElementTree.tostring(formatterNode))

def case_value_convert(value: Optional[str]) -> Optional[str]:
def case_value_convert(value: Optional[str]) -> Optional[str]:
return value

switchNode = formatterNode.find('switch')
single = switchNode.attrib.get('single', 'true') == 'true'
if not single:
sp_control_field = specify_model.get_field(switchNode.attrib['field'])
sp_control_field = specify_model.get_field(
switchNode.attrib['field'])
if sp_control_field.type == 'java.lang.Boolean':
def case_value_convert(value): return value == 'true'

def make_expr(query: QueryConstruct, fieldNode: Element) -> Tuple[QueryConstruct, blank_nulls]:
path = fieldNode.text.split('.')
path.insert(0, inspect(orm_table).class_.__name__)
def make_expr(query: QueryConstruct, fieldNodeText, fieldNodeAttrib) -> Tuple[
QueryConstruct, blank_nulls]:
path = fieldNodeText.split('.')
path = [inspect(orm_table).class_.__name__, *path]
formatter_field_spec = QueryFieldSpec.from_path(path)
if formatter_field_spec.is_relationship():
logger.warning('gets here')
if formatter_field_spec.get_field().type != 'one-to-many':
query, table, model, specify_field = formatter_field_spec.build_join(query, formatter_field_spec.join_path)
formatter_name = fieldNode.attrib.get('formatter', None)
query, table, model, specify_field = query.build_join(
specify_model, orm_table,
formatter_field_spec.join_path)
formatter_name = fieldNodeAttrib.get('formatter', None)
query, expr = self.objformat(query, table, formatter_name)
else:
query, orm_model, table, field = formatter_field_spec.build_join(query, formatter_field_spec.join_path[:-1])
aggregator_name = fieldNode.attrib.get('aggregator', None)
query, orm_model, table, field = query.build_join(
specify_model, orm_table,
formatter_field_spec.join_path)
aggregator_name = fieldNodeAttrib.get('aggregator', None)
expr = query.objectformatter.aggregate(query,
formatter_field_spec.get_field(),
orm_model,
aggregator_name)
formatter_field_spec.get_field(),
orm_model,
aggregator_name)
else:
expr = self._fieldformat(formatter_field_spec.get_field(), getattr( orm_table, formatter_field_spec.get_field().name))
query, table, model, specify_field = query.build_join(
specify_model, orm_table, formatter_field_spec.join_path)
expr = self._fieldformat(formatter_field_spec.get_field(),
getattr(table, specify_field.name))

if 'format' in fieldNode.attrib:
expr = self.pseudo_sprintf(fieldNode.attrib['format'], expr)
if 'format' in fieldNodeAttrib:
expr = self.pseudo_sprintf(fieldNodeAttrib['format'], expr)

if 'sep' in fieldNode.attrib:
expr = concat(fieldNode.attrib['sep'], expr)
if 'sep' in fieldNodeAttrib:
expr = concat(fieldNodeAttrib['sep'], expr)
return query, blank_nulls(expr)

def make_case(query: QueryConstruct, caseNode: Element) -> Tuple[QueryConstruct, Optional[str], blank_nulls]:
def make_case(query: QueryConstruct, caseNode: Element) -> Tuple[
QueryConstruct, Optional[str], blank_nulls]:
field_exprs = []
for node in caseNode.findall('field'):
query, expr = make_expr(query, node)
query, expr = make_expr(query, node.text, node.attrib)
field_exprs.append(expr)

expr = concat(*field_exprs) if len(field_exprs) > 1 else field_exprs[0]
return query, case_value_convert(caseNode.attrib.get('value', None)), expr
expr = concat(*field_exprs) if len(field_exprs) > 1 else \
field_exprs[0]
return query, case_value_convert(
caseNode.attrib.get('value', None)), expr

cases = []
for caseNode in switchNode.findall('fields'):
query, value, expr = make_case(query, caseNode)
cases.append((value, expr))

if not cases:
logger.warn("dataobjformatter for %s contains switch clause no fields", specify_model)
logger.warn(
"dataobjformatter for %s contains switch clause no fields",
specify_model)
return query, literal(_("<Formatter not defined.>"))

if single:
value, expr = cases[0]
else:
control_field = getattr(orm_table, switchNode.attrib['field'])
expr = case(cases, control_field)

query, formatted = make_expr(query, switchNode.attrib['field'], {})
expr = case(cases, formatted)
logger.warning(expr)
return query, blank_nulls(expr)

def aggregate(self, query: QueryConstruct, field: Union[Field, Relationship] , rel_table: SQLTable, aggregator_name) -> ScalarSelect:

logger.info('aggregating field %s on %s using %s', field, rel_table, aggregator_name)
def aggregate(self, query: QueryConstruct,
field: Union[Field, Relationship], rel_table: SQLTable,
aggregator_name) -> ScalarSelect:

logger.info('aggregating field %s on %s using %s', field, rel_table,
aggregator_name)
specify_model = datamodel.get_table(field.relatedModelName, strict=True)
aggregatorNode = self.getAggregatorDef(specify_model, aggregator_name)
if aggregatorNode is None:
logger.warn("aggregator is not defined")
return literal(_("<Aggregator not defined.>"))
logger.debug("using aggregator: %s", ElementTree.tostring(aggregatorNode))
logger.debug("using aggregator: %s",
ElementTree.tostring(aggregatorNode))
formatter_name = aggregatorNode.attrib.get('format', None)
separator = aggregatorNode.attrib.get('separator', ',')
order_by = aggregatorNode.attrib.get('orderfieldname', '')

orm_table = getattr(models, field.relatedModelName)
order_by = [getattr(orm_table, order_by)] if order_by != '' else []

join_column = list(inspect(getattr(orm_table, field.otherSideName)).property.local_columns)[0]
join_column = list(inspect(
getattr(orm_table, field.otherSideName)).property.local_columns)[0]
subquery = QueryConstruct(
collection=query.collection,
objectformatter=self,
query=orm.Query([]).select_from(orm_table) \
.filter(join_column == getattr(rel_table, rel_table._id)) \
.correlate(rel_table)
.filter(join_column == getattr(rel_table, rel_table._id)) \
.correlate(rel_table)
)
subquery, formatted = self.objformat(subquery, orm_table, formatter_name)
subquery, formatted = self.objformat(subquery, orm_table,
formatter_name)
aggregated = blank_nulls(group_concat(formatted, separator, *order_by))
return subquery.query.add_column(aggregated).as_scalar()

def fieldformat(self, query_field: QueryField, field: blank_nulls) -> blank_nulls:
def fieldformat(self, query_field: QueryField,
field: blank_nulls) -> blank_nulls:
field_spec = query_field.fieldspec
if field_spec.get_field() is not None:
if field_spec.is_temporal() and field_spec.date_part == "Full Date":
Expand All @@ -216,25 +249,29 @@ def _dateformat(self, specify_field, field):
prec_fld = getattr(field.class_, specify_field.name + 'Precision', None)

format_expr = \
case({2: self.date_format_month, 3: self.date_format_year}, prec_fld, else_=self.date_format) \
if prec_fld is not None \
else self.date_format
case({2: self.date_format_month, 3: self.date_format_year},
prec_fld, else_=self.date_format) \
if prec_fld is not None \
else self.date_format

return func.date_format(field, format_expr)

def _fieldformat(self, specify_field: Field, field: Union[InstrumentedAttribute, Extract]):
def _fieldformat(self, specify_field: Field,
field: Union[InstrumentedAttribute, Extract]):
if specify_field.type == "java.lang.Boolean":
return field != 0

if specify_field.type in ("java.lang.Integer", "java.lang.Short"):
return field

if specify_field is CollectionObject_model.get_field('catalogNumber') \
and self.catalog_number_is_numeric():
return cast(field, types.Numeric(65)) # 65 is the mysql max precision
and self.catalog_number_is_numeric():
return cast(field,
types.Numeric(65)) # 65 is the mysql max precision

return field


def get_date_format() -> str:
match = re.search(r'ui\.formatting\.scrdateformat=(.+)', get_remote_prefs())
date_format = match.group(1).strip() if match is not None else 'yyyy-MM-dd'
Expand Down Expand Up @@ -298,24 +335,24 @@ def get_date_format() -> str:
}

LDLM_TO_MYSQL = {
"MM dd yy": "%m %d %y",
"MM dd yy": "%m %d %y",
"MM dd yyyy": "%m %d %Y",
"MM-dd-yy": "%m-%d-%y",
"MM-dd-yy": "%m-%d-%y",
"MM-dd-yyyy": "%m-%d-%Y",
"MM.dd.yy": "%m.%d.%y",
"MM.dd.yy": "%m.%d.%y",
"MM.dd.yyyy": "%m.%d.%Y",
"MM/dd/yy": "%m/%d/%y",
"MM/dd/yy": "%m/%d/%y",
"MM/dd/yyyy": "%m/%d/%Y",
"dd MM yy": "%d %m %y",
"dd MM yy": "%d %m %y",
"dd MM yyyy": "%d %m %Y",
"dd MMM yyyy":"%d %b %Y",
"dd-MM-yy": "%d-%m-%y",
"dd MMM yyyy": "%d %b %Y",
"dd-MM-yy": "%d-%m-%y",
"dd-MM-yyyy": "%d-%m-%Y",
"dd-MMM-yyyy":"%d-%b-%Y",
"dd.MM.yy": "%d.%m.%y",
"dd-MMM-yyyy": "%d-%b-%Y",
"dd.MM.yy": "%d.%m.%y",
"dd.MM.yyyy": "%d.%m.%Y",
"dd.MMM.yyyy":"%d.%b.%Y",
"dd/MM/yy": "%d/%m/%y",
"dd.MMM.yyyy": "%d.%b.%Y",
"dd/MM/yy": "%d/%m/%y",
"dd/MM/yyyy": "%d/%m/%Y",
"dd/MMM/yyy": "%d/%b/%Y",
"yyyy MM dd": "%Y %m %d",
Expand Down

0 comments on commit db7f014

Please sign in to comment.