In [1]:
import os
import tensorflow as tf
from util import constants
from util.config_util import get_model_params, get_task_params, get_train_params
from tf2_models.trainer import Trainer
from absl import app
from absl import flags
import numpy as np
from util.models import MODELS
from util.tasks import TASKS
from notebook_utils import *

%matplotlib inline
import pandas as pd
import seaborn as sns; sns.set()

from tqdm import tqdm

In [14]:
def get_reps(inputs, model, index=1, layer=None):
    """
    If Model is LSTM:
        1: final_rnn_outputs, 
        2: hidden_activation (for all layers, including input embeddings)
    """
    outputs = model.detailed_call(inputs)
    reps = outputs[index]
    
    if layer is not None:
        reps = reps[layer]
    
    return reps

def normalized_pairwisedot_product_sim(reps1, reps2):    
    reps1 = reps1 / tf.norm(reps1, axis=-1)[...,None]
    reps2 = reps2 / tf.norm(reps2, axis=-1)[...,None]
    
    pw_dot_product = tf.cast(tf.matmul(reps1, reps2, transpose_b=True), dtype=tf.float32) 

    p_max = tf.reduce_max(pw_dot_product, axis=-1)
    p_min =  tf.reduce_min(pw_dot_product, axis=-1)
    
    
    #pw_dot_product = (pw_dot_product  - p_max) / (p_max - p_min)
    return pw_dot_product


def normalized_dot_product_sim(reps1, reps2):
    #normalize reps:
    reps1 = reps1 / tf.norm(reps1, axis=-1)[...,None]
    reps2 = reps2 / tf.norm(reps2, axis=-1)[...,None]
    
    norm1 = tf.norm(reps1, axis=-1)
    norm2 = tf.norm(reps2, axis=-1)

    # Elementwise multiplication
    dot_product = tf.multiply(reps1, reps2)
    
    # Sum over last axis to get the dot product similarity between corresponding pairs
    dot_product = tf.reduce_sum(dot_product, axis=-1) * padding_mask
    
    return dot_product
    
    
def second_order_rep_sim(reps1, reps2, padding_mask):
    
    sims1 = normalized_pairwisedot_product_sim(reps1, reps1)
    sims2 = normalized_pairwisedot_product_sim(reps2, reps2)
    
    padding_mask = tf.ones((reps1.shape))
    so_sims = normalized_dot_product_sim(sims1, sims2, padding_mask) * padding_mask
    
    mean_sim = so_sims / np.reduce_sum(padding_mask)
    
    return mean_sim, so_sims

def compare_models(inputs, model1, model2, index1=1, index2=1,layer1=None, layer2=None, padding_symbol=None):
    reps1 = get_reps(inputs, model1)
    reps2 = get_reps(inputs, model2)
    if padding_symbol is not None:
        padding_mask = tf.cast(1.0 - (inputs == padding_symbol), dtype=tf.float32)
    else:
        padding_mask = tf.ones
    
    reps1 = tf.reshape(reps1, (-1, tf.shape(reps1)[-1]))
    reps2 = tf.reshape(reps2, (-1, tf.shape(reps2)[-1]))
    
    similarity_measures = second_order_rep_sim(reps1, reps2, padding_mask=padding_mask)
    
    return similarity_measures

def compare_reps(reps1, reps2):
    reps1 = tf.reshape(reps1, (-1, tf.shape(reps1)[-1]))
    reps2 = tf.reshape(reps2, (-1, tf.shape(reps2)[-1]))
    
    similarity_measures = second_order_rep_sim(reps1, reps2)
    
    return similarity_measures

In [15]:
task_name = 'word_sv_agreement_vp'
chkpt_dir='../tf_ckpts'
task = TASKS[task_name](get_task_params(), data_dir='../data')
cl_token = task.databuilder.sentence_encoder().encode(constants.bos)

Vocab len:  10032


In [4]:
config={'student_exp_name':'lisa_fd131',
    'teacher_exp_name':'0.001_samira_offlineteacher_v11',
    'teacher_config':'small_lstm_v4',
    'task_name':'word_sv_agreement_vp',
    'student_model':'cl_lstm',
    'teacher_model':'cl_lstm',
    'student_config':'small_lstm_v4',
    'distill_config':'pure_dstl_4_crs_slw',
    'distill_mode':'offline',
    'chkpt_dir':'../tf_ckpts',
       }

std_hparams=get_model_params(task, config['student_model'], config['student_config'])
std_hparams.output_attentions = True
std_hparams.output_embeddings = True
std_hparams.output_hidden_states = True

model1, _ = get_student_model(config, task, std_hparams, cl_token)

tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])
tchr_hparams.output_attentions = True
tchr_hparams.output_embeddings = True
tchr_hparams.output_hidden_states = True

model2, _ = get_teacher_model(config, task, tchr_hparams, cl_token)


model config: small_lstm_v4
{'hidden_dim': 256, 'embedding_dim': 256, 'depth': 2, 'hidden_dropout_rate': 0.8, 'input_dropout_rate': 0.2, 'initializer_range': 0.1}
model config: small_lstm_v4
{'hidden_dim': 256, 'embedding_dim': 256, 'depth': 2, 'hidden_dropout_rate': 0.8, 'input_dropout_rate': 0.2, 'initializer_range': 0.1}
student_checkpoint: ../tf_ckpts/word_sv_agreement_vp/offline_pure_dstl_4_crs_slw_teacher_cl_lstm_em-256_h-256_d-2_hdrop-0.8_indrop-0.2_small_lstm_v4_0.001_samira_offlineteacher_v11_student_cl_lstm_em-256_h-256_d-2_hdrop-0.8_indrop-0.2_small_lstm_v4_lisa_fd131
Restored student from ../tf_ckpts/word_sv_agreement_vp/offline_pure_dstl_4_crs_slw_teacher_cl_lstm_em-256_h-256_d-2_hdrop-0.8_indrop-0.2_small_lstm_v4_0.001_samira_offlineteacher_v11_student_cl_lstm_em-256_h-256_d-2_hdrop-0.8_indrop-0.2_small_lstm_v4_lisa_fd131/ckpt-60
model config: small_lstm_v4
{'hidden_dim': 256, 'embedding_dim': 256, 'depth': 2, 'hidden_dropout_rate': 0.8, 'input_dropout_rate': 0.2, 'initia

In [18]:
for inputs, labels in task.valid_dataset:
    reps1 = get_reps(inputs, model1, index=1, layer=None)
    reps2 = get_reps(inputs, model2, index=1, layer=None)
    print(compare_reps(reps1, reps2))
    print(compare_reps(reps1, reps1))


(0.9217837, <tf.Tensor: id=12324, shape=(64,), dtype=float32, numpy=
array([ 0.9896728 ,  0.9732193 ,  0.9824896 ,  0.97884035,  0.98024136,
        0.9796636 ,  0.981256  ,  0.9709201 ,  0.9880232 ,  0.9770104 ,
        0.972235  ,  0.9807387 ,  0.9770463 ,  0.98081416,  0.9872096 ,
        0.9752636 ,  0.9747678 ,  0.9770424 ,  0.9810164 ,  0.98590946,
        0.97396654,  0.9790244 ,  0.97351813,  0.93374074,  0.9788846 ,
        0.9746642 ,  0.98091495,  0.97526336,  0.98450315,  0.98809594,
        0.96822846,  0.97443044,  0.9891899 ,  0.9739234 ,  0.979493  ,
        0.964818  ,  0.981316  ,  0.981212  ,  0.9638127 ,  0.98080385,
        0.50190663,  0.9731533 ,  0.9844091 ,  0.9459042 ,  0.9784997 ,
        0.9817607 ,  0.673989  ,  0.97494805,  0.98368967,  0.9638833 ,
        0.98150396, -0.29589486,  0.9777565 ,  0.97949785,  0.9800193 ,
        0.9809008 ,  0.9623152 ,  0.9848262 ,  0.9807231 , -0.5021987 ,
        0.9708454 ,  0.9760765 ,  0.9876416 ,  0.97481644], dtype=f

In [None]:
outputs = model2.detailed_call(inputs)

In [None]:
outputs[1]

In [None]:
len(outputs)