In [32]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    AutoProcessor,
    CLIPVisionModel,
    AdamW, 
    get_linear_schedule_with_warmup,
    AutoModel,
    AutoModelForSeq2SeqLM
)
import json
import logging
import faiss
from PIL import Image
import re
from tqdm import tqdm

In [33]:
# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

In [34]:
# Set random seeds for reproducibility
def set_seed(seed: int = 11):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
set_seed()

In [35]:
os.environ["MISTRAL_API_KEY"]="jJAuJZkjVcy2ynUhan375sHNviHiBeJU"

In [36]:
# Config class to hold hyperparameters
class Config:
    def __init__(self):
        # Base Model
        self.model_name = "bert-base-uncased"  # can be replaced with any suitable LLM
        self.tokenizer_name = "bert-base-uncased"
        
        # Vision Model
        self.vision_model_name = "openai/clip-vit-base-patch32"
        
        # ScienceQA Dataset
        self.file_path = "/Users/Viku/Datasets/ScienceQA"
        self.train_path = "/Users/Viku/Datasets/ScienceQA/train/train.json"
        self.val_path = "/Users/Viku/Datasets/ScienceQA/val/val.json"
        self.max_seq_length = 512
        self.batch_size = 4
        
        # Training
        self.learning_rate = 5e-5
        self.weight_decay = 0.01
        self.epochs = 3
        self.warmup_steps = 100
        self.max_grad_norm = 1.0
        self.gradient_accumulation_steps = 8
        
        # RL Training
        self.ppo_epochs = 4
        self.reward_scale = 0.01
        self.clip_param = 0.2
        self.value_loss_coef = 0.5
        self.entropy_coef = 0.01
        
        # Transformer Refiner
        self.refiner_model_name = "bert-base-uncased"  # Can be smaller than main model
        self.refiner_learning_rate = 2e-5
        self.refiner_weight_decay = 0.01
        self.refiner_batch_size = 16
        self.refiner_epochs = 2
        self.refiner_max_seq_length = 256  # Can be shorter than main model
        
        # Retrieval
        self.retrieval_top_k = 3
        self.embedding_dim = 768
        
        # Reflection
        self.reflection_threshold = 0.7
        
        # Paths
        self.output_dir = "outputs/"
        self.checkpoint_dir = "checkpoints/"
        self.exemplar_path = "data/exemplars.json"
        
        # Device
        os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
        self.device = torch.device("mps" if torch.mps.is_available() else "cpu")

        # Self Training
        self.max_answer_length = 64
        self.rl_updates = 1000
        self.self_training_iterations = 3

config = Config()

In [37]:
from transformers import BertTokenizer, BertLMHeadModel, AutoImageProcessor, BertConfig

# Define the initialize_bert function
def initialize_bert():
    logger.info(f"Initializing BERT model: {config.model_name}")
    
    # First get the original config
    bert_config = BertConfig.from_pretrained(config.model_name)
    
    # Set is_decoder=True
    bert_config.is_decoder = True
    
    # Initialize tokenizer normally
    tokenizer = BertTokenizer.from_pretrained(config.tokenizer_name)
    
    # Initialize model with this modified config
    model = BertLMHeadModel.from_pretrained(
        config.model_name, 
        config=bert_config,
        ignore_mismatched_sizes=True
    )
    
    vision_processor = AutoImageProcessor.from_pretrained(config.vision_model_name)
    return model, tokenizer, vision_processor

# Now call the function to initialize the models
bert_model, bert_tokenizer, vision_processor = initialize_bert()
logger.info(f"BERT model initialized successfully")

# After initializing the model
model_path = os.path.join(config.checkpoint_dir, "bert_decoder")
os.makedirs(model_path, exist_ok=True)
bert_model.save_pretrained(model_path)
bert_tokenizer.save_pretrained(model_path)

# Then reload it
bert_model = BertLMHeadModel.from_pretrained(model_path)
bert_tokenizer = BertTokenizer.from_pretrained(model_path)


2025-03-06 15:47:12,567 - __main__ - INFO - Initializing BERT model: bert-base-uncased
2025-03-06 15:47:15,753 - __main__ - INFO - BERT model initialized successfully


In [38]:
class ScienceQADataset(Dataset):
    def __init__(self, file_path, tokenizer, vision_processor, config, is_train=True):
        self.tokenizer = tokenizer
        self.vision_processor = vision_processor
        self.config = config
        self.is_train = is_train
        self.base_dir = os.path.dirname(file_path)  # Get directory containing the JSON file
        self.data = self.load_and_preprocess_data(file_path)
        
    def load_and_preprocess_data(self, file_path):
        logger.info(f"Loading ScienceQA data from {file_path}")
        with open(file_path, 'r') as f:
            data = json.load(f)
        
        processed_data = []
        for item in data:
            # Extract fields specific to ScienceQA
            question = item.get('question', '')
            context = item.get('context', '')
            choices = item.get('choices', [])
            answer = item.get('answer', '')
            explanation = item.get('explanation', '')
            question_id = item.get('id', '')  # Get question ID for image path
            
            # Format choices as text
            choices_text = ""
            for i, choice in enumerate(choices):
                choices_text += f"({chr(65+i)}) {choice} "
            
            # Combine context and question
            full_question = f"Context: {context}\nQuestion: {question}\nChoices: {choices_text}"
            
            # Split explanation into reasoning steps
            steps = self.extract_reasoning_steps(explanation)
            
            # Process image if available
            visual_features = None
            if 'image' in item and item['image']:
                # Construct image path based on question ID in train/val folder structure
                # Assuming question_id corresponds to the folder name
                image_folder = os.path.join(self.base_dir, str(question_id))
                image_files = [f for f in os.listdir(image_folder) if f.endswith('.png')] if os.path.exists(image_folder) else []
                
                if image_files:
                    image_path = os.path.join(image_folder, image_files[0])
                    try:
                        image = Image.open(image_path).convert('RGB')
                        visual_features = self.process_image(image)
                    except Exception as e:
                        logger.error(f"Error processing image {image_path}: {e}")
            
            # Tokenize question
            question_tokens = self.tokenizer.encode(
                "Let's think step by step! " + full_question, 
                add_special_tokens=True,
                truncation=True,
                max_length=self.config.max_seq_length // 2
            )
            
            # Tokenize each step separately
            steps_tokens = []
            for step in steps:
                step_tokens = self.tokenizer.encode(
                    step,
                    add_special_tokens=False,
                    truncation=True,
                    max_length=self.config.max_seq_length // (2 * max(1, len(steps)))
                )
                steps_tokens.append(step_tokens)
            
            # Tokenize answer
            answer_tokens = self.tokenizer.encode(
                f"Therefore, the answer is {answer}",
                add_special_tokens=False,
                truncation=True,
                max_length=self.config.max_seq_length // 4
            )
            
            processed_data.append({
                'question': full_question,
                'question_tokens': question_tokens,
                'steps': steps,
                'steps_tokens': steps_tokens,
                'answer': answer,
                'answer_tokens': answer_tokens,
                'visual_features': visual_features,
                'has_image': visual_features is not None
            })
        
        logger.info(f"Processed {len(processed_data)} ScienceQA examples")
        return processed_data
    
    def debug_label_issues(self):
        """Print detailed debugging info for label issues"""
        # Get a few samples
        for idx in range(3):
            sample = self.__getitem__(idx)
            
            # Get data
            input_ids = sample['input_ids'].tolist()
            attention_mask = sample['attention_mask'].tolist()
            labels = sample['labels'].tolist()
            
            # Count stats
            total = len(labels)
            non_ignored = sum(1 for l in labels if l != -100)
            padded = attention_mask.count(0)
            non_padded = attention_mask.count(1)
            
            print(f"\n=== SAMPLE {idx} ===")
            print(f"Total length: {total}")
            print(f"Non-padded tokens: {non_padded}")
            print(f"Padded tokens: {padded}")
            print(f"Non-ignored labels: {non_ignored}")
            print(f"Non-ignored percentage: {non_ignored/total*100:.2f}%")
            print(f"Non-ignored / Non-padded ratio: {non_ignored/max(1,non_padded)*100:.2f}%")
            
            # Show a sample of the tokens
            print("\nSample tokens (first 10):")
            for i in range(min(10, len(input_ids))):
                input_token = self.tokenizer.decode([input_ids[i]])
                label_val = labels[i]
                label_token = self.tokenizer.decode([label_val]) if label_val != -100 else "IGNORED"
                mask = attention_mask[i]
                
                print(f"Pos {i}: Input='{input_token}' | Label='{label_token}' | Mask={mask}")
            
            # Print stats on where ignored labels are
            print("\nIgnored label positions:")
            ignored_positions = [i for i, l in enumerate(labels) if l == -100]
            if len(ignored_positions) > 20:
                print(f"{ignored_positions[:10]} ... {ignored_positions[-10:]}")
            else:
                print(ignored_positions)
    
    def process_image(self, image):
        """Process image using CLIP vision encoder"""
        print("Processing Image")
        inputs = self.vision_processor(images=image, return_tensors="pt")
        with torch.no_grad():
            vision_model = CLIPVisionModel.from_pretrained(self.config.vision_model_name)
            vision_model.to(self.config.device)
            outputs = vision_model(**{k: v.to(self.config.device) for k, v in inputs.items()})
            visual_features = outputs.pooler_output.cpu().numpy()
        return visual_features[0]  # Return the feature vector
    
    def extract_reasoning_steps(self, explanation):
        """Extract reasoning steps from explanation"""
        # Method 1: Split by numbered steps if present
        numbered_pattern = re.compile(r'\d+\.\s+')
        if numbered_pattern.search(explanation):
            steps = [step.strip() for step in numbered_pattern.split(explanation) if step.strip()]
            if steps and not steps[0][0].isdigit():  # Remove introduction if it doesn't start with a number
                steps = steps[1:]
            return steps or [explanation]
        
        # Method 2: Split by sentences assuming each sentence is a step
        sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', explanation)
        if len(sentences) > 1:
            return [s.strip() for s in sentences if s.strip()]
        
        # Default: Treat the whole explanation as one step
        return [explanation]
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Combine question, steps, and answer tokens for input
        input_tokens = item['question_tokens'].copy()
        for step_tokens in item['steps_tokens']:
            input_tokens.extend(step_tokens)
        input_tokens.extend(item['answer_tokens'])
        
        # Pad or truncate to max sequence length
        if len(input_tokens) > self.config.max_seq_length:
            input_tokens = input_tokens[:self.config.max_seq_length]
        
        # Create attention mask
        attention_mask = [1] * len(input_tokens)
        padding_length = self.config.max_seq_length - len(input_tokens)
        input_tokens.extend([self.tokenizer.pad_token_id] * padding_length)
        attention_mask.extend([0] * padding_length)
        
        # FIXED: Create shifted labels for causal language modeling
        # Each token should predict the next token in the sequence
        labels = input_tokens.copy()
        labels = [-100] + labels[:-1]  # Shift right by one position
        
        # Make sure padding tokens are ignored in loss
        for i in range(len(attention_mask)):
            if attention_mask[i] == 0:
                labels[i] = -100
        
        # Optional: If you still want to focus more on reasoning steps and answer
        # You can keep some labels as -100 but not all of them
        question_length = len(item['question_tokens'])
        if question_length > 10:  # Only if question is long enough
            # Keep first and last few tokens of question, mask middle ones
            # This maintains signal while focusing on important parts
            middle_start = min(5, question_length // 4)
            middle_end = min(question_length - 5, question_length * 3 // 4)
            for i in range(middle_start, middle_end):
                if i < self.config.max_seq_length:
                    labels[i] = -100
                    
        return {
            'input_ids': torch.tensor(input_tokens, dtype=torch.long),
            'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
            'labels': torch.tensor(labels, dtype=torch.long),
            'question': item['question'],
            'steps': item['steps'],
            'answer': item['answer'],
            'visual_features': torch.tensor(item['visual_features'], dtype=torch.float) if item['visual_features'] is not None else torch.zeros(768),
            'has_image': item['has_image']
        }

In [39]:
# 2. Base LLM & Chain-of-Thought Generator

class ChainOfThoughtGenerator(nn.Module):
    def __init__(self, config, model=None):
        super().__init__()
        self.config = config
        
        # First, initialize the tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name)
        
        # Then check for special tokens
        special_tokens = {"pad_token": "[PAD]"} if self.tokenizer.pad_token is None else {}
        
        # Now initialize the model
        if model is not None:
            self.model = model
        else:
            # Initialize model from config
            self.model = AutoModelForCausalLM.from_pretrained(config.model_name)
        
        # Add special tokens if needed
        if special_tokens:
            self.tokenizer.add_special_tokens(special_tokens)
            self.model.resize_token_embeddings(len(self.tokenizer))
        
        # The rest of the code remains the same
        # Vision encoder for multimodal inputs
        self.vision_processor = AutoProcessor.from_pretrained(config.vision_model_name)
        self.vision_model = CLIPVisionModel.from_pretrained(config.vision_model_name)
        
        # Vision-language integration layer
        self.vision_projection = nn.Linear(
            self.vision_model.config.hidden_size,
            self.model.config.hidden_size
        )
        
        # Move models to device
        self.model.to(config.device)
        self.vision_model.to(config.device)
        self.vision_projection.to(config.device)

        print("ChainOfThoughtGenerator initialized with:")
        print(f"  Tokenizer: {config.tokenizer_name}")
        print(f"  Model: {config.model_name}")
        print(f"  Vision Model: {config.vision_model_name}")
        print(f"  Device: {config.device}")

    def encode_image(self, image):
        """Encode image using vision model"""
        print("Encoding image...")
        vision_inputs = self.vision_processor(images=image, return_tensors="pt")
        vision_inputs = {k: v.to(self.config.device) for k, v in vision_inputs.items()}
        
        with torch.no_grad():
            vision_outputs = self.vision_model(**vision_inputs)
            image_features = vision_outputs.pooler_output
            projected_features = self.vision_projection(image_features)
        
        print("Image encoding completed.")
        return projected_features
    
    def forward(self, input_ids, attention_mask, labels=None, visual_features=None):
        # """Forward pass with optional visual features"""
        # print("Forward pass started.")
        # print(f"  Input IDs: {input_ids.shape}")
        # print(f"  Attention Mask: {attention_mask.shape}")
        # if labels is not None:
        #     print(f"  Labels: {labels.shape}")

        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )

        # print("Model forward pass completed.")

        # If visual features are available, enhance the hidden states
        if visual_features is not None:
            # print("Processing visual features...")
            if hasattr(outputs, 'hidden_states') and outputs.hidden_states is not None:
                projected_visual = self.vision_projection(visual_features)
                last_hidden = outputs.hidden_states[-1]
                enhanced_hidden = last_hidden + projected_visual.unsqueeze(1)
                # print("Visual features integrated into hidden states.")

        return outputs
    
    def generate_step_by_step(self, question, image=None, num_steps=5, max_length=512):
        """Generate a chain-of-thought reasoning process for a given question"""
        print(f"Generating reasoning for question: {question}")

        # Prepare input
        prompt = f"Let's think step by step! {question}"
        inputs = self.tokenizer(prompt, return_tensors="pt")
        inputs = {k: v.to(self.config.device) for k, v in inputs.items()}
        
        print(f"Tokenized input: {inputs}")

        # Process image if provided
        visual_embedding = None
        if image is not None:
            print("Processing image for reasoning...")
            visual_embedding = self.encode_image(image)

        # Generate reasoning steps and answer
        with torch.no_grad():
            print("Generating response from model...")
            outputs = self.model.generate(
                **inputs,
                max_length=max_length,
                num_return_sequences=1,
                temperature=0.7,
                do_sample=True
            )
        
        # Decode the generated text
        generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        print(f"Generated text: {generated_text}")

        # Extract reasoning steps and answer
        generated_text = generated_text[len(prompt):]  # Remove the prompt
        parts = generated_text.split("\nTherefore, the answer is")
        
        reasoning = parts[0]
        answer = parts[1] if len(parts) > 1 else "No clear answer provided."

        # Split reasoning into steps
        steps = []
        for step in reasoning.split("\n"):
            if step.strip():
                steps.append(step.strip())

        print("Generated reasoning steps:")
        for i, step in enumerate(steps):
            print(f"  Step {i+1}: {step}")

        print(f"Final answer: {answer.strip()}")

        return {
            "question": question,
            "steps": steps,
            "answer": answer.strip(),
            "full_text": generated_text
        }

In [40]:
# 3. Reflection Module

class ReflectionModule(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # Use a smaller model for efficiency
        self.encoder = AutoModel.from_pretrained("distilbert-base-uncased")
        
        # Scoring layers
        self.coherence_scorer = nn.Linear(self.encoder.config.hidden_size, 1)
        self.language_scorer = nn.Linear(self.encoder.config.hidden_size, 1)
        self.progress_scorer = nn.Linear(self.encoder.config.hidden_size, 1)
        
        # Tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
        
        # Move to device
        self.encoder.to(config.device)
        self.coherence_scorer.to(config.device)
        self.language_scorer.to(config.device)
        self.progress_scorer.to(config.device)
    
    def forward(self, question, steps, previous_steps=None):
        """
        Evaluate the quality of reasoning steps
        
        Args:
            question: The original question
            steps: List of reasoning steps to evaluate
            previous_steps: Optional previous steps for context
        
        Returns:
            Dictionary of scores for each step and overall
        """
        all_scores = []
        
        # Process each step
        for i, step in enumerate(steps):
            # Create context from question and previous steps
            context = question
            if previous_steps:
                context += " " + " ".join(previous_steps)
            
            # Tokenize
            inputs = self.tokenizer(
                context, 
                step, 
                truncation=True, 
                padding=True, 
                return_tensors="pt"
            )
            inputs = {k: v.to(self.config.device) for k, v in inputs.items()}
            
            # Get embeddings
            with torch.no_grad():
                outputs = self.encoder(**inputs)
                pooled_output = outputs.last_hidden_state[:, 0]  # Use [CLS] token
            
            # Calculate scores
            coherence_score = torch.sigmoid(self.coherence_scorer(pooled_output)).item()
            language_score = torch.sigmoid(self.language_scorer(pooled_output)).item()
            progress_score = torch.sigmoid(self.progress_scorer(pooled_output)).item()
            
            # Calculate a combined score
            combined_score = (coherence_score + language_score + progress_score) / 3
            
            all_scores.append({
                'step': i+1,
                'coherence': coherence_score,
                'language': language_score,
                'progress': progress_score,
                'combined': combined_score
            })
            
            # Update previous steps for next iteration
            previous_steps = (previous_steps or []) + [step]
        
        # Calculate overall score
        overall_score = sum(s['combined'] for s in all_scores) / len(all_scores) if all_scores else 0
        
        return {
            'step_scores': all_scores,
            'overall_score': overall_score
        }
    
    def evaluate_reasoning(self, question, steps, answer=None):
        """Evaluate the overall reasoning process"""
        step_scores = self.forward(question, steps)
        
        # Check if reasoning meets the threshold
        meets_threshold = step_scores['overall_score'] >= self.config.reflection_threshold
        
        return {
            'scores': step_scores,
            'meets_threshold': meets_threshold,
            'feedback': self.generate_feedback(step_scores) if not meets_threshold else None
        }
    
    def generate_feedback(self, scores):
        """Generate feedback based on scores"""
        feedback = []
        
        for step_score in scores['step_scores']:
            step_num = step_score['step']
            if step_score['coherence'] < 0.6:
                feedback.append(f"Step {step_num} lacks coherence with the context.")
            if step_score['language'] < 0.6:
                feedback.append(f"Step {step_num} has language issues.")
            if step_score['progress'] < 0.6:
                feedback.append(f"Step {step_num} doesn't make sufficient progress toward the answer.")
        
        if not feedback:
            feedback = ["The reasoning needs improvement, but specific issues weren't identified."]
        
        return feedback

In [41]:
# 4. Retrieval Module

class RetrievalModule:
    def __init__(self, config):
        self.config = config
        self.encoder = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
        self.tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
        
        # Move to device
        self.encoder.to(config.device)
        
        # Load exemplars
        self.exemplars = self.load_exemplars()
        
        # Build index for fast retrieval
        self.index = self.build_index()
    
    def load_exemplars(self):
        """Load exemplar reasoning sequences"""
        if not os.path.exists(self.config.exemplar_path):
            logger.warning(f"Exemplar file {self.config.exemplar_path} not found, using empty exemplars")
            return []
        
        with open(self.config.exemplar_path, 'r') as f:
            exemplars = json.load(f)
        
        # Pre-compute embeddings for each exemplar
        for exemplar in exemplars:
            exemplar['embedding'] = self.encode_text(exemplar['question']).cpu().numpy()
        
        logger.info(f"Loaded {len(exemplars)} exemplars")
        return exemplars
    
    def build_index(self):
        """Build FAISS index for fast retrieval"""
        if not self.exemplars:
            return None
        
        # Extract embeddings
        embeddings = np.array([ex['embedding'] for ex in self.exemplars]).astype('float32')
        
        # Build index
        index = faiss.IndexFlatL2(embeddings.shape[1])
        index.add(embeddings)
        
        return index
    
    def encode_text(self, text):
        """Encode text using the sentence transformer"""
        inputs = self.tokenizer(
            text, 
            truncation=True, 
            padding=True, 
            return_tensors="pt"
        )
        inputs = {k: v.to(self.config.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = self.encoder(**inputs)
            # Use mean pooling
            attention_mask = inputs['attention_mask']
            mask = attention_mask.unsqueeze(-1).expand(outputs.last_hidden_state.size()).float()
            masked_embeddings = outputs.last_hidden_state * mask
            summed = torch.sum(masked_embeddings, 1)
            counts = torch.clamp(torch.sum(mask, 1), min=1e-9)
            mean_pooled = summed / counts
        
        return mean_pooled[0]  # Return the embedding vector
    
    def retrieve_similar_examples(self, question, k=None):
        """Retrieve similar exemplars for a given question"""
        if k is None:
            k = self.config.retrieval_top_k
        
        if not self.exemplars or self.index is None:
            return []
        
        # Encode the query
        query_embedding = self.encode_text(question).cpu().numpy().reshape(1, -1).astype('float32')
        
        # Search for similar examples
        distances, indices = self.index.search(query_embedding, k)
        
        # Get the exemplars
        results = []
        for i, idx in enumerate(indices[0]):
            if idx < len(self.exemplars):
                exemplar = self.exemplars[idx].copy()
                exemplar['similarity'] = float(1.0 / (1.0 + distances[0][i]))  # Convert distance to similarity
                exemplar.pop('embedding', None)  # Remove embedding from result
                results.append(exemplar)
        
        return results
    
    def get_demonstration_prompt(self, question, k=None):
        """Get a few-shot demonstration prompt based on retrieved examples"""
        examples = self.retrieve_similar_examples(question, k)
        
        if not examples:
            return f"Let's think step by step! {question}"
        
        prompt = "I'll solve some similar problems step by step, then answer your question.\n\n"
        
        # Add examples
        for i, example in enumerate(examples):
            prompt += f"Example {i+1}:\n"
            prompt += f"Question: {example['question']}\n"
            prompt += "Reasoning:\n"
            for j, step in enumerate(example['steps']):
                prompt += f"{j+1}. {step}\n"
            prompt += f"Answer: {example['answer']}\n\n"
        
        # Add the current question
        prompt += f"Now, let's solve your question step by step!\n{question}\n"
        
        return prompt

In [42]:
# 5. Transformer Module

class SimpleTransformerRefiner:
    """A lightweight transformer-based model for refining reasoning steps"""
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # Use a small, efficient transformer model
        self.model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
        self.tokenizer = AutoTokenizer.from_pretrained("t5-small")
        
        # Move model to device
        self.model.to(config.device)
        
    def train(self, train_examples, epochs=3, batch_size=8):
        """Train the refiner model on examples"""
        # Prepare optimizer
        optimizer = AdamW(
            self.model.parameters(),
            lr=self.config.refiner_learning_rate or 5e-5,
            weight_decay=self.config.refiner_weight_decay or 0.01
        )
        
        # Prepare data
        train_inputs = []
        train_targets = []
        
        for example in train_examples:
            # Format: "Question: {question} Previous steps: {prev_steps} Original step: {orig_step}"
            context = f"Question: {example['context']} Original step: {example['step']}"
            train_inputs.append(f"refine: {context}")
            train_targets.append(example['step'])  # Initially target is same as input for good examples
        
        # Convert to dataset
        dataset = self._prepare_dataset(train_inputs, train_targets)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        
        # Training loop
        self.model.train()
        total_steps = len(dataloader) * epochs
        
        # Create scheduler
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=int(0.1 * total_steps),
            num_training_steps=total_steps
        )
        
        # Main training loop
        global_step = 0
        for epoch in range(epochs):
            epoch_loss = 0
            
            # Training with progress bar
            pbar = tqdm(dataloader, desc=f"Training refiner (Epoch {epoch+1}/{epochs})")
            for batch in pbar:
                # Get batch
                input_ids = batch["input_ids"].to(self.config.device)
                attention_mask = batch["attention_mask"].to(self.config.device)
                labels = batch["labels"].to(self.config.device)
                
                # Forward pass
                outputs = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
                
                loss = outputs.loss
                
                # Backward pass
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                
                # Update metrics
                epoch_loss += loss.item()
                global_step += 1
                
                # Update progress bar
                pbar.set_postfix({"loss": f"{loss.item():.4f}"})
                
            # Log epoch stats
            avg_loss = epoch_loss / len(dataloader)
            print(f"Epoch {epoch+1}/{epochs}: Average Loss = {avg_loss:.4f}")
    
    def _prepare_dataset(self, inputs, targets):
        """Prepare a dataset from inputs and targets"""
        input_encodings = self.tokenizer(
            inputs,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors="pt"
        )
        
        target_encodings = self.tokenizer(
            targets,
            padding=True,
            truncation=True,
            max_length=128,
            return_tensors="pt"
        )
        
        # Create dataset
        class RefinementDataset(torch.utils.data.Dataset):
            def __init__(self, input_encodings, target_encodings):
                self.input_encodings = input_encodings
                self.target_encodings = target_encodings
                
            def __len__(self):
                return len(self.input_encodings["input_ids"])
                
            def __getitem__(self, idx):
                return {
                    "input_ids": self.input_encodings["input_ids"][idx],
                    "attention_mask": self.input_encodings["attention_mask"][idx],
                    "labels": self.target_encodings["input_ids"][idx]
                }
                
        dataset = RefinementDataset(input_encodings, target_encodings)
        return dataset
    
    def refine_step(self, context, original_step, max_length=100):
        """Refine a reasoning step"""
        # Prepare input
        input_text = f"refine: Question: {context} Original step: {original_step}"
        
        # Tokenize
        inputs = self.tokenizer(input_text, return_tensors="pt")
        inputs = {k: v.to(self.config.device) for k, v in inputs.items()}
        
        # Generate refined step
        self.model.eval()
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_length=max_length,
                num_beams=4,
                early_stopping=True
            )
        
        # Decode output
        refined_step = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        return refined_step
    
    def refine_reasoning_steps(self, question, original_steps):
        """Refine a sequence of reasoning steps"""
        refined_steps = []
        context = question
        
        for i, step in enumerate(original_steps):
            # Generate refined step
            refined_step = self.refine_step(context, step)
            refined_steps.append(refined_step)
            
            # Update context for next step
            context += f"\nStep {i+1}: {refined_step}"
        
        return refined_steps
    
    def evaluate_refinement(self, examples, reflection_module=None):
        """Evaluate refinement quality"""
        if not examples:
            return {"success_rate": 0, "average_improvement": 0}
        
        success_count = 0
        total_improvement = 0
        
        for example in examples:
            # Get original steps
            question = example['question']
            original_steps = example['steps']
            
            # Generate refined steps
            refined_steps = self.refine_reasoning_steps(question, original_steps)
            
            # Evaluate with reflection module if available
            if reflection_module:
                original_score = reflection_module.evaluate_reasoning(
                    question, original_steps
                )['scores']['overall_score']
                
                refined_score = reflection_module.evaluate_reasoning(
                    question, refined_steps
                )['scores']['overall_score']
                
                improvement = refined_score - original_score
                total_improvement += improvement
                
                if improvement > 0:
                    success_count += 1
            else:
                # Simple evaluation (just count non-identical refinements as success)
                changes = sum(1 for orig, ref in zip(original_steps, refined_steps) if orig != ref)
                if changes > 0:
                    success_count += 1
        
        success_rate = success_count / len(examples) * 100
        avg_improvement = total_improvement / len(examples)
        
        return {
            "success_rate": success_rate,
            "average_improvement": avg_improvement
        }
    
    def save_model(self, path):
        """Save the model"""
        self.model.save_pretrained(path)
        self.tokenizer.save_pretrained(path)
    
    def load_model(self, path):
        """Load the model"""
        self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.model.to(self.config.device)

In [43]:
# 6. Dual Reward Function and RL Training

class RewardFunction:
    """Combines outcome and process rewards for RL training"""
    def __init__(self, reflection_module, config):
        self.reflection_module = reflection_module
        self.config = config
        
        # For outcome verification
        self.tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
        
    def calculate_outcome_reward(self, predicted_answer, ground_truth):
        """Calculate reward based on correctness of final answer"""
        # For ScienceQA, we can do a simple match check
        # In a real system, you might use more sophisticated answer verification
        predicted_normalized = predicted_answer.strip().lower()
        ground_truth_normalized = ground_truth.strip().lower()
        
        # Check for exact match
        if predicted_normalized == ground_truth_normalized:
            return 1.0
        
        # Check for partial match (for longer answers)
        if len(ground_truth_normalized) > 10:
            # Using simple token overlap as a metric
            pred_tokens = set(self.tokenizer.tokenize(predicted_normalized))
            truth_tokens = set(self.tokenizer.tokenize(ground_truth_normalized))
            
            if not truth_tokens:
                return 0.0
                
            overlap = len(pred_tokens.intersection(truth_tokens)) / len(truth_tokens)
            return max(0.0, overlap - 0.3)  # Only reward significant overlap
        
        return 0.0
    
    def calculate_process_reward(self, question, steps, original_steps=None):
        """Calculate reward based on quality of reasoning process"""
        # Get reflection scores
        reflection_result = self.reflection_module.evaluate_reasoning(question, steps)
        process_score = reflection_result['scores']['overall_score']
        
        # If we have original steps, reward improvement
        if original_steps:
            original_result = self.reflection_module.evaluate_reasoning(question, original_steps)
            original_score = original_result['scores']['overall_score']
            
            # Reward improvement, penalize degradation
            improvement = process_score - original_score
            if improvement > 0:
                process_score += 0.2 * improvement  # Bonus for improvement
            else:
                process_score += 0.1 * improvement  # Smaller penalty for degradation
        
        return process_score
    
    def calculate_combined_reward(self, sample, ground_truth, original_steps=None):
        """Calculate combined reward from outcome and process"""
        question = sample['question']
        steps = sample['steps']
        predicted_answer = sample['answer']
        
        # Calculate component rewards
        outcome_reward = self.calculate_outcome_reward(predicted_answer, ground_truth)
        process_reward = self.calculate_process_reward(question, steps, original_steps)
        
        # Combine rewards
        # The balance between outcome and process rewards is important
        # In this implementation, we favor process for ScienceQA
        combined_reward = 0.4 * outcome_reward + 0.6 * process_reward
        
        return {
            'combined': combined_reward,
            'outcome': outcome_reward,
            'process': process_reward,
            'details': {
                'answer_correct': outcome_reward > 0.9,
                'reasoning_quality': process_reward
            }
        }

class PPOTrainer:
    """PPO-based RL trainer for the reasoning model"""
    def __init__(self, cot_generator, reward_function, transformer_refiner, config):
        self.cot_generator = cot_generator
        self.reward_function = reward_function
        self.transformer_refiner = transformer_refiner  # Changed from GAN to transformer
        self.config = config
        
        # Create a reference model for KL penalty
        self.ref_model = AutoModelForCausalLM.from_pretrained(config.model_name)
        self.ref_model.to(config.device)
        self.ref_model.eval()
        
        # Optimizer
        self.optimizer = AdamW(
            self.cot_generator.model.parameters(),
            lr=config.learning_rate,
            weight_decay=config.weight_decay
        )
        
        # Initialize policy entropy and value losses
        self.policy_loss = 0
        self.value_losses = []
        self.entropy_losses = []
    
    def train_step(self, batch, transformer_refiner=None):
        """Perform a single PPO training step"""
        # Extract data
        input_ids = batch['input_ids'].to(self.config.device)
        attention_mask = batch['attention_mask'].to(self.config.device)
        questions = batch['question']
        ground_truth_answers = batch['answer']
        
        # Forward pass with current policy to get initial log probs and values
        with torch.no_grad():
            outputs = self.cot_generator.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                return_dict=True
            )
            old_logits = outputs.logits
            
            # Extract values (implicitly learned through the LM head)
            # In a full implementation, you would have a separate value head
            values = torch.mean(old_logits, dim=-1)  # Simplistic value estimation
        
        # Generate samples from current policy
        generated_samples = []
        for i in range(len(questions)):
            sample = self.cot_generator.generate_step_by_step(questions[i], image=None)
            generated_samples.append(sample)
        
        # Optionally refine with transformer
        if transformer_refiner or self.transformer_refiner:
            refiner = transformer_refiner if transformer_refiner else self.transformer_refiner
            for i, sample in enumerate(generated_samples):
                refined_steps = refiner.refine_reasoning_steps(
                    sample['question'], 
                    sample['steps']
                )
                generated_samples[i]['refined_steps'] = refined_steps
        
        # Calculate rewards
        rewards = []
        for i, sample in enumerate(generated_samples):
            steps_to_evaluate = sample.get('refined_steps', sample['steps'])
            reward = self.reward_function.calculate_combined_reward(
                {
                    'question': sample['question'],
                    'steps': steps_to_evaluate,
                    'answer': sample['answer']
                },
                ground_truth_answers[i]
            )
            rewards.append(reward['combined'])
        
        rewards_tensor = torch.tensor(rewards, device=self.config.device)
        
        # PPO optimization loop
        for _ in range(self.config.ppo_epochs):
            # Forward pass with current policy
            outputs = self.cot_generator.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                return_dict=True
            )
            logits = outputs.logits
            
            # Calculate new log probabilities and values
            new_values = torch.mean(logits, dim=-1)  # Simplistic value estimation
            
            # Compute KL divergence penalty
            kl_div = self._compute_kl_divergence(old_logits, logits, attention_mask)
            
            # Compute policy loss (PPO clipped objective)
            # In a full implementation, you would compute proper action probabilities
            # For simplicity, we're using a proxy based on logits difference
            logit_diff = torch.sum(torch.abs(logits - old_logits), dim=-1)
            policy_ratio = torch.exp(-logit_diff * 0.01)  # Proxy for probability ratio
            
            clipped_ratio = torch.clamp(
                policy_ratio, 
                1.0 - self.config.clip_param, 
                1.0 + self.config.clip_param
            )
            
            policy_reward = rewards_tensor.unsqueeze(-1).expand_as(policy_ratio)
            policy_loss = -torch.min(
                policy_ratio * policy_reward,
                clipped_ratio * policy_reward
            ).mean()
            
            # Value loss
            value_loss = F.mse_loss(new_values, rewards_tensor.unsqueeze(-1).expand_as(new_values))
            
            # Entropy for exploration
            # Simplified entropy calculation
            entropy = torch.mean(torch.std(logits, dim=-1))
            entropy_loss = -self.config.entropy_coef * entropy
            
            # Total loss
            loss = policy_loss + self.config.value_loss_coef * value_loss + entropy_loss + 0.01 * kl_div
            
            # Backward and optimize
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.cot_generator.model.parameters(), self.config.max_grad_norm)
            self.optimizer.step()
            
            # Store metrics
            self.policy_loss = policy_loss.item()
            self.value_losses.append(value_loss.item())
            self.entropy_losses.append(entropy_loss.item())
        
        return {
            'policy_loss': self.policy_loss,
            'value_loss': sum(self.value_losses) / len(self.value_losses),
            'entropy_loss': sum(self.entropy_losses) / len(self.entropy_losses),
            'mean_reward': rewards_tensor.mean().item(),
            'transformer_refinement': transformer_refiner is not None or self.transformer_refiner is not None
        }
    
    def _compute_kl_divergence(self, old_logits, new_logits, attention_mask):
        """Compute KL divergence between old and new policies"""
        old_probs = F.softmax(old_logits, dim=-1)
        new_probs = F.softmax(new_logits, dim=-1)
        
        # KL divergence
        kl = old_probs * (torch.log(old_probs) - torch.log(new_probs))
        kl = kl.sum(-1)
        
        # Apply attention mask
        kl = kl * attention_mask.float()
        
        # Average over non-masked tokens
        kl = kl.sum() / attention_mask.float().sum()
        
        return kl

In [44]:
# 7. Self-Training and Distillation

class SelfTrainer:
    """Self-training through iterative pseudo-labeling"""
    def __init__(self, cot_generator, config):
        self.cot_generator = cot_generator
        self.config = config
        self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name)
        
        # Optimizer for fine-tuning
        self.optimizer = AdamW(
            self.cot_generator.model.parameters(),
            lr=config.learning_rate,
            weight_decay=config.weight_decay
        )
        
        # LR scheduler
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=config.warmup_steps,
            num_training_steps=1000  # Will be updated when dataset size is known
        )
    
    def generate_pseudo_labels(self, unlabeled_data, ground_truth_answers=None):
        """Generate pseudo-labels for unlabeled data"""
        pseudo_labeled_data = []
        
        for i, sample in enumerate(unlabeled_data):
            question = sample['question']
            
            # Generate reasoning steps and answer
            generated = self.cot_generator.generate_step_by_step(question, image=None)
            
            # Check if the answer is correct (if ground truth is available)
            is_correct = False
            if ground_truth_answers is not None:
                ground_truth = ground_truth_answers[i]
                predicted = generated['answer'].strip().lower()
                ground_truth = ground_truth.strip().lower()
                is_correct = predicted == ground_truth
            
            # Only include correct answers or all if no ground truth
            if is_correct or ground_truth_answers is None:
                pseudo_labeled_data.append({
                    'question': question,
                    'steps': generated['steps'],
                    'answer': generated['answer'],
                    'confidence': 1.0  # In a real implementation, you'd use model confidence
                })
        
        return pseudo_labeled_data
    
    def finetune_on_pseudo_labels(self, pseudo_labeled_data, epochs=None):
        """Fine-tune model on pseudo-labeled data"""
        if epochs is None:
            epochs = self.config.epochs
        
        # Create dataset
        dataset = self._create_dataset_from_samples(pseudo_labeled_data)
        
        # Adjust scheduler
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=self.config.warmup_steps,
            num_training_steps=epochs * len(dataset)
        )
        
        # Fine-tuning loop
        self.cot_generator.model.train()
        total_loss = 0
        global_step = 0
        
        for epoch in range(epochs):
            logger.info(f"Starting epoch {epoch+1}/{epochs}")
            epoch_loss = 0
            
            for batch in dataset:
                # Move batch to device
                batch = {k: v.to(self.config.device) if isinstance(v, torch.Tensor) else v 
                         for k, v in batch.items()}
                
                # Forward and backward
                outputs = self.cot_generator.model(
                    input_ids=batch['input_ids'],
                    attention_mask=batch['attention_mask'],
                    labels=batch['labels']
                )
                
                loss = outputs.loss
                epoch_loss += loss.item()
                
                # Backward pass
                loss.backward()
                
                # Clip gradients
                torch.nn.utils.clip_grad_norm_(
                    self.cot_generator.model.parameters(), 
                    self.config.max_grad_norm
                )
                
                # Update weights
                self.optimizer.step()
                self.scheduler.step()
                self.optimizer.zero_grad()
                
                global_step += 1
                
                # Log progress
                if global_step % 50 == 0:
                    logger.info(f"Step {global_step}: loss = {loss.item()}")
            
            avg_epoch_loss = epoch_loss / len(dataset)
            total_loss += avg_epoch_loss
            logger.info(f"Epoch {epoch+1} completed: Average loss = {avg_epoch_loss}")
        
        avg_loss = total_loss / epochs
        logger.info(f"Fine-tuning completed: Average loss = {avg_loss}")
        
        return avg_loss
    
    def _create_dataset_from_samples(self, samples):
        """Create a dataset from generated samples"""
        dataset = []
        
        for sample in samples:
            question = sample['question']
            steps = sample['steps']
            answer = sample['answer']
            
            # Prepare input text
            input_text = f"Let's think step by step! {question}\n"
            for i, step in enumerate(steps):
                input_text += f"{i+1}. {step}\n"
            input_text += f"Therefore, the answer is {answer}"
            
            # Tokenize
            encodings = self.tokenizer(
                input_text,
                truncation=True,
                max_length=self.config.max_seq_length,
                padding="max_length",
                return_tensors="pt"
            )
            
            # Create labels for causal LM training
            input_ids = encodings['input_ids'][0]
            attention_mask = encodings['attention_mask'][0]
            labels = input_ids.clone()
            
            # Mask question part in labels
            question_tokens = self.tokenizer.encode(
                f"Let's think step by step! {question}",
                add_special_tokens=True
            )
            labels[:len(question_tokens)] = -100
            
            dataset.append({
                'input_ids': input_ids,
                'attention_mask': attention_mask,
                'labels': labels
            })
        
        return dataset
    
    def self_training_loop(self, labeled_data, unlabeled_data, num_iterations=3):
        """Run the complete self-training loop"""
        for iteration in range(num_iterations):
            logger.info(f"Starting self-training iteration {iteration+1}/{num_iterations}")
            
            # Generate pseudo-labels
            pseudo_labels = self.generate_pseudo_labels(
                unlabeled_data,
                ground_truth_answers=[item['answer'] for item in unlabeled_data]
            )
            
            if not pseudo_labels:
                logger.warning("No pseudo-labels generated. Stopping self-training.")
                break
            
            logger.info(f"Generated {len(pseudo_labels)} pseudo-labels")
            
            # Combine with labeled data
            combined_data = labeled_data + pseudo_labels
            
            # Fine-tune on combined data
            loss = self.finetune_on_pseudo_labels(combined_data)
            
            logger.info(f"Iteration {iteration+1} completed: loss = {loss}")
            
            # Update labeled data for next iteration
            labeled_data = combined_data
        
        logger.info("Self-training completed")
        
        # Save the final model
        self.cot_generator.model.save_pretrained(os.path.join(self.config.output_dir, "self_trained_model"))
        self.tokenizer.save_pretrained(os.path.join(self.config.output_dir, "self_trained_tokenizer"))
        
        return labeled_data

class ModelDistiller:
    """Knowledge distillation for creating smaller, efficient models"""
    def __init__(self, teacher_model, config, student_model_name="distilbert-base-uncased"):
        self.teacher_model = teacher_model
        self.config = config
        
        # Load smaller student model
        self.student_tokenizer = AutoTokenizer.from_pretrained(student_model_name)
        self.student_model = AutoModelForCausalLM.from_pretrained(student_model_name)
        
        # Add special tokens if needed
        special_tokens = {"pad_token": "[PAD]"} if self.student_tokenizer.pad_token is None else {}
        if special_tokens:
            self.student_tokenizer.add_special_tokens(special_tokens)
            self.student_model.resize_token_embeddings(len(self.student_tokenizer))
        
        # Move to device
        self.student_model.to(config.device)
        
        # Optimizer
        self.optimizer = AdamW(
            self.student_model.parameters(),
            lr=2e-5,  # Usually higher for distillation
            weight_decay=0.01
        )
        
        # Temperature for softening distributions
        self.temperature = 2.0
    
    def distill(self, dataset, epochs=3):
        """Distill knowledge from teacher to student"""
        self.student_model.train()
        self.teacher_model.model.eval()
        
        total_loss = 0
        step_count = 0
        
        for epoch in range(epochs):
            logger.info(f"Starting distillation epoch {epoch+1}/{epochs}")
            epoch_loss = 0
            
            for batch in dataset:
                # Convert to student tokenization
                student_inputs = self._convert_teacher_to_student_inputs(batch)
                
                # Move to device
                student_inputs = {k: v.to(self.config.device) if isinstance(v, torch.Tensor) else v
                                 for k, v in student_inputs.items()}
                
                # Get teacher predictions
                with torch.no_grad():
                    teacher_outputs = self.teacher_model.model(
                        input_ids=batch['input_ids'].to(self.config.device),
                        attention_mask=batch['attention_mask'].to(self.config.device)
                    )
                    
                    # Apply temperature scaling to logits
                    teacher_logits = teacher_outputs.logits / self.temperature
                
                # Student forward pass
                student_outputs = self.student_model(**student_inputs)
                student_logits = student_outputs.logits / self.temperature
                
                # Compute distillation loss
                # Standard cross-entropy loss for task performance
                task_loss = F.cross_entropy(
                    student_logits.view(-1, student_logits.size(-1)),
                    student_inputs['labels'].view(-1),
                    ignore_index=-100
                )
                
                # Distillation loss (KL divergence)
                # We need to align teacher and student token representations
                aligned_teacher_logits = self._align_teacher_student_representations(
                    teacher_logits, 
                    student_logits,
                    batch['attention_mask'].to(self.config.device),
                    student_inputs['attention_mask']
                )
                
                distillation_loss = F.kl_div(
                    F.log_softmax(student_logits, dim=-1),
                    F.softmax(aligned_teacher_logits, dim=-1),
                    reduction='batchmean'
                )
                
                # Combine losses
                loss = 0.5 * task_loss + 0.5 * distillation_loss
                
                # Backward and optimize
                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.student_model.parameters(), 1.0)
                self.optimizer.step()
                
                # Track losses
                epoch_loss += loss.item()
                step_count += 1
                
                # Log progress
                if step_count % 100 == 0:
                    logger.info(f"Distillation step {step_count}: loss = {loss.item()}")
            
            avg_epoch_loss = epoch_loss / len(dataset)
            total_loss += avg_epoch_loss
            logger.info(f"Distillation epoch {epoch+1} completed: Average loss = {avg_epoch_loss}")
        
        avg_loss = total_loss / epochs
        logger.info(f"Distillation completed: Average loss = {avg_loss}")
        
        # Save distilled model
        self.student_model.save_pretrained(os.path.join(self.config.output_dir, "distilled_model"))
        self.student_tokenizer.save_pretrained(os.path.join(self.config.output_dir, "distilled_tokenizer"))
        
        return self.student_model
    
    def _convert_teacher_to_student_inputs(self, teacher_batch):
        """Convert teacher batch to student tokenization"""
        # This is a placeholder - in practice, you would need to implement 
        # conversion between different tokenizers
        return teacher_batch
    
    def _align_teacher_student_representations(self, teacher_logits, student_logits, 
                                              teacher_mask, student_mask):
        """Align teacher and student token representations"""
        # This is a placeholder for token alignment
        # In practice, you'd need to implement vocabulary mapping
        return teacher_logits

In [45]:
# 8. Integration - Full Reasoning Pipeline
class ReasoningPipeline:
    """Full reasoning pipeline integrating all components"""
    def __init__(self, config, bert_model=None, bert_tokenizer=None, vision_processor=None):
        self.config = config
        
        # Use provided models/tokenizers if available, otherwise initialize from config
        if bert_tokenizer is not None:
            self.tokenizer = bert_tokenizer
        else:
            self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name)
            
        if vision_processor is not None:
            self.vision_processor = vision_processor
        else:
            self.vision_processor = AutoProcessor.from_pretrained(config.vision_model_name)
        
        # Add special tokens if needed - FIXED: Make sure all necessary special tokens are present
        special_tokens = {}
        if self.tokenizer.pad_token is None:
            special_tokens["pad_token"] = "[PAD]"
        if self.tokenizer.eos_token is None:
            special_tokens["eos_token"] = "[EOS]"
        if self.tokenizer.bos_token is None:
            special_tokens["bos_token"] = "[BOS]"
        
        if special_tokens:
            self.tokenizer.add_special_tokens(special_tokens)
        
        # FIXED: Resize embeddings if special tokens were added
        if bert_model is not None and special_tokens:
            bert_model.resize_token_embeddings(len(self.tokenizer))
        
        # Initialize components, passing the BERT model to CoT generator if provided
        self.cot_generator = ChainOfThoughtGenerator(config, model=bert_model)
        
        # FIXED: Ensure the CoT model's embedding size matches the tokenizer
        if hasattr(self.cot_generator, 'model'):
            self.cot_generator.model.resize_token_embeddings(len(self.tokenizer))
            
        self.reflection_module = ReflectionModule(config)
        self.retrieval_module = RetrievalModule(config)
        
        # Replace GAN with simple transformer refiner
        self.transformer_refiner = SimpleTransformerRefiner(config)
        
        # Create reward function
        self.reward_function = RewardFunction(self.reflection_module, config)
        
        # Create RL trainer
        self.ppo_trainer = PPOTrainer(self.cot_generator, self.reward_function, self.transformer_refiner, config)
        
        # Create self-trainer
        self.self_trainer = SelfTrainer(self.cot_generator, config)
        
        # Initialize model distiller
        self.distiller = None
        
        # FIXED: Set the loss_type explicitly in the config
        if not hasattr(self.config, 'loss_type') or self.config.loss_type is None:
            self.config.loss_type = 'custom'
    
    def train(self, train_file, val_file, num_rl_updates=1000, num_self_training_iterations=3):
        """Train the full reasoning pipeline"""
        # Load datasets
        train_dataset = ScienceQADataset(
            train_file, 
            self.tokenizer, 
            self.vision_processor, 
            self.config, 
            is_train=True
        )
        
        val_dataset = ScienceQADataset(
            val_file, 
            self.tokenizer, 
            self.vision_processor, 
            self.config, 
            is_train=False
        )

        # DEBUG: Run label debugging
        print("Debugging labels in the dataset:")
        train_dataset.debug_label_issues()
    
        # Create data loaders
        train_dataloader = DataLoader(
            train_dataset, 
            batch_size=self.config.batch_size, 
            shuffle=True
        )
        
        val_dataloader = DataLoader(
            val_dataset, 
            batch_size=self.config.batch_size, 
            shuffle=False
        )

        # Verify model initialization
        loss_verification = self.verify_loss_calculation()
        if not loss_verification:
            print("The model is not producing non-zero loss on test inputs!")
            print("Check your model configuration and loss calculation.")

        # Analyze dataset
        label_stats = self.analyze_labels(train_dataloader)
        
        # FIXED: Check and fix label distribution before training
        if label_stats["non_ignored"] / label_stats["total_labels"] < 0.05:
            print("⚠️ WARNING: Less than 5% of labels are non-ignored. Checking dataset preparation...")
            self._check_and_fix_dataset_preparation(train_dataset)
            
            # Re-analyze after fixes
            print("Re-analyzing labels after dataset fixes...")
            label_stats = self.analyze_labels(train_dataloader)
        
        # 1. Initial supervised fine-tuning
        logger.info("Starting supervised fine-tuning")
        self._supervised_finetuning(train_dataloader, val_dataloader)
        
        # 2. Train refiner instead of GAN
        logger.info("Training transformer refiner")
        self._train_refiner(train_dataloader)
        
        # 3. RL fine-tuning
        logger.info("Starting RL fine-tuning")
        self._rl_finetuning(train_dataloader, num_rl_updates)
        
        # 4. Self-training loop
        logger.info("Starting self-training")
        labeled_data = train_dataset.data[:100]  # Start with a small labeled subset
        unlabeled_data = train_dataset.data[100:]  # The rest is unlabeled
        final_labeled_data = self.self_trainer.self_training_loop(
            labeled_data, 
            unlabeled_data, 
            num_iterations=num_self_training_iterations
        )
        
        # 5. Distillation to smaller model
        logger.info("Starting model distillation")
        self.distiller = ModelDistiller(self.cot_generator, self.config)
        distilled_model = self.distiller.distill(train_dataloader, epochs=3)
        
        logger.info("Training complete")
        
        return {
            'cot_generator': self.cot_generator,
            'reflection_module': self.reflection_module,
            'retrieval_module': self.retrieval_module,
            'refiner': self.refiner,
            'distilled_model': distilled_model
        }
    
    # FIXED: Add method to check and fix dataset preparation issues
    def _check_and_fix_dataset_preparation(self, dataset):
        """Check and fix common dataset preparation issues"""
        print("Performing dataset preparation checks and fixes...")
        
        # Check if the dataset has a prepare_inputs method we can modify
        if hasattr(dataset, 'prepare_inputs'):
            original_prepare = dataset.prepare_inputs
            
            # Define fixed prepare_inputs method
            def fixed_prepare_inputs(item):
                # Call the original preparation
                result = original_prepare(item)
                
                # FIXED: Ensure labels are properly set for causal language modeling
                # For causal LM, typically labels are the same as input_ids but shifted
                if 'input_ids' in result and 'labels' in result:
                    # Clone input_ids for labels
                    input_ids = result['input_ids']
                    
                    # Create labels by shifting input_ids right by one position
                    labels = torch.full_like(input_ids, -100)  # Initialize with -100
                    
                    # For causal LM: labels are next tokens (shifted by 1)
                    if len(input_ids.shape) > 1:  # For batched inputs
                        labels[:, :-1] = input_ids[:, 1:].clone()
                    else:  # For single inputs
                        labels[:-1] = input_ids[1:].clone()
                    
                    # Keep only non-padding positions for loss computation
                    if 'attention_mask' in result:
                        # Only compute loss on positions with attention
                        mask = result['attention_mask'] == 1
                        # Apply mask to labels (keep attention positions, set others to -100)
                        labels = labels * mask + (-100) * (~mask)
                    
                    result['labels'] = labels
                
                return result
            
            # Replace the dataset's prepare_inputs with our fixed version
            dataset.prepare_inputs = fixed_prepare_inputs
            print("✓ Fixed dataset preparation method")
        else:
            print("❌ Could not fix dataset - no prepare_inputs method found")
            
        # Add additional checks and fixes as needed
        print("Dataset preparation checks complete")
    
    def _supervised_finetuning(self, train_dataloader, val_dataloader, epochs=None):
        """Supervised fine-tuning on labeled data"""
        if epochs is None:
            epochs = self.config.epochs
        
        # FIXED: Force model to use our custom loss function
        if hasattr(self.cot_generator.model, 'config'):
            original_loss_config = getattr(self.cot_generator.model.config, 'loss_type', None)
            self.cot_generator.model.config.loss_type = 'custom'
            print(f"Updated loss_type in model config from {original_loss_config} to 'custom'")
        
        # Setup optimizer
        optimizer = AdamW(
            self.cot_generator.model.parameters(),
            lr=self.config.learning_rate,
            weight_decay=self.config.weight_decay
        )
        
        # Setup scheduler
        total_steps = len(train_dataloader) * epochs
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.config.warmup_steps,
            num_training_steps=total_steps
        )
        
        # Training loop
        self.cot_generator.model.train()
        global_step = 0
        best_val_loss = float('inf')
        
        for epoch in range(epochs):
            epoch_loss = 0
            
            # Training with tqdm progress bar
            pbar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{epochs} [Train]", leave=True)
            for batch in pbar:
                # Move batch to device
                input_ids = batch['input_ids'].to(self.config.device)
                attention_mask = batch['attention_mask'].to(self.config.device)
                labels = batch['labels'].to(self.config.device)
                visual_features = None
                if batch.get('visual_features') is not None:
                    visual_features = batch['visual_features'].to(self.config.device)
                
                # FIXED: Ensure we're not accidentally using the default loss
                # Explicitly disable built-in loss calculation in forward pass
                outputs = self.cot_generator.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=None,  # Pass None to prevent internal loss calculation
                    visual_features=visual_features
                )
                
                # Get loss using our improved loss function
                loss = self._calculate_improved_loss(outputs, labels)
                
                # Diagnostic information with less verbose output
                non_ignored = (labels != -100).sum().item()
                total_labels = labels.numel()
                
                # Only log detailed diagnostics if there's an issue or very occasionally
                if non_ignored < 5 or global_step % 50 == 0:
                    print(f"Non-ignored labels: {non_ignored}/{total_labels} ({non_ignored/total_labels*100:.2f}%)")
                    print(f"Raw loss value: {loss.item()}")
                    
                    # Check if your labels have any values within the vocabulary range
                    valid_range = (labels >= 0) & (labels < len(self.tokenizer))
                    valid_count = (valid_range & (labels != -100)).sum().item()
                    print(f"Labels in valid vocab range: {valid_count}")
                
                # Zero loss detection
                if loss.item() == 0:
                    print("\n⚠️ Zero loss detected in batch!")
                    # Inspect some samples from this batch
                    sample_idx = 0  # Check the first sample in batch
                    print(f"Input shape: {input_ids.shape}, Labels shape: {labels.shape}")
                    print(f"Sample input: {self.tokenizer.decode(input_ids[sample_idx])}")
                    
                    # Check where we have non-ignored labels
                    non_ignored_pos = (labels[sample_idx] != -100).nonzero().flatten()
                    if len(non_ignored_pos) > 0:
                        print(f"First few non-ignored positions: {non_ignored_pos[:10].tolist()}")
                        for pos in non_ignored_pos[:5]:
                            input_token = self.tokenizer.decode([input_ids[sample_idx, pos]])
                            label_token = self.tokenizer.decode([labels[sample_idx, pos]])
                            print(f"  Pos {pos}: Input='{input_token}', Label='{label_token}'")
                    else:
                        print("No non-ignored labels found in this sample!")
                        
                    # FIXED: Try to recover with a simple loss if our custom loss fails
                    if non_ignored > 0:
                        print("Attempting to recover with simple cross-entropy loss...")
                        shifted_logits = outputs.logits[:, :-1, :].contiguous().view(-1, outputs.logits.size(-1))
                        shifted_labels = labels[:, 1:].contiguous().view(-1)
                        loss = torch.nn.CrossEntropyLoss(ignore_index=-100)(shifted_logits, shifted_labels)
                        print(f"Recovery loss: {loss.item()}")
                
                # Backward pass
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    self.cot_generator.model.parameters(), 
                    self.config.max_grad_norm
                )
                optimizer.step()
                scheduler.step()
                
                epoch_loss += loss.item()
                global_step += 1
                
                # Update progress bar
                pbar.set_postfix({"loss": f"{loss.item():.4f}", "step": global_step, "device": self.config.device})
            
            avg_train_loss = epoch_loss / len(train_dataloader)
            logger.info(f"Epoch {epoch+1}/{epochs} completed: Average training loss = {avg_train_loss}")
            
            # Validation with tqdm
            val_loss = self._evaluate(val_dataloader)
            logger.info(f"Validation loss: {val_loss}")
            
            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                self.cot_generator.model.save_pretrained(
                    os.path.join(self.config.checkpoint_dir, "best_model")
                )
                logger.info("Saved new best model")
    
    def _calculate_improved_loss(self, outputs, labels):
        """
        Calculate improved loss that combines cross-entropy with auxiliary losses
        to enhance training stability and reasoning capabilities
        """
        # Get logits from model outputs
        logits = outputs.logits
        
        # Initialize weights for each loss component
        weights = {
            "ce_loss": 1.0,  # Cross-entropy loss weight
            "consistency_loss": 0.2,  # Consistency loss weight
            "coverage_loss": 0.1  # Coverage loss weight
        }
        
        # FIXED: Improved handling of labels, with better checks and warnings
        # First check that our inputs are valid
        if labels is None:
            logger.warning("Labels are None, cannot calculate loss")
            return torch.tensor(0.0, device=logits.device, requires_grad=True)
            
        if logits.size(0) != labels.size(0):
            logger.warning(f"Batch size mismatch: logits={logits.size(0)}, labels={labels.size(0)}")
            return torch.tensor(0.0, device=logits.device, requires_grad=True)
            
        if logits.size(1) != labels.size(1):
            # FIXED: Handle sequence length mismatch for causal LM
            logger.warning(f"Sequence length mismatch: logits={logits.size(1)}, labels={labels.size(1)}")
            # Truncate to shorter length
            min_len = min(logits.size(1), labels.size(1))
            logits = logits[:, :min_len, :]
            labels = labels[:, :min_len]
        
        # 1. Cross-entropy loss - standard training loss
        # Create a tensor with -100 weight values for ignored positions
        loss_weights = torch.ones_like(labels, dtype=torch.float)
        loss_weights[labels == -100] = 0.0
        
        # Count non-ignored tokens for debugging
        non_ignored = loss_weights.sum().item()
        if non_ignored == 0:
            logger.warning("No non-ignored labels found, returning zero loss")
            return torch.tensor(0.0, device=logits.device, requires_grad=True)
        
        # FIXED: If we're dealing with a causal LM, we need to shift labels and logits
        if logits.size(1) > 1:  # Only perform shifts for sequence lengths > 1
            # For causal language modeling:
            # - predictions at position i should be for the token at position i+1
            # - we want logits[:, :-1, :] and labels[:, 1:]
            shifted_logits = logits[:, :-1, :].contiguous()
            shifted_labels = labels[:, 1:].contiguous()
            shifted_weights = loss_weights[:, 1:].contiguous()
        else:
            # For single token prediction, no need to shift
            shifted_logits = logits
            shifted_labels = labels
            shifted_weights = loss_weights
            
        # Create a loss function with ignore_index=-100
        ce_loss_fn = torch.nn.CrossEntropyLoss(reduction='none')
        
        # Reshape for loss calculation
        flat_logits = shifted_logits.view(-1, shifted_logits.size(-1))
        flat_labels = shifted_labels.view(-1)
        flat_weights = shifted_weights.view(-1)
        
        # Calculate per-token cross-entropy loss
        per_token_loss = ce_loss_fn(flat_logits, flat_labels)
        
        # Apply weights to ignore padding (-100) tokens
        weighted_loss = per_token_loss * flat_weights
        
        # Get the mean loss over non-ignored tokens
        non_ignored = flat_weights.sum()
        ce_loss = weighted_loss.sum() / (non_ignored + 1e-8)
        
        # 2. Consistency loss - encourages logical coherence between steps
        consistency_loss = torch.tensor(0.0, device=logits.device)
        
        # If we have enough tokens and non-zero tokens to calculate consistency
        if non_ignored > 10:
            # Simple consistency metric: adjacent tokens should have some correlation
            # Get top predictions for each position
            top_preds = logits.argmax(dim=-1)
            
            # Calculate consistency as prediction stability over sequences
            for b in range(logits.size(0)):
                # Check consecutive predictions excluding padding
                valid_positions = (labels[b] != -100).nonzero().flatten()
                if len(valid_positions) > 2:
                    # Get embedding-based consistency
                    token_embeddings = self.cot_generator.model.get_input_embeddings()(top_preds[b, valid_positions])
                    similarities = torch.nn.functional.cosine_similarity(
                        token_embeddings[:-1], token_embeddings[1:], dim=1
                    )
                    # Consistency loss: encourage smooth transitions (higher similarity)
                    consistency_loss += (1.0 - similarities.mean())
        
            # Average across batch
            consistency_loss /= logits.size(0)
        
        # 3. Coverage loss - encourages diverse vocabulary usage
        coverage_loss = torch.tensor(0.0, device=logits.device)
        
        # If we have enough tokens to calculate coverage
        if non_ignored > 5:
            # Get token probability distribution averaged over sequence
            token_probs = torch.softmax(logits.view(-1, logits.size(-1)), dim=-1)
            mean_probs = token_probs.mean(dim=0)
            
            # Calculate negative entropy of this distribution
            # Lower entropy = more concentrated on few tokens = bad
            # Higher entropy = more diverse vocabulary = good
            eps = 1e-8  # For numerical stability
            entropy = -torch.sum(mean_probs * torch.log(mean_probs + eps))
            coverage_loss = 1.0 / (entropy + eps)  # Inverse of entropy
        
        # Combine losses with weights
        total_loss = (
            weights["ce_loss"] * ce_loss + 
            weights["consistency_loss"] * consistency_loss + 
            weights["coverage_loss"] * coverage_loss
        )
        
        # Store loss components for logging - use fields that won't conflict with HF
        outputs.ce_loss_value = ce_loss
        outputs.consistency_loss_value = consistency_loss
        outputs.coverage_loss_value = coverage_loss
        outputs.total_loss_value = total_loss
        
        return total_loss
    
    def _evaluate(self, dataloader):
        """Evaluate model on dataloader"""
        self.cot_generator.model.eval()
        total_loss = 0
        
        # Evaluation with tqdm progress bar
        pbar = tqdm(dataloader, desc="Evaluation", leave=False)
        with torch.no_grad():
            for batch in pbar:
                # Move batch to device
                input_ids = batch['input_ids'].to(self.config.device)
                attention_mask = batch['attention_mask'].to(self.config.device)
                labels = batch['labels'].to(self.config.device)
                visual_features = None
                if batch.get('visual_features') is not None:
                    visual_features = batch['visual_features'].to(self.config.device)
                
                # FIXED: Consistent with training, disable internal loss calculation
                outputs = self.cot_generator.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=None,  # Pass None to prevent internal loss calculation
                    visual_features=visual_features
                )
                
                # Calculate improved loss
                loss = self._calculate_improved_loss(outputs, labels)
                
                total_loss += loss.item()
                
                # Update progress bar
                pbar.set_postfix({"loss": f"{loss.item():.4f}"})
        
        # Set back to training mode
        self.cot_generator.model.train()
        
        return total_loss / len(dataloader)
    
    def _train_refiner(self, dataloader, epochs=None):
        """Train the transformer refiner"""
        if epochs is None:
            epochs = self.config.refiner_epochs or 3
        
        # Create training examples
        refiner_training_data = []
        
        # Sample batches for refiner training with progress bar
        pbar = tqdm(dataloader, desc="Preparing refiner training data", leave=True)
        for batch in pbar:
            questions = batch['question']
            steps_list = batch['steps']
            
            for question, steps in zip(questions, steps_list):
                for i in range(1, len(steps)):
                    context = question + " " + " ".join(steps[:i])
                    refiner_training_data.append({
                        'context': context,
                        'step': steps[i]
                    })
                    
                    # Break after a few examples per sample
                    if i >= 3:
                        break
            
            # Update progress bar
            pbar.set_postfix({"examples": len(refiner_training_data)})
            
            # Limit training data size
            if len(refiner_training_data) >= 1000:
                break
        
        # Train the refiner
        logger.info(f"Training refiner for {epochs} epochs")
        self.refiner.train(refiner_training_data, epochs=epochs)
        
        # Evaluate the refiner
        eval_results = self._evaluate_refiner(dataloader)
        logger.info(f"Refiner evaluation: Success rate = {eval_results['success_rate']:.1f}%, "
                   f"Average improvement = {eval_results['average_improvement']:.3f}")
        
        # Save refiner checkpoint
        os.makedirs(os.path.join(self.config.checkpoint_dir, "refiner"), exist_ok=True)
        self.refiner.save_model(
            os.path.join(self.config.checkpoint_dir, "refiner")
        )
    
    def _evaluate_refiner(self, dataloader):
        """Evaluate refiner quality"""
        # Sample a few examples from validation set
        examples = []
        
        # Use tqdm for sampling
        pbar = tqdm(dataloader, desc="Sampling for refiner evaluation", leave=False)
        for batch in pbar:
            questions = batch['question']
            steps_list = batch['steps']
            
            for question, steps in zip(questions, steps_list):
                if len(steps) >= 3:  # Ensure enough steps for evaluation
                    examples.append({
                        'question': question,
                        'steps': steps
                    })
                
                # Limit evaluation examples
                if len(examples) >= 10:
                    break
            
            if len(examples) >= 10:
                break
                
            # Update progress bar
            pbar.set_postfix({"samples": len(examples)})
        
        # Evaluate refiner quality
        return self.refiner.evaluate_refinement(examples, self.reflection_module)
    
    def _rl_finetuning(self, dataloader, num_updates):
        """Perform RL fine-tuning"""
        # Setup learning rate scheduler
        scheduler = get_linear_schedule_with_warmup(
            self.ppo_trainer.optimizer,
            num_warmup_steps=int(0.1 * num_updates),
            num_training_steps=num_updates
        )
        
        # RL training loop
        global_step = 0
        best_reward = float('-inf')
        
        logger.info(f"Starting RL fine-tuning: {num_updates} updates")
        
        # Create progress bar for RL updates
        pbar = tqdm(total=num_updates, desc="RL Fine-tuning", leave=True)
        
        while global_step < num_updates:
            # Sample batch from dataloader
            for batch in dataloader:
                # Perform PPO update with refiner instead of GAN
                metrics = self.ppo_trainer.train_step(batch, self.refiner)
                
                # Update learning rate
                scheduler.step()
                
                # Update progress bar
                global_step += 1
                pbar.update(1)
                pbar.set_postfix({
                    "policy_loss": f"{metrics['policy_loss']:.4f}", 
                    "value_loss": f"{metrics['value_loss']:.4f}", 
                    "reward": f"{metrics['mean_reward']:.4f}"
                })
                
                # Log metrics periodically
                if global_step % 10 == 0:
                    logger.info(
                        f"RL step {global_step}/{num_updates}: "
                        f"Policy loss = {metrics['policy_loss']:.4f}, "
                        f"Value loss = {metrics['value_loss']:.4f}, "
                        f"Mean reward = {metrics['mean_reward']:.4f}"
                    )
                
                # Save best model
                if metrics['mean_reward'] > best_reward:
                    best_reward = metrics['mean_reward']
                    self.cot_generator.model.save_pretrained(
                        os.path.join(self.config.checkpoint_dir, "best_rl_model")
                    )
                    logger.info(f"New best reward: {best_reward:.4f} - Saved model")
                
                # Evaluate periodically
                if global_step % 100 == 0:
                    eval_reward = self._evaluate_rl()
                    logger.info(f"RL evaluation reward: {eval_reward:.4f}")
                    pbar.set_postfix({
                        "policy_loss": f"{metrics['policy_loss']:.4f}", 
                        "value_loss": f"{metrics['value_loss']:.4f}", 
                        "reward": f"{metrics['mean_reward']:.4f}",
                        "eval": f"{eval_reward:.4f}"
                    })
                
                # Check if we've reached the target number of updates
                if global_step >= num_updates:
                    break
        
        # Close the progress bar
        pbar.close()
        logger.info(f"RL fine-tuning complete: Best reward = {best_reward:.4f}")
    
    def _evaluate_rl(self, num_samples=10):
        """Evaluate current policy"""
        self.cot_generator.model.eval()
        total_reward = 0.0
        
        # Sample examples for evaluation
        examples = []
        temp_data = self.self_trainer.generate_pseudo_labels(
            [{'question': item} for item in self.config.eval_questions]
        )
        data_loader = DataLoader(temp_data, batch_size=1)
        
        # Use tqdm for sampling
        pbar = tqdm(data_loader, desc="Sampling for RL evaluation", leave=False)
        for batch in pbar:
            examples.append(batch)
            if len(examples) >= num_samples:
                break
            
            # Update progress bar
            pbar.set_postfix({"samples": len(examples)})
        
        # Evaluate on examples with progress bar
        eval_pbar = tqdm(examples, desc="RL Evaluation", leave=False)
        for example in eval_pbar:
            # Generate reasoning
            sample = self.cot_generator.generate_step_by_step(example['question'][0])
            
            # Calculate reward
            reward = self.reward_function.calculate_combined_reward(
                sample,
                example['answer'][0]
            )
            
            total_reward += reward['combined']
            
            # Update progress bar
            eval_pbar.set_postfix({"reward": f"{reward['combined']:.4f}"})
        
        # Set back to training mode
        self.cot_generator.model.train()
        
        # Calculate average reward
        avg_reward = total_reward / len(examples)
        return avg_reward
    
    def generate(self, question, image=None):
        """Generate reasoning for a question"""
        logger.info(f"Generating reasoning for: {question}")
        
        # Lookup relevant information
        logger.info("Retrieving context information...")
        retrieved_info = self.retrieval_module.retrieve(question)
        
        # Generate initial chain-of-thought
        logger.info("Generating initial chain-of-thought...")
        result = self.cot_generator.generate_step_by_step(
            question, 
            image=image,
            retrieved_context=retrieved_info
        )
        
        # Apply refiner
        logger.info("Applying transformer refiner...")
        refined_steps = self.refiner.refine_reasoning_steps(
            question,
            result['steps']
        )
        result['refined_steps'] = refined_steps
        
        # Evaluate if refinement is better
        logger.info("Evaluating refinement quality...")
        original_score = self.reflection_module.evaluate_reasoning(
            question,
            result['steps']
        )['scores']['overall_score']
        
        refined_score = self.reflection_module.evaluate_reasoning(
            question,
            refined_steps
        )['scores']['overall_score']
        
        # Use refined steps if better
        if refined_score > original_score:
            logger.info(f"Using refined steps (score improved: {original_score:.2f} -> {refined_score:.2f})")
            result['steps'] = refined_steps
            result['used_refinement'] = True
        else:
            logger.info(f"Keeping original steps (refinement score: {refined_score:.2f} vs original: {original_score:.2f})")
        
        # Get final reflection
        logger.info("Generating final reflection...")
        reflection = self.reflection_module.evaluate_reasoning(
            question,
            result['steps']
        )
        
        logger.info("Reasoning generation complete")
        
        return {
            'question': question,
            'steps': result['steps'],
            'answer': result['answer'],
            'reflection': reflection,
            'retrieved_context': retrieved_info
        }
    
    def analyze_labels(self, dataloader, num_batches=5):
        """Analyze label distribution in the dataset"""
        label_stats = {
            "total_labels": 0,
            "non_ignored": 0,
            "min_value": float('inf'),
            "max_value": -float('inf'),
            "label_counts": {},
            "avg_length": 0,
            "vocab_size": len(self.tokenizer)
        }
        
        print(f"Analyzing labels from {num_batches} batches...")
        
        for i, batch in enumerate(dataloader):
            if i >= num_batches:
                break
                
            labels = batch['labels'].numpy().flatten()
            
            # Update statistics
            label_stats["total_labels"] += labels.size
            label_stats["non_ignored"] += (labels != -100).sum()
            
            # Update min/max (excluding -100)
            valid_labels = labels[labels != -100]
            if valid_labels.size > 0:
                label_stats["min_value"] = min(label_stats["min_value"], valid_labels.min())
                label_stats["max_value"] = max(label_stats["max_value"], valid_labels.max())
                
            # Count values
            unique, counts = np.unique(labels, return_counts=True)
            for val, count in zip(unique, counts):
                if val not in label_stats["label_counts"]:
                    label_stats["label_counts"][val] = 0
                label_stats["label_counts"][val] += count
                
            # Track average length of non-ignored sequences
            for l in batch['labels']:
                non_ignored_length = (l != -100).sum().item()
                label_stats["avg_length"] += non_ignored_length
        
        # Calculate average
        if i > 0:
            label_stats["avg_length"] /= (i * batch['labels'].size(0))
        
        # Print results
        print("\nLabel Analysis Results:")
        print(f"Vocabulary size: {label_stats['vocab_size']}")
        print(f"Total labels: {label_stats['total_labels']}")
        print(f"Non-ignored labels: {label_stats['non_ignored']} ({label_stats['non_ignored']/label_stats['total_labels']*100:.2f}%)")
        print(f"Min value (excluding -100): {label_stats['min_value']}")
        print(f"Max value: {label_stats['max_value']}")
        print(f"Average non-ignored length: {label_stats['avg_length']:.1f} tokens")
        
        # Check for out of vocabulary indices
        if label_stats["max_value"] >= label_stats["vocab_size"]:
            print("\n⚠️ WARNING: Some labels are outside the vocabulary range!")
            print(f"Max label value ({label_stats['max_value']}) >= Vocabulary size ({label_stats['vocab_size']})")
        
        # Value distribution excluding -100
        print("\nTop label values (excluding -100):")
        value_counts = {k: v for k, v in label_stats["label_counts"].items() if k != -100}
        top_values = sorted(value_counts.items(), key=lambda x: x[1], reverse=True)[:10]
        for val, count in top_values:
            token = self.tokenizer.decode([val])
            print(f"  {val} ('{token}'): {count} occurrences")
        
        return label_stats
    
    def verify_loss_calculation(self):
        """Verify loss calculation with a simple input"""
        print("\nVerifying loss calculation with test input...")
        
        # Create a simple input
        text = "This is a test sentence."
        encoding = self.tokenizer(text, return_tensors="pt")
        input_ids = encoding["input_ids"].to(self.config.device)
        attention_mask = encoding["attention_mask"].to(self.config.device)
        
        # Create labels (shift input_ids right by one)
        labels = input_ids.clone()
        
        # Check if EOS token exists, use PAD token or a default token if not
        eos_token_id = self.tokenizer.eos_token_id
        if eos_token_id is None:
            # Try pad token, or use a default token ID (usually 0 or 1)
            eos_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0
            print(f"Warning: No EOS token found. Using alternative token ID: {eos_token_id}")
        
        # Now create the shifted labels
        labels = torch.cat([labels[:, 1:], torch.tensor([[eos_token_id]]).to(self.config.device)], dim=1)
        
        # Show the tokens and labels
        print("Input tokens:", self.tokenizer.convert_ids_to_tokens(input_ids[0]))
        print("Label tokens:", self.tokenizer.convert_ids_to_tokens(labels[0]))
        
        # Forward pass
        self.cot_generator.model.train()  # Ensure in training mode
        outputs = self.cot_generator(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        
        # Calculate our improved loss
        loss = self._calculate_improved_loss(outputs, labels)
        
        print(f"Test loss calculation result: {loss.item()}")
        
        if loss.item() == 0:
            print("⚠️ WARNING: Loss is still zero on test input!")
            return False
        else:
            print("✓ Loss calculation is working properly!")
            
        # Check logits shape and activation
        logits = outputs.logits
        print(f"Logits shape: {logits.shape}")
        print(f"Logits mean: {logits.mean().item()}")
        print(f"Logits std: {logits.std().item()}")
        
        # Check predictions for expected tokens
        for i in range(min(3, input_ids.shape[1])):
            next_token_logits = logits[0, i, :]
            top_tokens = torch.topk(next_token_logits, 5)
            print(f"\nTop predictions for position {i} (input: '{self.tokenizer.decode([input_ids[0, i]])}'):")
            for token_id, score in zip(top_tokens.indices, top_tokens.values):
                token = self.tokenizer.decode([token_id])
                print(f"  {token} (ID: {token_id}): {score.item():.4f}")
        
        return loss.item() > 0

In [46]:
# Create output directories
os.makedirs(config.output_dir, exist_ok=True)
os.makedirs(config.checkpoint_dir, exist_ok=True)

# Setup logging
logger = logging.getLogger(__name__)
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
    handlers=[
        logging.FileHandler(os.path.join(config.output_dir, "train.log")),
        logging.StreamHandler()
    ]
)
logger.info(f"Configuration: {vars(config)}")

# Set random seeds for reproducibility
def set_seed(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# Sample evaluation questions
config.eval_questions = [
    "What happens when water boils?",
    "How does gravity work?",
    "Why does the moon have phases?",
    "What is photosynthesis?",
    "How do magnets work?"
]

# Initialize pipeline
pipeline = ReasoningPipeline(config, bert_model, bert_tokenizer, vision_processor)

# Train pipeline
trained_components = pipeline.train(
    config.train_path,
    config.val_path,
    num_rl_updates=config.rl_updates,
    num_self_training_iterations=config.self_training_iterations
)

# Evaluate final model
logger.info("Evaluating final model")

# Create test dataloader
test_dataset = ScienceQADataset(
    config.val_path,
    pipeline.tokenizer,
    pipeline.vision_processor,
    config,
    is_train=False
)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=config.batch_size,
    shuffle=False
)

# Calculate validation metrics
val_loss = pipeline._evaluate(test_dataloader)
logger.info(f"Final validation loss: {val_loss:.4f}")

# Generate examples for qualitative evaluation
logger.info("Generating example outputs")
for question in config.eval_questions:
    result = pipeline.generate(question)
    logger.info(f"Question: {question}")
    logger.info(f"Steps:")
    for i, step in enumerate(result['steps']):
        logger.info(f" {i+1}. {step}")
    logger.info(f"Answer: {result['answer']}")
    logger.info(f"Reflection score: {result['reflection']['scores']['overall_score']:.2f}")
    logger.info("---")

logger.info("Training and evaluation complete")

2025-03-06 15:47:17,412 - __main__ - INFO - Configuration: {'model_name': 'bert-base-uncased', 'tokenizer_name': 'bert-base-uncased', 'vision_model_name': 'openai/clip-vit-base-patch32', 'file_path': '/Users/Viku/Datasets/ScienceQA', 'train_path': '/Users/Viku/Datasets/ScienceQA/train/train.json', 'val_path': '/Users/Viku/Datasets/ScienceQA/val/val.json', 'max_seq_length': 512, 'batch_size': 4, 'learning_rate': 5e-05, 'weight_decay': 0.01, 'epochs': 3, 'warmup_steps': 100, 'max_grad_norm': 1.0, 'gradient_accumulation_steps': 8, 'ppo_epochs': 4, 'reward_scale': 0.01, 'clip_param': 0.2, 'value_loss_coef': 0.5, 'entropy_coef': 0.01, 'refiner_model_name': 'bert-base-uncased', 'refiner_learning_rate': 2e-05, 'refiner_weight_decay': 0.01, 'refiner_batch_size': 16, 'refiner_epochs': 2, 'refiner_max_seq_length': 256, 'retrieval_top_k': 3, 'embedding_dim': 768, 'reflection_threshold': 0.7, 'output_dir': 'outputs/', 'checkpoint_dir': 'checkpoints/', 'exemplar_path': 'data/exemplars.json', 'devic

ChainOfThoughtGenerator initialized with:
  Tokenizer: bert-base-uncased
  Model: bert-base-uncased
  Vision Model: openai/clip-vit-base-patch32
  Device: mps


If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`
2025-03-06 15:47:38,965 - __main__ - INFO - Loading ScienceQA data from /Users/Viku/Datasets/ScienceQA/train/train.json
2025-03-06 15:48:07,531 - __main__ - INFO - Processed 12726 ScienceQA examples
2025-03-06 15:48:07,538 - __main__ - INFO - Loading ScienceQA data from /Users/Viku/Datasets/ScienceQA/val/val.json
2025-03-06 15:48:12,721 - __main__ - INFO - Processed 4241 ScienceQA examples


Debugging labels in the dataset:

=== SAMPLE 0 ===
Total length: 512
Non-padded tokens: 48
Padded tokens: 464
Non-ignored labels: 21
Non-ignored percentage: 4.10%
Non-ignored / Non-padded ratio: 43.75%

Sample tokens (first 10):
Pos 0: Input='[CLS]' | Label='IGNORED' | Mask=1
Pos 1: Input='let' | Label='[CLS]' | Mask=1
Pos 2: Input=''' | Label='let' | Mask=1
Pos 3: Input='s' | Label=''' | Mask=1
Pos 4: Input='think' | Label='s' | Mask=1
Pos 5: Input='step' | Label='IGNORED' | Mask=1
Pos 6: Input='by' | Label='IGNORED' | Mask=1
Pos 7: Input='step' | Label='IGNORED' | Mask=1
Pos 8: Input='!' | Label='IGNORED' | Mask=1
Pos 9: Input='context' | Label='IGNORED' | Mask=1

Ignored label positions:
[0, 5, 6, 7, 8, 9, 10, 11, 12, 13] ... [502, 503, 504, 505, 506, 507, 508, 509, 510, 511]

=== SAMPLE 1 ===
Total length: 512
Non-padded tokens: 88
Padded tokens: 424
Non-ignored labels: 31
Non-ignored percentage: 6.05%
Non-ignored / Non-padded ratio: 35.23%

Sample tokens (first 10):
Pos 0: Input='

2025-03-06 15:48:18,033 - __main__ - INFO - Starting supervised fine-tuning


Test loss calculation result: 9.087302207946777
✓ Loss calculation is working properly!
Logits shape: torch.Size([1, 8, 30524])
Logits mean: -7.146275520324707
Logits std: 2.88424015045166

Top predictions for position 0 (input: '[CLS]'):
  and (ID: 1998): 7.7626
  of (ID: 1997): 6.7241
  in (ID: 1999): 6.5357
  , (ID: 1010): 6.2768
  . (ID: 1012): 6.2642

Top predictions for position 1 (input: 'this'):
  and (ID: 1998): 7.9312
  is (ID: 2003): 5.9540
  in (ID: 1999): 5.9331
  , (ID: 1010): 5.6260
  with (ID: 2007): 4.9781

Top predictions for position 2 (input: 'is'):
  is (ID: 2003): 5.5263
  and (ID: 1998): 4.9881
  , (ID: 1010): 4.7009
  act (ID: 2552): 4.6559
  formed (ID: 2719): 4.5885
Analyzing labels from 5 batches...

Label Analysis Results:
Vocabulary size: 30524
Total labels: 10240
Non-ignored labels: 483 (4.72%)
Min value (excluding -100): 101
Max value: 28290
Average non-ignored length: 24.1 tokens

Top label values (excluding -100):
  1007 (')'): 33 occurrences
  1006 ('(

Epoch 1/3 [Train]:   0%|          | 0/3182 [00:00<?, ?it/s]

Non-ignored labels: 151/2048 (7.37%)
Raw loss value: 3.2973084449768066
Labels in valid vocab range: 151


Epoch 1/3 [Train]:   0%|          | 2/3182 [01:17<34:02:28, 38.54s/it, loss=4.1196, step=2, device=mps]


KeyboardInterrupt: 