In [None]:
import torch
import sys
sys.path.append('../src')

import torch.nn.functional as F

from models.utils import get_model
from models.config import TOKENS_RAW_CUTOFF
from models.seq2seqattn import init_weights, EncRnn, DecRnn, Seq2SeqAttn

In [None]:
w2v_model = get_model()
# w2ind from w2v
w2ind = {token: token_index for token_index, token in enumerate(w2v_model.wv.index2word)} 
# sorted vocab words
assert w2v_model.vocabulary.sorted_vocab == True
word_counts = {word: vocab_obj.count for word, vocab_obj in w2v_model.wv.vocab.items()}
word_counts = sorted(word_counts.items(), key=lambda x:-x[1])
words = [t[0] for t in word_counts]
# sentence marker token inds
sos_ind = w2ind['<sos>']
eos_ind = w2ind['<eos>']
# adjusted sequence length
SEQ_LEN = 5 + 2 # sos, eos tokens
# padding token for now
TRG_PAD_IDX = w2ind["."] # this is 0
# vocab, embed dims
VOCAB_SIZE, EMBED_DIM = w2v_model.wv.vectors.shape
VOCAB_SIZE, EMBED_DIM

In [None]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
enc = EncRnn(hidden_size=64, num_layers=2, embed_size=EMBED_DIM)
dec = DecRnn(hidden_size=64, num_layers=2, embed_size=EMBED_DIM, output_size=VOCAB_SIZE)
model = Seq2SeqAttn(enc, dec, TRG_PAD_IDX, VOCAB_SIZE, device).to(device)

In [None]:

model.load_state_dict(torch.load(
    '/scratch/nsk367/deepRL/limitation-learning/src/pretrained_generators/model-epoch10.pt'))

In [None]:
d = torch.load('../dat/processed/padded_vectorized_states_v3.pt')


In [None]:
def translate_sentence(words, input_state, next_state, model, eos_ind, max_len, device):
    
    model.eval()
    src_tensor = input_state.unsqueeze(0).to(device)
    src_len = torch.Tensor([int(max_len)])

    with torch.no_grad():
        encoder_outputs, hidden = model.encoder(src_tensor, src_len)

    mask = model.create_mask(src_tensor.transpose(1,0)).to(device)
    # get first decoder input (<sos>)'s one hot
    trg_indexes = [next_state[0]]
    # create a array to store attetnion
    attentions = torch.zeros(max_len, 1, len(input_state))
    #print(attentions.shape)


    for i in range(max_len):
        trg_tensor = torch.LongTensor([trg_indexes[-1]]).to(device)
        #print(trg_tensor.shape)
        with torch.no_grad():
            output, hidden, attention = model.decoder(trg_tensor, hidden, encoder_outputs, mask)
        #print(F.softmax(output))
        attentions[i] = attention
        pred_token = output.argmax(1).item()
        if pred_token == eos_ind: # end of sentence.
            break
        trg_indexes.append(pred_token)
        
    trg_tokens = [words[int(ind)] for ind in trg_indexes]
    #  remove <sos>
    return trg_tokens[1:], attentions[:len(trg_tokens)-1]

In [None]:
def get_action_probs(model, input_state, sos_ind, eos_ind, SEQ_LEN, device):
    """
    Given an input sequence and policy, produce a distribution over tokens the predicted token for each step in the sequence. 
    """
    src_tensor = input_state.unsqueeze(0).to(device)
    src_len = torch.Tensor([int(SEQ_LEN)])
    encoder_outputs, hidden = model.encoder(src_tensor, src_len)
    print('encoutshape',encoder_outputs.shape)
    mask = model.create_mask(src_tensor.transpose(1,0)).to(device)
    trg_indexes = [sos_ind]
    attentions = torch.zeros(SEQ_LEN, 1, len(input_state))
    outputs = []
    for i in range(SEQ_LEN):
        trg_tensor = torch.LongTensor([trg_indexes[-1]]).to(device)
        output, hidden, attention = model.decoder(trg_tensor, hidden, encoder_outputs, mask)
        attentions[i] = attention
        pred_token = output.argmax(1).item()
        trg_indexes.append(pred_token)
       # print(pred_token)
        if pred_token == eos_ind: # end of sentence.
        
            break
        outputs.append(output)
        
  #  trg_tokens = [words[int(ind)] for ind in trg_indexes]
    #  remove <sos>
    return F.softmax(torch.stack(outputs)).to(device)
    #return torch.stack(outputs).to(device)

In [None]:
for toke in range(output_dist.shape[0]):
    m = Categorical(output_dist[toke])
    action = m.sample()
    action_log_prob = m.log_prob(action)# * reward
    print(action_log_prob)

In [None]:
outputs.argmax(dim=2)

In [None]:
J = -action_log_prob * irl_reward

# irl_reward = what discriminator says


In [None]:
discriminator(action,state) -> # between 0 and 1 



In [None]:
from torch.distributions import Categorical

In [None]:
def get_action(action_probs):

for toke in range(output_dist.shape[0]):
    m = Categorical(output_dist[toke])
    action = m.sample()
    action_log_prob = m.log_prob(action)# * reward
    print(action_log_prob)

In [None]:
action = []
action_log_probs = []
for i in action_probs:
    m = Categorical(i)
    action_ = m.sample()
    action_log_prob = m.log_prob(action_)# * reward
    action.append(action_.item())
    action_log_probs.append(action_log_prob)

In [None]:
policy_loss = torch.stack(action_log_probs).sum()

In [None]:
policy_loss.backward()

In [None]:
policy_loss

In [None]:
            policy_loss = -log_probs[batch_index] * rewards[batch_index]


# GAIL 

In [None]:
# Main loop... 
# iterate through episodes
for idx, (index, vects) in enumerate(d.items()):
        input_state, expert_action = vects

        input_state = torch.cat((torch.LongTensor([sos_ind]), 
                                 input_state,
                                 torch.LongTensor([eos_ind])), 
                                 dim=0).to(device)
        expert_action = torch.cat((torch.LongTensor([sos_ind]), 
                                expert_action, 
                                torch.LongTensor([eos_ind])), 
                               dim=0).to(device)
        
        
        
        action_probs =  get_action_probs(model, input_state, sos_ind, eos_ind, SEQ_LEN, device)
        action = get_action(action_probs)
        break
        # get action from this using torch distributions! 
        
        state_tokens = [words[int(ind)] for ind in input_state.cpu().detach().numpy()][1:-1]
        #expert_act = [words[int(ind)] for ind in next_state.cpu().detach().numpy()][1:-1]
        print(state_tokens,'\n',target_tokens)
        break

In [None]:
state = input_state.unsqueeze(0).to(device)
expert_action = expert_action.unsqueeze(0).to(device)

In [None]:
import torch.nn as nn

In [None]:
class Discriminator(nn.Module):
    def __init__(self,model,SEQ_LEN):
        super(Discriminator, self).__init__()

        self.state_encoder = model.encoder
        self.action_encoder = model.encoder
        
        self.fc1 = nn.Linear(1280,512)
        self.fc2 = nn.Linear(512,512)
        self.fc3 = nn.Linear(512,1)
        self.src_len = torch.Tensor([int(SEQ_LEN)])
    def forward(self,x1,x2):
        state_z, _ = self.state_encoder(x1, self.src_len)
        action_z, _ = self.action_encoder(x2, self.src_len)

        state_action = torch.cat([state_z.flatten().unsqueeze(0), action_z.flatten().unsqueeze(0)],dim=1)
        
        state_action = torch.relu(self.fc1(state_action))
        state_action = torch.relu(self.fc2(state_action))
        state_action = torch.sigmoid(self.fc3(state_action))

        return state_action

In [None]:
discrim = Discriminator(model,5).to(device)

In [None]:
pred = discrim(state,expert_action)

In [None]:
action = torch.Tensor(action).to(device)


In [None]:
action.un.shape

In [None]:
pred = discrim(state,action.unsqueeze(0).long())

In [None]:
pred

In [None]:
action

In [None]:
expert_action