In [1]:
import pandas as pd
import os
import torch
from model.database_util import *

In [2]:
# Parameters cell (tagged as 'parameters')
QUERYID = 1
query1_ts = '2024-12-03-20.44.01.238859'

In [3]:
# Base directory name
base_dir = 'job_queries'

In [4]:


# File paths for the query explanation components
file_paths = {
    'operator': f'{base_dir}/EXPLAIN_OPERATOR.csv',
    'stream': f'{base_dir}/EXPLAIN_STREAM.csv',
    'predicate': f'{base_dir}/EXPLAIN_PREDICATE.csv'
}


In [5]:
max_node = 30
rel_pos_max = 20

# Extract the operator ids and types

In [6]:
import pandas as pd
import os

def extract_operator_ids_and_types(explain_operator_path, timestamp):
    """
    Extract rows from a CSV file where the EXPLAIN_TIME column matches a given timestamp.

    Parameters:
        explain_operator_path (str): Path to the CSV file.
        timestamp: The timestamp value to filter rows by.

    Returns:
        pd.DataFrame: A DataFrame containing the matching rows with EXPLAIN_TIME, OPERATOR_ID, and OPERATOR_TYPE.
    """
    # Check if the file exists
    if not os.path.exists(explain_operator_path):
        raise FileNotFoundError(f"The file '{explain_operator_path}' does not exist.")
    
    # Read the CSV file
    try:
        df = pd.read_csv(explain_operator_path)
    except Exception as e:
        raise ValueError(f"Error reading the file: {e}")

    # Check if the required columns exist
    required_columns = ['EXPLAIN_TIME', 'OPERATOR_ID', 'OPERATOR_TYPE']
    if not all(col in df.columns for col in required_columns):
        raise KeyError(f"The file must contain the columns: {required_columns}")

    # Filter rows where the EXPLAIN_TIME column matches the timestamp value
    df_ops = df[df['EXPLAIN_TIME'] == timestamp][required_columns]

    # Warn if no rows match the timestamp
    if df_ops.empty:
        print(f"No rows found for timestamp '{timestamp}'.")

    return df_ops


# try:
#     result = extract_operator_ids_and_types(explain_operator_path, query1_ts)
#     print(result)
# except Exception as e:
#     print(f"Error: {e}")


# Extract the Operators accessing objects

In [7]:
import pandas as pd

def extract_stream_table(explain_stream_path, timestamp):
    """
    Process the EXPLAIN_STREAM CSV file to extract relevant rows, concatenate schema and name, and prepare the output.

    Parameters:
        explain_stream_path (str): Path to the EXPLAIN_STREAM CSV file.
        timestamp (str): The timestamp value to filter rows by.

    Returns:
        pd.DataFrame: A DataFrame containing EXPLAIN_TIME, OPERATOR_ID, and TABLE columns.
    """
    # Read the CSV file
    df = pd.read_csv(explain_stream_path)

    # Filter for the given EXPLAIN_TIME value and non-NaN OBJECT_NAME
    filtered_df = df.loc[(df['EXPLAIN_TIME'] == timestamp) & (df['OBJECT_NAME'].notna())].copy()

    # Concatenate OBJECT_SCHEMA and OBJECT_NAME to create a new column, removing any extra spaces
    filtered_df['TABLE'] = (
        filtered_df['OBJECT_SCHEMA'].str.strip().fillna('') +
        '.' +
        filtered_df['OBJECT_NAME'].str.strip()
    )

    # Select and rename the TARGET_ID column to OPERATOR_ID
    df_table = filtered_df[['EXPLAIN_TIME', 'SOURCE_ID', 'TARGET_ID', 'TABLE']].copy()
    df_table.rename(columns={'TARGET_ID': 'OPERATOR_ID'}, inplace=True)

    return df_table


# Example usage:
# explain_stream_path = f'{base_dir}/EXPLAIN_STREAM.csv'

# query1_ts = '2024-12-03-20.44.07.528574'  # Replace with your timestamp
# result = extract_stream_table(explain_stream_path, query1_ts)
# print(result)


# Extract local predicates

In [8]:
import pandas as pd

def extract_local_predicates(explain_predicate_path, timestamp):
    """
    Extract and process local predicates (WHERE/HAVING conditions) from the EXPLAIN_PREDICATE CSV file.

    Parameters:
        explain_predicate_path (str): Path to the EXPLAIN_PREDICATE CSV file.
        timestamp (str): The timestamp value to filter rows by.

    Returns:
        pd.DataFrame: A DataFrame with grouped local predicates (PREDICATE1, PREDICATE2, PREDICATE3) 
                      and their corresponding FILTER_FACTOR columns (FILTER_FACTOR1, FILTER_FACTOR2, FILTER_FACTOR3) 
                      for each EXPLAIN_TIME and OPERATOR_ID.
    """
    # Define the columns to read from the file
    columns_to_read = ['EXPLAIN_TIME', 'OPERATOR_ID', 'HOW_APPLIED', 'PREDICATE_TEXT', 'FILTER_FACTOR']

    # Read the CSV file, selecting only relevant columns
    df = pd.read_csv(explain_predicate_path, usecols=columns_to_read)

    # Strip whitespace from EXPLAIN_TIME and HOW_APPLIED columns
    df['EXPLAIN_TIME'] = df['EXPLAIN_TIME'].str.strip()
    df['HOW_APPLIED'] = df['HOW_APPLIED'].str.strip()

    # Filter rows where EXPLAIN_TIME matches the timestamp and HOW_APPLIED is 'SARG'
    filtered_df = df.loc[(df['EXPLAIN_TIME'] == timestamp) & (df['HOW_APPLIED'] == 'SARG')].copy()

    # Remove only outer parentheses from the PREDICATE_TEXT column
    def remove_outer_parentheses(predicate_text):
        # Check if the text starts and ends with parentheses
        if isinstance(predicate_text, str) and predicate_text.startswith('(') and predicate_text.endswith(')'):
            # Remove only the outermost parentheses
            return predicate_text[1:-1]
        return predicate_text

    filtered_df['PREDICATE_TEXT'] = filtered_df['PREDICATE_TEXT'].apply(remove_outer_parentheses)

    # Select the relevant columns for the result
    df_predicates = filtered_df[['EXPLAIN_TIME', 'OPERATOR_ID', 'PREDICATE_TEXT', 'FILTER_FACTOR']]

    # Group by EXPLAIN_TIME and OPERATOR_ID and apply transformation
    df_local_predicates = df_predicates.groupby(['EXPLAIN_TIME', 'OPERATOR_ID']).agg(
        PREDICATE1=('PREDICATE_TEXT', lambda x: x.iloc[0] if len(x) > 0 else None),
        PREDICATE2=('PREDICATE_TEXT', lambda x: x.iloc[1] if len(x) > 1 else None),
        PREDICATE3=('PREDICATE_TEXT', lambda x: x.iloc[2] if len(x) > 2 else None),
        FILTER_FACTOR1=('FILTER_FACTOR', lambda x: x.iloc[0] if len(x) > 0 else None),
        FILTER_FACTOR2=('FILTER_FACTOR', lambda x: x.iloc[1] if len(x) > 1 else None),
        FILTER_FACTOR3=('FILTER_FACTOR', lambda x: x.iloc[2] if len(x) > 2 else None),
    ).reset_index()
    
    # Replace NaN values in FILTER_FACTOR columns with 1.0
    filter_factor_columns = ['FILTER_FACTOR1', 'FILTER_FACTOR2', 'FILTER_FACTOR3']
    df_local_predicates[filter_factor_columns] = df_local_predicates[filter_factor_columns].fillna(1.0)

    return df_local_predicates


# Example usage:
# explain_predicate_path = 'SimpleQueriesSQ12c/EXPLAIN_PREDICATE.csv'
# query1_ts = '2024-11-28 12:00:00'  # Replace with your timestamp
# explain_predicate_path = f'{base_dir}/EXPLAIN_PREDICATE.csv'
# result = extract_local_predicates(explain_predicate_path, query1_ts)
# print(result)


In [9]:
# result.info()

In [10]:
# result.head(3)

In [11]:
# result[['OPERATOR_ID', 'PREDICATE1', 'PREDICATE2', 
#        'FILTER_FACTOR1', 'FILTER_FACTOR2']]

In [12]:
pd.set_option('display.max_colwidth', None)

# Parsing the PREDICATEs and adding PARSED_PREDICATE columns

In [13]:
import re

def parse_predicate(predicate):
    """
    Parse a predicate string into a tuple (col, op, val).
    Supports various predicate types including IN, LIKE, IS NULL, and comparisons.
    """
    print('Original predicate:', predicate)
    if predicate is None or pd.isna(predicate):
        return None

    predicate = predicate.strip()

    # Patterns for different predicate types
    patterns = {
        'basic': r'(?P<col>[\w\.]+)\s*(?P<op>=|<=|>=|<|>)\s*(?P<val>.+)',
        'in': r'(?P<col>[\w\.]+)\s+IN\s+\((?P<val>.+?)\)',
        'like': r'(?P<col>[\w\.]+)\s+LIKE\s+(?P<val>.+)',
        'is': r'(?P<col>[\w\.]+)\s+IS\s+(?P<val>NULL|NOT\s+NULL)'
    }

    # Match and parse based on patterns
    for key, pattern in patterns.items():
        match = re.match(pattern, predicate, flags=re.IGNORECASE)
        if match:
            print(f"Matched pattern: {key} | Groups: {match.groupdict()}")
            col = match.group('col').strip()
            
            col_parts = col.split('.')
            if len(col_parts) == 3:
                col = '.'.join(col_parts[1:])
            
            op = match.group('op')
            val = match.group('val').strip()

            # Handle special cases
            if key == 'in':
                # Split and clean IN clause values
                val = [
                    f"'{v.strip()}'" for v in val.split(',')
                ]
            elif key == 'basic' and val.startswith("'") and val.endswith("'"):
                # Clean single-quoted values for basic predicates
                val = val.strip("'")

            return (col, op, val)

    # Log failed predicates
    print(f"Failed to match predicate: {predicate}")
    return None


In [14]:
explain_predicate_path = f'{base_dir}/EXPLAIN_PREDICATE.csv'
# query1_ts = '2024-10-08-04.51.06.406285' # not working

# Use the functions to extract features
df_pred = extract_local_predicates(explain_predicate_path, query1_ts)

  df_local_predicates[filter_factor_columns] = df_local_predicates[filter_factor_columns].fillna(1.0)


In [15]:
df_pred

Unnamed: 0,EXPLAIN_TIME,OPERATOR_ID,PREDICATE1,PREDICATE2,PREDICATE3,FILTER_FACTOR1,FILTER_FACTOR2,FILTER_FACTOR3
0,2024-12-03-20.44.01.238859,4,"Q4.INFO IN ('Sweden ', 'Norway ', 'Germany ', 'Denmark ', 'Swedish ', 'Denish ', 'Norwegian ', 'German ', 'USA ', 'CANADA ', 'Netherlands ', 'Brazil ', 'UK ', 'Belgium ', 'Finland ', 'Hungary ', 'Estonia ', 'worldwide ', 'Australia ', 'Spain ', 'France ', 'Japan ', 'Columbia ', 'Slovenia ', 'Israel ', 'Venezuela ', 'Nigeria ', 'Philippines ', 'New Zealand ', 'Ireland ', 'Romania ', 'Non-USA ', 'Bulgaria ', 'Argentina ', 'Malaysia ', 'Singapore ', 'Turkey ', 'Sri Lanka ', 'Italy ', 'Indonesia ', 'South Korea ', 'Vietnam ', 'Slovakia ', 'Czech Republic ', 'China ', 'Portugal ', 'Greece ', 'Republic of Macedonia', 'Serbia ', 'Jamaica ', 'Switzerland ', 'Yugoslavia ', 'Mexico ', 'Austria ', 'Russia ')",,,0.029517,1.0,1.0
1,2024-12-03-20.44.01.238859,9,Q9.COUNTRY_CODE = '[jp]',,,0.024702,1.0,1.0
2,2024-12-03-20.44.01.238859,12,Q5.NOTE IS NULL,,,0.487515,1.0,1.0
3,2024-12-03-20.44.01.238859,14,Q1.PRODUCTION_YEAR <= 2020,1950 <= Q1.PRODUCTION_YEAR,,0.971485,0.894801,1.0
4,2024-12-03-20.44.01.238859,17,Q7.KEYWORD = 'superhero',,,7e-06,1.0,1.0
5,2024-12-03-20.44.01.238859,18,Q8.KIND = 'production companies',,,0.25,1.0,1.0
6,2024-12-03-20.44.01.238859,19,Q6.LINK LIKE '%follows%',,,0.083333,1.0,1.0


In [16]:
import re

def parse_predicate(predicate):
    print('Original predicate:', predicate)
    if predicate is None or pd.isna(predicate):
        return None

    # Match basic predicate (col, op, val)
    basic_pattern = r'(?P<col>[\w\.]+)\s*(?P<op><=|>=|<|>|=|!=|LIKE)\s*(?P<val>.+)'
    # Match IN predicate
    in_pattern = r'(?P<col>[\w\.]+)\s+IN\s+\((?P<val>.+?)\)'

    # Try to match basic predicate
    match = re.match(basic_pattern, predicate.strip())
    if match:
        col = match.group('col')
        op = match.group('op')
        val = match.group('val')
        return (col, op, val)

    # Try to match IN predicate
    match = re.match(in_pattern, predicate.strip(), flags=re.IGNORECASE)
    if match:
        col = match.group('col')
        op = 'IN'
        val = match.group('val')

        # Debugging: Log the raw values before processing
        print('Raw IN values:', val)

        # Split by commas and clean spaces inside quotes
        val_list = [
            f'{v.strip().strip('\'').strip()}' if v.strip().startswith("'") else v.strip()
            for v in val.split(',')
        ]

        # Debugging: Log the cleaned values after processing
        print('Cleaned IN values:', val_list)

        return (col, op, val_list)

    return None

predicate = "AHNAF.MOVIE_INFO.INFO IN (' Sweden ', 'Norway ', ' Germany', ' Denmark ')"
result = parse_predicate(predicate)
print('Parsed result:', result)


Original predicate: AHNAF.MOVIE_INFO.INFO IN (' Sweden ', 'Norway ', ' Germany', ' Denmark ')
Raw IN values: ' Sweden ', 'Norway ', ' Germany', ' Denmark '
Cleaned IN values: ['Sweden', 'Norway', 'Germany', 'Denmark']
Parsed result: ('AHNAF.MOVIE_INFO.INFO', 'IN', ['Sweden', 'Norway', 'Germany', 'Denmark'])


In [17]:
df_pred

Unnamed: 0,EXPLAIN_TIME,OPERATOR_ID,PREDICATE1,PREDICATE2,PREDICATE3,FILTER_FACTOR1,FILTER_FACTOR2,FILTER_FACTOR3
0,2024-12-03-20.44.01.238859,4,"Q4.INFO IN ('Sweden ', 'Norway ', 'Germany ', 'Denmark ', 'Swedish ', 'Denish ', 'Norwegian ', 'German ', 'USA ', 'CANADA ', 'Netherlands ', 'Brazil ', 'UK ', 'Belgium ', 'Finland ', 'Hungary ', 'Estonia ', 'worldwide ', 'Australia ', 'Spain ', 'France ', 'Japan ', 'Columbia ', 'Slovenia ', 'Israel ', 'Venezuela ', 'Nigeria ', 'Philippines ', 'New Zealand ', 'Ireland ', 'Romania ', 'Non-USA ', 'Bulgaria ', 'Argentina ', 'Malaysia ', 'Singapore ', 'Turkey ', 'Sri Lanka ', 'Italy ', 'Indonesia ', 'South Korea ', 'Vietnam ', 'Slovakia ', 'Czech Republic ', 'China ', 'Portugal ', 'Greece ', 'Republic of Macedonia', 'Serbia ', 'Jamaica ', 'Switzerland ', 'Yugoslavia ', 'Mexico ', 'Austria ', 'Russia ')",,,0.029517,1.0,1.0
1,2024-12-03-20.44.01.238859,9,Q9.COUNTRY_CODE = '[jp]',,,0.024702,1.0,1.0
2,2024-12-03-20.44.01.238859,12,Q5.NOTE IS NULL,,,0.487515,1.0,1.0
3,2024-12-03-20.44.01.238859,14,Q1.PRODUCTION_YEAR <= 2020,1950 <= Q1.PRODUCTION_YEAR,,0.971485,0.894801,1.0
4,2024-12-03-20.44.01.238859,17,Q7.KEYWORD = 'superhero',,,7e-06,1.0,1.0
5,2024-12-03-20.44.01.238859,18,Q8.KIND = 'production companies',,,0.25,1.0,1.0
6,2024-12-03-20.44.01.238859,19,Q6.LINK LIKE '%follows%',,,0.083333,1.0,1.0


In [18]:
df_pred

Unnamed: 0,EXPLAIN_TIME,OPERATOR_ID,PREDICATE1,PREDICATE2,PREDICATE3,FILTER_FACTOR1,FILTER_FACTOR2,FILTER_FACTOR3
0,2024-12-03-20.44.01.238859,4,"Q4.INFO IN ('Sweden ', 'Norway ', 'Germany ', 'Denmark ', 'Swedish ', 'Denish ', 'Norwegian ', 'German ', 'USA ', 'CANADA ', 'Netherlands ', 'Brazil ', 'UK ', 'Belgium ', 'Finland ', 'Hungary ', 'Estonia ', 'worldwide ', 'Australia ', 'Spain ', 'France ', 'Japan ', 'Columbia ', 'Slovenia ', 'Israel ', 'Venezuela ', 'Nigeria ', 'Philippines ', 'New Zealand ', 'Ireland ', 'Romania ', 'Non-USA ', 'Bulgaria ', 'Argentina ', 'Malaysia ', 'Singapore ', 'Turkey ', 'Sri Lanka ', 'Italy ', 'Indonesia ', 'South Korea ', 'Vietnam ', 'Slovakia ', 'Czech Republic ', 'China ', 'Portugal ', 'Greece ', 'Republic of Macedonia', 'Serbia ', 'Jamaica ', 'Switzerland ', 'Yugoslavia ', 'Mexico ', 'Austria ', 'Russia ')",,,0.029517,1.0,1.0
1,2024-12-03-20.44.01.238859,9,Q9.COUNTRY_CODE = '[jp]',,,0.024702,1.0,1.0
2,2024-12-03-20.44.01.238859,12,Q5.NOTE IS NULL,,,0.487515,1.0,1.0
3,2024-12-03-20.44.01.238859,14,Q1.PRODUCTION_YEAR <= 2020,1950 <= Q1.PRODUCTION_YEAR,,0.971485,0.894801,1.0
4,2024-12-03-20.44.01.238859,17,Q7.KEYWORD = 'superhero',,,7e-06,1.0,1.0
5,2024-12-03-20.44.01.238859,18,Q8.KIND = 'production companies',,,0.25,1.0,1.0
6,2024-12-03-20.44.01.238859,19,Q6.LINK LIKE '%follows%',,,0.083333,1.0,1.0


In [19]:
import re
import pandas as pd

def parse_predicate(predicate):
    """
    Parses a SQL predicate string to extract the column, operator, and value.

    Args:
        predicate (str): The predicate string to parse.

    Returns:
        tuple: A tuple containing the column, operator, and value, or None if parsing fails.
    """
    if predicate is None or pd.isna(predicate):
        return None

    # Patterns for different predicate types
    basic_pattern = r'(?P<col>[\w\.]+)\s*(?P<op><=|>=|<|>|=|!=|LIKE)\s*(?P<val>.+)'
    in_pattern = r'(?P<col>[\w\.]+)\s+IN\s*\((?P<val>.*)\)'
    is_null_pattern = r'(?P<col>[\w\.]+)\s+IS\s+(?P<val>NULL|NOT\s+NULL)'

    # Remove parentheses around the predicate if present
    predicate = predicate.strip()

    # Match IS NULL / IS NOT NULL predicates
    match = re.match(is_null_pattern, predicate, flags=re.IGNORECASE)
    if match:
        col = match.group('col')
        op = 'IS'
        val = match.group('val')
        return (col, op, val)

    # Match IN predicates (e.g., Q4.INFO IN ('Sweden', 'Norway'))
    match = re.match(in_pattern, predicate, flags=re.IGNORECASE)
    if match:
        col = match.group('col')
        op = 'IN'
        val = match.group('val')

        # Extract single-quoted strings and strip spaces inside each value
        val_list = [v.strip() for v in re.findall(r"'(.*?)'", val)]
        if val_list:
            return (col, op, val_list)

        # Fall back to splitting by commas and stripping spaces
        val_list = [v.strip() for v in val.split(',') if v.strip()]
        return (col, op, val_list)

    # Match basic predicates (e.g., Q1.PRODUCTION_YEAR <= 2020)
    match = re.match(basic_pattern, predicate)
    if match:
        col = match.group('col')
        op = match.group('op')
        val = match.group('val')
        return (col, op, val)

    # Return None if no pattern matches
    return None


In [20]:
# Function to transform predicate to (COLUMN, OPERATOR, VALUE) format
def transform_predicate(predicate):
    if isinstance(predicate, str):
        # Regex to match predicates in the form (VALUE OPERATOR COLUMN)
        match = re.match(r'(\d+|\'[^\']*\'|\"[^\"]*\")\s*(<=|>=|<|>)\s*([A-Za-z_][A-Za-z0-9_\.]*)', predicate)
        if match:
            value, operator, column = match.groups()
            # Reverse the operation and switch operator
            reversed_operator = {'<=': '>=', '>=': '<=', '<': '>', '>': '<'}[operator]
            return f'{column} {reversed_operator} {value}'
    return predicate  # Return unchanged if not a string or no match

# Function to apply transformation only if at least one of the predicates is not NaN
def apply_transformation(row):
    if pd.notna(row['PREDICATE1']) or pd.notna(row['PREDICATE2']) or pd.notna(row['PREDICATE3']):
        row['PREDICATE1'] = transform_predicate(row['PREDICATE1'])
        row['PREDICATE2'] = transform_predicate(row['PREDICATE2'])
        row['PREDICATE3'] = transform_predicate(row['PREDICATE3'])
    return row

# Extract Join Key

In [21]:
import pandas as pd

def extract_join_keys(explain_predicate_path, timestamp):
    """
    Extract join keys from the EXPLAIN_PREDICATE CSV file where HOW_APPLIED is 'JOIN'.

    Parameters:
        explain_predicate_path (str): Path to the EXPLAIN_PREDICATE CSV file.
        timestamp (str): The timestamp value to filter rows by.

    Returns:
        pd.DataFrame: A DataFrame containing EXPLAIN_TIME, OPERATOR_ID, and JOIN_KEY columns.
    """
    # Define the columns to read from the file
    columns_to_read = ['EXPLAIN_TIME', 'OPERATOR_ID', 'HOW_APPLIED', 'PREDICATE_TEXT']

    # Read the CSV file, selecting only relevant columns
    df = pd.read_csv(explain_predicate_path, usecols=columns_to_read)

    # Strip whitespace from EXPLAIN_TIME and HOW_APPLIED columns
    df['EXPLAIN_TIME'] = df['EXPLAIN_TIME'].str.strip()
    df['HOW_APPLIED'] = df['HOW_APPLIED'].str.strip()

    # Filter rows where EXPLAIN_TIME matches the timestamp and HOW_APPLIED is 'JOIN'
    filtered_df = df[(df['EXPLAIN_TIME'] == timestamp) & (df['HOW_APPLIED'] == 'JOIN')].copy()

    # Remove parentheses from the PREDICATE_TEXT column and assign to JOIN_KEY
    filtered_df.loc[:, 'JOIN_KEY'] = filtered_df['PREDICATE_TEXT'].str.replace(r'[\(\)]', '', regex=True)

    # Select the relevant columns for the result
    df_join = filtered_df[['EXPLAIN_TIME', 'OPERATOR_ID', 'JOIN_KEY']]

    # Check if there are matching rows
    if df_join.empty:
        print("No matching rows found.")
        return pd.DataFrame()  # Return an empty DataFrame

    return df_join

# Example usage:
# explain_predicate_path = 'SimpleQueriesSQ12c/EXPLAIN_PREDICATE.csv'
# query1_ts = '2024-11-28 12:00:00'  # Replace with your timestamp
# result = extract_join_keys(explain_predicate_path, query1_ts)
# print(result)


# Combining Node features

In [22]:
import pandas as pd

# Define the file paths and timestamp value
# File paths
explain_operator_path = f'{base_dir}/EXPLAIN_OPERATOR.csv'
explain_stream_path = f'{base_dir}/EXPLAIN_STREAM.csv'
explain_predicate_path = f'{base_dir}/EXPLAIN_PREDICATE.csv'
# query1_ts = '2024-10-08-04.51.06.406285' # not working

# Use the functions to extract features
df_ops = extract_operator_ids_and_types(explain_operator_path, query1_ts)
df_join = extract_join_keys(explain_predicate_path, query1_ts)
df_table = extract_stream_table(explain_stream_path, query1_ts)
df_pred = extract_local_predicates(explain_predicate_path, query1_ts)

# Apply the transformation row by row
df_pred = df_pred.apply(apply_transformation, axis=1)


  df_local_predicates[filter_factor_columns] = df_local_predicates[filter_factor_columns].fillna(1.0)


In [23]:
df_pred

Unnamed: 0,EXPLAIN_TIME,OPERATOR_ID,PREDICATE1,PREDICATE2,PREDICATE3,FILTER_FACTOR1,FILTER_FACTOR2,FILTER_FACTOR3
0,2024-12-03-20.44.01.238859,4,"Q4.INFO IN ('Sweden ', 'Norway ', 'Germany ', 'Denmark ', 'Swedish ', 'Denish ', 'Norwegian ', 'German ', 'USA ', 'CANADA ', 'Netherlands ', 'Brazil ', 'UK ', 'Belgium ', 'Finland ', 'Hungary ', 'Estonia ', 'worldwide ', 'Australia ', 'Spain ', 'France ', 'Japan ', 'Columbia ', 'Slovenia ', 'Israel ', 'Venezuela ', 'Nigeria ', 'Philippines ', 'New Zealand ', 'Ireland ', 'Romania ', 'Non-USA ', 'Bulgaria ', 'Argentina ', 'Malaysia ', 'Singapore ', 'Turkey ', 'Sri Lanka ', 'Italy ', 'Indonesia ', 'South Korea ', 'Vietnam ', 'Slovakia ', 'Czech Republic ', 'China ', 'Portugal ', 'Greece ', 'Republic of Macedonia', 'Serbia ', 'Jamaica ', 'Switzerland ', 'Yugoslavia ', 'Mexico ', 'Austria ', 'Russia ')",,,0.029517,1.0,1.0
1,2024-12-03-20.44.01.238859,9,Q9.COUNTRY_CODE = '[jp]',,,0.024702,1.0,1.0
2,2024-12-03-20.44.01.238859,12,Q5.NOTE IS NULL,,,0.487515,1.0,1.0
3,2024-12-03-20.44.01.238859,14,Q1.PRODUCTION_YEAR <= 2020,Q1.PRODUCTION_YEAR >= 1950,,0.971485,0.894801,1.0
4,2024-12-03-20.44.01.238859,17,Q7.KEYWORD = 'superhero',,,7e-06,1.0,1.0
5,2024-12-03-20.44.01.238859,18,Q8.KIND = 'production companies',,,0.25,1.0,1.0
6,2024-12-03-20.44.01.238859,19,Q6.LINK LIKE '%follows%',,,0.083333,1.0,1.0


In [24]:
df_pred.columns

Index(['EXPLAIN_TIME', 'OPERATOR_ID', 'PREDICATE1', 'PREDICATE2', 'PREDICATE3',
       'FILTER_FACTOR1', 'FILTER_FACTOR2', 'FILTER_FACTOR3'],
      dtype='object')

In [25]:
df_pred.head()

Unnamed: 0,EXPLAIN_TIME,OPERATOR_ID,PREDICATE1,PREDICATE2,PREDICATE3,FILTER_FACTOR1,FILTER_FACTOR2,FILTER_FACTOR3
0,2024-12-03-20.44.01.238859,4,"Q4.INFO IN ('Sweden ', 'Norway ', 'Germany ', 'Denmark ', 'Swedish ', 'Denish ', 'Norwegian ', 'German ', 'USA ', 'CANADA ', 'Netherlands ', 'Brazil ', 'UK ', 'Belgium ', 'Finland ', 'Hungary ', 'Estonia ', 'worldwide ', 'Australia ', 'Spain ', 'France ', 'Japan ', 'Columbia ', 'Slovenia ', 'Israel ', 'Venezuela ', 'Nigeria ', 'Philippines ', 'New Zealand ', 'Ireland ', 'Romania ', 'Non-USA ', 'Bulgaria ', 'Argentina ', 'Malaysia ', 'Singapore ', 'Turkey ', 'Sri Lanka ', 'Italy ', 'Indonesia ', 'South Korea ', 'Vietnam ', 'Slovakia ', 'Czech Republic ', 'China ', 'Portugal ', 'Greece ', 'Republic of Macedonia', 'Serbia ', 'Jamaica ', 'Switzerland ', 'Yugoslavia ', 'Mexico ', 'Austria ', 'Russia ')",,,0.029517,1.0,1.0
1,2024-12-03-20.44.01.238859,9,Q9.COUNTRY_CODE = '[jp]',,,0.024702,1.0,1.0
2,2024-12-03-20.44.01.238859,12,Q5.NOTE IS NULL,,,0.487515,1.0,1.0
3,2024-12-03-20.44.01.238859,14,Q1.PRODUCTION_YEAR <= 2020,Q1.PRODUCTION_YEAR >= 1950,,0.971485,0.894801,1.0
4,2024-12-03-20.44.01.238859,17,Q7.KEYWORD = 'superhero',,,7e-06,1.0,1.0


In [26]:
df_ops.columns

Index(['EXPLAIN_TIME', 'OPERATOR_ID', 'OPERATOR_TYPE'], dtype='object')

In [27]:
df_ops.shape

(19, 3)

In [28]:
df_table.columns

Index(['EXPLAIN_TIME', 'SOURCE_ID', 'OPERATOR_ID', 'TABLE'], dtype='object')

In [29]:
df_join.columns

Index(['EXPLAIN_TIME', 'OPERATOR_ID', 'JOIN_KEY'], dtype='object')

In [30]:
# Merge DataFrames step-by-step
merged_df = pd.merge(df_ops, df_pred, on=['EXPLAIN_TIME', 'OPERATOR_ID'], how='outer')  # Include all rows
merged_df = pd.merge(merged_df, df_join, on=['EXPLAIN_TIME', 'OPERATOR_ID'], how='outer')
merged_df = pd.merge(merged_df, df_table, on=['EXPLAIN_TIME', 'OPERATOR_ID'], how='outer')
# Ensure all FILTER_FACTOR columns are filled
filter_columns = ['FILTER_FACTOR1', 'FILTER_FACTOR2', 'FILTER_FACTOR3']
merged_df[filter_columns] = merged_df[filter_columns].fillna(1.0)

df = merged_df.copy()


In [31]:
import pandas as pd
import re

# Sample DataFrame provided
# df = pd.DataFrame({
#     'EXPLAIN_TIME': ['2024-10-08-04.48.13.056756', '2024-10-08-04.48.13.056756', '2024-10-08-04.48.13.056756', '2024-10-08-04.48.13.056756'],
#     'OPERATOR_ID': [1, 2, 3, 4],
#     'OPERATOR_TYPE': ['RETURN', 'HSJOIN', 'TBSCAN', 'TBSCAN'],
#     'TABLE': [None, None, 'TPCDS.CUSTOMER', 'TPCDS.DATE_DIM2'],
#     'PREDICATE1': [None, None, None, 'Q2.D_YEAR >= 1958'],
#     'PREDICATE2': [None, None, None, 'Q2.D_MOY = 12'],
#     'PREDICATE3': [None, None, None, None],
#     'JOIN_KEY': [None, 'Q2.D_DATE_SK = Q1.C_FIRST_SHIPTO_DATE_SK', None, None]
# })

# Function to extract the alias from the predicate (before the period)
def extract_alias(predicate):
    if isinstance(predicate, str):
        match = re.search(r'([A-Za-z0-9_]+)\.', predicate)
        if match:
            return match.group(1)  # Return the part before the period (alias)
    return None

# Filter rows where TABLE is not None and at least one PREDICATE column has a value
filtered_df = df[(df['TABLE'].notna()) & (df[['PREDICATE1', 'PREDICATE2', 'PREDICATE3']].notna().any(axis=1))]

# Create a new DataFrame with EXPLAIN_TIME, TABLE, and ALIAS
rows = []
for _, row in filtered_df.iterrows():
    predicates = [row['PREDICATE1'], row['PREDICATE2'], row['PREDICATE3']]
    alias = None
    for predicate in predicates:
        alias = extract_alias(predicate)
        if alias:
            break  # Once an alias is found, no need to check other predicates
    rows.append({'EXPLAIN_TIME': row['EXPLAIN_TIME'], 'TABLE': row['TABLE'], 'ALIAS': alias})

# Create the new dataframe
df_table_alias = pd.DataFrame(rows)

# Display the result
print(df_table_alias)


                 EXPLAIN_TIME                  TABLE ALIAS
0  2024-12-03-20.44.01.238859       AHNAF.MOVIE_INFO    Q4
1  2024-12-03-20.44.01.238859     AHNAF.COMPANY_NAME    Q9
2  2024-12-03-20.44.01.238859  AHNAF.MOVIE_COMPANIES    Q5
3  2024-12-03-20.44.01.238859            AHNAF.TITLE    Q1
4  2024-12-03-20.44.01.238859          AHNAF.KEYWORD    Q7
5  2024-12-03-20.44.01.238859     AHNAF.COMPANY_TYPE    Q8
6  2024-12-03-20.44.01.238859        AHNAF.LINK_TYPE    Q6


# extract optimized statement

In [32]:
if not os.path.isfile('db2.ipynb'):
    os.system('wget https://raw.githubusercontent.com/IBM/db2-jupyter/master/db2.ipynb')

%run db2.ipynb

  firstCommand = "(?:^\s*)([a-zA-Z]+)(?:\s+.*|$)"
  pattern = "\?\*[0-9]+"


         Install itables if you want to enable scrolling of result sets.
Db2 Extensions Loaded. Version: 2024-09-16


In [33]:
from dotenv import dotenv_values

In [34]:
db2creds = dotenv_values('db2con.env')
%sql CONNECT CREDENTIALS db2creds

Connection successful. TESTDB @ localhost 


In [35]:
%sql SET CURRENT SCHEMA AHNAF

Command completed.


In [36]:
explain_stmt = %sql select STATEMENT_TEXT from EXPLAIN_STATEMENT where EXPLAIN_TIME = '{query1_ts}' AND EXPLAIN_LEVEL = 'S'

In [37]:
explain_stmt['STATEMENT_TEXT'][0]

'SELECT Q9.NAME AS "COMPANY_NAME", Q6.LINK AS "LINK_TYPE", Q1.TITLE AS "WESTERN_FOLLOW_UP" FROM AHNAF.TITLE AS Q1, AHNAF.MOVIE_LINK AS Q2, AHNAF.MOVIE_KEYWORD AS Q3, AHNAF.MOVIE_INFO AS Q4, AHNAF.MOVIE_COMPANIES AS Q5, AHNAF.LINK_TYPE AS Q6, AHNAF.KEYWORD AS Q7, AHNAF.COMPANY_TYPE AS Q8, AHNAF.COMPANY_NAME AS Q9 WHERE (Q5.MOVIE_ID = Q4.MOVIE_ID) AND (Q3.MOVIE_ID = Q4.MOVIE_ID) AND (Q2.MOVIE_ID = Q4.MOVIE_ID) AND (Q4.MOVIE_ID = Q1.ID) AND (Q5.COMPANY_ID = Q9.ID) AND (Q5.COMPANY_TYPE_ID = Q8.ID) AND (Q3.KEYWORD_ID = Q7.ID) AND (Q6.ID = Q2.LINK_TYPE_ID) AND (Q1.PRODUCTION_YEAR <= 2020) AND (1950 <= Q1.PRODUCTION_YEAR) AND Q5.NOTE IS NULL AND (Q6.LINK LIKE \'%follows%\') AND (Q7.KEYWORD = \'superhero\') AND (Q8.KIND = \'production companies\') AND (Q9.COUNTRY_CODE = \'[jp]\') AND Q4.INFO IN (\'Sweden               \', \'Norway               \', \'Germany              \', \'Denmark              \', \'Swedish              \', \'Denish               \', \'Norwegian            \', \'German    

In [38]:
import re

# Sample SQL text (multiline)
sql_text = explain_stmt['STATEMENT_TEXT'][0]

# Function to extract the FROM clause, handling multiline SQL
def extract_from_clause(sql_text):
    # Extract the FROM clause (up to WHERE, GROUP BY, etc.), handling newlines and extra spaces
    from_clause_match = re.search(r'\bFROM\b\s+(.+?)\s*(\bWHERE\b|\bGROUP BY\b|\bORDER BY\b|$)', sql_text, re.IGNORECASE | re.DOTALL)
    if from_clause_match:
        from_clause = from_clause_match.group(1).strip()
        return from_clause
    return None

# Function to extract table names and their aliases from the FROM clause
def extract_table_aliases(sql_text):
    # Extract the FROM clause
    from_clause = extract_from_clause(sql_text)
    
    if from_clause:
        print(f"Captured FROM clause: {from_clause}")  # Debugging step to print captured FROM clause

        # Find all table names and aliases in the form "table AS alias" or "table alias"
        table_alias_pairs = re.findall(r'([A-Za-z0-9_\.]+)\s+(?:AS\s+)?([A-Za-z0-9_]+)', from_clause, re.IGNORECASE)

        # Flip the dictionary: aliases as keys, table names as values
        alias_table_dict = {pair[1]: pair[0] for pair in table_alias_pairs}
        return alias_table_dict
    return {}

# Extract table aliases (flipped)
alias_table_dict = extract_table_aliases(sql_text)

# Display the extracted alias-table dictionary
print("Extracted alias-table dictionary:", alias_table_dict)


Captured FROM clause: AHNAF.TITLE AS Q1, AHNAF.MOVIE_LINK AS Q2, AHNAF.MOVIE_KEYWORD AS Q3, AHNAF.MOVIE_INFO AS Q4, AHNAF.MOVIE_COMPANIES AS Q5, AHNAF.LINK_TYPE AS Q6, AHNAF.KEYWORD AS Q7, AHNAF.COMPANY_TYPE AS Q8, AHNAF.COMPANY_NAME AS Q9
Extracted alias-table dictionary: {'Q1': 'AHNAF.TITLE', 'Q2': 'AHNAF.MOVIE_LINK', 'Q3': 'AHNAF.MOVIE_KEYWORD', 'Q4': 'AHNAF.MOVIE_INFO', 'Q5': 'AHNAF.MOVIE_COMPANIES', 'Q6': 'AHNAF.LINK_TYPE', 'Q7': 'AHNAF.KEYWORD', 'Q8': 'AHNAF.COMPANY_TYPE', 'Q9': 'AHNAF.COMPANY_NAME'}


In [39]:
df.head()

Unnamed: 0,EXPLAIN_TIME,OPERATOR_ID,OPERATOR_TYPE,PREDICATE1,PREDICATE2,PREDICATE3,FILTER_FACTOR1,FILTER_FACTOR2,FILTER_FACTOR3,JOIN_KEY,SOURCE_ID,TABLE
0,2024-12-03-20.44.01.238859,1,RETURN,,,,1.0,1.0,1.0,,,
1,2024-12-03-20.44.01.238859,2,TQ,,,,1.0,1.0,1.0,,,
2,2024-12-03-20.44.01.238859,3,HSJOIN,,,,1.0,1.0,1.0,Q2.MOVIE_ID = Q4.MOVIE_ID,,
3,2024-12-03-20.44.01.238859,4,TBSCAN,"Q4.INFO IN ('Sweden ', 'Norway ', 'Germany ', 'Denmark ', 'Swedish ', 'Denish ', 'Norwegian ', 'German ', 'USA ', 'CANADA ', 'Netherlands ', 'Brazil ', 'UK ', 'Belgium ', 'Finland ', 'Hungary ', 'Estonia ', 'worldwide ', 'Australia ', 'Spain ', 'France ', 'Japan ', 'Columbia ', 'Slovenia ', 'Israel ', 'Venezuela ', 'Nigeria ', 'Philippines ', 'New Zealand ', 'Ireland ', 'Romania ', 'Non-USA ', 'Bulgaria ', 'Argentina ', 'Malaysia ', 'Singapore ', 'Turkey ', 'Sri Lanka ', 'Italy ', 'Indonesia ', 'South Korea ', 'Vietnam ', 'Slovakia ', 'Czech Republic ', 'China ', 'Portugal ', 'Greece ', 'Republic of Macedonia', 'Serbia ', 'Jamaica ', 'Switzerland ', 'Yugoslavia ', 'Mexico ', 'Austria ', 'Russia ')",,,0.029517,1.0,1.0,,-1.0,AHNAF.MOVIE_INFO
4,2024-12-03-20.44.01.238859,5,HSJOIN,,,,1.0,1.0,1.0,Q6.ID = Q2.LINK_TYPE_ID,,


In [40]:
pd.set_option('display.max_colwidth', None)

In [41]:
import pandas as pd
import re

# # Sample alias-table dictionary extracted from SQL
# alias_table_dict = {
#     'Q1': 'TPCDS.CUSTOMER',
#     'Q2': 'TPCDS.DATE_DIM2'
# }

# Function to replace all aliases in the predicate with the corresponding table names using alias_table_dict
def replace_alias_with_table(predicate, alias_table_dict):
    if isinstance(predicate, str):
        # Replace all aliases in the string with the corresponding table names
        for alias, table_name in alias_table_dict.items():
            predicate = re.sub(fr'\b{alias}\b\.', f'{table_name}.', predicate)
    return predicate

# Apply the replacement to each predicate column
for index, row in df.iterrows():
    df.at[index, 'PREDICATE1'] = replace_alias_with_table(row['PREDICATE1'], alias_table_dict)
    df.at[index, 'PREDICATE2'] = replace_alias_with_table(row['PREDICATE2'], alias_table_dict)
    df.at[index, 'PREDICATE3'] = replace_alias_with_table(row['PREDICATE3'], alias_table_dict)
    df.at[index, 'JOIN_KEY'] = replace_alias_with_table(row['JOIN_KEY'], alias_table_dict)

# Display the updated DataFrame
print(df)


                  EXPLAIN_TIME  OPERATOR_ID OPERATOR_TYPE  \
0   2024-12-03-20.44.01.238859            1        RETURN   
1   2024-12-03-20.44.01.238859            2        TQ       
2   2024-12-03-20.44.01.238859            3        HSJOIN   
3   2024-12-03-20.44.01.238859            4        TBSCAN   
4   2024-12-03-20.44.01.238859            5        HSJOIN   
5   2024-12-03-20.44.01.238859            6        HSJOIN   
6   2024-12-03-20.44.01.238859            7        TBSCAN   
7   2024-12-03-20.44.01.238859            8        HSJOIN   
8   2024-12-03-20.44.01.238859            9        TBSCAN   
9   2024-12-03-20.44.01.238859           10        HSJOIN   
10  2024-12-03-20.44.01.238859           11        HSJOIN   
11  2024-12-03-20.44.01.238859           12        TBSCAN   
12  2024-12-03-20.44.01.238859           13        HSJOIN   
13  2024-12-03-20.44.01.238859           14        TBSCAN   
14  2024-12-03-20.44.01.238859           15        HSJOIN   
15  2024-12-03-20.44.01.

In [42]:
df.columns

Index(['EXPLAIN_TIME', 'OPERATOR_ID', 'OPERATOR_TYPE', 'PREDICATE1',
       'PREDICATE2', 'PREDICATE3', 'FILTER_FACTOR1', 'FILTER_FACTOR2',
       'FILTER_FACTOR3', 'JOIN_KEY', 'SOURCE_ID', 'TABLE'],
      dtype='object')

In [43]:
# Creating new columns for parsed predicates
for i in range(1, 4):
    print(i)
    pred_col = f'PREDICATE{i}'
    parsed_col = f'PARSED_PRED{i}'
    df[parsed_col] = df[pred_col].apply(parse_predicate)

# Display the updated dataframe
df

1
2
3


Unnamed: 0,EXPLAIN_TIME,OPERATOR_ID,OPERATOR_TYPE,PREDICATE1,PREDICATE2,PREDICATE3,FILTER_FACTOR1,FILTER_FACTOR2,FILTER_FACTOR3,JOIN_KEY,SOURCE_ID,TABLE,PARSED_PRED1,PARSED_PRED2,PARSED_PRED3
0,2024-12-03-20.44.01.238859,1,RETURN,,,,1.0,1.0,1.0,,,,,,
1,2024-12-03-20.44.01.238859,2,TQ,,,,1.0,1.0,1.0,,,,,,
2,2024-12-03-20.44.01.238859,3,HSJOIN,,,,1.0,1.0,1.0,AHNAF.MOVIE_LINK.MOVIE_ID = AHNAF.MOVIE_INFO.MOVIE_ID,,,,,
3,2024-12-03-20.44.01.238859,4,TBSCAN,"AHNAF.MOVIE_INFO.INFO IN ('Sweden ', 'Norway ', 'Germany ', 'Denmark ', 'Swedish ', 'Denish ', 'Norwegian ', 'German ', 'USA ', 'CANADA ', 'Netherlands ', 'Brazil ', 'UK ', 'Belgium ', 'Finland ', 'Hungary ', 'Estonia ', 'worldwide ', 'Australia ', 'Spain ', 'France ', 'Japan ', 'Columbia ', 'Slovenia ', 'Israel ', 'Venezuela ', 'Nigeria ', 'Philippines ', 'New Zealand ', 'Ireland ', 'Romania ', 'Non-USA ', 'Bulgaria ', 'Argentina ', 'Malaysia ', 'Singapore ', 'Turkey ', 'Sri Lanka ', 'Italy ', 'Indonesia ', 'South Korea ', 'Vietnam ', 'Slovakia ', 'Czech Republic ', 'China ', 'Portugal ', 'Greece ', 'Republic of Macedonia', 'Serbia ', 'Jamaica ', 'Switzerland ', 'Yugoslavia ', 'Mexico ', 'Austria ', 'Russia ')",,,0.029517,1.0,1.0,,-1.0,AHNAF.MOVIE_INFO,"(AHNAF.MOVIE_INFO.INFO, IN, [Sweden, Norway, Germany, Denmark, Swedish, Denish, Norwegian, German, USA, CANADA, Netherlands, Brazil, UK, Belgium, Finland, Hungary, Estonia, worldwide, Australia, Spain, France, Japan, Columbia, Slovenia, Israel, Venezuela, Nigeria, Philippines, New Zealand, Ireland, Romania, Non-USA, Bulgaria, Argentina, Malaysia, Singapore, Turkey, Sri Lanka, Italy, Indonesia, South Korea, Vietnam, Slovakia, Czech Republic, China, Portugal, Greece, Republic of Macedonia, Serbia, Jamaica, Switzerland, Yugoslavia, Mexico, Austria, Russia])",,
4,2024-12-03-20.44.01.238859,5,HSJOIN,,,,1.0,1.0,1.0,AHNAF.LINK_TYPE.ID = AHNAF.MOVIE_LINK.LINK_TYPE_ID,,,,,
5,2024-12-03-20.44.01.238859,6,HSJOIN,,,,1.0,1.0,1.0,AHNAF.MOVIE_LINK.MOVIE_ID = AHNAF.MOVIE_COMPANIES.MOVIE_ID,,,,,
6,2024-12-03-20.44.01.238859,7,TBSCAN,,,,1.0,1.0,1.0,,-1.0,AHNAF.MOVIE_LINK,,,
7,2024-12-03-20.44.01.238859,8,HSJOIN,,,,1.0,1.0,1.0,AHNAF.MOVIE_COMPANIES.COMPANY_ID = AHNAF.COMPANY_NAME.ID,,,,,
8,2024-12-03-20.44.01.238859,9,TBSCAN,AHNAF.COMPANY_NAME.COUNTRY_CODE = '[jp]',,,0.024702,1.0,1.0,,-1.0,AHNAF.COMPANY_NAME,"(AHNAF.COMPANY_NAME.COUNTRY_CODE, =, '[jp]')",,
9,2024-12-03-20.44.01.238859,10,HSJOIN,,,,1.0,1.0,1.0,AHNAF.MOVIE_COMPANIES.COMPANY_TYPE_ID = AHNAF.COMPANY_TYPE.ID,,,,,


# Calculating Tree Node heights

In [44]:
# Read the CSV file
file_path = f'{base_dir}/EXPLAIN_STREAM.csv'
df_stream = pd.read_csv(file_path)

# Filter for the given EXPLAIN_TIME value and non-NaN OBJECT_NAME
df_stream = df_stream.loc[(df_stream['EXPLAIN_TIME'] == query1_ts)].copy()

In [45]:
df_stream.columns

Index(['EXPLAIN_REQUESTER', 'EXPLAIN_TIME', 'SOURCE_NAME', 'SOURCE_SCHEMA',
       'SOURCE_VERSION', 'EXPLAIN_LEVEL', 'STMTNO', 'SECTNO', 'STREAM_ID',
       'SOURCE_TYPE', 'SOURCE_ID', 'TARGET_TYPE', 'TARGET_ID', 'OBJECT_SCHEMA',
       'OBJECT_NAME', 'STREAM_COUNT', 'COLUMN_COUNT', 'PREDICATE_ID',
       'COLUMN_NAMES', 'PMID', 'SINGLE_NODE', 'PARTITION_COLUMNS',
       'SEQUENCE_SIZES', 'OBJECT_TENANTID'],
      dtype='object')

In [46]:
df_stream = df_stream[['SOURCE_ID', 'TARGET_ID']]

In [47]:
df_stream.head()

Unnamed: 0,SOURCE_ID,TARGET_ID
3942,-1,4
3943,4,3
3944,-1,7
3945,7,6
3946,-1,9


In [48]:
df_stream

Unnamed: 0,SOURCE_ID,TARGET_ID
3942,-1,4
3943,4,3
3944,-1,7
3945,7,6
3946,-1,9
3947,9,8
3948,-1,12
3949,12,11
3950,-1,14
3951,14,13


In [49]:

import pandas as pd
import numpy as np

# Step 1: Generate the adjacency matrix (from previous code)
# data = {
#     'SOURCE_ID': [-1, 3, -1, 4, 2],
#     'TARGET_ID': [3, 2, 4, 2, 1]
# }
# df = pd.DataFrame(data)

# Extract all unique nodes from SOURCE_ID and TARGET_ID
node_ids = set(df_stream['SOURCE_ID']).union(set(df_stream['TARGET_ID'])) - {-1}
node_ids = sorted(node_ids)

# Map node ids to indices for the adjacency matrix
node_to_index = {node: idx for idx, node in enumerate(node_ids)}

# Initialize the adjacency matrix with zeros
n = len(node_ids)
adj_matrix = np.zeros((n, n), dtype=int)

# Fill the adjacency matrix based on the parent-child relationships
for _, row in df_stream.iterrows():
    source = row['SOURCE_ID']
    target = row['TARGET_ID']
    
    if source != -1:
        parent_idx = node_to_index[target]
        child_idx = node_to_index[source]
        adj_matrix[parent_idx][child_idx] = 1

# Convert to DataFrame for better display
adj_matrix_df = pd.DataFrame(adj_matrix, index=node_ids, columns=node_ids)
print("Adjacency Matrix:")
print(adj_matrix_df)


Adjacency Matrix:
    1   2   3   4   5   6   7   8   9   10  11  12  13  14  15  16  17  18  19
1    0   1   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
2    0   0   1   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
3    0   0   0   1   1   0   0   0   0   0   0   0   0   0   0   0   0   0   0
4    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
5    0   0   0   0   0   1   0   0   0   0   0   0   0   0   0   0   0   0   1
6    0   0   0   0   0   0   1   1   0   0   0   0   0   0   0   0   0   0   0
7    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
8    0   0   0   0   0   0   0   0   1   1   0   0   0   0   0   0   0   0   0
9    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
10   0   0   0   0   0   0   0   0   0   0   1   0   0   0   0   0   0   1   0
11   0   0   0   0   0   0   0   0   0   0   0   1   1   0   0   0   0   0   0
12   0   0   0   0   0   0   0   0

In [50]:
# Reset the index of the adjacency matrix to be 0-based
adj_matrix_df.index = adj_matrix_df.index - 1  # Convert row indices to 0-based
adj_matrix_df.columns = adj_matrix_df.columns - 1  # Convert column indices to 0-based

print("Adjacency Matrix with 0-Based Index:")
print(adj_matrix_df)

Adjacency Matrix with 0-Based Index:
    0   1   2   3   4   5   6   7   8   9   10  11  12  13  14  15  16  17  18
0    0   1   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
1    0   0   1   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
2    0   0   0   1   1   0   0   0   0   0   0   0   0   0   0   0   0   0   0
3    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
4    0   0   0   0   0   1   0   0   0   0   0   0   0   0   0   0   0   0   1
5    0   0   0   0   0   0   1   1   0   0   0   0   0   0   0   0   0   0   0
6    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
7    0   0   0   0   0   0   0   0   1   1   0   0   0   0   0   0   0   0   0
8    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
9    0   0   0   0   0   0   0   0   0   0   1   0   0   0   0   0   0   1   0
10   0   0   0   0   0   0   0   0   0   0   0   1   1   0   0   0   0   0   0
11   0   0   0 

# Find the BFS sequence

In [51]:
import numpy as np
from collections import deque

def bfs_from_adj_matrix(adj_matrix, start_node):
    """
    Perform BFS on a graph represented by an adjacency matrix.

    Args:
        adj_matrix (np.ndarray): Adjacency matrix of the graph.
        start_node (int): Starting node for BFS.

    Returns:
        list: List of nodes in the order they are visited.
    """
    num_nodes = adj_matrix.shape[0]  # Number of nodes
    visited = [False] * num_nodes    # Track visited nodes
    queue = deque([start_node])      # Initialize queue with the start node
    visited[start_node] = True       # Mark the start node as visited

    bfs_order = []  # List to store BFS traversal order

    while queue:
        node = queue.popleft()       # Dequeue the next node
        bfs_order.append(node)       # Record the visited node

        # Check all neighbors (columns in the row of this node)
        for neighbor, is_connected in enumerate(adj_matrix[node]):
            if is_connected and not visited[neighbor]:  # If connected and not visited
                queue.append(neighbor)                  # Enqueue the neighbor
                visited[neighbor] = True                # Mark as visited

    return bfs_order

# Convert adj_matrix_df to NumPy array
adj_matrix_np = adj_matrix_df.values

# Perform BFS starting from node 0
bfs_order = bfs_from_adj_matrix(adj_matrix_np, start_node=0)
print("BFS Row Index Order:", bfs_order)


BFS Row Index Order: [0, 1, 2, 3, 4, 5, 18, 6, 7, 8, 9, 10, 17, 11, 12, 13, 14, 15, 16]


In [52]:
# Convert 0-based indexed adjacency matrix to adj_list
adj_list = []

# Iterate over the adjacency matrix
for parent_idx, row in enumerate(adj_matrix):  # parent_idx is the row index
    for child_idx, is_child in enumerate(row):  # child_idx is the column index
        if is_child:  # If there's a connection (value is 1)
            adj_list.append((parent_idx, child_idx))  # Use 0-based indices directly

adj_list_tensor = torch.LongTensor(np.array(adj_list))

print("Adjacency List:")
print(adj_list_tensor)


Adjacency List:
tensor([[ 0,  1],
        [ 1,  2],
        [ 2,  3],
        [ 2,  4],
        [ 4,  5],
        [ 4, 18],
        [ 5,  6],
        [ 5,  7],
        [ 7,  8],
        [ 7,  9],
        [ 9, 10],
        [ 9, 17],
        [10, 11],
        [10, 12],
        [12, 13],
        [12, 14],
        [14, 15],
        [14, 16]])


In [53]:
# Step 1: Perform BFS to get the order of nodes
bfs_order = bfs_from_adj_matrix(adj_matrix, start_node=0)  # Adjust start_node as needed

# Step 2: Create a mapping from BFS order to indices for sorting
bfs_index_map = {node: i for i, node in enumerate(bfs_order)}

# Step 3: Reorder adj_list based on BFS order
adj_list = [
    (parent, child)
    for parent in bfs_order
    for child, is_child in enumerate(adj_matrix[parent])
    if is_child  # Keep only valid edges
]

# Step 4: Convert to tensor if needed
adj_list_tensor = torch.LongTensor(np.array(adj_list))

print("Reordered Adjacency List:")
print(adj_list_tensor)

Reordered Adjacency List:
tensor([[ 0,  1],
        [ 1,  2],
        [ 2,  3],
        [ 2,  4],
        [ 4,  5],
        [ 4, 18],
        [ 5,  6],
        [ 5,  7],
        [ 7,  8],
        [ 7,  9],
        [ 9, 10],
        [ 9, 17],
        [10, 11],
        [10, 12],
        [12, 13],
        [12, 14],
        [14, 15],
        [14, 16]])


In [54]:
edge_index = adj_list_tensor.t()

In [55]:
edge_index

tensor([[ 0,  1,  2,  2,  4,  4,  5,  5,  7,  7,  9,  9, 10, 10, 12, 12, 14, 14],
        [ 1,  2,  3,  4,  5, 18,  6,  7,  8,  9, 10, 17, 11, 12, 13, 14, 15, 16]])

In [56]:
from model.database_util import floyd_warshall_rewrite

In [57]:
df

Unnamed: 0,EXPLAIN_TIME,OPERATOR_ID,OPERATOR_TYPE,PREDICATE1,PREDICATE2,PREDICATE3,FILTER_FACTOR1,FILTER_FACTOR2,FILTER_FACTOR3,JOIN_KEY,SOURCE_ID,TABLE,PARSED_PRED1,PARSED_PRED2,PARSED_PRED3
0,2024-12-03-20.44.01.238859,1,RETURN,,,,1.0,1.0,1.0,,,,,,
1,2024-12-03-20.44.01.238859,2,TQ,,,,1.0,1.0,1.0,,,,,,
2,2024-12-03-20.44.01.238859,3,HSJOIN,,,,1.0,1.0,1.0,AHNAF.MOVIE_LINK.MOVIE_ID = AHNAF.MOVIE_INFO.MOVIE_ID,,,,,
3,2024-12-03-20.44.01.238859,4,TBSCAN,"AHNAF.MOVIE_INFO.INFO IN ('Sweden ', 'Norway ', 'Germany ', 'Denmark ', 'Swedish ', 'Denish ', 'Norwegian ', 'German ', 'USA ', 'CANADA ', 'Netherlands ', 'Brazil ', 'UK ', 'Belgium ', 'Finland ', 'Hungary ', 'Estonia ', 'worldwide ', 'Australia ', 'Spain ', 'France ', 'Japan ', 'Columbia ', 'Slovenia ', 'Israel ', 'Venezuela ', 'Nigeria ', 'Philippines ', 'New Zealand ', 'Ireland ', 'Romania ', 'Non-USA ', 'Bulgaria ', 'Argentina ', 'Malaysia ', 'Singapore ', 'Turkey ', 'Sri Lanka ', 'Italy ', 'Indonesia ', 'South Korea ', 'Vietnam ', 'Slovakia ', 'Czech Republic ', 'China ', 'Portugal ', 'Greece ', 'Republic of Macedonia', 'Serbia ', 'Jamaica ', 'Switzerland ', 'Yugoslavia ', 'Mexico ', 'Austria ', 'Russia ')",,,0.029517,1.0,1.0,,-1.0,AHNAF.MOVIE_INFO,"(AHNAF.MOVIE_INFO.INFO, IN, [Sweden, Norway, Germany, Denmark, Swedish, Denish, Norwegian, German, USA, CANADA, Netherlands, Brazil, UK, Belgium, Finland, Hungary, Estonia, worldwide, Australia, Spain, France, Japan, Columbia, Slovenia, Israel, Venezuela, Nigeria, Philippines, New Zealand, Ireland, Romania, Non-USA, Bulgaria, Argentina, Malaysia, Singapore, Turkey, Sri Lanka, Italy, Indonesia, South Korea, Vietnam, Slovakia, Czech Republic, China, Portugal, Greece, Republic of Macedonia, Serbia, Jamaica, Switzerland, Yugoslavia, Mexico, Austria, Russia])",,
4,2024-12-03-20.44.01.238859,5,HSJOIN,,,,1.0,1.0,1.0,AHNAF.LINK_TYPE.ID = AHNAF.MOVIE_LINK.LINK_TYPE_ID,,,,,
5,2024-12-03-20.44.01.238859,6,HSJOIN,,,,1.0,1.0,1.0,AHNAF.MOVIE_LINK.MOVIE_ID = AHNAF.MOVIE_COMPANIES.MOVIE_ID,,,,,
6,2024-12-03-20.44.01.238859,7,TBSCAN,,,,1.0,1.0,1.0,,-1.0,AHNAF.MOVIE_LINK,,,
7,2024-12-03-20.44.01.238859,8,HSJOIN,,,,1.0,1.0,1.0,AHNAF.MOVIE_COMPANIES.COMPANY_ID = AHNAF.COMPANY_NAME.ID,,,,,
8,2024-12-03-20.44.01.238859,9,TBSCAN,AHNAF.COMPANY_NAME.COUNTRY_CODE = '[jp]',,,0.024702,1.0,1.0,,-1.0,AHNAF.COMPANY_NAME,"(AHNAF.COMPANY_NAME.COUNTRY_CODE, =, '[jp]')",,
9,2024-12-03-20.44.01.238859,10,HSJOIN,,,,1.0,1.0,1.0,AHNAF.MOVIE_COMPANIES.COMPANY_TYPE_ID = AHNAF.COMPANY_TYPE.ID,,,,,


In [58]:
bfs_order

[0, 1, 2, 3, 4, 5, 18, 6, 7, 8, 9, 10, 17, 11, 12, 13, 14, 15, 16]

In [59]:
# Reorder the DataFrame rows based on BFS order
df = df.iloc[bfs_order].reset_index(drop=True)

df

Unnamed: 0,EXPLAIN_TIME,OPERATOR_ID,OPERATOR_TYPE,PREDICATE1,PREDICATE2,PREDICATE3,FILTER_FACTOR1,FILTER_FACTOR2,FILTER_FACTOR3,JOIN_KEY,SOURCE_ID,TABLE,PARSED_PRED1,PARSED_PRED2,PARSED_PRED3
0,2024-12-03-20.44.01.238859,1,RETURN,,,,1.0,1.0,1.0,,,,,,
1,2024-12-03-20.44.01.238859,2,TQ,,,,1.0,1.0,1.0,,,,,,
2,2024-12-03-20.44.01.238859,3,HSJOIN,,,,1.0,1.0,1.0,AHNAF.MOVIE_LINK.MOVIE_ID = AHNAF.MOVIE_INFO.MOVIE_ID,,,,,
3,2024-12-03-20.44.01.238859,4,TBSCAN,"AHNAF.MOVIE_INFO.INFO IN ('Sweden ', 'Norway ', 'Germany ', 'Denmark ', 'Swedish ', 'Denish ', 'Norwegian ', 'German ', 'USA ', 'CANADA ', 'Netherlands ', 'Brazil ', 'UK ', 'Belgium ', 'Finland ', 'Hungary ', 'Estonia ', 'worldwide ', 'Australia ', 'Spain ', 'France ', 'Japan ', 'Columbia ', 'Slovenia ', 'Israel ', 'Venezuela ', 'Nigeria ', 'Philippines ', 'New Zealand ', 'Ireland ', 'Romania ', 'Non-USA ', 'Bulgaria ', 'Argentina ', 'Malaysia ', 'Singapore ', 'Turkey ', 'Sri Lanka ', 'Italy ', 'Indonesia ', 'South Korea ', 'Vietnam ', 'Slovakia ', 'Czech Republic ', 'China ', 'Portugal ', 'Greece ', 'Republic of Macedonia', 'Serbia ', 'Jamaica ', 'Switzerland ', 'Yugoslavia ', 'Mexico ', 'Austria ', 'Russia ')",,,0.029517,1.0,1.0,,-1.0,AHNAF.MOVIE_INFO,"(AHNAF.MOVIE_INFO.INFO, IN, [Sweden, Norway, Germany, Denmark, Swedish, Denish, Norwegian, German, USA, CANADA, Netherlands, Brazil, UK, Belgium, Finland, Hungary, Estonia, worldwide, Australia, Spain, France, Japan, Columbia, Slovenia, Israel, Venezuela, Nigeria, Philippines, New Zealand, Ireland, Romania, Non-USA, Bulgaria, Argentina, Malaysia, Singapore, Turkey, Sri Lanka, Italy, Indonesia, South Korea, Vietnam, Slovakia, Czech Republic, China, Portugal, Greece, Republic of Macedonia, Serbia, Jamaica, Switzerland, Yugoslavia, Mexico, Austria, Russia])",,
4,2024-12-03-20.44.01.238859,5,HSJOIN,,,,1.0,1.0,1.0,AHNAF.LINK_TYPE.ID = AHNAF.MOVIE_LINK.LINK_TYPE_ID,,,,,
5,2024-12-03-20.44.01.238859,6,HSJOIN,,,,1.0,1.0,1.0,AHNAF.MOVIE_LINK.MOVIE_ID = AHNAF.MOVIE_COMPANIES.MOVIE_ID,,,,,
6,2024-12-03-20.44.01.238859,19,TBSCAN,AHNAF.LINK_TYPE.LINK LIKE '%follows%',,,0.083333,1.0,1.0,,-1.0,AHNAF.LINK_TYPE,"(AHNAF.LINK_TYPE.LINK, LIKE, '%follows%')",,
7,2024-12-03-20.44.01.238859,7,TBSCAN,,,,1.0,1.0,1.0,,-1.0,AHNAF.MOVIE_LINK,,,
8,2024-12-03-20.44.01.238859,8,HSJOIN,,,,1.0,1.0,1.0,AHNAF.MOVIE_COMPANIES.COMPANY_ID = AHNAF.COMPANY_NAME.ID,,,,,
9,2024-12-03-20.44.01.238859,9,TBSCAN,AHNAF.COMPANY_NAME.COUNTRY_CODE = '[jp]',,,0.024702,1.0,1.0,,-1.0,AHNAF.COMPANY_NAME,"(AHNAF.COMPANY_NAME.COUNTRY_CODE, =, '[jp]')",,


In [60]:
df.shape
N = df.shape[0]

In [61]:
if len(edge_index) == 0:
            shortest_path_result = np.array([[0]])
            path = np.array([[0]])
            adj = torch.tensor([[0]]).bool()
else:
            adj = torch.zeros([N,N], dtype=torch.bool)
            adj[edge_index[0,:], edge_index[1,:]] = True
            
            shortest_path_result = floyd_warshall_rewrite(adj.numpy())
        
rel_pos = torch.from_numpy((shortest_path_result)).long()

In [62]:
rel_pos

tensor([[ 0,  1,  2,  3,  3,  4,  5,  5,  6,  6,  7,  8,  8,  9,  9, 10, 10,  7,
          4],
        [60,  0,  1,  2,  2,  3,  4,  4,  5,  5,  6,  7,  7,  8,  8,  9,  9,  6,
          3],
        [60, 60,  0,  1,  1,  2,  3,  3,  4,  4,  5,  6,  6,  7,  7,  8,  8,  5,
          2],
        [60, 60, 60,  0, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60,
         60],
        [60, 60, 60, 60,  0,  1,  2,  2,  3,  3,  4,  5,  5,  6,  6,  7,  7,  4,
          1],
        [60, 60, 60, 60, 60,  0,  1,  1,  2,  2,  3,  4,  4,  5,  5,  6,  6,  3,
         60],
        [60, 60, 60, 60, 60, 60,  0, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60,
         60],
        [60, 60, 60, 60, 60, 60, 60,  0,  1,  1,  2,  3,  3,  4,  4,  5,  5,  2,
         60],
        [60, 60, 60, 60, 60, 60, 60, 60,  0, 60, 60, 60, 60, 60, 60, 60, 60, 60,
         60],
        [60, 60, 60, 60, 60, 60, 60, 60, 60,  0,  1,  2,  2,  3,  3,  4,  4,  1,
         60],
        [60, 60, 60, 60, 60, 60, 60, 60, 60, 60,  

In [63]:
def pad_attn_bias_unsqueeze(x, padlen):
    xlen = x.size(0)
    if xlen < padlen:
        new_x = x.new_zeros([padlen, padlen], dtype=x.dtype).fill_(float('-inf'))
        new_x[:xlen, :xlen] = x
        new_x[xlen:, :xlen] = 0
        x = new_x
    return x.unsqueeze(0)

In [64]:
attn_bias = torch.zeros([N+1,N+1], dtype=torch.float)
attn_bias[1:, 1:][rel_pos >= rel_pos_max] = float('-inf')

In [65]:
attn_bias = pad_attn_bias_unsqueeze(attn_bias, max_node + 1)

In [66]:
def pad_rel_pos_unsqueeze(x, padlen):
    x = x + 1
    xlen = x.size(0)
    if xlen < padlen:
        new_x = x.new_zeros([padlen, padlen], dtype=x.dtype)
        new_x[:xlen, :xlen] = x
        x = new_x
    return x.unsqueeze(0)

In [67]:
rel_pos.shape

torch.Size([19, 19])

In [68]:
rel_pos

tensor([[ 0,  1,  2,  3,  3,  4,  5,  5,  6,  6,  7,  8,  8,  9,  9, 10, 10,  7,
          4],
        [60,  0,  1,  2,  2,  3,  4,  4,  5,  5,  6,  7,  7,  8,  8,  9,  9,  6,
          3],
        [60, 60,  0,  1,  1,  2,  3,  3,  4,  4,  5,  6,  6,  7,  7,  8,  8,  5,
          2],
        [60, 60, 60,  0, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60,
         60],
        [60, 60, 60, 60,  0,  1,  2,  2,  3,  3,  4,  5,  5,  6,  6,  7,  7,  4,
          1],
        [60, 60, 60, 60, 60,  0,  1,  1,  2,  2,  3,  4,  4,  5,  5,  6,  6,  3,
         60],
        [60, 60, 60, 60, 60, 60,  0, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60,
         60],
        [60, 60, 60, 60, 60, 60, 60,  0,  1,  1,  2,  3,  3,  4,  4,  5,  5,  2,
         60],
        [60, 60, 60, 60, 60, 60, 60, 60,  0, 60, 60, 60, 60, 60, 60, 60, 60, 60,
         60],
        [60, 60, 60, 60, 60, 60, 60, 60, 60,  0,  1,  2,  2,  3,  3,  4,  4,  1,
         60],
        [60, 60, 60, 60, 60, 60, 60, 60, 60, 60,  

# creating attn_bias tensor

In [69]:
attn_bias = torch.zeros([N+1,N+1], dtype=torch.float)
attn_bias[1:, 1:][rel_pos >= rel_pos_max] = float('-inf')
#attn_bias = pad_attn_bias_unsqueeze(attn_bias, max_node + 1)

In [70]:
attn_bias = pad_attn_bias_unsqueeze(attn_bias, max_node + 1)
rel_pos = pad_rel_pos_unsqueeze(rel_pos, max_node)

In [71]:
attn_bias.shape

torch.Size([1, 31, 31])

In [72]:
rel_pos.shape

torch.Size([1, 30, 30])

# Computing longest path for each node

In [73]:
# Step 6: Define a function to compute the longest path from a node to a leaf
def longest_path_from_node(node_idx, adj_matrix, memo):
    # If this node's longest path is already computed, return the cached result
    if node_idx in memo:
        return memo[node_idx]
    
    # If the node is a leaf node (no outgoing edges)
    if np.sum(adj_matrix[node_idx]) == 0:
        memo[node_idx] = 0  # Longest path from a leaf is 0
        return 0
    
    # Otherwise, recursively calculate the longest path from this node to a leaf
    max_length = 0
    for child_idx in range(len(adj_matrix)):
        if adj_matrix[node_idx][child_idx] == 1:  # There's a connection to this child
            max_length = max(max_length, 1 + longest_path_from_node(child_idx, adj_matrix, memo))
    
    memo[node_idx] = max_length  # Memoize the result
    return memo[node_idx]

memo = {}
longest_paths = {}

for node in node_ids:
    node_idx = node_to_index[node]
    longest_paths[node] = longest_path_from_node(node_idx, adj_matrix, memo) + 1

# Step 8: Display the longest path for each node
print("\nLongest Path from Each Node to a Leaf:")
for node, length in longest_paths.items():
    print(f"Node {node}: Longest path = {length}")


Longest Path from Each Node to a Leaf:
Node 1: Longest path = 11
Node 2: Longest path = 10
Node 3: Longest path = 9
Node 4: Longest path = 1
Node 5: Longest path = 8
Node 6: Longest path = 7
Node 7: Longest path = 1
Node 8: Longest path = 6
Node 9: Longest path = 1
Node 10: Longest path = 5
Node 11: Longest path = 4
Node 12: Longest path = 1
Node 13: Longest path = 3
Node 14: Longest path = 1
Node 15: Longest path = 2
Node 16: Longest path = 1
Node 17: Longest path = 1
Node 18: Longest path = 1
Node 19: Longest path = 1


In [74]:
type(longest_paths)

dict

In [75]:
longest_paths.keys()

dict_keys([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19])

In [76]:
# Add a new column 'height' based on the matching OPERATOR_ID
df['HEIGHT'] = df['OPERATOR_ID'].map(longest_paths)

In [77]:
df.shape

(19, 16)

In [78]:
df

Unnamed: 0,EXPLAIN_TIME,OPERATOR_ID,OPERATOR_TYPE,PREDICATE1,PREDICATE2,PREDICATE3,FILTER_FACTOR1,FILTER_FACTOR2,FILTER_FACTOR3,JOIN_KEY,SOURCE_ID,TABLE,PARSED_PRED1,PARSED_PRED2,PARSED_PRED3,HEIGHT
0,2024-12-03-20.44.01.238859,1,RETURN,,,,1.0,1.0,1.0,,,,,,,11
1,2024-12-03-20.44.01.238859,2,TQ,,,,1.0,1.0,1.0,,,,,,,10
2,2024-12-03-20.44.01.238859,3,HSJOIN,,,,1.0,1.0,1.0,AHNAF.MOVIE_LINK.MOVIE_ID = AHNAF.MOVIE_INFO.MOVIE_ID,,,,,,9
3,2024-12-03-20.44.01.238859,4,TBSCAN,"AHNAF.MOVIE_INFO.INFO IN ('Sweden ', 'Norway ', 'Germany ', 'Denmark ', 'Swedish ', 'Denish ', 'Norwegian ', 'German ', 'USA ', 'CANADA ', 'Netherlands ', 'Brazil ', 'UK ', 'Belgium ', 'Finland ', 'Hungary ', 'Estonia ', 'worldwide ', 'Australia ', 'Spain ', 'France ', 'Japan ', 'Columbia ', 'Slovenia ', 'Israel ', 'Venezuela ', 'Nigeria ', 'Philippines ', 'New Zealand ', 'Ireland ', 'Romania ', 'Non-USA ', 'Bulgaria ', 'Argentina ', 'Malaysia ', 'Singapore ', 'Turkey ', 'Sri Lanka ', 'Italy ', 'Indonesia ', 'South Korea ', 'Vietnam ', 'Slovakia ', 'Czech Republic ', 'China ', 'Portugal ', 'Greece ', 'Republic of Macedonia', 'Serbia ', 'Jamaica ', 'Switzerland ', 'Yugoslavia ', 'Mexico ', 'Austria ', 'Russia ')",,,0.029517,1.0,1.0,,-1.0,AHNAF.MOVIE_INFO,"(AHNAF.MOVIE_INFO.INFO, IN, [Sweden, Norway, Germany, Denmark, Swedish, Denish, Norwegian, German, USA, CANADA, Netherlands, Brazil, UK, Belgium, Finland, Hungary, Estonia, worldwide, Australia, Spain, France, Japan, Columbia, Slovenia, Israel, Venezuela, Nigeria, Philippines, New Zealand, Ireland, Romania, Non-USA, Bulgaria, Argentina, Malaysia, Singapore, Turkey, Sri Lanka, Italy, Indonesia, South Korea, Vietnam, Slovakia, Czech Republic, China, Portugal, Greece, Republic of Macedonia, Serbia, Jamaica, Switzerland, Yugoslavia, Mexico, Austria, Russia])",,,1
4,2024-12-03-20.44.01.238859,5,HSJOIN,,,,1.0,1.0,1.0,AHNAF.LINK_TYPE.ID = AHNAF.MOVIE_LINK.LINK_TYPE_ID,,,,,,8
5,2024-12-03-20.44.01.238859,6,HSJOIN,,,,1.0,1.0,1.0,AHNAF.MOVIE_LINK.MOVIE_ID = AHNAF.MOVIE_COMPANIES.MOVIE_ID,,,,,,7
6,2024-12-03-20.44.01.238859,19,TBSCAN,AHNAF.LINK_TYPE.LINK LIKE '%follows%',,,0.083333,1.0,1.0,,-1.0,AHNAF.LINK_TYPE,"(AHNAF.LINK_TYPE.LINK, LIKE, '%follows%')",,,1
7,2024-12-03-20.44.01.238859,7,TBSCAN,,,,1.0,1.0,1.0,,-1.0,AHNAF.MOVIE_LINK,,,,1
8,2024-12-03-20.44.01.238859,8,HSJOIN,,,,1.0,1.0,1.0,AHNAF.MOVIE_COMPANIES.COMPANY_ID = AHNAF.COMPANY_NAME.ID,,,,,,6
9,2024-12-03-20.44.01.238859,9,TBSCAN,AHNAF.COMPANY_NAME.COUNTRY_CODE = '[jp]',,,0.024702,1.0,1.0,,-1.0,AHNAF.COMPANY_NAME,"(AHNAF.COMPANY_NAME.COUNTRY_CODE, =, '[jp]')",,,1


# Creating a tensor of heights

In [79]:
heights = df['HEIGHT'].values

In [80]:
heights

array([11, 10,  9,  1,  8,  7,  1,  1,  6,  1,  5,  4,  1,  1,  3,  1,  2,
        1,  1])

In [81]:
torch.LongTensor(heights).size()

torch.Size([19])

In [82]:
heights = pad_1d_unsqueeze(torch.LongTensor(heights), max_node)


In [83]:
heights

tensor([[12, 11, 10,  2,  9,  8,  2,  2,  7,  2,  6,  5,  2,  2,  4,  2,  3,  2,
          2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]])

In [84]:
heights.shape

torch.Size([1, 30])

In [85]:
N

19

In [86]:
heights

tensor([[12, 11, 10,  2,  9,  8,  2,  2,  7,  2,  6,  5,  2,  2,  4,  2,  3,  2,
          2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]])

# Look up the `SORT_SHRHEAP_TOP`

In [87]:
### Step 1: look up the `APPL_ID`, 'ACTIVITY_ID', 'UOW_ID'
# Define the file path and the timestamp value
file_path = f'{base_dir}/success.csv'  # Replace with your CSV file path

# Read the CSV file, selecting only relevant columns
# column_headers = ['QUERYID', 'APPL_ID', 'UOW_ID', 'ACTIVITY_ID', 'EXPLAIN_TIME', 'SORT_SHRHEAP_TOP', 'QUERY']
# df_success = pd.read_csv(file_path, header=None)
df_success = pd.read_csv(file_path, dtype={
    "QUERYID": int,
    "EXPLAIN_TIME": str,
    "SORT_SHRHEAP_TOP": float,
    "QUERY": str
})

#df_success.columns = column_headers

In [88]:
df_success.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2332 entries, 0 to 2331
Data columns (total 7 columns):
 #   Column            Non-Null Count  Dtype  
---  ------            --------------  -----  
 0   QUERYID           2332 non-null   int64  
 1   APPL_ID           2332 non-null   object 
 2   UOW_ID            2332 non-null   int64  
 3   ACTIVITY_ID       2332 non-null   int64  
 4   EXPLAIN_TIME      2332 non-null   object 
 5   SORT_SHRHEAP_TOP  2332 non-null   float64
 6   QUERY             2332 non-null   object 
dtypes: float64(1), int64(3), object(3)
memory usage: 127.7+ KB


In [89]:
df_success.shape

(2332, 7)

In [90]:
query1_ts

'2024-12-03-20.44.01.238859'

In [91]:
df_success.head(1)

Unnamed: 0,QUERYID,APPL_ID,UOW_ID,ACTIVITY_ID,EXPLAIN_TIME,SORT_SHRHEAP_TOP,QUERY
0,1,*LOCAL.db2inst1.241204044349,4,1,2024-12-03-20.43.44.200877,66844.0,"SELECT cn.name AS company_name, lt.link AS link_type,t.title AS western_follow_up FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, movie_link AS ml, title AS t WHERE cn.country_code ='[us]' AND ct.kind ='distributors' AND k.keyword ='number-in-title' AND lt.link LIKE '%referenced in%' AND mc.note IS NULL AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German', 'USA', 'CANADA', 'Netherlands', 'Brazil', 'UK', 'Belgium', 'Finland', 'Hungary', 'Estonia', 'worldwide', 'Australia', 'Spain', 'France', 'Germany', 'Japan', 'Hungary', 'Sweden', 'Columbia', 'Slovenia', 'Israel', 'Venezuela', 'Finland', 'Nigeria', 'Philippines', 'New Zealand', 'Ireland', 'Romania', 'Non-USA', 'Bulgaria', 'Argentina', 'Malaysia', 'Singapore', 'Turkey', 'Sri Lanka', 'Italy', 'Indonesia', 'South Korea', 'Vietnam', 'Slovakia', 'Czech Republic', 'China', 'Portugal', 'Greece', 'Republic of Macedonia', 'Serbia', 'Jamaica', 'Switzerland', 'Yugoslavia', 'Mexico', 'Austria', 'Russia') AND t.production_year BETWEEN 1950 AND 2020 AND lt.id = ml.link_type_id AND ml.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_type_id = ct.id AND mc.company_id = cn.id AND mi.movie_id = t.id AND ml.movie_id = mk.movie_id AND ml.movie_id = mc.movie_id AND mk.movie_id = mc.movie_id AND ml.movie_id = mi.movie_id AND mk.movie_id = mi.movie_id AND mc.movie_id = mi.movie_id;"


In [92]:
df_query1 = df_success[df_success['EXPLAIN_TIME'] == query1_ts]

In [93]:
df_query1

Unnamed: 0,QUERYID,APPL_ID,UOW_ID,ACTIVITY_ID,EXPLAIN_TIME,SORT_SHRHEAP_TOP,QUERY
9,10,*LOCAL.db2inst1.241204044349,76,1,2024-12-03-20.44.01.238859,58513.0,"SELECT cn.name AS company_name, lt.link AS link_type,t.title AS western_follow_up FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, movie_link AS ml, title AS t WHERE cn.country_code ='[jp]' AND ct.kind ='production companies' AND k.keyword ='superhero' AND lt.link LIKE '%follows%' AND mc.note IS NULL AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German', 'USA', 'CANADA', 'Netherlands', 'Brazil', 'UK', 'Belgium', 'Finland', 'Hungary', 'Estonia', 'worldwide', 'Australia', 'Spain', 'France', 'Germany', 'Japan', 'Hungary', 'Sweden', 'Columbia', 'Slovenia', 'Israel', 'Venezuela', 'Finland', 'Nigeria', 'Philippines', 'New Zealand', 'Ireland', 'Romania', 'Non-USA', 'Bulgaria', 'Argentina', 'Malaysia', 'Singapore', 'Turkey', 'Sri Lanka', 'Italy', 'Indonesia', 'South Korea', 'Vietnam', 'Slovakia', 'Czech Republic', 'China', 'Portugal', 'Greece', 'Republic of Macedonia', 'Serbia', 'Jamaica', 'Switzerland', 'Yugoslavia', 'Mexico', 'Austria', 'Russia') AND t.production_year BETWEEN 1950 AND 2020 AND lt.id = ml.link_type_id AND ml.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_type_id = ct.id AND mc.company_id = cn.id AND mi.movie_id = t.id AND ml.movie_id = mk.movie_id AND ml.movie_id = mc.movie_id AND mk.movie_id = mc.movie_id AND ml.movie_id = mi.movie_id AND mk.movie_id = mi.movie_id AND mc.movie_id = mi.movie_id;"


In [94]:
df_query1.shape

(1, 7)

## List the final set of dataframes, which I will convert to tensors

In [95]:
df_query1.head()

Unnamed: 0,QUERYID,APPL_ID,UOW_ID,ACTIVITY_ID,EXPLAIN_TIME,SORT_SHRHEAP_TOP,QUERY
9,10,*LOCAL.db2inst1.241204044349,76,1,2024-12-03-20.44.01.238859,58513.0,"SELECT cn.name AS company_name, lt.link AS link_type,t.title AS western_follow_up FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, movie_link AS ml, title AS t WHERE cn.country_code ='[jp]' AND ct.kind ='production companies' AND k.keyword ='superhero' AND lt.link LIKE '%follows%' AND mc.note IS NULL AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German', 'USA', 'CANADA', 'Netherlands', 'Brazil', 'UK', 'Belgium', 'Finland', 'Hungary', 'Estonia', 'worldwide', 'Australia', 'Spain', 'France', 'Germany', 'Japan', 'Hungary', 'Sweden', 'Columbia', 'Slovenia', 'Israel', 'Venezuela', 'Finland', 'Nigeria', 'Philippines', 'New Zealand', 'Ireland', 'Romania', 'Non-USA', 'Bulgaria', 'Argentina', 'Malaysia', 'Singapore', 'Turkey', 'Sri Lanka', 'Italy', 'Indonesia', 'South Korea', 'Vietnam', 'Slovakia', 'Czech Republic', 'China', 'Portugal', 'Greece', 'Republic of Macedonia', 'Serbia', 'Jamaica', 'Switzerland', 'Yugoslavia', 'Mexico', 'Austria', 'Russia') AND t.production_year BETWEEN 1950 AND 2020 AND lt.id = ml.link_type_id AND ml.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_type_id = ct.id AND mc.company_id = cn.id AND mi.movie_id = t.id AND ml.movie_id = mk.movie_id AND ml.movie_id = mc.movie_id AND mk.movie_id = mc.movie_id AND ml.movie_id = mi.movie_id AND mk.movie_id = mi.movie_id AND mc.movie_id = mi.movie_id;"


In [96]:
# get the SORT_SHRHEAP_TOP
raw_costs = df_query1['SORT_SHRHEAP_TOP'].values.tolist()
QUERYID = df_query1['QUERYID'].values[0]

print(raw_costs)

[58513.0]


In [97]:
type(raw_costs)

list

In [98]:
type(raw_costs[0])

float

In [99]:
import torch
from model.util import Normalizer

# cost_norm = Normalizer(1, 100)
# cost_norm = Normalizer(-3.61192, 12.290855)
#cost_norm = Normalizer(5, 2611)
cost_norm = Normalizer(8.26, 11.12)
cost_labels = torch.from_numpy(cost_norm.normalize_labels(raw_costs))

In [100]:
cost_labels

tensor([0.9500], dtype=torch.float64)

In [101]:
QUERYID

10

In [102]:
# get the node features
df

Unnamed: 0,EXPLAIN_TIME,OPERATOR_ID,OPERATOR_TYPE,PREDICATE1,PREDICATE2,PREDICATE3,FILTER_FACTOR1,FILTER_FACTOR2,FILTER_FACTOR3,JOIN_KEY,SOURCE_ID,TABLE,PARSED_PRED1,PARSED_PRED2,PARSED_PRED3,HEIGHT
0,2024-12-03-20.44.01.238859,1,RETURN,,,,1.0,1.0,1.0,,,,,,,11
1,2024-12-03-20.44.01.238859,2,TQ,,,,1.0,1.0,1.0,,,,,,,10
2,2024-12-03-20.44.01.238859,3,HSJOIN,,,,1.0,1.0,1.0,AHNAF.MOVIE_LINK.MOVIE_ID = AHNAF.MOVIE_INFO.MOVIE_ID,,,,,,9
3,2024-12-03-20.44.01.238859,4,TBSCAN,"AHNAF.MOVIE_INFO.INFO IN ('Sweden ', 'Norway ', 'Germany ', 'Denmark ', 'Swedish ', 'Denish ', 'Norwegian ', 'German ', 'USA ', 'CANADA ', 'Netherlands ', 'Brazil ', 'UK ', 'Belgium ', 'Finland ', 'Hungary ', 'Estonia ', 'worldwide ', 'Australia ', 'Spain ', 'France ', 'Japan ', 'Columbia ', 'Slovenia ', 'Israel ', 'Venezuela ', 'Nigeria ', 'Philippines ', 'New Zealand ', 'Ireland ', 'Romania ', 'Non-USA ', 'Bulgaria ', 'Argentina ', 'Malaysia ', 'Singapore ', 'Turkey ', 'Sri Lanka ', 'Italy ', 'Indonesia ', 'South Korea ', 'Vietnam ', 'Slovakia ', 'Czech Republic ', 'China ', 'Portugal ', 'Greece ', 'Republic of Macedonia', 'Serbia ', 'Jamaica ', 'Switzerland ', 'Yugoslavia ', 'Mexico ', 'Austria ', 'Russia ')",,,0.029517,1.0,1.0,,-1.0,AHNAF.MOVIE_INFO,"(AHNAF.MOVIE_INFO.INFO, IN, [Sweden, Norway, Germany, Denmark, Swedish, Denish, Norwegian, German, USA, CANADA, Netherlands, Brazil, UK, Belgium, Finland, Hungary, Estonia, worldwide, Australia, Spain, France, Japan, Columbia, Slovenia, Israel, Venezuela, Nigeria, Philippines, New Zealand, Ireland, Romania, Non-USA, Bulgaria, Argentina, Malaysia, Singapore, Turkey, Sri Lanka, Italy, Indonesia, South Korea, Vietnam, Slovakia, Czech Republic, China, Portugal, Greece, Republic of Macedonia, Serbia, Jamaica, Switzerland, Yugoslavia, Mexico, Austria, Russia])",,,1
4,2024-12-03-20.44.01.238859,5,HSJOIN,,,,1.0,1.0,1.0,AHNAF.LINK_TYPE.ID = AHNAF.MOVIE_LINK.LINK_TYPE_ID,,,,,,8
5,2024-12-03-20.44.01.238859,6,HSJOIN,,,,1.0,1.0,1.0,AHNAF.MOVIE_LINK.MOVIE_ID = AHNAF.MOVIE_COMPANIES.MOVIE_ID,,,,,,7
6,2024-12-03-20.44.01.238859,19,TBSCAN,AHNAF.LINK_TYPE.LINK LIKE '%follows%',,,0.083333,1.0,1.0,,-1.0,AHNAF.LINK_TYPE,"(AHNAF.LINK_TYPE.LINK, LIKE, '%follows%')",,,1
7,2024-12-03-20.44.01.238859,7,TBSCAN,,,,1.0,1.0,1.0,,-1.0,AHNAF.MOVIE_LINK,,,,1
8,2024-12-03-20.44.01.238859,8,HSJOIN,,,,1.0,1.0,1.0,AHNAF.MOVIE_COMPANIES.COMPANY_ID = AHNAF.COMPANY_NAME.ID,,,,,,6
9,2024-12-03-20.44.01.238859,9,TBSCAN,AHNAF.COMPANY_NAME.COUNTRY_CODE = '[jp]',,,0.024702,1.0,1.0,,-1.0,AHNAF.COMPANY_NAME,"(AHNAF.COMPANY_NAME.COUNTRY_CODE, =, '[jp]')",,,1


In [103]:
df.shape

(19, 16)

In [104]:
df.columns

Index(['EXPLAIN_TIME', 'OPERATOR_ID', 'OPERATOR_TYPE', 'PREDICATE1',
       'PREDICATE2', 'PREDICATE3', 'FILTER_FACTOR1', 'FILTER_FACTOR2',
       'FILTER_FACTOR3', 'JOIN_KEY', 'SOURCE_ID', 'TABLE', 'PARSED_PRED1',
       'PARSED_PRED2', 'PARSED_PRED3', 'HEIGHT'],
      dtype='object')

In [105]:
df_node_feat = df[['OPERATOR_ID', 'OPERATOR_TYPE', 'TABLE', 'PREDICATE1', 'PREDICATE2', 'PREDICATE3', 'PARSED_PRED1', 'PARSED_PRED2', 'PARSED_PRED3', 'FILTER_FACTOR1', 'FILTER_FACTOR2', 'FILTER_FACTOR3', 'JOIN_KEY', 'HEIGHT']]

In [106]:
df_node_feat.sample(3)

Unnamed: 0,OPERATOR_ID,OPERATOR_TYPE,TABLE,PREDICATE1,PREDICATE2,PREDICATE3,PARSED_PRED1,PARSED_PRED2,PARSED_PRED3,FILTER_FACTOR1,FILTER_FACTOR2,FILTER_FACTOR3,JOIN_KEY,HEIGHT
5,6,HSJOIN,,,,,,,,1.0,1.0,1.0,AHNAF.MOVIE_LINK.MOVIE_ID = AHNAF.MOVIE_COMPANIES.MOVIE_ID,7
10,10,HSJOIN,,,,,,,,1.0,1.0,1.0,AHNAF.MOVIE_COMPANIES.COMPANY_TYPE_ID = AHNAF.COMPANY_TYPE.ID,5
9,9,TBSCAN,AHNAF.COMPANY_NAME,AHNAF.COMPANY_NAME.COUNTRY_CODE = '[jp]',,,"(AHNAF.COMPANY_NAME.COUNTRY_CODE, =, '[jp]')",,,0.024702,1.0,1.0,,1


In [107]:
# Generate the bitmap indicating NaN (0) or non-NaN (1) in PREDICATE1, PREDICATE2, and PREDICATE3
bitmap = df_node_feat[['PREDICATE1', 'PREDICATE2', 'PREDICATE3']].notna().astype(int).values.tolist()

# Display the bitmap
print(bitmap)

df_node_feat['PREDICATE_MASK'] = bitmap

[[0, 0, 0], [0, 0, 0], [0, 0, 0], [1, 0, 0], [0, 0, 0], [0, 0, 0], [1, 0, 0], [0, 0, 0], [0, 0, 0], [1, 0, 0], [0, 0, 0], [0, 0, 0], [1, 0, 0], [1, 0, 0], [0, 0, 0], [1, 1, 0], [0, 0, 0], [0, 0, 0], [1, 0, 0]]


In [108]:
df_node_feat.sample(5)

Unnamed: 0,OPERATOR_ID,OPERATOR_TYPE,TABLE,PREDICATE1,PREDICATE2,PREDICATE3,PARSED_PRED1,PARSED_PRED2,PARSED_PRED3,FILTER_FACTOR1,FILTER_FACTOR2,FILTER_FACTOR3,JOIN_KEY,HEIGHT,PREDICATE_MASK
0,1,RETURN,,,,,,,,1.0,1.0,1.0,,11,"[0, 0, 0]"
12,18,TBSCAN,AHNAF.COMPANY_TYPE,AHNAF.COMPANY_TYPE.KIND = 'production companies',,,"(AHNAF.COMPANY_TYPE.KIND, =, 'production companies')",,,0.25,1.0,1.0,,1,"[1, 0, 0]"
9,9,TBSCAN,AHNAF.COMPANY_NAME,AHNAF.COMPANY_NAME.COUNTRY_CODE = '[jp]',,,"(AHNAF.COMPANY_NAME.COUNTRY_CODE, =, '[jp]')",,,0.024702,1.0,1.0,,1,"[1, 0, 0]"
18,17,TBSCAN,AHNAF.KEYWORD,AHNAF.KEYWORD.KEYWORD = 'superhero',,,"(AHNAF.KEYWORD.KEYWORD, =, 'superhero')",,,7e-06,1.0,1.0,,1,"[1, 0, 0]"
4,5,HSJOIN,,,,,,,,1.0,1.0,1.0,AHNAF.LINK_TYPE.ID = AHNAF.MOVIE_LINK.LINK_TYPE_ID,8,"[0, 0, 0]"


In [109]:
df_node_feat.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 19 entries, 0 to 18
Data columns (total 15 columns):
 #   Column          Non-Null Count  Dtype  
---  ------          --------------  -----  
 0   OPERATOR_ID     19 non-null     int64  
 1   OPERATOR_TYPE   19 non-null     object 
 2   TABLE           9 non-null      object 
 3   PREDICATE1      7 non-null      object 
 4   PREDICATE2      1 non-null      object 
 5   PREDICATE3      0 non-null      object 
 6   PARSED_PRED1    7 non-null      object 
 7   PARSED_PRED2    1 non-null      object 
 8   PARSED_PRED3    0 non-null      object 
 9   FILTER_FACTOR1  19 non-null     float64
 10  FILTER_FACTOR2  19 non-null     float64
 11  FILTER_FACTOR3  19 non-null     float64
 12  JOIN_KEY        8 non-null      object 
 13  HEIGHT          19 non-null     int64  
 14  PREDICATE_MASK  19 non-null     object 
dtypes: float64(3), int64(2), object(10)
memory usage: 2.4+ KB


In [110]:
df_node_feat.sample(3)

Unnamed: 0,OPERATOR_ID,OPERATOR_TYPE,TABLE,PREDICATE1,PREDICATE2,PREDICATE3,PARSED_PRED1,PARSED_PRED2,PARSED_PRED3,FILTER_FACTOR1,FILTER_FACTOR2,FILTER_FACTOR3,JOIN_KEY,HEIGHT,PREDICATE_MASK
14,13,HSJOIN,,,,,,,,1.0,1.0,1.0,AHNAF.TITLE.ID = AHNAF.MOVIE_KEYWORD.MOVIE_ID,3,"[0, 0, 0]"
9,9,TBSCAN,AHNAF.COMPANY_NAME,AHNAF.COMPANY_NAME.COUNTRY_CODE = '[jp]',,,"(AHNAF.COMPANY_NAME.COUNTRY_CODE, =, '[jp]')",,,0.024702,1.0,1.0,,1,"[1, 0, 0]"
12,18,TBSCAN,AHNAF.COMPANY_TYPE,AHNAF.COMPANY_TYPE.KIND = 'production companies',,,"(AHNAF.COMPANY_TYPE.KIND, =, 'production companies')",,,0.25,1.0,1.0,,1,"[1, 0, 0]"


In [111]:
# Function to collect predicate and parsed predicate pairs
def gather_predicate_info(row):
    gathered_data = []
    for i in range(1, 4):  # Loop through PREDICATE1, PREDICATE2, PREDICATE3
        predicate_col = f"PREDICATE{i}"
        parsed_col = f"PARSED_PRED{i}"
        if pd.notna(row[predicate_col]):  # Only include non-None values
            gathered_data.append({'PREDICATE': row[predicate_col], 'PARSED_PRED': row[parsed_col]})
    return gathered_data

# Apply the function to collect all predicates and parsed predicates
all_predicates = []
df_node_feat.apply(lambda row: all_predicates.extend(gather_predicate_info(row)), axis=1)

# Create a new DataFrame from the collected data
df_predicate_pairs = pd.DataFrame(all_predicates)

# Display the new DataFrame
#print(df_predicate_pairs)

In [112]:
df_predicate_pairs

Unnamed: 0,PREDICATE,PARSED_PRED
0,"AHNAF.MOVIE_INFO.INFO IN ('Sweden ', 'Norway ', 'Germany ', 'Denmark ', 'Swedish ', 'Denish ', 'Norwegian ', 'German ', 'USA ', 'CANADA ', 'Netherlands ', 'Brazil ', 'UK ', 'Belgium ', 'Finland ', 'Hungary ', 'Estonia ', 'worldwide ', 'Australia ', 'Spain ', 'France ', 'Japan ', 'Columbia ', 'Slovenia ', 'Israel ', 'Venezuela ', 'Nigeria ', 'Philippines ', 'New Zealand ', 'Ireland ', 'Romania ', 'Non-USA ', 'Bulgaria ', 'Argentina ', 'Malaysia ', 'Singapore ', 'Turkey ', 'Sri Lanka ', 'Italy ', 'Indonesia ', 'South Korea ', 'Vietnam ', 'Slovakia ', 'Czech Republic ', 'China ', 'Portugal ', 'Greece ', 'Republic of Macedonia', 'Serbia ', 'Jamaica ', 'Switzerland ', 'Yugoslavia ', 'Mexico ', 'Austria ', 'Russia ')","(AHNAF.MOVIE_INFO.INFO, IN, [Sweden, Norway, Germany, Denmark, Swedish, Denish, Norwegian, German, USA, CANADA, Netherlands, Brazil, UK, Belgium, Finland, Hungary, Estonia, worldwide, Australia, Spain, France, Japan, Columbia, Slovenia, Israel, Venezuela, Nigeria, Philippines, New Zealand, Ireland, Romania, Non-USA, Bulgaria, Argentina, Malaysia, Singapore, Turkey, Sri Lanka, Italy, Indonesia, South Korea, Vietnam, Slovakia, Czech Republic, China, Portugal, Greece, Republic of Macedonia, Serbia, Jamaica, Switzerland, Yugoslavia, Mexico, Austria, Russia])"
1,AHNAF.LINK_TYPE.LINK LIKE '%follows%',"(AHNAF.LINK_TYPE.LINK, LIKE, '%follows%')"
2,AHNAF.COMPANY_NAME.COUNTRY_CODE = '[jp]',"(AHNAF.COMPANY_NAME.COUNTRY_CODE, =, '[jp]')"
3,AHNAF.COMPANY_TYPE.KIND = 'production companies',"(AHNAF.COMPANY_TYPE.KIND, =, 'production companies')"
4,AHNAF.MOVIE_COMPANIES.NOTE IS NULL,"(AHNAF.MOVIE_COMPANIES.NOTE, IS, NULL)"
5,AHNAF.TITLE.PRODUCTION_YEAR <= 2020,"(AHNAF.TITLE.PRODUCTION_YEAR, <=, 2020)"
6,AHNAF.TITLE.PRODUCTION_YEAR >= 1950,"(AHNAF.TITLE.PRODUCTION_YEAR, >=, 1950)"
7,AHNAF.KEYWORD.KEYWORD = 'superhero',"(AHNAF.KEYWORD.KEYWORD, =, 'superhero')"


In [113]:
# Function to transform PARSED_PRED into a sentence
def construct_sentence(parsed_pred):
    if not isinstance(parsed_pred, tuple) or len(parsed_pred) != 3:
        return None  # Skip invalid entries

    col, op, val = parsed_pred
    # Convert list values in 'val' to a string representation
    if isinstance(val, list):
        val = ", ".join(map(str, val))

    # Construct the sentence
    return f"'col: {col} [SEP] op: {op} [SEP] val: {val}'"

# Apply the transformation to the DataFrame
df_predicate_pairs['SENTENCE'] = df_predicate_pairs['PARSED_PRED'].apply(construct_sentence)

In [114]:
df_predicate_pairs.sample(3)

Unnamed: 0,PREDICATE,PARSED_PRED,SENTENCE
2,AHNAF.COMPANY_NAME.COUNTRY_CODE = '[jp]',"(AHNAF.COMPANY_NAME.COUNTRY_CODE, =, '[jp]')",'col: AHNAF.COMPANY_NAME.COUNTRY_CODE [SEP] op: = [SEP] val: '[jp]''
1,AHNAF.LINK_TYPE.LINK LIKE '%follows%',"(AHNAF.LINK_TYPE.LINK, LIKE, '%follows%')",'col: AHNAF.LINK_TYPE.LINK [SEP] op: LIKE [SEP] val: '%follows%''
7,AHNAF.KEYWORD.KEYWORD = 'superhero',"(AHNAF.KEYWORD.KEYWORD, =, 'superhero')",'col: AHNAF.KEYWORD.KEYWORD [SEP] op: = [SEP] val: 'superhero''


In [115]:
from sentence_transformers import SentenceTransformer

# Load the pre-trained embedding model
model = SentenceTransformer('all-MiniLM-L6-v2')

# Generate embeddings for each sentence
df_predicate_pairs['EMBEDDING'] = df_predicate_pairs['SENTENCE'].apply(lambda sentence: model.encode(sentence))

2025-01-07 18:29:37.103485: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [116]:
df_predicate_pairs.head(1)

Unnamed: 0,PREDICATE,PARSED_PRED,SENTENCE,EMBEDDING
0,"AHNAF.MOVIE_INFO.INFO IN ('Sweden ', 'Norway ', 'Germany ', 'Denmark ', 'Swedish ', 'Denish ', 'Norwegian ', 'German ', 'USA ', 'CANADA ', 'Netherlands ', 'Brazil ', 'UK ', 'Belgium ', 'Finland ', 'Hungary ', 'Estonia ', 'worldwide ', 'Australia ', 'Spain ', 'France ', 'Japan ', 'Columbia ', 'Slovenia ', 'Israel ', 'Venezuela ', 'Nigeria ', 'Philippines ', 'New Zealand ', 'Ireland ', 'Romania ', 'Non-USA ', 'Bulgaria ', 'Argentina ', 'Malaysia ', 'Singapore ', 'Turkey ', 'Sri Lanka ', 'Italy ', 'Indonesia ', 'South Korea ', 'Vietnam ', 'Slovakia ', 'Czech Republic ', 'China ', 'Portugal ', 'Greece ', 'Republic of Macedonia', 'Serbia ', 'Jamaica ', 'Switzerland ', 'Yugoslavia ', 'Mexico ', 'Austria ', 'Russia ')","(AHNAF.MOVIE_INFO.INFO, IN, [Sweden, Norway, Germany, Denmark, Swedish, Denish, Norwegian, German, USA, CANADA, Netherlands, Brazil, UK, Belgium, Finland, Hungary, Estonia, worldwide, Australia, Spain, France, Japan, Columbia, Slovenia, Israel, Venezuela, Nigeria, Philippines, New Zealand, Ireland, Romania, Non-USA, Bulgaria, Argentina, Malaysia, Singapore, Turkey, Sri Lanka, Italy, Indonesia, South Korea, Vietnam, Slovakia, Czech Republic, China, Portugal, Greece, Republic of Macedonia, Serbia, Jamaica, Switzerland, Yugoslavia, Mexico, Austria, Russia])","'col: AHNAF.MOVIE_INFO.INFO [SEP] op: IN [SEP] val: Sweden, Norway, Germany, Denmark, Swedish, Denish, Norwegian, German, USA, CANADA, Netherlands, Brazil, UK, Belgium, Finland, Hungary, Estonia, worldwide, Australia, Spain, France, Japan, Columbia, Slovenia, Israel, Venezuela, Nigeria, Philippines, New Zealand, Ireland, Romania, Non-USA, Bulgaria, Argentina, Malaysia, Singapore, Turkey, Sri Lanka, Italy, Indonesia, South Korea, Vietnam, Slovakia, Czech Republic, China, Portugal, Greece, Republic of Macedonia, Serbia, Jamaica, Switzerland, Yugoslavia, Mexico, Austria, Russia'","[0.0013051259, 0.01390308, -0.15022178, -0.04022137, 0.014927813, 0.019273622, -0.004296431, -0.029138565, 0.024275826, -0.027220987, 0.052932236, -0.035772923, 0.017127478, 0.06413056, -0.10257447, -0.054812163, -0.01584304, 0.02342822, -0.00075711793, -0.038751286, -0.034652725, 0.018303558, 0.081806526, 0.036038738, 0.054716084, -0.030117575, 0.0632168, 0.019856347, -0.04327227, -0.077914156, 0.0034970422, 0.03455585, -0.039521392, -0.002988534, 0.029415105, 0.008228485, -0.07196365, -0.05202048, -0.14010213, 0.019418249, 0.0051503642, -0.041509032, 0.04295384, 0.026789637, 0.08137868, -0.023035113, -0.033882502, -0.06630058, 0.05297076, 0.017825892, -0.08627354, 0.010711092, -0.03906606, -0.011549604, 0.0041380622, -0.03863845, 0.0749612, -0.04092702, 0.009877817, -0.022378935, -0.02488432, -0.0075904382, 0.035363127, 0.0038767548, 0.004661521, 0.0042444225, -0.0013261894, -0.028766463, -0.12508853, -0.04720793, -0.027496427, -0.023173368, -0.024765609, -0.03452944, 0.013580754, -0.021035306, 0.018876484, -0.003618749, 0.07107111, -0.0042440416, 0.098549195, -0.09517976, -0.065362476, -0.028853985, 0.041683476, -0.019839354, 0.04743606, -0.04983741, 0.04683033, 0.035424177, -0.09295083, 0.012213284, 0.041074853, 0.04425942, -0.015838023, -0.014565695, 0.051665775, 0.13402733, -0.03248349, 0.03985524, ...]"


In [117]:
# Create the dictionary
predicate_embedding_dict = dict(zip(df_predicate_pairs['PREDICATE'], df_predicate_pairs['EMBEDDING']))

# Display the dictionary
for predicate, embedding in predicate_embedding_dict.items():
    print(f"Predicate: {predicate}")
    print(f"Embedding: {embedding}")
    print()


Predicate: AHNAF.MOVIE_INFO.INFO IN ('Sweden               ', 'Norway               ', 'Germany              ', 'Denmark              ', 'Swedish              ', 'Denish               ', 'Norwegian            ', 'German               ', 'USA                  ', 'CANADA               ', 'Netherlands          ', 'Brazil               ', 'UK                   ', 'Belgium              ', 'Finland              ', 'Hungary              ', 'Estonia              ', 'worldwide            ', 'Australia            ', 'Spain                ', 'France               ', 'Japan                ', 'Columbia             ', 'Slovenia             ', 'Israel               ', 'Venezuela            ', 'Nigeria              ', 'Philippines          ', 'New Zealand          ', 'Ireland              ', 'Romania              ', 'Non-USA              ', 'Bulgaria             ', 'Argentina            ', 'Malaysia             ', 'Singapore            ', 'Turkey               ', 'Sri Lanka            ', 'Italy       

In [118]:
# Get the embedding dimension from the pretrained model
sample_embedding = model.encode(["test"])[0]  # Generate an example embedding
embedding_dim = sample_embedding.shape[0]     # Determine the dimension

# Define a zero embedding vector
zero_embedding = np.zeros(embedding_dim)

# Updated function to fetch embeddings
def get_embedding(predicate, embedding_dict):
    if predicate is not None and predicate in embedding_dict:
        return embedding_dict[predicate]
    return zero_embedding  # Return a zero vector if predicate is None or not in the dictionary

# Add EMBEDDING1, EMBEDDING2, and EMBEDDING3 columns to df_node_feat
df_node_feat['EMBEDDING1'] = df_node_feat['PREDICATE1'].apply(lambda pred: get_embedding(pred, predicate_embedding_dict))
df_node_feat['EMBEDDING2'] = df_node_feat['PREDICATE2'].apply(lambda pred: get_embedding(pred, predicate_embedding_dict))
df_node_feat['EMBEDDING3'] = df_node_feat['PREDICATE3'].apply(lambda pred: get_embedding(pred, predicate_embedding_dict))

# Display the updated DataFrame
print(df_node_feat[['PREDICATE1', 'EMBEDDING1', 'PREDICATE2', 'EMBEDDING2', 'PREDICATE3', 'EMBEDDING3']])

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        

In [119]:
df_node_feat.sample(1)

Unnamed: 0,OPERATOR_ID,OPERATOR_TYPE,TABLE,PREDICATE1,PREDICATE2,PREDICATE3,PARSED_PRED1,PARSED_PRED2,PARSED_PRED3,FILTER_FACTOR1,FILTER_FACTOR2,FILTER_FACTOR3,JOIN_KEY,HEIGHT,PREDICATE_MASK,EMBEDDING1,EMBEDDING2,EMBEDDING3
14,13,HSJOIN,,,,,,,,1.0,1.0,1.0,AHNAF.TITLE.ID = AHNAF.MOVIE_KEYWORD.MOVIE_ID,3,"[0, 0, 0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]"


In [120]:
df_node_feat.columns

Index(['OPERATOR_ID', 'OPERATOR_TYPE', 'TABLE', 'PREDICATE1', 'PREDICATE2',
       'PREDICATE3', 'PARSED_PRED1', 'PARSED_PRED2', 'PARSED_PRED3',
       'FILTER_FACTOR1', 'FILTER_FACTOR2', 'FILTER_FACTOR3', 'JOIN_KEY',
       'HEIGHT', 'PREDICATE_MASK', 'EMBEDDING1', 'EMBEDDING2', 'EMBEDDING3'],
      dtype='object')

In [121]:
# Helper function for label encoding with NaN values set to 0
def label_encode_with_nan(value, encoder, counter_start=1):
    if pd.isna(value):
        return 0, counter_start  # Return 0 for NaN and do not increment the counter
    if value not in encoder:
        encoder[value] = counter_start
        counter_start += 1
    return encoder[value], counter_start

### Load or Initialize Encoding Dictionaries

In [122]:
import pickle
import os

# Paths for dictionary storage
dictionary_dir = "./dictionaries"
os.makedirs(dictionary_dir, exist_ok=True)
op_type_file = os.path.join(dictionary_dir, "op_type_encoder.pkl")
table_file = os.path.join(dictionary_dir, "table_encoder.pkl")
join_key_file = os.path.join(dictionary_dir, "join_key_encoder.pkl")

# Load or initialize dictionaries
def load_or_initialize(file_path):
    if os.path.exists(file_path):
        with open(file_path, 'rb') as f:
            return pickle.load(f)
    return {}

op_type_encoder = load_or_initialize(op_type_file)
table_encoder = load_or_initialize(table_file)
join_key_encoder = load_or_initialize(join_key_file)

### Perform Encoding

In [123]:
# Counters for label encoding
op_type_counter, table_counter, join_key_counter = len(op_type_encoder) + 1, len(table_encoder) + 1, len(join_key_encoder) + 1
encoded_operator_type = []
encoded_table = []
encoded_join_key = []

for op, tbl, jk in zip(df_node_feat['OPERATOR_TYPE'], df_node_feat['TABLE'], df_node_feat['JOIN_KEY']):
    encoded_op, op_type_counter = label_encode_with_nan(op, op_type_encoder, op_type_counter)
    encoded_operator_type.append(encoded_op)
    
    encoded_tbl, table_counter = label_encode_with_nan(tbl, table_encoder, table_counter)
    encoded_table.append(encoded_tbl)
    
    encoded_jk, join_key_counter = label_encode_with_nan(jk, join_key_encoder, join_key_counter)
    encoded_join_key.append(encoded_jk)

df_node_feat['ENCODED_OPERATOR_TYPE'] = encoded_operator_type
df_node_feat['ENCODED_TABLE'] = encoded_table
df_node_feat['ENCODED_JOIN_KEY'] = encoded_join_key

### Save Dictionaries


In [124]:
# Save updated dictionaries back to files
with open(op_type_file, 'wb') as f:
    pickle.dump(op_type_encoder, f)
with open(table_file, 'wb') as f:
    pickle.dump(table_encoder, f)
with open(join_key_file, 'wb') as f:
    pickle.dump(join_key_encoder, f)

### Display Encoding Results and Dictionaries

In [125]:
# Inspect encoded columns and dictionaries
print(df_node_feat[['OPERATOR_TYPE', 'TABLE', 'JOIN_KEY', 'ENCODED_OPERATOR_TYPE', 'ENCODED_TABLE', 'ENCODED_JOIN_KEY']])
print("\nOPERATOR_TYPE Encoding Dictionary:", op_type_encoder)
print("TABLE Encoding Dictionary:", table_encoder)
print("JOIN_KEY Encoding Dictionary:", join_key_encoder)

   OPERATOR_TYPE                  TABLE  \
0         RETURN                    NaN   
1         TQ                        NaN   
2         HSJOIN                    NaN   
3         TBSCAN       AHNAF.MOVIE_INFO   
4         HSJOIN                    NaN   
5         HSJOIN                    NaN   
6         TBSCAN        AHNAF.LINK_TYPE   
7         TBSCAN       AHNAF.MOVIE_LINK   
8         HSJOIN                    NaN   
9         TBSCAN     AHNAF.COMPANY_NAME   
10        HSJOIN                    NaN   
11        HSJOIN                    NaN   
12        TBSCAN     AHNAF.COMPANY_TYPE   
13        TBSCAN  AHNAF.MOVIE_COMPANIES   
14        HSJOIN                    NaN   
15        TBSCAN            AHNAF.TITLE   
16        HSJOIN                    NaN   
17        TBSCAN    AHNAF.MOVIE_KEYWORD   
18        TBSCAN          AHNAF.KEYWORD   

                                                         JOIN_KEY  \
0                                                             NaN   
1

In [126]:
df_node_feat[['ENCODED_OPERATOR_TYPE', 'ENCODED_TABLE', 'ENCODED_JOIN_KEY', 'PREDICATE_MASK']]

Unnamed: 0,ENCODED_OPERATOR_TYPE,ENCODED_TABLE,ENCODED_JOIN_KEY,PREDICATE_MASK
0,1,0,0,"[0, 0, 0]"
1,2,0,0,"[0, 0, 0]"
2,3,0,11,"[0, 0, 0]"
3,4,2,0,"[1, 0, 0]"
4,3,0,6,"[0, 0, 0]"
5,3,0,5,"[0, 0, 0]"
6,4,6,0,"[1, 0, 0]"
7,4,7,0,"[0, 0, 0]"
8,3,0,3,"[0, 0, 0]"
9,4,3,0,"[1, 0, 0]"


# Making a tensor of node features

In [127]:
df_node_feat.columns

Index(['OPERATOR_ID', 'OPERATOR_TYPE', 'TABLE', 'PREDICATE1', 'PREDICATE2',
       'PREDICATE3', 'PARSED_PRED1', 'PARSED_PRED2', 'PARSED_PRED3',
       'FILTER_FACTOR1', 'FILTER_FACTOR2', 'FILTER_FACTOR3', 'JOIN_KEY',
       'HEIGHT', 'PREDICATE_MASK', 'EMBEDDING1', 'EMBEDDING2', 'EMBEDDING3',
       'ENCODED_OPERATOR_TYPE', 'ENCODED_TABLE', 'ENCODED_JOIN_KEY'],
      dtype='object')

In [128]:
len(df_node_feat)

19

In [129]:
tensor_data = []
for _, row in df_node_feat.iterrows():
    # Flatten all values into a single list
    flat_row = np.concatenate([
        [row['ENCODED_OPERATOR_TYPE']],
        [row['ENCODED_JOIN_KEY']],
        [row['ENCODED_TABLE']],
        row['PREDICATE_MASK'],
        row['EMBEDDING1'],  # Directly use the embedding column
        row['EMBEDDING2'],  # Directly use the embedding column
        row['EMBEDDING3'],  # Directly use the embedding column
        [row['FILTER_FACTOR1']],
        [row['FILTER_FACTOR2']],
        [row['FILTER_FACTOR3']]
    ])
    tensor_data.append(flat_row)

# Convert the list of concatenated rows into a tensor
x = torch.tensor(tensor_data, dtype=torch.float32)
x = x.unsqueeze(0)


In [130]:
x.shape

torch.Size([1, 19, 1161])

In [131]:
def pad_2d_unsqueeze(x, padlen):
    # dont know why add 1, comment out first
#    x = x + 1 # pad id = 0
    _, xlen, xdim = x.size()
    if xlen < padlen:
        new_x = x.new_zeros([padlen, xdim], dtype=x.dtype) + 1
        new_x[:xlen, :] = x
        x = new_x
    return x.unsqueeze(0)

def pad_1d_unsqueeze(x, padlen):
    x = x + 1 # pad id = 0
    xlen = x.size(0)
    if xlen < padlen:
        new_x = x.new_zeros([padlen], dtype=x.dtype)
        new_x[:xlen] = x
        x = new_x
    return x.unsqueeze(0)

In [132]:
x = pad_2d_unsqueeze(x, max_node)

In [133]:
x.shape

torch.Size([1, 30, 1161])

In [134]:
heights.shape

torch.Size([1, 30])

In [135]:
attn_bias.shape

torch.Size([1, 31, 31])

In [136]:
attn_bias

tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf,
          -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf,
          -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
         [0., -inf, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf,
          -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
         [0., -inf, -inf, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf,
          -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
         [0., -inf, -inf, -inf, 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
          -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
         [0., -inf, -inf, -inf, -inf, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf,
       

## Define Model

In [137]:
class Args:
    # bs = 1024
    # SQ: smaller batch size
    bs = 1
    #lr = 0.001
    lr = 0.001
    # epochs = 200
    epochs = 50
    clip_size = 50
    embed_size = 64
    pred_hid = 128
    ffn_dim = 128
    head_size = 12
    n_layers = 8
    dropout = 0.1
    sch_decay = 0.6
    # device = 'cuda:0'
    device = 'cpu'
    newpath = 'job_queries_training'
    to_predict = 'cost'
args = Args()

import os
if not os.path.exists(args.newpath):
    os.makedirs(args.newpath)

In [138]:
from model.model import QueryFormer

model = QueryFormer(emb_size = args.embed_size ,ffn_dim = args.ffn_dim, head_size = args.head_size, \
                 dropout = args.dropout, n_layers = args.n_layers, \
                 use_sample = False, use_hist = False, \
                 pred_hid = args.pred_hid
                )

In [139]:
from model.dataset import PlanTreeDataset

In [140]:
cost_labels

tensor([0.9500], dtype=torch.float64)

In [141]:
raw_costs

[58513.0]

In [142]:
type(x)

torch.Tensor

In [143]:
type(rel_pos)

torch.Tensor

In [144]:
type(attn_bias)

torch.Tensor

In [145]:
type(heights)

torch.Tensor

In [146]:
type(cost_labels)

torch.Tensor

In [147]:
type(raw_costs)

list

In [148]:
x.shape

torch.Size([1, 30, 1161])

In [149]:
attn_bias.shape

torch.Size([1, 31, 31])

In [150]:
type(x)

torch.Tensor

In [151]:
rel_pos.shape

torch.Size([1, 30, 30])

In [152]:
heights.shape

torch.Size([1, 30])

In [153]:
cost_labels.shape

torch.Size([1])

In [154]:
type(raw_costs)

list

In [155]:
raw_costs = torch.from_numpy(np.array(raw_costs))

In [156]:
type(raw_costs)

torch.Tensor

In [157]:
raw_costs

tensor([58513.], dtype=torch.float64)

# Storing a collection of tensors for the given plan

In [158]:
# Directory where the file will be saved
save_dir = "./job_queries/tensors"
# Ensure the directory exists
os.makedirs(save_dir, exist_ok=True)

# Dynamically construct the filename
filename = os.path.join(save_dir, f"query_{QUERYID}_{query1_ts}.pt")

In [159]:
print(f'x.shape: {x.shape}')
print(f'rel_pos.shape: {rel_pos.shape}')
print(f'attn_bias.shape: {attn_bias.shape}')
print(f'heights.shape: {heights.shape}')
print(f'cost_labels.shape: {cost_labels.shape}')
print(f'raw_costs.shape: {raw_costs.shape}')

x.shape: torch.Size([1, 30, 1161])
rel_pos.shape: torch.Size([1, 30, 30])
attn_bias.shape: torch.Size([1, 31, 31])
heights.shape: torch.Size([1, 30])
cost_labels.shape: torch.Size([1])
raw_costs.shape: torch.Size([1])


In [160]:
# Save tensors as a dictionary
tensor_collection = {
    "x": x,
    "rel_pos": rel_pos,
    "attn_bias": attn_bias,
    "heights": heights,
    "cost_labels": cost_labels,
    "raw_costs": raw_costs,
}

torch.save(tensor_collection, filename)
print(f"Tensors saved successfully as {filename}!")

Tensors saved successfully as ./job_queries/tensors/query_10_2024-12-03-20.44.01.238859.pt!


In [161]:
# Load the saved tensor collection
loaded_tensors = torch.load("tensor_collection.pt")

# Access each tensor
x_loaded = loaded_tensors["x"]
rel_pos_loaded = loaded_tensors["rel_pos"]
attn_bias_loaded = loaded_tensors["attn_bias"]
heights_loaded = loaded_tensors["heights"]
cost_labels_loaded = loaded_tensors["cost_labels"]
raw_costs_loaded = loaded_tensors["raw_costs"]

# Print the shapes to verify
print(f"x shape: {x_loaded.shape}")
print(f"rel_pos shape: {rel_pos_loaded.shape}")
print(f"attn_bias shape: {attn_bias_loaded.shape}")
print(f"heights shape: {heights_loaded.shape}")
print(f"cost_labels shape: {cost_labels_loaded.shape}")
print(f"raw_costs shape: {raw_costs_loaded.shape}")

x shape: torch.Size([1, 30, 1161])
rel_pos shape: torch.Size([1, 30, 30])
attn_bias shape: torch.Size([1, 31, 31])
heights shape: torch.Size([1, 30])
cost_labels shape: torch.Size([1])
raw_costs shape: torch.Size([1])


In [162]:
QUERYID

10