### import and load model

In [1]:
import sys
sys.path.append("..")

import torch
from tqdm import tqdm
import pandas as pd
from datasets import load_dataset
from transformers import AutoTokenizer

from core.trainers.ppo_trainer.custom_ppo import CustomPPOTrainer
from core.trainers.ppo_trainer.config import CustomPPOConfig
from core.custom_components.custom_interaction.custom_test_model import CustomTestModel
from trl.core import LengthSampler
from trl import AutoModelForCausalLMWithValueHead

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
 # Initialize models and tokenizers
model_name = "meta-llama/Llama-3.2-1B-Instruct"
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load tokenizers
child_tokenizer = AutoTokenizer.from_pretrained(model_name)
teacher_tokenizer = AutoTokenizer.from_pretrained(model_name)

# Load models with value heads for PPO
child_model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name)
teacher_model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name)

# Move models to device
child_model.to(device)
teacher_model.to(device)



AutoModelForCausalLMWithValueHead(
  (pretrained_model): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(128256, 2048)
      (layers): ModuleList(
        (0-15): 16 x LlamaDecoderLayer(
          (self_attn): LlamaAttention(
            (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (k_proj): Linear(in_features=2048, out_features=512, bias=False)
            (v_proj): Linear(in_features=2048, out_features=512, bias=False)
            (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
            (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
            (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
          (post_attention_layernorm): LlamaRMSNorm((20

In [3]:
# check gpu memory, or use nvidia-smi
import torch

if torch.cuda.is_available():
    free_mem, total_mem = torch.cuda.mem_get_info()
    print(f"Free memory: {free_mem / 1024**2:.2f} MB")
    print(f"Total memory: {total_mem / 1024**2:.2f} MB")
    print(f"Used memory: {(total_mem - free_mem) / 1024**2:.2f} MB")


Free memory: 30482.31 MB
Total memory: 40339.31 MB
Used memory: 9857.00 MB


In [4]:
# Define generation arguments for student and teacher
student_generation_args = {
    "max_new_tokens": 100,
    "temperature": 0.7,
    "top_p": 0.9,
    "do_sample": True,
    "pad_token_id": child_tokenizer.eos_token_id
}

teacher_generation_args = {
    "max_new_tokens": 150,
    "do_sample": False,
    "pad_token_id": teacher_tokenizer.eos_token_id
}

# Initialize the custom test model
model = CustomTestModel(
    child_model=child_model,
    teacher_model=teacher_model,
    child_tokenizer=child_tokenizer,
    teacher_tokenizer=teacher_tokenizer,
    student_generation_args=student_generation_args,
    teacher_generation_args=teacher_generation_args,
    device=device
)

### generate one interaction instance

In [5]:
# test single interaction
prompt = "The matrix is one of the be"
result = model.interact(prompt)

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


In [6]:
result

{'child_query': tensor([128000,  45147,  31868,  65562,     60,   1134,  39031,  40171,   2675,
            527,    264,   5575,   6975,    311,   3350,     13,    720,    286,
          19121,    279,   2728,  10137,    304,    264,   2867,    323,  56887,
           1648,     13,   1115,   8712,    374,   5818,   3477,     11,   4587,
           4686,    433,   1615,    285,    989,    198,     27,    524,  39031,
          58645,    791,   6303,    374,    832,    315,    279,    387,  66028,
          65562,     60], device='cuda:0'),
 'child_response': tensor([128000,  45147,  31868,  65562,     60,   1134,  39031,  40171,   2675,
            527,    264,   5575,   6975,    311,   3350,     13,    720,    286,
          19121,    279,   2728,  10137,    304,    264,   2867,    323,  56887,
           1648,     13,   1115,   8712,    374,   5818,   3477,     11,   4587,
           4686,    433,   1615,    285,    989,    198,     27,    524,  39031,
          58645,    791,   6303,

In [7]:
# test batch interaction
prompts = ["The matrix is one of the best m", "what is thi stu"]
results = model.batch_interact(prompts)
results

{'child_queries': [tensor([128000,  45147,  31868,  65562,     60,   1134,  39031,  40171,   2675,
             527,    264,   5575,   6975,    311,   3350,     13,    720,    286,
           19121,    279,   2728,  10137,    304,    264,   2867,    323,  56887,
            1648,     13,   1115,   8712,    374,   5818,   3477,     11,   4587,
            4686,    433,   1615,    285,    989,    198,     27,    524,  39031,
           58645,    791,   6303,    374,    832,    315,    279,   1888,    296,
           66028,  65562,     60], device='cuda:0'),
  tensor([128000,  45147,  31868,  65562,     60,   1134,  39031,  40171,   2675,
             527,    264,   5575,   6975,    311,   3350,     13,    720,    286,
           19121,    279,   2728,  10137,    304,    264,   2867,    323,  56887,
            1648,     13,   1115,   8712,    374,   5818,   3477,     11,   4587,
            4686,    433,   1615,    285,    989,    198,     27,    524,  39031,
           58645,  12840,   