In [15]:
import matplotlib.pyplot as plt

import torch

from torch.utils.data import DataLoader

device = "cuda" if torch.cuda.is_available() else "cpu"

from diffusion_framework.ddpm_1d import DDPM_1d
from diffusion_framework.nets import ErrorNet, CondErrorNet

from torch.optim import Adam
import numpy as np

import os
import gym
import numpy as np
from stable_baselines3 import SAC
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv
import platform
if platform.system() == "Linux":
    os.environ['PYOPENGL_PLATFORM'] = 'egl'

### Select environment
env_id = "MountainCarContinuous-v0"
# Note that the algorithm is SAC

gym_env = gym.make(env_id)
max_obs_values = gym_env.observation_space.high
min_obs_values = gym_env.observation_space.low

max_act_value = gym_env.action_space.high
min_act_value = gym_env.action_space.low

bias = max_obs_values + min_obs_values
bias = bias / 2
scale = max_obs_values - min_obs_values
scale = scale / 2

env = make_vec_env(env_id, n_envs=1)
best_model = SAC.load('/home/sai-admin/advanced_ml_project/rl-baselines3-zoo/rl-trained-agents/sac/' + 
                      env_id + '_1/' + 
                      env_id + '.zip', env=env)

In [36]:
# Define data
expert_demo_path = 'expert/logs/expert_demonstrations.npy'

dataset = np.load(expert_demo_path, allow_pickle=True)
dataset = torch.tensor(dataset, dtype=torch.float32)

batch_size = 128
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Define models
timesteps = 100
diffusion = DDPM_1d(timesteps)

In [38]:
# Visualizing corruption process

# x_sample = dataset[:10000].unsqueeze(0)
# t = torch.tensor([0,250,500,750,999])
# out = diffusion.q_sample(x_sample, t)
# N = 5
# fig, ax = plt.subplots(1,5,figsize=(N*2,2))
# for i in range(N):
#     ax[i].scatter(out[i, :, 0], out[i, :, 1], s=2)
#     ax[i].set_xlim(-1,1)
#     ax[i].set_ylim(-1,1)
# plt.tight_layout()

In [39]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
model = CondErrorNet(dim=1,cond_dim=2)
model.to(device)

optimizer = Adam(model.parameters(), lr=1e-4)
model

cuda:0


CondErrorNet(
  (time_mlp): Sequential(
    (0): SinusoidalPositionEmbeddings()
    (1): Linear(in_features=64, out_features=128, bias=True)
    (2): GELU()
    (3): Linear(in_features=128, out_features=64, bias=True)
    (4): SiLU()
    (5): Linear(in_features=64, out_features=2, bias=True)
  )
  (state_mlp): Sequential(
    (0): Linear(in_features=3, out_features=64, bias=True)
    (1): GELU()
    (2): Linear(in_features=64, out_features=128, bias=True)
    (3): GELU()
    (4): Linear(in_features=128, out_features=64, bias=True)
    (5): GELU()
    (6): Linear(in_features=64, out_features=1, bias=True)
  )
  (res_mlp): Sequential(
    (0): Linear(in_features=4, out_features=64, bias=True)
    (1): SiLU()
    (2): Linear(in_features=64, out_features=1, bias=True)
  )
  (final_mlp): Sequential(
    (0): Linear(in_features=3, out_features=64, bias=True)
    (1): GELU()
    (2): Linear(in_features=64, out_features=128, bias=True)
    (3): GELU()
    (4): Linear(in_features=128, out_featu

In [40]:
epochs = 20

for epoch in range(epochs):
    for step, batch in enumerate(dataloader):
        optimizer.zero_grad()

        batch_size = batch.shape[0]
        batch = batch.to(device)

        # Algorithm 1 line 3: sample t uniformally for every example in the batch
        t = torch.randint(0, timesteps, (batch_size,), device=device).long()

        loss = diffusion.conditional_p_losses(model, batch, t, loss_type="huber")

        if step % 100 == 0:
            print("Epoch: %d, Loss: %f" %(epoch, loss.item()))

        loss.backward()
        optimizer.step()


Epoch: 0, Loss: 0.430114
Epoch: 0, Loss: 0.357732
Epoch: 0, Loss: 0.217975
Epoch: 0, Loss: 0.145122
Epoch: 0, Loss: 0.206579
Epoch: 0, Loss: 0.163589
Epoch: 0, Loss: 0.171231
Epoch: 0, Loss: 0.157102
Epoch: 0, Loss: 0.162236
Epoch: 0, Loss: 0.099001
Epoch: 0, Loss: 0.136750
Epoch: 0, Loss: 0.123100
Epoch: 0, Loss: 0.086983
Epoch: 0, Loss: 0.135005
Epoch: 0, Loss: 0.161553
Epoch: 0, Loss: 0.130541
Epoch: 0, Loss: 0.114306
Epoch: 0, Loss: 0.095399
Epoch: 0, Loss: 0.103303
Epoch: 0, Loss: 0.129993
Epoch: 0, Loss: 0.109435
Epoch: 0, Loss: 0.077038
Epoch: 0, Loss: 0.091411
Epoch: 0, Loss: 0.101343
Epoch: 0, Loss: 0.076956
Epoch: 0, Loss: 0.102985
Epoch: 0, Loss: 0.083000
Epoch: 0, Loss: 0.094485
Epoch: 0, Loss: 0.099095
Epoch: 0, Loss: 0.065592
Epoch: 0, Loss: 0.060605
Epoch: 0, Loss: 0.087256
Epoch: 0, Loss: 0.074713
Epoch: 0, Loss: 0.057595
Epoch: 0, Loss: 0.066759
Epoch: 0, Loss: 0.045417
Epoch: 0, Loss: 0.055380
Epoch: 0, Loss: 0.056336
Epoch: 0, Loss: 0.050846
Epoch: 0, Loss: 0.055994


Epoch: 4, Loss: 0.013671
Epoch: 4, Loss: 0.016053
Epoch: 4, Loss: 0.025013
Epoch: 4, Loss: 0.028327
Epoch: 4, Loss: 0.020766
Epoch: 4, Loss: 0.012013
Epoch: 4, Loss: 0.017137
Epoch: 4, Loss: 0.027727
Epoch: 4, Loss: 0.013306
Epoch: 4, Loss: 0.013166
Epoch: 4, Loss: 0.010040
Epoch: 4, Loss: 0.032870
Epoch: 4, Loss: 0.010741
Epoch: 4, Loss: 0.016809
Epoch: 4, Loss: 0.034235
Epoch: 4, Loss: 0.021551
Epoch: 4, Loss: 0.010542
Epoch: 4, Loss: 0.025413
Epoch: 4, Loss: 0.014254
Epoch: 4, Loss: 0.019865
Epoch: 4, Loss: 0.051387
Epoch: 4, Loss: 0.019933
Epoch: 4, Loss: 0.013534
Epoch: 4, Loss: 0.022058
Epoch: 4, Loss: 0.013166
Epoch: 4, Loss: 0.027430
Epoch: 4, Loss: 0.021431
Epoch: 4, Loss: 0.019225
Epoch: 4, Loss: 0.016859
Epoch: 4, Loss: 0.015152
Epoch: 4, Loss: 0.008568
Epoch: 4, Loss: 0.021791
Epoch: 4, Loss: 0.015998
Epoch: 4, Loss: 0.021424
Epoch: 4, Loss: 0.015685
Epoch: 4, Loss: 0.024452
Epoch: 4, Loss: 0.010634
Epoch: 4, Loss: 0.010325
Epoch: 4, Loss: 0.027455
Epoch: 4, Loss: 0.023856


Epoch: 8, Loss: 0.023989
Epoch: 8, Loss: 0.034135
Epoch: 8, Loss: 0.025515
Epoch: 8, Loss: 0.005622
Epoch: 8, Loss: 0.016998
Epoch: 8, Loss: 0.004387
Epoch: 8, Loss: 0.003562
Epoch: 8, Loss: 0.015082
Epoch: 8, Loss: 0.019754
Epoch: 8, Loss: 0.008440
Epoch: 8, Loss: 0.012654
Epoch: 8, Loss: 0.014601
Epoch: 8, Loss: 0.004279
Epoch: 8, Loss: 0.008896
Epoch: 8, Loss: 0.009156
Epoch: 8, Loss: 0.011961
Epoch: 8, Loss: 0.005751
Epoch: 8, Loss: 0.015179
Epoch: 8, Loss: 0.023014
Epoch: 8, Loss: 0.008555
Epoch: 8, Loss: 0.010561
Epoch: 8, Loss: 0.005550
Epoch: 8, Loss: 0.016098
Epoch: 8, Loss: 0.011278
Epoch: 8, Loss: 0.008379
Epoch: 8, Loss: 0.013069
Epoch: 8, Loss: 0.021893
Epoch: 8, Loss: 0.014761
Epoch: 8, Loss: 0.012266
Epoch: 8, Loss: 0.016863
Epoch: 8, Loss: 0.009229
Epoch: 8, Loss: 0.009660
Epoch: 8, Loss: 0.010422
Epoch: 8, Loss: 0.011213
Epoch: 8, Loss: 0.013431
Epoch: 8, Loss: 0.016112
Epoch: 8, Loss: 0.009798
Epoch: 8, Loss: 0.025606
Epoch: 8, Loss: 0.010105
Epoch: 8, Loss: 0.011880


Epoch: 12, Loss: 0.010170
Epoch: 12, Loss: 0.005123
Epoch: 12, Loss: 0.008568
Epoch: 12, Loss: 0.012900
Epoch: 12, Loss: 0.009857
Epoch: 12, Loss: 0.007693
Epoch: 12, Loss: 0.007165
Epoch: 12, Loss: 0.004209
Epoch: 12, Loss: 0.012198
Epoch: 12, Loss: 0.007672
Epoch: 12, Loss: 0.014557
Epoch: 12, Loss: 0.012632
Epoch: 12, Loss: 0.008377
Epoch: 12, Loss: 0.005519
Epoch: 12, Loss: 0.013161
Epoch: 12, Loss: 0.012123
Epoch: 12, Loss: 0.011639
Epoch: 12, Loss: 0.004345
Epoch: 12, Loss: 0.003349
Epoch: 12, Loss: 0.008645
Epoch: 12, Loss: 0.008030
Epoch: 12, Loss: 0.014844
Epoch: 12, Loss: 0.010861
Epoch: 12, Loss: 0.009066
Epoch: 12, Loss: 0.013532
Epoch: 12, Loss: 0.007548
Epoch: 12, Loss: 0.012695
Epoch: 12, Loss: 0.003910
Epoch: 12, Loss: 0.005475
Epoch: 12, Loss: 0.005302
Epoch: 12, Loss: 0.006338
Epoch: 12, Loss: 0.019589
Epoch: 12, Loss: 0.004839
Epoch: 12, Loss: 0.021625
Epoch: 12, Loss: 0.013311
Epoch: 12, Loss: 0.022571
Epoch: 12, Loss: 0.002953
Epoch: 12, Loss: 0.014674
Epoch: 12, L

Epoch: 16, Loss: 0.005633
Epoch: 16, Loss: 0.009408
Epoch: 16, Loss: 0.008184
Epoch: 16, Loss: 0.006743
Epoch: 16, Loss: 0.007087
Epoch: 16, Loss: 0.013512
Epoch: 16, Loss: 0.014774
Epoch: 16, Loss: 0.003035
Epoch: 16, Loss: 0.008752
Epoch: 16, Loss: 0.006441
Epoch: 16, Loss: 0.007314
Epoch: 16, Loss: 0.003435
Epoch: 16, Loss: 0.006320
Epoch: 16, Loss: 0.007881
Epoch: 16, Loss: 0.006892
Epoch: 16, Loss: 0.004675
Epoch: 16, Loss: 0.004861
Epoch: 16, Loss: 0.004176
Epoch: 16, Loss: 0.005443
Epoch: 16, Loss: 0.007656
Epoch: 16, Loss: 0.004299
Epoch: 16, Loss: 0.012046
Epoch: 16, Loss: 0.012885
Epoch: 16, Loss: 0.009662
Epoch: 16, Loss: 0.010548
Epoch: 16, Loss: 0.003884
Epoch: 16, Loss: 0.006001
Epoch: 16, Loss: 0.009561
Epoch: 16, Loss: 0.019247
Epoch: 16, Loss: 0.019203
Epoch: 16, Loss: 0.013306
Epoch: 16, Loss: 0.003682
Epoch: 16, Loss: 0.004303
Epoch: 16, Loss: 0.005715
Epoch: 16, Loss: 0.016024
Epoch: 16, Loss: 0.005089
Epoch: 16, Loss: 0.007069
Epoch: 16, Loss: 0.010628
Epoch: 16, L

In [6]:
for step, batch in enumerate(dataloader): break
cond = batch[:, :-1].to(device) 
samples = diffusion.conditional_sample(cond, model, image_size=1, batch_size=128)

sampling loop time step:   0%|          | 0/10000 [00:00<?, ?it/s]

In [7]:
cond = cond.detach().cpu().numpy()
trial_samples = cond * scale + bias
best_model.predict(trial_samples)[0]

array([[-0.842015  ],
       [-0.8195985 ],
       [ 0.98319817],
       [ 0.9791057 ],
       [-0.42925155],
       [-0.03757697],
       [ 0.14848042],
       [ 0.41027284],
       [ 0.8430635 ],
       [-0.45497656],
       [-0.05592513],
       [ 0.9832127 ],
       [ 0.98097944],
       [ 0.9548286 ],
       [-0.21146631],
       [-0.25092345],
       [ 0.34486234],
       [-0.38975686],
       [ 0.98119617],
       [ 0.82282114],
       [ 0.9799963 ],
       [-0.34440982],
       [ 0.98209167],
       [ 0.5164199 ],
       [ 0.5192437 ],
       [ 0.44996905],
       [ 0.8777826 ],
       [ 0.9787164 ],
       [-0.21505892],
       [-0.12742108],
       [ 0.95991135],
       [-0.3614183 ],
       [-0.3639124 ],
       [ 0.2473414 ],
       [ 0.98543215],
       [-0.19490314],
       [-0.12956715],
       [ 0.7173605 ],
       [ 0.9825808 ],
       [ 0.622939  ],
       [-0.17044902],
       [ 0.95837593],
       [-0.0820961 ],
       [-0.13803995],
       [ 0.98026514],
       [ 0

In [41]:
for i in range(1):
    obs = env.reset()
    print(obs.shape)
    done = False 
    while not done:
        #action, _= best_model.predict(obs)
        print(obs)
        cond = torch.tensor((obs-bias)/scale).to(device)
        samples = diffusion.conditional_sample(cond, model, image_size=1, batch_size=1)
        action = samples[-1]
        obs, reward, done, info = env.step(action)
        if done:
            print("reward at the end of the episode : ", reward)

(1, 2)
[[-0.56191593  0.        ]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-5.616663e-01  2.496418e-04]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-5.6114668e-01  5.1966263e-04]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.5603324   0.00081423]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.5592276   0.00110481]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.55781037  0.00141719]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.5560709   0.00173946]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.5539894   0.00208149]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.551578    0.00241144]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.5488104   0.00276758]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.54568726  0.00312314]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.54220486  0.00348239]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.5383605   0.00384437]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.5341741   0.00418638]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.52961636  0.00455776]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.52469903  0.0049173 ]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.5194593   0.00523972]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.5138948   0.00556453]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.5080315   0.00586331]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.50187176  0.00615972]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.49545074  0.00642102]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.48878884  0.00666189]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.4819407   0.00684814]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.47493246  0.00700821]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.46782294  0.00710953]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.46064037  0.00718257]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.45342544  0.00721494]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.44624555  0.00717987]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.43910876  0.00713679]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.4321018   0.00700697]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.425291    0.00681079]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.41875732  0.00653369]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.41256732  0.00618999]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.40680546  0.00576187]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.40155065  0.00525481]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.39689493  0.00465572]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.39291078  0.00398415]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.38968927  0.00322151]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.3873243   0.00236496]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.3859216  0.0014027]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-3.8556677e-01  3.5481242e-04]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.38634816 -0.0007814 ]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.38834807 -0.0019999 ]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.39164156 -0.00329348]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.39629996 -0.00465841]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.402366   -0.00606605]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.40990046 -0.00753444]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.4189156  -0.00901514]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.42939192 -0.01047632]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.44130164 -0.01190973]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.45459762 -0.01329597]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.46921033 -0.01461269]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.4850279  -0.01581758]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.50194967 -0.01692178]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.5198467  -0.01789702]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.53856224 -0.01871554]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.5579445  -0.01938226]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.5777946  -0.01985011]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.5979055  -0.02011088]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.61810535 -0.02019984]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.6381813 -0.020076 ]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.657938   -0.01975667]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.6771817  -0.01924369]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.6957461 -0.0185644]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.713418   -0.01767189]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.730052   -0.01663396]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.74545914 -0.01540717]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.75946903 -0.0140099 ]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.77197284 -0.01250382]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.7828049  -0.01083208]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.79186016 -0.00905525]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.79902935 -0.00716917]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.8042236  -0.00519423]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.80735254 -0.00312893]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.80836135 -0.0010088 ]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.807194    0.00116734]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.8037642   0.00342975]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.79799736  0.00576689]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.78981084  0.00818649]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.7791236   0.01068723]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.7658829   0.01324068]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.7500193   0.01586362]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.73151684  0.0185025 ]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.71037644  0.02114039]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.68664324  0.02373323]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.66040343  0.02623979]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.6317808   0.02862264]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.60093087  0.03084996]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.56807977  0.03285109]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.5334506   0.03462915]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.49726897  0.03618164]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.45982006  0.03744891]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.4213747   0.03844535]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.38222137  0.03915335]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.34261483  0.03960653]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.30285466  0.03976018]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.26315278  0.03970189]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.22375643  0.03939635]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.18483634  0.03892009]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.14656147  0.03827487]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.10906814  0.03749333]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.07245228  0.03661586]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.0368095   0.03564278]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[-0.00217848  0.03463102]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[0.0314226  0.03360109]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[0.06402238 0.03259978]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[0.09565756 0.03163518]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[0.12636814 0.03071057]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[0.15625621 0.02988808]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[0.18539083 0.02913462]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[0.21388198 0.02849116]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[0.24183904 0.02795705]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[0.26938337 0.02754432]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[0.29666358 0.02728022]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[0.3238444  0.02718083]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[0.3510915  0.02724711]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[0.37857598 0.02748449]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[0.4064931  0.02791713]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

[[0.4350193  0.02852623]]


sampling loop time step:   0%|          | 0/100 [00:00<?, ?it/s]

reward at the end of the episode :  [99.90146]


In [21]:
cond = torch.tensor(obs).to(device)

  cond = torch.tensor(obs).to(device)


In [22]:
samples = diffusion.conditional_sample(cond, model, image_size=1, batch_size=1)

sampling loop time step:   0%|          | 0/10000 [00:00<?, ?it/s]

In [27]:
samples[-1].detach()

AttributeError: 'numpy.ndarray' object has no attribute 'detach'