In [None]:
import os
import tensorflow as tf
from util import constants
from util.config_util import get_model_params, get_task_params, get_train_params
from tf2_models.trainer import Trainer
from absl import app
from absl import flags
import numpy as np
from util.models import MODELS
from util.tasks import TASKS

%matplotlib inline
import pandas as pd
import seaborn as sns; sns.set()

from tqdm import tqdm

In [2]:
student_exp_name='samira_fd1'
teacher_exp_name='0.0001_offlineteacher_v3'
teacher_config='small_lstm_v4'
task_name = 'word_sv_agreement_vp'
student_model='cl_gpt2'
teacher_model='cl_lstm'
student_config='small_gpt_v9'
distill_config='pure_distill_2'
distill_mode='offline'

chkpt_dir='../tf_ckpts'

In [3]:
task = TASKS[task_name](get_task_params(), data_dir='../data')

Vocab len:  10034


In [4]:
cl_token = task.databuilder.sentence_encoder().encode(constants.bos)
teacher_model = MODELS[teacher_model](hparams=get_model_params(task, teacher_model, teacher_config), cl_token=cl_token)
std_hparams=get_model_params(task, student_model, student_config)
std_hparams.output_attentions = True
std_hparams.output_embeddings = True
student_model = MODELS[student_model](
std_hparams, cl_token=cl_token)

model config: small_lstm_v4
{'hidden_dim': 256, 'embedding_dim': 256, 'depth': 2, 'hidden_dropout_rate': 0.8, 'input_dropout_rate': 0.2, 'initializer_range': 0.1}
model config: small_gpt_v9
{'embedding_dim': 128, 'resid_pdrop': 0.4, 'embd_pdrop': 0.2, 'attn_pdrop': 0.6, 'initializer_range': 0.05}


In [5]:
student_ckpt_dir = os.path.join(chkpt_dir, task.name,
                              '_'.join([distill_mode,distill_config,
                                        "teacher", teacher_model.model_name, 
                                        #teacher_config,
                                        teacher_exp_name,
                                       "student",student_model.model_name,
                                        str(student_config),
                                        student_exp_name]))
print("student_checkpoint:", student_ckpt_dir)

student_checkpoint: ../tf_ckpts/word_sv_agreement_vp/offline_pure_distill_2_teacher_cl_lstm_em-256_h-256_d-2_hdrop-0.8_indrop-0.2_0.0001_offlineteacher_v3_student_cl_gpt2_h-128_d-6_rdrop-0.4_adrop-0.6_indrop-0.2_small_gpt_v9_samira_fd1


In [6]:
student_ckpt = tf.train.Checkpoint(net=student_model)
student_manager = tf.train.CheckpointManager(student_ckpt, student_ckpt_dir, max_to_keep=None)

student_ckpt.restore(student_manager.latest_checkpoint)
if student_manager.latest_checkpoint:
  print("Restored student from {}".format(student_manager.latest_checkpoint))

student_model.compile(loss=task.get_loss_fn(), metrics=task.metrics())
student_model.evaluate(task.test_dataset, steps=100)

Restored student from ../tf_ckpts/word_sv_agreement_vp/offline_pure_distill_2_teacher_cl_lstm_em-256_h-256_d-2_hdrop-0.8_indrop-0.2_0.0001_offlineteacher_v3_student_cl_gpt2_h-128_d-6_rdrop-0.4_adrop-0.6_indrop-0.2_small_gpt_v9_samira_fd1/ckpt-60






















[0.12402263533324004, 0.07817629, 0.97125]

In [7]:
teacher_ckpt_dir = os.path.join(chkpt_dir, task.name,
                                  '_'.join([teacher_model.model_name, teacher_config,teacher_exp_name]))

teacher_ckpt = tf.train.Checkpoint(net=teacher_model)
teacher_manager = tf.train.CheckpointManager(teacher_ckpt, teacher_ckpt_dir, max_to_keep=None)

teacher_ckpt.restore(teacher_manager.latest_checkpoint)
if teacher_manager.latest_checkpoint:
  print("Restored student from {}".format(teacher_manager.latest_checkpoint))

Restored student from ../tf_ckpts/word_sv_agreement_vp/cl_lstm_em-256_h-256_d-2_hdrop-0.8_indrop-0.2_small_lstm_v4_0.0001_offlineteacher_v3/ckpt-60


In [8]:
teacher_model.compile(loss=task.get_loss_fn(), metrics=task.metrics())
teacher_model.evaluate(task.test_dataset, steps=100)























[0.10886244153603912, 0.08943609, 0.971875]

In [11]:
model_name='cl_gpt2'
model_config='small_gpt_v9'
learning_rate=0.0001
exp_name='offlineteacher_v1'

cl_token = task.databuilder.sentence_encoder().encode(constants.bos)
hparams=get_model_params(task, model_name, model_config)
hparams.output_attentions = True
hparams.output_embeddings = True

model = MODELS[model_name](hparams=hparams, cl_token=cl_token)


ckpt_dir = os.path.join(chkpt_dir,task.name,
                        model.model_name+"_"+str(model_config)+"_"+str(learning_rate)+"_"+exp_name)

ckpt = tf.train.Checkpoint(net=model)
manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=None)

ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
  print("Restored student from {}".format(manager.latest_checkpoint))

model.compile(loss=task.get_loss_fn(), metrics=task.metrics())
model.evaluate(task.test_dataset, steps=100)

model config: small_gpt_v9
{'embedding_dim': 128, 'resid_pdrop': 0.4, 'embd_pdrop': 0.2, 'attn_pdrop': 0.6, 'initializer_range': 0.05}
Restored student from ../tf_ckpts/word_sv_agreement_vp/cl_gpt2_h-128_d-6_rdrop-0.4_adrop-0.6_indrop-0.2_small_gpt_v9_0.0001_offlineteacher_v1/ckpt-60


[0.2104924625903368, 0.19008262, 0.951875]

In [12]:
student_preds = []
y_trues = []
teacher_preds = []
independent_preds = []
inputs = []
batch_count = task.n_valid_batches
for x, y in tqdm(task.valid_dataset, total=batch_count):
    std_pred = tf.argmax(student_model(x), axis=-1)
    teach_pred = tf.argmax(teacher_model(x), axis=-1)
    indep_pred = tf.argmax(model(x), axis=-1)
    student_preds.extend(std_pred.numpy())
    teacher_preds.extend(teach_pred.numpy())
    y_trues.extend(y.numpy())
    independent_preds.extend(indep_pred)
    inputs.extend(x)
    batch_count -= 1
    if batch_count == 0:
        break


  0%|          | 0/246 [00:00<?, ?it/s][A
  0%|          | 1/246 [00:03<14:10,  3.47s/it][A
  1%|          | 2/246 [00:07<14:22,  3.53s/it][A
  1%|          | 3/246 [00:10<14:00,  3.46s/it][A
  2%|▏         | 4/246 [00:14<14:10,  3.51s/it][A
  2%|▏         | 5/246 [00:17<13:48,  3.44s/it][A
  2%|▏         | 6/246 [00:21<14:02,  3.51s/it][A
  3%|▎         | 7/246 [00:24<13:40,  3.43s/it][A
  3%|▎         | 8/246 [00:28<14:04,  3.55s/it][A
  4%|▎         | 9/246 [00:31<13:40,  3.46s/it][A
  4%|▍         | 10/246 [00:35<13:56,  3.54s/it][A
  4%|▍         | 11/246 [00:38<13:34,  3.47s/it][A
  5%|▍         | 12/246 [00:42<13:43,  3.52s/it][A
  5%|▌         | 13/246 [00:45<13:22,  3.44s/it][A
  6%|▌         | 14/246 [00:49<13:42,  3.55s/it][A
  6%|▌         | 15/246 [00:52<13:28,  3.50s/it][A
  7%|▋         | 16/246 [00:56<13:34,  3.54s/it][A
  7%|▋         | 17/246 [00:59<13:23,  3.51s/it][A
  7%|▋         | 18/246 [01:03<13:32,  3.56s/it][A
  8%|▊         | 19/246 [01:0

In [13]:
student_mistakes = np.asarray(student_preds) == np.asarray(y_trues)
teacher_mistakes = np.asarray(teacher_preds) == np.asarray(y_trues)
model_mistakes = np.asarray(independent_preds) == np.asarray(y_trues)

In [None]:
nonoverlapping_mistakes = np.where(student_mistakes != teacher_mistakes)[0]
teach_wrong_student_right = np.where(student_mistakes > teacher_mistakes)[0]
teach_right_student_wrong = np.where(student_mistakes < teacher_mistakes)[0]

In [19]:
model_right_student_wrong = np.where(student_mistakes < model_mistakes)[0]
model_wrong_student_right = np.where(student_mistakes > model_mistakes)[0]
model_right_teacher_wrong = np.where(teacher_mistakes < model_mistakes)[0]
model_wrong_teacher_right = np.where(teacher_mistakes > model_mistakes)[0]

In [20]:
len(model_wrong_student_right)

402