In [57]:
import matplotlib.pyplot as plt

import torch

from torch.utils.data import DataLoader

device = "cuda" if torch.cuda.is_available() else "cpu"

from diffusion_framework.ddpm_1d import DDPM_1d
from diffusion_framework.nets import ErrorNet, CondErrorNet

from torch.optim import Adam
import numpy as np

import os
import gym
import numpy as np
from stable_baselines3 import SAC
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv
import platform
if platform.system() == "Linux":
    os.environ['PYOPENGL_PLATFORM'] = 'egl'

### Select environment
env_id = "MountainCarContinuous-v0"
# Note that the algorithm is SAC

gym_env = gym.make(env_id)
max_obs_values = gym_env.observation_space.high
min_obs_values = gym_env.observation_space.low

max_act_value = gym_env.action_space.high
min_act_value = gym_env.action_space.low

bias = max_obs_values + min_obs_values
bias = bias / 2
scale = max_obs_values - min_obs_values
scale = scale / 2

env = make_vec_env(env_id, n_envs=1)
#best_model = SAC.load('/home/sai-admin/advanced_ml_project/rl-baselines3-zoo/rl-trained-agents/sac/' + 
#                      env_id + '_1/' + 
#                      env_id + '.zip', env=env)

In [58]:
# Define data
expert_demo_path = 'expert/logs/expert_demonstrations.npy'

dataset = np.load(expert_demo_path, allow_pickle=True)
dataset = torch.tensor(dataset, dtype=torch.float32)

batch_size = 128
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Define models
timesteps = 1000
diffusion = DDPM_1d(timesteps)

In [59]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
model = ErrorNet(dim=3)
model.to(device)

optimizer = Adam(model.parameters(), lr=1e-4)
model

cuda:0


ErrorNet(
  (time_mlp): Sequential(
    (0): SinusoidalPositionEmbeddings()
    (1): Linear(in_features=64, out_features=128, bias=True)
    (2): GELU()
    (3): Linear(in_features=128, out_features=64, bias=True)
    (4): SiLU()
    (5): Linear(in_features=64, out_features=6, bias=True)
  )
  (state_mlp): Sequential(
    (0): Linear(in_features=3, out_features=64, bias=True)
    (1): GELU()
    (2): Linear(in_features=64, out_features=128, bias=True)
    (3): GELU()
    (4): Linear(in_features=128, out_features=64, bias=True)
    (5): GELU()
    (6): Linear(in_features=64, out_features=3, bias=True)
  )
  (res_mlp): Sequential(
    (0): Linear(in_features=6, out_features=64, bias=True)
    (1): SiLU()
    (2): Linear(in_features=64, out_features=3, bias=True)
  )
  (final_mlp): Sequential(
    (0): Linear(in_features=3, out_features=64, bias=True)
    (1): GELU()
    (2): Linear(in_features=64, out_features=128, bias=True)
    (3): GELU()
    (4): Linear(in_features=128, out_features=

In [60]:
path = '/tmp/adv_ml/diffusion_models/mountaincar/joint/thousand_steps.pth'
model.load_state_dict(torch.load(path))

<All keys matched successfully>

In [None]:
epochs = 20

for epoch in range(epochs):
    for step, batch in enumerate(dataloader):
        optimizer.zero_grad()

        batch_size = batch.shape[0]
        batch = batch.to(device)

        # Algorithm 1 line 3: sample t uniformally for every example in the batch
        t = torch.randint(0, timesteps, (batch_size,), device=device).long()

        loss = diffusion.p_losses(model, batch, t, loss_type="huber")

        if step % 100 == 0:
            print("Epoch: %d, Loss: %f" %(epoch, loss.item()))

        loss.backward()
        optimizer.step()
        
path = '/tmp/adv_ml/diffusion_models/mountaincar/joint/thousand_steps.pth'
torch.save(model.state_dict(), path)

In [None]:


for timestep in reversed(range(timesteps)):
    diffusion.p_sample(model, x, timestep, )

In [61]:
mask = torch.tensor([[1.0, 1.0, 0.0]], device=device)
a = torch.randn(1,1).to(device)
a = torch.clamp(a, -1, 1)


for i in range(1):
    obs = env.reset()
    print(obs.shape)
    done = False 
    while not done:
        print(obs)
        # inpainting
        cond = torch.tensor((obs-bias)/scale).to(device)
        x = torch.cat((cond,a),-1)
        g = torch.randn(x.shape, device=device)
        for i in reversed(range(0, timesteps)):
            t = torch.tensor([i], device=device)
            x_noisy = diffusion.q_sample(x, t)
            g = diffusion.p_sample(model, g, t, i)
            g = x_noisy * mask + g * (1 - mask)
        
        g = g.cpu().numpy()
        action = [[g[0][-1]]]
        obs, reward, done, info = env.step(action)
        if done:
            print("reward at the end of the episode : ", reward)

(1, 2)
[[-0.5372814  0.       ]]
[[-0.5365116   0.00076979]]
[[-0.5353332   0.00117837]]
[[-0.53403455  0.00129864]]
[[-0.5327331   0.00130148]]
[[-0.53124607  0.00148699]]
[[-0.52945113  0.00179494]]
[[-0.5273889   0.00206225]]
[[-0.52486205  0.00252679]]
[[-0.5222018  0.0026603]]
[[-0.5192413   0.00296053]]
[[-0.5158618   0.00337946]]
[[-0.5129677   0.00289411]]
[[-0.50957793  0.00338978]]
[[-0.5062923   0.00328564]]
[[-0.5033015   0.00299081]]
[[-0.49994075  0.00336075]]
[[-0.4971649   0.00277585]]
[[-0.49463698  0.00252792]]
[[-0.49232212  0.00231488]]
[[-0.48970488  0.00261725]]
[[-0.4871277   0.00257719]]
[[-0.48489556  0.00223214]]
[[-0.48213553  0.00276003]]
[[-0.47929072  0.00284481]]
[[-0.47652882  0.00276189]]
[[-0.4735205   0.00300834]]
[[-0.47019964  0.00332085]]
[[-0.46699584  0.0032038 ]]
[[-0.46452177  0.00247406]]
[[-0.46181217  0.00270961]]
[[-0.45985064  0.00196154]]
[[-0.4577859   0.00206472]]
[[-0.4562404   0.00154552]]
[[-0.45536652  0.00087387]]
[[-4.5557454e-01 

In [32]:

obs = env.reset()
cond = torch.tensor((obs-bias)/scale).to(device)

x = torch.cat((cond,a),-1)
g = torch.randn(x.shape, device=device)
for i in reversed(range(0, timesteps)):
    t = torch.tensor([i], device=device)
    x_noisy = diffusion.q_sample(x, t)
    g = diffusion.p_sample(model, g, t, i)
    g = x_noisy * mask + g * (1 - mask)
print(g)

tensor([[-0.2395,  0.0125,  0.3885]], device='cuda:0')


In [40]:
x_noisy = diffusion.q_sample(x, t)

In [54]:
g = g.cpu().numpy()
np.array([[g[0][-1]]])

array([[0.38854772]], dtype=float32)

In [48]:
x_noisy * (mask)

tensor([[-0.2377,  0.0009,  0.0000]], device='cuda:0')

tensor([1.0000e-04, 1.1992e-04, 1.3984e-04, 1.5976e-04, 1.7968e-04, 1.9960e-04,
        2.1952e-04, 2.3944e-04, 2.5936e-04, 2.7928e-04, 2.9920e-04, 3.1912e-04,
        3.3904e-04, 3.5896e-04, 3.7888e-04, 3.9880e-04, 4.1872e-04, 4.3864e-04,
        4.5856e-04, 4.7848e-04, 4.9840e-04, 5.1832e-04, 5.3824e-04, 5.5816e-04,
        5.7808e-04, 5.9800e-04, 6.1792e-04, 6.3784e-04, 6.5776e-04, 6.7768e-04,
        6.9760e-04, 7.1752e-04, 7.3744e-04, 7.5736e-04, 7.7728e-04, 7.9720e-04,
        8.1712e-04, 8.3704e-04, 8.5696e-04, 8.7688e-04, 8.9680e-04, 9.1672e-04,
        9.3664e-04, 9.5656e-04, 9.7648e-04, 9.9640e-04, 1.0163e-03, 1.0362e-03,
        1.0562e-03, 1.0761e-03, 1.0960e-03, 1.1159e-03, 1.1358e-03, 1.1558e-03,
        1.1757e-03, 1.1956e-03, 1.2155e-03, 1.2354e-03, 1.2554e-03, 1.2753e-03,
        1.2952e-03, 1.3151e-03, 1.3350e-03, 1.3550e-03, 1.3749e-03, 1.3948e-03,
        1.4147e-03, 1.4346e-03, 1.4546e-03, 1.4745e-03, 1.4944e-03, 1.5143e-03,
        1.5342e-03, 1.5542e-03, 1.5741e-