In [None]:
import torch
import sys
sys.path.append('../src')


from models.utils import get_model
from models.config import TOKENS_RAW_CUTOFF
from models.seq2seqattn import init_weights, EncRnn, DecRnn, Seq2SeqAttn

In [None]:
w2v_model = get_model()
# w2ind from w2v
w2ind = {token: token_index for token_index, token in enumerate(w2v_model.wv.index2word)} 
# sorted vocab words
assert w2v_model.vocabulary.sorted_vocab == True
word_counts = {word: vocab_obj.count for word, vocab_obj in w2v_model.wv.vocab.items()}
word_counts = sorted(word_counts.items(), key=lambda x:-x[1])
words = [t[0] for t in word_counts]
# sentence marker token inds
sos_ind = w2ind['<sos>']
eos_ind = w2ind['<eos>']
# adjusted sequence length
SEQ_LEN = 5 + 2 # sos, eos tokens
# padding token for now
TRG_PAD_IDX = w2ind["."] # this is 0
# vocab, embed dims
VOCAB_SIZE, EMBED_DIM = w2v_model.wv.vectors.shape
VOCAB_SIZE, EMBED_DIM

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
enc = EncRnn(hidden_size=64, num_layers=2, embed_size=EMBED_DIM)
dec = DecRnn(hidden_size=64, num_layers=2, embed_size=EMBED_DIM, output_size=VOCAB_SIZE)
model = Seq2SeqAttn(enc, dec, TRG_PAD_IDX, VOCAB_SIZE, device).to(device)

In [None]:
model.load_state_dict(torch.load(
    '/scratch/rz1567/deep_rl/final_project_new/limitation-learning/src/pretrained_generators/model-epoch10.pt')
                     )
model.eval()

In [None]:
def translate_sentence(words, input_state, next_state, model, eos_ind, max_len, device):
    
    model.eval()
    src_tensor = input_state.unsqueeze(0).to(device)
    src_len = torch.Tensor([int(max_len)])

    with torch.no_grad():
        encoder_outputs, hidden = model.encoder(src_tensor, src_len)

    mask = model.create_mask(src_tensor.transpose(1,0)).to(device)
    # get first decoder input (<sos>)'s one hot
    trg_indexes = [next_state[0]]
    # create a array to store attetnion
    attentions = torch.zeros(max_len, 1, len(input_state))
    #print(attentions.shape)


    for i in range(max_len):
        trg_tensor = torch.LongTensor([trg_indexes[-1]]).to(device)
        #print(trg_tensor.shape)
        with torch.no_grad():
            output, hidden, attention = model.decoder(trg_tensor, hidden, encoder_outputs, mask)
        #print(F.softmax(output))
        attentions[i] = attention
        pred_token = output.argmax(1).item()
        if pred_token == eos_ind: # end of sentence.
            break
        trg_indexes.append(pred_token)
        
    trg_tokens = [words[int(ind)] for ind in trg_indexes]
    #  remove <sos>
    return trg_tokens[1:], attentions[:len(trg_tokens)-1]

In [None]:
d = torch.load('../dat/processed/padded_vectorized_states_v3.pt')

In [None]:
from GAIL import get_cosine_sim

In [None]:
with torch.no_grad():
    for idx, (index, vects) in enumerate(d.items()):
        input_state, next_state = vects[0], vects[1]
        
        input_state = torch.cat((torch.LongTensor([sos_ind]), 
                                 input_state,
                                 torch.LongTensor([eos_ind])), 
                                 dim=0).to(device)
        
        next_state = torch.cat((torch.LongTensor([sos_ind]), 
                                next_state, 
                                torch.LongTensor([eos_ind])), 
                               dim=0).to(device)
        
        trg = next_state.unsqueeze(0).to(device)
        
        seq_len_tensor = torch.Tensor([int(SEQ_LEN)])
        
        output = model(input_state.unsqueeze(0), seq_len_tensor, trg)

        trg = trg.transpose(1,0)
        output_dim = output.shape[-1]                
        output = output[1:].view(-1, output_dim)
        trg = trg[1:].view(-1)

        translation, attention = translate_sentence(words, input_state, next_state, model, eos_ind, SEQ_LEN, device)

        # drop <sos>, <eos>
        init_act = [words[int(ind)] for ind in input_state.cpu().detach().numpy()][1:-1]
        expert_act = [words[int(ind)] for ind in next_state.cpu().detach().numpy()][1:-1]
        # drop multiple instances of padded token
        expert_act_unpadded = []
        for tok in expert_act:
            expert_act_unpadded.append(tok)
            #if tok == words[int(TRG_PAD_IDX)]:
            #    break
        init_act_unpadded = []
        for tok in init_act:
            init_act_unpadded.append(tok)
            #if tok == words[int(TRG_PAD_IDX)]:
            #    break        
        vectorized_expert_act = [w2v_model.wv[tok] for tok in expert_act_unpadded]
        vectorized_pred_act = [w2v_model.wv[tok] for tok in translation]
        cos_sim = get_cosine_sim(vectorized_expert_act, vectorized_pred_act, 
                                 type = None, 
                                 seq_len = SEQ_LEN-2,
                                 dim = EMBED_DIM)
        
        print(f'state = {" ".join(init_act_unpadded)}')
        print(f'expert = {" ".join(expert_act_unpadded)}')
        print(f'model = {" ".join(translation)}')
        print(cos_sim)
        print("\n")
        
        if idx>10:
            break
