# TPCDS Q1 in Python

```SQL
WITH customer_total_return AS
( SELECT
    sr_customer_sk AS ctr_customer_sk,
    sr_store_sk AS ctr_store_sk,
    sum(sr_return_amt) AS ctr_total_return
  FROM store_returns, date_dim
  WHERE sr_returned_date_sk = d_date_sk AND d_year = 2000
  GROUP BY sr_customer_sk, sr_store_sk)
SELECT c_customer_id
FROM customer_total_return ctr1, store, customer
WHERE ctr1.ctr_total_return >
  (SELECT avg(ctr_total_return) * 1.2
  FROM customer_total_return ctr2
  WHERE ctr1.ctr_store_sk = ctr2.ctr_store_sk)
  AND s_store_sk = ctr1.ctr_store_sk
  AND s_state = 'TN'
  AND ctr1.ctr_customer_sk = c_customer_sk
ORDER BY c_customer_id
LIMIT 100
```
High level plan for query:
1. Filter indivudial tables if possible (typically `O(N)`)
    - operations that combine 2 tables (e.g. through cartesian product) are expensive, so we want these to be as small as possible
2. Combine pairs of tables.
3. Treat combined pairs of tables as individual tables and return to (1.)

In [1]:
import pandas as pd
import numpy as np

In [2]:
def read_header(table_name):
    with open("headers/" + table_name, "r") as f:
        return f.read().split("\n")

def read_table(table_name):
    header = read_header(table_name)
    table = pd.read_csv("data/" + table_name + ".csv", names=header, delimiter='|')
    table.tail()
    return table

In [4]:
# Load tables of interest
store_returns = read_table("store_returns")
date_dim = read_table("date_dim")
store = read_table("store")
customer = read_table("customer")

In [5]:
# helper functions
def filtered_cartesian_product(A, B, filter_fn):
    """Returns the rows of the cartesian product of A, B for which filter_fn is True.
    
    Used for SELECTs from multiple tables
    Has O(N M) runtime where N is the number of rows in A and M is the number of rows in B.
    
    Params:
        - A: a table (pandas dataframe object)
        - B: a table (pandas dataframe object)
        - filter_fn(a, b): takes in a potential pair of rows a in A, b in B and returns true
            if concatenating them forms a valid new row
            
    Returns:
        pandas Series object.
    """
    # Note that B.apply could be replaced with a map function over B's rows
    # Get a mapping from A's rows to sets of valid rows in B
    valid_row_table_pairs = map(
        lambda pair: (pair[1], B[B.apply(lambda b: filter_fn(pair[1], b), axis=1)]),
        A.iterrows() # returns index, row pair
    )
    # Transform each A_row -> B_valid_rows mapping to a list of concatenated rows
    concat_row_table_pair = lambda row_tbl: map(
        lambda idx_row: pd.concat([row_tbl[0], idx_row[1]]),
        row_tbl[1].iterrows())
    row_pair_lists = map(concat_row_table_pair, valid_row_table_pairs)
    # Flatten the list of lists
    valid_rows = [i for lst in row_pair_lists for i in lst]
    # Transform list of lists into pandas object
    return pd.concat(valid_rows, axis=1).transpose().reset_index()

def merge_tables(A, B, a_column, b_column):
    """Merges 2 tables where a_column and b_column are equal.
    
    Assumes that a_column and b_column contain unique keys.
    This is *much* faster than filtered_cartesian product.
    However, it can be replaced with a map/reduce:
        1. map over rows of A
        2. select row with matching key in B
        3. return concatenation of row from A with row from B
        4. reduce list of concatenated rows by converting to dataframe
    
    Params:
        - A: a table (pandas dataframe object)
        - B: a table (pandas dataframe object)
        - a_column: string name of column of keys in A
        - b_column: string name of column of keys in B
    
    Returns:
        pandas dataframe
        
    Note:
        - The returned dataframe does not contain a column with the label
            b_column.
    """
    B_renamed  = B.rename(columns={b_column: a_column}, inplace=False)
    return pd.merge(A, B_renamed, on=a_column)

## Construct `customer_total_return`
```SQL
SELECT
    sr_customer_sk AS ctr_customer_sk,
    sr_store_sk AS ctr_store_sk,
    sum(sr_return_amt) AS ctr_total_return
  FROM store_returns, date_dim
  WHERE sr_returned_date_sk = d_date_sk AND d_year = 2000
  GROUP BY sr_customer_sk, sr_store_sk)
SELECT c_customer_id
```

In [6]:
# Filter indivual tables if possible
date_dim_filtered = date_dim[date_dim["d_year"] == 2000]

In [15]:
# Merge date_dim and store_returns
ctr_merged = merge_tables(date_dim_filtered, store_returns, "d_date_sk", "sr_returned_date_sk")

In [16]:
# Group and apply aggregate function
ctr_grouped = ctr_merged.groupby(["sr_customer_sk", "sr_store_sk"])
ctr_summed = ctr_grouped.agg({"sr_return_amt": np.sum})

In [17]:
# Fix indexing
customer_total_return = ctr_summed.reset_index()
customer_total_return.rename(columns={"sr_customer_sk": "ctr_customer_sk",
                                  "sr_store_sk": "ctr_store_sk",
                                  "sr_return_amt": "ctr_total_return"
                                 }, inplace=True)
print(customer_total_return)

       ctr_customer_sk  ctr_store_sk  ctr_total_return
0                  5.0           2.0           4260.39
1                  5.0          10.0            328.85
2                  6.0           4.0           2118.22
3                 16.0           2.0           1411.41
4                 18.0           2.0              0.00
5                 18.0           4.0           2137.09
6                 19.0           4.0            253.29
7                 20.0           8.0              9.40
8                 24.0           1.0           1462.44
9                 24.0          10.0             67.22
10                26.0           1.0           1040.85
11                31.0           2.0           1771.50
12                32.0          10.0            174.80
13                35.0           1.0            536.21
14                35.0           2.0           6631.56
15                36.0           2.0            835.20
16                38.0           2.0           6609.68
17        

In [24]:
# Filter on customer_total_return
ctr1, ctr2 = customer_total_return, customer_total_return

# Could cache mean to speed up.
def filter_avg(r):
    ctr2_filtered = ctr2[ctr2["ctr_store_sk"] == r["ctr_store_sk"]]
    avg = ctr2_filtered["ctr_total_return"].mean()
    return r["ctr_total_return"] > 1.2 * avg

ctr1_filtered = ctr1[ctr1.apply(filter_avg, axis=1)]

In [25]:
store_filtered = store[store["s_state"] == "TN"]

In [26]:
# Merge ctr1 and store
store_ctr1 = merge_tables(store_filtered, ctr1_filtered, "s_store_sk", "ctr_store_sk")

In [28]:
# Merge with customer
store_ctr1_customer = merge_tables(store_ctr1, customer, "ctr_customer_sk", "c_customer_sk")

In [32]:
# SELECT c_customer_id, ORDER BY c_customer_id LIMIT 100
result = store_ctr1_customer.sort_values("c_customer_id")["c_customer_id"][0:100]

In [45]:
print(date_dim[date_dim["d_year"] == 2000].info())

<class 'pandas.core.frame.DataFrame'>
Int64Index: 366 entries, 36523 to 36888
Data columns (total 29 columns):
d_date_sk              366 non-null int64
d_date_id (B)          366 non-null object
d_date                 366 non-null object
d_month_seq            366 non-null int64
d_week_seq             366 non-null int64
d_quarter_seq          366 non-null int64
d_year                 366 non-null int64
d_dow                  366 non-null int64
d_moy                  366 non-null int64
d_dom                  366 non-null int64
d_qoy                  366 non-null int64
d_fy_year              366 non-null int64
d_fy_quarter_seq       366 non-null int64
d_fy_week_seq          366 non-null int64
d_day_name             366 non-null object
d_quarter_name         366 non-null object
d_holiday              366 non-null object
d_weekend              366 non-null object
d_following_holiday    366 non-null object
d_first_dom            366 non-null int64
d_last_dom             366 non-null int64
