In [11]:
import wandb


In [12]:
api = wandb.Api()

project_name = "middelman/saint_rossmann_mse"
metric_name = "orig_valid_rmse"
saint_runs = api.runs(project_name)

In [13]:
sorted_saint_runs = sorted(
    saint_runs, key=lambda r: r.summary.get(metric_name, float("inf"))
)

In [14]:
sorted_saint_runs[0].id

'48cdbuke'

In [29]:
#Get the best run from a sweep

sweep_id = "middelman/saint_rossmann_mse/50p3fsjs"

sweep = api.sweep(sweep_id)

best_run = sorted(
    sweep.runs, key=lambda r: r.summary.get(metric_name, float("inf"))
)[0]

print(best_run.summary[metric_name])


623.9197387695312


In [30]:
# get the metadata for the run with the best metric
import pprint
# best_run = sorted_saint_runs[0]
pprint.pprint(best_run.config)

{'active_log': True,
 'attention_dropout': 0.45981794326954567,
 'attention_heads': 7,
 'attentiontype': 'colrow',
 'batchsize': 258,
 'cont_embeddings': 'MLP',
 'dset_seed': 42,
 'embedding_size': 43,
 'epochs': 113,
 'ff_dropout': 0.36833501402934365,
 'final_mlp_style': 'sep',
 'lam0': 0.5,
 'lam1': 10,
 'lam2': 1,
 'lam3': 10,
 'lr': 0.00031024476896993704,
 'mask_prob': 0,
 'mixup_lam': 0.3,
 'nce_temp': 0.7,
 'optimizer': 'AdamW',
 'pretrain': False,
 'pretrain_epochs': 50,
 'pt_aug': [],
 'pt_aug_lam': 0.1,
 'pt_projhead_style': 'diff',
 'pt_tasks': ['contrastive', 'denoising'],
 'run_name': 'Rossmann',
 'savemodelroot': '/home/coenraadmiddel/Documents/RossmannStoreSales/SAINT/saint/bestmodels/regression/',
 'scheduler': 'cosine',
 'set_seed': 43,
 'ssl_avail_y': 0,
 'task': 'regression',
 'train_mask_prob': 0,
 'transformer_depth': 1,
 'vision_dset': False}


# Get the data

In [None]:
import pandas as pd
import numpy as np

print("Reading the data...")
train = pd.read_parquet(
    r"/home/coenraadmiddel/Documents/RossmannStoreSales/TabNet/tabnet/train_processed.parquet"
)
print("Read:", train.shape)

# select only a couple of columns

train = train[
    [
        "Store",
        "DayOfWeek",
        "Promo",
        "StateHoliday",
        "SchoolHoliday",
        "StoreType",
        "Assortment",
        "CompetitionDistance",
        "Promo2SinceWeek",
        "Promo2SinceYear",
        "Year",
        "Month",
        "Day",
        "WeekOfYear",
        "CompetitionOpen",
        "PromoOpen",
        "IsPromoMonth",
        "Sales",
        "Set",
    ]
]


if "Set" not in train.columns:
    train.reset_index(inplace=True, drop=True)
    train["Set"] = np.random.choice(
        ["train", "valid", "test"], p=[0.8, 0.1, 0.1], size=(train.shape[0],)
    )

train_indices = train[train.Set == "train"].index
valid_indices = train[train.Set == "valid"].index
test_indices = train[train.Set == "test"].index


categorical_columns = [
    "Store",
    "DayOfWeek",
    "Promo",
    "StateHoliday",
    "SchoolHoliday",
    "StoreType",
    "Assortment",
    # 'Year',
    # 'Month',
    # 'Day',
    # 'WeekOfYear',
    "IsPromoMonth",
]


# split x and y
X_all, y_all = train.drop(columns=["Sales", "Set"]), np.log1p(train[["Sales"]].values)

temp = X_all.fillna("MissingValue")
nan_mask = temp.ne("MissingValue").astype(int)

X_train = X_all.iloc[train_indices]
X_test = X_all.iloc[test_indices]
X_valid = X_all.iloc[valid_indices]

y_train = y_all[train_indices]
y_test = y_all[test_indices]
y_valid = y_all[valid_indices]

In [None]:
train[categorical_columns] = train[categorical_columns].astype('category')

cat_idxs = [train.columns.get_loc(c) for c in categorical_columns if c in train]
cat_dims = [len(train[c].cat.categories) for c in categorical_columns if c in train]
cont_idxs = [i for i in range(X_train.shape[1]) if i not in cat_idxs]


In [None]:

#Load the SAINT model
from models import SAINT
import torch

model = SAINT(
    categories = cat_dims,
    num_continuous = len(cont_idxs),
    dim = best_run.config['embedding_size'],
    dim_out = 1,
    depth = best_run.config['transformer_depth'],
    heads = best_run.config['attention_heads'],
    attn_dropout = best_run.config['attention_dropout'],
    ff_dropout = best_run.config['ff_dropout'],
    mlp_hidden_mults = (4, 2),
    cont_embeddings = best_run.config['cont_embeddings'],
    attentiontype = best_run.config['attentiontype'],
    final_mlp_style = best_run.config['final_mlp_style'],
    y_dim = 1,
)
    

model.load_state_dict(torch.load('/home/coenraadmiddel/Documents/RossmannStoreSales/SAINT/saint/artifacts/SAINT_model:v87/bestmodel.pth'))
model.eval()
