
# random code for testing, I wouldn't reference this file

In [32]:
import os
import time
import pickle
import random

import torch
import numpy as np
import matplotlib.pyplot as plt
import tqdm
from torch.nn import functional as F

from quadrotor_diffusion.utils.dataset.normalizer import NoNormalizer
from quadrotor_diffusion.utils.dataset.dataset import DiffusionDataset, VaeDataset
from quadrotor_diffusion.utils.file import get_checkpoint_file
from quadrotor_diffusion.utils.nn.training import Trainer
from quadrotor_diffusion.models.vae_wrapper import VAE_Wrapper
from quadrotor_diffusion.utils.dataset.boundary_condition import PolynomialTrajectory
from quadrotor_diffusion.utils.plotting import plot_states, add_gates_to_course, add_trajectory_to_course, course_base_plot, COLORS

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
dataset = DiffusionDataset("../data", traj_len=128, normalizer=NoNormalizer(), folder="diffusion4")
print(len(dataset))

vae_experiment: int = 249
chkpt = get_checkpoint_file("../logs/training", vae_experiment)
vae_wrapper: VAE_Wrapper = None
vae_wrapper, _, _, _ = Trainer.load(chkpt, get_ema=False)
vae_wrapper = vae_wrapper.cuda()

151544


In [None]:
d = dataset[-1350]
course = d["global_conditioning"]
context = d["local_conditioning"]
traj = d["x_0"]

course = torch.cat((course[0].unsqueeze(0), course))

_, axs = course_base_plot()
add_gates_to_course(course, axs, has_end=False)
add_trajectory_to_course(axs, traj)

mu, _ = vae_wrapper.encode(traj.unsqueeze(0).cuda())
trajdec = vae_wrapper.decode(mu).squeeze(0).cpu()
add_trajectory_to_course(axs, trajdec, reference=True)


axs[0].scatter(context[:, 0], context[:, 1], color='red')
axs[1].scatter(context[:, 0], context[:, 2], color='red')

# plt.show()
# plt.savefig("d.pdf")
plt.close()

In [7]:
x = []
y = []
z = []
for d in tqdm.tqdm(dataset, total=len(dataset)):
    course = d["global_conditioning"]
    x.append(course[0][0])
    y.append(course[0][1])
    z.append(course[0][2])
    # axs[0].scatter(course[0][0], course[0][1], s=0.1, alpha=0.2, color='red')
    # axs[1].scatter(course[0][0], course[0][2], s=0.1, alpha=0.2, color='red')

100%|██████████| 72632/72632 [00:10<00:00, 6759.55it/s]


In [8]:
_, axs = course_base_plot()
axs[0].scatter(x, y, alpha=0.05, s=0.5, c=COLORS[2])
axs[1].scatter(x, z, alpha=0.05, s=0.5, c=COLORS[2])
plt.savefig("data.pdf")
plt.close()

In [None]:
vae_dataset = VaeDataset("../data", normalizer=NoNormalizer())


tensor([[ 0.0000e+00,  0.0000e+00,  5.0153e-01],
        [ 2.3445e-02,  1.4367e-03,  5.0390e-01],
        [ 4.6148e-02,  4.3536e-03,  5.0611e-01],
        [ 6.8068e-02,  8.7879e-03,  5.0817e-01],
        [ 8.9167e-02,  1.4770e-02,  5.1008e-01],
        [ 1.0941e-01,  2.2322e-02,  5.1184e-01],
        [ 1.2878e-01,  3.1460e-02,  5.1346e-01],
        [ 1.4724e-01,  4.2192e-02,  5.1494e-01],
        [ 1.6479e-01,  5.4515e-02,  5.1629e-01],
        [ 1.8140e-01,  6.8420e-02,  5.1751e-01],
        [ 1.9708e-01,  8.3891e-02,  5.1861e-01],
        [ 2.1184e-01,  1.0090e-01,  5.1959e-01],
        [ 2.2567e-01,  1.1941e-01,  5.2047e-01],
        [ 2.3860e-01,  1.3939e-01,  5.2124e-01],
        [ 2.5064e-01,  1.6077e-01,  5.2191e-01],
        [ 2.6182e-01,  1.8350e-01,  5.2250e-01],
        [ 2.7219e-01,  2.0752e-01,  5.2300e-01],
        [ 2.8177e-01,  2.3275e-01,  5.2343e-01],
        [ 2.9061e-01,  2.5911e-01,  5.2379e-01],
        [ 2.9876e-01,  2.8651e-01,  5.2408e-01],
        [ 3.0628e-01

In [None]:
import matplotlib.pyplot as plt
import random

fig = plt.figure(figsize=(18, 8))  # Slightly shorter height

num_samples = 5  # number of columns

for col in range(num_samples):
    sample = random.choice(vae_dataset)
    reference = dataset.normalizer.undo(sample.cpu().numpy())

    # Row 0: X
    plt.subplot(3, num_samples, col + 1)
    plt.plot(reference[:, 0], label='X', linewidth=3.5, c=COLORS[0])
    if col == 0:
        plt.ylabel("x (meters)")
    plt.title(f"Sample {col+1}", fontsize=10)
    plt.xticks([])
    plt.yticks(fontsize=8)
    plt.grid()

    # Row 1: Y
    plt.subplot(3, num_samples, num_samples + col + 1)
    plt.plot(reference[:, 1], label='Y', linewidth=3.5, c=COLORS[1])
    if col == 0:
        plt.ylabel("y (meters)")
    plt.xticks([])
    plt.yticks(fontsize=8)
    plt.grid()

    # Row 2: Z
    plt.subplot(3, num_samples, 2 * num_samples + col + 1)
    plt.plot(reference[:, 2], label='Z', linewidth=3.5, c=COLORS[2])
    if col == 0:
        plt.ylabel("z (meters)")
    plt.yticks(fontsize=8)
    plt.grid()

# Use very tight layout
plt.tight_layout(pad=0.3)  # pad between subplots
plt.subplots_adjust(top=0.92, hspace=0.15, wspace=0.15)  # less vertical & horizontal space

fig.legend(['X', 'Y', 'Z'],
           loc='upper center',
           bbox_to_anchor=(0.5, 1.03),
           ncol=3,
           frameon=False,
           fontsize=12)


plt.savefig("d.pdf")
plt.close()