## Graph construction for tables entries
In this example, we construct graphs of table entries for an exemplary tabular data. The dataset is the Wine Poland dataset, which contains information about wines on the polish market. 

In [1]:
# Set the current working directory and import packages
import os
from pathlib import Path
os.chdir(Path().cwd().parent)

import json
import pandas as pd
import numpy as np
from sklearn.model_selection import GroupShuffleSplit
from src.carte_table_to_graph import Table2GraphTransformer
from configs.directory import config_directory

In [2]:
# Define necessary functions

# Load data
def _load_data(data_name):
    data_pd_dir = f"{config_directory['data_singletable']}/{data_name}/raw.parquet"
    data_pd = pd.read_parquet(data_pd_dir)
    data_pd.fillna(value=np.nan, inplace=True)
    config_data_dir = f"{config_directory['data_singletable']}/{data_name}/config_data.json"
    filename = open(config_data_dir)
    config_data = json.load(filename)
    filename.close()
    return data_pd, config_data

# Set train/test split given the random state
def _set_split(data, data_config, num_train, random_state):
    target_name = data_config["target_name"]
    X = data.drop(columns=target_name)
    y = data[target_name]
    y = np.array(y)

    if data_config["repeated"]:
        entity_name = data_config["entity_name"]
    else:
        entity_name = np.arange(len(y))

    groups = np.array(data.groupby(entity_name).ngroup())
    num_groups = len(np.unique(groups))
    gss = GroupShuffleSplit(
        n_splits=1,
        test_size=int(num_groups - num_train),
        random_state=random_state,
    )
    idx_train, idx_test = next(iter(gss.split(X=y, groups=groups)))

    X_train, X_test = X.iloc[idx_train], X.iloc[idx_test]
    y_train, y_test = y[idx_train], y[idx_test]

    return X_train, X_test, y_train, y_test

To prepare the dataset, we
- set basic specifications
- load data and set train/test split
- generate graphs for each table entries (rows) using the Table2GraphTransformer

In [5]:
# Run the table2graph to get the necessary inputs for CARTE

# Set basic specifications
data_name = "wina_pl"      # Name of the data
num_train = 128     # Train-size
random_state = 1    # Random_state

# Load data and set train/test split
data, data_config = _load_data(data_name)
X_train_, X_test_, y_train, y_test = _set_split(
    data,
    data_config,
    num_train,
    random_state=random_state,
)
preprocessor = Table2GraphTransformer()
X_train = preprocessor.fit_transform(X_train_, y=y_train)
X_test = preprocessor.transform(X_test_)

In [17]:
# Original data
print("Original Data:\n", X_train_.iloc[0])

# Graph data
print("\nGraph Data:\n", X_train[0])

Original Data:
 name                   Achillée Crémant Soléra AOC Crémant d'Alsace NV
country                                                         France
region                                                          Alsace
appellation                                       Cremant d'Alsace AOC
vineyard                                                      Achillée
vintage                                                            NaN
volume                                                           750.0
ABV                                                               13.5
serving_temperature                                                  9
wine_type                                                          NaN
taste                                                              dry
style                                                          average
vegan                                                            False
natural                                                      

The result is a list of graph objects which can be used as inputs for the neural network in CARTE.

Each row is transformed into a graph data with node features(x), edge index (the graph structure), edge features, and the target y (not visible in the test set).

Also, this data point contains 13 columns (out of 15) which are not missing. Thus, the resulting graph will contain 14 node features (13 columns and center node), and 26 edge features (13 columns and 13 self-loops), as the graph is directed.
