**Notebook's outline:**

- load data for `JOB` benchmark and check the degree of uniqueness in it!
- proof the usefullness of tree convolution
- build neural network with best architecture

In [4]:
import os
import sys

if 'COLAB_GPU' in os.environ:
  from google.colab import drive
  print("Hello, Colab")
  drive.mount('/content/drive')
  ROOT_PATH = "/content/drive/MyDrive/hero"
  os.environ['CLEARML_CONFIG_FILE'] = f'{ROOT_PATH}/clearml.conf'
else:
  ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))

EXPERIMENT_PATH = f"{ROOT_PATH}/experiments/tcnn-abilities"
sys.path.insert(0, ROOT_PATH)

In [5]:
from collections import defaultdict
from json import load, dumps, dump

from tqdm import tqdm
import random
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

from src.utils import get_logical_plan, get_full_plan, get_selectivities
from src.models import binary_tree_layers as btl
from src.datasets.oracle import Oracle, OracleRequest, TIMEOUT
from src.datasets.data_config import HINTSETS, DOPS, HINTS, DEFAULT_HINTSET
from src.datasets.data_types import ExplainNode
from src.datasets.vectorization import extract_vertices_and_edges, ALL_FEATURES
from src.datasets.binary_tree_dataset import binary_tree_collate, BinaryTreeDataset, WeightedBinaryTreeDataset, weighted_binary_tree_collate

# Loading data ...

Добавь описание того, почему sample queries стал test, и что такое OOD, как это мерилось и так далее

In [6]:
job_oracle = Oracle(f"{ROOT_PATH}/data/processed/JOB")
sq_oracle = Oracle(f"{ROOT_PATH}/data/processed/sample_queries")

In [7]:
def extract_list_info(oracle, query_names):
    """ Collects dicts with info 'query_name', 'hintset', 'dop', 'vertices', 'edges', 'time'"""
    list_info = []

    for query_name in tqdm(query_names, leave=False):
        seen_logical_plans = set()
        timeouted_logical_plans_to_dops = defaultdict(set)
        timeouted_logical_plans_to_settings = defaultdict(list)
        logical_plan_to_times = defaultdict(list)
        for dop in DOPS:            
            for hintset in HINTSETS:
                custom_request = OracleRequest(query_name=query_name, hintset=hintset,dop=dop)
                custom_logical_plan = get_logical_plan(
                    query_name=query_name, 
                    oracle=oracle, 
                    hintset=hintset, 
                    dop=dop
                )
                 
                custom_time = oracle.get_execution_time(custom_request)
                if custom_time != TIMEOUT:
                    info = {"query_name": query_name, "hintset": hintset, "dop": dop}
                    time = torch.tensor(custom_time / 1000, dtype=torch.float32)
                    vertices, edges = extract_vertices_and_edges(
                        oracle.get_explain_plan(request=custom_request)
                    )
                    info.update({"time": time, "vertices": vertices, "edges": edges}) 
                    seen_logical_plans.add(custom_logical_plan)
                    list_info.append(info)
                    logical_plan_to_times[custom_logical_plan].append(time)
                else:
                    timeouted_logical_plans_to_dops[custom_logical_plan].add(dop)
                    timeouted_logical_plans_to_settings[custom_logical_plan].append((dop, hintset))

        for custom_logical_plan in timeouted_logical_plans_to_settings:             
            if custom_logical_plan in logical_plan_to_times:
                time = sum(logical_plan_to_times[custom_logical_plan]) / len(logical_plan_to_times[custom_logical_plan])
            else:
                max_def_time = 0
                for dop in timeouted_logical_plans_to_dops[custom_logical_plan]:
                    def_request = OracleRequest(query_name=query_name, hintset=0, dop=dop)
                    def_time = oracle.get_execution_time(request=def_request)
                    max_def_time = max(max_def_time, def_time)
                time = torch.tensor(2 * max_def_time / 1000, dtype=torch.float32)

            for dop, hintset in timeouted_logical_plans_to_settings[custom_logical_plan]:
                info = {"query_name": query_name, "hintset": hintset, "dop": dop}
                custom_request = OracleRequest(query_name=query_name, hintset=hintset, dop=dop)
                vertices, edges = extract_vertices_and_edges(
                    oracle.get_explain_plan(request=custom_request)
                )
                info.update({"time": time, "vertices": vertices, "edges": edges})             
                list_info.append(info)
            
    return list_info

In [8]:
job_list_info = extract_list_info(oracle=job_oracle, query_names=job_oracle.get_query_names())
job_max_possible_size = len(job_oracle.get_query_names()) * len(DOPS) * len(HINTSETS)
print(f"[JOB]: dataset size is {len(job_list_info)} / {job_max_possible_size}")

                                                 

[JOB]: dataset size is 43392 / 43392




In [9]:
job_list_vertices = [info["vertices"] for info in job_list_info]
job_list_edges = [info["edges"] for info in job_list_info]
job_list_time = [info["time"] for info in job_list_info]

In [10]:
sq_list_info = extract_list_info(oracle=sq_oracle, query_names=sq_oracle.get_query_names())
sq_max_possible_size = len(sq_oracle.get_query_names()) * len(DOPS) * len(HINTSETS)
print(f"[SQ]: dataset size is {len(sq_list_info)} / {sq_max_possible_size}")

                                               

[SQ]: dataset size is 15360 / 15360




In [11]:
sq_list_vertices = [info["vertices"] for info in sq_list_info]
sq_list_edges = [info["edges"] for info in sq_list_info]
sq_list_time = [info["time"] for info in sq_list_info]

In [12]:
job_X = set([(str(v.flatten().tolist()), str(e.flatten().tolist())) for v, e in zip(job_list_vertices, job_list_edges)])
sq_X = set([(str(v.flatten().tolist()), str(e.flatten().tolist())) for v, e in zip(sq_list_vertices, sq_list_edges)])
print(f"Around {100 * len(sq_X & job_X) / len(sq_X):0.1f} % of plans from SQ bench exists in JOB bench")

Around 23.0 % of plans from SQ bench exists in JOB bench


In [13]:
job_logical_plans = set(
    get_logical_plan(oracle=job_oracle,  query_name=query_name, hintset=hintset, dop=dop)
    for query_name in job_oracle.get_query_names() for hintset in HINTSETS for dop in DOPS
)

ood_sq_list_info, id_sq_list_info = [], []
for info in sq_list_info:
    logical_plan = get_logical_plan(oracle=sq_oracle, query_name=info["query_name"], hintset=info["hintset"], dop=info["dop"])
    if logical_plan not in job_logical_plans:
        ood_sq_list_info.append(info)
    else:
        id_sq_list_info.append(info)
assert len(ood_sq_list_info) + len(id_sq_list_info) == len(sq_list_info)

In [14]:
ood_sq_list_vertices = [info["vertices"] for info in ood_sq_list_info]
ood_sq_list_edges = [info["edges"] for info in ood_sq_list_info]
ood_sq_list_time = [info["time"] for info in ood_sq_list_info]

id_sq_list_vertices = [info["vertices"] for info in id_sq_list_info]
id_sq_list_edges = [info["edges"] for info in id_sq_list_info]
id_sq_list_time = [info["time"] for info in id_sq_list_info]

In [15]:
job_all_plans = [
    get_full_plan(oracle=job_oracle, query_name=query_name, hintset=hintset, dop=dop)
    for query_name in job_oracle.get_query_names() for hintset in HINTSETS for dop in DOPS
]
print(f"[JOB]: total # of unique plans: {len(set(job_all_plans))} / {len(job_all_plans)}")

[JOB]: total # of unique plans: 7429 / 43392


In [16]:
sq_all_plans = [
    get_full_plan(oracle=sq_oracle, query_name=query_name, hintset=hintset, dop=dop)
    for query_name in sq_oracle.get_query_names() for hintset in HINTSETS for dop in DOPS
]
print(f"[SQ]: total # of unique plans: {len(set(sq_all_plans))} / {len(sq_all_plans)}")

[SQ]: total # of unique plans: 4496 / 15360


**Observation.** We see that most of the plans are repetitive. This is the reason for switching to *weighted datasets* to speed up the learning process.


## Datasets

In [17]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device is {DEVICE}")

device is cpu


In [18]:
job_dataset = BinaryTreeDataset(job_list_vertices, job_list_edges, job_list_time, DEVICE)
job_weighted_dataset = WeightedBinaryTreeDataset(job_list_vertices, job_list_edges, job_list_time, DEVICE)
print(f"[JOB]: total # of unique plans in the weighted dataset {len(job_weighted_dataset)}")

[JOB]: total # of unique plans in the weighted dataset 7323


In [19]:
ood_sq_dataset = BinaryTreeDataset(ood_sq_list_vertices, ood_sq_list_edges, ood_sq_list_time, DEVICE)
ood_sq_weighted_dataset = WeightedBinaryTreeDataset(ood_sq_list_vertices, ood_sq_list_edges, ood_sq_list_time, DEVICE)
print(f"[SQ, OOD]: total # of unique plans in the weighted dataset {len(ood_sq_weighted_dataset)}")

id_sq_dataset = BinaryTreeDataset(id_sq_list_vertices, id_sq_list_edges, id_sq_list_time, DEVICE)
id_sq_weighted_dataset = WeightedBinaryTreeDataset(id_sq_list_vertices, id_sq_list_edges, id_sq_list_time, DEVICE)
print(f"[SQ, ID]: total # of unique plans in the weighted dataset {len(id_sq_weighted_dataset)}")

[SQ, OOD]: total # of unique plans in the weighted dataset 3409
[SQ, ID]: total # of unique plans in the weighted dataset 1021


**O**ut **O**f **D**istribution vs **I**n **D**istribution

In [20]:
def build_best_predictor(dataset):
    sum_dict = defaultdict(float)
    count_dict = defaultdict(int)

    for v, e, t in dataset:
        key = str(v.flatten().tolist()), str(e.flatten().tolist())
        sum_dict[key] += t.item()
        count_dict[key] += 1

    def predictor(v, e):
        key = str(v.flatten().tolist()), str(e.flatten().tolist())
        if key not in sum_dict:
            return None
        return sum_dict[key] / count_dict[key]
    return predictor

In [18]:
predictor = build_best_predictor(job_dataset)

In [19]:
# % of data from JOB, mse / best possible mse ...

# Baselines

Let's evaluate the worst and best possible predictors so we know what we should aim for.

In [20]:
def calculate_mse_for_best_constant_predictor(dataset):
    time_sum = 0
    count = 0

    for v, e, t in dataset:
        time_sum += t
        count += 1

    mse_loss = 0
    for v, e, t in dataset:
        t_pred = time_sum / count
        mse_loss += (t_pred - t) ** 2

    return mse_loss.item() / len(dataset)

In [21]:
def calculate_mse_for_best_possible_predictor(dataset):
    sum_dict = defaultdict(float)
    count_dict = defaultdict(int)

    for v, e, t in dataset:
        key = str(v.flatten().tolist()), str(e.flatten().tolist())
        sum_dict[key] += t.item()
        count_dict[key] += 1

    mse_loss = 0
    for v, e, t in dataset:
        key = str(v.flatten().tolist()), str(e.flatten().tolist())
        t_pred = sum_dict[key] / count_dict[key]
        mse_loss += (t_pred - t) ** 2

    return mse_loss/ len(dataset)

In [22]:
print(f"[JOB]: MSE of the best constant predictor is {calculate_mse_for_best_constant_predictor(job_dataset):0.3f}")
print(f"[JOB]: MSE of the best possible predictor is {calculate_mse_for_best_possible_predictor(job_dataset):0.3f}")

[JOB]: MSE of the best constant predictor is 135.607
[JOB]: MSE of the best possible predictor is 0.662


In [23]:
print(f"[SQ, OOD]: MSE of the best constant predictor is {calculate_mse_for_best_constant_predictor(ood_sq_dataset):0.3f}")
print(f"[SQ, OOD]: MSE of the best possible predictor is {calculate_mse_for_best_possible_predictor(ood_sq_dataset):0.3f}")

print(f"[SQ, ID]: MSE of the best constant predictor is {calculate_mse_for_best_constant_predictor(id_sq_dataset):0.3f}")
print(f"[SQ, ID]: MSE of the best possible predictor is {calculate_mse_for_best_possible_predictor(id_sq_dataset):0.3f}")

[SQ, OOD]: MSE of the best constant predictor is 3953.941
[SQ, OOD]: MSE of the best possible predictor is 0.036
[SQ, ID]: MSE of the best constant predictor is 650.152
[SQ, ID]: MSE of the best possible predictor is 0.105


# Helpers

In [21]:
def save_ckpt(model, optimizer, scheduler, epoch, path):
    state = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict()
    }
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(state, path)

In [22]:
def load_ckpt(model, ckpt_path):
    ckpt_state = torch.load(ckpt_path)
    model.load_state_dict(ckpt_state['model_state_dict'])
    optimizer = optim.Adam(model.parameters(), lr=lr)
    optimizer.load_state_dict(ckpt_state['optimizer_state_dict'])
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2)
    scheduler.load_state_dict(ckpt_state['scheduler_state_dict'])
    start_epoch = ckpt_state["epoch"]    
    return model, optimizer, scheduler, start_epoch

In [23]:
def get_prediction(model, q_n, hs, dop):
    request = OracleRequest(query_name=q_n, hintset=hs, dop=dop)
    explain_plan = oracle.get_explain_plan(request=request)
    vertices, edges = extract_vertices_and_edges(explain_plan)
    (vertices_batch, edges_batch), y_batch = binary_tree_collate([(vertices, edges, torch.tensor([1.0]))], max_length)                
    return float(model(vertices_batch, edges_batch).squeeze(0))

In [24]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

In [25]:
def get_time(oracle, query_name, hs, dop):
    request = OracleRequest(query_name=query_name, hintset=hs, dop=dop)
    explain_plan = oracle.get_execution_time(request=request)
    return explain_plan / 1000

In [26]:
def evaluate_model(model, dataloader):
    loss_sum = .0
    for ve, t in tqdm(dataloader):
        loss_sum += sum((model(*ve).squeeze(-1) - t).to("cpu") ** 2)
    print(f'MSE of given model is {loss_sum.item() / len(dataloader.dataset):0.3f}')

# Architecture searching

In [27]:
in_channels = len(ALL_FEATURES)

In [28]:
class BinaryTreeRegressor(nn.Module):
    def __init__(self, btcnn: "btl.BinaryTreeSequential", fcnn: "torch.nn.Sequential", name: "str" = "unknown"):
        super().__init__()
        self.btcnn: "btl.BinaryTreeSequential" = btcnn
        self.fcnn: "torch.nn.Sequential" = fcnn
        self.name = name

    def forward(self, vertices: "Tensor", edges: "Tensor") -> "Tensor":
        return self.fcnn(self.btcnn(vertices=vertices, edges=edges))

## The usefullness of tree convolution layers

In [29]:
no_btcnn = lambda: btl.BinaryTreeSequential(
    btl.BinaryTreeAdaptivePooling(torch.nn.AdaptiveMaxPool1d(1)),
)

small_btcnn = lambda: btl.BinaryTreeSequential(
    btl.BinaryTreeConv(in_channels, in_channels),
    btl.BinaryTreeAdaptivePooling(torch.nn.AdaptiveMaxPool1d(1)),
)

medium_btcnn = lambda: btl.BinaryTreeSequential(
    btl.BinaryTreeConv(in_channels, 128),
    btl.BinaryTreeActivation(torch.nn.functional.leaky_relu),
    btl.BinaryTreeConv(128, in_channels),
    btl.BinaryTreeAdaptivePooling(torch.nn.AdaptiveMaxPool1d(1)),
)

small_fcnn = lambda: torch.nn.Sequential(
    nn.Linear(in_channels, 64),
    nn.LeakyReLU(),
    nn.Linear(64, 1),
    nn.Softplus(),
)

medium_fcnn = lambda: torch.nn.Sequential(
    nn.Linear(in_channels, 256),
    nn.LeakyReLU(),
    nn.Linear(256, 128),
    nn.LeakyReLU(),
    nn.Linear(128, 64),
    nn.LeakyReLU(),
    nn.Linear(64, 1),
    nn.Softplus(),
)

In [30]:
def initialize_models():
    return [
        BinaryTreeRegressor(no_btcnn(), small_fcnn(), "NoBTCNN_SmallFCNN"),
        BinaryTreeRegressor(no_btcnn(), medium_fcnn(), "NoBTCNN_MediumFCNN"),
        BinaryTreeRegressor(small_btcnn(), small_fcnn(), "SmallBTCNN_SmallFCNN"),
        BinaryTreeRegressor(small_btcnn(), medium_fcnn(), "SmallBTCNN_MediumFCNN"),
        BinaryTreeRegressor(medium_btcnn(), small_fcnn(), "MediumBTCNN_SmallFCNN"),
        BinaryTreeRegressor(medium_btcnn(), medium_fcnn(), "MediumBTCNN_MediumFCNN"),
    ]

# Training

In [31]:
max_length = max([v.shape[0] for v in job_list_vertices + ood_sq_list_vertices + id_sq_list_vertices])
print(f"The longest tree has length {max_length}")
batch_size = 32
lr = 3e-4

The longest tree has length 66


In [32]:
def generate_dataloaders(n):
    res = []
    for seed in range(42, 42+n):
        generator = torch.Generator().manual_seed(seed)
        train_dataset, val_dataset = torch.utils.data.dataset.random_split(job_weighted_dataset, [0.8, 0.2], generator=generator)
        test_dataset = id_sq_weighted_dataset
        ood_dataset = ood_sq_weighted_dataset
        train_dataloader, val_dataloader, test_dataloader, ood_dataloader = [
            DataLoader(
                dataset=dataset,
                batch_size=batch_size,
                shuffle=True,
                collate_fn=lambda el: weighted_binary_tree_collate(el, max_length),
                drop_last=False
            )
            for dataset in [train_dataset, val_dataset, test_dataset, ood_dataset]
        ]        
        yield (train_dataloader, val_dataloader, test_dataloader, ood_dataloader)

In [33]:
def calculate_loss(model, optimizer, criterion, dataloader, train_mode=True):
    model.train() if train_mode else model.eval()
    running_loss, total_samples = .0, 0
    for (vertices, edges, freq), time in dataloader:
        if train_mode:
            optimizer.zero_grad()
        
        outputs = model(vertices, edges)
        weighted_loss = (freq.float().squeeze(-1) * criterion(outputs.squeeze(-1), time)).mean()
        
        if train_mode:
            weighted_loss.backward()
            optimizer.step()

        running_loss += weighted_loss.item() * vertices.size(0)
        total_samples += freq.sum()
    return running_loss / total_samples

In [37]:
def weighted_train_loop(
    model,  optimizer,  criterion, scheduler, train_dataloader, num_epochs, clearml_task, 
    start_epoch=0, metadata=None, ckpt_period=10, eval_period=10, path_to_save=None, val_dataloader=None, test_dataloader=None, ood_dataloader=None
    ):
    tqdm_desc = "Initialization"
    progress_bar = tqdm(range(start_epoch + 1, start_epoch + num_epochs + 1), desc=tqdm_desc, leave=True, position=0)

    for epoch in progress_bar:
        train_loss = calculate_loss(model, optimizer, criterion, train_dataloader)
        scheduler.step(train_loss)
        clearml_task.get_logger().report_scalar('MSE', "[train] " + model.name, iteration=epoch, value=train_loss)
        progress_bar.set_description(f'[{epoch}/{start_epoch + num_epochs}] MSE: {train_loss:.4f}')

        if val_dataloader and not epoch % eval_period:
          with torch.no_grad():
            val_loss = calculate_loss(model, optimizer, criterion, val_dataloader, train_mode=False)
            clearml_task.get_logger().report_scalar('MSE', "[val] " + model.name, iteration=epoch, value=val_loss)

        if test_dataloader and not epoch % eval_period:
          with torch.no_grad():
            test_loss = calculate_loss(model, optimizer, criterion, test_dataloader, train_mode=False)
            clearml_task.get_logger().report_scalar('MSE', "[test] " + model.name,iteration=epoch, value=test_loss)

        if ood_dataloader and not epoch % eval_period:
            with torch.no_grad():
              ood_loss = calculate_loss(model, optimizer, criterion, ood_dataloader, train_mode=False)
              clearml_task.get_logger().report_scalar('MSE', "[ood] " + model.name,iteration=epoch, value=ood_loss)            

        if path_to_save and not epoch % ckpt_period:
            save_ckpt(model, optimizer, scheduler, epoch, path_to_save)

In [49]:
!pip install git+https://github.com/allegroai/clearml
!pip install clearml-agent

In [35]:
from clearml import Task
task_for_tcnn = Task.init(project_name="hero", task_name='Test generalization')
assert task_for_tcnn is not None

ClearML Task: created new task id=3ce65a095d204e5e98eb410780f12ad3
2024-07-01 16:01:35,316 - clearml.Task - INFO - Storing jupyter notebook directly as code
ClearML results page: https://app.clear.ml/projects/8c90433626c94a93bf97392993994c51/experiments/3ce65a095d204e5e98eb410780f12ad3/output/log


In [38]:
epochs = 300
metadata = {
    "data": "weighted_dataset",
    "lr": lr,
    "batch_size": batch_size
}

n_runs = 5
for run, (train_dataloader, val_dataloader, test_dataloader, ood_dataloader) in enumerate(generate_dataloaders(n_runs), start=1):
    for model in initialize_models():
        model.name = model.name + "_" + str(run)
        model.btcnn.to(DEVICE)
        model.fcnn.to(DEVICE)

        optimizer = optim.Adam(model.parameters(), lr=lr)
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20)
        set_seed(2024)
        weighted_train_loop(
            model=model,
            optimizer=optimizer,
            criterion=nn.MSELoss(reduction="none"),
            scheduler=scheduler,
            train_dataloader=train_dataloader,
            num_epochs=epochs,
            clearml_task=task_for_tcnn,
            metadata=metadata,
            ckpt_period=300,
            eval_period=5,
            val_dataloader=val_dataloader,
            test_dataloader=test_dataloader,
            ood_dataloader=ood_dataloader,
            path_to_save=f"{EXPERIMENT_PATH}/models/{model.name}.pth",
        )

[1/50] MSE: 124.1291:   2%|▏         | 1/50 [00:08<07:17,  8.93s/it]


KeyboardInterrupt: 

: 

In [54]:
big_btcnn = lambda: btl.BinaryTreeSequential(
    btl.BinaryTreeConv(in_channels, 64),
    btl.BinaryTreeActivation(torch.nn.functional.leaky_relu),
    btl.BinaryTreeConv(64, 128),
    btl.BinaryTreeActivation(torch.nn.functional.leaky_relu),
    btl.BinaryTreeConv(128, 256),
    btl.BinaryTreeActivation(torch.nn.functional.leaky_relu),
    btl.BinaryTreeConv(256, 512),
    btl.BinaryTreeAdaptivePooling(torch.nn.AdaptiveMaxPool1d(1)),
)

big_btcnn_and_layer_norm = lambda: btl.BinaryTreeSequential(
    btl.BinaryTreeConv(in_channels, 64),
    btl.BinaryTreeActivation(torch.nn.functional.leaky_relu),
    btl.BinaryTreeConv(64, 128),
    btl.BinaryTreeLayerNorm(128),
    btl.BinaryTreeActivation(torch.nn.functional.leaky_relu),
    btl.BinaryTreeConv(128, 256),
    btl.BinaryTreeLayerNorm(256),
    btl.BinaryTreeActivation(torch.nn.functional.leaky_relu),
    btl.BinaryTreeConv(256, 512),
    btl.BinaryTreeAdaptivePooling(torch.nn.AdaptiveMaxPool1d(1)),
)

big_btcnn_and_instance_norm = lambda: btl.BinaryTreeSequential(
    btl.BinaryTreeConv(in_channels, 64),
    btl.BinaryTreeActivation(torch.nn.functional.leaky_relu),
    btl.BinaryTreeConv(64, 128),
    btl.BinaryTreeInstanceNorm(128),
    btl.BinaryTreeActivation(torch.nn.functional.leaky_relu),
    btl.BinaryTreeConv(128, 256),
    btl.BinaryTreeInstanceNorm(256),
    btl.BinaryTreeActivation(torch.nn.functional.leaky_relu),
    btl.BinaryTreeConv(256, 512),
    btl.BinaryTreeAdaptivePooling(torch.nn.AdaptiveMaxPool1d(1)),
)

big_fcnn = lambda: torch.nn.Sequential(
    nn.Linear(512, 256),
    nn.LeakyReLU(),
    nn.Linear(256, 128),
    nn.LeakyReLU(),
    nn.Linear(128, 64),
    nn.LeakyReLU(),
    nn.Linear(64, 1),
    nn.Softplus(),
)

def initialize_big_models():
    return [
        BinaryTreeRegressor(big_btcnn(), big_fcnn(), "BigBTCNN_BigFCNN"),
        BinaryTreeRegressor(big_btcnn_and_layer_norm(), big_fcnn(), "BigBTCNN_BigFCNN_LayerNorm"),
        BinaryTreeRegressor(big_btcnn_and_instance_norm(), big_fcnn(), "BigBTCNN_BigFCNN_InstanceNorm"),
    ]

# Loading

In [32]:
model, optimizer, scheduler, start_epoch = load_ckpt(
    model=BinaryTreeRegressor(no_btcnn(), small_fcnn()),
    ckpt_path=f"{EXPERIMENT_PATH}/models/NoBTCNN_SmallFCNN.pth",
)
evaluate_model(model, dataloader)

2024-06-10 19:58:28,200 - clearml.model - INFO - Selected model id: 7e95f3e08d0f43a9b9dfab65748e78b2


100%|██████████| 1356/1356 [00:02<00:00, 612.14it/s]

MSE of given model is 120.312





In [33]:
def hand_on_mse(model, list_info):
    total_mse = .0
    for info in list_info:
        time, query_name, hintset, dop, vertices, edges = info["time"], info["query_name"], info["hintset"], info["dop"], info["vertices"], info["edges"]
        prediction = get_prediction(model, query_name, hintset, dop)
        total_mse += (prediction - time) ** 2
    print(f"Calculated by hand MSE {total_mse / len(list_info):0.3f}")

In [44]:
hand_on_mse(model, list_info)

Calculated by hand MSE 120.310


In [42]:
def hand_on_weighted_mse(model, freq_dataloader):
    loss_sum = .0
    for (v, e, w), t in tqdm(freq_dataloader):
        loss_sum += w.squeeze(-1).to("cpu").float() @ ((model(v, e).squeeze(-1) - t).to("cpu") ** 2)
    print(f'MSE of given model is {loss_sum.item() / len(dataloader.dataset):0.3f}')    

In [43]:
hand_on_weighted_mse(model, weighted_dataloader)

100%|██████████| 228/228 [00:00<00:00, 676.10it/s]

MSE of given model is 119.134





In [None]:
def retrain(model, metadata, ckpt_period, num_epochs, data_loader):
    model, optimizer, scheduler, start_epoch = load_ckpt(
        model=model,
        ckpt_path=f"{EXPERIMENT_PATH}/models/{model.name}.pth",
    )

    model.btcnn.to(DEVICE)
    model.fcnn.to(DEVICE)

    freq_train_loop(
        model=model,
        optimizer=optimizer,
        criterion = nn.MSELoss(reduction="none"),
        scheduler=scheduler,
        data_loader=data_loader,
        num_epochs=num_epochs,
        start_epoch=start_epoch,
        clearml_task=task_for_tcnn,
        descr="[all_plans] " + name,
        metadata=metadata,
        ckpt_period=ckpt_period,
        path_to_save=f"{EXPERIMENT_PATH}/models/{name}",
    )