In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    AutoProcessor,
    CLIPVisionModel,
    AdamW, 
    get_linear_schedule_with_warmup,
    AutoModel
)
from sklearn.metrics.pairwise import cosine_similarity
from typing import List, Dict, Tuple, Optional, Union
import json
import logging
import faiss
from PIL import Image
import re
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

In [3]:
# Set random seeds for reproducibility
def set_seed(seed: int = 11):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
set_seed()

In [4]:
os.environ["MISTRAL_API_KEY"]="jJAuJZkjVcy2ynUhan375sHNviHiBeJU"

In [5]:
# Config class to hold hyperparameters
class Config:
    def __init__(self):
        # Base Model
        self.model_name = "bert-base-uncased"  # can be replaced with any suitable LLM
        self.tokenizer_name = "bert-base-uncased"
        
        # Vision Model
        self.vision_model_name = "openai/clip-vit-base-patch32"
        
        # ScienceQA Dataset
        self.file_path = "/Users/Viku/Datasets/ScienceQA"
        self.train_path = "/Users/Viku/Datasets/ScienceQA/train/train.json"
        self.val_path = "/Users/Viku/Datasets/ScienceQA/val/val.json"
        self.max_seq_length = 512
        self.batch_size = 4
        
        # Training
        self.learning_rate = 5e-5
        self.weight_decay = 0.01
        self.epochs = 3
        self.warmup_steps = 100
        self.max_grad_norm = 1.0
        self.gradient_accumulation_steps = 8
        
        # RL Training
        self.ppo_epochs = 4
        self.reward_scale = 0.01
        self.clip_param = 0.2
        self.value_loss_coef = 0.5
        self.entropy_coef = 0.01
        
        # GAN
        self.gan_learning_rate = 2e-5
        self.gan_weight_decay = 0.01
        self.gan_batch_size = 16
        self.gan_epochs = 2
        
        # Retrieval
        self.retrieval_top_k = 3
        self.embedding_dim = 768
        
        # Reflection
        self.reflection_threshold = 0.7
        
        # Paths
        self.output_dir = "outputs/"
        self.checkpoint_dir = "checkpoints/"
        self.exemplar_path = "data/exemplars.json"
        
        # Device
        self.device = torch.device("mps" if torch.mps.is_available() else "cpu")

        self.max_answer_length = 64
        self.rl_updates = 1000
        self.self_training_iterations = 3

config = Config()

In [6]:
from transformers import BertTokenizer, BertLMHeadModel, AutoImageProcessor, BertConfig

# Define the initialize_bert function
def initialize_bert():
    logger.info(f"Initializing BERT model: {config.model_name}")
    
    # First get the original config
    bert_config = BertConfig.from_pretrained(config.model_name)
    
    # Set is_decoder=True
    bert_config.is_decoder = True
    
    # Initialize tokenizer normally
    tokenizer = BertTokenizer.from_pretrained(config.tokenizer_name)
    
    # Initialize model with this modified config
    model = BertLMHeadModel.from_pretrained(
        config.model_name, 
        config=bert_config,
        ignore_mismatched_sizes=True
    )
    
    vision_processor = AutoImageProcessor.from_pretrained(config.vision_model_name)
    return model, tokenizer, vision_processor

# Now call the function to initialize the models
bert_model, bert_tokenizer, vision_processor = initialize_bert()
logger.info(f"BERT model initialized successfully")

# After initializing the model
model_path = os.path.join(config.checkpoint_dir, "bert_decoder")
os.makedirs(model_path, exist_ok=True)
bert_model.save_pretrained(model_path)
bert_tokenizer.save_pretrained(model_path)

# Then reload it
bert_model = BertLMHeadModel.from_pretrained(model_path)
bert_tokenizer = BertTokenizer.from_pretrained(model_path)


2025-03-06 12:26:31,148 - __main__ - INFO - Initializing BERT model: bert-base-uncased
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
2025-03-06 12:26:34,802 - __main__ - INFO - BERT model initialized successfully


In [7]:
class ScienceQADataset(Dataset):
    def __init__(self, file_path, tokenizer, vision_processor, config, is_train=True):
        self.tokenizer = tokenizer
        self.vision_processor = vision_processor
        self.config = config
        self.is_train = is_train
        self.base_dir = os.path.dirname(file_path)  # Get directory containing the JSON file
        self.data = self.load_and_preprocess_data(file_path)
        
    def load_and_preprocess_data(self, file_path):
        logger.info(f"Loading ScienceQA data from {file_path}")
        with open(file_path, 'r') as f:
            data = json.load(f)
        
        processed_data = []
        for item in data:
            # Extract fields specific to ScienceQA
            question = item.get('question', '')
            context = item.get('context', '')
            choices = item.get('choices', [])
            answer = item.get('answer', '')
            explanation = item.get('explanation', '')
            question_id = item.get('id', '')  # Get question ID for image path
            
            # Format choices as text
            choices_text = ""
            for i, choice in enumerate(choices):
                choices_text += f"({chr(65+i)}) {choice} "
            
            # Combine context and question
            full_question = f"Context: {context}\nQuestion: {question}\nChoices: {choices_text}"
            
            # Split explanation into reasoning steps
            steps = self.extract_reasoning_steps(explanation)
            
            # Process image if available
            visual_features = None
            if 'image' in item and item['image']:
                # Construct image path based on question ID in train/val folder structure
                # Assuming question_id corresponds to the folder name
                image_folder = os.path.join(self.base_dir, str(question_id))
                image_files = [f for f in os.listdir(image_folder) if f.endswith('.png')] if os.path.exists(image_folder) else []
                
                if image_files:
                    image_path = os.path.join(image_folder, image_files[0])
                    try:
                        image = Image.open(image_path).convert('RGB')
                        visual_features = self.process_image(image)
                    except Exception as e:
                        logger.error(f"Error processing image {image_path}: {e}")
            
            # Tokenize question
            question_tokens = self.tokenizer.encode(
                "Let's think step by step! " + full_question, 
                add_special_tokens=True,
                truncation=True,
                max_length=self.config.max_seq_length // 2
            )
            
            # Tokenize each step separately
            steps_tokens = []
            for step in steps:
                step_tokens = self.tokenizer.encode(
                    step,
                    add_special_tokens=False,
                    truncation=True,
                    max_length=self.config.max_seq_length // (2 * max(1, len(steps)))
                )
                steps_tokens.append(step_tokens)
            
            # Tokenize answer
            answer_tokens = self.tokenizer.encode(
                f"Therefore, the answer is {answer}",
                add_special_tokens=False,
                truncation=True,
                max_length=self.config.max_seq_length // 4
            )
            
            processed_data.append({
                'question': full_question,
                'question_tokens': question_tokens,
                'steps': steps,
                'steps_tokens': steps_tokens,
                'answer': answer,
                'answer_tokens': answer_tokens,
                'visual_features': visual_features,
                'has_image': visual_features is not None
            })
        
        logger.info(f"Processed {len(processed_data)} ScienceQA examples")
        return processed_data
    
    def process_image(self, image):
        """Process image using CLIP vision encoder"""
        print("Processing Image")
        inputs = self.vision_processor(images=image, return_tensors="pt")
        with torch.no_grad():
            vision_model = CLIPVisionModel.from_pretrained(self.config.vision_model_name)
            vision_model.to(self.config.device)
            outputs = vision_model(**{k: v.to(self.config.device) for k, v in inputs.items()})
            visual_features = outputs.pooler_output.cpu().numpy()
        return visual_features[0]  # Return the feature vector
    
    def extract_reasoning_steps(self, explanation):
        """Extract reasoning steps from explanation"""
        # Method 1: Split by numbered steps if present
        numbered_pattern = re.compile(r'\d+\.\s+')
        if numbered_pattern.search(explanation):
            steps = [step.strip() for step in numbered_pattern.split(explanation) if step.strip()]
            if steps and not steps[0][0].isdigit():  # Remove introduction if it doesn't start with a number
                steps = steps[1:]
            return steps or [explanation]
        
        # Method 2: Split by sentences assuming each sentence is a step
        sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', explanation)
        if len(sentences) > 1:
            return [s.strip() for s in sentences if s.strip()]
        
        # Default: Treat the whole explanation as one step
        return [explanation]
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Combine question, steps, and answer tokens for input
        input_tokens = item['question_tokens'].copy()
        for step_tokens in item['steps_tokens']:
            input_tokens.extend(step_tokens)
        input_tokens.extend(item['answer_tokens'])
        
        # Pad or truncate to max sequence length
        if len(input_tokens) > self.config.max_seq_length:
            input_tokens = input_tokens[:self.config.max_seq_length]
        
        # Create attention mask
        attention_mask = [1] * len(input_tokens)
        padding_length = self.config.max_seq_length - len(input_tokens)
        input_tokens.extend([self.tokenizer.pad_token_id] * padding_length)
        attention_mask.extend([0] * padding_length)
        
        # Create labels for language modeling (shift right)
        # We don't compute loss for the question part
        labels = [-100] * len(item['question_tokens'])
        for step_tokens in item['steps_tokens']:
            labels.extend(step_tokens)
        labels.extend(item['answer_tokens'])
        
        # Pad labels
        if len(labels) > self.config.max_seq_length:
            labels = labels[:self.config.max_seq_length]
        labels.extend([-100] * (self.config.max_seq_length - len(labels)))
        
        return {
            'input_ids': torch.tensor(input_tokens, dtype=torch.long),
            'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
            'labels': torch.tensor(labels, dtype=torch.long),
            'question': item['question'],
            'steps': item['steps'],
            'answer': item['answer'],
            'visual_features': torch.tensor(item['visual_features'], dtype=torch.float) if item['visual_features'] is not None else torch.zeros(768),
            'has_image': item['has_image']
        }

In [8]:
# 2. Base LLM & Chain-of-Thought Generator

class ChainOfThoughtGenerator(nn.Module):
    def __init__(self, config, model=None):
        super().__init__()
        self.config = config
        
        if model is not None:
            self.model = model
        else:
            # Initialize model from config
            self.model = AutoModelForCausalLM.from_pretrained(config.model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name)
        
        # Add special tokens if needed
        special_tokens = {"pad_token": "[PAD]"} if self.tokenizer.pad_token is None else {}
        if special_tokens:
            self.tokenizer.add_special_tokens(special_tokens)
        
        # Load the base language model
        self.model = AutoModelForCausalLM.from_pretrained(config.model_name)
        if special_tokens:
            self.model.resize_token_embeddings(len(self.tokenizer))
        
        # Vision encoder for multimodal inputs
        self.vision_processor = AutoProcessor.from_pretrained(config.vision_model_name)
        self.vision_model = CLIPVisionModel.from_pretrained(config.vision_model_name)
        
        # Vision-language integration layer
        self.vision_projection = nn.Linear(
            self.vision_model.config.hidden_size,
            self.model.config.hidden_size
        )
        
        # Move models to device
        self.model.to(config.device)
        self.vision_model.to(config.device)
        self.vision_projection.to(config.device)

        print("ChainOfThoughtGenerator initialized with:")
        print(f"  Tokenizer: {config.tokenizer_name}")
        print(f"  Model: {config.model_name}")
        print(f"  Vision Model: {config.vision_model_name}")
        print(f"  Device: {config.device}")

    def encode_image(self, image):
        """Encode image using vision model"""
        print("Encoding image...")
        vision_inputs = self.vision_processor(images=image, return_tensors="pt")
        vision_inputs = {k: v.to(self.config.device) for k, v in vision_inputs.items()}
        
        with torch.no_grad():
            vision_outputs = self.vision_model(**vision_inputs)
            image_features = vision_outputs.pooler_output
            projected_features = self.vision_projection(image_features)
        
        print("Image encoding completed.")
        return projected_features
    
    def forward(self, input_ids, attention_mask, labels=None, visual_features=None):
        # """Forward pass with optional visual features"""
        # print("Forward pass started.")
        # print(f"  Input IDs: {input_ids.shape}")
        # print(f"  Attention Mask: {attention_mask.shape}")
        # if labels is not None:
        #     print(f"  Labels: {labels.shape}")

        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )

        # print("Model forward pass completed.")

        # If visual features are available, enhance the hidden states
        if visual_features is not None:
            # print("Processing visual features...")
            if hasattr(outputs, 'hidden_states') and outputs.hidden_states is not None:
                projected_visual = self.vision_projection(visual_features)
                last_hidden = outputs.hidden_states[-1]
                enhanced_hidden = last_hidden + projected_visual.unsqueeze(1)
                # print("Visual features integrated into hidden states.")

        return outputs
    
    def generate_step_by_step(self, question, image=None, num_steps=5, max_length=512):
        """Generate a chain-of-thought reasoning process for a given question"""
        print(f"Generating reasoning for question: {question}")

        # Prepare input
        prompt = f"Let's think step by step! {question}"
        inputs = self.tokenizer(prompt, return_tensors="pt")
        inputs = {k: v.to(self.config.device) for k, v in inputs.items()}
        
        print(f"Tokenized input: {inputs}")

        # Process image if provided
        visual_embedding = None
        if image is not None:
            print("Processing image for reasoning...")
            visual_embedding = self.encode_image(image)

        # Generate reasoning steps and answer
        with torch.no_grad():
            print("Generating response from model...")
            outputs = self.model.generate(
                **inputs,
                max_length=max_length,
                num_return_sequences=1,
                temperature=0.7,
                do_sample=True
            )
        
        # Decode the generated text
        generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        print(f"Generated text: {generated_text}")

        # Extract reasoning steps and answer
        generated_text = generated_text[len(prompt):]  # Remove the prompt
        parts = generated_text.split("\nTherefore, the answer is")
        
        reasoning = parts[0]
        answer = parts[1] if len(parts) > 1 else "No clear answer provided."

        # Split reasoning into steps
        steps = []
        for step in reasoning.split("\n"):
            if step.strip():
                steps.append(step.strip())

        print("Generated reasoning steps:")
        for i, step in enumerate(steps):
            print(f"  Step {i+1}: {step}")

        print(f"Final answer: {answer.strip()}")

        return {
            "question": question,
            "steps": steps,
            "answer": answer.strip(),
            "full_text": generated_text
        }

In [9]:
# 3. Reflection Module

class ReflectionModule(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # Use a smaller model for efficiency
        self.encoder = AutoModel.from_pretrained("distilbert-base-uncased")
        
        # Scoring layers
        self.coherence_scorer = nn.Linear(self.encoder.config.hidden_size, 1)
        self.language_scorer = nn.Linear(self.encoder.config.hidden_size, 1)
        self.progress_scorer = nn.Linear(self.encoder.config.hidden_size, 1)
        
        # Tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
        
        # Move to device
        self.encoder.to(config.device)
        self.coherence_scorer.to(config.device)
        self.language_scorer.to(config.device)
        self.progress_scorer.to(config.device)
    
    def forward(self, question, steps, previous_steps=None):
        """
        Evaluate the quality of reasoning steps
        
        Args:
            question: The original question
            steps: List of reasoning steps to evaluate
            previous_steps: Optional previous steps for context
        
        Returns:
            Dictionary of scores for each step and overall
        """
        all_scores = []
        
        # Process each step
        for i, step in enumerate(steps):
            # Create context from question and previous steps
            context = question
            if previous_steps:
                context += " " + " ".join(previous_steps)
            
            # Tokenize
            inputs = self.tokenizer(
                context, 
                step, 
                truncation=True, 
                padding=True, 
                return_tensors="pt"
            )
            inputs = {k: v.to(self.config.device) for k, v in inputs.items()}
            
            # Get embeddings
            with torch.no_grad():
                outputs = self.encoder(**inputs)
                pooled_output = outputs.last_hidden_state[:, 0]  # Use [CLS] token
            
            # Calculate scores
            coherence_score = torch.sigmoid(self.coherence_scorer(pooled_output)).item()
            language_score = torch.sigmoid(self.language_scorer(pooled_output)).item()
            progress_score = torch.sigmoid(self.progress_scorer(pooled_output)).item()
            
            # Calculate a combined score
            combined_score = (coherence_score + language_score + progress_score) / 3
            
            all_scores.append({
                'step': i+1,
                'coherence': coherence_score,
                'language': language_score,
                'progress': progress_score,
                'combined': combined_score
            })
            
            # Update previous steps for next iteration
            previous_steps = (previous_steps or []) + [step]
        
        # Calculate overall score
        overall_score = sum(s['combined'] for s in all_scores) / len(all_scores) if all_scores else 0
        
        return {
            'step_scores': all_scores,
            'overall_score': overall_score
        }
    
    def evaluate_reasoning(self, question, steps, answer=None):
        """Evaluate the overall reasoning process"""
        step_scores = self.forward(question, steps)
        
        # Check if reasoning meets the threshold
        meets_threshold = step_scores['overall_score'] >= self.config.reflection_threshold
        
        return {
            'scores': step_scores,
            'meets_threshold': meets_threshold,
            'feedback': self.generate_feedback(step_scores) if not meets_threshold else None
        }
    
    def generate_feedback(self, scores):
        """Generate feedback based on scores"""
        feedback = []
        
        for step_score in scores['step_scores']:
            step_num = step_score['step']
            if step_score['coherence'] < 0.6:
                feedback.append(f"Step {step_num} lacks coherence with the context.")
            if step_score['language'] < 0.6:
                feedback.append(f"Step {step_num} has language issues.")
            if step_score['progress'] < 0.6:
                feedback.append(f"Step {step_num} doesn't make sufficient progress toward the answer.")
        
        if not feedback:
            feedback = ["The reasoning needs improvement, but specific issues weren't identified."]
        
        return feedback

In [10]:
# 4. Retrieval Module

class RetrievalModule:
    def __init__(self, config):
        self.config = config
        self.encoder = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
        self.tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
        
        # Move to device
        self.encoder.to(config.device)
        
        # Load exemplars
        self.exemplars = self.load_exemplars()
        
        # Build index for fast retrieval
        self.index = self.build_index()
    
    def load_exemplars(self):
        """Load exemplar reasoning sequences"""
        if not os.path.exists(self.config.exemplar_path):
            logger.warning(f"Exemplar file {self.config.exemplar_path} not found, using empty exemplars")
            return []
        
        with open(self.config.exemplar_path, 'r') as f:
            exemplars = json.load(f)
        
        # Pre-compute embeddings for each exemplar
        for exemplar in exemplars:
            exemplar['embedding'] = self.encode_text(exemplar['question']).cpu().numpy()
        
        logger.info(f"Loaded {len(exemplars)} exemplars")
        return exemplars
    
    def build_index(self):
        """Build FAISS index for fast retrieval"""
        if not self.exemplars:
            return None
        
        # Extract embeddings
        embeddings = np.array([ex['embedding'] for ex in self.exemplars]).astype('float32')
        
        # Build index
        index = faiss.IndexFlatL2(embeddings.shape[1])
        index.add(embeddings)
        
        return index
    
    def encode_text(self, text):
        """Encode text using the sentence transformer"""
        inputs = self.tokenizer(
            text, 
            truncation=True, 
            padding=True, 
            return_tensors="pt"
        )
        inputs = {k: v.to(self.config.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = self.encoder(**inputs)
            # Use mean pooling
            attention_mask = inputs['attention_mask']
            mask = attention_mask.unsqueeze(-1).expand(outputs.last_hidden_state.size()).float()
            masked_embeddings = outputs.last_hidden_state * mask
            summed = torch.sum(masked_embeddings, 1)
            counts = torch.clamp(torch.sum(mask, 1), min=1e-9)
            mean_pooled = summed / counts
        
        return mean_pooled[0]  # Return the embedding vector
    
    def retrieve_similar_examples(self, question, k=None):
        """Retrieve similar exemplars for a given question"""
        if k is None:
            k = self.config.retrieval_top_k
        
        if not self.exemplars or self.index is None:
            return []
        
        # Encode the query
        query_embedding = self.encode_text(question).cpu().numpy().reshape(1, -1).astype('float32')
        
        # Search for similar examples
        distances, indices = self.index.search(query_embedding, k)
        
        # Get the exemplars
        results = []
        for i, idx in enumerate(indices[0]):
            if idx < len(self.exemplars):
                exemplar = self.exemplars[idx].copy()
                exemplar['similarity'] = float(1.0 / (1.0 + distances[0][i]))  # Convert distance to similarity
                exemplar.pop('embedding', None)  # Remove embedding from result
                results.append(exemplar)
        
        return results
    
    def get_demonstration_prompt(self, question, k=None):
        """Get a few-shot demonstration prompt based on retrieved examples"""
        examples = self.retrieve_similar_examples(question, k)
        
        if not examples:
            return f"Let's think step by step! {question}"
        
        prompt = "I'll solve some similar problems step by step, then answer your question.\n\n"
        
        # Add examples
        for i, example in enumerate(examples):
            prompt += f"Example {i+1}:\n"
            prompt += f"Question: {example['question']}\n"
            prompt += "Reasoning:\n"
            for j, step in enumerate(example['steps']):
                prompt += f"{j+1}. {step}\n"
            prompt += f"Answer: {example['answer']}\n\n"
        
        # Add the current question
        prompt += f"Now, let's solve your question step by step!\n{question}\n"
        
        return prompt

In [11]:
class TextGenerator(nn.Module):
    """Generator model for refining intermediate reasoning steps"""
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # Small model for efficiency
        self.model = AutoModelForCausalLM.from_pretrained("distilgpt2")
        self.tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
        
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        self.model.to(config.device)
    
    def forward(self, input_ids, attention_mask):
        return self.model(input_ids=input_ids, attention_mask=attention_mask)
    
    def generate_refined_step(self, context, original_step=None, max_length=100):
        """Generate a refined reasoning step"""
        # Prepare input
        if original_step:
            prompt = f"{context}\nRefined step: "
        else:
            prompt = f"{context}\nNext step: "
        
        inputs = self.tokenizer(prompt, return_tensors="pt")
        inputs = {k: v.to(self.config.device) for k, v in inputs.items()}
        
        # Generate refined step
        outputs = self.model.generate(
            **inputs,
            max_length=len(inputs["input_ids"][0]) + max_length,
            temperature=0.8,
            top_p=0.92,
            do_sample=True,
            num_return_sequences=1
        )
        
        # Decode the generated text
        generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        refined_step = generated_text[len(prompt):].strip()
        
        return refined_step

class TextDiscriminator(nn.Module):
    """Discriminator model for evaluating reasoning steps"""
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # Use a small pretrained model
        self.encoder = AutoModel.from_pretrained("distilbert-base-uncased")
        self.tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
        
        # Scoring layer
        self.scorer = nn.Sequential(
            nn.Linear(self.encoder.config.hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        
        self.encoder.to(config.device)
        self.scorer.to(config.device)
    
    def forward(self, context, step):
        """Evaluate the quality of a reasoning step"""
        # Tokenize
        inputs = self.tokenizer(
            context, 
            step, 
            truncation=True, 
            padding=True, 
            return_tensors="pt"
        )
        inputs = {k: v.to(self.config.device) for k, v in inputs.items()}
        
        # Get embeddings
        with torch.no_grad():
            outputs = self.encoder(**inputs)
            pooled_output = outputs.last_hidden_state[:, 0]  # Use [CLS] token
        
        # Calculate score
        score = self.scorer(pooled_output)
        
        return score

class GANModule:
    """GAN-based text refinement module"""
    def __init__(self, config):
        self.config = config
        self.generator = TextGenerator(config)
        self.discriminator = TextDiscriminator(config)
        
        # Optimizers
        self.gen_optimizer = AdamW(
            self.generator.parameters(),
            lr=config.gan_learning_rate,
            weight_decay=config.gan_weight_decay
        )
        self.disc_optimizer = AdamW(
            self.discriminator.parameters(),
            lr=config.gan_learning_rate,
            weight_decay=config.gan_weight_decay
        )
    
    def train_step(self, real_examples):
        """Single GAN training step"""
        # Train discriminator
        self.disc_optimizer.zero_grad()
        
        # Real examples
        real_contexts = [ex['context'] for ex in real_examples]
        real_steps = [ex['step'] for ex in real_examples]
        real_scores = []
        
        for context, step in zip(real_contexts, real_steps):
            real_score = self.discriminator(context, step)
            real_scores.append(real_score)
        
        real_loss = sum([(1 - score) ** 2 for score in real_scores]) / len(real_scores)
        
        # Generated examples
        fake_steps = []
        for context in real_contexts:
            with torch.no_grad():
                fake_step = self.generator.generate_refined_step(context)
                fake_steps.append(fake_step)
        
        fake_scores = []
        for context, step in zip(real_contexts, fake_steps):
            fake_score = self.discriminator(context, step)
            fake_scores.append(fake_score)
        
        fake_loss = sum([score ** 2 for score in fake_scores]) / len(fake_scores)
        
        # Combine losses
        disc_loss = real_loss + fake_loss
        disc_loss.backward()
        self.disc_optimizer.step()
        
        # Train generator
        self.gen_optimizer.zero_grad()
        
        # Generate new fake examples
        new_fake_steps = []
        for context in real_contexts:
            fake_step = self.generator.generate_refined_step(context)
            new_fake_steps.append(fake_step)
        
        # Calculate generator loss
        gen_scores = []
        for context, step in zip(real_contexts, new_fake_steps):
            gen_score = self.discriminator(context, step)
            gen_scores.append(gen_score)
        
        gen_loss = sum([(1 - score) ** 2 for score in gen_scores]) / len(gen_scores)
        gen_loss.backward()
        self.gen_optimizer.step()
        
        return {
            'disc_loss': disc_loss.item(),
            'gen_loss': gen_loss.item(),
            'real_score_mean': sum([s.item() for s in real_scores]) / len(real_scores),
            'fake_score_mean': sum([s.item() for s in fake_scores]) / len(fake_scores)
        }
    
    def refine_reasoning_steps(self, question, original_steps):
        """Refine a sequence of reasoning steps"""
        refined_steps = []
        context = question
        
        for i, step in enumerate(original_steps):
            # Context includes the question and previous steps
            current_context = context
            
            # Generate refined step
            refined_step = self.generator.generate_refined_step(current_context, step)
            
            # Evaluate original and refined
            orig_score = self.discriminator(current_context, step).item()
            refined_score = self.discriminator(current_context, refined_step).item()
            
            # Use the better step
            if refined_score > orig_score:
                refined_steps.append(refined_step)
                context += f"\nStep {i+1}: {refined_step}"
            else:
                refined_steps.append(step)
                context += f"\nStep {i+1}: {step}"
        
        return refined_steps

In [12]:
# 6. Dual Reward Function and RL Training

class RewardFunction:
    """Combines outcome and process rewards for RL training"""
    def __init__(self, reflection_module, config):
        self.reflection_module = reflection_module
        self.config = config
        
        # For outcome verification
        self.tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
        
    def calculate_outcome_reward(self, predicted_answer, ground_truth):
        """Calculate reward based on correctness of final answer"""
        # For ScienceQA, we can do a simple match check
        # In a real system, you might use more sophisticated answer verification
        predicted_normalized = predicted_answer.strip().lower()
        ground_truth_normalized = ground_truth.strip().lower()
        
        # Check for exact match
        if predicted_normalized == ground_truth_normalized:
            return 1.0
        
        # Check for partial match (for longer answers)
        if len(ground_truth_normalized) > 10:
            # Using simple token overlap as a metric
            pred_tokens = set(self.tokenizer.tokenize(predicted_normalized))
            truth_tokens = set(self.tokenizer.tokenize(ground_truth_normalized))
            
            if not truth_tokens:
                return 0.0
                
            overlap = len(pred_tokens.intersection(truth_tokens)) / len(truth_tokens)
            return max(0.0, overlap - 0.3)  # Only reward significant overlap
        
        return 0.0
    
    def calculate_process_reward(self, question, steps, original_steps=None):
        """Calculate reward based on quality of reasoning process"""
        # Get reflection scores
        reflection_result = self.reflection_module.evaluate_reasoning(question, steps)
        process_score = reflection_result['scores']['overall_score']
        
        # If we have original steps, reward improvement
        if original_steps:
            original_result = self.reflection_module.evaluate_reasoning(question, original_steps)
            original_score = original_result['scores']['overall_score']
            
            # Reward improvement, penalize degradation
            improvement = process_score - original_score
            if improvement > 0:
                process_score += 0.2 * improvement  # Bonus for improvement
            else:
                process_score += 0.1 * improvement  # Smaller penalty for degradation
        
        return process_score
    
    def calculate_combined_reward(self, sample, ground_truth, original_steps=None):
        """Calculate combined reward from outcome and process"""
        question = sample['question']
        steps = sample['steps']
        predicted_answer = sample['answer']
        
        # Calculate component rewards
        outcome_reward = self.calculate_outcome_reward(predicted_answer, ground_truth)
        process_reward = self.calculate_process_reward(question, steps, original_steps)
        
        # Combine rewards
        # The balance between outcome and process rewards is important
        # In this implementation, we favor process for ScienceQA
        combined_reward = 0.4 * outcome_reward + 0.6 * process_reward
        
        return {
            'combined': combined_reward,
            'outcome': outcome_reward,
            'process': process_reward,
            'details': {
                'answer_correct': outcome_reward > 0.9,
                'reasoning_quality': process_reward
            }
        }

class PPOTrainer:
    """PPO-based RL trainer for the reasoning model"""
    def __init__(self, cot_generator, reward_function, config):
        self.cot_generator = cot_generator
        self.reward_function = reward_function
        self.config = config
        
        # Create a reference model for KL penalty
        self.ref_model = AutoModelForCausalLM.from_pretrained(config.model_name)
        self.ref_model.to(config.device)
        self.ref_model.eval()
        
        # Optimizer
        self.optimizer = AdamW(
            self.cot_generator.model.parameters(),
            lr=config.learning_rate,
            weight_decay=config.weight_decay
        )
        
        # Initialize policy entropy and value losses
        self.policy_loss = 0
        self.value_losses = []
        self.entropy_losses = []
    
    def train_step(self, batch, gan_module=None):
        """Perform a single PPO training step"""
        # Extract data
        input_ids = batch['input_ids'].to(self.config.device)
        attention_mask = batch['attention_mask'].to(self.config.device)
        questions = batch['question']
        ground_truth_answers = batch['answer']
        
        # Forward pass with current policy to get initial log probs and values
        with torch.no_grad():
            outputs = self.cot_generator.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                return_dict=True
            )
            old_logits = outputs.logits
            
            # Extract values (implicitly learned through the LM head)
            # In a full implementation, you would have a separate value head
            values = torch.mean(old_logits, dim=-1)  # Simplistic value estimation
        
        # Generate samples from current policy
        generated_samples = []
        for i in range(len(questions)):
            sample = self.cot_generator.generate_step_by_step(questions[i], image=None)
            generated_samples.append(sample)
        
        # Optionally refine with GAN
        if gan_module:
            for i, sample in enumerate(generated_samples):
                refined_steps = gan_module.refine_reasoning_steps(
                    sample['question'], 
                    sample['steps']
                )
                generated_samples[i]['refined_steps'] = refined_steps
        
        # Calculate rewards
        rewards = []
        for i, sample in enumerate(generated_samples):
            steps_to_evaluate = sample.get('refined_steps', sample['steps'])
            reward = self.reward_function.calculate_combined_reward(
                {
                    'question': sample['question'],
                    'steps': steps_to_evaluate,
                    'answer': sample['answer']
                },
                ground_truth_answers[i]
            )
            rewards.append(reward['combined'])
        
        rewards_tensor = torch.tensor(rewards, device=self.config.device)
        
        # PPO optimization loop
        for _ in range(self.config.ppo_epochs):
            # Forward pass with current policy
            outputs = self.cot_generator.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                return_dict=True
            )
            logits = outputs.logits
            
            # Calculate new log probabilities and values
            new_values = torch.mean(logits, dim=-1)  # Simplistic value estimation
            
            # Compute KL divergence penalty
            kl_div = self._compute_kl_divergence(old_logits, logits, attention_mask)
            
            # Compute policy loss (PPO clipped objective)
            # In a full implementation, you would compute proper action probabilities
            # For simplicity, we're using a proxy based on logits difference
            logit_diff = torch.sum(torch.abs(logits - old_logits), dim=-1)
            policy_ratio = torch.exp(-logit_diff * 0.01)  # Proxy for probability ratio
            
            clipped_ratio = torch.clamp(
                policy_ratio, 
                1.0 - self.config.clip_param, 
                1.0 + self.config.clip_param
            )
            
            policy_reward = rewards_tensor.unsqueeze(-1).expand_as(policy_ratio)
            policy_loss = -torch.min(
                policy_ratio * policy_reward,
                clipped_ratio * policy_reward
            ).mean()
            
            # Value loss
            value_loss = F.mse_loss(new_values, rewards_tensor.unsqueeze(-1).expand_as(new_values))
            
            # Entropy for exploration
            # Simplified entropy calculation
            entropy = torch.mean(torch.std(logits, dim=-1))
            entropy_loss = -self.config.entropy_coef * entropy
            
            # Total loss
            loss = policy_loss + self.config.value_loss_coef * value_loss + entropy_loss + 0.01 * kl_div
            
            # Backward and optimize
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.cot_generator.model.parameters(), self.config.max_grad_norm)
            self.optimizer.step()
            
            # Store metrics
            self.policy_loss = policy_loss.item()
            self.value_losses.append(value_loss.item())
            self.entropy_losses.append(entropy_loss.item())
        
        return {
            'policy_loss': self.policy_loss,
            'value_loss': sum(self.value_losses) / len(self.value_losses),
            'entropy_loss': sum(self.entropy_losses) / len(self.entropy_losses),
            'mean_reward': rewards_tensor.mean().item()
        }
    
    def _compute_kl_divergence(self, old_logits, new_logits, attention_mask):
        """Compute KL divergence between old and new policies"""
        old_probs = F.softmax(old_logits, dim=-1)
        new_probs = F.softmax(new_logits, dim=-1)
        
        # KL divergence
        kl = old_probs * (torch.log(old_probs) - torch.log(new_probs))
        kl = kl.sum(-1)
        
        # Apply attention mask
        kl = kl * attention_mask.float()
        
        # Average over non-masked tokens
        kl = kl.sum() / attention_mask.float().sum()
        
        return kl

In [13]:
# 7. Self-Training and Distillation

class SelfTrainer:
    """Self-training through iterative pseudo-labeling"""
    def __init__(self, cot_generator, config):
        self.cot_generator = cot_generator
        self.config = config
        self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name)
        
        # Optimizer for fine-tuning
        self.optimizer = AdamW(
            self.cot_generator.model.parameters(),
            lr=config.learning_rate,
            weight_decay=config.weight_decay
        )
        
        # LR scheduler
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=config.warmup_steps,
            num_training_steps=1000  # Will be updated when dataset size is known
        )
    
    def generate_pseudo_labels(self, unlabeled_data, ground_truth_answers=None):
        """Generate pseudo-labels for unlabeled data"""
        pseudo_labeled_data = []
        
        for i, sample in enumerate(unlabeled_data):
            question = sample['question']
            
            # Generate reasoning steps and answer
            generated = self.cot_generator.generate_step_by_step(question, image=None)
            
            # Check if the answer is correct (if ground truth is available)
            is_correct = False
            if ground_truth_answers is not None:
                ground_truth = ground_truth_answers[i]
                predicted = generated['answer'].strip().lower()
                ground_truth = ground_truth.strip().lower()
                is_correct = predicted == ground_truth
            
            # Only include correct answers or all if no ground truth
            if is_correct or ground_truth_answers is None:
                pseudo_labeled_data.append({
                    'question': question,
                    'steps': generated['steps'],
                    'answer': generated['answer'],
                    'confidence': 1.0  # In a real implementation, you'd use model confidence
                })
        
        return pseudo_labeled_data
    
    def finetune_on_pseudo_labels(self, pseudo_labeled_data, epochs=None):
        """Fine-tune model on pseudo-labeled data"""
        if epochs is None:
            epochs = self.config.epochs
        
        # Create dataset
        dataset = self._create_dataset_from_samples(pseudo_labeled_data)
        
        # Adjust scheduler
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=self.config.warmup_steps,
            num_training_steps=epochs * len(dataset)
        )
        
        # Fine-tuning loop
        self.cot_generator.model.train()
        total_loss = 0
        global_step = 0
        
        for epoch in range(epochs):
            logger.info(f"Starting epoch {epoch+1}/{epochs}")
            epoch_loss = 0
            
            for batch in dataset:
                # Move batch to device
                batch = {k: v.to(self.config.device) if isinstance(v, torch.Tensor) else v 
                         for k, v in batch.items()}
                
                # Forward and backward
                outputs = self.cot_generator.model(
                    input_ids=batch['input_ids'],
                    attention_mask=batch['attention_mask'],
                    labels=batch['labels']
                )
                
                loss = outputs.loss
                epoch_loss += loss.item()
                
                # Backward pass
                loss.backward()
                
                # Clip gradients
                torch.nn.utils.clip_grad_norm_(
                    self.cot_generator.model.parameters(), 
                    self.config.max_grad_norm
                )
                
                # Update weights
                self.optimizer.step()
                self.scheduler.step()
                self.optimizer.zero_grad()
                
                global_step += 1
                
                # Log progress
                if global_step % 50 == 0:
                    logger.info(f"Step {global_step}: loss = {loss.item()}")
            
            avg_epoch_loss = epoch_loss / len(dataset)
            total_loss += avg_epoch_loss
            logger.info(f"Epoch {epoch+1} completed: Average loss = {avg_epoch_loss}")
        
        avg_loss = total_loss / epochs
        logger.info(f"Fine-tuning completed: Average loss = {avg_loss}")
        
        return avg_loss
    
    def _create_dataset_from_samples(self, samples):
        """Create a dataset from generated samples"""
        dataset = []
        
        for sample in samples:
            question = sample['question']
            steps = sample['steps']
            answer = sample['answer']
            
            # Prepare input text
            input_text = f"Let's think step by step! {question}\n"
            for i, step in enumerate(steps):
                input_text += f"{i+1}. {step}\n"
            input_text += f"Therefore, the answer is {answer}"
            
            # Tokenize
            encodings = self.tokenizer(
                input_text,
                truncation=True,
                max_length=self.config.max_seq_length,
                padding="max_length",
                return_tensors="pt"
            )
            
            # Create labels for causal LM training
            input_ids = encodings['input_ids'][0]
            attention_mask = encodings['attention_mask'][0]
            labels = input_ids.clone()
            
            # Mask question part in labels
            question_tokens = self.tokenizer.encode(
                f"Let's think step by step! {question}",
                add_special_tokens=True
            )
            labels[:len(question_tokens)] = -100
            
            dataset.append({
                'input_ids': input_ids,
                'attention_mask': attention_mask,
                'labels': labels
            })
        
        return dataset
    
    def self_training_loop(self, labeled_data, unlabeled_data, num_iterations=3):
        """Run the complete self-training loop"""
        for iteration in range(num_iterations):
            logger.info(f"Starting self-training iteration {iteration+1}/{num_iterations}")
            
            # Generate pseudo-labels
            pseudo_labels = self.generate_pseudo_labels(
                unlabeled_data,
                ground_truth_answers=[item['answer'] for item in unlabeled_data]
            )
            
            if not pseudo_labels:
                logger.warning("No pseudo-labels generated. Stopping self-training.")
                break
            
            logger.info(f"Generated {len(pseudo_labels)} pseudo-labels")
            
            # Combine with labeled data
            combined_data = labeled_data + pseudo_labels
            
            # Fine-tune on combined data
            loss = self.finetune_on_pseudo_labels(combined_data)
            
            logger.info(f"Iteration {iteration+1} completed: loss = {loss}")
            
            # Update labeled data for next iteration
            labeled_data = combined_data
        
        logger.info("Self-training completed")
        
        # Save the final model
        self.cot_generator.model.save_pretrained(os.path.join(self.config.output_dir, "self_trained_model"))
        self.tokenizer.save_pretrained(os.path.join(self.config.output_dir, "self_trained_tokenizer"))
        
        return labeled_data

class ModelDistiller:
    """Knowledge distillation for creating smaller, efficient models"""
    def __init__(self, teacher_model, config, student_model_name="distilbert-base-uncased"):
        self.teacher_model = teacher_model
        self.config = config
        
        # Load smaller student model
        self.student_tokenizer = AutoTokenizer.from_pretrained(student_model_name)
        self.student_model = AutoModelForCausalLM.from_pretrained(student_model_name)
        
        # Add special tokens if needed
        special_tokens = {"pad_token": "[PAD]"} if self.student_tokenizer.pad_token is None else {}
        if special_tokens:
            self.student_tokenizer.add_special_tokens(special_tokens)
            self.student_model.resize_token_embeddings(len(self.student_tokenizer))
        
        # Move to device
        self.student_model.to(config.device)
        
        # Optimizer
        self.optimizer = AdamW(
            self.student_model.parameters(),
            lr=2e-5,  # Usually higher for distillation
            weight_decay=0.01
        )
        
        # Temperature for softening distributions
        self.temperature = 2.0
    
    def distill(self, dataset, epochs=3):
        """Distill knowledge from teacher to student"""
        self.student_model.train()
        self.teacher_model.model.eval()
        
        total_loss = 0
        step_count = 0
        
        for epoch in range(epochs):
            logger.info(f"Starting distillation epoch {epoch+1}/{epochs}")
            epoch_loss = 0
            
            for batch in dataset:
                # Convert to student tokenization
                student_inputs = self._convert_teacher_to_student_inputs(batch)
                
                # Move to device
                student_inputs = {k: v.to(self.config.device) if isinstance(v, torch.Tensor) else v
                                 for k, v in student_inputs.items()}
                
                # Get teacher predictions
                with torch.no_grad():
                    teacher_outputs = self.teacher_model.model(
                        input_ids=batch['input_ids'].to(self.config.device),
                        attention_mask=batch['attention_mask'].to(self.config.device)
                    )
                    
                    # Apply temperature scaling to logits
                    teacher_logits = teacher_outputs.logits / self.temperature
                
                # Student forward pass
                student_outputs = self.student_model(**student_inputs)
                student_logits = student_outputs.logits / self.temperature
                
                # Compute distillation loss
                # Standard cross-entropy loss for task performance
                task_loss = F.cross_entropy(
                    student_logits.view(-1, student_logits.size(-1)),
                    student_inputs['labels'].view(-1),
                    ignore_index=-100
                )
                
                # Distillation loss (KL divergence)
                # We need to align teacher and student token representations
                aligned_teacher_logits = self._align_teacher_student_representations(
                    teacher_logits, 
                    student_logits,
                    batch['attention_mask'].to(self.config.device),
                    student_inputs['attention_mask']
                )
                
                distillation_loss = F.kl_div(
                    F.log_softmax(student_logits, dim=-1),
                    F.softmax(aligned_teacher_logits, dim=-1),
                    reduction='batchmean'
                )
                
                # Combine losses
                loss = 0.5 * task_loss + 0.5 * distillation_loss
                
                # Backward and optimize
                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.student_model.parameters(), 1.0)
                self.optimizer.step()
                
                # Track losses
                epoch_loss += loss.item()
                step_count += 1
                
                # Log progress
                if step_count % 100 == 0:
                    logger.info(f"Distillation step {step_count}: loss = {loss.item()}")
            
            avg_epoch_loss = epoch_loss / len(dataset)
            total_loss += avg_epoch_loss
            logger.info(f"Distillation epoch {epoch+1} completed: Average loss = {avg_epoch_loss}")
        
        avg_loss = total_loss / epochs
        logger.info(f"Distillation completed: Average loss = {avg_loss}")
        
        # Save distilled model
        self.student_model.save_pretrained(os.path.join(self.config.output_dir, "distilled_model"))
        self.student_tokenizer.save_pretrained(os.path.join(self.config.output_dir, "distilled_tokenizer"))
        
        return self.student_model
    
    def _convert_teacher_to_student_inputs(self, teacher_batch):
        """Convert teacher batch to student tokenization"""
        # This is a placeholder - in practice, you would need to implement 
        # conversion between different tokenizers
        return teacher_batch
    
    def _align_teacher_student_representations(self, teacher_logits, student_logits, 
                                              teacher_mask, student_mask):
        """Align teacher and student token representations"""
        # This is a placeholder for token alignment
        # In practice, you'd need to implement vocabulary mapping
        return teacher_logits

In [14]:
# 8. Integration - Full Reasoning Pipeline
class ReasoningPipeline:
    """Full reasoning pipeline integrating all components"""
    def __init__(self, config, bert_model=None, bert_tokenizer=None, vision_processor=None):
        self.config = config
        
        # Use provided models/tokenizers if available, otherwise initialize from config
        if bert_tokenizer is not None:
            self.tokenizer = bert_tokenizer
        else:
            self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name)
            
        if vision_processor is not None:
            self.vision_processor = vision_processor
        else:
            self.vision_processor = AutoProcessor.from_pretrained(config.vision_model_name)
        
        # Add special tokens if needed
        special_tokens = {"pad_token": "[PAD]"} if self.tokenizer.pad_token is None else {}
        if special_tokens:
            self.tokenizer.add_special_tokens(special_tokens)
        
        # Initialize components, passing the BERT model to CoT generator if provided
        self.cot_generator = ChainOfThoughtGenerator(config, model=bert_model)
        self.reflection_module = ReflectionModule(config)
        self.retrieval_module = RetrievalModule(config)
        self.gan_module = GANModule(config)
        
        # Create reward function
        self.reward_function = RewardFunction(self.reflection_module, config)
        
        # Create RL trainer
        self.ppo_trainer = PPOTrainer(self.cot_generator, self.reward_function, config)
        
        # Create self-trainer
        self.self_trainer = SelfTrainer(self.cot_generator, config)
        
        # Initialize model distiller
        self.distiller = None
    
    def train(self, train_file, val_file, num_rl_updates=1000, num_self_training_iterations=3):
        """Train the full reasoning pipeline"""
        # Load datasets
        train_dataset = ScienceQADataset(
            train_file, 
            self.tokenizer, 
            self.vision_processor, 
            self.config, 
            is_train=True
        )
        
        val_dataset = ScienceQADataset(
            val_file, 
            self.tokenizer, 
            self.vision_processor, 
            self.config, 
            is_train=False
        )
        
        # Create data loaders
        train_dataloader = DataLoader(
            train_dataset, 
            batch_size=self.config.batch_size, 
            shuffle=True
        )
        
        val_dataloader = DataLoader(
            val_dataset, 
            batch_size=self.config.batch_size, 
            shuffle=False
        )
        
        # 1. Initial supervised fine-tuning
        logger.info("Starting supervised fine-tuning")
        self._supervised_finetuning(train_dataloader, val_dataloader)
        
        # 2. Train GAN components
        logger.info("Training GAN components")
        self._train_gan(train_dataloader)
        
        # 3. RL fine-tuning
        logger.info("Starting RL fine-tuning")
        self._rl_finetuning(train_dataloader, num_rl_updates)
        
        # 4. Self-training loop
        logger.info("Starting self-training")
        labeled_data = train_dataset.data[:100]  # Start with a small labeled subset
        unlabeled_data = train_dataset.data[100:]  # The rest is unlabeled
        final_labeled_data = self.self_trainer.self_training_loop(
            labeled_data, 
            unlabeled_data, 
            num_iterations=num_self_training_iterations
        )
        
        # 5. Distillation to smaller model
        logger.info("Starting model distillation")
        self.distiller = ModelDistiller(self.cot_generator, self.config)
        distilled_model = self.distiller.distill(train_dataloader, epochs=3)
        
        logger.info("Training complete")
        
        return {
            'cot_generator': self.cot_generator,
            'reflection_module': self.reflection_module,
            'retrieval_module': self.retrieval_module,
            'gan_module': self.gan_module,
            'distilled_model': distilled_model
        }
    
    def _supervised_finetuning(self, train_dataloader, val_dataloader, epochs=None):
        """Supervised fine-tuning on labeled data"""
        if epochs is None:
            epochs = self.config.epochs
        
        # Setup optimizer
        optimizer = AdamW(
            self.cot_generator.model.parameters(),
            lr=self.config.learning_rate,
            weight_decay=self.config.weight_decay
        )
        
        # Setup scheduler
        total_steps = len(train_dataloader) * epochs
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.config.warmup_steps,
            num_training_steps=total_steps
        )
        
        # Training loop
        self.cot_generator.model.train()
        global_step = 0
        best_val_loss = float('inf')
        
        for epoch in range(epochs):
            epoch_loss = 0
            
            # Training with tqdm progress bar
            pbar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{epochs} [Train]", leave=True)
            for batch in pbar:
                # Move batch to device
                input_ids = batch['input_ids'].to(self.config.device)
                attention_mask = batch['attention_mask'].to(self.config.device)
                labels = batch['labels'].to(self.config.device)
                visual_features = None
                if batch.get('visual_features') is not None:
                    visual_features = batch['visual_features'].to(self.config.device)
                
                # Forward pass
                outputs = self.cot_generator(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels,
                    visual_features=visual_features
                )
                
                loss = outputs.loss
                
                # Backward pass
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    self.cot_generator.model.parameters(), 
                    self.config.max_grad_norm
                )
                optimizer.step()
                scheduler.step()
                
                epoch_loss += loss.item()
                global_step += 1
                
                # Update progress bar
                pbar.set_postfix({"loss": f"{loss.item():.4f}", "step": global_step})
            
            avg_train_loss = epoch_loss / len(train_dataloader)
            logger.info(f"Epoch {epoch+1}/{epochs} completed: Average training loss = {avg_train_loss}")
            
            # Validation with tqdm
            val_loss = self._evaluate(val_dataloader)
            logger.info(f"Validation loss: {val_loss}")
            
            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                self.cot_generator.model.save_pretrained(
                    os.path.join(self.config.checkpoint_dir, "best_model")
                )
                logger.info("Saved new best model")
    
    def _evaluate(self, dataloader):
        """Evaluate model on dataloader"""
        self.cot_generator.model.eval()
        total_loss = 0
        
        # Evaluation with tqdm progress bar
        pbar = tqdm(dataloader, desc="Evaluation", leave=False)
        with torch.no_grad():
            for batch in pbar:
                # Move batch to device
                input_ids = batch['input_ids'].to(self.config.device)
                attention_mask = batch['attention_mask'].to(self.config.device)
                labels = batch['labels'].to(self.config.device)
                visual_features = None
                if batch.get('visual_features') is not None:
                    visual_features = batch['visual_features'].to(self.config.device)
                
                # Forward pass
                outputs = self.cot_generator(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels,
                    visual_features=visual_features
                )
                
                loss = outputs.loss.item()
                total_loss += loss
                
                # Update progress bar
                pbar.set_postfix({"loss": f"{loss:.4f}"})
        
        # Set back to training mode
        self.cot_generator.model.train()
        
        return total_loss / len(dataloader)
    
    def _train_gan(self, dataloader, epochs=None):
        """Train GAN components"""
        if epochs is None:
            epochs = self.config.gan_epochs
        
        # Create training examples
        gan_training_data = []
        
        # Sample batches for GAN training with progress bar
        pbar = tqdm(dataloader, desc="Preparing GAN training data", leave=True)
        for batch in pbar:
            questions = batch['question']
            steps_list = batch['steps']
            
            for question, steps in zip(questions, steps_list):
                for i in range(1, len(steps)):
                    context = question + " " + " ".join(steps[:i])
                    gan_training_data.append({
                        'context': context,
                        'step': steps[i]
                    })
                    
                    # Break after a few examples per sample
                    if i >= 3:
                        break
            
            # Update progress bar
            pbar.set_postfix({"examples": len(gan_training_data)})
            
            # Limit training data size
            if len(gan_training_data) >= 1000:
                break
        
        # Train GAN for specified epochs
        for epoch in range(epochs):
            logger.info(f"GAN training epoch {epoch+1}/{epochs}")
            epoch_g_loss = 0
            epoch_d_loss = 0
            
            # Shuffle training data
            random.shuffle(gan_training_data)
            
            # Create batches
            batch_size = self.config.gan_batch_size
            num_batches = len(gan_training_data) // batch_size
            
            # Train with progress bar
            pbar = tqdm(range(num_batches), desc=f"GAN Epoch {epoch+1}/{epochs}", leave=True)
            for i in pbar:
                batch_data = gan_training_data[i * batch_size:(i + 1) * batch_size]
                
                # Train discriminator
                d_loss = self.gan_module.train_discriminator(batch_data)
                epoch_d_loss += d_loss
                
                # Train generator
                g_loss = self.gan_module.train_generator(batch_data)
                epoch_g_loss += g_loss
                
                # Update progress bar
                pbar.set_postfix({
                    "G_loss": f"{g_loss:.4f}", 
                    "D_loss": f"{d_loss:.4f}"
                })
            
            # Average losses
            avg_g_loss = epoch_g_loss / num_batches
            avg_d_loss = epoch_d_loss / num_batches
            logger.info(f"GAN epoch {epoch+1} completed: "
                        f"Average G loss = {avg_g_loss:.4f}, "
                        f"Average D loss = {avg_d_loss:.4f}")
            
            # Evaluate GAN
            self._evaluate_gan(dataloader)
            
            # Save GAN checkpoint
            self.gan_module.save_checkpoint(
                os.path.join(self.config.checkpoint_dir, f"gan_checkpoint_epoch_{epoch+1}")
            )
    
    def _evaluate_gan(self, dataloader):
        """Evaluate GAN quality"""
        # Sample a few examples from validation set
        examples = []
        
        # Use tqdm for sampling
        pbar = tqdm(dataloader, desc="Sampling for GAN evaluation", leave=False)
        for batch in pbar:
            questions = batch['question']
            steps_list = batch['steps']
            
            for question, steps in zip(questions, steps_list):
                if len(steps) >= 3:  # Ensure enough steps for evaluation
                    examples.append({
                        'question': question,
                        'steps': steps
                    })
                
                # Limit evaluation examples
                if len(examples) >= 10:
                    break
            
            if len(examples) >= 10:
                break
                
            # Update progress bar
            pbar.set_postfix({"samples": len(examples)})
        
        # Evaluate GAN refinement quality
        logger.info("Evaluating GAN refinement quality:")
        success_count = 0
        
        # Progress bar for evaluation
        pbar = tqdm(examples, desc="GAN Evaluation", leave=False)
        for example in pbar:
            # Get partial reasoning chain (first half of steps)
            partial_steps = example['steps'][:len(example['steps']) // 2]
            
            # Generate continuation with GAN
            refined_steps = self.gan_module.refine_reasoning_steps(
                example['question'],
                partial_steps
            )
            
            # Evaluate with reflection module
            original_score = self.reflection_module.evaluate_reasoning(
                example['question'],
                example['steps']
            )['scores']['overall_score']
            
            refined_score = self.reflection_module.evaluate_reasoning(
                example['question'],
                partial_steps + refined_steps
            )['scores']['overall_score']
            
            # Check improvement
            improvement = refined_score - original_score
            status = "✓" if improvement >= 0 else "✗"
            
            # Update progress bar
            pbar.set_postfix({
                "orig": f"{original_score:.2f}",
                "refined": f"{refined_score:.2f}",
                "diff": f"{improvement:.2f} {status}"
            })
            
            # Log individual example
            logger.info(f"  Example {examples.index(example)+1}: "
                        f"Original: {original_score:.2f}, "
                        f"Refined: {refined_score:.2f}, "
                        f"Diff: {improvement:.2f} {status}")
            
            if improvement >= 0:
                success_count += 1
        
        success_rate = success_count / len(examples) * 100
        logger.info(f"GAN evaluation complete: Success rate = {success_rate:.1f}%")
        
        return success_rate
    
    def _rl_finetuning(self, dataloader, num_updates):
        """Perform RL fine-tuning"""
        # Setup learning rate scheduler
        scheduler = get_linear_schedule_with_warmup(
            self.ppo_trainer.optimizer,
            num_warmup_steps=int(0.1 * num_updates),
            num_training_steps=num_updates
        )
        
        # RL training loop
        global_step = 0
        best_reward = float('-inf')
        
        logger.info(f"Starting RL fine-tuning: {num_updates} updates")
        
        # Create progress bar for RL updates
        pbar = tqdm(total=num_updates, desc="RL Fine-tuning", leave=True)
        
        while global_step < num_updates:
            # Sample batch from dataloader
            for batch in dataloader:
                # Perform PPO update
                metrics = self.ppo_trainer.train_step(batch, self.gan_module)
                
                # Update learning rate
                scheduler.step()
                
                # Update progress bar
                global_step += 1
                pbar.update(1)
                pbar.set_postfix({
                    "policy_loss": f"{metrics['policy_loss']:.4f}", 
                    "value_loss": f"{metrics['value_loss']:.4f}", 
                    "reward": f"{metrics['mean_reward']:.4f}"
                })
                
                # Log metrics periodically
                if global_step % 10 == 0:
                    logger.info(
                        f"RL step {global_step}/{num_updates}: "
                        f"Policy loss = {metrics['policy_loss']:.4f}, "
                        f"Value loss = {metrics['value_loss']:.4f}, "
                        f"Mean reward = {metrics['mean_reward']:.4f}"
                    )
                
                # Save best model
                if metrics['mean_reward'] > best_reward:
                    best_reward = metrics['mean_reward']
                    self.cot_generator.model.save_pretrained(
                        os.path.join(self.config.checkpoint_dir, "best_rl_model")
                    )
                    logger.info(f"New best reward: {best_reward:.4f} - Saved model")
                
                # Evaluate periodically
                if global_step % 100 == 0:
                    eval_reward = self._evaluate_rl()
                    logger.info(f"RL evaluation reward: {eval_reward:.4f}")
                    pbar.set_postfix({
                        "policy_loss": f"{metrics['policy_loss']:.4f}", 
                        "value_loss": f"{metrics['value_loss']:.4f}", 
                        "reward": f"{metrics['mean_reward']:.4f}",
                        "eval": f"{eval_reward:.4f}"
                    })
                
                # Check if we've reached the target number of updates
                if global_step >= num_updates:
                    break
        
        # Close the progress bar
        pbar.close()
        logger.info(f"RL fine-tuning complete: Best reward = {best_reward:.4f}")
    
    def _evaluate_rl(self, num_samples=10):
        """Evaluate current policy"""
        self.cot_generator.model.eval()
        total_reward = 0.0
        
        # Sample examples for evaluation
        examples = []
        temp_data = self.self_trainer.generate_pseudo_labels(
            [{'question': item} for item in self.config.eval_questions]
        )
        data_loader = DataLoader(temp_data, batch_size=1)
        
        # Use tqdm for sampling
        pbar = tqdm(data_loader, desc="Sampling for RL evaluation", leave=False)
        for batch in pbar:
            examples.append(batch)
            if len(examples) >= num_samples:
                break
            
            # Update progress bar
            pbar.set_postfix({"samples": len(examples)})
        
        # Evaluate on examples with progress bar
        eval_pbar = tqdm(examples, desc="RL Evaluation", leave=False)
        for example in eval_pbar:
            # Generate reasoning
            sample = self.cot_generator.generate_step_by_step(example['question'][0])
            
            # Calculate reward
            reward = self.reward_function.calculate_combined_reward(
                sample,
                example['answer'][0]
            )
            
            total_reward += reward['combined']
            
            # Update progress bar
            eval_pbar.set_postfix({"reward": f"{reward['combined']:.4f}"})
        
        # Set back to training mode
        self.cot_generator.model.train()
        
        # Calculate average reward
        avg_reward = total_reward / len(examples)
        return avg_reward
    
    def generate(self, question, image=None):
        """Generate reasoning for a question"""
        logger.info(f"Generating reasoning for: {question}")
        
        # Lookup relevant information
        logger.info("Retrieving context information...")
        retrieved_info = self.retrieval_module.retrieve(question)
        
        # Generate initial chain-of-thought
        logger.info("Generating initial chain-of-thought...")
        result = self.cot_generator.generate_step_by_step(
            question, 
            image=image,
            retrieved_context=retrieved_info
        )
        
        # Apply GAN refinement if available
        if self.gan_module is not None:
            logger.info("Applying GAN refinement...")
            refined_steps = self.gan_module.refine_reasoning_steps(
                question,
                result['steps']
            )
            result['refined_steps'] = refined_steps
            
            # Evaluate if refinement is better
            logger.info("Evaluating refinement quality...")
            original_score = self.reflection_module.evaluate_reasoning(
                question,
                result['steps']
            )['scores']['overall_score']
            
            refined_score = self.reflection_module.evaluate_reasoning(
                question,
                refined_steps
            )['scores']['overall_score']
            
            # Use refined steps if better
            if refined_score > original_score:
                logger.info(f"Using refined steps (score improved: {original_score:.2f} -> {refined_score:.2f})")
                result['steps'] = refined_steps
                result['used_refinement'] = True
            else:
                logger.info(f"Keeping original steps (refinement score: {refined_score:.2f} vs original: {original_score:.2f})")
        
        # Get final reflection
        logger.info("Generating final reflection...")
        reflection = self.reflection_module.evaluate_reasoning(
            question,
            result['steps']
        )
        
        logger.info("Reasoning generation complete")
        
        return {
            'question': question,
            'steps': result['steps'],
            'answer': result['answer'],
            'reflection': reflection,
            'retrieved_context': retrieved_info
        }

In [None]:
# Create output directories
os.makedirs(config.output_dir, exist_ok=True)
os.makedirs(config.checkpoint_dir, exist_ok=True)

# Setup logging
logger = logging.getLogger(__name__)
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
    handlers=[
        logging.FileHandler(os.path.join(config.output_dir, "train.log")),
        logging.StreamHandler()
    ]
)
logger.info(f"Configuration: {vars(config)}")

# Set random seeds for reproducibility
def set_seed(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# Sample evaluation questions
config.eval_questions = [
    "What happens when water boils?",
    "How does gravity work?",
    "Why does the moon have phases?",
    "What is photosynthesis?",
    "How do magnets work?"
]

# Initialize pipeline
pipeline = ReasoningPipeline(config, bert_model, bert_tokenizer, vision_processor)

# Train pipeline
trained_components = pipeline.train(
    config.train_path,
    config.val_path,
    num_rl_updates=config.rl_updates,
    num_self_training_iterations=config.self_training_iterations
)

# Evaluate final model
logger.info("Evaluating final model")

# Create test dataloader
test_dataset = ScienceQADataset(
    config.val_path,
    pipeline.tokenizer,
    pipeline.vision_processor,
    config,
    is_train=False
)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=config.batch_size,
    shuffle=False
)

# Calculate validation metrics
val_loss = pipeline._evaluate(test_dataloader)
logger.info(f"Final validation loss: {val_loss:.4f}")

# Generate examples for qualitative evaluation
logger.info("Generating example outputs")
for question in config.eval_questions:
    result = pipeline.generate(question)
    logger.info(f"Question: {question}")
    logger.info(f"Steps:")
    for i, step in enumerate(result['steps']):
        logger.info(f" {i+1}. {step}")
    logger.info(f"Answer: {result['answer']}")
    logger.info(f"Reflection score: {result['reflection']['scores']['overall_score']:.2f}")
    logger.info("---")

logger.info("Training and evaluation complete")

2025-03-06 12:26:35,797 - __main__ - INFO - Configuration: {'model_name': 'bert-base-uncased', 'tokenizer_name': 'bert-base-uncased', 'vision_model_name': 'openai/clip-vit-base-patch32', 'file_path': '/Users/Viku/Datasets/ScienceQA', 'train_path': '/Users/Viku/Datasets/ScienceQA/train/train.json', 'val_path': '/Users/Viku/Datasets/ScienceQA/val/val.json', 'max_seq_length': 512, 'batch_size': 4, 'learning_rate': 5e-05, 'weight_decay': 0.01, 'epochs': 3, 'warmup_steps': 100, 'max_grad_norm': 1.0, 'gradient_accumulation_steps': 8, 'ppo_epochs': 4, 'reward_scale': 0.01, 'clip_param': 0.2, 'value_loss_coef': 0.5, 'entropy_coef': 0.01, 'gan_learning_rate': 2e-05, 'gan_weight_decay': 0.01, 'gan_batch_size': 16, 'gan_epochs': 2, 'retrieval_top_k': 3, 'embedding_dim': 768, 'reflection_threshold': 0.7, 'output_dir': 'outputs/', 'checkpoint_dir': 'checkpoints/', 'exemplar_path': 'data/exemplars.json', 'device': device(type='mps'), 'max_answer_length': 64, 'rl_updates': 1000, 'self_training_iterat

ChainOfThoughtGenerator initialized with:
  Tokenizer: bert-base-uncased
  Model: bert-base-uncased
  Vision Model: openai/clip-vit-base-patch32
  Device: mps


If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`
2025-03-06 12:26:48,922 - __main__ - INFO - Loading ScienceQA data from /Users/Viku/Datasets/ScienceQA/train/train.json
2025-03-06 12:27:31,651 - __main__ - INFO - Processed 12726 ScienceQA examples
2025-03-06 12:27:31,659 - __main__ - INFO - Loading ScienceQA data from /Users/Viku/Datasets/ScienceQA/val/val.json
2025-03-06 12:27:38,995 - __main__ - INFO - Processed 4241 ScienceQA examples
2025-03-06 12:27:38,999 - __main__ - INFO - Starting supervised fine-tuning
Epoch 1/3 [Train]:   0%|          | 0/3182 [00:00<?, ?it/s]`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.
Epoch 1/3 [Train]:   2%|▏         | 77/3182 [03:56<2:48:05,  3.25s/it, loss=0.0001, step=77] 