In [1]:
import os
import sys
import time
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader
from torchsummary import summary
from typing import List, Tuple
from torch.nn import functional as F
if '..' not in os.sys.path:os.sys.path.append('..')
from utils.distribution_utils import get_probability_distributions_from_sequence
device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")

In [2]:
dataset = 'icbc'
data_dir = f'/mnt/ssd1/hsj/encore/{dataset}'
size_cdf_file = os.path.join(data_dir, f'{dataset}_cdf', f'{dataset}_size_cdf_coarse.csv')
interval_cdf_file = os.path.join(data_dir, f'{dataset}_cdf', f'{dataset}_interval_cdf_coarse.csv')
size_cdf = pd.read_csv(size_cdf_file)
interval_cdf = pd.read_csv(interval_cdf_file)
n_size, n_interval = len(size_cdf), len(interval_cdf)
start_token = n_size * n_interval
print(f'n_size: {n_size}, n_interval: {n_interval}, start_token: {start_token}')

n_size: 30, n_interval: 30, start_token: 900


In [3]:
def get_trainset(pair_index_file:str, n_size:int, n_interval:int, device:torch.device) -> Tuple[torch.utils.data.TensorDataset, np.ndarray, np.ndarray]:
    pair_index = []
    with open(pair_index_file, 'r') as f:
        lines = f.readlines()
        for line in lines:
            pair_index.append([int(x) for x in line.strip().split(',')])

    size_probs, interval_probs = [], []
    for seq in tqdm(pair_index):
        seq = np.array(seq[:-1])
        size_prob, interval_prob = get_probability_distributions_from_sequence(seq, n_size, n_interval)
        size_probs.append(size_prob)
        interval_probs.append(interval_prob)

    size_probs, interval_probs = np.array(size_probs), np.array(interval_probs)
    size_probe_tensor = torch.tensor(size_probs, dtype=torch.float32).to(device)
    interval_probe_tensor = torch.tensor(interval_probs, dtype=torch.float32).to(device)
    dataset = torch.utils.data.TensorDataset(size_probe_tensor, interval_probe_tensor)
    return dataset, size_probs, interval_probs

pair_index_file = os.path.join(data_dir, f'{dataset}_pair_index.txt')
trainset, size_probs, interval_probs = get_trainset(pair_index_file, n_size, n_interval, device)

100%|██████████| 7569/7569 [00:00<00:00, 35200.72it/s]


In [4]:
class VAE(nn.Module):
    def __init__(self, n_size, n_interval, hidden_dims, n_latent):
        super(VAE, self).__init__()
        # initialize the parameters
        self.n_size = n_size
        self.n_interval = n_interval
        self.n_latent = n_latent
        self.encorer_hidden_dims = hidden_dims
        self.decoder_hidden_dims = hidden_dims
        self.decoder_hidden_dims.reverse()
        self.fc_mu = nn.Linear(self.encorer_hidden_dims[-1], n_latent)
        self.fc_var = nn.Linear(self.encorer_hidden_dims[-1], n_latent)
        self.fc_size = nn.Linear(self.decoder_hidden_dims[-1], n_size)
        self.fc_interval = nn.Linear(self.decoder_hidden_dims[-1], n_interval)

        # construct encoder
        encoder_layers = [nn.Linear(n_size + n_interval, hidden_dims[0]), nn.ReLU()]
        for i in range(1, len(hidden_dims)):
            encoder_layers.append(nn.Linear(hidden_dims[i-1], hidden_dims[i]))
            encoder_layers.append(nn.ReLU())
        self.encoder = nn.Sequential(*encoder_layers)

        # construct decoder
        decoder_layers = [nn.Linear(n_latent, hidden_dims[0]), nn.ReLU()]
        for i in range(1, len(hidden_dims)):
            decoder_layers.append(nn.Linear(hidden_dims[i-1], hidden_dims[i]))
            decoder_layers.append(nn.ReLU())
        self.decoder = nn.Sequential(*decoder_layers)

    def encode(self, x:Tensor) -> Tuple[Tensor, Tensor]:
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_var(h)
    
    def reparameterize(self, mu:Tensor, log_var:Tensor) -> Tensor:
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z:Tensor) -> Tensor:
        h = self.decoder(z)
        return self.fc_size(h), self.fc_interval(h)

    def forward(self, x:Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        size_logits, interval_logits = self.decode(z)
        size_recon = F.softmax(size_logits, dim=-1)
        interval_recon = F.softmax(interval_logits, dim=-1)
        return size_recon, interval_recon, mu, log_var
    
    def generate(self, n:int) -> Tuple[Tensor, Tensor]:
        z = torch.randn(n, self.n_latent).to(device)
        size_logits, interval_logits = self.decode(z)
        size_recon = F.softmax(size_logits, dim=-1)
        interval_recon = F.softmax(interval_logits, dim=-1)
        return size_recon, interval_recon
    

class WeightedLoss(nn.Module):
    def __init__(self, kld_weight:float=1.0):
        super(WeightedLoss, self).__init__()
        self.kld_weight = kld_weight

    def forward(self, size_recon:Tensor, size_target:Tensor, interval_recon:Tensor, interval_target:Tensor, mu:Tensor, log_var:Tensor) -> Tensor:
        recon_loss = F.l1_loss(size_recon, size_target) + F.l1_loss(interval_recon, interval_target)
        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1), dim=0)
        weighted_loss = recon_loss + self.kld_weight * kld_loss
        return weighted_loss, recon_loss, kld_loss

In [5]:
def train_epoch(model:nn.Module, dataloader:DataLoader, optimizer:torch.optim.Optimizer, criterion:nn.Module, device:torch.device) -> Tuple[float, float, float]:
    model.train()
    running_loss, running_recon_loss, running_kld_loss = 0.0, 0.0, 0.0
    for size_batch, interval_batch in dataloader:
        size_batch, interval_batch = size_batch.to(device), interval_batch.to(device)
        input_tensor = torch.cat([size_batch, interval_batch], dim=1)
        size_recon, interval_recon, mu, log_var = model(input_tensor)
        loss, recon_loss, kld_loss = criterion(size_recon, size_batch, interval_recon, interval_batch, mu, log_var)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * size_batch.size(0)
        running_recon_loss += recon_loss.item() * size_batch.size(0)
        running_kld_loss += kld_loss.item() * size_batch.size(0)
    average_loss = running_loss / len(dataloader.dataset)
    average_recon_loss = running_recon_loss / len(dataloader.dataset)
    average_kld_loss = running_kld_loss / len(dataloader.dataset)
    return average_loss, average_recon_loss, average_kld_loss


def train(model:nn.Module, dataloader:DataLoader, optimizer:torch.optim.Optimizer, scheduler:torch.optim.lr_scheduler, n_epochs:int, plot_every:int, init_kld_weight:float, kld_update_every:int, device:torch.device) -> None:
    start_time = time.time()
    kld_weight = init_kld_weight
    for epoch in range(1, n_epochs+1):
        epoch_start_time = time.time()
        criterion = WeightedLoss(kld_weight)
        train_loss, train_recon_loss, train_kld_loss = train_epoch(model, dataloader, optimizer, criterion, device)
        scheduler.step()
        if epoch % plot_every == 0:
            print(f'Epoch: {epoch}/{n_epochs}, Epoch time: {time.time() - epoch_start_time:.2f}s, Total time: {time.time() - start_time:.2f}s, LR: {scheduler.get_last_lr()[0]:.2e}, KLD Weight: {kld_weight:.2e}')
            print(f'Train Loss: {train_loss:.4f}, Train Recon Loss: {train_recon_loss:.4f}, Train KLD Loss: {train_kld_loss:.4f}')
            print('-' * 80)
        if epoch % kld_update_every == 0:
            kld_weight = min(kld_weight * 10, 1e-3)
    print(f'Training finished after {n_epochs} epochs in {time.time() - start_time:.2f}s')

In [6]:
batch_size = 512
hidden_dims = [256, 512, 1024, 1024, 512, 256]
latent_dim = 64
vae = VAE(n_size, n_interval, hidden_dims, latent_dim)
vae = vae.to('cpu')
summary(vae, input_size=(n_size + n_interval,), device='cpu')
vae = vae.to(device)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                  [-1, 256]          15,616
              ReLU-2                  [-1, 256]               0
            Linear-3                  [-1, 512]         131,584
              ReLU-4                  [-1, 512]               0
            Linear-5                 [-1, 1024]         525,312
              ReLU-6                 [-1, 1024]               0
            Linear-7                 [-1, 1024]       1,049,600
              ReLU-8                 [-1, 1024]               0
            Linear-9                  [-1, 512]         524,800
             ReLU-10                  [-1, 512]               0
           Linear-11                  [-1, 256]         131,328
             ReLU-12                  [-1, 256]               0
           Linear-13                   [-1, 64]          16,448
           Linear-14                   

In [7]:
model_dir = f'/mnt/ssd1/hsj/encore/{dataset}/model/vae-01-08/'
model_path = os.path.join(model_dir, 'encore_vae_70000.pt')
if os.path.exists(model_path):
    vae.load_state_dict(torch.load(model_path))
    print(f'Loaded model from {model_path}')

Loaded model from /mnt/ssd1/hsj/encore/icbc/model/vae-01-08/encore_vae_70000.pt


In [8]:
params = sum(p.numel() for p in vae.parameters() if p.requires_grad)
print(f'Total trainable parameters: {params}')

Total trainable parameters: 4805820


In [9]:
dataloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
size_batch, interval_batch = next(iter(dataloader))
input_tensor = torch.cat([size_batch, interval_batch], dim=1)
size_recon, interval_recon, mu, log_var = vae(input_tensor)

In [25]:
n_epochs = 10000
plot_every = 100
update_kld_every = 1000
init_kld_weight = 1e-3
optimizer = torch.optim.Adam(vae.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.8)
train(vae, dataloader, optimizer, scheduler, n_epochs, plot_every, init_kld_weight, update_kld_every, device)

Epoch: 100/10000, Epoch time: 0.21s, Total time: 23.65s, LR: 5.00e-04, KLD Weight: 1.00e-03
Train Loss: 0.0148, Train Recon Loss: 0.0067, Train KLD Loss: 8.0213
--------------------------------------------------------------------------------
Epoch: 200/10000, Epoch time: 0.23s, Total time: 46.50s, LR: 5.00e-04, KLD Weight: 1.00e-03
Train Loss: 0.0154, Train Recon Loss: 0.0070, Train KLD Loss: 8.3058
--------------------------------------------------------------------------------
Epoch: 300/10000, Epoch time: 0.23s, Total time: 69.32s, LR: 5.00e-04, KLD Weight: 1.00e-03
Train Loss: 0.0149, Train Recon Loss: 0.0070, Train KLD Loss: 7.8934
--------------------------------------------------------------------------------
Epoch: 400/10000, Epoch time: 0.09s, Total time: 85.33s, LR: 5.00e-04, KLD Weight: 1.00e-03
Train Loss: 0.0150, Train Recon Loss: 0.0070, Train KLD Loss: 7.9672
--------------------------------------------------------------------------------
Epoch: 500/10000, Epoch time: 0.

KeyboardInterrupt: 

In [10]:
size_gen, interval_gen = vae.generate(len(size_probs))
size_gen, interval_gen = size_gen.detach().cpu().numpy(), interval_gen.detach().cpu().numpy()
size_gen[size_gen < 1e-4] = 0
size_gen = size_gen / size_gen.sum(axis=1, keepdims=True)

In [11]:
def get_mse(ori, recon):
    return np.sqrt(np.sum((ori[:, np.newaxis, :] - recon[np.newaxis, :, :]) ** 2, axis=2))
    # mae = np.sum(np.abs(ori[:, np.newaxis, :] - recon[np.newaxis, :, :]), axis=2)
    # return mse, mae
mse = get_mse(size_probs, size_gen)

In [12]:
mse_accuracy = mse.min(axis=0)
mse_coverage = mse.min(axis=1)
print(f'Accuracy, mean: {mse_accuracy.mean():.3f}, p90: {np.percentile(mse_accuracy, 90):.3f}, p95: {np.percentile(mse_accuracy, 95):.3f}, p99: {np.percentile(mse_accuracy, 99):.3f}')
print(f'Coverage, mean: {mse_coverage.mean():.3f}, p90: {np.percentile(mse_coverage, 90):.3f}, p95: {np.percentile(mse_coverage, 95):.3f}, p99: {np.percentile(mse_coverage, 99):.3f}')

Accuracy, mean: 0.017, p90: 0.051, p95: 0.090, p99: 0.203
Coverage, mean: 0.018, p90: 0.050, p95: 0.073, p99: 0.164


In [None]:
def get_mse(ori, recon):
    return np.sqrt(np.sum((ori[:, np.newaxis, :] - recon[np.newaxis, :, :]) ** 2, axis=2))

mse_accuracy = mse.min(axis=0)
mse_coverage = mse.min(axis=1)
print(f'Accuracy, mean: {mse_accuracy.mean():.3f}, p90: {np.percentile(mse_accuracy, 90):.3f}, p95: {np.percentile(mse_accuracy, 95):.3f}, p99: {np.percentile(mse_accuracy, 99):.3f}')
print(f'Coverage, mean: {mse_coverage.mean():.3f}, p90: {np.percentile(mse_coverage, 90):.3f}, p95: {np.percentile(mse_coverage, 95):.3f}, p99: {np.percentile(mse_coverage, 99):.3f}')

In [13]:
[np.percentile(mse_accuracy, i) for i in range(10, 100, 10)]

[0.0,
 1.8303059785570975e-05,
 0.00018246617818991768,
 0.00040697942087060683,
 0.0009566476372602447,
 0.004010271883551145,
 0.011588841660199946,
 0.023936840160009524,
 0.05061360908652382]

In [14]:
avg_size_prob = size_probs.mean(axis=0)
avg_diff = np.sum((size_probs - avg_size_prob) ** 2, axis=1)

In [15]:
np.min(avg_diff), np.mean(avg_diff), np.percentile(avg_diff, 90), np.percentile(avg_diff, 95), np.percentile(avg_diff, 99)    

(0.03520654593132698,
 0.6344232657271119,
 0.9394095746923774,
 0.9904461265397568,
 1.0246181633645304)