In [None]:
from transformers import TFAutoModelForCausalLM, AutoTokenizer
import tensorflow as tf

# Load the pretrained teacher model and tokenizer
teacher_model_name = "meta-llama/Llama-2-13b-hf"
teacher_model = TFAutoModelForCausalLM.from_pretrained(teacher_model_name)
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)

# Define input texts
texts = [
    "The movie was fantastic!",
    "The plot was dull and predictable.",
]

# Tokenize inputs
inputs = teacher_tokenizer(texts, return_tensors="tf", padding=True, truncation=True)

# Generate soft labels (logits) from the teacher model
teacher_logits = teacher_model(inputs.input_ids).logits


In [None]:
from transformers import TFAutoModelForCausalLM, AutoTokenizer
import tensorflow as tf
import numpy as np

# Load the smaller student model and tokenizer
student_model_name = "meta-llama/Llama-2-7b-hf"
student_model = TFAutoModelForCausalLM.from_pretrained(student_model_name)
student_tokenizer = AutoTokenizer.from_pretrained(student_model_name)

# Ensure teacher and student tokenizers are consistent
assert teacher_tokenizer.vocab == student_tokenizer.vocab, "Mismatch in vocabularies!"

# Define KL Divergence loss function
def kl_divergence_loss(teacher_logits, student_logits):
    teacher_probs = tf.nn.softmax(teacher_logits, axis=-1)
    student_probs = tf.nn.softmax(student_logits, axis=-1)
    return tf.reduce_mean(
        tf.reduce_sum(teacher_probs * tf.math.log(teacher_probs / student_probs), axis=-1)
    )

# Prepare optimizer
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5)

# Distillation training loop
batch_size = 2  # Use small batches for demonstration
epochs = 3

for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")
    total_loss = 0

    # Process inputs in batches
    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i + batch_size]

        # Tokenize for both teacher and student
        student_inputs = student_tokenizer(batch_texts, return_tensors="tf", padding=True, truncation=True)
        teacher_inputs = teacher_tokenizer(batch_texts, return_tensors="tf", padding=True, truncation=True)

        # Forward pass
        with tf.GradientTape() as tape:
            student_logits = student_model(student_inputs.input_ids).logits
            teacher_logits_batch = teacher_model(teacher_inputs.input_ids).logits

            # Compute KL divergence loss
            loss = kl_divergence_loss(teacher_logits_batch, student_logits)

        # Backpropagation and optimization
        gradients = tape.gradient(loss, student_model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, student_model.trainable_variables))

        total_loss += loss.numpy()

    print(f"Loss: {total_loss / (len(texts) // batch_size):.4f}")


In [None]:
def combined_loss(teacher_logits, student_logits, true_labels, alpha=0.5):
    # Distillation loss
    distill_loss = kl_divergence_loss(teacher_logits, student_logits)
    # Supervised loss
    supervised_loss = tf.keras.losses.sparse_categorical_crossentropy(
        true_labels, student_logits, from_logits=True
    )
    return alpha * distill_loss + (1 - alpha) * tf.reduce_mean(supervised_loss)
