In [1]:
# Dependencies
"""
meteor
"""

'\nmeteor\n'

In [1]:
import os
HTTP_PROXY = 'http://10.10.78.61:3128'
HTTPS_PROXY = 'http://10.10.78.61:3128'

os.environ['http_proxy'] = HTTP_PROXY
os.environ['https_proxy'] = HTTPS_PROXY

# set path for locally downloaed models
student_path = "models/metaresearch/llama-3.2/transformer/1b"
teacher_path = "models/metaresearch/llama-3.1/transformers/8b/2"

In [2]:
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from evaluate import load
from transformers import BitsAndBytesConfig
from peft import get_peft_model, LoraConfig, TaskType
import torch.nn as nn

In [3]:
# Load and process datasets
def load_datasets():
    # Summarization
    cnn_dm = load_dataset("cnn_dailymail", "3.0.0", split="train[:30000]")
    
    # Question Answering
    squad = load_dataset("squad_v2", split="train[:30000]")
    
    # Paraphrase Generation
    quora = load_dataset("quora", split="train[:80000]")
    # quora = load_dataset("quora", split="train[:363861]")
    print("train dataset loaded")
    
    return {
        "summarization": cnn_dm,
        "qa": squad,
        "paraphrase": quora
    }

In [5]:
def load_test_split():
    # Summarization
    cnn_dm = load_dataset("cnn_dailymail", "3.0.0", split="test[100:120]")
    
    # Question Answering
    squad = load_dataset("squad_v2", split="validation[100:120]")
    
    # Paraphrase Generation
    quora = load_dataset("quora", split="train[100:120]")
    #quora = load_dataset("quora", split="train[363861:]")
    print("test dataset loaded")
    return {
        "summarization": cnn_dm,
        "qa": squad,
        "paraphrase": quora
    }

In [6]:
data = load_test_split()

KeyboardInterrupt: 

In [6]:
# from datasets import load_dataset

# def load_test_split():
#     test_data = {}

#     # Load CNN/DailyMail for summarization
#     cnn = load_dataset("cnn_dailymail", "3.0.0", split="test[:100]")  # Adjust size as needed
#     test_data["summarization"] = [
#         {"prompt": example["article"], "target": example["highlights"]}
#         for example in cnn
#     ]

#     # Load SQuAD for QA
#     squad = load_dataset("squad", split="validation[:100]")  # SQuAD uses 'validation' as test
#     test_data["qa"] = [
#         {"prompt": f"question: {example['question']} context: {example['context']}",
#          "target": example["answers"]["text"][0] if example["answers"]["text"] else ""}
#         for example in squad
#     ]

#     # Load Quora for paraphrasing
#     quora = load_dataset("glue", "quora", split="test[:100]")  # Note: 'test' split doesn't have labels
#     # If you need labels (to find paraphrases), use validation split or your own format
#     test_data["paraphrase"] = [
#         {"prompt": example["question1"], "target": example["question2"]}
#         for example in quora
#         if example["question1"] is not None and example["question2"] is not None
#     ]

#     return test_data


In [7]:
# Format datasets into consistent prompt structure
def format_datasets(datasets):
    formatted = {}
    
    # Summarization formatting
    summarization_data = []
    for example in datasets["summarization"]:
        prompt = f"Summarize the following article:\n{example['article']}"
        target = example['highlights']
        summarization_data.append({"prompt": prompt, "target": target, "task": "summarization"})
    formatted["summarization"] = summarization_data
    
    # QA formatting
    qa_data = []
    for example in datasets["qa"]:
        prompt = f"Context: {example['context']}\nQuestion: {example['question']}"
        target = example['answers']['text'][0] if len(example['answers']['text']) > 0 else "No answer available."
        qa_data.append({"prompt": prompt, "target": target, "task": "qa"})
    formatted["qa"] = qa_data
    
    # Paraphrase formatting
    paraphrase_data = []
    for example in datasets["paraphrase"]:
        if example['is_duplicate']:  # Only use duplicate pairs for paraphrasing
            prompt = f"Paraphrase the following:\n{example['questions']['text'][0]}"
            target = example['questions']['text'][1]
            paraphrase_data.append({"prompt": prompt, "target": target, "task": "paraphrase"})
    formatted["paraphrase"] = paraphrase_data
    print("format the data in desired form")
    return formatted

In [7]:
## teacher model setup

class TeacherModel:
    def __init__(self, model_path=teacher_path):
        print("*** Initializing teacher model ***")
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        # os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2"
        # # Manually set pad_token_id if missing
        # if self.tokenizer.pad_token_id is None:
        #     if self.tokenizer.eos_token_id is not None:
        #         self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        #     else:
        #         # If eos_token_id is also missing, assign a safe value like 0 or add a new token
        #         self.tokenizer.pad_token = self.tokenizer.eos_token or self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
        #         self.tokenizer.pad_token_id = self.tokenizer.pad_token_id
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_uant=True,
            bnb_4bit_quant_type="nf4",
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            # torch_dtype=torch.bfloat16,
            quantization_config=bnb_config,

            device_map="cuda:0"
        )
        # Set model config pad_token_id
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model.config.pad_token_id = self.tokenizer.pad_token_id
        self.model.eval()  # Set to evaluation mode
        
    def generate_outputs(self, prompts, tokenizer, max_new_tokens):
        """Generate logits and outputs from the teacher model"""
        outputs = []
        logits_list = []
        # hidden_states_list = []
        inputs = tokenizer(
            prompts, 
            return_tensors="pt", 
            padding=True, 
            truncation=True,
            return_attention_mask=True
        ).to(self.device)
        
        with torch.no_grad():
            # Generate text outputs
            model_output = self.model(
                input_ids=inputs.input_ids,
                attention_mask=inputs.attention_mask,
                output_hidden_states=False # As in original code
            )
            
            # Get logits for each generation step
            # Stack all scores (logits) from each generation step
            all_prompt_logits = model_output.logits
            
            
            # Split logits to match each prompt's output
            logits_list = []
            # `attention_mask` is 1 for real tokens, 0 for padding. Summing gives actual length.
            actual_lengths = inputs.attention_mask.sum(dim=1)
            for i in range(all_prompt_logits.size(0)): # Iterate over batch
                prompt_len = actual_lengths[i].item()
                # Slice to get logits only for the actual tokens of this prompt
                # Shape: (prompt_len, vocab_size)
                prompt_specific_logits = all_prompt_logits[i, :prompt_len, :]
                # Add a leading batch dimension of 1 as per the original implied output structure
                logits_list.append(prompt_specific_logits.unsqueeze(0))
                
        return [], logits_list
        # with torch.no_grad():
        #     for prompt in prompts:
        #         inputs = tokenizer(prompt, return_tensors="pt").to(self.device)
                
        #         # Get full output distribution (logits)
        #         model_output = self.model(**inputs, output_hidden_states=False)
        #         logits = model_output.logits
        #         # hidden_states = model_output.hidden_states[-1]  # Last layer hidden states
                
        #         # Generate text output
        #         generated = self.model.generate(
        #             **inputs,
        #             max_new_tokens=max_new_tokens,
        #             do_sample=False
        #         )
        #         decoded = tokenizer.decode(generated[0], skip_special_tokens=True)
                
        #         outputs.append(decoded)
        #         logits_list.append(logits)
        #         # hidden_states_list.append(hidden_states)
                
        # return outputs, logits_list

In [19]:
class StudentSystem:
    def __init__(self,model_path=student_path):
        self.device = "cuda:3" if torch.cuda.is_available() else "cpu"
        print("*** Initializing student model ***")

        # Load the base model (shared backbone)
        self.base_model = AutoModelForCausalLM.from_pretrained(
            model_path,
            device_map="cuda:3",
        )
        
        # Initialize task-specific adapters
        self.setup_task_adapters()
        
    def setup_task_adapters(self):
        """Initialize LoRA adapters for each task"""
        
        # Define LoRA configurations for each task
        # Adjust r and alpha based on your parameter budget
        lora_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            r=8,  # Low-rank dimension
            lora_alpha=32,
            lora_dropout=0.1,
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]
        )
        
        # Create task-specific LoRA adapters
        self.task_adapters = {
            "summarization": get_peft_model(self.base_model, lora_config),
            "qa": get_peft_model(self.base_model, lora_config),
            "paraphrase": get_peft_model(self.base_model, lora_config)
        }
        
    def task_router(self, prompt):
        """Determine which task adapter to use based on the prompt"""
        prompt_lower = prompt.lower()
        # Simple keyword-based routing
        if any(term in prompt_lower for term in ["summarize", "summary", "summarization"]):
            return "summarization"
        elif any(term in prompt_lower for term in ["question", "answer", "context"]):
            return "qa"
        elif any(term in prompt_lower for term in ["paraphrase", "rephrase", "rewrite"]):
            return "paraphrase"
        else:
            # Default to the most general task or run a more sophisticated classifier
            return "summarization"
    
    def generate(self, prompt, tokenizer, max_new_tokens=100):
        """Generate response using the appropriate task adapter"""
        task = self.task_router(prompt)
        adapter = self.task_adapters[task]
        
        inputs = tokenizer(prompt, return_tensors="pt").to(self.device)
        
        with torch.no_grad():
            outputs = adapter.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False
            )
            
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return response, task

In [None]:
#### Knowledge Distillation Process
class KnowledgeDistillation:
    def __init__(self, teacher, student, tokenizer):
        self.teacher = teacher
        self.student = student
        self.tokenizer = tokenizer
        self.device = student.device
        
    def compute_kd_loss(self, student_logits, teacher_logits, temperature=2.0,attention_mask=None):
        """Compute knowledge distillation loss"""
         # Ensure teacher_logits is a tensor
        if not isinstance(teacher_logits, torch.Tensor):
            teacher_logits = torch.tensor(teacher_logits).to(self.device)
    
        # Add batch dimension if needed 
        if len(teacher_logits.shape) == 2:
            teacher_logits = teacher_logits.unsqueeze(0)
        # Both logits should have the same shape
        if student_logits.shape != teacher_logits.shape:
            # Truncate to the smallest length
            min_length = min(student_logits.shape[1], teacher_logits.shape[1])
            student_logits = student_logits[:, :min_length, :]
            teacher_logits = teacher_logits[:, :min_length, :]
            if attention_mask is not None:
                attention_mask = attention_mask[:, :min_length]

        # Soften the distributions
        soft_teacher = F.softmax(teacher_logits / temperature, dim=-1)
        soft_student = F.log_softmax(student_logits / temperature, dim=-1)
        # Compute KL divergence loss
        if attention_mask is not None:
            # Apply mask to focus on non-padding tokens
            mask = attention_mask.unsqueeze(-1).expand_as(soft_student)
            kd_loss = F.kl_div(
                soft_student * mask, 
                soft_teacher * mask, 
                reduction="sum"
            ) / (mask.sum() + 1e-8)  # Add small epsilon to avoid division by zero
        else:
            kd_loss = F.kl_div(
                soft_student, 
                soft_teacher, 
                reduction="batchmean"
            )
        return kd_loss * (temperature ** 2)
    
    def compute_task_loss(self, student_outputs, targets,attention_mask= None):
        """Compute task-specific loss (e.g., cross-entropy for next token prediction)"""
        # Standard language modeling loss
        shift_logits = student_outputs.logits[..., :-1, :].contiguous()
        shift_labels = targets[..., 1:].contiguous()

        # Truncate both to the same minimum length
        min_len = min(shift_logits.size(1), shift_labels.size(1))
        shift_logits = shift_logits[:, :min_len, :]
        shift_labels = shift_labels[:, :min_len]
        if attention_mask is not None:
            shift_attention_mask = attention_mask[..., 1:].contiguous()
            loss = F.cross_entropy(
                shift_logits.reshape(-1, shift_logits.size(-1)),
                shift_labels.reshape(-1),
                reduction='none'
            )
            # print('shift_attention_mask: ',shift_attention_mask.size())
            # print("loss:",loss.size())
            shift_attention_mask = shift_attention_mask[:, :min_len]
            # print('shift_attention_mask changed: ',shift_attention_mask.size())

            loss = loss * shift_attention_mask.reshape(-1)
            task_loss = loss.sum() / shift_attention_mask.sum()
        else:
            task_loss = F.cross_entropy(
                shift_logits.reshape(-1, shift_logits.size(-1)),
                shift_labels.reshape(-1)
            )
        return task_loss
    
    def train_step(self, batch, task, alpha=0.5):
        """Single training step combining KD and task losses"""
        prompts = batch["prompt"]
        targets = batch["target"]
        
        max_input_length = 512
        # Prepare inputs for student
        inputs = self.tokenizer(prompts,
                                return_tensors="pt", 
                                padding=True,
                                truncation=True,
                                max_length=max_input_length).to(self.device)
        target_ids = self.tokenizer(targets,
                                    return_tensors="pt",
                                    padding=True,
                                    truncation=True,
                                    max_length=max_input_length).to(self.device)
        # Get teacher outputs and logits
        with torch.no_grad():
            teacher_outputs, teacher_logits_list = self.teacher.generate_outputs(prompts,
                                                                                    self.tokenizer,
                                                                                    max_new_tokens=50
                                                                                    )

        # Forward pass through student
        student_adapter = self.student.task_adapters[task]
        student_outputs = student_adapter(**inputs)
        
        actual_lengths = inputs.attention_mask.sum(dim=1)
        # for i in range(all_prompt_logits.size(0)): # Iterate over batch
        #     prompt_len = actual_lengths[i].item()
        #     # Slice to get logits only for the actual tokens of this prompt
        #     # Shape: (prompt_len, vocab_size)
        #     prompt_specific_logits = all_prompt_logits[i, :prompt_len, :]
        #     # Add a leading batch dimension of 1 as per the original implied output structure
        #     logits_list.append(prompt_specific_logits.unsqueeze(0))

        total_kd_loss = 0
        for i in range(len(prompts)):
            # Extract the individual student logits for this example
            student_logit = student_outputs.logits[i:i+1,:actual_lengths[i], :]  # Keep batch dimension
             # Extract the corresponding teacher logits - convert to tensor if not already
            teacher_logit = teacher_logits_list[i].to(self.device)
            if not isinstance(teacher_logit, torch.Tensor):
                teacher_logit = torch.tensor(teacher_logit).to(self.device)

            # Add batch dimension if needed
            if len(teacher_logit.shape) == 2:
                teacher_logit = teacher_logit.unsqueeze(0)

            # Extract attention mask for this example if available
            attention_mask = inputs["attention_mask"][i:i+1] if "attention_mask" in inputs else None

            # Compute KD loss for this example
            example_kd_loss = self.compute_kd_loss(
                student_logit, 
                teacher_logit,
                attention_mask=attention_mask
            )
        
            total_kd_loss += example_kd_loss
        
         # Average KD loss across batch
        kd_loss = total_kd_loss / len(prompts)
        
        # Compute task loss normally
        task_loss = self.compute_task_loss(
            student_outputs, 
            target_ids["input_ids"],
            attention_mask=target_ids["attention_mask"] if "attention_mask" in target_ids else None
        )
        
        # Combined loss with task-specific weighting
        alpha_map = {
            "summarization": 0.6,  # More weight on KD for summarization
            "qa": 0.5,             # Equal weight for QA
            "paraphrase": 0.4      # More weight on task loss for paraphrase
        }
        task_alpha = alpha_map.get(task, alpha)
        total_loss = task_alpha * kd_loss + (1 - task_alpha) * task_loss
        
        return total_loss, kd_loss, task_loss

In [2]:
# teacher = TeacherModel(teacher_path)
student = StudentSystem(student_path)
# prompts = [
#     "Summarize the following article:\nThe new iPhone was released yesterday with improved camera capabilities...",
#     # "Context: The Eiffel Tower is located in Paris, France. It was completed in 1889.\nQuestion: Where is the Eiffel Tower located?",
#     # "Paraphrase the following question:\nHow do I reset my password on this website?"
# ]

NameError: name 'StudentSystem' is not defined

In [1]:
student.task_adapters['qa'].print_trainable_parameters()

NameError: name 'student' is not defined

In [12]:
# student.base_model.device
# teacher.model.device

In [None]:
class MultiTaskDataset(Dataset):
    def __init__(self, formatted_data):
        self.data = []
        for task, examples in formatted_data.items():
            for example in examples:
                self.data.append(example)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

def train_student_system(teacher, student, tokenizer, formatted_data, epochs=3):
    # Create dataset and dataloader
    dataset = MultiTaskDataset(formatted_data)
    dataloader = DataLoader(dataset, batch_size=6, shuffle=True,num_workers=4)
    
    # Initialize distillation trainer
    distiller = KnowledgeDistillation(teacher, student, tokenizer)
    
    # Optimizer - one per task adapter
    optimizers = {
        task: torch.optim.AdamW(adapter.parameters(), lr=5e-5)
        for task, adapter in student.task_adapters.items()
    }
    print("traing loop starts")
    # Training loop
    for epoch in range(epochs):
        total_loss = 0
        total_examples = 0
        
        for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
            # Group examples by task
            task_examples = {}
            for i, task in enumerate(batch["task"]):
                if task not in task_examples:
                    task_examples[task] = {"prompt": [], "target": []}
                task_examples[task]["prompt"].append(batch["prompt"][i])
                task_examples[task]["target"].append(batch["target"][i])
            
            # Process each task separately
            batch_loss = 0
            batch_count =0
            for task, examples in task_examples.items():
                # Skip if no examples for this task in this batch
                if not examples["prompt"]:
                    continue
                
                # Number of examples for this task
                num_examples = len(examples["prompt"])
                batch_count += num_examples
                # Compute loss
                loss, kd_loss, task_loss = distiller.train_step(examples, task)
           
            
                # Backward pass and optimization
                optimizers[task].zero_grad()
                loss.backward()
                # Add gradient clipping
                torch.nn.utils.clip_grad_norm_(student.task_adapters[task].parameters(), max_norm=1.0)
                optimizers[task].step()

                # Accumulate weighted loss
                batch_loss += loss.item()*num_examples
    
            total_loss += batch_loss
            total_examples += batch_count
        
        # Calculate average loss properly (weighted by number of examples)
        avg_loss = total_loss / total_examples if total_examples > 0 else 0
        print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")
        
    # Save trained adapters
    for task, adapter in student.task_adapters.items():
        adapter.save_pretrained(f"./adapter_{task}")
    
    return student

In [14]:
def evaluate_model(student, test_data, tokenizer):
    print("evaluating results")
    # Load metrics
    rouge = load("rouge")
    bert_score = load("bertscore")
    sacrebleu = load("sacrebleu")
    #meteor = load("meteor")
    
    # results = {
    #     "summarization": {"rouge_l": []},
    #     "qa": {"rouge_l": [], "bert_score": []},
    #     "paraphrase": {"sacrebleu": [], "meteor": []}
    # }
    results = {
        "summarization": {"rouge_l": []},
        "qa": {"rouge_l": [], "bert_score": []},
        "paraphrase": {"sacrebleu": [],}
    }
    
    for task, examples in test_data.items():
        if not examples:
            continue
        for example in tqdm(examples, desc=f"Evaluating {task}"):
            prompt = example["prompt"]
            reference = example["target"]

            if not reference.strip():
                continue
            
            # Generate prediction
            prediction, detected_task = student.generate(prompt, tokenizer)
            
            if not prediction.strip():
                continue
            if task != detected_task:
                print(task, detected_task)
                print(prompt)
            # Calculate task-specific metrics
            if detected_task == "summarization":
                rouge_scores = rouge.compute(predictions=[prediction], references=[reference])
                results[detected_task]["rouge_l"].append(rouge_scores['rougeL'])
                
            elif detected_task == "qa":
                rouge_scores = rouge.compute(predictions=[prediction], references=[reference])
                bert_scores = bert_score.compute(predictions=[prediction], references=[reference], lang="en")
                
                results[detected_task]["rouge_l"].append(rouge_scores['rougeL'])
                results[detected_task]["bert_score"].append(bert_scores["f1"][0])
                
            elif detected_task == "paraphrase":
                sacrebleu_score = sacrebleu.compute(predictions=[prediction], references=[[reference]])
                #meteor_score = meteor.compute(predictions=[prediction], references=[reference])
                
                results[detected_task]["sacrebleu"].append(sacrebleu_score["score"])
                #results[task]["meteor"].append(meteor_score["meteor"])
    
    # Aggregate results
    print(results)
    aggregated = {}
    for task, metrics in results.items():
        aggregated[task] = {}
        for metric_name, scores in metrics.items():
            aggregated[task][metric_name] = sum(scores) / len(scores)
    
    return aggregated

In [15]:
def quantize_model(student):
    """Quantize the student model to improve inference efficiency"""
    from bitsandbytes.nn import Linear8bitLt
    print("quantization process")
    # Replace full precision linear layers with 8-bit quantized layers
    for task, adapter in student.task_adapters.items():
        # Keep a copy of adapter weights
        adapter_weights = {}
        for name, param in adapter.named_parameters():
            if "lora" in name:  # Save only LoRA parameters
                adapter_weights[name] = param.data.clone()
        
        # Quantize base model for this adapter
        for name, module in adapter.named_modules():
            if isinstance(module, torch.nn.Linear) and not "lora" in name:
                parent_name = name.rsplit(".", 1)[0] if "." in name else ""
                parent = adapter if parent_name == "" else adapter.get_submodule(parent_name)
                layer_name = name.rsplit(".", 1)[1] if "." in name else name
                setattr(parent, layer_name, Linear8bitLt.from_float(module))
        # Restore adapter weights
        for name, param in adapter.named_parameters():
            if name in adapter_weights:
                param.data.copy_(adapter_weights[name])
    
    return student

In [16]:
import warnings
from transformers.utils import logging

# Disable only Hugging Face warnings
logging.set_verbosity_error()

# Or use the Python warnings filter to suppress pad_token_id-specific messages
warnings.filterwarnings("ignore", message=".*pad_token_id.*")


In [17]:
# def main():
#     # Setup - teacher path from saved model
#     tokenizer = AutoTokenizer.from_pretrained(teacher_path)
#     tokenizer.pad_token = tokenizer.eos_token
    
#     # Load and format datasets
#     raw_datasets = load_datasets()
#     formatted_data = format_datasets(raw_datasets)
    
#     # Initialize models
#     teacher = TeacherModel()
#     student = StudentSystem()
    
#     # Train student system
#     trained_student = train_student_system(
#         teacher, 
#         student, 
#         tokenizer, 
#         formatted_data,
#         epochs=1
#     )
    
#     # Quantize for efficiency
#     #quantized_student = quantize_model(trained_student)

#     # load_test_split
#     test_datasets = load_test_split()
#     test_data = format_datasets(test_datasets)

#     # Evaluate
#     evaluation_results = evaluate_model(
#         trained_student,
#         test_data,
#         tokenizer
#     )
    
#     print("Evaluation Results:")
#     print(evaluation_results)
    
#     # Save final model
#     for task, adapter in trained_student.task_adapters.items():
#         adapter.save_pretrained(f"./final_adapter_{task}")

# if __name__ == "__main__":
#     main()

In [18]:
tokenizer = AutoTokenizer.from_pretrained(teacher_path)
tokenizer.pad_token = tokenizer.eos_token

# Initialize models
teacher = TeacherModel()
student = StudentSystem()


*** Initializing teacher model ***


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

*** Initializing student model ***


In [8]:
# Train student system
# Load and format datasets
raw_datasets = load_datasets()
formatted_data = format_datasets(raw_datasets)


train dataset loaded
format the data in desired form


In [None]:
out= {}

In [None]:
dataset = torch.load('./logits.pt',  map_location="cuda")
prompts = list(dataset.keys())


  dataset = torch.load('./logits.pt',  map_location="cuda")


In [15]:
out={}
prompts = set(prompts)
for task,examples in tqdm(formatted_data.items()):
    for example in examples:
        if example['prompt'] in prompts:
            # print(example)
            out[example['prompt']] = {'target':example['target'], 'task':task}

  0%|          | 0/3 [00:00<?, ?it/s]

100%|██████████| 3/3 [00:00<00:00, 10.35it/s]


In [17]:
torch.save(out, './target.pt')

In [None]:
        self.dataset = torch.load(top_k_path,  map_location="cuda:3")

In [None]:

trained_student = train_student_system(
    teacher, 
    student, 
    tokenizer, 
    formatted_data,
    epochs=1
)

# Quantize for efficiency
#quantized_student = quantize_model(trained_student)


traing loop starts


Epoch 1/1:   1%|          | 1090/138152 [1:09:03<153:36:44,  4.03s/it]

In [None]:

# load_test_split
test_datasets = load_test_split()
test_data = format_datasets(test_datasets)

# Evaluate
evaluation_results = evaluate_model(
    trained_student,
    test_data,
    tokenizer
)

print("Evaluation Results:")
print(evaluation_results)

# Save final model
# for task, adapter in trained_student.task_adapters.items():
#     adapter.save_pretrained(f"./final_adapter_{task}")

test dataset loaded
format the data in desired form
evaluating results


Evaluating summarization: 100%|██████████| 20/20 [00:18<00:00,  1.10it/s]
Evaluating qa: 100%|██████████| 20/20 [00:06<00:00,  2.89it/s]
Evaluating paraphrase: 100%|██████████| 4/4 [00:10<00:00,  2.65s/it]

{'summarization': {'rouge_l': [0.055865921787709494, 0.16279069767441862, 0.07317073170731708, 0.07089552238805971, 0.056644880174291944, 0.08043875685557587, 0.04597701149425287, 0.05078124999999999, 0.04866743916570104, 0.058287795992714025, 0.023897058823529414, 0.04008438818565401, 0.10714285714285714, 0.10135135135135136, 0.04684317718940937, 0.072992700729927, 0.09523809523809523, 0.0737564322469983, 0.08970976253298153, 0.1377672209026128]}, 'qa': {'rouge_l': [0.0, 0.0, 0.0, 0.03508771929824561, 0.0, 0.0, 0.0, 0.0091324200913242, 0.009259259259259259, 0.009216589861751152, 0.0, 0.0, 0.0, 0.019801980198019802, 0.09523809523809523, 0.0196078431372549, 0.02, 0.0, 0.0, 0.0], 'bert_score': [0.8267420530319214, 0.8266420364379883, 0.8258287906646729, 0.7967613339424133, 0.8242578506469727, 0.8242000937461853, 0.825376033782959, 0.7772008776664734, 0.7792404294013977, 0.7784525156021118, 0.8083816170692444, 0.8094208836555481, 0.8086572885513306, 0.8000931143760681, 0.839154064655304, 


