In [1]:
import os
HTTP_PROXY = 'http://10.10.78.61:3128'
HTTPS_PROXY = 'http://10.10.78.61:3128'

os.environ['http_proxy'] = HTTP_PROXY
os.environ['https_proxy'] = HTTPS_PROXY

# set path for locally downloaed models
student_path = "models/metaresearch/llama-3.2/transformer/1b"
teacher_path = "models/metaresearch/llama-3.1/transformers/8b/2"

In [2]:
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from evaluate import load
from transformers import BitsAndBytesConfig
from peft import get_peft_model, LoraConfig, TaskType
import torch.nn as nn
import os
import numpy as np
import json
import threading

In [4]:
# Load and process datasets
def load_datasets():
    # Summarization
    cnn_dm = load_dataset("cnn_dailymail", "3.0.0", split="train[:15000]")
    
    # Question Answering
    squad = load_dataset("squad_v2", split="train[:15000]")
    
    # Paraphrase Generation
    # quora = load_dataset("quora", split="train[:3]")
    # quora = load_dataset("quora", split="train[:363861]")
    quora = load_dataset("quora", split="train[:40000]")
    print("train dataset loaded")
    
    return {
        "summarization": cnn_dm,
        "qa": squad,
        "paraphrase": quora
    }
# Format datasets into consistent prompt structure
def format_datasets(datasets):
    formatted = {}
    
    # Summarization formatting
    summarization_data = []
    for example in datasets["summarization"]:
        prompt = f"Summarize the following article:\n{example['article']}"
        target = example['highlights']
        summarization_data.append({"prompt": prompt, "target": target, "task": "summarization"})
    formatted["summarization"] = summarization_data
    
    # QA formatting
    qa_data = []
    for example in datasets["qa"]:
        prompt = f"Context: {example['context']}\nQuestion: {example['question']}"
        target = example['answers']['text'][0] if len(example['answers']['text']) > 0 else "No answer available."
        qa_data.append({"prompt": prompt, "target": target, "task": "qa"})
    formatted["qa"] = qa_data
    
    # Paraphrase formatting
    paraphrase_data = []
    for example in datasets["paraphrase"]:
        if example['is_duplicate']:  # Only use duplicate pairs for paraphrasing
            prompt = f"Paraphrase the following:\n{example['questions']['text'][0]}"
            target = example['questions']['text'][1]
            paraphrase_data.append({"prompt": prompt, "target": target, "task": "paraphrase"})
    formatted["paraphrase"] = paraphrase_data
    print("format the data in desired form")
    return formatted

In [None]:
## teacher model setup

class TeacherModel:
    def __init__(self, model_path=teacher_path, gpu=3):
        print("*** Initializing teacher model ***")
        # os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu)

        self.device = f"cuda:{gpu}" if torch.cuda.is_available() else "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            # torch_dtype=torch.bfloat16,
            quantization_config=bnb_config,

            device_map=self.device
        )
        # Set model config pad_token_id
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model.config.pad_token_id = self.tokenizer.pad_token_id
        self.model.eval()  # Set to evaluation mode
        
    def generate_outputs(self, prompts, max_new_tokens):
        """Generate logits and outputs from the teacher model"""
        outputs = []
        logits_list = []
        # hidden_states_list = []
        inputs = self.tokenizer(
            prompts, 
            return_tensors="pt", 
            padding=True, 
            truncation=True,
            return_attention_mask=True
        ).to(self.device)
        
        with torch.no_grad():
            # Generate text outputs
            model_output = self.model(
                input_ids=inputs.input_ids,
                attention_mask=inputs.attention_mask,
                output_hidden_states=True # As in original code
            )
            
            # Get logits for each generation step
            # Stack all scores (logits) from each generation step
            all_prompt_logits = model_output.logits
            hidden_state = model_output.hu
            
            # Split logits to match each prompt's output
            logits_list = []
            # `attention_mask` is 1 for real tokens, 0 for padding. Summing gives actual length.
            actual_lengths = inputs.attention_mask.sum(dim=1)
            for i in range(all_prompt_logits.size(0)): # Iterate over batch
                prompt_len = actual_lengths[i].item()
                # Slice to get logits only for the actual tokens of this prompt
                # Shape: (prompt_len, vocab_size)
                prompt_specific_logits = all_prompt_logits[i, :prompt_len, :]
                # Add a leading batch dimension of 1 as per the original implied output structure
                logits_list.append(prompt_specific_logits.unsqueeze(0))
                
        return logits_list

In [5]:
# Load and format datasets
datasets = load_datasets()
formatted_data = format_datasets(datasets)

# Flatten all tasks into one list for inference
all_data = formatted_data["summarization"] + formatted_data["qa"] + formatted_data["paraphrase"]

# Extract prompts only
prompts = [example["prompt"] for example in all_data]

# Initialize teacher model
teacher1 = TeacherModel(model_path=teacher_path, gpu=0)
teacher2 = TeacherModel(model_path=teacher_path, gpu=1)
teacher3 = TeacherModel(model_path=teacher_path, gpu=2)
teacher4 = TeacherModel(model_path=teacher_path, gpu=3)


train dataset loaded
format the data in desired form
*** Initializing teacher model ***


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

*** Initializing teacher model ***


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

*** Initializing teacher model ***


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

*** Initializing teacher model ***


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [5]:
datasets = load_datasets()
formatted_data = format_datasets(datasets)

train dataset loaded
format the data in desired form


In [6]:
formatted_data

{'summarization': [{'prompt': 'Summarize the following article:\nLONDON, England (Reuters) -- Harry Potter star Daniel Radcliffe gains access to a reported £20 million ($41.1 million) fortune as he turns 18 on Monday, but he insists the money won\'t cast a spell on him. Daniel Radcliffe as Harry Potter in "Harry Potter and the Order of the Phoenix" To the disappointment of gossip columnists around the world, the young actor says he has no plans to fritter his cash away on fast cars, drink and celebrity parties. "I don\'t plan to be one of those people who, as soon as they turn 18, suddenly buy themselves a massive sports car collection or something similar," he told an Australian interviewer earlier this month. "I don\'t think I\'ll be particularly extravagant. "The things I like buying are things that cost about 10 pounds -- books and CDs and DVDs." At 18, Radcliffe will be able to gamble in a casino, buy a drink in a pub or see the horror film "Hostel: Part II," currently six places 

In [7]:
class ThreadWithReturnValue(threading.Thread):
    def __init__(self, group=None, target=None, name=None,
                 args=(), kwargs=None, *, daemon=None):
        super().__init__(group=group, target=target, name=name,
                         args=args, kwargs=kwargs or {}, daemon=daemon)
        self._return = None
        self._exception = None # To store exception if one occurs

    def run(self):
        if self._target is not None:
            try:
                self._return = self._target(*self._args, **self._kwargs)
            except Exception as e:
                self._exception = e # Store the exception

    def join(self, *args, **kwargs):
        super().join(*args, **kwargs)
        if self._exception:
            raise self._exception # Re-raise exception in the main thread
        return self._return


In [8]:
import threading
import time

from collections import defaultdict
class thread_out():
    def __init__(self):
        self.out= defaultdict()
        self.i=0
    def add_data(self, prompts, logits):
        self.i+=1
        for prompt, logits_tensor in zip(prompts, logits):
        # Convert logits tensor to list (removes batch dim)
            self.out[prompt] = logits_tensor.squeeze(0).tolist()
        if self.i%5==0:
            self.save()
    
    def save(self, path='./_logists.json'):
        with open(path, 'w') as f:
            json.dump(self.out, f)

In [9]:
class ThreadSafeLogitSaver:
    def __init__(self, save_interval_batches=20, intermediate_save_path='./_logits_intermediate.json'):
        self.all_logits_data = {}  # Use a regular dict
        self.processed_batch_count = 0
        self.lock = threading.Lock()
        self.save_interval_batches = save_interval_batches
        self.intermediate_save_path = intermediate_save_path
        print(f"Saver initialized. Intermediate saves every {save_interval_batches} batches to {self.intermediate_save_path}")

    def process_and_add_batch_data(self, prompts_in_batch, logit_tensors_in_batch):
        # This part is CPU-bound and done by each worker thread
        # No lock needed here as it operates on local data before updating shared dict
        local_batch_data = {}
        for prompt, logits_tensor in zip(prompts_in_batch, logit_tensors_in_batch):
            # Ensure logits are on CPU before .tolist() if they might be on GPU
            # squeezed_logits = logits_tensor.squeeze(0).cpu().tolist() # if on GPU
            squeezed_logits = logits_tensor.squeeze(0).tolist() # if already on CPU or .cpu() handled by generate_outputs
            local_batch_data[prompt] = squeezed_logits

        # Now, acquire lock to update shared data and handle periodic save
        with self.lock:
            self.all_logits_data.update(local_batch_data)
            self.processed_batch_count += 1
            # print(f"Batch {self.processed_batch_count} processed. Total prompts: {len(self.all_logits_data)}") # Debug
            if self.processed_batch_count % self.save_interval_batches == 0:
                self._save_data_unsafe(self.intermediate_save_path, is_intermediate=True)

    def _save_data_unsafe(self, path, is_intermediate=False):
        # This method MUST be called with the lock held
        action = "Intermediate" if is_intermediate else "Final"
        print(f"{action} saving {len(self.all_logits_data)} prompts to {path}...")
        try:
            with open(path, "w") as f:  # Corrected: 'w' mode
                json.dump(self.all_logits_data, f, indent=2) # indent for readability
            print(f"{action} save to {path} successful.")
        except Exception as e:
            print(f"Error during {action.lower()} save to {path}: {e}")

    def final_save(self, path):
        with self.lock: # Ensure exclusivity for the final save
            self._save_data_unsafe(path, is_intermediate=False)

In [10]:
# logit_saver = ThreadSafeLogitSaver(save_interval_batches=50) # Save intermediate every 50 batches
# output_path = './logits_out/out.json'
# batch_size = 32  # Adjust as needed
# # max_new_tokens = 100 # This might be unused if your teacher.generate_outputs is simplified
# obj = thread_out()
# active_threads = []
# # Determine a reasonable number of concurrent processing threads
# # Too many can lead to GIL contention or overwhelm CPU for .tolist()
# # Too few might not keep up if .tolist() is very slow.
# max_concurrent_threads = max(1, os.cpu_count() - 1) if os.cpu_count() else 4 # Leave one CPU for main/GPU tasks
# print(f"Using up to {max_concurrent_threads} concurrent processing threads.")


# for i in tqdm(range(0, len(prompts), batch_size), desc="Generating and queueing logits for processing"):
#     batch_prompts = prompts[i:i+batch_size]
#     if not batch_prompts:
#         continue

#     # Ensure your teacher.generate_outputs returns only the list of logit tensors for the prompts
#     # If it returns (None, batch_logits_list), adjust accordingly.
#     # Assuming simplified return: batch_logit_tensors_list = teacher.generate_outputs(...)
#     _ignored_outputs, batch_logit_tensors_list = teacher.generate_outputs(
#         batch_prompts,
#         teacher.tokenizer, # Use teacher's own tokenizer or the one you intend
#         max_new_tokens=1 # Minimal if only prompt logits are needed
#     )

#     # Clean up finished threads from the active_threads list
#     active_threads = [t for t in active_threads if t.is_alive()]

#     # If we have too many active threads, wait for some to complete
#     while len(active_threads) >= max_concurrent_threads:
#         # print(f"Max concurrent threads ({max_concurrent_threads}) reached. Waiting...") # Debug
#         time.sleep(0.1) # Brief pause
#         active_threads = [t for t in active_threads if t.is_alive()] # Re-check

#     # Create and start a new thread for processing this batch
#     thread = threading.Thread(
#         target=obj.out,
#         args=(batch_prompts, batch_logit_tensors_list)
#     )
#     # thread = threading.Thread(
#     #     target=logit_saver.process_and_add_batch_data,
#     #     args=(batch_prompts, batch_logit_tensors_list)
#     # )
#     thread.start()
#     active_threads.append(thread)

# # Wait for all remaining processing threads to complete
# print("Main generation loop finished. Waiting for all processing threads to complete...")
# for t in tqdm(active_threads, desc="Joining threads"):
#     t.join()

# # Perform the final save
# logit_saver.final_save(path=output_path)
# print(f"All logits processed. Final data saved for {len(logit_saver.all_logits_data)} prompts to {output_path}")

In [11]:
from safetensors.torch import save_file


In [None]:
tokenizer=teacher1.tokenizer
logits_dict = {}
logits=[]
batch_size = 32
max_new_tokens=100
output_path="prompt_logits.json"
obj = thread_out()
tensors = {}
for i in tqdm(range(0, len(prompts), batch_size), desc="Generating logits"):
    batch_prompts = prompts[i:i+batch_size]
    batch_logits=[]
    # _, batch_logits = teacher1.generate_outputs(batch_prompts, tokenizer, max_new_tokens=max_new_tokens)
    t1 = ThreadWithReturnValue(target=teacher1.generate_outputs, args=(batch_prompts[:8], max_new_tokens))
    t2 = ThreadWithReturnValue(target=teacher2.generate_outputs, args=(batch_prompts[8:16], max_new_tokens,))
    t3 = ThreadWithReturnValue(target=teacher3.generate_outputs, args=(batch_prompts[16:24], max_new_tokens,))
    t4 = ThreadWithReturnValue(target=teacher4.generate_outputs, args=(batch_prompts[24:32], max_new_tokens,))
    t1.start()
    t2.start()
    t3.start()
    t4.start()
    
    batch_logits.extend(t1.join())
    batch_logits.extend(t2.join())
    batch_logits.extend(t3.join())
    batch_logits.extend(t4.join())
    
    # t1 = threading.Thread(target=obj.add_data, args=(batch_prompts,batch_logits,))
    # t1.start()
    for prompt, logits_tensor in zip(batch_prompts, batch_logits):
        top_k_values, top_k_indices = torch.topk(logits_tensor, 100, dim=-1)
        tensors[prompt] = [top_k_values.to('cpu'), top_k_indices.to('cpu')]
        # tensors[prompt] = {'top_k_values':top_k_values.cpu(), 'top_k_indices':top_k_indices.cpu()}
    if i%100==0:
        torch.save(tensors, "model.pt")

# Save to JSON
# with open(output_path, "w") as f:
#     json.dump(logits_dict, f)
# obj.save()
# print(f"Saved logits for {len(logits_dict)} prompts to {output_path}")

Generating logits:   0%|          | 1/1403 [00:23<9:01:34, 23.18s/it]


KeyboardInterrupt: 

In [None]:
# # Function to batch process prompts and save logits to JSON
# def generate_logits(teacher_model, prompts, tokenizer, output_path="prompt_logits.json", batch_size=8, max_new_tokens=100):
#     logits_dict = {}
#     logits=[]
#     for i in tqdm(range(0, len(prompts), batch_size), desc="Generating logits"):
#         batch_prompts = prompts[i:i+batch_size]
#         _, batch_logits = teacher_model.generate_outputs(batch_prompts, tokenizer, max_new_tokens=max_new_tokens)
#         logits.extend(batch_logits)
#         # for prompt, logits_tensor in zip(batch_prompts, batch_logits):
#         #     # Convert logits tensor to list (removes batch dim)
#         #     logits_dict[prompt] = logits_tensor.squeeze(0).tolist()

#     # Save to JSON
#     with open(output_path, "w") as f:
#         json.dump(logits_dict, f)

#     print(f"Saved logits for {len(logits_dict)} prompts to {output_path}")

# # Run the generation and save process
# generate_logits(teacher, prompts, teacher.tokenizer)

Generating logits:   0%|          | 6/69076 [01:11<229:46:48, 11.98s/it]


KeyboardInterrupt: 

In [None]:
logits

NameError: name 'logits' is not defined