# Finetuning with Llava as teacher/student

In [2]:
init_prompt_task = """
Enhanced Prompt
Identify the target sphere according to the description
Outline its position using a bounding box and provide its coordinates in the format:

x0 (left)
y0 (top)
x1 (right)
y1 (bottom)

Format for Response:
"Bounding box coordinates: [x0, y0, x1, y1]"
"""
init_prompt_instruct = """
Describe the location of the blue sphere relative to the environment features.
"""


In [1]:
import torch
from torch.distributions import Normal
import numpy as np
from transformers import AutoProcessor, LlavaForConditionalGeneration
import torch.nn.functional as F
from tqdm import tqdm
from datasets import load_dataset
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
)

device = "cuda:1"
# Initialize models and LoRA configuration
model_name = "llava-hf/llava-1.5-7b-hf"

def create_lora_model(base_model, device, trainable=True):
    """Create a LoRA-adapted model"""
    # LoRA configuration
    lora_config = LoraConfig(
        r=8,  # rank
        lora_alpha=32,  # alpha scaling
        target_modules=[
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM"
    )
    # Add LoRA adapters
    model = get_peft_model(base_model, lora_config)
    return model

# Create base models
base_speaker_model = LlavaForConditionalGeneration.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
).to(device)

listener_model = LlavaForConditionalGeneration.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
).to(device).eval()

# Apply LoRA
speaker_model = create_lora_model(base_speaker_model, 'cuda:7', trainable=True)
processor = AutoProcessor.from_pretrained(model_name)
# dataset = load you own dataset

  from .autonotebook import tqdm as notebook_tqdm
The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.
0it [00:00, ?it/s]
Loading checkpoint shards: 100%|██████████| 3/3 [00:02<00:00,  1.15it/s]
Loading checkpoint shards: 100%|██████████| 3/3 [00:02<00:00,  1.15it/s]
Some kwargs in processor config are unused and will not have any effect: num_additional_image_tokens. 


In [3]:
def create_dummy_image():
    """Creates a dummy image tensor. Replace with actual image data as needed."""
    # For example, a random RGB image of size 224x224
    return torch.randn(3, 224, 224)

# Example Dataset with 2 Dummy Examples
dataset = [
    {
        'speaker_view_image': create_dummy_image(),
        'listener_view_image': create_dummy_image(),
        'listener_target_bbox': [50, 50, 150, 150],
        'listener_distractor_0_bbox': [30, 30, 100, 100],
        'listener_distractor_1_bbox': [160, 160, 220, 220],
    },
    {
        'speaker_view_image': create_dummy_image(),
        'listener_view_image': create_dummy_image(),
        'listener_target_bbox': [60, 60, 140, 140],
        'listener_distractor_0_bbox': [20, 20, 80, 80],
        'listener_distractor_1_bbox': [170, 170, 230, 230],
    }
]

# PPO

In [None]:
import torch
from torch.distributions import Normal
import numpy as np
from transformers import AutoProcessor, LlavaForConditionalGeneration
import torch.nn.functional as F
from tqdm import tqdm
from datasets import load_dataset
import copy

class DualLLaVATrainer:
    def __init__(
        self, 
        speaker_model, 
        listener_model, 
        processor, 
        learning_rate=1e-5, 
        gamma=0.99, 
        epsilon=0.2, 
        c1=1, 
        c2=0.01
    ):
        self.speaker_model = speaker_model
        self.listener_model = listener_model
        self.processor = processor
        self.optimizer = torch.optim.Adam(speaker_model.parameters(), lr=learning_rate)
        self.gamma = gamma
        self.epsilon = epsilon
        self.c1 = c1
        self.c2 = c2
        
        # Freeze listener model
        for param in self.listener_model.parameters():
            param.requires_grad = False
        self.listener_model.eval()
        
    def calculate_iou(self, pred_box, gt_box):
        """Calculate IoU between predicted and ground truth boxes"""
        if isinstance(gt_box, str):
            gt_box = eval(gt_box)
        
        pred_x1, pred_y1, pred_x2, pred_y2 = pred_box
        gt_x1, gt_y1, gt_x2, gt_y2 = gt_box
        
        x1 = max(pred_x1, gt_x1)
        y1 = max(pred_y1, gt_y1)
        x2 = min(pred_x2, gt_x2)
        y2 = min(pred_y2, gt_y2)
        
        intersection = max(0, x2 - x1) * max(0, y2 - y1)
        
        pred_area = (pred_x2 - pred_x1) * (pred_y2 - pred_y1)
        gt_area = (gt_x2 - gt_x1) * (gt_y2 - gt_y1)
        union = pred_area + gt_area - intersection
        
        return intersection / (union + 1e-6)

    def get_bbox_from_output(self, output_text):
        """Extract bounding box coordinates from model output"""
        try:
            import re
            coords = re.findall(r'\[([\d\.,\s]+)\]', output_text)
            if coords:
                return [float(x) for x in coords[-1].split(',')]
            return None
        except:
            return None

    def get_listener_predictions(self, images, speaker_messages, batch):
        """Get bounding box predictions from listener model by selecting from available boxes"""
        boxes = [
            batch[0]['listener_target_bbox'],
            batch[0]['listener_distractor_0_bbox'],
            batch[0]['listener_distractor_1_bbox']
        ]
        # Format the boxes as choices
        box_choices = f"""
        Boxes:
        A={boxes[0]} B={boxes[1]} C={boxes[2]}
        Description: {speaker_messages[0]}
        Choose box A, B, or C.
        """
        # Get speaker outputs (old policy)
        conversation = [
            {
            "role": "user",
            "content": [
                {"type": "text", "text": box_choices},
                {"type": "image"},
                ],
            },
        ]
        prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
        # print(prompt)
        inputs = processor(images=images[0], text=prompt, return_tensors='pt').to(self.listener_model.device).to(torch.bfloat16)
        # Generate predictions
        with torch.no_grad():
            outputs = self.listener_model.generate(**inputs, max_length=300)
            texts = self.processor.batch_decode(outputs, skip_special_tokens=True)
            
        # Parse selections and convert to boxes
        selected_boxes = []
        # print(texts[0])
        for i, text in enumerate(texts):
            # Extract the selection (A, B, or C)
            if 'A' in text.upper() or 'CHOICE A' in text.upper() or 'SELECT A' in text.upper():
                selected_boxes.append(eval(batch[i]['listener_target_bbox']))
            elif 'B' in text.upper() or 'CHOICE B' in text.upper() or 'SELECT B' in text.upper():
                selected_boxes.append(eval(batch[i]['listener_distractor_0_bbox']))
            elif 'C' in text.upper() or 'CHOICE C' in text.upper() or 'SELECT C' in text.upper():
                selected_boxes.append(eval(batch[i]['listener_distractor_1_bbox']))
            else:
                # Default to target box if no clear selection (you might want to handle this differently)
                selected_boxes.append(eval(batch[i]['listener_distractor_0_bbox']))
        
        return selected_boxes

    def compute_rewards(self, listener_boxes, gt_boxes):
        """Compute rewards based on listener's box predictions"""
        rewards = []
        for pred, gt in zip(listener_boxes, gt_boxes):
            if pred is None:
                rewards.append(-1.0)
            else:
                iou = self.calculate_iou(pred, gt)
                rewards.append(iou)  # Use raw IOU as reward
        return torch.tensor(rewards, device=self.speaker_model.device)

    def train_step(self, batch):
        """Single training step"""
        speaker_images = [item['speaker_view_image'] for item in batch]
        listener_images = [item['listener_view_image'] for item in batch]
        gt_boxes = [item['listener_target_bbox'] for item in batch]
        
        # Get speaker outputs (old policy)
        conversation = [
            {
            "role": "user",
            "content": [
                {"type": "text", "text": init_prompt_instruct},
                {"type": "image"},
                ],
            },
        ]
        prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
        speaker_inputs = processor(images=speaker_images[0], text=prompt, return_tensors='pt').to(self.speaker_model.device).to(torch.bfloat16)

        with torch.no_grad():
            old_outputs = self.speaker_model.generate(
                **speaker_inputs,
                max_length=300,
                num_beams=1,
                do_sample=True,
                temperature=0.7
            )
            old_logprobs = self.speaker_model(**speaker_inputs).logits
            old_messages = self.processor.batch_decode(
                old_outputs, 
                skip_special_tokens=True
            )
            old_messages = [item.split('ASSISTANT: ')[-1][:100] for item in old_messages]
        
        # Get listener predictions and compute rewards
        listener_boxes = self.get_listener_predictions(listener_images, old_messages, batch)
        rewards = self.compute_rewards(listener_boxes, gt_boxes).to(torch.bfloat16)
        # Store the old policy outputs
        # print(rewards.shape)
        old_policy = {
            'logprobs': old_logprobs.detach(),
            'rewards': rewards.detach(),
        }
        
        # PPO update for speaker using stored values
        for _ in range(3):
            outputs = self.speaker_model(**speaker_inputs)
            new_logprobs = outputs.logits
            
            # Compute policy ratio using stored logprobs
            ratio = torch.exp(new_logprobs - old_policy['logprobs'])
            clipped_ratio = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon)
            # print(ratio.shape)
            # Compute losses using stored rewards
            policy_loss = -torch.min(
                ratio * old_policy['rewards'], 
                clipped_ratio * old_policy['rewards']
            ).mean()
            value_loss = F.mse_loss(new_logprobs, old_policy['rewards'].unsqueeze(-1))
            entropy_loss = -torch.mean(
                torch.distributions.Categorical(logits=new_logprobs).entropy()
            )
            
            total_loss = policy_loss + self.c1 * value_loss + self.c2 * entropy_loss
            
            self.optimizer.zero_grad()
            total_loss.backward()
            self.optimizer.step()
            
        return total_loss.item(), rewards.mean().item()

def train_dual_llava(num_epochs=10, batch_size=1):
    trainer = DualLLaVATrainer(speaker_model, listener_model, processor)
    # Training loop
    for epoch in range(num_epochs):
        epoch_losses = []
        epoch_rewards = []
        
        for i in tqdm(range(0, len(dataset), batch_size)):
            try:
                batch = [dataset[index] for index in range(i, i+batch_size)]
                loss, reward = trainer.train_step(batch)
                epoch_losses.append(loss)
                epoch_rewards.append(reward)
                
                # Print examples periodically
                if i % (batch_size * 10) == 0:
                    print("\nExample outputs:")
                    # Get speaker message
                    conversation = [
                        {
                            "role": "user",
                            "content": [
                                {"type": "text", "text": init_prompt_instruct},
                                {"type": "image"},
                            ],
                        },
                    ]
                    prompt = processor.apply_chat_template(
                        conversation, 
                        add_generation_prompt=True,
                        max_length=200  # Limit input length
                    )
                    speaker_input = processor(
                        images=batch[0]['speaker_view_image'], 
                        text=prompt, 
                        return_tensors='pt',
                        max_length=200,  # Limit input length
                        truncation=True
                    ).to(trainer.speaker_model.device).to(torch.bfloat16)
                    
                    generated_message = trainer.processor.batch_decode(
                        trainer.speaker_model.generate(
                            **speaker_input, 
                            max_length=100,
                            temperature=0.7
                        ),
                        skip_special_tokens=True
                    )[0]
                    generated_message = generated_message.split('ASSISTANT: ')[-1][:100]
                    
                    # Get listener prediction with bounding box choices
                    boxes = [
                        batch[0]['listener_target_bbox'],
                        batch[0]['listener_distractor_0_bbox'],
                        batch[0]['listener_distractor_1_bbox']
                    ]
                    box_choices = f"""Boxes:
                        A={boxes[0]} B={boxes[1]} C={boxes[2]}
                        Description: {generated_message}
                        Choose box A, B, or C."""
                    
                    conversation = [
                        {
                            "role": "user",
                            "content": [
                                {"type": "text", "text": box_choices},
                                {"type": "image"},
                            ],
                        },
                    ]
                    prompt = processor.apply_chat_template(
                        conversation, 
                        add_generation_prompt=True,
                        max_length=300  # Limit input length
                    )
                    listener_input = processor(
                        images=batch[0]['listener_view_image'], 
                        text=prompt, 
                        return_tensors='pt',
                        max_length=300,  # Limit input length
                        truncation=True
                    ).to(trainer.listener_model.device).to(torch.bfloat16)
                    
                    listener_output = trainer.processor.batch_decode(
                        trainer.listener_model.generate(
                            **listener_input, 
                            max_length=300
                        ),
                        skip_special_tokens=True
                    )[0]
                    
                    print(f"Speaker Message: {generated_message}")
                    print(f"Listener Response: {listener_output}")
                    print(f"Ground Truth Box: {batch[0]['listener_target_bbox']}")
                    if 'A' in listener_output.upper():
                        print("Selected: Box A (Target)")
                    elif 'B' in listener_output.upper():
                        print("Selected: Box B (Distractor 0)")
                    elif 'C' in listener_output.upper():
                        print("Selected: Box C (Distractor 1)")
                    else:
                        print("No clear selection")
                    print()
            except:
                pass
        
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"Average Loss: {np.mean(epoch_losses):.4f}")
        print(f"Average Reward: {np.mean(epoch_rewards):.4f}")
        
        # Save speaker checkpoint
        speaker_model.save_pretrained(f"llava_speaker_dual_checkpoint_epoch")
train_dual_llava()

# Preference Learning

In [6]:
import torch
from torch.distributions import Categorical
import numpy as np
from transformers import AutoProcessor, LlavaForConditionalGeneration
import torch.nn.functional as F
from tqdm import tqdm
from datasets import load_dataset, Dataset
import copy
import re

class DualLLaVATrainer:
    def __init__(
        self, 
        speaker_model, 
        listener_model, 
        processor, 
        learning_rate=1e-5, 
        gamma=0.99, 
        epsilon=0.2, 
        c1=1, 
        c2=0.01
    ):
        self.speaker_model = speaker_model
        self.listener_model = listener_model
        self.processor = processor
        self.optimizer = torch.optim.Adam(speaker_model.parameters(), lr=learning_rate)
        self.gamma = gamma
        self.epsilon = epsilon
        self.c1 = c1
        self.c2 = c2
        
        # Freeze listener model
        for param in self.listener_model.parameters():
            param.requires_grad = False
        self.listener_model.eval()
        
    def calculate_iou(self, pred_box, gt_box):
        """Calculate IoU between predicted and ground truth boxes"""
        if isinstance(gt_box, str):
            gt_box = eval(gt_box)
        
        pred_x1, pred_y1, pred_x2, pred_y2 = pred_box
        gt_x1, gt_y1, gt_x2, gt_y2 = gt_box
        
        x1 = max(pred_x1, gt_x1)
        y1 = max(pred_y1, gt_y1)
        x2 = min(pred_x2, gt_x2)
        y2 = min(pred_y2, gt_y2)
        
        intersection = max(0, x2 - x1) * max(0, y2 - y1)
        
        pred_area = (pred_x2 - pred_x1) * (pred_y2 - pred_y1)
        gt_area = (gt_x2 - gt_x1) * (gt_y2 - gt_y1)
        union = pred_area + gt_area - intersection
        
        return intersection / (union + 1e-6)
    
    def get_bbox_from_output(self, output_text):
        """Extract bounding box coordinates from model output"""
        try:
            coords = re.findall(r'\[([\d\.,\s]+)\]', output_text)
            if coords:
                return [float(x) for x in coords[-1].split(',')]
            return None
        except:
            return None
    
    def get_listener_predictions(self, images, speaker_messages, batch):
        """Get bounding box predictions from listener model by selecting from available boxes"""
        boxes = [
            batch[0]['listener_target_bbox'],
            batch[0]['listener_distractor_0_bbox'],
            batch[0]['listener_distractor_1_bbox']
        ]
        # Format the boxes as choices
        box_choices = f"""
Boxes:
A={boxes[0]} B={boxes[1]} C={boxes[2]}
Description: {speaker_messages[0]}
Choose box A, B, or C.
"""
        # Get speaker outputs (old policy)
        conversation = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": box_choices},
                    {"type": "image"},
                ],
            },
        ]
        prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
        # Assuming images are pre-loaded; replace 'path/to/image' with actual image tensors if necessary
        inputs = self.processor(images=images[0], text=prompt, return_tensors='pt').to(self.listener_model.device)
        # Generate predictions
        with torch.no_grad():
            outputs = self.listener_model.generate(**inputs, max_length=300)
            texts = self.processor.batch_decode(outputs, skip_special_tokens=True)
            
        # Parse selections and convert to boxes
        selected_boxes = []
        for i, text in enumerate(texts):
            # Extract the selection (A, B, or C) using regex for robustness
            match = re.search(r'\b(A|B|C)\b', text.upper())
            if match:
                choice = match.group(1)
                if choice == 'A':
                    selected_boxes.append(eval(batch[i]['listener_target_bbox']))
                elif choice == 'B':
                    selected_boxes.append(eval(batch[i]['listener_distractor_0_bbox']))
                elif choice == 'C':
                    selected_boxes.append(eval(batch[i]['listener_distractor_1_bbox']))
            else:
                # Default to distractor 0 if no clear selection
                selected_boxes.append(eval(batch[i]['listener_distractor_0_bbox']))
        
        return selected_boxes
    
    def compute_rewards_ppl(self, batch, chosen_boxes):
        """Compute rewards based on Pairwise Preference Learning"""
        rewards = []
        for i, (chosen_box, example) in enumerate(zip(chosen_boxes, batch)):
            target_box = eval(example['listener_target_bbox'])
            intended_target = target_box
            chosen_target = chosen_box
            
            if chosen_target == intended_target:
                # Successful communication
                rewards.append(1.0)
            else:
                # Communicative failure: reward is p(x|chosen) - p(x|intended)
                # Compute p_speaker(refex | scene, referents, chosen)
                # and p_speaker(refex | scene, referents, intended)
                # Assuming we have a method to compute these probabilities
                # For simplicity, using placeholder probabilities
                # Replace with actual probability computations
                # Example:
                # p_chosen = compute_probability(example, chosen_target)
                # p_intended = compute_probability(example, intended_target)
                # reward = p_chosen - p_intended
                # Here, we'll use dummy values
                p_chosen = 0.6  # Placeholder
                p_intended = 0.4  # Placeholder
                reward = p_chosen - p_intended
                rewards.append(reward)
        
        return torch.tensor(rewards, device=self.speaker_model.device)
    
    def compute_rewards(self, listener_boxes, batch):
        """Compute rewards based on listener's box predictions using PPL"""
        rewards = []
        for pred, example in zip(listener_boxes, batch):
            if pred is None:
                rewards.append(-1.0)
            else:
                target_box = eval(example['listener_target_bbox'])
                chosen_box = pred
                if chosen_box == target_box:
                    rewards.append(1.0)
                else:
                    # Compute p_speaker(x|chosen) - p_speaker(x|intended)
                    # Placeholder for actual computation
                    # Replace with actual model probability computations
                    # For demonstration, using a fixed difference
                    p_chosen = 0.7  # Example probability for chosen target
                    p_intended = 0.3  # Example probability for intended target
                    reward = p_chosen - p_intended
                    rewards.append(reward)
        return torch.tensor(rewards, device=self.speaker_model.device)
    
    def train_step(self, batch):
        """Single training step with PPL"""
        speaker_images = [item['speaker_view_image'] for item in batch]
        listener_images = [item['listener_view_image'] for item in batch]
        gt_boxes = [item['listener_target_bbox'] for item in batch]
        
        # Get speaker outputs (old policy)
        conversation = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": "Please describe the target object."},
                    {"type": "image"},
                ],
            },
        ]
        prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
        speaker_inputs = self.processor(images=speaker_images, text=prompt, return_tensors='pt', padding=True).to(self.speaker_model.device)
        
        # Generate referring expressions
        with torch.no_grad():
            old_outputs = self.speaker_model.generate(
                **speaker_inputs,
                max_length=100,
                num_beams=1,
                do_sample=True,
                temperature=0.7
            )
            old_messages = self.processor.batch_decode(
                old_outputs, 
                skip_special_tokens=True
            )
            old_messages = [msg.split('ASSISTANT: ')[-1].strip() for msg in old_messages]
        
        # Get listener predictions
        listener_boxes = self.get_listener_predictions(listener_images, old_messages, batch)
        
        # Compute rewards using PPL
        rewards = self.compute_rewards(listener_boxes, batch).detach()
        
        # Encode the inputs again for gradient computation
        speaker_inputs = self.processor(images=speaker_images, text=prompt, return_tensors='pt', padding=True).to(self.speaker_model.device)
        
        # Forward pass to get logits
        outputs = self.speaker_model(**speaker_inputs, labels=old_outputs)
        log_probs = -F.cross_entropy(outputs.logits.view(-1, outputs.logits.size(-1)), old_outputs.view(-1), reduction='none')
        log_probs = log_probs.view(old_outputs.size())  # Reshape to match outputs
        
        # Mask padding tokens
        attention_mask = speaker_inputs['attention_mask']
        log_probs = (log_probs * attention_mask).sum(dim=1)
        
        # Compute policy ratio
        ratios = torch.exp(log_probs - log_probs.detach())
        
        # Compute surrogate losses
        surrogate1 = ratios * rewards
        surrogate2 = torch.clamp(ratios, 1 - self.epsilon, 1 + self.epsilon) * rewards
        policy_loss = -torch.min(surrogate1, surrogate2).mean()
        
        # Value loss (optional, can be implemented if using a value network)
        # Here, we'll omit it for simplicity
        # value_loss = ...
        
        # Entropy loss for exploration
        # Assuming logits are available
        entropy = Categorical(logits=outputs.logits).entropy().mean()
        entropy_loss = -entropy
        
        # Total loss
        total_loss = policy_loss + self.c1 * 0 + self.c2 * entropy_loss  # Assuming no value loss
        
        # Backpropagation
        self.optimizer.zero_grad()
        total_loss.backward()
        self.optimizer.step()
        
        return total_loss.item(), rewards.mean().item()

def train_dual_llava(num_epochs=10, batch_size=1):
    # Initialize models and processor
    trainer = DualLLaVATrainer(speaker_model, listener_model, processor)
    
    # Training loop
    for epoch in range(num_epochs):
        epoch_losses = []
        epoch_rewards = []
        
        for i in tqdm(range(0, len(dataset), batch_size)):
            try:
                batch = [dataset[index] for index in range(i, min(i + batch_size, len(dataset)))]
                loss, reward = trainer.train_step(batch)
                epoch_losses.append(loss)
                epoch_rewards.append(reward)
                
                # Print examples periodically
                if i % (batch_size * 1) == 0:
                    print("\nExample outputs:")
                    # Get speaker message
                    conversation = [
                        {
                            "role": "user",
                            "content": [
                                {"type": "text", "text": "Please describe the target object."},
                                {"type": "image"},
                            ],
                        },
                    ]
                    prompt = processor.apply_chat_template(
                        conversation, 
                        add_generation_prompt=True,
                        max_length=200  # Limit input length
                    )
                    speaker_input = processor(
                        images=batch[0]['speaker_view_image'], 
                        text=prompt, 
                        return_tensors='pt',
                        max_length=200,  # Limit input length
                        truncation=True,
                        padding=True
                    ).to(trainer.speaker_model.device)
                    
                    generated_outputs = trainer.speaker_model.generate(
                        **speaker_input, 
                        max_length=100,
                        temperature=0.7
                    )
                    generated_message = processor.batch_decode(
                        generated_outputs,
                        skip_special_tokens=True
                    )[0]
                    generated_message = generated_message.split('ASSISTANT: ')[-1].strip()[:100]
                    
                    # Get listener prediction with bounding box choices
                    boxes = [
                        batch[0]['listener_target_bbox'],
                        batch[0]['listener_distractor_0_bbox'],
                        batch[0]['listener_distractor_1_bbox']
                    ]
                    box_choices = f"""Boxes:
A={boxes[0]} B={boxes[1]} C={boxes[2]}
Description: {generated_message}
Choose box A, B, or C."""
                    
                    conversation = [
                        {
                            "role": "user",
                            "content": [
                                {"type": "text", "text": box_choices},
                                {"type": "image"},
                            ],
                        },
                    ]
                    prompt = processor.apply_chat_template(
                        conversation, 
                        add_generation_prompt=True,
                        max_length=300  # Limit input length
                    )
                    listener_input = processor(
                        images=batch[0]['listener_view_image'], 
                        text=prompt, 
                        return_tensors='pt',
                        max_length=300,  # Limit input length
                        truncation=True,
                        padding=True
                    ).to(trainer.listener_model.device)
                    
                    listener_output = trainer.listener_model.generate(
                        **listener_input, 
                        max_length=300
                    )
                    listener_response = processor.batch_decode(
                        listener_output,
                        skip_special_tokens=True
                    )[0]
                    
                    print(f"Speaker Message: {generated_message}")
                    print(f"Listener Response: {listener_response}")
                    print(f"Ground Truth Box: {batch[0]['listener_target_bbox']}")
                    if 'A' in listener_response.upper():
                        print("Selected: Box A (Target)")
                    elif 'B' in listener_response.upper():
                        print("Selected: Box B (Distractor 0)")
                    elif 'C' in listener_response.upper():
                        print("Selected: Box C (Distractor 1)")
                    else:
                        print("No clear selection")
                    print()
            except Exception as e:
                print(f"Error during training step: {e}")
                continue
        
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"Average Loss: {np.mean(epoch_losses):.4f}")
        print(f"Average Reward: {np.mean(epoch_rewards):.4f}")
        
        # Save speaker checkpoint
        speaker_model.save_pretrained(f"llava_speaker_ppl_checkpoint_epoch_{epoch+1}")

# Example Usage
# Ensure that the paths to the processor and models are correctly specified
train_dual_llava(num_epochs=5, batch_size=2)


100%|██████████| 1/1 [00:00<00:00, 48.16it/s]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Error during training step: The image to be converted to a PIL image contains values outside the range [0, 1], got [-4.688628196716309, 4.558362007141113] which cannot be converted to uint8.
Epoch 1/5
Average Loss: nan
Average Reward: nan


100%|██████████| 1/1 [00:00<00:00, 332.01it/s]


Error during training step: The image to be converted to a PIL image contains values outside the range [0, 1], got [-4.688628196716309, 4.558362007141113] which cannot be converted to uint8.
Epoch 2/5
Average Loss: nan
Average Reward: nan


100%|██████████| 1/1 [00:00<00:00, 321.92it/s]


Error during training step: The image to be converted to a PIL image contains values outside the range [0, 1], got [-4.688628196716309, 4.558362007141113] which cannot be converted to uint8.
Epoch 3/5
Average Loss: nan
Average Reward: nan


100%|██████████| 1/1 [00:00<00:00, 327.32it/s]


Error during training step: The image to be converted to a PIL image contains values outside the range [0, 1], got [-4.688628196716309, 4.558362007141113] which cannot be converted to uint8.
Epoch 4/5
Average Loss: nan
Average Reward: nan


100%|██████████| 1/1 [00:00<00:00, 311.17it/s]


Error during training step: The image to be converted to a PIL image contains values outside the range [0, 1], got [-4.688628196716309, 4.558362007141113] which cannot be converted to uint8.
Epoch 5/5
Average Loss: nan
Average Reward: nan


# Generative Finetuning

In [None]:
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
from tqdm import tqdm
from datasets import load_dataset

model_name = "llava-hf/llava-1.5-7b-hf"
model = LlavaForConditionalGeneration.from_pretrained(
    model_name,
    torch_dtype=torch.float16
).to(device).to(torch.float16)
processor = AutoProcessor.from_pretrained(model_name)
rephrase = False
# dataset = load you own dataset


In [None]:
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
from peft import get_peft_model, LoraConfig, TaskType
from tqdm import tqdm
from datasets import load_dataset
import numpy as np

from openai import OpenAI
# client = OpenAI(api_key='your-key')

def setup_lora_model():
    # Initialize base model
    model_name = "llava-hf/llava-1.5-7b-hf"
    base_model = LlavaForConditionalGeneration.from_pretrained(
        model_name,
        torch_dtype=torch.float16
    ).to('cuda:7')
    processor = AutoProcessor.from_pretrained(model_name)
    
    # Define LoRA configuration
    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=8,  # rank of LoRA update matrices
        lora_alpha=32,  # alpha scaling factor
        lora_dropout=0.1,
        target_modules=[
            "q_proj",
            "v_proj",
            "k_proj",
        ],
    )
    
    # Create PEFT model
    model = get_peft_model(base_model, lora_config)
    model.print_trainable_parameters()
    
    return model, processor

class LoRASpeakerTrainer:
    def __init__(
        self,
        model,
        processor,
        learning_rate=1e-4  # Higher learning rate for LoRA
    ):
        self.model = model
        self.processor = processor
        # Only optimize LoRA parameters
        self.optimizer = torch.optim.AdamW(
            [p for n, p in model.named_parameters() if "lora" in n.lower()],
            lr=learning_rate
        )
        
    def train_step(self, batch):
        """Single training step with LoRA parameters"""
        # Get speaker image and reference message
        speaker_image = batch[0]['speaker_view_image']

        reference_message = batch[0]['human_speaker_message']
        if rephrase:
            chat_completion = client.chat.completions.create(
                messages=[
                    {
                        "role": "user",
                        "content": f"rephrase the below message: {reference_message}",
                    }
                ],
                model="gpt-3.5-turbo",
            )
            reference_message = chat_completion.choices[0].message.content

        # Prepare conversation prompt
        init_prompt_instruct = "Describe the location of the blue sphere relative to the environment features."
        conversation = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": init_prompt_instruct},
                    {"type": "image"},
                ],
            },
        ]
        prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
        
        # Process inputs
        inputs = self.processor(
            images=speaker_image,
            text=prompt,
            return_tensors='pt',
            max_length=200,
            truncation=True
        ).to(self.model.device)
        
        # Create labels
        labels = inputs["input_ids"].clone()
        labels[labels == self.processor.tokenizer.pad_token_id] = -100
        
        # Convert inputs to bfloat16 while preserving input_ids and attention_mask
        inputs = {
            k: v.to(torch.float16) if k not in ['input_ids', 'attention_mask'] else v
            for k, v in inputs.items()
        }
        inputs["labels"] = labels
        
        with torch.autocast("cuda:7", dtype=torch.float16):
            outputs = self.model(**inputs)
            loss = outputs.loss
        
        # Backward pass and optimize
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        torch.cuda.empty_cache()
        
        # Generate prediction for monitoring
        with torch.no_grad():
            generated = self.model.generate(
                **inputs,
                max_length=100,
                num_beams=1,
                do_sample=True,
                temperature=0.7
            )
            generated_message = self.processor.batch_decode(
                generated,
                skip_special_tokens=True
            )[0].split('ASSISTANT: ')[-1][:100]
        
        return loss.item(), reference_message, generated_message

def train_speaker_lora(num_epochs=10):
    # Initialize LoRA model and trainer
    model, processor = setup_lora_model()
    trainer = LoRASpeakerTrainer(model, processor)
    
    # Load dataset
    dataset = load_dataset("ZinengTang/PersReFex", split="validation")
    dataset = dataset.select(range(100))
    
    # Training loop
    best_loss = float('inf')
    for epoch in range(num_epochs):
        epoch_losses = []
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        
        for i in tqdm(range(len(dataset))):
            try:
                # Get single example
                batch = [dataset[i]]
                
                # Training step
                loss, reference, generated = trainer.train_step(batch)
                epoch_losses.append(loss)
                
                # Print examples periodically
                if i % 10 == 0:
                    print(f"\nStep {i}")
                    print(f"Loss: {loss:.4f}")
                    print(f"Reference: {reference}")
                    print(f"Generated: {generated}")
                    print("-" * 50)
            
            except Exception as e:
                print(f"Error in batch {i}: {str(e)}")
                continue
        
        # Epoch summary
        avg_loss = np.mean(epoch_losses)
        print(f"\nEpoch {epoch+1} Average Loss: {avg_loss:.4f}")
        
        # Save best model
        if avg_loss < best_loss:
            best_loss = avg_loss
            model.save_pretrained(f"llava_speaker_lora_best_model")
        
        # Regular checkpoint
        if (epoch + 1) % 5 == 0:
            model.save_pretrained(f"llava_speaker_lora_checkpoint_epoch_{epoch+1}")

if __name__ == "__main__":
    train_speaker_lora()

# Testing/ Be sure the replace the model

In [None]:
from peft import PeftModel
from transformers import AutoProcessor, LlavaForConditionalGeneration
import torch
# Load base model
base_model = LlavaForConditionalGeneration.from_pretrained(
    "llava-hf/llava-1.5-7b-hf",
    torch_dtype=torch.bfloat16
)
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

# Load LoRA adapter
model = PeftModel.from_pretrained(
    base_model,
    "ZinengTang/llava-lora-spatial"
).to(device)


In [None]:
from PIL import Image
init_prompt_instruct = "Describe the location of the blue sphere relative to the environment features."
conversation = [
    {
        "role": "user",
        "content": [
            {"type": "text", "text": init_prompt_instruct},
            {"type": "image"},  # This will be replaced with the actual image
        ],
    },
]
speaker_image = Image.open('/home/terran/projects/spatial/vlsim/source/embodied/final_data/output_data_1/images/0/speaker.jpg')
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
# print(prompt)
# Process the input image and prompt
inputs = processor(
    images=speaker_image,
    text=prompt,
    return_tensors="pt",
    max_length=256,
).to('cuda:7')

with torch.no_grad():
    generated = model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        pixel_values=inputs["pixel_values"],
        max_length=512,
        num_beams=1,
        do_sample=True,
        temperature=0.7
    )
    generated_message = processor.batch_decode(
        generated, 
        skip_special_tokens=True
    )
    print(generated_message)
    generated_message = generated_message[0].split('ASSISTANT: ')[-1][:100]
