In [2]:
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
import seaborn as sns
import os
import pandas as pd
import sys
import numpy as np
import glob
import torch

sys.path.insert(1,"/home/showalte/research/prob_seq_queries/")
from seq_queries.utils import read_pkl, write_pkl

In [4]:
def tuple_of_tuples_match(tup1, tup2):
    assert len(tup1) == len(tup2),"Layer lengths do not match"
    for t1,t2 in zip(tup1, tup2):
        assert len(t1) == len(t2),f"Tuples of 2 do not match, t1 {t1.shape} | t2 {t2.shape}"
        t11,t12 = t1
        t21,t22 = t2
        assert torch.equal(t11,t21),\
            f"First tuple position shapes do not match: t11 {t11.shape} | t21 {t21.shape}"
        assert torch.equal(t12,t22),\
            f"Second tuple position shapes do not match: t12 {t12.shape} | t21 {t22.shape}"
    print("All good!")
        

In [5]:
def split_2xtups(hidden_states,max_batch_size = 16):
    hidden_states = list(zip(*[
                    zip(*(torch.split(h[0],max_batch_size), torch.split(h[1],max_batch_size)))
                     for h in hidden_states]))
    print(len(hidden_states))
    print(len(hidden_states[0]))
    print(len(hidden_states[0][0]))
    print(hidden_states[0][0][0].shape)
    return hidden_states

def unsplit_2xtups(step_outputs):
    layer_hiddens = []
    for layer_data in zip(*step_outputs):
            layer_data= list(zip(*layer_data))
            layer_hiddens.append(
                (torch.cat(layer_data[0],dim=0).cpu(),
                 torch.cat(layer_data[1],dim=0).cpu())
            )
    return tuple(layer_hiddens)

In [6]:

def make_hidden_state(num_layers, tup_size, tensor_dim):
    tup = []
    for l in range(num_layers):
        i_tup = []
        for t in range(tup_size):
            i_tup.append(torch.randn(*tensor_dim))
        i_tup = tuple(i_tup)
        tup.append(i_tup)
    return tuple(tup)

In [7]:
num_layers = 8
# Samples x heads x seq_len x head_dim)
tensor_dim = (400, 12, 10, 68)
tup_size = 2
h = make_hidden_state(num_layers, tup_size, tensor_dim)

In [8]:
max_batch_sizes = [1,8,9,31,16,128,32]
for max_batch_size in max_batch_sizes:
    tuple_of_tuples_match(
        h,
        unsplit_2xtups(
            split_2xtups(h,
                         max_batch_size), 
        )

    )

400
8
2
torch.Size([1, 12, 10, 68])
All good!
50
8
2
torch.Size([8, 12, 10, 68])
All good!
45
8
2
torch.Size([9, 12, 10, 68])
All good!
13
8
2
torch.Size([31, 12, 10, 68])
All good!
25
8
2
torch.Size([16, 12, 10, 68])
All good!
4
8
2
torch.Size([128, 12, 10, 68])
All good!
13
8
2
torch.Size([32, 12, 10, 68])
All good!


## Check sequences for ground truth v hybrid

In [9]:
hybrid = read_pkl("../data/beam_search_is_hybrid/wikitext/val_dl/val-dl_wikitext_beam-search-is-hybrid_13h_15s_10000mc_10q.pkl")

In [10]:
pgt = read_pkl("../data/pseudo_gt/wikitext/val_dl/val-dl_wikitext_pseudo-gt_13h_15s_1000mc_pgt_500q.pkl")

In [21]:
pgt2 = read_pkl("../data/pseudo_gt/wikitext/val_dl/val-dl_wikitext_pseudo-gt_13h_15s_1000mc_pgt_100q.pkl")

In [22]:
pgt_est2 = torch.gather(pgt2['sample_estimates'],1,pgt2['excluded_terms'].unsqueeze(-1)).squeeze()
pgt_est_test2 = pgt_est2[:num_samples]

In [13]:
num_samples =10
pgt_est = torch.gather(pgt['sample_estimates'],1,pgt['excluded_terms'].unsqueeze(-1)).squeeze()
pgt_est_test = pgt_est[:num_samples]

In [16]:
# print(hybrid['sample_estimates'][.shape)
hybrid = torch.gather(hybrid['sample_estimates'][:,-1,:],1,hybrid['excluded_terms'].unsqueeze(-1)).squeeze()
hybrid.shape

torch.Size([10])

In [23]:
pgt_est_test2

tensor([3.9509e-03, 6.3127e-07, 3.0005e-04, 3.1273e-03, 7.6277e-02, 1.9384e-03,
        1.8733e-04, 2.4132e-06, 1.1433e-05, 1.0377e-01])

In [18]:
pgt_est_test

tensor([8.1826e-04, 5.2079e-07, 3.1304e-04, 4.3466e-02, 7.4360e-02, 5.3057e-03,
        4.8004e-04, 1.6603e-06, 6.2787e-06, 1.1255e-01])

In [19]:
hybrid

tensor([1.6305e-03, 5.5858e-07, 5.1883e-04, 2.4782e-02, 1.4181e-01, 4.8309e-03,
        7.3468e-04, 2.2985e-06, 9.0691e-06, 3.2000e-02])

In [26]:
np.abs(hybrid - pgt_est_test)

tensor([8.1227e-04, 3.7790e-08, 2.0579e-04, 1.8684e-02, 6.7446e-02, 4.7484e-04,
        2.5464e-04, 6.3820e-07, 2.7905e-06, 8.0546e-02])

In [27]:
np.abs(hybrid - pgt_est_test2)

tensor([2.3204e-03, 7.2685e-08, 2.1878e-04, 2.1654e-02, 6.5530e-02, 2.8925e-03,
        5.4735e-04, 1.1462e-07, 2.3640e-06, 7.1769e-02])

In [28]:
np.abs(pgt_est_test - pgt_est_test2)

tensor([3.1326e-03, 1.1047e-07, 1.2991e-05, 4.0338e-02, 1.9164e-03, 3.3674e-03,
        2.9271e-04, 7.5282e-07, 5.1544e-06, 8.7766e-03])

In [20]:
(np.abs(hybrid - pgt_est_test)/pgt_est_test)*100

tensor([99.2686,  7.2563, 65.7401, 42.9855, 90.7018,  8.9496, 53.0460, 38.4377,
        44.4438, 71.5671])

In [24]:
(np.abs(hybrid - pgt_est_test2)/pgt_est_test2)*100

tensor([ 58.7302,  11.5141,  72.9161, 692.4354,  85.9104, 149.2244, 292.1930,
          4.7498,  20.6765,  69.1623])

In [25]:
(np.abs(pgt_est_test - pgt_est_test2)/pgt_est_test)*100

tensor([382.8433,  21.2128,   4.1500,  92.8052,   2.5772,  63.4665,  60.9769,
         45.3410,  82.0945,   7.7982])