In [2]:
from util.config_util import get_task_params, get_model_params
import os
from tasks.sv_agreement import SvAgreementLM, WordSvAgreementLM
from tf2_models.lm_transformer import LmGPT2
from util.config_util import get_model_params, get_task_params, get_train_params
from tf2_models.lm_lstm import LmLSTM
from absl import flags
import sys
import tensorflow as tf
from tqdm import tqdm 


FLAGS = flags.FLAGS
flags.DEFINE_string('task', 'word_sv_agreement_lm', 'sv_agreement_lm | word_sv_agreement_lm')


flags.DEFINE_string('teacher_exp_name', 'trial4', 'experiment directory')
flags.DEFINE_string('teacher_model', 'lm_lstm', 'lm_lstm | lm_gpt2')

flags.DEFINE_string('student_exp_name', 'trial1', 'experiment directory')
flags.DEFINE_string('student_model', 'lm_lstm', 'lm_lstm | lm_gpt2')

flags.DEFINE_string('f', None,'kernel')

FLAGS(sys.argv)

hparams = flags.FLAGS


MODELS = {"lm_lstm": LmLSTM,
          "lm_gpt2": LmGPT2}

TASKS = {
  'sv_agreement_lm': SvAgreementLM,
  'word_sv_agreement_lm': WordSvAgreementLM,
}


In [2]:
class DistillLoss(tf.keras.losses.Loss):
    def __init__(self, padding_symbol=0, tmp=1.0,
                 **kwargs):
        super(DistillLoss, self).__init__(**kwargs)
        self.tmp = tf.constant(tmp, dtype=tf.float32)
        self.padding_symbol = tf.constant(padding_symbol, dtype=tf.int32)
    
    def call(self, y_true, y_pred):
      y_true = tf.cast(tf.squeeze(y_true), dtype=tf.float32)
      sequence_mask = tf.cast(y_true[...,self.padding_symbol] != 1.0, dtype=tf.float32)
      sequence_mask = sequence_mask / tf.reduce_sum(sequence_mask)
      return tf.reduce_sum(tf.compat.v2.nn.softmax_cross_entropy_with_logits(logits=y_pred/self.tmp,
                                                                      labels=y_true,
                                                                      name='loss') * sequence_mask)

In [3]:
log_dir = "../logs"
chkpt_dir = "../tf_ckpts"

# Create task
task = TASKS[hparams.task](get_task_params(), data_dir='../data')

# Create the Model
teacher_model = MODELS[hparams.teacher_model](hparams=get_model_params(task, hparams.teacher_model))
student_model = MODELS[hparams.student_model](hparams=get_model_params(task, hparams.student_model))

teacher_log_dir = os.path.join(log_dir, task.name, teacher_model.model_name + "_" + hparams.teacher_exp_name)
teacher_ckpt_dir = os.path.join(chkpt_dir, task.name, teacher_model.model_name + "_" + hparams.teacher_exp_name)


Vocab len:  10034




In [4]:
student_model.compile(
                      optimizer=tf.keras.optimizers.Adam(learning_rate=0.01),
                      loss=DistillLoss(tmp=1))

In [5]:
print(teacher_ckpt_dir)
teacher_ckpt = tf.train.Checkpoint(net=teacher_model)
teacher_manager = tf.train.CheckpointManager(teacher_ckpt, teacher_ckpt_dir, max_to_keep=2)
teacher_ckpt.restore(teacher_manager.latest_checkpoint)
if teacher_manager.latest_checkpoint:
  print("Restored from {}".format(teacher_manager.latest_checkpoint))
else:
  print("Initializing from scratch.")

../tf_ckpts/word_sv_agreement_lm/lm_lstm_h-512_d-3_hdrop-0.5_indrop-0.2_trial4
Restored from ../tf_ckpts/word_sv_agreement_lm/lm_lstm_h-512_d-3_hdrop-0.5_indrop-0.2_trial4/ckpt-18


In [6]:
student_model.optimizer
student_model.trainable_weights

[]

In [7]:
# Apply teacher and student
train_iter = iter(task.valid_dataset)

In [11]:
@tf.function(experimental_relax_shapes=True)
def get_logits(x):
    return student_model(x) 
  
@tf.function(experimental_relax_shapes=True)
def train_step(x,y):
    with tf.GradientTape() as tape:
       logits = student_model(x)
       loss_value = student_model.loss(y_pred=logits, y_true=y)

    grads = tape.gradient(loss_value, student_model.trainable_weights)
    student_model.optimizer.apply_gradients(zip(grads, student_model.trainable_weights))
    return loss_value

@tf.function(experimental_relax_shapes=True)
def get_probs(logits, y, tmp):
    teacher_probs = tf.nn.softmax(logits/tmp, axis=-1)
    sequence_mask = tf.cast(y != 0, dtype=tf.float32)
    masked_teacher_probs = teacher_probs * sequence_mask[...,None] + tf.eye(tf.shape(teacher_probs)[-1])[0] * (1 - sequence_mask[...,None])

    return masked_teacher_probs

soft_targets = []
inputs = []
tmp = tf.constant(1, dtype=tf.float32)

step = 0
for  (x,y) in train_iter:
    x = tf.convert_to_tensor(x, dtype=tf.int64)
    y = tf.convert_to_tensor(x, dtype=tf.int64)
    
    teacher_logits = teacher_model(x)
    masked_teacher_probs = get_probs(teacher_logits, y, tmp)
    
    loss_value = train_step(x,masked_teacher_probs)
    # Log every 200 batches.
    if step % 10 == 0:
        print('Training loss (for one batch) at step %s: %s' % (step, float(loss_value)))
        print('Seen so far: %s samples' % ((step + 1) * 64))

Training loss (for one batch) at step 0: 5.463839530944824
Seen so far: 64 samples
Training loss (for one batch) at step 0: 5.45330810546875
Seen so far: 64 samples
Training loss (for one batch) at step 0: 5.35497522354126
Seen so far: 64 samples
Training loss (for one batch) at step 0: 5.404634475708008
Seen so far: 64 samples
Training loss (for one batch) at step 0: 5.398497581481934
Seen so far: 64 samples
Training loss (for one batch) at step 0: 5.464436054229736
Seen so far: 64 samples
Training loss (for one batch) at step 0: 5.433465957641602
Seen so far: 64 samples
Training loss (for one batch) at step 0: 5.383969306945801
Seen so far: 64 samples
Training loss (for one batch) at step 0: 5.3649983406066895
Seen so far: 64 samples
Training loss (for one batch) at step 0: 5.466148853302002
Seen so far: 64 samples
Training loss (for one batch) at step 0: 5.416903495788574
Seen so far: 64 samples
Training loss (for one batch) at step 0: 5.425065517425537
Seen so far: 64 samples


KeyboardInterrupt: 

In [None]:
teacher_probs = tf.nn.softmax(teacher_logits/1.0, axis=-1)
print(teacher_probs[0])

In [None]:
teacher_probs = teacher_probs * sequence_mask[...,None] + tf.eye(tf.shape(teacher_probs)[-1])[0] * (1 - sequence_mask[...,None])

In [None]:
teacher_probs[0]

In [66]:
def get_masked_probs(logits, labels, temperature, padding_symbol=0):
  teacher_probs = tf.nn.softmax(logits / temperature, axis=-1)
  sequence_mask = tf.cast(labels != padding_symbol, dtype=tf.float32)
  masked_teacher_probs = teacher_probs * sequence_mask[..., None] + tf.eye(tf.shape(teacher_probs)[-1])[0] * (
      1 - sequence_mask[..., None])

  return masked_teacher_probs

def get_topk_mask(inputs, k):
    values, indices = tf.nn.top_k(inputs, k=k, sorted=False)

    temp_indices = tf.meshgrid(*[tf.range(d) for d in (tf.unstack(
           tf.shape(inputs)[:(inputs.get_shape().ndims - 1)]) + [k])], indexing='ij')
    temp_indices = tf.stack(temp_indices[:-1] + [indices], axis=-1)
    full_indices = tf.reshape(temp_indices, [-1, inputs.get_shape().ndims])
    values = tf.reshape(values, [-1])

    mask_st = tf.SparseTensor(indices=tf.cast(
          full_indices, dtype=tf.int64), values=tf.ones_like(values), dense_shape=inputs.shape)
    mask = tf.sparse.to_dense(tf.sparse.reorder(mask_st))
    
    return mask

def get_topk_masked_probs(logits, labels, temperature, k=100, padding_symbol=0):
  topk_mask =(1 - tf.cast(get_topk_mask(logits, k), dtype=tf.float32)) * -10e8
  teacher_probs = tf.nn.softmax((logits+topk_mask) / temperature, axis=-1)
  sequence_mask = tf.cast(labels != padding_symbol, dtype=tf.float32)
  masked_teacher_probs = teacher_probs * sequence_mask[..., None] + tf.eye(tf.shape(teacher_probs)[-1])[0] * (
      1 - sequence_mask[..., None])

  return masked_teacher_probs

In [67]:
import numpy as np
logits = np.float32(np.random.random(size=(3,8)))
print(logits)
labels = np.float32(np.eye(8, dtype=np.float32)[0:3])
print(labels)



[[0.32073522 0.6911475  0.1812065  0.52984554 0.40381116 0.7097301
  0.6959685  0.01675281]
 [0.8634111  0.3515953  0.78308344 0.03304483 0.11540707 0.8430617
  0.70706254 0.7702668 ]
 [0.5482357  0.7089918  0.07130605 0.31061053 0.09298615 0.2160058
  0.99783164 0.7269276 ]]
[[1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0.]]


In [68]:
probs = get_topk_masked_probs(tf.convert_to_tensor(logits), tf.convert_to_tensor(np.argmax(labels, axis=-1)), 1.0, 2)
print(probs)

tf.Tensor(
[[1.         0.         0.         0.         0.         0.
  0.         0.        ]
 [0.5050872  0.         0.         0.         0.         0.4949128
  0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.56731486 0.43268517]], shape=(3, 8), dtype=float32)


In [29]:
np.argmax(probs, axis=-1) == np.argmax(logits,axis=-1)

array([ True,  True,  True])

In [50]:
K = 4
arr = tf.random.normal(shape=(5, 3, 8))
values, indices = tf.nn.top_k(arr, k=K, sorted=False)

temp_indices = tf.meshgrid(*[tf.range(d) for d in (tf.unstack(
       tf.shape(arr)[:(arr.get_shape().ndims - 1)]) + [K])], indexing='ij')
temp_indices = tf.stack(temp_indices[:-1] + [indices], axis=-1)
full_indices = tf.reshape(temp_indices, [-1, arr.get_shape().ndims])
values = tf.reshape(values, [-1])

mask_st = tf.SparseTensor(indices=tf.cast(
      full_indices, dtype=tf.int64), values=tf.ones_like(values), dense_shape=arr.shape)
mask = tf.sparse.to_dense(tf.sparse.reorder(mask_st))

In [51]:
mask

<tf.Tensor: id=427, shape=(5, 3, 8), dtype=float32, numpy=
array([[[0., 1., 1., 0., 1., 0., 0., 1.],
        [0., 1., 0., 1., 0., 1., 0., 1.],
        [0., 0., 1., 1., 1., 0., 0., 1.]],

       [[0., 1., 1., 1., 0., 0., 0., 1.],
        [0., 1., 1., 0., 0., 1., 1., 0.],
        [1., 0., 1., 0., 1., 1., 0., 0.]],

       [[1., 1., 0., 0., 1., 1., 0., 0.],
        [0., 1., 0., 1., 1., 1., 0., 0.],
        [1., 0., 1., 0., 1., 0., 0., 1.]],

       [[0., 0., 1., 0., 1., 0., 1., 1.],
        [0., 1., 1., 0., 0., 0., 1., 1.],
        [0., 1., 1., 0., 0., 1., 1., 0.]],

       [[1., 1., 0., 0., 1., 0., 0., 1.],
        [1., 0., 0., 1., 0., 0., 1., 1.],
        [1., 1., 0., 1., 0., 0., 1., 0.]]], dtype=float32)>