# STaSy: Score-based Tabular Data Synthesis

In [1]:
#@title Install Git repository% 
%cd baselines
!git clone https://github.com/JayoungKim408/STaSy
%cd STaSy

/home/bigdyl/economics/AIFinLab/baselines
fatal: destination path 'STaSy' already exists and is not an empty directory.
/home/bigdyl/economics/AIFinLab/baselines/STaSy


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [2]:
#@title Install required libraries

!pip install --upgrade pip
!pip install --upgrade setuptools
!pip install ml_collections


Please see https://github.com/pypa/pip/issues/5599 for advice on fixing the underlying issue.
To avoid this problem you can invoke Python with '-m pip' instead of running pip directly.
[33mDEPRECATION: Python 2.7 reached the end of its life on January 1st, 2020. Please upgrade your Python as Python 2.7 is no longer maintained. pip 21.0 will drop support for Python 2.7 in January 2021. More details about Python 2 support in pip can be found at https://pip.pypa.io/en/latest/development/release-process/#python-2-support pip 21.0 will remove support for this functionality.[0m
Defaulting to user installation because normal site-packages is not writeable
Requirement already up-to-date: pip in /home/bigdyl/.local/lib/python2.7/site-packages (20.3.4)
Please see https://github.com/pypa/pip/issues/5599 for advice on fixing the underlying issue.
To avoid this problem you can invoke Python with '-m pip' instead of running pip directly.
[33mDEPRECATION: Python 2.7 reached the end of its l

In [3]:
#@title Import packages for score-based generative model

import numpy as np
# import tensorflow as tf
import pandas as pd
from models import ncsnpp_tabular
import losses
import likelihood
import sampling as sampling_
from models import utils as mutils
from models.ema import ExponentialMovingAverage
import datasets
from torch.utils.data import DataLoader
import evaluation
import sde_lib
from absl import flags
import torch
from utils import save_checkpoint, restore_checkpoint, apply_activate
import collections
from torch.utils import tensorboard
import os
from ml_collections import config_flags, config_dict

2024-06-02 09:11:27.934670: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [28]:
#@title Load configuration
config = config_dict.ConfigDict()
config.workdir = "stasy"
config.seed = 42
config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

config.training = training = config_dict.ConfigDict()
training.batch_size = 900 # 1000
training.epoch = 1 # 20
training.likelihood_weighting = False
training.continuous = True
training.reduce_mean = False
training.eps = 1e-05
training.loss_weighting = False
training.spl = True
training.lambda_ = 0.5
training.sde = 'vesde'
training.n_iters = 100000
training.tolerance = 1e-01  # 1e-3
training.hutchinson_type = "Rademacher"
training.retrain_type = "median"
training.eps_iters = 1
training.fine_tune_epochs = 1

config.sampling = sampling = config_dict.ConfigDict()
sampling.n_steps_each = 1
sampling.noise_removal = False
sampling.probability_flow = True
sampling.snr = 0.16
sampling.method = 'ode'
sampling.predictor = 'euler_maruyama'
sampling.corrector = 'none'

config.data = data = config_dict.ConfigDict()
data.centered = False
data.uniform_dequantization = False
data.dataset = "iris" # shoppers
data.image_size = 7  # 77

config.model = model = config_dict.ConfigDict()
model.nf = 64
model.hidden_dims = (256, 512, 1024, 1024, 512, 256)
model.conditional = True
model.embedding_type = 'fourier'
model.fourier_scale = 16
model.layer_type = 'concatsquash'
model.name = 'ncsnpp_tabular'
model.scale_by_sigma = False
model.ema_rate = 0.9999
model.activation = 'elu'
model.sigma_min = 0.01
model.sigma_max = 10.
model.num_scales = 50
model.alpha0 = 0.3
model.beta0 = 0.95

config.optim = optim = config_dict.ConfigDict()
optim.weight_decay = 0
optim.optimizer = 'Adam'
optim.lr = 1e-2 # 2e-3
optim.beta1 = 0.9
optim.eps = 1e-8
optim.warmup = 1000 # 5000
optim.grad_clip = 1.



In [29]:
#@title Build a score mode network and dataset

score_model = mutils.create_model(config)
num_params = sum(p.numel() for p in score_model.parameters())

# Build data iterators
train_ds, eval_ds, (transformer, meta) = datasets.get_dataset(config,
                                            uniform_dequantization=config.data.uniform_dequantization)

train_iter = DataLoader(train_ds, batch_size=config.training.batch_size)
# eval_iter = iter(DataLoader(eval_ds, batch_size=config.eval.batch_size))  # pytype: disable=wrong-arg-types


In [30]:
#@title Setup SDEs

# Setup SDEs
if config.training.sde.lower() == 'vpsde':
  sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
  sampling_eps = 1e-3
elif config.training.sde.lower() == 'subvpsde':
  sde = sde_lib.subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
  sampling_eps = 1e-3
elif config.training.sde.lower() == 'vesde':
  sde = sde_lib.VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales)
  sampling_eps = 1e-5
else:
  raise NotImplementedError(f"SDE {config.training.sde} unknown.")


In [31]:
#@title Build utilities for training
tb_dir = os.path.join(config.workdir, "tensorboard")
os.makedirs(tb_dir, exist_ok=True)
writer = tensorboard.SummaryWriter(tb_dir)

ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)
optimizer = losses.get_optimizer(config, score_model.parameters()) # Adam optimizer, lr 2e-3
state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0, epoch=0)

checkpoint_dir = os.path.join(config.workdir, "checkpoints")
checkpoint_meta_dir = os.path.join(config.workdir, "checkpoints-meta", "checkpoint.pth")
checkpoint_finetune_dir = os.path.join(config.workdir, "checkpoints_finetune")

os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(os.path.dirname(checkpoint_meta_dir), exist_ok=True)
os.makedirs(checkpoint_finetune_dir, exist_ok=True)

scaler = datasets.get_data_scaler(config)
inverse_scaler = datasets.get_data_inverse_scaler(config)

optimize_fn = losses.optimization_manager(config)
continuous = config.training.continuous
reduce_mean = config.training.reduce_mean
likelihood_weighting = config.training.likelihood_weighting

def loss_fn(model, batch):
  score_fn = mutils.get_score_fn(sde, model, train=True, continuous=continuous)
  t = torch.rand(batch.shape[0], device=batch.device) * (sde.T - 1e-5) + 1e-5
  z = torch.randn_like(batch)
  mean, std = sde.marginal_prob(batch, t)
  perturbed_data = mean + std[:, None] * z

  score = score_fn(perturbed_data, t)

  loss_values = torch.square(score * std[:, None] + z)
  loss_values = torch.mean(loss_values.reshape(loss_values.shape[0], -1), dim=-1)

  return loss_values

# Building sampling functions
sampling_shape = (config.training.batch_size, config.data.image_size)
sampling_fn = sampling_.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps)

scores_max = 0

likelihood_fn = likelihood.get_likelihood_fn(sde, inverse_scaler)
sampling_shape = (train_ds.shape[0], config.data.image_size)


In [32]:
#@title Build utilities for v scheduling

def min_max_scaling(factor, scale=(0, 1)):

  std = (factor - factor.min()) / (factor.max() - factor.min())
  new_min = torch.tensor(scale[0])
  new_max = torch.tensor(scale[1])
  return std * (new_max - new_min) + new_min


def compute_v(ll, alpha, beta):

  v = -torch.ones(ll.shape).to(ll.device)
  v[torch.gt(ll, beta)] = torch.tensor(0., device=v.device)
  v[torch.le(ll, alpha)] = torch.tensor(1., device=v.device)

  if ll[torch.eq(v, -1)].shape[0] !=0 and ll[torch.eq(v, -1)].shape[0] !=1 :
        v[torch.eq(v, -1)] = min_max_scaling(ll[torch.eq(v, -1)], scale=(1, 0)).to(v.device)
  else:
        v[torch.eq(v, -1)] = torch.tensor(0.5, device=v.device)

  return v

In [None]:
#@title Start model training
alpha0 = config.model.alpha0
beta0 = config.model.beta0

for epoch in range(config.training.epoch+1):
  state['epoch'] += 1
  for iteration, batch in enumerate(train_iter):
    print(len(train_iter), iteration)
    batch = batch.to(config.device).float()
    # loss = train_step_fn(state, batch)

    # model = state['model']
    optimizer = state['optimizer']
    optimizer.zero_grad()
    loss_values = loss_fn(score_model, batch)

    q_alpha = torch.tensor(alpha0 + torch.log( torch.tensor(1+ 0.0001718*state['step']* (1-alpha0), dtype=torch.float32) )).clamp_(max=1).to(loss_values.device)
    q_beta = torch.tensor(beta0 + torch.log( torch.tensor(1+ 0.0001718*state['step']* (1-beta0), dtype=torch.float32) )).clamp_(max=1).to(loss_values.device)

    alpha = torch.quantile(loss_values, q_alpha)
    beta = torch.quantile(loss_values, q_beta)
    assert alpha <= beta
    v = compute_v(loss_values, alpha, beta)
    loss = torch.mean(v*loss_values)

    loss.backward()
    optimize_fn(optimizer, score_model.parameters(), step=state['step'])
    state['step'] += 1
    state['ema'].update(score_model.parameters())

  print("epoch: %d, iter: %d, training_loss: %.5e, q_alpha: %.3e, q_beta: %.3e" % (epoch, iteration, loss.item(), q_alpha, q_beta))
  if epoch % 10 == 0:
    save_checkpoint(checkpoint_meta_dir, state)


1 0


  q_alpha = torch.tensor(alpha0 + torch.log( torch.tensor(1+ 0.0001718*state['step']* (1-alpha0), dtype=torch.float32) )).clamp_(max=1).to(loss_values.device)
  q_beta = torch.tensor(beta0 + torch.log( torch.tensor(1+ 0.0001718*state['step']* (1-beta0), dtype=torch.float32) )).clamp_(max=1).to(loss_values.device)


In [None]:
#@title Start fine-tune the pre-trained model

hutchinson_type = config.training.hutchinson_type
tolerance = config.training.tolerance

likelihood_fn = likelihood.get_likelihood_fn(sde, inverse_scaler, hutchinson_type, tolerance, tolerance)

train_ds_tensor = torch.tensor(train_ds, device=config.device, dtype=torch.float32)
train_ll = likelihood_fn(score_model, train_ds_tensor, eps_iters = config.training.eps_iters)[0]

if config.training.retrain_type == 'median':
  idx = torch.where(train_ll <= torch.median(train_ll), True, False)
elif config.training.retrain_type == 'mean':
  idx = torch.where(train_ll <= torch.mean(train_ll), True, False)

re_train = train_ds_tensor[idx]
train_iter = DataLoader(re_train, batch_size=config.training.batch_size)
step = 0

# model = state['model']
samples, n = sampling_fn(score_model, sampling_shape=sampling_shape)
samples = apply_activate(samples, transformer.output_info)
samples = transformer.inverse_transform(samples.cpu().numpy())
scores_max = 0

for epoch in range(config.training.fine_tune_epochs):
  for iteration, batch in enumerate(train_iter):
    batch = batch.to(config.device).float()

    optimizer = state['optimizer']
    optimizer.zero_grad()
    loss_values = loss_fn(score_model, batch)
    loss = torch.mean(loss_values)

    state['step'] += 1

    loss.backward()
    optimize_fn(optimizer, score_model.parameters(), step=state['step'])
    state['step'] += 1
    state['ema'].update(score_model.parameters())
  train_ll_after = likelihood_fn(score_model, train_ds_tensor, eps_iters = config.training.eps_iters)[0]

  diff = train_ll_after - train_ll
  idx_after = torch.where(diff < -0.1, True, False)
  re_train = train_ds_tensor[idx_after]

  train_iter = DataLoader(re_train, batch_size=config.training.batch_size)

  save_checkpoint(os.path.join(checkpoint_finetune_dir, "checkpoint_finetune.pth"), state)
  print("epoch: %d, iter: %d, finetuning_loss: %.5e" % (epoch, iteration, loss.item()))



In [None]:
#@title Start evaluating the model after the training

samples, n = sampling_fn(score_model, sampling_shape=sampling_shape)
samples = apply_activate(samples, transformer.output_info)
samples = transformer.inverse_transform(samples.cpu().numpy())

eval_samples = transformer.inverse_transform(eval_ds)
train_samples = transformer.inverse_transform(train_ds)

scores, _ = evaluation.compute_scores(train=train_samples, test=eval_samples, synthesized_data=[samples], metadata=meta)


In [None]:
print(scores)

binary_f1      0.220284
roc_auc        0.488710
weighted_f1    0.265565
accuracy       0.420114
dtype: float64
