In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pickle
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader, Dataset, TensorDataset
from tqdm import trange
import torch.nn.functional as F

params = {'axes.grid': True,
          'grid.linestyle': '--',
          }
plt.rcParams.update(params)

import vae

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

In [None]:
with open("data/tt_epi_ds_long.pkl", "rb") as f:
    data = pickle.load(f).astype(np.float32)

data = torch.from_numpy(data).unsqueeze(1).to(device)
data = F.pad(data, (0, 6), "constant", 0)

ds = TensorDataset(data)
dl = DataLoader(ds, batch_size=256, shuffle=True)

In [None]:
latent_dim = 16
features = 16
model = vae.VAE(latent_dim, features).to(device)
loss_fn = torch.nn.MSELoss(reduction='sum')
opt = torch.optim.Adam(model.parameters(), lr=3e-4)
loss_history=[]

In [None]:
def train_loop(kl_weight=0.0):
    total_loss = 0
    total_rec_loss = 0
    total_kl_loss = 0

    for X, in dl:
        X = X.to(device)
        opt.zero_grad()

        Y, mu, logvar = model(X)
        reconstruction_loss = loss_fn(Y, X)
        kl_divergence = -0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp()) # eq. 10
        loss = reconstruction_loss + kl_weight*kl_divergence

        loss.backward()
        opt.step()

        total_rec_loss += reconstruction_loss.item() / len(X)
        total_kl_loss += kl_divergence.item() / len(X)
        total_loss += loss.item() / len(X)

    return total_loss, total_rec_loss, total_kl_loss

In [None]:
epochs = 200
model.train()
loop = trange(epochs)
for epoch in loop:
    loss, rec_loss, kl_loss = train_loop(0.2)
    loss_history.append(loss)
    loop.set_postfix(loss=loss, reconstruction=rec_loss, kl=kl_loss)
plt.plot(loss_history[2:])

In [None]:
# Generacja

z = torch.randn(16, latent_dim).to(device)
model.eval()

with torch.no_grad():
    generated = model.dec(z).cpu().detach()

for g in generated:
    plt.plot(g[0])
plt.xlim(0, 220)

In [None]:
# Rekonstrukcja

model.eval()
rows, cols = 2, 3
signal,  = next(iter(dl))
decoded = model(signal.to(device))[0].cpu().detach()

fig, axs = plt.subplots(rows, cols, figsize=(10, 5), sharex=True, sharey=True)
for i, ax in enumerate(axs.flat):
    ax.plot(signal[i, 0].cpu().numpy())
    ax.plot(decoded[i, 0].numpy(), '--r')
    ax.set_xlim(0, 120)

# Zapis i odczyt modelu

In [None]:
fnames = !ls checkpoints/*.pth
fnames

In [None]:
# load model
latent_dim = 16
features = 16
# fn = f"vae_z{latent_dim}_f{features}.pth"
fn = fnames[0]
model = vae.VAE(latent_dim, features).to(device)
model.load_state_dict(torch.load(fn, map_location=device, weights_only=True))

In [None]:
# from safetensors.torch import save_model
# save_model(model, fn.replace(".pth", ".safetensors"))

In [None]:
# save model
# torch.save(model.state_dict(), f"checkpoints/vae_z{latent_dim}_f{features}.pth")
# torch.save(model.state_dict(), f"checkpoints/vae_z{latent_dim}_f{features}_long_kl02.pth")

# load model
# model = VAE(latent_dim, 16).to(device)
# model.load_state_dict(torch.load(f"vae_{}.pth"))