In [None]:
import itertools
import functools
import logging
import pandas as pd
from frozendict import frozendict

from octo_spork.simplify import simplify
from octo_spork.expressions import And, Or, Not, Le, Lt, Ge, Gt, Attribute
from octo_spork.logic import to_dnf, to_cnf

In [None]:
def map_query_df(df, query):
    ''' Pandas engine implementation applying a query to a dataframe.
    Returns an index on the dataframe.
    TODO implement NOT and test things. '''
    if query.expr == 'le':
        return df[query.attribute.name] <= query.value
    if query.expr == 'ge':
        return df[query.attribute.name] >= query.value
    if query.expr == 'lt':
        return df[query.attribute.name] < query.value
    if query.expr == 'gt':
        return df[query.attribute.name] > query.value
    if query.expr == 'and':
        return functools.reduce(
            lambda ind1, ind2: ind1 & ind2,
            (map_query_df(df, clause) for clause in query.clauses))
    if query.expr == 'or':
        return functools.reduce(
            lambda ind1, ind2: ind1 | ind2,
            (map_query_df(df, clause) for clause in query.clauses))

def query_df(df, query):
    ''' Use index from map_query_df to return filtered dataframe. '''
    if query is None:
        return df
    return df[map_query_df(df, query)]

def decompose(query, cached_query):
    ''' Find intersection to filter the cached data and remainder specifying
    any data missing from the cache. '''
    return (
        to_dnf(simplify(to_dnf(And([query, cached_query])))),         # intersection
        to_dnf(simplify(to_dnf(And([query, Not(cached_query)])))))    # remainder

def resolve_cache_steps(cache, query):
    ''' Step through cached queries, sequentially removing contributions. '''
    for cached_query, cached_data in cache.items():
        cached_intersection, cached_remainder = decompose(query, cached_query)
        if cached_intersection is not False:
            logging.warning('Use cache: {}'.format(repr(cached_query)))
            yield query_df(cached_data, cached_intersection), cached_remainder
        query = cached_remainder
        if query is False:
            break

def resolve_cache(cache, query):
    ''' Resolve as much as possible over the cache, returning partial datasets
    and a query object specifying the remainder. '''
    remainder = query
    datasets = []
    for dataset, remainder in resolve_cache_steps(cache, query):
        datasets.append(dataset)
    return datasets, remainder

def resolve_update_cache(cache, query, remote):
    ''' Resolve over the cache, run the remainder query on the given remote,
    adding its result to the cache. '''
    datasets, remainder = resolve_cache(cache, query)
    remainder = to_dnf(simplify(to_dnf(remainder)))
    if remainder is False:
        return pd.concat(datasets)
    cache[remainder] = remote.get(remainder)
    if len(datasets) == 0:
        return cache[remainder]
    return pd.concat(datasets + [cache[remainder]])

def to_recordset(df):
    return set(frozendict(row) for _, row in df.iterrows())

def results_equal(df1, df2):
    return to_recordset(df1) == to_recordset(df2)

In [None]:
_full_df = pd.DataFrame(
    columns=['a', 'b', 'c', 'ind'],
    data=[
        (x, y, z, ind) for ind, (x, y, z) in
        enumerate(itertools.product(range(10), range(10), range(10)))])


class RemoteDataSource(object):
    ''' Static data source ('remote' part): columns, dataframe, etc '''

    def __init__(self):
        self.a, self.b, self.c = [Attribute(name) for name in 'abc']

    def get(self, query):
        logging.warning('Use remote: {}'.format(repr(query)))
        return query_df(_full_df, query)


class Dataset(object):
    ''' The interface part allowing for querying. Indexing with a query returns a
    new Dataset with the same remote, cache, and info, with the query appended to
    any existing query constants (joined with And to get intersection). '''

    def __init__(self, remote, cache, desc, columns, query=None):
        self._remote = remote
        self._cache = cache
        self._desc = desc
        self._columns = {c: Attribute(c) for c in columns}
        self._query = query

    def __repr__(self):
        return '{}\nColumns: {}\nFilter: {}'.format(
            self._desc, ', '.join(self._columns.keys()),
            repr(self._query))

    def __getattr__(self, col):
        if col in self._columns:
            return self._columns[col]
        raise AttributeError('\'{}\' object has no attribute \'{}\''.format(
            self.__class__.__name__, col))

    def __getitem__(self, query):
        if self._query is not None:
            query = And([self._query, query])
        return Dataset(
            remote=self._remote, cache=self._cache, desc=self._desc,
            columns=self._columns.keys(), query=query)

    def get(self):
        ''' Run the current query. '''
        return resolve_update_cache(
            cache=self._cache, query=self._query, remote=self._remote)

# This base object has a clean cache, filtered datasets will be derived from it.
data = Dataset(
    remote=RemoteDataSource(), cache=dict(), columns='abc',
    desc='Really simple 3D data source.')

In [None]:
# Blank cache, query is passed to remote.
data1 = data[Ge(data.b, 8)][Le(data.c, 0)][And([Ge(data.a, 2), Le(data.a, 4)])]
print(data1)
data1.get()

In [None]:
# Repeated query, only the cache is used.
data2 = data[Ge(data.b, 8)][Le(data.c, 0)][And([Ge(data.a, 2), Le(data.a, 4)])]
print(data2)
data2.get()

In [None]:
# Expanded query. Cache is used but an additional remote call must be made.
data3 = data[Ge(data.b, 8)][Le(data.c, 0)][And([Ge(data.a, 2), Le(data.a, 6)])]
print(data3)
data3.get()

In [None]:
# Expanded query. Cache is used but an additional remote call must be made.
data4 = data[Ge(data.b, 8)][Le(data.c, 0)][And([Ge(data.a, 1), Le(data.a, 7)])]
print(data4)
data4.get()

In [None]:
# Verified things were correctly assembled by comparing with a source pull.
remote = RemoteDataSource()
logging.warning(results_equal(data1.get(), remote.get(data1._query)))
logging.warning(results_equal(data2.get(), remote.get(data2._query)))
logging.warning(results_equal(data3.get(), remote.get(data3._query)))
logging.warning(results_equal(data4.get(), remote.get(data4._query)))