In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [2]:
from tqdm import tqdm_notebook as tqdm
import numpy as np

In [3]:
import torch as T
import torch.nn as nn

In [4]:
import gym

env = gym.make('Reacher-v2')
print(env.observation_space.shape, env.action_space.shape)

for i in range(6):
    env.reset()
    rew = 0
    
    while True:
        _, r, done, _ = env.step(env.action_space.sample())
        
        rew += r
        
        if done==True:
            print('Ep %d: %.2f' % (i+1, rew))
            
            break

(11,) (2,)
Ep 1: -41.53
Ep 2: -40.14
Ep 3: -44.62
Ep 4: -40.81
Ep 5: -43.25
Ep 6: -48.14


In [5]:
class BCO(nn.Module):
    def __init__(self, env, policy='mlp'):
        super(BCO, self).__init__()
        
        self.policy = policy
        self.act_n = env.action_space.shape[0]
        
        if self.policy=='mlp':
            self.obs_n = env.observation_space.shape[0]
            self.pol = nn.Sequential(*[nn.Linear(self.obs_n, 32), nn.LeakyReLU(), 
                                       nn.Linear(32, 32), nn.LeakyReLU(), 
                                       nn.Linear(32, self.act_n)])
            self.inv = nn.Sequential(*[nn.Linear(self.obs_n*2, 32), nn.LeakyReLU(), 
                                       nn.Linear(32, 32), nn.LeakyReLU(),  
                                       nn.Linear(32, self.act_n)])
        
        elif self.policy=='cnn':
            pass
    
    def pred_act(self, obs):
        out = self.pol(obs)
        
        return out
    
    def pred_inv(self, obs1, obs2):
        obs = T.cat([obs1, obs2], dim=1)
        out = self.inv(obs)
        
        return out

POLICY = 'mlp'
model = BCO(env, policy=POLICY).cuda()

In [6]:
from torch.utils.data import Dataset, DataLoader

class DS_Inv(Dataset):
    def __init__(self, trajs):
        self.dat = []
        
        for traj in trajs:
            for dat in traj:
                obs, act, new_obs = dat
                
                self.dat.append([obs, new_obs, act])
    
    def __len__(self):
        return len(self.dat)
    
    def __getitem__(self, idx):
        obs, new_obs, act = self.dat[idx]
        
        return obs, new_obs, act

class DS_Policy(Dataset):
    def __init__(self, traj):
        self.dat = []
        
        for dat in traj:
            obs, act = dat
                
            self.dat.append([obs, act])
    
    def __len__(self):
        return len(self.dat)
    
    def __getitem__(self, idx):
        obs, act = self.dat[idx]
        
        return obs, act

In [7]:
import pickle

trajs_demo = pickle.load(open('Demo/demo_reacher.pkl', 'rb'))
ld_demo = DataLoader(DS_Inv(trajs_demo), batch_size=100)

print(len(ld_demo))
for obs1, obs2, _ in ld_demo:
    print(obs1.shape, obs2.shape)
    
    break

50
torch.Size([100, 11]) torch.Size([100, 11])


In [8]:
loss_func = nn.MSELoss().cuda()
optim = T.optim.Adam(model.parameters(), lr=5e-4)

EPOCHS = 20
M = 5000

EPS = 0.9
DECAY = 0.5

In [9]:
trajs_inv = []

for e in tqdm(range(EPOCHS)):
    
    # step1, generate inverse samples
    cnt = 0
    epn = 0
    
    rews = 0
        
    while True:
        traj = []
        rew = 0
            
        obs = env.reset()
        while True:
            inp = T.from_numpy(obs).view(((1, )+obs.shape)).float().cuda()
            out = model.pred_act(inp).cpu().detach().numpy()
                
            if np.random.rand()>=EPS:
                act = out[0]
            else:
                act = env.action_space.sample()
                
            new_obs, r, done, _ = env.step(act)
                
            traj.append([obs, act, new_obs])
            obs = new_obs
            rew += r
            
            cnt += 1
                
            if done==True:
                rews += rew
                trajs_inv.append(traj)
                
                epn += 1
                
                break
        
        if cnt >= M:
            break
        
    rews /= epn
    print('Ep %d: reward=%.2f' % (e+1, rews))
        
    # step2, update inverse model
    ld_inv = DataLoader(DS_Inv(trajs_inv), batch_size=32, shuffle=True)
    
    with tqdm(ld_inv) as TQ:
        ls_ep = 0
        
        for obs1, obs2, act in TQ:
            out = model.pred_inv(obs1.float().cuda(), obs2.float().cuda())
            ls_bh = loss_func(out, act.cuda())
            
            optim.zero_grad()
            ls_bh.backward()
            optim.step()
            
            ls_bh = ls_bh.cpu().detach().numpy()
            TQ.set_postfix(loss_inv='%.3f' % (ls_bh))
            ls_ep += ls_bh
        
        ls_ep /= len(TQ)
        print('Ep %d: loss_inv=%.3f' % (e+1, ls_ep))
    
    # step3, predict inverse action for demo samples
    traj_policy = []
    
    for obs1, obs2, _ in ld_demo:
        out = model.pred_inv(obs1.float().cuda(), obs2.float().cuda())
        
        obs = obs1.cpu().detach().numpy()
        out = out.cpu().detach().numpy()
        
        for i in range(100):
            traj_policy.append([obs[i], out[i]])
    
    # step4, update policy via demo samples
    ld_policy = DataLoader(DS_Policy(traj_policy), batch_size=32, shuffle=True)
    
    with tqdm(ld_policy) as TQ:
        ls_ep = 0
        
        for obs, act in TQ:
            out = model.pred_act(obs.float().cuda())
            ls_bh = loss_func(out, act.cuda())
            
            optim.zero_grad()
            ls_bh.backward()
            optim.step()
            
            ls_bh = ls_bh.cpu().detach().numpy()
            TQ.set_postfix(loss_policy='%.3f' % (ls_bh))
            ls_ep += ls_bh
        
        ls_ep /= len(TQ)
        print('Ep %d: loss_policy=%.3f' % (e+1, ls_ep))
    
    # step5, save model
    T.save(model.state_dict(), 'Model/model_reacher_%d.pt' % (e+1))
    
    EPS *= DECAY

HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

Ep 1: reward=-41.75


HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

Ep 1: loss_inv=0.267


HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

Ep 1: loss_policy=0.007
Ep 2: reward=-34.40


HBox(children=(IntProgress(value=0, max=313), HTML(value='')))

Ep 2: loss_inv=0.083


HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

Ep 2: loss_policy=0.004
Ep 3: reward=-43.24


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Ep 3: loss_inv=0.052


HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

Ep 3: loss_policy=0.003
Ep 4: reward=-20.75


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Ep 4: loss_inv=0.032


HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

Ep 4: loss_policy=0.003
Ep 5: reward=-20.34


HBox(children=(IntProgress(value=0, max=782), HTML(value='')))

Ep 5: loss_inv=0.022


HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

Ep 5: loss_policy=0.002
Ep 6: reward=-16.03


HBox(children=(IntProgress(value=0, max=938), HTML(value='')))

Ep 6: loss_inv=0.016


HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

Ep 6: loss_policy=0.002
Ep 7: reward=-16.49


HBox(children=(IntProgress(value=0, max=1094), HTML(value='')))

Ep 7: loss_inv=0.013


HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

Ep 7: loss_policy=0.002
Ep 8: reward=-13.95


HBox(children=(IntProgress(value=0, max=1250), HTML(value='')))

Ep 8: loss_inv=0.010


HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

Ep 8: loss_policy=0.002
Ep 9: reward=-12.96


HBox(children=(IntProgress(value=0, max=1407), HTML(value='')))

Ep 9: loss_inv=0.009


HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

Ep 9: loss_policy=0.002
Ep 10: reward=-12.78


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

Ep 10: loss_inv=0.008


HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

Ep 10: loss_policy=0.002
Ep 11: reward=-13.51


HBox(children=(IntProgress(value=0, max=1719), HTML(value='')))

Ep 11: loss_inv=0.007


HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

Ep 11: loss_policy=0.002
Ep 12: reward=-12.55


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

Ep 12: loss_inv=0.006


HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

Ep 12: loss_policy=0.002
Ep 13: reward=-11.60


HBox(children=(IntProgress(value=0, max=2032), HTML(value='')))

Ep 13: loss_inv=0.005


HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

Ep 13: loss_policy=0.002
Ep 14: reward=-12.66


HBox(children=(IntProgress(value=0, max=2188), HTML(value='')))

Ep 14: loss_inv=0.005


HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

Ep 14: loss_policy=0.003
Ep 15: reward=-12.44


HBox(children=(IntProgress(value=0, max=2344), HTML(value='')))

Ep 15: loss_inv=0.004


HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

Ep 15: loss_policy=0.003
Ep 16: reward=-12.20


HBox(children=(IntProgress(value=0, max=2500), HTML(value='')))

Ep 16: loss_inv=0.004


HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

Ep 16: loss_policy=0.003
Ep 17: reward=-11.53


HBox(children=(IntProgress(value=0, max=2657), HTML(value='')))

Ep 17: loss_inv=0.003


HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

Ep 17: loss_policy=0.002
Ep 18: reward=-12.09


HBox(children=(IntProgress(value=0, max=2813), HTML(value='')))

Ep 18: loss_inv=0.003


HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

Ep 18: loss_policy=0.002
Ep 19: reward=-11.29


HBox(children=(IntProgress(value=0, max=2969), HTML(value='')))

Ep 19: loss_inv=0.003


HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

Ep 19: loss_policy=0.002
Ep 20: reward=-11.78


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Ep 20: loss_inv=0.002


HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

Ep 20: loss_policy=0.002

