In [741]:
import numpy as np
import os
import torch
import torch.nn as nn
import time
import pandas as pd
from scipy.stats import pearsonr

In [742]:
from model.util import Normalizer
from model.database_util import get_hist_file, get_job_table_sample, collator
from model.model import QueryFormer
from model.database_util import Encoding, Batch
from model.dataset import PlanTreeDataset

In [743]:
data_path = './data/imdb/'

In [744]:
class Args:
    pass

In [745]:
hist_file = get_hist_file('histograms.csv')
cost_norm = Normalizer(-3.61192, 12.290855)

You are setting values through chained assignment. Currently this works in certain cases, but when using Copy-on-Write (which will become the default behaviour in pandas 3.0) this will never work to update the original DataFrame or Series, because the intermediate object on which we are setting values will behave as a copy.
A typical example is when you are setting values in a column of a DataFrame, like:

df["col"][row_indexer] = value

Use `df.loc[row_indexer, "col"] = values` instead, to perform the assignment in a single step and ensure this keeps updating the original `df`.

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy

  hist_file['freq'][i] = freq_np
You are setting values through chained assignment. Currently this works in certain cases, but when using Copy-on-Write (which will become the default behaviour in pandas 3.0) this will never work to update the original DataFrame or Series, 

In [746]:
hist_file.head()

Unnamed: 0,table,column,bins,table_column,freq
0,nation,n_nationkey,"[0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, ...",n.n_nationkey,"[1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, ..."
1,nation,n_name,"[24309792, 73180151, 122050510, 127186136, 128...",n.n_name,"[1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, ..."
2,nation,n_regionkey,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, ...",n.n_regionkey,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, ..."
3,nation,n_comment,"[9944096, 15206094, 20468093, 65980060, 115151...",n.n_comment,"[1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, ..."
4,region,r_regionkey,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, ...",r.r_regionkey,"[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."


In [747]:
encoding_ckpt = torch.load('checkpoints/tpch_encoding.pt')
encoding = encoding_ckpt['encoding']
checkpoint = torch.load('checkpoints/cost_model.pt', map_location='cpu')

  encoding_ckpt = torch.load('checkpoints/tpch_encoding.pt')
  checkpoint = torch.load('checkpoints/cost_model.pt', map_location='cpu')


In [748]:
from model.util import seed_everything
seed_everything()

In [749]:
args = checkpoint['args']


In [750]:
model = QueryFormer(emb_size = args.embed_size ,ffn_dim = args.ffn_dim, head_size = args.head_size, \
                 dropout = args.dropout, n_layers = args.n_layers, \
                 use_sample = True, use_hist = True, \
                 pred_hid = args.pred_hid
                )

62


In [751]:
pretrained_dict = checkpoint['model']
model_dict = model.state_dict()

# Filter out weights in pretrained_dict that don't match in size with model_dict
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and v.size() == model_dict[k].size()}

# Update the model's state dict with the matching pretrained weights
model_dict.update(pretrained_dict)

# Load the updated state dict into the model
model.load_state_dict(model_dict)


<All keys matched successfully>

In [752]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
_ = model.to(device).eval()

In [753]:
to_predict = 'cost'

In [754]:
methods = {
    'get_sample' : get_job_table_sample,
    'encoding': encoding,
    'cost_norm': cost_norm,
    'hist_file': hist_file,
    'model': model,
    'device': device,
    'bs': 512,
}

In [755]:
import json

def print_plan(plan):
    # pretty print the json object

    # the following code parses the json string into a dictionary
    json_parsed = json.loads(plan)
    json_pretty = json.dumps(json_parsed, indent=4)
    print(json_pretty)

    with open('output.json', 'w') as f:
        f.write(json_pretty)



In [756]:
def print_qerror(preds_unnorm, labels_unnorm):
    qerror = []
    for i in range(len(preds_unnorm)):
        print("Predicted: {}, Actual: {}".format(preds_unnorm[i], labels_unnorm[i]))
        if preds_unnorm[i] > float(labels_unnorm[i]):
            qerror.append(preds_unnorm[i] / float(labels_unnorm[i]))
        else:
            qerror.append(float(labels_unnorm[i]) / float(preds_unnorm[i]))

    e_50, e_90 = np.median(qerror), np.percentile(qerror,90)    
    e_mean = np.mean(qerror)
    print("Median: {}".format(e_50))
    print("90th percentile: {}".format(e_90))
    print("Mean: {}".format(e_mean))
    return 

def get_corr(ps, ls): # unnormalised
    ps = np.array(ps)
    ls = np.array(ls)
    corr, _ = pearsonr(np.log(ps), np.log(ls))
    
    return corr

In [757]:
def evaluate(model, ds, bs, norm, device):
    model.eval()
    cost_predss = np.empty(0)

    with torch.no_grad():
        for i in range(0, len(ds), bs):
            batch, batch_labels = collator(list(zip(*[ds[j] for j in range(i,min(i+bs, len(ds)) ) ])))


            batch = batch.to(device)

            cost_preds, _ = model(batch)

            cost_preds = cost_preds.squeeze()

            cost_predss = np.append(cost_predss, cost_preds.cpu().detach().numpy())


    print_qerror(norm.unnormalize_labels(cost_predss), ds.costs)

    if len(cost_predss) > 2:  
        corr = get_corr(norm.unnormalize_labels(cost_predss), ds.costs)
        print('Corr: ', corr)

    return 

In [758]:
def eval_workload(workload, methods):

    get_table_sample = methods['get_sample']

    workload_file_name = 'tpch'
    output_file_name = '{}_output.csv'.format(workload)

    table_sample = get_table_sample(workload_file_name)

    plan_df = pd.read_csv('query_plans.csv')
    print_plan(plan_df['json'][0])
    workload_csv = pd.read_csv('tpch.csv',sep='#',header=None)
    workload_csv.columns = ['table','join','predicate','cardinality']

    workload_csv.to_csv(output_file_name, index=False)

    
    ds = PlanTreeDataset(plan_df, workload_csv, \
        methods['encoding'], methods['hist_file'], methods['cost_norm'], \
        methods['cost_norm'], 'cost', table_sample)
    
    

    evaluate(methods['model'], ds, methods['bs'], methods['cost_norm'], methods['device'])
    return 

In [759]:
eval_workload('tpch', methods)

Loaded queries with len  2
Loaded bitmaps
{
    "Plan": {
        "Node Type": "Aggregate",
        "Strategy": "Sorted",
        "Partial Mode": "Finalize",
        "Parallel Aware": false,
        "Async Capable": false,
        "Startup Cost": 184469.24,
        "Total Cost": 184471.2,
        "Plan Rows": 6,
        "Plan Width": 236,
        "Actual Startup Time": 3645.51,
        "Actual Total Time": 3646.542,
        "Actual Rows": 4,
        "Actual Loops": 1,
        "Group Key": [
            "l_returnflag",
            "l_linestatus"
        ],
        "Plans": [
            {
                "Node Type": "Gather Merge",
                "Parent Relationship": "Outer",
                "Parallel Aware": false,
                "Async Capable": false,
                "Startup Cost": 184469.24,
                "Total Cost": 184470.64,
                "Plan Rows": 12,
                "Plan Width": 236,
                "Actual Startup Time": 3645.496,
                "Actual Total 

In [760]:
%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [761]:
bins = hist_file.loc[hist_file['table_column']=='l.l_discount','bins'].item()
print(encoding.idx2col)

{0: 'c.c_custkey', 1: 'c.c_name', 2: 'c.c_address', 3: 'c.c_nationkey', 4: 'c.c_phone', 5: 'c.c_acctbal', 6: 'c.c_mktsegment', 7: 'c.c_comment', 61: 'NA', 8: 'l.l_orderkey', 9: 'l.l_partkey', 10: 'l.l_suppkey', 11: 'l.l_linenumber', 12: 'l.l_quantity', 13: 'l.l_extendedprice', 14: 'l.l_discount', 15: 'l.l_tax', 16: 'l.l_returnflag', 17: 'l.l_linestatus', 18: 'l.l_shipdate', 19: 'l.l_commitdate', 20: 'l.l_receiptdate', 21: 'l.l_shipinstruct', 22: 'l.l_shipmode', 23: 'l.l_comment', 24: 'n.n_nationkey', 25: 'n.n_name', 26: 'n.n_regionkey', 27: 'n.n_comment', 28: 'o.o_orderkey', 29: 'o.o_custkey', 30: 'o.o_orderstatus', 31: 'o.o_totalprice', 32: 'o.o_orderdate', 33: 'o.o_orderpriority', 34: 'o.o_clerk', 35: 'o.o_shippriority', 36: 'o.o_comment', 37: 'p.p_partkey', 38: 'p.p_name', 39: 'p.p_mfgr', 40: 'p.p_brand', 41: 'p.p_type', 42: 'p.p_size', 43: 'p.p_container', 44: 'p.p_retailprice', 45: 'p.p_comment', 46: 'ps.ps_partkey', 47: 'ps.ps_suppkey', 48: 'ps.ps_availqty', 49: 'ps.ps_supplycost