Rough runthrough of GAIL for dialog generation, getting close...

In [1]:
from collections import deque  
#code for training 
import torch
import numpy as np

import sys
sys.path.append('../src')
from models import *
from dialog_environment import *
import torch.optim as optim
import math
import torch
from torch.distributions import Normal

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [2]:
env = DialogEnvironment()

# Normally args but not here :-)
seed = 0
render = False
gamma = 0.99
lamda = .98

train_discrim_flag = True
learning_rate = 3e-4
clip_param = .2
discrim_update_num = 2
actor_critic_update_num = 10
l2_rate = 1e-3 # weight decay
total_sample_size = 256 # total num of state-actions to collect before learning
batch_size = 32
suspend_accu_exp = 1. # do not need to be this high typically, but seems likely it has to be for a simple env like mountain car cont.
suspend_accu_gen = 1.
max_iter_num = 500


actor = Actor(hidden_size=3,num_layers=3)
critic = Critic(hidden_size=1,num_layers=3)
discrim = Discriminator(input_size = 300, hidden_size=1,device='cuda',num_layers=3)
actor.to(device), critic.to(device), discrim.to(device)
actor_optim = optim.Adam(actor.parameters(), lr=learning_rate)
critic_optim = optim.Adam(critic.parameters(), lr=learning_rate, 
                          weight_decay=l2_rate) 
discrim_optim = optim.Adam(discrim.parameters(), lr=learning_rate)



In [21]:


def subsample(data, target, n=15):
    return [x[::n] for x in data], [y[::n] for y in target]


def get_action(mu, std):
    action = torch.normal(mu, std)
    action = action.data.numpy()
    return action


def get_entropy(mu, std):
    dist = Normal(mu, std)
    entropy = dist.entropy().mean()
    return entropy

def log_prob_density(x, mu, std):
    log_prob_density = -(x - mu).pow(2) / (2 * std.pow(2)) \
                     - 0.5 * math.log(2 * math.pi)
    return log_prob_density.sum(1, keepdim=True)

def get_reward(discrim, state, action):
    """
    The reward function according to irl. It's log D(s,a). 
    
    Reward is higher the closer this is to 0, because the more similar it is to an expert action. :
    Is quite close to imitation learning, but hope here is that with such a large number of expert demonstrations and entropy bonuses etc. it learns more than direct imitation. 
    """

    action = torch.Tensor(action).to(device)# turn state into a tensor if not already

    with torch.no_grad():
        return -math.log(discrim(state.resize(1,60,300),action.resize(1,60,300))[0].item())

def save_checkpoint(state, filename):
    return


def train_discrim(discrim, memory, discrim_optim, discrim_update_num, clip_param):
    """
    Training the discriminator. 

    Use binary cross entropy to classify whether 
    or not a sequence was predicted by the expert (real data) or actor. 
    """
    states = torch.stack([memory[i][0] for i in range(len(memory))])
    actions = torch.stack([memory[i][1] for i in range(len(memory))])
    rewards = [memory[i][2] for i in range(len(memory))]

    masks = [memory[i][2] for i in range(len(memory))]
    expert_actions = torch.stack([memory[i][4] for i in range(len(memory))])

    criterion = torch.nn.BCELoss() # classify

    for _ in range(discrim_update_num):

        learner = discrim(states, actions) #pass (s,a) through discriminator

        # TODO
       # demonstrations = torch.Tensor([states, expert_actions]) # pass (s,a) of expert through discriminator
        expert = discrim(states,expert_actions) #discrimator "guesses" whether or not these 
        # actions came from expert or learner
        discrim_loss = criterion(learner, torch.ones((states.shape[0], 1)).to(device)) + \
                        criterion(expert, torch.zeros((states.shape[0], 1)).to(device))
                # discrim loss: predict agent is all wrong, get as close to 0, and predict expert is 1, getting as close to 1 as possible. 
        discrim_optim.zero_grad() # gan loss, it tries to always get it right. 
        discrim_loss.backward()
        discrim_optim.step()
            # take these steps, do it however many times specified. 
        #return discrim(states,expert_actions) , discrim(states,actions)
    expert_acc = ((discrim(states,expert_actions) < 0.5).float()).mean() #how often it realized the fake examples were fake
    learner_acc = ((discrim(states,actions) > 0.5).float()).mean() #how often if predicted expert correctly. 

    return expert_acc, learner_acc # accuracy, it's the same kind, but because imbalanced better to look at separately. 
 

In [22]:
def train_actor_critic(actor, critic, memory, actor_optim, critic_optim, actor_critic_update_num, batch_size, clip_param):
    """
    Take a PPO step or two to improve the actor critic model,  using GAE to estimate returns. 
    
    In our case each trajectory it most one step, so the value function will have to do. 
    
    
    """
    # tuple of a regular old RL problem, but now reward is what the discriminator says. 
    states = torch.stack([memory[i][0] for i in range(len(memory))])
    actions = torch.stack([memory[i][1] for i in range(len(memory))])
    rewards = [memory[i][2] for i in range(len(memory))]
    masks = [memory[i][2] for i in range(len(memory))]
    # compute value of what happened, see if what we can get us better. 
    old_values = critic(states)

    #GAE aka estimate of Value + actual return roughtly 
    returns, advants = get_gae(rewards, masks, old_values, gamma, lamda)
    
    # pass states through actor, get corresponding actions
    mu, std = actor(states)
    # new mus and stds? 
    old_policy = log_prob_density(actions, mu, std) # sum of log probability
    # of old actions

    criterion = torch.nn.MSELoss()
    n = len(states)
    arr = np.arange(n)

    for _ in range(actor_critic_update_num):
        np.random.shuffle(arr)

        for i in range(n // batch_size): 
            batch_index = arr[batch_size * i : batch_size * (i + 1)]
            #batch_index = torch.LongTensor(batch_index)
            
            inputs = states[batch_index]
            actions_samples = actions[batch_index]
            returns_samples = returns.unsqueeze(1)[batch_index].to(device)
            advants_samples = advants.unsqueeze(1)[batch_index].to(device)
            oldvalue_samples = old_values[batch_index].detach()
        
        
            values = critic(inputs) #
            clipped_values = oldvalue_samples + \
                             torch.clamp(values - oldvalue_samples,
                                         -clip_param, 
                                         clip_param)
            critic_loss1 = criterion(clipped_values, returns_samples)
            critic_loss2 = criterion(values, returns_samples)
            critic_loss = torch.max(critic_loss1, critic_loss2).mean()

            loss, ratio, entropy = surrogate_loss(actor, advants_samples, inputs,
                                         old_policy.detach(), actions_samples,
                                         batch_index)
            clipped_ratio = torch.clamp(ratio,
                                        1.0 - clip_param,
                                        1.0 + clip_param)
            clipped_loss = clipped_ratio * advants_samples
            actor_loss = -torch.min(loss, clipped_loss).mean()
            #print(actor_loss,critic_loss,entropy)
           # return actor_loss, critic_loss, entropy
            loss = actor_loss + 0.5 * critic_loss - 0.001 * entropy #entropy bonus to promote exploration.

            actor_optim.zero_grad()
            loss.backward()
            actor_optim.step()

           # critic_optim.zero_grad()
           # loss.backward() 
            critic_optim.step()

def get_gae(rewards, masks, values, gamma, lamda):
    """
    How much better a particular action is in a particular state. 
    
    Uses reward of current action + value function of that state-action pair, discount factor gamma, and then lamda to compute. 
    """
    rewards = torch.Tensor(rewards)
    masks = torch.Tensor(masks)
    returns = torch.zeros_like(rewards)
    advants = torch.zeros_like(rewards)
    
    running_returns = 0
    previous_value = 0
    running_advants = 0

    for t in reversed(range(0, len(rewards))): #for LL, only ever one step :-)
        running_returns = rewards[t] + (gamma * running_returns * masks[t])
        returns[t] = running_returns

        running_delta = rewards[t] + (gamma * previous_value * masks[t]) - \
                                        values.data[t]
        previous_value = values.data[t]
        
        running_advants = running_delta + (gamma * lamda * \
                                            running_advants * masks[t])
        advants[t] = running_advants

    advants = (advants - advants.mean()) / advants.std()
    return returns, advants

def surrogate_loss(actor, advants, states, old_policy, actions, batch_index):
    """
    The loss for PPO. Re-run through network, recomput policy from states
    and see if this surrogate ratio is better. If it is, use as proximal policy update. It's very close to prior policy, but def better. 
    
    Not sure this actually works though. Should not the new mu and stds be used to draw,
    
        When do we use get_action? Only once in main, I think it should be for all? 
    """
    mu, std = actor(states)
    new_policy = log_prob_density(actions, mu, std)
    old_policy = old_policy[batch_index]

    ratio = torch.exp(new_policy - old_policy)
    surrogate_loss = ratio * advants
    entropy = get_entropy(mu, std)

    return surrogate_loss, ratio, entropy

In [23]:
episodes = 0
train_discrim_flag = True
total_sample_size = 256
max_iter_num = 500
render=False

In [None]:
# Now what we came for...

for iter in range(max_iter_num):
    actor.eval(), critic.eval()
    memory = deque()

    steps = 0
    scores = []

    while steps < total_sample_size: 
        state, expert_action, raw_state, raw_expert_action = env.reset()
        state = state.to(device)
        expert_action = expert_action.to(device)
        score = 0
            
        #print("breakpt")
       # break
        
        
        for _ in range(10000): 
            if render:
                print(raw_state, raw_expert_action)
            steps += 1

            #TODO

            mu, std = actor(state.resize(1,60,300))
            action = get_action(mu.cpu(), std.cpu())[0]
            done= env.step(action)
            irl_reward = get_reward(discrim, state, action)
            if done:
                mask = 0
            else:
                mask = 1

            memory.append([state, torch.from_numpy(action).to(device), irl_reward, mask,expert_action])

            sys.exit

            score += irl_reward

            if done:
                break

        episodes += 1
        scores.append(score)

    score_avg = np.mean(scores)
    print('{}:: {} episode score is {:.2f}'.format(iter, episodes, score_avg))

    actor.train(), critic.train(), discrim.train()
    if train_discrim_flag:
        expert_acc, learner_acc = train_discrim(discrim, memory, discrim_optim, discrim_update_num, clip_param)
        print("Expert: %.2f%% | Learner: %.2f%%" % (expert_acc * 100, learner_acc * 100))
        if expert_acc > suspend_accu_exp and learner_acc > suspend_accu_gen:
            train_discrim_flag = False
    train_actor_critic(actor, critic, memory, actor_optim, critic_optim, actor_critic_update_num, batch_size, clip_param)

    if iter % 100:
        score_avg = int(score_avg)

        print("score_avg:",score_avg)

0:: 256 episode score is 0.77
Expert: 100.00% | Learner: 0.00%
1:: 512 episode score is 0.75
Expert: 100.00% | Learner: 0.00%
score_avg: 0
2:: 768 episode score is 0.73
Expert: 100.00% | Learner: 0.00%
score_avg: 0
3:: 1024 episode score is 0.71
Expert: 100.00% | Learner: 0.00%
score_avg: 0
4:: 1280 episode score is 0.70
Expert: 0.00% | Learner: 100.00%
score_avg: 0
5:: 1536 episode score is 0.69
Expert: 0.00% | Learner: 100.00%
score_avg: 0
6:: 1792 episode score is 0.68
Expert: 0.00% | Learner: 100.00%
score_avg: 0
7:: 2048 episode score is 0.67
Expert: 0.00% | Learner: 100.00%
score_avg: 0
8:: 2304 episode score is 0.66
Expert: 0.00% | Learner: 100.00%
score_avg: 0
9:: 2560 episode score is 0.66
Expert: 0.00% | Learner: 100.00%
score_avg: 0
10:: 2816 episode score is 0.66
Expert: 0.00% | Learner: 100.00%
score_avg: 0
11:: 3072 episode score is 0.67
Expert: 0.00% | Learner: 100.00%
score_avg: 0
12:: 3328 episode score is 0.67
Expert: 0.00% | Learner: 100.00%
score_avg: 0
13:: 3584 ep

In [19]:
%debug

> [0;32m/scratch/nsk367/anaconda3/envs/irl/lib/python3.8/site-packages/torch/nn/functional.py[0m(2660)[0;36mmse_loss[0;34m()[0m
[0;32m   2658 [0;31m[0;34m[0m[0m
[0m[0;32m   2659 [0;31m    [0mexpanded_input[0m[0;34m,[0m [0mexpanded_target[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mbroadcast_tensors[0m[0;34m([0m[0minput[0m[0;34m,[0m [0mtarget[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 2660 [0;31m    [0;32mreturn[0m [0mtorch[0m[0;34m.[0m[0m_C[0m[0;34m.[0m[0m_nn[0m[0;34m.[0m[0mmse_loss[0m[0;34m([0m[0mexpanded_input[0m[0;34m,[0m [0mexpanded_target[0m[0;34m,[0m [0m_Reduction[0m[0;34m.[0m[0mget_enum[0m[0;34m([0m[0mreduction[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   2661 [0;31m[0;34m[0m[0m
[0m[0;32m   2662 [0;31m[0;34m[0m[0m
[0m
ipdb> u
> [0;32m/scratch/nsk367/anaconda3/envs/irl/lib/python3.8/site-packages/torch/nn/modules/loss.py[0m(446)[0;36mforward[0;34m()[0m
[0;32m    44

In [None]:
mu, std = actor(state.resize(1,60,300))

In [None]:
action = get_action(mu, std)[0]


In [None]:
action