In [1]:
%cd ../..

/home/dev/sb-proj/tabsyn-concat


In [2]:
import json
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np

from utils_train import preprocess, TabularDataset
from tabsyn.vae.model import Model_VAE, Encoder_model, Decoder_model

In [4]:
torch.set_printoptions(sci_mode=False, precision=4)
np.set_printoptions(suppress=True)

In [5]:
LR = 1e-3
WD = 0
D_TOKEN = 6
TOKEN_BIAS = True

N_HEAD = 2
FACTOR = 64
NUM_LAYERS = 2

In [7]:
# dataname = "mbd_short_100k"
dataname = "mbd_debug"
data_dir = f'data/{dataname}_with_id'
ckpt_dir = f"tabsyn/vae/ckpt/{dataname}"
info_path = f'data/{dataname}_with_id/info.json'

with open(info_path, 'r') as f:
    info = json.load(f)
    
device = "cuda:0"

# Get latent codes

In [8]:
X_num, X_cat, categories, d_numerical = preprocess(data_dir, task_type = info['task_type'])

No NaNs in numerical features, skipping


In [9]:
send_telegram_msg('Data was loaded.')

In [10]:
d_numerical, categories

(3, [80000, 19, 11, 13, 36])

In [11]:
d_numerical = d_numerical - 1
categories.pop(0)
d_numerical, categories

(2, [19, 11, 13, 36])

In [12]:
X_train_num, _ = X_num
X_train_cat, _ = X_cat

X_train_num, X_test_num = X_num
X_train_cat, X_test_cat = X_cat

X_train_num, X_test_num = torch.tensor(X_train_num).float(), torch.tensor(X_test_num).float()
X_train_cat, X_test_cat =  torch.tensor(X_train_cat), torch.tensor(X_test_cat)

train_data = TabularDataset(X_train_num.float(), X_train_cat)
test_data = TabularDataset(X_test_num.float(), X_test_cat)

batch_size = 2**15
train_loader = DataLoader(
    train_data,
    batch_size=batch_size,
    num_workers=4,
    drop_last=False,
)
test_loader = DataLoader(
    test_data,
    batch_size=batch_size,
    num_workers=4,
    drop_last=False,
)

In [13]:

model = Model_VAE(NUM_LAYERS, d_numerical, categories, D_TOKEN, n_head = N_HEAD, factor = FACTOR, bias = True)
model = model.to(device)

pre_encoder = Encoder_model(NUM_LAYERS, d_numerical, categories, D_TOKEN, n_head = N_HEAD, factor = FACTOR).to(device)
pre_decoder = Decoder_model(NUM_LAYERS, d_numerical, categories, D_TOKEN, n_head = N_HEAD, factor = FACTOR).to(device)

pre_encoder.eval()
pre_decoder.eval()

self.category_embeddings.weight.shape=torch.Size([79, 6])
self.category_embeddings.weight.shape=torch.Size([79, 6])


Decoder_model(
  (VAE_Decoder): Transformer(
    (layers): ModuleList(
      (0): ModuleDict(
        (attention): MultiheadAttention(
          (W_q): Linear(in_features=6, out_features=6, bias=True)
          (W_k): Linear(in_features=6, out_features=6, bias=True)
          (W_v): Linear(in_features=6, out_features=6, bias=True)
          (W_out): Linear(in_features=6, out_features=6, bias=True)
        )
        (linear0): Linear(in_features=6, out_features=384, bias=True)
        (linear1): Linear(in_features=384, out_features=6, bias=True)
        (norm1): LayerNorm((6,), eps=1e-05, elementwise_affine=True)
      )
      (1): ModuleDict(
        (attention): MultiheadAttention(
          (W_q): Linear(in_features=6, out_features=6, bias=True)
          (W_k): Linear(in_features=6, out_features=6, bias=True)
          (W_v): Linear(in_features=6, out_features=6, bias=True)
          (W_out): Linear(in_features=6, out_features=6, bias=True)
        )
        (linear0): Linear(in_fea

In [14]:
ckpt = torch.load(f"{ckpt_dir}/model.pt")
model.load_state_dict(ckpt)

<All keys matched successfully>

In [15]:
model.eval()

Model_VAE(
  (VAE): VAE(
    (Tokenizer): Tokenizer(
      (category_embeddings): Embedding(79, 6)
    )
    (encoder_mu): Transformer(
      (layers): ModuleList(
        (0): ModuleDict(
          (attention): MultiheadAttention(
            (W_q): Linear(in_features=6, out_features=6, bias=True)
            (W_k): Linear(in_features=6, out_features=6, bias=True)
            (W_v): Linear(in_features=6, out_features=6, bias=True)
            (W_out): Linear(in_features=6, out_features=6, bias=True)
          )
          (linear0): Linear(in_features=6, out_features=384, bias=True)
          (linear1): Linear(in_features=384, out_features=6, bias=True)
          (norm1): LayerNorm((6,), eps=1e-05, elementwise_affine=True)
        )
        (1): ModuleDict(
          (attention): MultiheadAttention(
            (W_q): Linear(in_features=6, out_features=6, bias=True)
            (W_k): Linear(in_features=6, out_features=6, bias=True)
            (W_v): Linear(in_features=6, out_features

In [16]:
pre_encoder.load_weights(model)
pre_decoder.load_weights(model)

In [17]:
pre_encoder.eval()
pre_decoder.to(device)

Decoder_model(
  (VAE_Decoder): Transformer(
    (layers): ModuleList(
      (0): ModuleDict(
        (attention): MultiheadAttention(
          (W_q): Linear(in_features=6, out_features=6, bias=True)
          (W_k): Linear(in_features=6, out_features=6, bias=True)
          (W_v): Linear(in_features=6, out_features=6, bias=True)
          (W_out): Linear(in_features=6, out_features=6, bias=True)
        )
        (linear0): Linear(in_features=6, out_features=384, bias=True)
        (linear1): Linear(in_features=384, out_features=6, bias=True)
        (norm1): LayerNorm((6,), eps=1e-05, elementwise_affine=True)
      )
      (1): ModuleDict(
        (attention): MultiheadAttention(
          (W_q): Linear(in_features=6, out_features=6, bias=True)
          (W_k): Linear(in_features=6, out_features=6, bias=True)
          (W_v): Linear(in_features=6, out_features=6, bias=True)
          (W_out): Linear(in_features=6, out_features=6, bias=True)
        )
        (linear0): Linear(in_fea

In [18]:
latent_codes = []
client_ids = []
abs_times = []

with torch.inference_mode():
    for batch_num, batch_cat in tqdm(train_loader):
        # print(batch_num.shape, batch_cat.shape)
        # print(batch_num[:5])
        batch_num = batch_num.to(device)
        batch_cat = batch_cat.to(device)
        
        # break
        abs_time, batch_num = batch_num[:, 0], batch_num[:, 1:]
        client_id, batch_cat = batch_cat[:, 0], batch_cat[:, 1:]
        
        train_z = pre_encoder(batch_num, batch_cat).cpu().numpy()
        latent_codes.append(train_z)
        client_ids.append(client_id.cpu().numpy())
        abs_times.append(abs_time.cpu().numpy())

100%|██████████| 1477/1477 [03:14<00:00,  7.59it/s]


In [19]:
train_z = np.concatenate(latent_codes)
train_client_ids = np.concatenate(client_ids)
train_abs_times = np.concatenate(abs_times)

In [20]:
train_z.shape

(48390887, 7, 6)

In [21]:
latent_codes = []
client_ids = []
abs_times = []

with torch.inference_mode():
    for batch_num, batch_cat in tqdm(test_loader):
        batch_num = batch_num.to(device)
        batch_cat = batch_cat.to(device)
        abs_time, batch_num = batch_num[:, 0], batch_num[:, 1:]
        client_id, batch_cat = batch_cat[:, 0], batch_cat[:, 1:]
        
        z = pre_encoder(batch_num, batch_cat).cpu().numpy()
        latent_codes.append(z)
        client_ids.append(client_id.cpu().numpy())
        abs_times.append(abs_time.cpu().numpy())

  0%|          | 0/386 [00:00<?, ?it/s]

100%|██████████| 386/386 [00:56<00:00,  6.82it/s]


In [22]:
test_z = np.concatenate(latent_codes)
test_client_ids = np.concatenate(client_ids)
test_abs_times = np.concatenate(abs_times)

In [23]:
test_z.shape

(12647803, 7, 6)

In [24]:
ckpt_dir

'tabsyn/vae/ckpt/mbd_debug'

In [25]:
np.save(f'{ckpt_dir}/train_z.npy', train_z)
np.save(f'{ckpt_dir}/train_client_ids.npy', train_client_ids)
np.save(f'{ckpt_dir}/train_abs_times.npy', train_abs_times)

In [26]:
np.save(f'{ckpt_dir}/test_z.npy', test_z)
np.save(f'{ckpt_dir}/test_client_ids.npy', test_client_ids)
np.save(f'{ckpt_dir}/test_abs_times.npy', test_abs_times)

# Sort codes

In [63]:
import pandas as pd
from IPython.display import display

In [6]:
train_z = np.load(f'{ckpt_dir}/train_z.npy')
train_client_ids = np.load(f'{ckpt_dir}/train_client_ids.npy')
train_abs_times = np.load(f'{ckpt_dir}/train_abs_times.npy')

In [7]:
df_seqs = pd.DataFrame(dict(idx=train_client_ids, order=train_abs_times))
df_seqs

Unnamed: 0,idx,order
0,14315,-0.489501
1,17935,-0.566783
2,26622,-1.252068
3,48255,0.175591
4,36048,-0.455435
...,...,...
9356048,41529,0.143390
9356049,5793,-0.500177
9356050,39608,0.657490
9356051,5910,0.346575


In [8]:
df_seqs = df_seqs.sort_values(["idx", "order"])
df_seqs

Unnamed: 0,idx,order
7764011,0,-3.014032
3805364,0,-1.187927
5330317,0,-0.701916
587580,0,-0.652003
7456754,0,-0.186872
...,...,...
4961325,76799,1.861384
3090022,76799,1.862275
1721820,76799,1.878158
7467402,76799,1.878258


In [9]:
df_seqs.groupby("idx").size().describe()

count    76800.000000
mean       121.823607
std        133.992507
min          5.000000
25%         30.000000
50%         73.000000
75%        167.000000
max       1357.000000
dtype: float64

In [58]:
dev = df_seqs[df_seqs.idx <= 1]
dev

Unnamed: 0,idx,order
7764011,0,-3.014032
3805364,0,-1.187927
5330317,0,-0.701916
587580,0,-0.652003
7456754,0,-0.186872
...,...,...
6925063,1,1.374727
8270442,1,1.374869
220895,1,1.377111
5817231,1,2.651878


In [75]:
length = 12

In [86]:
def get_slices(df):
    w = filter(lambda df: len(df) == length, df.rolling(window=length))
    indices = list(map(lambda df: df.index.values, w))
    if not indices:
        return pd.DataFrame(columns=list(range(length)), dtype=int)
    return pd.DataFrame(np.stack(indices))

In [87]:
idx = dev.groupby("idx", group_keys=False).apply(get_slices, include_groups=False).values

In [88]:
train_z[idx].shape

(72, 12, 7, 4)

In [70]:
train_z.shape

(9356053, 7, 4)