In [37]:
def _validate_relation(row, entry_syntax):

    Lprefix, Rprefix, r = row

    rl = ['r','l']
    
    if (r[0][0].casefold() not in rl) or (r[1][0].casefold() not in rl):
        raise ValueError('r l error')
        return False
    
    series_labels = {'r':entry_syntax.loc[entry_syntax['entry_prefix'] == Rprefix],
                     'l':entry_syntax.loc[entry_syntax['entry_prefix'] == Lprefix]}

    for locator in r[:2]:
        if len(locator) > 1:
            this_label = locator[1:]
            labels = series_labels[locator[0].casefold()]['item_label'].array
            if this_label.isdigit():
                if not str(int(this_label)) in labels:
                    raise ValueError('numerical label error')
                    return False
            elif this_label not in labels:
                print(locator)
                print(labels)
                raise ValueError('alphanumeric label error')
                return False

    if len(r) == 4:
        all1 = ['all', '1']
        specs = r[3].split(':')
        if (specs[0] not in all1) or (specs[1] not in all1):
            raise ValueError('all or 1 error')
            return False

        error_text = ('relation {} includes a list specification for an '
                      'item_label that has no list_delimiter in the label'
                      'map.').format(r)

        for i in [0, 1]:
            locator = r[i]
            if (r[3] == '1:1') or (specs[i] == 'all'):
                if (len(locator) == 1):
                    raise ValueError('list with no item locator')
                    return False
                labels = series_labels[locator[0].casefold()]
                has_delimiter = pd.Series(labels['list_delimiter'].notna(),
                                          index=labels['item_label'])
                if not has_delimiter[locator[1:]]:
                    raise ValueError(error_text)

    return True


def _escape_regex_metachars(s):
    s = s.replace( "\\", "\\\\")
    metachars = '.^$*+?{}[]|()'
    for metachar in [c for c in metachars if c in s]:
        s = s.replace(metachar, f'\\{metachar}')
    return s


def _strip_csv_comments(column, pattern):

    column = column.str.split(pat=pattern, expand=True)
    return column[0]


def _replace_escaped_comment_chars(column, comment_char, pattern):

    return column.replace(to_replace=pattern,
                          value=comment_char,
                          regex=True)


def _normalize_shorthand_input(df,
                               comment_char,
                               fill_cols=(),
                               drop_na=()):
    '''
    THIS FUNCTION MUTATES ITS FIRST ARGUMENT
    '''
    
    if not iterable_not_string(fill_cols):
        fill_cols = [fill_cols]
    if not iterable_not_string(drop_na):
        drop_na = [drop_na]

    required_cols = ['left_entry',
                     'right_entry',
                     'link_tags_or_override',
                     'reference']
    valid_columns = [(c in map(str.casefold, df.columns)) for c in required_cols]
    if not all(valid_columns):
        raise ValueError('shorthand csv files must have a header whose first '
                         'four column labels must be (ignoring case and list '
                         'order):\n >> ["left_entry", "right_entry", '
                         '"link_tags_or_override", "reference"]')
    
    comment_char = _escape_regex_metachars(comment_char)
    unescaped_comment_regex = r"(?<!\\)[{}]".format(comment_char)
    escaped_comment_regex = fr"(\\{comment_char})"

    # mask off cells erroneously created from commas in comments
    has_comment = df.apply(lambda x: x.str.contains(unescaped_comment_regex))
    has_comment = has_comment.fillna(False)
    commented_out = has_comment.where(has_comment).ffill(axis=1).fillna(False)
    commented_out = commented_out ^ has_comment
    #df = _set_StringDtype(df.mask(commented_out)))
    # can't use pd.StringDtype() throughout because it currently doesn't allow
    # construction with null types other than pd.NA. This will likely change
    # soon (https://github.com/pandas-dev/pandas/pull/41412)
    df = df.mask(commented_out)
    
    # split cells where comments start and take the uncommented part
    has_comment = has_comment.any(axis=1)
    df.loc[has_comment, :] = df.loc[has_comment].apply(_strip_csv_comments,
                                                       args=(comment_char,))
    
    # drop rows that began with comments
    df = df.mask(df=='')
    df = df.dropna(how='all')

    # replace escaped comment characters with bare comment characters
    df = df.apply(_replace_escaped_comment_chars,
                  args=(comment_char, escaped_comment_regex))
    
    # optionally forward fill missing values
    for column in fill_cols:
        df.loc[:, column] = df.loc[:, column].ffill()

    # optionally drop lines missing values
    df = df.dropna(subset=drop_na)
   
    return df


def get_single_value(df, column_label, none_ok=False, group_key=None):
    
    if group_key is None:
        message_tail = 'in column {}'.format(column_label)
    else:
        message_tail = ('in column {}, group key {}'
                     .format(column_label, group_key))

    values = df[column_label].value_counts()
    num_values = len(values)

    if num_values == 1:
        return values.index[0]
    elif num_values > 1:
        raise ValueError('found multiple values {}'
                         .format(message_tail))
    elif (num_values < 1):
        if not none_ok:
            raise ValueError('no values found {}'
                             .format(message_tail))
        else:
            return None


def _validate_entry_prefix_group(group):

    get_single_value(group,
                     'entry_node_type',
                     none_ok=True,
                     group_key=group.name)

    series_node_type_na = group['entry_node_type'].isna()
    link_type_na = group['item_link_type'].isna()
    
    if series_node_type_na.all():
        if (~link_type_na).any():
            raise ValueError('Error parsing labels for entry_prefix {}. '
                             'There are no values in column "entry_node '
                             'type" but column "item_link_type" contains '
                             'values.'.format(group.name))
    if link_type_na.all():
        if (~series_node_type_na).any():
            raise ValueError('Error parsing labels for entry_prefix {}. '
                             'There are no values in column "item_link_type" '
                             'but column "entry_node type" contains '
                             'values.'.format(group.name))

    item_node_type_na = group['item_node_type'].isna()
    if item_node_type_na.any():
        if group.loc[item_node_type_na, 'item_link_type'].notna().any():
            raise ValueError('Error parsing labels for entry_prefix {}. '
                             'If any row has no value in column "item_node_'
                             'type" then column "item_link_type" must also '
                             'have no value in that row.'.format(group.name))

    list_delimiter_not_na = group['list_delimiter'].notna()
    if list_delimiter_not_na.any():
        if group.loc[list_delimiter_not_na, 'item_node_type'].isna().any():
            raise ValueError('Error parsing labels for entry_prefix {}. '
                             'If any row has a value in column "list_'
                             'delimiter" then column "item_node_type" must '
                             'also contain a value in that row.'
                             .format(group.name))

    item_prefix_separator_not_na = group['item_prefix_separator'].notna()
    if item_prefix_separator_not_na.any():
        has_item_prefix = group.loc[item_prefix_separator_not_na]
        if has_item_prefix.loc[:, 'item_prefixes'].isna().any():
            raise ValueError('Error parsing labels for entry_prefix {}. '
                             'If any row has a value in column "item_prefix_'
                             'delimiter" then column "item_prefixes" '
                             'must also contain a value in that row.'
                             .format(group.name))
        if has_item_prefix.loc[:, 'item_node_type'].notna().any():
            raise ValueError('Error parsing labels for entry_prefix {}. '
                             'If any row has a value in column "item_prefix '
                             'delimiter" then column "item_node_type" '
                             'cannot contain values in that row.'
                             .format(group.name))
        if has_item_prefix.loc[:, 'item_link_type'].notna().any():
            raise ValueError('Error parsing labels for entry_prefix {}. '
                             'If any row has a value in column "item_prefix '
                             'delimiter" then column "item_link_type" '
                             'cannot contain values in that row.'
                             .format(group.name))

        if group['item_label'].duplicated().any():
            raise ValueError('Error parsing labels for entry_prefix {}. '
                             'Rows with the same entry prefix must have '
                             'different item labels.'
                             .format(group.name))


def validate_entry_rules(entry_rules):
    get_single_value(entry_rules, 'item_separator')
    get_single_value(entry_rules, 'default_entry_prefix')

    if entry_rules['item_label'].isna().any():
        raise ValueError('Error parsing entry rules. All rows must have a '
                         'value in column "item_label"')

    entry_rules.groupby('entry_prefix').apply(_validate_entry_prefix_group)

In [38]:
from bibliograph.util import iterable_not_string
'''
def _expand_grouped_entries(group,
                            entry_syntax,
                            item_separator,
                            default_prefix,
                            na_values):

    expanded = group.str.rstrip(item_separator)

    if pd.isna(group.name):
        grp_prefix = default_prefix
        
    else:
        grp_prefix = group.name
        slice_start = len(grp_prefix + item_separator)
        expanded = expanded.str.slice(start=slice_start)
    
    unescaped_sep_regex = r"(?<!\\)[{}]".format(']['.join(item_separator))
    escaped_sep_regex = fr"(\\{item_separator})"

    expanded = expanded.str.split(pat=unescaped_sep_regex, expand=True)
    expanded = expanded.replace(to_replace=escaped_sep_regex,
                                value=item_separator,
                                regex=True)
    for na in na_values:
        expanded = expanded.replace(to_replace=na, value=pd.NA)
                                                      
    expanded.columns = [str(c) for c in expanded.columns]

    csv_rows = expanded.index.get_level_values(0)
    csv_cols = expanded.index.get_level_values(1)
    grp_pfix = [grp_prefix]*len(expanded)
    expanded.index = pd.MultiIndex.from_arrays((csv_rows,
                                                csv_cols,
                                                grp_pfix))

    entry_labels = entry_syntax.query('entry_prefix == @grp_prefix')

    row_with_item_prefix = entry_labels['item_node_type'].isna()

    if row_with_item_prefix.any():

        prefixed_items = entry_labels.loc[row_with_item_prefix]
        prefixed_item_labels = prefixed_items['item_label'].array
        prefixes = entry_labels.loc[~entry_labels['item_label'].str.isdigit(),
                                    'item_label']
        
        disagged = expanded[prefixed_item_labels].stack()
        
        def item_prefix_splitter(group):

            item_label = str(group.name)
            this_item = prefixed_items.query('item_label == @item_label')
            separator = get_single_value(this_item, 'item_prefix_separator')
            default_pfix = get_single_value(this_item, 'item_prefixes')
            default_pfix = default_pfix.split()[0]

            item_pfixes_w_sep = [p + separator for p in prefixes]
            default_pfix = default_pfix + separator

            prefixes_regex = '|'.join([f'^({p})' for p in item_pfixes_w_sep])
            has_no_pfix = ~group.str.match(prefixes_regex)
            pfixed = group.copy()
            pfixed.loc[has_no_pfix] = pfixed.loc[has_no_pfix] \
                                            .apply(lambda x: default_pfix + x)
                                 
            return pfixed.str.split(separator, n=1, expand=True)

        disagged = disagged.groupby(level=3).apply(item_prefix_splitter)
        disagged.index = disagged.index \
                                 .set_levels([grp_prefix], level=2) \
                                 .droplevel(3)
        disagged = disagged.pivot(columns=0)
        disagged.columns = disagged.columns.get_level_values(1)
        
        unprefixed_item_labels = [label for label in entry_labels['item_label']
                                  if label.isdigit()]
        unprefixed_item_labels = [label for label in unprefixed_item_labels 
                                        if label in expanded.columns
                                        and label not in prefixed_item_labels]

        expanded = expanded[unprefixed_item_labels]
        
        return pd.concat([expanded, disagged], axis='columns')

    else:
        return expanded
'''
def item_prefix_splitter(item_grp, prefixed_items):
    '''
    Dear Future Me: This function was refactored to ignore an
    argument "expected_prefixes" that was previously passed in.
    If you run into trouble you might have to revisit that.
    '''

    # get the labels for the item corresponding to the current
    # item_label = str(item_grp.name)
    # this_item = prefixed_items.query('item_label == @item_label')
    item_syntax = prefixed_items.query('item_label == @item_grp.name')
    item_pfx_separator = get_single_value(
        item_syntax,
        'item_prefix_separator'
    )
    # default_pfix = util.get_single_value(item_syntax, 'item_prefixes')
    # default_pfix = default_pfix.split()[0]
    # Get the default prefix and the set of possible prefixes
    # for this item from the entry syntax
    expected_prefixes = get_single_value(
        item_syntax,
        'item_prefixes'
    )
    expected_prefixes = expected_prefixes.split()
    default_pfix = expected_prefixes[0]

    # add the item prefix separator onto the prefixes
    expected_prefixes = [
        p + item_pfx_separator for p in expected_prefixes
    ]
    default_pfix = default_pfix + item_pfx_separator

    # make a regular expression that will match any of the
    # possible prefixes for this item
    prefixes_regex = '|'.join([f'^({p})' for p in expected_prefixes])
    has_no_pfix = ~item_grp.str.match(prefixes_regex)

    # Add the default prefix onto unprefixed instances of this
    # item. Copy the group first so we don't mutate a grouped
    # object during iteration.
    pfixed = item_grp.copy()
    pfixed.loc[has_no_pfix] = pfixed.loc[has_no_pfix] \
                                    .apply(lambda x: default_pfix + x)

    # Split the group on the prefix separator and expand into a
    # dataframe, then return the dataframe
    return pfixed.str.split(item_pfx_separator, n=1, expand=True)


def _expand_grouped_entries(entry_grp,
                            entry_syntax,
                            item_separator,
                            default_prefix,
                            na_values):
    '''
    Takes a pandas.Series of stacked strings representing shorthand
    entries and expands them into their component items according to
    rules defined by the entry syntax.

    Helper function for parse_entries.

    Parameters
    ----------
    entry_grp : pandas.Series
        A group of entries generated by pandas.Series.groupby. This
        series has a multiindex generated by pandas.DataFrame.stack

    entry_syntax : pandas.DataFrame
        A dataframe containing item node types and item labels for each
        type of entry.

    item_separator : str
        A string separating items within an entry string.

    default_prefix : str
        Entries with no prefix will be interpreted as having the default
        prefix.

    na_values : list-like
        Items within an entry whose values are in na_values will be
        replaced with pandas.NA

    Returns
    -------
    pandas.DataFrame
        A DataFrame the same length as the input with columns defined
        by the item_label column in entry_syntax.
    '''
    # strip any trailing separators from each entry this operation
    # creates a copy of the input group, so we're free to mutate the
    # copy
    expanded = entry_grp.str.rstrip(item_separator)

    if pd.isna(entry_grp.name):
        grp_prefix = default_prefix

    else:
        grp_prefix = entry_grp.name
        slice_start = len(grp_prefix + item_separator)
        expanded = expanded.str.slice(start=slice_start)

    # regular expressions to locate bare and escaped item separators
    unescaped_sep_regex = r"(?<!\\)[{}]".format(']['.join(item_separator))
    escaped_sep_regex = fr"(\\{item_separator})"
    # split on bare item separators and expand into a dataframe
    expanded = expanded.str.split(pat=unescaped_sep_regex, expand=True)
    # replace escaped item separators with the bare value
    expanded = expanded.replace(to_replace=escaped_sep_regex,
                                value=item_separator,
                                regex=True)
    # replace missing values with pd.NA
    for na in na_values:
        expanded = expanded.replace(to_replace=na, value=pd.NA)

    # column values were integers after str.split(pat, expand=True)
    # above but they need to be strings because string-valued item
    # labels are allowed
    expanded.columns = [str(c) for c in expanded.columns]

    # the group has a multiindex whose values are rows and columns in
    # the csv input file we're parsing
    # csv_rows = expanded.index.get_level_values(0)
    # csv_cols = expanded.index.get_level_values(1)
    # grp_pfix = [grp_prefix]*len(expanded)
    # expanded.index = pd.MultiIndex.from_arrays(
    #     (csv_rows, csv_cols, grp_pfix)
    # )

    # group prefixes are required to sort out entry relations later, so
    # add an index level for the group prefix
    expanded['grp_prefix'] = [grp_prefix]*len(expanded)
    expanded = expanded.set_index('grp_prefix', append=True)

    # get item labels and node types for this entry prefix
    item_labels = entry_syntax.query('entry_prefix == @grp_prefix')

    # items with no node type in the entry syntax are prefixed to
    # indicate which node type they correspond to
    item_is_prefixed = item_labels['item_node_type'].isna()

    if item_is_prefixed.any():

        prefixed_items = item_labels.loc[item_is_prefixed]
        labels_of_prefixed_items = prefixed_items['item_label'].array
        # We can identify item prefixes because they are not allowed to
        # be numeric strings. Item labels (as opposed to prefixes) must
        # be numeric strings becase entries are string-valued delimited
        # lists of items, so the item labels are positional indexes in
        # the list.
        # expected_prefixes = item_labels.loc[
        #    ~item_labels['item_label'].str.isdigit(),
        #    'item_label'
        # ]

        # stack the prefixed items into a series
        disagged = expanded[labels_of_prefixed_items].stack()

        # Split the prefixes off of the stacked items and expand into a
        # dataframe
        disagged = disagged.groupby(level=3).apply(
            item_prefix_splitter,
            prefixed_items
        )

        # disagged.index = disagged.index.set_levels([grp_prefix], level=2)
        # drop the item labels from the multiindex so the disaggregated
        # items align with the index of the entry group
        disagged.index = disagged.index.droplevel(3)

        # pivot the disaggregated items to create a dataframe with
        # columns for each item prefix
        disagged = disagged.pivot(columns=0)
        disagged.columns = disagged.columns.get_level_values(1)

        # unprefixed_item_labels = [label for label in item_labels['item_label']
        #                           if label.isdigit()]
        # get labels of items that are not prefixed and present in this
        # dataset
        unprefixed_item_labels = [
            label for label in item_labels['item_label']
            if label.isdigit()
            and label in expanded.columns
            and label not in labels_of_prefixed_items
        ]

        # select only the unprefixed item labels
        expanded = expanded[unprefixed_item_labels]
        # concatenate the unprefixed and prefixed items and return
        return pd.concat([expanded, disagged], axis='columns')

    else:
        # there were no prefixed items, so return the expanded items
        # directly
        return expanded

In [39]:
import pandas as pd
from bibliograph.util import iterable_not_string

def _set_StringDtype(df):
    # can't use pd.StringDtype() throughout because it currently doesn't allow
    # construction with null types other than pd.NA. This will likely change
    # soon (https://github.com/pandas-dev/pandas/pull/41412)
    df_cols = df.columns
    return df.astype(dict(zip(df_cols, [pd.StringDtype()]*len(df_cols))))

def parse_shorthand_entries(entries, entry_syntax, space_placeholder, na_values):

    '''
    THIS FUNCTION MUTATES ITS FIRST ARGUMENT
    '''
    space_placeholder = _escape_regex_metachars(space_placeholder)
    if iterable_not_string(na_values):
        na_values = [_escape_regex_metachars(v) for v in na_values]
    else:
        na_values = [_escape_regex_metachars(na_values)]

    default_entry_prefix = get_single_value(entry_syntax, 'default_entry_prefix')
    default_entry_prefix = _escape_regex_metachars(default_entry_prefix)
    item_separator = get_single_value(entry_syntax, 'item_separator')
    item_separator = _escape_regex_metachars(item_separator)

    pfixes_w_sep = [p + item_separator 
                    for p in entry_syntax['entry_prefix'].drop_duplicates()]

    # use a regex to split tags off the input strings
    tag_sep_regex = r"(?<!\\)[{}][ ]".format(']['.join(item_separator))
    entries = entries.str.split(pat=tag_sep_regex, expand=True)

    if len(entries.columns) == 1:
        entries = entries.rename(columns={0:'string'})
        entries['node_tags'] = pd.NA
    else:
        entries = entries.rename(columns={0:'string', 1:'node_tags'})
    
    prefixes_regex = '|'.join([f'^({p})' for p in pfixes_w_sep])
    has_prefix = entries['string'].str.match(prefixes_regex)
    prefixes = entries.loc[has_prefix, 'string'].str.split(item_separator,
                                                           expand=True)
    entries.loc[has_prefix, 'entry_prefix'] = prefixes[0].array
    
    expanded = entries['string'].groupby(by=entries['entry_prefix'],
                                         dropna=False,
                                         group_keys=False)
    expanded = expanded.apply(_expand_grouped_entries,
                              entry_syntax,
                              item_separator,
                              default_entry_prefix,
                              na_values)
   
    # join tag strings back onto values in the 'string' column to recover
    # the original input strings
    where_tags = entries['node_tags'].notna()
    entries.loc[where_tags, 'string'] = entries.loc[where_tags, 'string'] \
                                        + item_separator + ' ' \
                                        + entries.loc[where_tags, 'node_tags']
    
    entry_node_types = entry_syntax.loc[:, ['entry_prefix', 'entry_node_type']]
    entry_node_types = entry_node_types.drop_duplicates()
    entry_node_types = pd.Series(entry_node_types['entry_node_type'].array,
                                 index=entry_node_types['entry_prefix'].array)

    entry_prefixes =  entries.loc[:, 'entry_prefix']
    entry_prefixes.loc[entry_prefixes.isna()] = default_entry_prefix

    entries['node_type'] = entry_node_types[entry_prefixes].array
    #entries['node_type'] = entries['node_type'].astype(pd.StringDtype())
    # can't use pd.StringDtype() throughout because it currently doesn't allow
    # construction with null types other than pd.NA. This will likely change
    # soon (https://github.com/pandas-dev/pandas/pull/41412)

    entries = entries.set_index('entry_prefix', append=True, drop=True)
    entries['item_label'] = pd.NA
    entries = entries.set_index('item_label', append=True, drop=True)

    expanded = expanded.mask(expanded=='', pd.NA)

    plchldr_regex = r"(?<!\\)({})".format(space_placeholder)
    esc_plchldr_regex = fr"(\\{space_placeholder})"
    
    #expanded = expanded.apply(lambda x: x.str.replace(pat=plchldr_regex, 
    #                                                  repl=' ',
    #                                                 regex=True))
    expanded = expanded.replace(to_replace=plchldr_regex,
                                value=' ',
                                regex=True)
    #expanded = expanded.apply(lambda x: x.str.replace(pat=esc_plchldr_regex, 
    #                                                  repl=space_placeholder,
    #                                                  regex=True))
    expanded = expanded.replace(to_replace=esc_plchldr_regex,
                                value=space_placeholder,
                                regex=True)

    expanded = expanded.stack()

    item_label_idx = pd.MultiIndex.from_arrays((entry_syntax['entry_prefix'],
                                                entry_syntax['item_label']))
    item_types = pd.DataFrame({'node_type':entry_syntax['item_node_type'].array,
                               'link_type':entry_syntax['item_link_type'].array},
                              index=item_label_idx)


    #expanded.name = 'string'
    #expanded = pd.DataFrame(expanded)
    #cols = ['node_type', 'item_link_type']
    
    #expanded[cols] = item_types.loc[expanded.index.droplevel([0,1])].to_numpy()
    entry_item_idx = expanded.index.droplevel([0, 1])
    item_types = item_types.loc[entry_item_idx].set_index(expanded.index)
    expanded = pd.concat([expanded.rename('string'), item_types], axis=1)
    #expanded = _set_StringDtype(expanded)

    return pd.concat([expanded, entries]).sort_index().fillna(pd.NA)
    #expanded.append(entries).sort_index().fillna(pd.NA)


In [40]:
import bibliograph as bg

string_id_dtype = pd.UInt32Dtype()

def _create_id_map(domain, map_na = False, **kwargs):
    '''
    Maps distinct values in a domain to a range of integers.  Additional
    keyword arguments are passed to the pandas.Series constructor when
    the map series is created.

    Parameters
    ----------
    domain : list-like (coercible to pandas.Series)
        Arbitrary set of values to map. May contain duplicates.

    map_na : bool, default False
        Map nan values to an integer.

    Returns
    -------
    pandas.Series
        Series whose length is the number of distinct values in the 
        input domain.

    Examples
    --------
    >>> import pandas as pd
    >>> dom = ['a', 'a', 'b', pd.NA, 'f', 'b']
    >>> _create_id_map(dom, dtype=pd.UInt32Dtype())

    a    0
    b    1
    f    2
    dtype: UInt32
    
    >>> _create_id_map(dom, map_na=True, dtype=pd.UInt32Dtype())
    
    a       0
    b       1
    <NA>    2
    f       3
    dtype: UInt32
    '''
    # check if domain object has a str attribute like a pandas.Series
    # and convert if not
    try:
        assert domain.str
        # make a copy so we can mutate one (potentially large) object
        # instead of creating additional references
        domain = domain.copy()
    except AttributeError:
        domain = pd.Series(domain)

    if not map_na:
        domain = domain.loc[~domain.isna()]
        
    value_is_distinct = ~domain.duplicated()
    num_distinct = value_is_distinct.value_counts()[True]

    id_map = pd.Series(
        range(num_distinct), 
        index = domain.loc[value_is_distinct].array,
        **kwargs
    )

    return id_map

def _extend_id_map(domain,
                   existing_domain,
                   existing_range = None,
                   map_na = False,
                   **kwargs):
    '''
    Map distinct values in a domain which are not in an existing domain
    to integers that do not overlap with an existing range. Additional
    keyword arguments are passed to the pandas.Series constructor when
    the map series is created.

    Parameters
    ----------
    domain : list-like (coercible to pandas.Series)
        Arbitrary set of values to map. May contain duplicates.

    existing_domain : list-like
        Set of values already present in a map from values to integers.

    existing_range : list-like or None, default None
        Range of integers already present in the range of a map. If
        None, assume existing_domain.index contains the existing range.

    map_na : bool, default False
        Map nan values to an integer.
    
    Examples
    --------
    >>> existing_domain = pd.Series(['a', 'a', 'b', 'f', 'b'])
    >>> new_domain = ['a', 'b', 'z', pd.NA]
    >>> _extend_id_map(new_domain,
                       existing_domain,
                       dtype = pd.UInt16Dtype())

    z    5
    dtype: UInt16

    >>> _extend_id_map(new_domain,
                       existing_domain,
                       dtype = pd.UInt16Dtype(),
                       map_na = True)

    z       5
    <NA>    6
    dtype: UInt16
    '''
    # check if domain object has a str attribute like a pandas.Series
    # and convert if not
    try:
        assert domain.str
        # make a copy so we can mutate one (potentially large) object
        # instead of creating additional references
        domain = domain.copy()
    except AttributeError:
        domain = pd.Series(domain)

    if not map_na:
        domain = domain.loc[~domain.isna()]

    domain_is_new = ~domain.isin(existing_domain)

    if domain_is_new.any():
        domain = domain.loc[domain_is_new].drop_duplicates()

        if existing_range is None:
            new_ids = bg.non_intersecting_sequence(
                len(domain),
                existing_domain.index
            )
        else:
            new_ids = bg.non_intersecting_sequence(
                len(domain),
                existing_range
            )

    else:
        domain = []
        new_ids = []

    return pd.Series(new_ids, index = domain, **kwargs)

def _split_item_lists(item_group, entry_syntax, prefix_id_map, item_id_map):
    '''
    THIS FUNCTION MUTATES ITS FIRST ARGUMENT
    '''

    if pd.isna(item_group.name[1]):
        return item_group
    else:
        item_label = item_group.name[1]

    item_label = item_id_map.loc[item_id_map == item_label].index[0]
    
    if pd.isna(item_label):
        return item_group
    
    prefix = item_group.name[0]
    prefix = prefix_id_map.loc[prefix_id_map == prefix].index[0]
    label_row = entry_syntax.query('entry_prefix == @prefix') \
                            .query('item_label == @item_label') \
                            .squeeze()

    delimiter = label_row['list_delimiter']

    if pd.notna(delimiter):
        item_group.loc[:, 'string'] = item_group['string'].str.split(delimiter)
        #item_group['string'] = item_group['string'].astype(pd.StringDtype())
        # can't use pd.StringDtype() throughout because it currently doesn't 
        # allow construction with null types other than pd.NA. This will
        # likely change soon (https://github.com/pandas-dev/pandas/pull/41412)
        return item_group
    else:
        return item_group


def _make_item_list_indexes(index_group):
    grp_len = len(index_group)
    if grp_len == 1:
        list_position = pd.Series([pd.NA]*grp_len)
    else:
        list_position = pd.Series(range(grp_len))
    return list_position.astype(type_id_dtype)

In [41]:
import pandas as pd

type_id_dtype = pd.UInt8Dtype()
string_id_dtype = pd.UInt32Dtype()

entry_syntax = pd.read_csv('./bibliograph/resources/label_map_new.csv')
entry_syntax = entry_syntax.query('context == "manual"')

validate_entry_rules(entry_syntax)

########################################
# Load a shorthand file and normalize it
########################################

skiprows = 2
comment_char = '#'
fill_cols = 'left_entry'
drop_na = 'right_entry'

data = pd.read_csv('./bibliograph/test_data/manual_annotation.shnd',
                   skiprows = skiprows,
                   #header = 
                   #comment = comment_char,
                   skipinitialspace = True)
#data = _set_StringDtype(data)
# can't use pd.StringDtype() throughout because it currently doesn't 
# allow construction with null types other than pd.NA. This will
# likely change soon (https://github.com/pandas-dev/pandas/pull/41412)
data = _normalize_shorthand_input(data,
                                  comment_char=comment_char,
                                  fill_cols=fill_cols,
                                  drop_na=drop_na)

# Get any metadata for links between entries
entry_link_has_override_or_tags = data['link_tags_or_override'].notna()
if entry_link_has_override_or_tags.any():
    rows_w_link_override_or_tags = data.loc[entry_link_has_override_or_tags] \
                                       .index

#has_three_entries = data['reference'].notna()
#if has_three_entries.any():
#    rows_w_three_entries = data.loc[has_three_entries].index

###################################################
# Stack shorthand entries and delete original input
###################################################

# replace text column labels with integers so we compute on integer
# indexes
csv_column_id_map = _create_id_map(list(data.columns), dtype = pd.UInt8Dtype())
data.columns = csv_column_id_map.array    

entry_cols = csv_column_id_map[['left_entry', 'right_entry', 'reference']]
stacked_entries = data[entry_cols].stack()
del(data)

###################################################
# Parse distinct shorthand entry strings
###################################################

# Store the csv row indexes for duplicates so we can reconstruct assertions
# about entries later

entry_is_duplicated = stacked_entries.duplicated()

if entry_is_duplicated.any():

    distinct_entries = stacked_entries.loc[~entry_is_duplicated, :]

    dplct_entries = stacked_entries.loc[entry_is_duplicated, :]

    distinct_map = pd.Series(
        distinct_entries.index.to_flat_index(),
        index = distinct_entries.array
    )
    dplct_entries = dplct_entries.map(distinct_map)
    dplct_entries = pd.DataFrame(
        tuple(dplct_entries.array),
        columns = ['string_csv_row', 'string_csv_col'],
        index = dplct_entries.index
    )
    
    del(distinct_map)

    dplct_entries = dplct_entries.reset_index()
    dplct_entries = dplct_entries.rename(
        columns = {'level_0':'entry_csv_row',
                   'level_1':'entry_csv_col'}
    )

    dplct_entries = dplct_entries.astype({
        'entry_csv_row': string_id_dtype,
        'entry_csv_col': type_id_dtype,
        'string_csv_row': string_id_dtype,
        'string_csv_col': type_id_dtype
    })

else:

    distinct_entries = stacked_entries

del(stacked_entries)

distinct_entries = distinct_entries.str.strip()

parsed_shnd = parse_shorthand_entries(
    distinct_entries,
    entry_syntax,
    space_placeholder='|',
    na_values='x'
)
parsed_shnd = parsed_shnd.reset_index()
parsed_shnd = parsed_shnd.rename(
    columns={'level_0':'csv_row',
             'level_1':'csv_col',
             'grp_prefix':'entry_prefix',
             'level_3':'item_label'}
)

csv_idx_dtypes = {'csv_row':string_id_dtype,
                  'csv_col':pd.UInt8Dtype()}
parsed_shnd = parsed_shnd.astype(csv_idx_dtypes)
#for c in ['entry_prefix', 'item_label']:
#    parsed_shnd[c] = parsed_shnd[c].astype(pd.StringDtype())

In [42]:
parsed_shnd

Unnamed: 0,csv_row,csv_col,entry_prefix,item_label,string,node_type,link_type,node_tags
0,1,0,wrk,0,asmith_bwu,actors,author,
1,1,0,wrk,1,1999,dates,published,
2,1,0,wrk,3,101,works,volume,
3,1,0,wrk,4,803,works,page,
4,1,0,wrk,5,xxx,identifiers,doi,
5,1,0,wrk,s,bams,works,supertitle,
6,1,0,wrk,,asmith_bwu__1999__bams__101__803__xxx,works,,
7,1,1,wrk,0,asmith_bwu,actors,author,
8,1,1,wrk,1,1998,dates,published,
9,1,1,wrk,3,100,works,volume,


In [43]:
####################################################
# Map integer IDs for entry prefixes and item labels
####################################################
entry_pfx_id_map = _create_id_map(
    parsed_shnd['entry_prefix'],
    dtype = type_id_dtype
)

parsed_shnd['entry_prefix'] = parsed_shnd['entry_prefix'].map(entry_pfx_id_map)

item_label_id_map = _create_id_map(
    parsed_shnd['item_label'],
    dtype = type_id_dtype
)

parsed_shnd['item_label'] = parsed_shnd['item_label'].map(item_label_id_map)

#######################################
# Split strings that are lists of items
#######################################
parsed_shnd = parsed_shnd.groupby(by=['entry_prefix', 'item_label'],
                                  dropna=False)
parsed_shnd = parsed_shnd.apply(_split_item_lists,
                                entry_syntax,
                                entry_pfx_id_map,
                                item_label_id_map)

parsed_shnd = parsed_shnd.explode('string')

item_list_position = pd.Series(parsed_shnd.index).groupby(parsed_shnd.index)
item_list_position = item_list_position.apply(_make_item_list_indexes)
parsed_shnd['item_list_position'] = item_list_position.array

parsed_shnd = parsed_shnd.reset_index(drop=True)

######################################################
# Map string and node IDs, separate strings from nodes
######################################################
string_id_map = _create_id_map(parsed_shnd['string'], dtype = string_id_dtype)
parsed_shnd['string_id'] = parsed_shnd['string'].map(string_id_map)

del(string_id_map)

# assume distinct strings are distinct nodes
parsed_shnd['node_id'] = parsed_shnd['string_id'].array

strings = parsed_shnd.loc[:, ['string_id', 'node_id', 'string']]
strings = strings.drop_duplicates()

parsed_shnd = parsed_shnd.drop('string', axis=1)

################
# Get link types
################
link_types = _create_id_map(
    parsed_shnd['link_type'],
    dtype = type_id_dtype
)

parsed_shnd['link_type_id'] = parsed_shnd['link_type'].map(link_types)
parsed_shnd = parsed_shnd.drop('link_type', axis=1)

link_types = link_types.reset_index()
link_types = link_types[[0, 'index']]
link_types = link_types.rename(
    columns = {0: 'link_type_id', 'index': 'link_type'}
)

################
# Get node types
################
node_types = _create_id_map(parsed_shnd['node_type'], dtype = type_id_dtype)

parsed_shnd['node_type_id'] = parsed_shnd['node_type'].map(node_types)
parsed_shnd = parsed_shnd.drop('node_type', axis=1)

node_types = node_types.reset_index()
node_types = node_types[[0, 'index']]
node_types = node_types.rename(
    columns = {0: 'node_type_id', 'index': 'node_type'}
)

################
# Get node tags
################
node_tags = parsed_shnd.loc[parsed_shnd['node_tags'].notna(),
                            ['node_id', 'node_tags']]

parsed_shnd = parsed_shnd.drop('node_tags', axis=1)

parsed_shnd.loc[~parsed_shnd['link_type_id'].isna()]
print(link_types)
parsed_shnd.iloc[:30]

   link_type_id   link_type
0             0      author
1             1   published
2             2      volume
3             3        page
4             4         doi
5             5  supertitle
6             6       title
7             7      funder


Unnamed: 0,csv_row,csv_col,entry_prefix,item_label,item_list_position,string_id,node_id,link_type_id,node_type_id
0,1,0,0,0.0,0.0,0,0,0.0,0
1,1,0,0,0.0,1.0,1,1,0.0,0
2,1,0,0,1.0,,2,2,1.0,1
3,1,0,0,2.0,,3,3,2.0,2
4,1,0,0,3.0,,4,4,3.0,2
5,1,0,0,4.0,,5,5,4.0,3
6,1,0,0,5.0,,6,6,5.0,2
7,1,0,0,,,7,7,,2
8,1,1,0,0.0,0.0,0,0,0.0,0
9,1,1,0,0.0,1.0,1,1,0.0,0


In [44]:
import pandas as pd
a = pd.DataFrame({'a': [1, 2, 3], 'g': ['alpha', 'beta', 'gamma']})
b = pd.Series({'a': 4, 'g': 'delta'})
print(pd.concat([a, pd.DataFrame(b).T]))

b = pd.DataFrame({'a': [4], 'g': ['delta']})
print(pd.concat([a, b]))

b = [a['a'].max() + 1, 'delta']
b = pd.DataFrame(
    dict(
        zip(
            a.columns,
            [[e] for e in b]
        )
    )
)
pd.concat([a, b])

   a      g
0  1  alpha
1  2   beta
2  3  gamma
0  4  delta
   a      g
0  1  alpha
1  2   beta
2  3  gamma
0  4  delta


Unnamed: 0,a,g
0,1,alpha
1,2,beta
2,3,gamma
0,4,delta


In [46]:
import bibliograph as bg

###############################################################
# SETUP ENTRY RELATIONS
###############################################################
relations = pd.read_csv('./bibliograph/resources/entry_relations.csv',
                        skipinitialspace=True)

# TODO: implement relations file validation
# _validate_entry_relations(relations)

# Select relations with entry prefix pairs that exist in this data set
prefix_cols = ['left_entry_prefix', 'right_entry_prefix']

for column in prefix_cols:
    relations = relations.loc[relations[column].isin(entry_pfx_id_map.index)]
    relations[column] = relations[column].map(entry_pfx_id_map)

# Get any new link types in these entry relations
new_link_types = _extend_id_map(relations['link_type'],
                                link_types['link_type'],
                                link_types['link_type_id'],
                                dtype = type_id_dtype)
new_link_types = pd.DataFrame({'link_type_id':new_link_types.array,
                               'link_type':new_link_types.index})
link_types = pd.concat([link_types, new_link_types])

relations['link_type_id'] = relations['link_type'].map(
    pd.Series(
        link_types['link_type_id'].array,
        index = link_types['link_type']
    )
)

relations = relations.drop('link_type', axis=1)

# Parse the position codes from the entry relations
def _make_node_locators(rltn_column):
    position_indicator = relations[rltn_column]

    csv_col = (position_indicator.str.slice(0, 1) == 'R')
    csv_col = csv_col.mask(position_indicator.isna())

    item_label = position_indicator.str.slice(1)
    item_label = item_label.mask(item_label == '')

    return csv_col, item_label

columns_and_prefixes = {'source_position':'src_',
                        'target_position':'tgt_',
                        'dflt_ref_position':'ref_'}
for column, prefix in columns_and_prefixes.items():
    position_indicator = relations[column]
    relations[prefix + 'csv_col'] = position_indicator.str.slice(0, 1).array
    item_label = position_indicator.str.slice(1)
    item_label = item_label.mask(item_label == '')
    relations[prefix + 'item_label'] = item_label
    relations = relations.drop(column, axis=1)

# Validate item labels in the relations file and then map them
item_label_cols = [col for col in relations.columns if 'item_label' in col]
item_labels = relations[item_label_cols]
item_labels = item_labels.stack()

new_item_labels = _extend_id_map(item_labels,
                                 item_label_id_map.index,
                                 item_label_id_map.array,
                                 dtype = type_id_dtype)

if not new_item_labels.empty:
    raise ValueError('Found the following item labels in the entry relations '
                     'file which are expected from the entry prefixes but do '
                     'not exist in the input data:\n{}'
                     .format(new_item_labels))

for column in item_label_cols:
    relations[column] = relations[column].map(item_label_id_map)

relations

Unnamed: 0,left_entry_prefix,right_entry_prefix,list_mode,link_type_id,src_csv_col,src_item_label,tgt_csv_col,tgt_item_label,ref_csv_col,ref_item_label
0,0,0,,8,L,,R,,L,
1,0,3,1:1,9,L,0.0,R,0.0,L,
2,0,1,,7,R,0.0,L,,L,
3,0,1,m:m,7,R,0.0,L,0.0,L,
4,0,1,,7,R,1.0,L,,L,
5,0,1,m:m,7,R,1.0,L,0.0,L,
6,0,2,m:m,10,L,,R,0.0,L,
7,0,4,,11,L,,R,0.0,L,
8,0,5,,12,L,,R,0.0,,


In [49]:
def _make_entry_assertions(grp, dropna=False):
    lt_isna = grp['link_type_id'].isna()
    lt_notna = ~lt_isna
    if lt_notna.any():
        src = grp.loc[lt_isna, 'string_id'].squeeze()
        return pd.Series([src]*len(grp))
    else:
        return pd.Series([pd.NA]*len(grp))


assertions = parsed_shnd.groupby(by=['csv_row', 'csv_col'], group_keys=False)
assertions = assertions.apply(_make_entry_assertions)
assertions.index = parsed_shnd.index
has_link = ~parsed_shnd['link_type_id'].isna()
print(list(map(len, [assertions.loc[has_link], parsed_shnd.loc[has_link]])))

assertions = pd.DataFrame(
    {'src_string_id': assertions.loc[has_link].array,
     'tgt_string_id': parsed_shnd.loc[has_link, 'string_id'],
     'link_type_id': parsed_shnd.loc[has_link, 'link_type_id']}
)
assertions

[40, 40]


Unnamed: 0,src_string_id,tgt_string_id,link_type_id
0,7,0,0
1,7,1,0
2,7,2,1
3,7,3,2
4,7,4,3
5,7,5,4
6,7,6,5
8,12,0,0
9,12,1,0
10,12,8,1


In [50]:
###############################################################
# MAKE ASSERTIONS
###############################################################

def _make_assertions_from_entry_relations():


    def _get_prefix_pairs(side):
        '''
        Reconstruct prefix pairs, including duplicates not present in
        the parsed shorthand
        '''

        csv_col = csv_column_id_map[side + '_entry']
        label = side + '_entry_prefix'
        relevant_cols = ['csv_row', 'csv_col', 'entry_prefix', 'item_label']

        entries = parsed_shnd[relevant_cols].query('item_label.isna()') \
                                            .query('csv_col == @csv_col')
        entries = entries.drop(['item_label'], axis=1)

        entries = entries.rename(
            columns = {'entry_prefix': label,
                       'csv_row': 'string_csv_row',
                       'csv_col': 'string_csv_col'}
        )
        
        on_side_dplcts = dplct_entries.query('entry_csv_col == @csv_col')
        on_side_dplcts = entries.merge(on_side_dplcts, how='right')

        all_pfxs = pd.concat([entries, on_side_dplcts])
        
        # "string" csv indexes locating distinct strings are different
        # from "entry" csv indexes, which locate entries with
        # potentially duplicate string values. Copy the string csv
        # indexes to fill in missing data after merger with duplicates.
        missing_entry = all_pfxs['entry_csv_row'].isna()

        string_csv_cols = ['string_csv_row', 'string_csv_col']
        string_csv_idx = all_pfxs.loc[missing_entry, string_csv_cols]

        entry_csv_cols = ['entry_csv_row', 'entry_csv_col']
        all_pfxs.loc[missing_entry, entry_csv_cols] = string_csv_idx.to_numpy()

        return all_pfxs
        

    def _copy_cross_duplicates(L_prefixes, R_prefixes):
        '''
        THIS FUNCTION MUTATES BOTH ARGUMENTS

        gets prefixes for duplicate entries in one column whose distinct
        string values are present in the other column
        '''

        L_missing_pfix = L_prefixes['left_entry_prefix'].isna()
        L_subset = ['string_csv_col', 'string_csv_row', 'left_entry_prefix']

        R_missing_pfix = R_prefixes['right_entry_prefix'].isna()
        R_subset = ['string_csv_col', 'string_csv_row', 'right_entry_prefix']

        # copy values from left to right
        to_copy = L_prefixes[L_subset].drop_duplicates()
        to_fill = R_prefixes.loc[R_missing_pfix, R_subset]

        cross_dplcts = to_copy.merge(to_fill,
                                     on=['string_csv_row', 'string_csv_col'],
                                     how='right')
        cross_dplcts = cross_dplcts['left_entry_prefix'].array

        R_prefixes.loc[R_missing_pfix, 'right_entry_prefix'] = cross_dplcts

        # copy values from right to left
        to_copy = R_prefixes[R_subset].drop_duplicates()
        to_fill = L_prefixes.loc[L_missing_pfix, L_subset]

        cross_dplcts = to_copy.merge(to_fill,
                                     on=['string_csv_row', 'string_csv_col'],
                                     how='right')
        cross_dplcts = cross_dplcts['right_entry_prefix'].array

        L_prefixes.loc[L_missing_pfix, 'left_entry_prefix'] = cross_dplcts

    L_prefixes = _get_prefix_pairs('left')
    R_prefixes = _get_prefix_pairs('right')

    _copy_cross_duplicates(L_prefixes, R_prefixes)

    L_prefixes = L_prefixes.drop('entry_csv_col', axis=1)
    R_prefixes = R_prefixes.drop('entry_csv_col', axis=1)

    L_prefixes = L_prefixes.rename(
        columns = {'string_csv_row':'L_str_csv_row',
                   'string_csv_col':'L_str_csv_col'}
    )

    R_prefixes = R_prefixes.rename(
        columns = {'string_csv_row':'R_str_csv_row',
                   'string_csv_col':'R_str_csv_col'}
    )

    prefix_pairs = L_prefixes.merge(R_prefixes)


    def _merge_relation_component_w_parsed_data(component_prefix,
                                                relations_selector,
                                                links=False):

        relevant_relations_cols = ['left_entry_prefix',
                                   'right_entry_prefix',
                                   component_prefix + '_csv_col',
                                   component_prefix + '_item_label']
        
        if links:
            relevant_relations_cols.append('link_type_id')
        
        selected_relations = relations.loc[relations_selector]

        component = prefix_pairs.merge(
            selected_relations[relevant_relations_cols],
            on = ['left_entry_prefix', 'right_entry_prefix']
        )

        component_csv_col = component_prefix + '_csv_col'
        component = component.loc[~component[component_csv_col].isna()]

        csv_labels = ['csv_row', 'csv_col']
        is_L = (component[component_csv_col] == 'L')
        is_R = (component[component_csv_col] == 'R')

        left_indexes = component[['L_str_csv_row', 'L_str_csv_col']].loc[is_L]
        left_indexes.columns = csv_labels
        right_indexes = component[['R_str_csv_row', 'R_str_csv_col']].loc[is_R]
        right_indexes.columns = csv_labels

        csv_indexes = pd.concat([left_indexes, right_indexes]).sort_index()
        
        component[['csv_row', 'csv_col']] = csv_indexes.to_numpy()
        del(csv_indexes)
        
        component = component.drop(
            ['left_entry_prefix',
             'right_entry_prefix',
             'L_str_csv_row',
             'L_str_csv_col',
             'R_str_csv_row',
             'R_str_csv_col',
             component_csv_col],
            axis=1
        )

        component = component.rename(
            columns = {component_prefix + '_item_label': 'item_label'}
        )

        shnd_cols = ['csv_row',
                     'csv_col',
                     'item_label',
                     'item_list_position',
                     'string_id']

        component = component.merge(parsed_shnd[shnd_cols], 
                                    on=['csv_row', 'csv_col', 'item_label'],
                                    how='left')

        return component


    ##########################################################
    # Get assertion string IDs for relations whose sources and 
    # targets are matched one to one
    ##########################################################
    relation_is_nonlist = relations['list_mode'].isna()
    relation_is_one_to_one = (relations['list_mode'] == '1:1')
    relation_selector = relation_is_nonlist | relation_is_one_to_one

    sources = _merge_relation_component_w_parsed_data(
        'src',
        relation_selector,
        links=True
    )
    sources = sources[['entry_csv_row', 'string_id', 'link_type_id']]
    sources = sources.rename(
        columns = {'string_id': 'src_string_id'}
    )

    targets = _merge_relation_component_w_parsed_data(
        'tgt',
        relation_selector
    )
    targets = targets['string_id'].rename('tgt_string_id')

    if len(sources) != len(targets):
        raise ValueError('Length mismatch between sources and targets for '
                         'one-to-one relations')

    references = _merge_relation_component_w_parsed_data(
        'ref',
        relation_selector
    )
    references = references[['entry_csv_row', 'string_id']]
    references = references.rename(
        columns = {'string_id': 'ref_string_id'}
    )

    one_to_one_assertions = pd.concat([sources, targets], axis=1)
    one_to_one_assertions = one_to_one_assertions.merge(
        references,
        on = 'entry_csv_row',
        how = 'left'
    )
    
    ##########################################################
    # Get assertion string IDs for relations whose sources and 
    # targets are not matched one to one
    ##########################################################
    relation_is_one_to_many = (relations['list_mode'] == '1:m')
    relation_is_many_to_one = (relations['list_mode'] == 'm:1')
    relation_is_many_to_many = (relations['list_mode'] == 'm:m')

    relation_selector = relation_is_one_to_many | relation_is_many_to_one
    relation_selector = relation_selector | relation_is_many_to_many

    sources = _merge_relation_component_w_parsed_data(
        'src',
        relation_selector,
        links=True
    )
    sources = sources[['entry_csv_row', 'string_id', 'link_type_id']]
    sources = sources.drop_duplicates()
    sources = sources.rename(
        columns = {'string_id': 'src_string_id'}
    )

    targets = _merge_relation_component_w_parsed_data(
        'tgt',
        relation_selector
    )
    targets = targets[['entry_csv_row', 'string_id']].drop_duplicates()
    targets = targets.rename(
        columns = {'string_id': 'tgt_string_id'}
    )

    references = _merge_relation_component_w_parsed_data(
        'ref',
        relation_selector
    )
    references = references[['entry_csv_row', 'string_id']].drop_duplicates()
    references = references.rename(
        columns = {'string_id': 'ref_string_id'}
    )

    many_to_many_assertions = sources.merge(targets, on='entry_csv_row') \
                                     .merge(references, on='entry_csv_row')
    many_to_many_assertions = many_to_many_assertions.drop(
        'entry_csv_row',
        axis=1
    )
    
    assertions = pd.concat([one_to_one_assertions, many_to_many_assertions])

    assertions = assertions[
        ['src_string_id',
         'tgt_string_id',
         'ref_string_id',
         'link_type_id']
    ]
    
    return assertions

a = _make_assertions_from_entry_relations()
assertions = pd.concat([a, assertions])
    
assertions = assertions.reset_index(drop=True).reset_index()
assertions = assertions.rename(columns = {'index': 'assertion_id'})
assertions

Unnamed: 0,assertion_id,src_string_id,tgt_string_id,ref_string_id,link_type_id
0,0,7,12,7,8
1,1,52,7,52,8
2,2,7,18,7,8
3,3,7,22,7,8
4,4,7,25,7,8
...,...,...,...,...,...
59,59,52,49,,1
60,60,52,50,,2
61,61,52,16,,3
62,62,52,51,,4


In [8]:
import pandas as pd
from bibliograph.util import non_intersecting_sequence

def normalize_types(to_norm, template, strict=True, continue_idx=True):
    '''
    Create an object from to_norm that can be concatenated with template
    such that the concatenated object and its index will have the same
    dtypes as the template.

    If template is pandas.DataFrame and to_norm is dict-like or
    list-like, convert each element of to_norm to a pandas.Series with
    dtypes that conform to the dtypes of a template dataframe and and
    concatenate them together as columns in a dataframe.

    If template is pandas.Series, convert to_norm to a Series with the
    appropriate array dtype or, if to_norm is dict-like,
    convert its values to an array with the appropriate dtype and if
    both to_norm and template have numeric indexes, also normalize the
    index dtype. If to_norm is dict-like and either to_norm or template
    does not have a numeric index, use to_norm.keys() as the output
    index.

    Parameters
    ----------
    to_norm : dict-like or list-like
        A set of objects to treat as columns of an output
        pandas.DataFrame and whose types will be coerced.

    template : pandas.DataFrame or pandas.Series
        Output dtypes will conform to template.dtypes and the output
        index dtype will be template.index.dtype

    strict : bool, default True
        If True and to_norm is dict-like, only objects whose keys are
        column labels in the template will be included as columns in the
        output dataframe.

        If True, to_norm is list-like, and template has N columns,
        include only the first N elements of to_norm in the output
        dataframe.

        If False and to_norm is dict-like, normalize dtypes for objects
        whose keys are column labels in the template and include the
        other elements of to_norm as columns with dtypes inferred by
        the pandas.Series constructor.

        If False, to_norm is list-like, and template has N columns,
        normalize dtypes for the first N elements of to_norm and include
        the other elements of to_norm as columns with dtypes inferred by
        the pandas.Series constructor. Labels for the extra columns in
        the output dataframe will be integers counting from N.

    continue_index : bool, default True
        If True and template has a numerical index, the index of the
        returned object will be a sequence of integers which fills
        gaps in and/or extends to_norm.index

    Returns
    -------
    pandas.DataFrame
        Has as many rows as the longest element of to_norm.

        If to_norm is dict-like and strict is True, output includes
        only objects in to_norm whose keys are also column labels in the
        template.

        If to_norm is list-like and strict is True, output has the same
        width as the template.

        If strict is False, output has one column for each element in
        to_norm.
    '''

    # get the size of to_norm

    try:
        # to_norm is treated as columns, so if it has a columns
        # attribute, we want the size of that instead of the length
        num_elements = len(to_norm.columns)
    except AttributeError:
        num_elements = len(to_norm)

    try:
        tmplt_columns = template.columns
        # If the template is a dataframe, get its width.
        num_tmplt_columns = len(tmplt_columns)

    except AttributeError:
        # Template is not dataframe-like.
        # Treat it as 1D object with attributes dtype and index
        templt_idx_is_nmrc = pd.api.types.is_numeric_dtype(template.index)

        try:
            assert to_norm.items
            # to_norm is dict-like, so process its keys as an index

            norm_keys_are_nmrc = pd.api.types.is_numeric_dtype(
                to_norm.keys()[0]
            )
            if strict and norm_keys_are_nmrc and templt_idx_is_nmrc:
                index = non_intersecting_sequence(
                    to_norm.keys(),
                    template.index
                )
            else:
                index = to_norm.keys()

            values = to_norm.values()

        except AttributeError:
            if strict and templt_idx_is_nmrc:
                index = non_intersecting_sequence(num_elements, template.index)
            else:
                index = range(num_elements)

            values = to_norm

        index = pd.Index(index, dtype=template.index.dtype)
        return pd.Series(values, index=index, dtype=template.dtype)

    try:
        # check if to_norm is dict-like
        assert to_norm.items
        # optionally get extra columns
        if (num_elements > num_tmplt_columns) and not strict:
            extra_columns = {
                k: v for k, v in to_norm.items() if k not in num_tmplt_columns
            }

    except AttributeError:
        if num_elements < num_tmplt_columns:
            raise ValueError(
                'If to_norm is list-like, to_norm must have at least '
                'as many elements as there are columns in template.'
            )
        # optionally get extra columns
        if (num_elements > num_tmplt_columns) and not strict:
            extra_columns = dict(zip(
                range(num_tmplt_columns, num_elements),
                to_norm[num_tmplt_columns:]
            ))
        # make the list-like dict-like
        to_norm = dict(zip(tmplt_columns, to_norm[:num_tmplt_columns]))

    to_norm = [
        pd.Series(v, dtype=template[k].dtype, name=k)
        for k, v in to_norm.items() if k in tmplt_columns
    ]

    try:
        to_norm += [pd.Series(v, name=k) for k, v in extra_columns.items()]
    except NameError:
        pass

    new_df = pd.concat(to_norm, axis='columns')

    if continue_idx and pd.api.types.is_numeric_dtype(template.index.dtype):
        index = non_intersecting_sequence(new_df.index, template.index)

    new_df.index = pd.Index(index, dtype=template.index.dtype)
    return new_df

a = pd.DataFrame({'a':[1,2,3],'b':[10,20,30]})
a = a.astype({'a':pd.UInt8Dtype(), 'b':pd.UInt8Dtype()})
a.index = a.index.astype(pd.UInt8Dtype())

b = pd.Series({'a':10,'b':100})
#print(b.dtype)
b = normalize_types(b, a)
#pd.concat([a, b]).dtypes

c = pd.Series([1,2,3], dtype=pd.UInt16Dtype())
d = [4]
d = normalize_types(d, c, strict=False)
pd.concat([c, d])


a    UInt8
b    UInt8
dtype: object

In [15]:
a = pd.DataFrame({'a':[1,2,3], 'b':['alpha', 'beta', 'gamma']})
a.astype({'a':pd.UInt16Dtype(), 'b':pd.StringDtype()})
print(a['b'].astype(pd.StringDtype()))
c = pd.Series([10, 20, 30], dtype=pd.UInt8Dtype(), name='ID')

pd.concat([a, c], axis='columns').dtypes


0    alpha
1     beta
2    gamma
Name: b, dtype: string


a      int64
b     object
ID     UInt8
dtype: object