diff --git a/hawc/apps/common/exports.py b/hawc/apps/common/exports.py new file mode 100644 index 0000000000..6693a06ad2 --- /dev/null +++ b/hawc/apps/common/exports.py @@ -0,0 +1,220 @@ +import pandas as pd +from django.db.models import QuerySet + +from .helper import FlatExport + + +class ModelExport: + """Model level export module for use in Exporter class.""" + + def __init__( + self, + key_prefix: str = "", + query_prefix: str = "", + include: tuple[str, ...] | None = None, + exclude: tuple[str, ...] | None = None, + ): + """Instantiate an exporter instance for a given django model. + + Args: + key_prefix (str, optional): The model name to prepend to data frame columns. + query_prefix (str, optional): The model prefix in the ORM. + include (tuple | None, optional): If included, only these items are added. + exclude (tuple | None, optional): If specified, items are removed from base. + """ + self.key_prefix = key_prefix + "-" if key_prefix else key_prefix + self.query_prefix = query_prefix + "__" if query_prefix else query_prefix + self.include = (key_prefix + field for field in include) if include else tuple() + self.exclude = (key_prefix + field for field in exclude) if exclude else tuple() + + @property + def value_map(self) -> dict: + """Value map of column names to ORM field names. + + This caches the result from get_value_map and applies any prefixes + to the column names and ORM field names. It is also filtered down + in compliance with any include/exclude parameters. + + Returns: + dict: Value map + """ + if hasattr(self, "_value_map"): + return self._value_map + + value_map = self.get_value_map() + # add key prefix + if self.key_prefix: + value_map = {self.key_prefix + k: v for k, v in value_map.items()} + # add query prefix + if self.query_prefix: + value_map = {k: self.query_prefix + v for k, v in value_map.items()} + # handle any includes + if self.include: + value_map = {k: v for k, v in value_map.items() if k in self.include} + # handle any excludes + if self.exclude: + value_map = {k: v for k, v in value_map.items() if k not in self.exclude} + + self._value_map = value_map + return self._value_map + + @property + def annotation_map(self) -> dict: + """Annotation map of annotated names to ORM expressions. + + This caches the result from get_annotation_map and applies any + query_prefix to the annotated names. It is also filtered down + in compliance with any include/exclude parameters. + + Returns: + dict: Annotation map + """ + if hasattr(self, "_annotation_map"): + return self._annotation_map + + annotation_map = self.get_annotation_map(self.query_prefix) + # add query prefix + if self.query_prefix: + annotation_map = {self.query_prefix + k: v for k, v in annotation_map.items()} + # handle any includes/excludes + if self.include or self.exclude: + annotation_map = { + k: v for k, v in annotation_map.items() if k in self.value_map.values() + } + + self._annotation_map = annotation_map + return self._annotation_map + + def get_value_map(self) -> dict: + """Value map of column names to ORM field names. + + This should be overridden by any subclass where applicable. + Prefixes and include/exclude should not be handled in this method; + they are handled by the value_map property. + + Returns: + dict: Value map + """ + return {} + + def get_annotation_map(self, query_prefix: str) -> dict: + """Annotation map of annotated names to ORM expressions. + + This should be overridden by any subclass where applicable. + query_prefix for the annotated names and any include/exclude parameters + are handled by the annotation_map property. + query_prefix should still be used in the custom ORM expression + values though, since there is no way to apply that through the + annotation_map property. + + Returns: + dict: Annotation map + """ + return {} + + def get_column_name(self, name: str) -> str: + """Get column name with key_prefix applied. + + Args: + name (str): Column name + + Returns: + str: Column name with prefix + """ + return f"{self.key_prefix}{name}" + + def prepare_qs(self, qs: QuerySet) -> QuerySet: + """Prepare the queryset for export. + + This includes applying any annotations if they exist. + + Args: + qs (QuerySet): Queryset to prepare + + Returns: + QuerySet: Prepared queryset + """ + if self.annotation_map: + return qs.annotate(**self.annotation_map) + return qs + + def prepare_df(self, df: pd.DataFrame) -> pd.DataFrame: + """Prepare the dataframe for export. + + This should be overridden by any subclass where applicable. + Any data manipulations that couldn't be done by the ORM + should be done in this method. + + Args: + df (pd.DataFrame): Dataframe to manipulate + + Returns: + pd.DataFrame: Manipulated dataframe + """ + return df + + def get_df(self, qs: QuerySet) -> pd.DataFrame: + """Get dataframe export from queryset. + + Args: + qs (QuerySet): Queryset + + Returns: + pd.DataFrame: Dataframe + """ + qs = self.prepare_qs(qs) + df = pd.DataFrame( + data=qs.values_list(*self.value_map.values()), columns=list(self.value_map.keys()) + ) + return self.prepare_df(df) + + +class Exporter: + """Data export for querysets. + + This class runs multiple ModelExports on a queryset + and outputs a dataframe through the get_df method. + """ + + def build_modules(self) -> list[ModelExport]: + """ModelExport instances to use for exporter. + + This should be overridden by any subclass. + A key_prefix and query_prefix should be given to + each ModelExport so that the column names don't clash + and the ORM correctly navigates relationships. + + Returns: + list[ModelExport]: List of ModelExports to build export with + """ + raise NotImplementedError() + + def get_df(self, qs: QuerySet) -> pd.DataFrame: + """Get dataframe export from queryset. + + Args: + qs (QuerySet): Queryset + + Returns: + pd.DataFrame: Dataframe + """ + self._modules = self.build_modules() + for module in self._modules: + qs = module.prepare_qs(qs) + values = [value for module in self._modules for value in module.value_map.values()] + keys = [key for module in self._modules for key in module.value_map.keys()] + df = pd.DataFrame(data=qs.values_list(*values), columns=keys) + for module in self._modules: + df = module.prepare_df(df) + return df + + @classmethod + def flat_export(cls, qs: QuerySet, filename: str) -> FlatExport: + """Return an instance of a FlatExport. + + Args: + qs (QuerySet): the initial QuerySet + filename (str): the filename for the export + """ + df = cls().get_df(qs) + return FlatExport(df=df, filename=filename) diff --git a/hawc/apps/common/models.py b/hawc/apps/common/models.py index 78b548be75..6d2e281efc 100644 --- a/hawc/apps/common/models.py +++ b/hawc/apps/common/models.py @@ -10,8 +10,8 @@ from django.core.exceptions import ObjectDoesNotExist, SuspiciousOperation from django.core.files.storage import FileSystemStorage from django.db import IntegrityError, connection, models, router, transaction -from django.db.models import Case, CharField, Choices, Q, QuerySet, URLField, Value, When -from django.db.models.functions import Coalesce +from django.db.models import Case, CharField, Choices, Q, QuerySet, TextField, URLField, Value, When +from django.db.models.functions import Coalesce, Concat from django.template.defaultfilters import slugify as default_slugify from django.utils.html import strip_tags from treebeard.mp_tree import MP_Node @@ -534,6 +534,30 @@ def sql_display(name: str, Choice: type[Choices]) -> Case: ) +def sql_format(format_str: str, *field_params) -> Concat: + """Create an ORM expression to simulate a format string. + + Args: + format_str (str): Format string. Any {} present in the string + will be replaced by field_params. + + Returns: + Concat: An expression that generates a string + """ + value_params = format_str.split("{}") + if format_str.count("{}") != len(field_params): + raise ValueError("field params must be equal to value params.") + replace_num = len(field_params) + concat_args = [] + for i in range(replace_num): + if value_params[i]: + concat_args.append(Value(value_params[i])) + concat_args.append(field_params[i]) + if remainder := "".join(value_params[replace_num:]): + concat_args.append(Value(remainder)) + return Concat(*concat_args, output_field=TextField()) + + def replace_null(field: str, replacement: str = ""): """Replace null values with a replacement string diff --git a/hawc/apps/epiv2/api.py b/hawc/apps/epiv2/api.py index eb721de8ff..9066f3336b 100644 --- a/hawc/apps/epiv2/api.py +++ b/hawc/apps/epiv2/api.py @@ -40,8 +40,8 @@ def export(self, request, pk): .published_only(published_only) .complete() ) - exporter = exports.EpiFlatComplete(qs, filename=f"{assessment}-epi") - return Response(exporter.build_export()) + exporter = exports.EpiV2Exporter.flat_export(qs, filename=f"{assessment}-epi") + return Response(exporter) @action( detail=True, diff --git a/hawc/apps/epiv2/exports.py b/hawc/apps/epiv2/exports.py index 46def8fd2a..ddf8cddc03 100644 --- a/hawc/apps/epiv2/exports.py +++ b/hawc/apps/epiv2/exports.py @@ -1,7 +1,233 @@ +import pandas as pd +from django.db.models import CharField, F, Func, Value + from hawc.apps.common.helper import FlatFileExporter -from hawc.apps.study.models import Study -from . import models +from ..common.exports import Exporter, ModelExport +from ..common.models import sql_display, sql_format, str_m2m, to_display_array +from ..study.exports import StudyExport +from . import constants + + +class DesignExport(ModelExport): + def get_value_map(self): + return { + "pk": "pk", + "url": "url", + "summary": "summary", + "study_name": "study_name", + "study_design": "study_design_display", + "source": "source_display", + "age_profile": "age_profile_string", + "age_description": "age_description", + "sex": "sex_display", + "race": "race", + "participant_n": "participant_n", + "years_enrolled": "years_enrolled", + "years_followup": "years_followup", + "countries": "countries__name", + "region": "region", + "criteria": "criteria", + "susceptibility": "susceptibility", + "comments": "comments", + "created": "created", + "last_updated": "last_updated", + } + + def get_annotation_map(self, query_prefix): + return { + "url": sql_format("/epidemiology/design/{}/", query_prefix + "pk"), # hardcoded URL + "study_design_display": sql_display( + query_prefix + "study_design", constants.StudyDesign + ), + "source_display": sql_display(query_prefix + "source", constants.Source), + "age_profile_string": Func( + F(query_prefix + "age_profile"), + Value(", "), + Value(""), + function="array_to_string", + output_field=CharField(max_length=256), + ), + "sex_display": sql_display(query_prefix + "sex", constants.Sex), + "countries__name": str_m2m(query_prefix + "countries__name"), + } + + def prepare_df(self, df): + df.loc[:, self.get_column_name("age_profile")] = to_display_array( + df[self.get_column_name("age_profile")], constants.AgeProfile, ", " + ) + return df + + +class ChemicalExport(ModelExport): + def get_value_map(self): + return { + "pk": "pk", + "name": "name", + "DTSXID": "dsstox__dtxsid", + "created": "created", + "last_updated": "last_updated", + } + + +class ExposureExport(ModelExport): + def get_value_map(self): + return { + "pk": "pk", + "name": "name", + "measurement_type": "measurement_type_string", + "biomonitoring_matrix": "biomonitoring_matrix_display", + "biomonitoring_source": "biomonitoring_source_display", + "measurement_timing": "measurement_timing", + "exposure_route": "exposure_route_display", + "measurement_method": "measurement_method", + "comments": "comments", + "created": "created", + "last_updated": "last_updated", + } + + def get_annotation_map(self, query_prefix): + return { + "measurement_type_string": Func( + F(query_prefix + "measurement_type"), + Value(", "), + Value(""), + function="array_to_string", + output_field=CharField(max_length=256), + ), + "biomonitoring_matrix_display": sql_display( + query_prefix + "biomonitoring_matrix", # todo fix default display "?" + constants.BiomonitoringMatrix, + ), + "biomonitoring_source_display": sql_display( + query_prefix + "biomonitoring_source", # todo fix default display "?" + constants.BiomonitoringSource, + ), + "exposure_route_display": sql_display( + query_prefix + "exposure_route", constants.ExposureRoute + ), + } + + +class ExposureLevelExport(ModelExport): + def get_value_map(self): + return { + "pk": "pk", + "name": "name", + "sub_population": "sub_population", + "median": "median", + "mean": "mean", + "variance": "variance", + "variance_type": "variance_type_display", + "units": "units", + "ci_lcl": "ci_lcl", + "percentile_25": "percentile_25", + "percentile_75": "percentile_75", + "ci_ucl": "ci_ucl", + "ci_type": "ci_type_display", + "negligible_exposure": "negligible_exposure", + "data_location": "data_location", + "comments": "comments", + "created": "created", + "last_updated": "last_updated", + } + + def get_annotation_map(self, query_prefix): + return { + "variance_type_display": sql_display( + query_prefix + "variance_type", constants.VarianceType + ), + "ci_type_display": sql_display( + query_prefix + "ci_type", constants.ConfidenceIntervalType + ), + } + + +class OutcomeExport(ModelExport): + def get_value_map(self): + return { + "pk": "pk", + "system": "system_display", + "effect": "effect", + "effect_detail": "effect_detail", + "endpoint": "endpoint", + "comments": "comments", + "created": "created", + "last_updated": "last_updated", + } + + def get_annotation_map(self, query_prefix): + return { + "system_display": sql_display(query_prefix + "system", constants.HealthOutcomeSystem), + } + + +class AdjustmentFactorExport(ModelExport): + def get_value_map(self): + return { + "pk": "pk", + "name": "name", + "description": "description", + "comments": "comments", + "created": "created", + "last_updated": "last_updated", + } + + +class DataExtractionExport(ModelExport): + def get_value_map(self): + return { + "pk": "pk", + "sub_population": "sub_population", + "outcome_measurement_timing": "outcome_measurement_timing", + "effect_estimate_type": "effect_estimate_type", + "effect_estimate": "effect_estimate", + "ci_lcl": "ci_lcl", + "ci_ucl": "ci_ucl", + "ci_type": "ci_type_display", + "units": "units", + "variance_type": "variance_type_display", + "variance": "variance", + "n": "n", + "p_value": "p_value", + "significant": "significant_display", + "group": "group", + "exposure_rank": "exposure_rank", + "exposure_transform": "exposure_transform", + "outcome_transform": "outcome_transform", + "confidence": "confidence", + "data_location": "data_location", + "effect_description": "effect_description", + "statistical_method": "statistical_method", + "comments": "comments", + "created": "created", + "last_updated": "last_updated", + } + + def get_annotation_map(self, query_prefix): + return { + "ci_type_display": sql_display( + query_prefix + "ci_type", constants.ConfidenceIntervalType + ), + "variance_type_display": sql_display( + query_prefix + "variance_type", constants.VarianceType + ), + "significant_display": sql_display(query_prefix + "significant", constants.Significant), + } + + +class EpiV2Exporter(Exporter): + def build_modules(self) -> list[ModelExport]: + return [ + StudyExport("study", "design__study"), + DesignExport("design", "design"), + ChemicalExport("chemical", "exposure_level__chemical"), + ExposureExport("exposure", "exposure_level__exposure_measurement"), + ExposureLevelExport("exposure_level", "exposure_level"), + OutcomeExport("outcome", "outcome"), + AdjustmentFactorExport("adjustment_factor", "factors"), + DataExtractionExport("data_extraction", ""), + ] class EpiFlatComplete(FlatFileExporter): @@ -10,46 +236,5 @@ class EpiFlatComplete(FlatFileExporter): epidemiological meta-result study type from scratch. """ - def _get_header_row(self): - header = [] - header.extend(Study.flat_complete_header_row()) - header.extend(models.Design.flat_complete_header_row()) - header.extend(models.Chemical.flat_complete_header_row()) - header.extend(models.Exposure.flat_complete_header_row()) - header.extend(models.ExposureLevel.flat_complete_header_row()) - header.extend(models.Outcome.flat_complete_header_row()) - header.extend(models.DataExtraction.flat_complete_header_row()) - header.extend(models.AdjustmentFactor.flat_complete_header_row()) - return header - - def get_optimized_queryset(self): - return self.queryset.select_related( - "exposure_level__exposure_measurement", - "exposure_level__chemical__dsstox", - "factors", - "outcome", - "design__study", - ).prefetch_related("design__countries") - - def _get_data_rows(self): - rows = [] - identifiers_df = Study.identifiers_df(self.queryset, "design__study_id") - n_col_factors = len(models.AdjustmentFactor.flat_complete_header_row()) - for obj in self.get_optimized_queryset(): - row = [] - row.extend( - Study.flat_complete_data_row( - obj.design.study.get_json(json_encode=False), identifiers_df - ) - ) - row.extend(obj.design.flat_complete_data_row()) - row.extend(obj.exposure_level.chemical.flat_complete_data_row()) - row.extend(obj.exposure_level.exposure_measurement.flat_complete_data_row()) - row.extend(obj.exposure_level.flat_complete_data_row()) - row.extend(obj.outcome.flat_complete_data_row()) - row.extend(obj.flat_complete_data_row()) - row.extend( - obj.factors.flat_complete_data_row() if obj.factors else [None] * n_col_factors - ) - rows.append(row) - return rows + def build_df(self) -> pd.DataFrame: + return EpiV2Exporter().get_df(self.queryset) diff --git a/hawc/apps/epiv2/models.py b/hawc/apps/epiv2/models.py index 7b2ce5e628..74fe7f14a7 100644 --- a/hawc/apps/epiv2/models.py +++ b/hawc/apps/epiv2/models.py @@ -110,55 +110,6 @@ def get_age_profile_display(self): def __str__(self): return f"{self.summary}" - @staticmethod - def flat_complete_header_row(): - return ( - "design-pk", - "design-url", - "design-summary", - "design-study_name", - "design-study_design", - "design-source", - "design-age_profile", - "design-age_description", - "design-sex", - "design-race", - "design-participant_n", - "design-years_enrolled", - "design-years_followup", - "design-countries", - "design-region", - "design-criteria", - "design-susceptibility", - "design-comments", - "design-created", - "design-last_updated", - ) - - def flat_complete_data_row(self): - return ( - self.pk, - self.get_absolute_url(), - self.summary, - self.study_name, - self.get_study_design_display(), - self.get_source_display(), - self.get_age_profile_display(), - self.age_description, - self.get_sex_display(), - self.race, - self.participant_n, - self.years_enrolled, - self.years_followup, - "|".join(el.name for el in self.countries.all()), - self.region, - self.criteria, - self.susceptibility, - self.comments, - self.created, - self.last_updated, - ) - class Chemical(models.Model): objects = managers.ChemicalManager() @@ -199,25 +150,6 @@ def clone(self): self.save() return self - @staticmethod - def flat_complete_header_row(): - return ( - "chemical-pk", - "chemical-name", - "chemical-DTSXID", - "chemical-created", - "chemical-last_updated", - ) - - def flat_complete_data_row(self): - return ( - self.pk, - self.name, - self.dsstox.dtxsid if self.dsstox else None, - self.created, - self.last_updated, - ) - class Exposure(models.Model): objects = managers.ExposureManager() @@ -282,37 +214,6 @@ def clone(self): self.save() return self - @staticmethod - def flat_complete_header_row(): - return ( - "exposure-pk", - "exposure-name", - "exposure-measurement_type", - "exposure-biomonitoring_matrix", - "exposure-biomonitoring_source", - "exposure-measurement_timing", - "exposure-exposure_route", - "exposure-measurement_method", - "exposure-comments", - "exposure-created", - "exposure-last_updated", - ) - - def flat_complete_data_row(self): - return ( - self.pk, - self.name, - ", ".join(self.measurement_type), - self.get_biomonitoring_matrix_display(), - self.get_biomonitoring_source_display(), - self.measurement_timing, - self.get_exposure_route_display(), - self.measurement_method, - self.comments, - self.created, - self.last_updated, - ) - class ExposureLevel(models.Model): objects = managers.ExposureLevelManager() @@ -423,51 +324,6 @@ def clone(self): self.save() return self - @staticmethod - def flat_complete_header_row(): - return ( - "exposure_level-pk", - "exposure_level-name", - "exposure_level-sub_population", - "exposure_level-median", - "exposure_level-mean", - "exposure_level-variance", - "exposure_level-variance_type", - "exposure_level-units", - "exposure_level-ci_lcl", - "exposure_level-percentile_25", - "exposure_level-percentile_75", - "exposure_level-ci_ucl", - "exposure_level-ci_type", - "exposure_level-negligible_exposure", - "exposure_level-data_location", - "exposure_level-comments", - "exposure_level-created", - "exposure_level-last_updated", - ) - - def flat_complete_data_row(self): - return ( - self.pk, - self.name, - self.sub_population, - self.median, - self.mean, - self.variance, - self.get_variance_type_display(), - self.units, - self.ci_lcl, - self.percentile_25, - self.percentile_75, - self.ci_ucl, - self.get_ci_type_display(), - self.negligible_exposure, - self.data_location, - self.comments, - self.created, - self.last_updated, - ) - class Outcome(models.Model): objects = managers.OutcomeManager() @@ -520,31 +376,6 @@ def clone(self): self.save() return self - @staticmethod - def flat_complete_header_row(): - return ( - "outcome-pk", - "outcome-system", - "outcome-effect", - "outcome-effect_detail", - "outcome-endpoint", - "outcome-comments", - "outcome-created", - "outcome-last_updated", - ) - - def flat_complete_data_row(self): - return ( - self.pk, - self.get_system_display(), - self.effect, - self.effect_detail, - self.endpoint, - self.comments, - self.created, - self.last_updated, - ) - class AdjustmentFactor(models.Model): objects = managers.AdjustmentFactorManager() @@ -585,27 +416,6 @@ def clone(self): self.save() return self - @staticmethod - def flat_complete_header_row(): - return ( - "adjustment_factor-pk", - "adjustment_factor-name", - "adjustment_factor-description", - "adjustment_factor-comments", - "adjustment_factor-created", - "adjustment_factor-last_updated", - ) - - def flat_complete_data_row(self): - return ( - self.pk, - self.name, - self.description, - self.comments, - self.created, - self.last_updated, - ) - class DataExtraction(models.Model): objects = managers.DataExtractionManager() @@ -738,65 +548,6 @@ def clone(self): self.save() return self - @staticmethod - def flat_complete_header_row(): - return ( - "data_extraction-pk", - "data_extraction-sub_population", - "data_extraction-outcome_measurement_timing", - "data_extraction-effect_estimate_type", - "data_extraction-effect_estimate", - "data_extraction-ci_lcl", - "data_extraction-ci_ucl", - "data_extraction-ci_type", - "data_extraction-units", - "data_extraction-variance_type", - "data_extraction-variance", - "data_extraction-n", - "data_extraction-p_value", - "data_extraction-significant", - "data_extraction-group", - "data_extraction-exposure_rank", - "data_extraction-exposure_transform", - "data_extraction-outcome_transform", - "data_extraction-confidence", - "data_extraction-data_location", - "data_extraction-effect_description", - "data_extraction-statistical_method", - "data_extraction-comments", - "data_extraction-created", - "data_extraction-last_updated", - ) - - def flat_complete_data_row(self): - return ( - self.pk, - self.sub_population, - self.outcome_measurement_timing, - self.effect_estimate_type, - self.effect_estimate, - self.ci_lcl, - self.ci_ucl, - self.get_ci_type_display(), - self.units, - self.get_variance_type_display(), - self.variance, - self.n, - self.p_value, - self.get_significant_display(), - self.group, - self.exposure_rank, - self.exposure_transform, - self.outcome_transform, - self.confidence, - self.data_location, - self.effect_description, - self.statistical_method, - self.comments, - self.created, - self.last_updated, - ) - reversion.register(Design, follow=("countries",)) reversion.register(Chemical) diff --git a/hawc/apps/study/exports.py b/hawc/apps/study/exports.py new file mode 100644 index 0000000000..3e6f1a4a8b --- /dev/null +++ b/hawc/apps/study/exports.py @@ -0,0 +1,60 @@ +import numpy as np +import pandas as pd +from django.db.models import Q + +from ..common.exports import ModelExport +from ..common.models import sql_display, sql_format, str_m2m +from ..lit.constants import ReferenceDatabase +from .constants import CoiReported + + +class StudyExport(ModelExport): + def get_value_map(self): + return { + "id": "id", + "hero_id": "hero", + "pubmed_id": "pmid", + "doi": "doi", + "url": "url", + "short_citation": "short_citation", + "full_citation": "full_citation", + "coi_reported": "coi_reported_display", + "coi_details": "coi_details", + "funding_source": "funding_source", + "bioassay": "bioassay", + "epi": "epi", + "epi_meta": "epi_meta", + "in_vitro": "in_vitro", + "eco": "eco", + "study_identifier": "study_identifier", + "contact_author": "contact_author", + "ask_author": "ask_author", + "summary": "summary", + "editable": "editable", + "published": "published", + } + + def get_annotation_map(self, query_prefix): + return { + "pmid": str_m2m( + query_prefix + "identifiers__unique_id", + filter=Q(**{query_prefix + "identifiers__database": ReferenceDatabase.PUBMED}), + ), + "hero": str_m2m( + query_prefix + "identifiers__unique_id", + filter=Q(**{query_prefix + "identifiers__database": ReferenceDatabase.HERO}), + ), + "doi": str_m2m( + query_prefix + "identifiers__unique_id", + filter=Q(**{query_prefix + "identifiers__database": ReferenceDatabase.DOI}), + ), + "coi_reported_display": sql_display(query_prefix + "coi_reported", CoiReported), + "url": sql_format("/study/{}/", query_prefix + "pk"), # hardcoded URL + } + + def prepare_df(self, df): + for key in [self.get_column_name("pubmed_id"), self.get_column_name("hero_id")]: + df[key] = pd.to_numeric(df[key], errors="coerce") + for key in [self.get_column_name("doi")]: + df[key] = df[key].replace("", np.nan) + return df diff --git a/tests/hawc/apps/common/test_models.py b/tests/hawc/apps/common/test_models.py index 2295eeed68..1a8aa942c1 100644 --- a/tests/hawc/apps/common/test_models.py +++ b/tests/hawc/apps/common/test_models.py @@ -2,6 +2,7 @@ # use concrete implementations to test from hawc.apps.animal.models import DoseGroup, Experiment +from hawc.apps.common.models import sql_format from hawc.apps.lit.models import ReferenceFilterTag _nested_names = [ @@ -40,3 +41,20 @@ def test_get_order_by(self): assert Experiment.objects._get_order_by() == ("id",) assert DoseGroup._meta.ordering == ("dose_units", "dose_group_id") assert DoseGroup.objects._get_order_by() == ("dose_units", "dose_group_id") + + +def test_sql_format(): + assert str(sql_format("/left/{}", "foo")) == "Concat(ConcatPair(Value('/left/'), F(foo)))" + assert str(sql_format("{}/right", "foo")) == "Concat(ConcatPair(F(foo), Value('/right')))" + assert ( + str(sql_format("/test/{}/here/", "foo")) + == "Concat(ConcatPair(Value('/test/'), ConcatPair(F(foo), Value('/here/'))))" + ) + assert ( + str(sql_format("/a/{}/b/{}/c/", "foo", "bar")) + == "Concat(ConcatPair(Value('/a/'), ConcatPair(F(foo), ConcatPair(Value('/b/'), ConcatPair(F(bar), Value('/c/'))))))" + ) + + for case in ["/too-few/", "{}", "/too-many/{}/{}/"]: + with pytest.raises(ValueError): + sql_format(case, "foo")