In [18]:
import os
import torch
from neo import *
import yaml
from tqdm import tqdm
%load_ext autoreload
%autoreload 2
import pandas as pd
from database_env import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
def get_times(path, env_config, time=False):        
    env = DataBaseEnv(env_config)
    plans = {}
    times = {}
    paths = list(Path(path).glob("*sql.json"))
    paths_queries = [(pa.parts[-1][:-5],pa) for pa in paths]
    paths_queries = sorted(paths_queries, key=lambda x: (int(x[0].split(".")[0][:-1]), x[0].split(".")[0][-1]))
    
    for q,pa in paths_queries :
        plan = Plan()
        plan.load(pa)
        plans[q] = plan
        env.plan = plan
        if env.plan.is_complete:      
            times[q] = env.reward() 
        else:
            times[q] = np.nan
    return plans, times

In [12]:
with open('postgres_env_config.json', "r") as f:
    env_config = json.load(f)

with open('config_latency.yml', 'r') as file:
    d = yaml.load(file, Loader=yaml.FullLoader)

d['neo_args']['device'] = 'cuda:0'
# d['net_args']['pretrained_path'] = False
d['neo_args']['selectivity'] = False
# create_agent
agent = NeoAgent(TransNet(**d['net_args']), collate_fn=collate, device=d['neo_args']['device'])
alg = Neo(agent, env_config, d['neo_args'], d['train_args'], [], {})

In [13]:
paths = ['runs/random_dataset/q_split/1.pt', 'runs/random_dataset/sp_split/1.pt', 
        'runs/random_dataset/random_val/1.pt', 'runs/random_dataset/random_train_val/1.pt']

## Generate

In [16]:
for p in paths:
    path = Path(p)
    alg.agent.net.load_state_dict(torch.load(path, map_location=d['neo_args']['device']), strict=False)
    alg.agent.net.eval();
    save_path = path.parent / 'generated'
    save_path.mkdir(exist_ok=True)
    for q in tqdm(env_config['db_data']):
        plan = alg.generate_plan(q, num=1)
        plan.save(save_path / (q + ".json"))

100%|██████████| 113/113 [04:08<00:00,  2.20s/it]
100%|██████████| 113/113 [04:13<00:00,  2.24s/it]
100%|██████████| 113/113 [03:48<00:00,  2.02s/it]
100%|██████████| 113/113 [03:50<00:00,  2.04s/it]


## Evaluate

In [21]:
N = 3
save_path = Path('runs/random_dataset')

test_set = np.random.choice(list(env_config['db_data']), 30)
job_queries = {k: v[-1] for k,v in env_config['db_data'].items()}

data = defaultdict(list)
for i in range(N):
    for p in tqdm(paths):
        path = Path(p)
        _, times = get_times(path.parent / "generated", env_config, True)
        data[p].append(times)
        
for k,v in data.items():
    dic = {}
    for i in v[0].keys():
        dic[i] = np.mean([l[i] for l in v if l[i] is not None])
    data[k] = dic

sorted_queries = sorted(job_queries.keys(), key=lambda x: [x in test_set, (int(x.split(".")[0][:-1]), x.split(".")[0][-1])])
sorted_data = {}    
for name in data:
    sorted_data[name] =  np.array([data[name].get(k, np.nan) for k in sorted_queries])
    

d = {name : sorted_data[name] for name in sorted_data}
df = pd.DataFrame(d, index=sorted_queries)
df.to_csv(save_path / "latency.csv")
df

100%|██████████| 4/4 [21:06<00:00, 316.51s/it]
100%|██████████| 4/4 [23:11<00:00, 347.96s/it]
100%|██████████| 4/4 [20:03<00:00, 300.77s/it]


Unnamed: 0,runs/random_dataset/q_split/1.pt,runs/random_dataset/sp_split/1.pt,runs/random_dataset/random_val/1.pt,runs/random_dataset/random_train_val/1.pt
1a.sql,499.388000,516.096667,469.258667,467.546000
1b.sql,457.443333,438.333000,397.829000,432.308333
1c.sql,434.729000,404.772333,384.476333,419.626667
1d.sql,472.831667,439.687000,414.417000,423.868000
2a.sql,709.147000,788.007667,701.848000,723.138333
...,...,...,...,...
25a.sql,3163.609667,2983.052000,3053.919333,3132.844000
26c.sql,3918.443000,3865.290667,3866.488667,7248.639000
29c.sql,4209.297000,4526.218333,4050.516000,4168.188333
31c.sql,5044.078333,3900.053000,3971.598667,4276.200000
