<a href="https://colab.research.google.com/github/sayanbanerjee32/multimodal_llm/blob/main/phi_3_QLoRA_instruct150k.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers==4.44.2
!pip install -Uq accelerate peft bitsandbytes trl dataset bitsandbytes
# !pip install -Uq flash_attn



In [None]:
import numpy as np
import requests
from tqdm import tqdm
import os, gc
import subprocess
import json
import random
### Download Phi-3 model
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig, PreTrainedModel
from transformers.trainer_callback import TrainerCallback
from datasets import Dataset, DatasetDict
import joblib
from huggingface_hub import hf_hub_download

from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer

In [None]:
# Check if CUDA is available and set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


### Download Phi-3 model

In [None]:
# Load the Phi-3 model and tokenizer
model_name = "microsoft/Phi-3-mini-4k-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="right",
                                           trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

### Downlaod image embedding

In [None]:
from google.colab import drive
# Mount Google Drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# URL of the embeddings file (replace with your actual URL)
embeddings_url = '/content/drive/MyDrive/multimodel_llm/image_embedding/coco_image_embeddings.npz'

# Load the embeddings
print("Loading embeddings...")
embeddings = np.load(embeddings_url, allow_pickle=True)

# Print embeddings and image names
for image_name, embedding in embeddings.items():
    print(f"Image: {image_name}")
    print(f"Embedding shape: {embedding.shape}")
    print(f"Embedding preview: {embedding[:5]}...")  # Print first 5 values
    print("-" * 50)
    break

print(f"Total number of embeddings: {len(embeddings)}")

Loading embeddings...
Image: 000000401144.jpg
Embedding shape: (512,)
Embedding preview: [-0.13     0.1564   0.02017  0.1678   0.2393 ]...
--------------------------------------------------
Total number of embeddings: 81479


### Data processing

In [None]:
# List of URLs to download
url = "https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/resolve/main/llava_instruct_150k.json"

# Download each file
subprocess.run(["wget", "-c", url])

# Load the downloaded JSON file
json_file = "llava_instruct_150k.json"
with open(json_file, 'r') as f:
    data = json.load(f)

# Function to convert conversation format
# def convert_conversation(conversation):
#     system_message = "<|system|>\nYou are a helpful assistant.<|end|>\n"
#     user_message = ""
#     assistant_message = ""

#     for item in conversation:
#         if item['from'] == 'human':
#             user_message = f"<|user|>\n{item['value']}<|end|>\n"
#         elif item['from'] == 'gpt':
#             assistant_message = f"<|assistant|>\n{item['value']}<|end|>\n"

#     return system_message + user_message + assistant_message

# Process and tokenize the data
# Process and tokenize the data


# tokenized_data = []
# for item in tqdm(data, desc="Tokenizing data"):
#     image_file = item['image']
#     if image_file in embeddings:
#         image_embedding = torch.tensor(embeddings[image_file], dtype=torch.float32, device=device)

#         # Use the existing conversation format
#         conversation = [
#             {"role": "system", "content": "You are a helpful assistant."}
#         ] + [{"role": "user" if msg['from'] == 'human' else "assistant", "content": msg['value']}
#              for msg in item['conversations']]

#         # Apply chat template and tokenize directly
#         tokenized_conversation = tokenizer.apply_chat_template(conversation, return_tensors='pt').to(device)

#         tokenized_item = {
#             'image': image_file,
#             'image_embedding': image_embedding,
#             'tokenized_conversation': tokenized_conversation
#         }
#         tokenized_data.append(tokenized_item)

# print(f"Total tokenized items: {len(tokenized_data)}")
# print(f"Sample tokenized item:")
# print(f"Image: {tokenized_data[0]['image']}")
# print(f"Image embedding shape: {tokenized_data[0]['image_embedding'].shape}")
# print(f"Tokenized conversation shape: {tokenized_data[0]['tokenized_conversation'].shape}")
# print(f"Image embedding device: {tokenized_data[0]['image_embedding'].device}")
# print(f"Tokenized conversation device: {tokenized_data[0]['tokenized_conversation'].device}")

In [None]:
# data = data[0:100]

In [None]:
def create_dataset():
    processed_data = []
    print("Processing data...")
    with tqdm(total=len(data)) as pbar:
        for item in data:
            image_file = item['image']
            if image_file in embeddings:
                processed_data.append({
                    'image': image_file,
                    'image_embedding': embeddings[image_file].tolist(),
                    'conversation': item['conversations']
                })
            pbar.update(1)

    print(f"Data processing completed. Total processed items: {len(processed_data)}")

    return Dataset.from_dict({
        "image": [item['image'] for item in processed_data],
        "image_embedding": [item['image_embedding'] for item in processed_data],
        "conversation": [item['conversation'] for item in processed_data]
    })

print("Creating HuggingFace dataset...")
hf_dataset = create_dataset()

print("HuggingFace dataset creation completed.")
print(f"Total samples in dataset: {len(hf_dataset)}")



Creating HuggingFace dataset...
Processing data...


100%|██████████| 157712/157712 [12:31<00:00, 209.99it/s]


Data processing completed. Total processed items: 157712
HuggingFace dataset creation completed.
Total samples in dataset: 157712


In [None]:
# print("Applying tokenization and preparing the dataset...")

# def prepare_dataset(examples):
#     image_embeddings = torch.stack([torch.tensor(item) for item in examples['image_embedding']])

#     conversations = []
#     for conv in examples['conversation']:
#         dialogue = [{"role": "system", "content": "You are a helpful assistant."}]

#         for i, message in enumerate(conv):
#             if message['from'] == 'human':
#                 content = message['value'].replace('<image>', '').strip()  # Remove '<image>' and strip whitespace
#                 if i == 0:
#                     content = f"Given the following information, provide a detailed and accurate response:\n{content}\n[An image is provided for this task.]\n"
#                 dialogue.append({"role": "user", "content": content})
#             elif message['from'] == 'gpt':
#                 dialogue.append({"role": "assistant", "content": message['value']})

#         conversations.append(dialogue)

#     tokenized_conversations = tokenizer.apply_chat_template(conversations,
#                                                             return_tensors='pt', padding=True)

#     return {
#         "image_embeddings": image_embeddings,
#         "input_ids": tokenized_conversations,
#         "attention_mask": torch.ones_like(tokenized_conversations),
#         "labels": tokenized_conversations.clone()
#     }

def prepare_dataset(examples):
    image_embeddings = []
    conversations = []

    for idx, conv in enumerate(examples['conversation']):
        image_embedding = torch.tensor(examples['image_embedding'][idx])
        dialogue_pairs = []

        for i in range(0, len(conv), 2):
            if i + 1 < len(conv):  # Ensure we have a pair
                human_msg = conv[i]['value'].replace('<image>', '').strip()
                gpt_msg = conv[i + 1]['value']

                dialogue = [
                    {"role": "system", "content": "You are a helpful assistant."},
                    {"role": "user", "content": f"Given the following information, provide a detailed and accurate response:\n{human_msg}\n[An image is provided for this task.]\n"},
                    {"role": "assistant", "content": gpt_msg}
                ]

                dialogue_pairs.append(dialogue)
                image_embeddings.append(image_embedding)

        conversations.extend(dialogue_pairs)

    image_embeddings = torch.stack(image_embeddings)

    tokenized_conversations = tokenizer.apply_chat_template(conversations,
                                                            return_tensors='pt', padding=True)

    return {
        "image_embeddings": image_embeddings,
        "input_ids": tokenized_conversations,
        "attention_mask": torch.ones_like(tokenized_conversations),
        "labels": tokenized_conversations.clone()
    }

# Test the prepare_dataset function with a real training example
def test_prepare_dataset():
    # Get a batch of examples from the dataset
    batch_size = 1  # You can adjust this as needed
    sample_batch = hf_dataset[5:5+batch_size]

    print("Original conversations:")
    # for i, sample in enumerate(sample_batch):
    #     print(f"\nSample {i + 1}:")
    for message in sample_batch['conversation'][0]:
        print(f"{message['from']}: {message['value']}")

    # Process the sample batch
    result = prepare_dataset(sample_batch)

    # Print the structure of the result
    print("\nResult keys:", result.keys())
    print("Image embeddings shape:", result['image_embeddings'].shape)
    print("Input IDs shape:", result['input_ids'].shape)
    print("Attention mask shape:", result['attention_mask'].shape)
    print("Labels shape:", result['labels'].shape)

    for i in range(batch_size):
        decoded_input = tokenizer.decode(result['input_ids'][i])
        decoded_labels = tokenizer.decode(result['labels'][i])

        print(f"\nRestructured input for sample {i + 1}:")
        print(decoded_input)

        print(f"\nLabels for sample {i + 1}:")
        print(decoded_labels)

        # Optionally, you can print a more readable version of the labels
        print("\nReadable labels (non-padding tokens):")
        readable_labels = tokenizer.decode([token for token in result['labels'][i] if token != -100])
        print(readable_labels)

    # Optionally, you can print attention mask to see where it's applied
    print("\nAttention Mask:")
    print(result['attention_mask'][0])
# Run the test
test_prepare_dataset()


Original conversations:
human: <image>
What is the girl eating in the image?
gpt: The girl in the image is eating a dessert, which appears to be a graham cracker treat or a cookie sandwich.
human: Describe the girl's hair color and clothing.
gpt: The girl has blonde hair, and she is wearing a pink shirt.
human: What color is the plate that the dessert is on?
gpt: The dessert is on a green plate.
human: Is the girl looking at the camera or focusing on her dessert?
gpt: The girl is looking up at the camera while taking a bite of her dessert.
human: Where is the girl eating her dessert?
gpt: The girl is eating her dessert at the table.

Result keys: dict_keys(['image_embeddings', 'input_ids', 'attention_mask', 'labels'])
Image embeddings shape: torch.Size([5, 512])
Input IDs shape: torch.Size([5, 75])
Attention mask shape: torch.Size([5, 75])
Labels shape: torch.Size([5, 75])

Restructured input for sample 1:
<|system|> You are a helpful assistant.<|end|><|user|> Given the following infor

In [None]:
# Apply tokenization and prepare the dataset
print("Applying tokenization and preparing the dataset...")


# def prepare_dataset(examples):
#     image_embeddings = torch.stack([torch.tensor(item) for item in examples['image_embedding']])

#     conversations = [
#         [{"role": "system", "content": "You are a helpful assistant."}] +
#         [{"role": "user" if msg['from'] == 'human' else "assistant", "content": msg['value']}
#          for msg in conv]
#         for conv in examples['conversation']
#     ]

#     tokenized_conversations = tokenizer.apply_chat_template(conversations,
#                                                              return_tensors='pt', padding=True)

#     return {
#         "image_embeddings": image_embeddings,
#         "input_ids": tokenized_conversations,
#         "attention_mask": torch.ones_like(tokenized_conversations),
#         "labels": tokenized_conversations.clone()
#     }


hf_dataset_mapped = hf_dataset.map(
    prepare_dataset,
    batched=True,
    remove_columns=hf_dataset.column_names,
    batch_size=1024  # Adjust based on your memory constraints
).with_format("torch")

# Split the dataset
train_test_split = hf_dataset_mapped.train_test_split(test_size=0.05)

# Create a DatasetDict
dataset_dict = DatasetDict({
    'train': train_test_split['train'],
    'test': train_test_split['test']
})

print(f"Train dataset size: {len(dataset_dict['train'])}")
print(f"Test dataset size: {len(dataset_dict['test'])}")

Applying tokenization and preparing the dataset...


Map:   0%|          | 0/157712 [00:00<?, ? examples/s]

Train dataset size: 325269
Test dataset size: 36142


In [None]:
# Example of accessing an item:
sample = dataset_dict['train'][0]
print(f"Input IDs shape: {len(sample['input_ids'])}")
print(f"Attention mask shape: {len(sample['attention_mask'])}")
print(f"Labels shape: {len(sample['labels'])}")

Input IDs shape: 754
Attention mask shape: 754
Labels shape: 754


### Projection Layer

In [None]:
### Projection Layer
# class SimpleResBlock(nn.Module):
#     def __init__(self, input_dim, output_dim):
#         super().__init__()
#         self.pre_norm = nn.LayerNorm(input_dim)
#         self.proj = nn.Sequential(
#             nn.Linear(input_dim, output_dim),
#             nn.GELU(),
#             nn.Linear(output_dim, output_dim)
#         )

#     def forward(self, x):
#         x = self.pre_norm(x)
#         return x + self.proj(x)

# class ImageProjector(nn.Module):
#     def __init__(self, input_dim, output_dim):
#         super().__init__()
#         self.proj = nn.Linear(input_dim, output_dim)

#     def forward(self, x):
#         return self.proj(x)

In [None]:
# import torch
# from torch.utils.data import Dataset, random_split

# class Phi3Dataset(Dataset):
#     def __init__(self, tokenized_data, projector, tokenizer):
#         self.data = tokenized_data
#         self.projector = projector
#         self.tokenizer = tokenizer

#     def __len__(self):
#         return len(self.data)

#     def __getitem__(self, idx):
#         item = self.data[idx]
#         image_embedding = item['image_embedding']
#         conversation = item['tokenized_conversation']

#         projected_image = self.projector(image_embedding.unsqueeze(0)).squeeze(0)

#         # Combine projected image and conversation
#         combined_input = torch.cat([projected_image, conversation.squeeze(0)])

#         # Create attention mask
#         attention_mask = torch.ones_like(combined_input)

#         # Prepare labels (shift right, set first token to -100)
#         labels = torch.cat([-100 * torch.ones(projected_image.shape[0], dtype=torch.long), conversation.squeeze(0)[:-1]])

#         return {
#             "input_ids": combined_input,
#             "attention_mask": attention_mask,
#             "labels": labels
#         }

# # Usage:
# image_embedding_dim = tokenized_data[0]['image_embedding'].shape[-1]
# projection_dim = 1024  # Adjust as needed
# projector = SimpleResBlock(image_embedding_dim, projection_dim)
# full_dataset = Phi3Dataset(tokenized_data, projector, tokenizer)

# # Split the dataset
# train_size = int(0.9 * len(full_dataset))
# test_size = len(full_dataset) - train_size
# train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])

# print(f"Train dataset size: {len(train_dataset)}")
# print(f"Test dataset size: {len(test_dataset)}")

# # Example of accessing an item:
# sample = train_dataset[0]
# print(f"Input IDs shape: {sample['input_ids'].shape}")
# print(f"Attention mask shape: {sample['attention_mask'].shape}")
# print(f"Labels shape: {sample['labels'].shape}")


### QLoRA set up

In [None]:
# new_model = "ms-phi3-custom"
lora_r = 32 #64
lora_alpha = 16
lora_dropout = 0.05
use_4bit = True
bnb_4bit_compute_dtype = "float16"
bnb_4bit_quant_type = "nf4"
use_nested_quant = False
output_dir = "./results"
num_train_epochs = 1
fp16 = False
bf16 = False
per_device_train_batch_size = 8
per_device_eval_batch_size = 4
gradient_accumulation_steps = 8
gradient_checkpointing = True
max_grad_norm = 0.3
learning_rate = 5e-4
weight_decay = 0.001
optim = "paged_adamw_32bit"
lr_scheduler_type = "constant"
max_steps = -1
warmup_ratio = 0.03
group_by_length = True
save_steps = 25
logging_steps = 25
eval_steps = 50 # Evaluate every 25 steps
max_seq_length = 256
packing = False
device_map = {"": 0}

In [None]:
# import torch
# from transformers import PreTrainedModel

# class Phi3WithProjector(PreTrainedModel):
#     supports_gradient_checkpointing = True

#     def __init__(self, phi3_model, projector):
#         super().__init__(phi3_model.config)
#         self.phi3 = phi3_model
#         self.projector = projector

#     def forward(self, input_ids=None, attention_mask=None, image_embeddings=None, labels=None, **kwargs):
#         device = next(self.parameters()).device

#         if image_embeddings is not None:
#             image_embeddings = image_embeddings.to(device)
#             projected_images = self.projector(image_embeddings)
#             projected_images = projected_images.unsqueeze(1)

#             if 'inputs_embeds' in kwargs and kwargs['inputs_embeds'] is not None:
#                 inputs_embeds = kwargs['inputs_embeds']
#                 inputs_embeds = torch.cat([projected_images, inputs_embeds], dim=1)
#                 kwargs['inputs_embeds'] = inputs_embeds
#             elif input_ids is not None:
#                 inputs_embeds = self.get_input_embeddings()(input_ids.to(device))
#                 inputs_embeds = torch.cat([projected_images, inputs_embeds], dim=1)
#                 kwargs['inputs_embeds'] = inputs_embeds
#                 input_ids = None  # Set to None to avoid conflict

#             if attention_mask is not None:
#                 attention_mask = torch.cat([torch.ones(image_embeddings.size(0), 1, device=device), attention_mask.to(device)], dim=1)
#             else:
#                 attention_mask = torch.ones(image_embeddings.size(0), inputs_embeds.size(1), device=device)

#             if labels is not None:
#                 # Adjust labels to match the new sequence length
#                 labels = torch.cat([torch.full((labels.size(0), 1), -100, device=device), labels], dim=1)

#         if labels is not None:
#             labels = labels.to(device)

#         # Ensure attention_mask matches the sequence length
#         if 'inputs_embeds' in kwargs:
#             seq_length = kwargs['inputs_embeds'].size(1)
#         elif input_ids is not None:
#             seq_length = input_ids.size(1)
#         else:
#             raise ValueError("Either input_ids or inputs_embeds should be provided")

#         if attention_mask is not None and attention_mask.size(1) != seq_length:
#             attention_mask = attention_mask[:, :seq_length]

#         return self.phi3(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs)

#     def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **kwargs):
#         inputs = self.phi3.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask, **kwargs)
#         if 'image_embeddings' in kwargs:
#             inputs['image_embeddings'] = kwargs['image_embeddings']

#             # Adjust attention_mask if it's present
#             if attention_mask is not None:
#                 inputs['attention_mask'] = torch.cat([torch.ones(attention_mask.size(0), 1, device=attention_mask.device), attention_mask], dim=1)

#             # Remove position_ids as they're not used
#             inputs.pop('position_ids', None)

#         return inputs

#     def get_input_embeddings(self):
#         return self.phi3.get_input_embeddings()

#     def set_input_embeddings(self, value):
#         self.phi3.set_input_embeddings(value)

#     def gradient_checkpointing_enable(self, **kwargs):
#         self.phi3.gradient_checkpointing_enable(**kwargs)

#     def gradient_checkpointing_disable(self):
#         self.phi3.gradient_checkpointing_disable()

#     def __getattr__(self, name):
#         try:
#             return super().__getattr__(name)
#         except AttributeError:
#             return getattr(self.phi3, name)

#     def generate(self, input_ids=None, attention_mask=None, image_embeddings=None, **kwargs):
#         if image_embeddings is not None:
#             kwargs['image_embeddings'] = image_embeddings

#         if attention_mask is not None and image_embeddings is not None:
#             # Add an extra attention mask token for the image embedding
#             attention_mask = torch.cat([torch.ones(attention_mask.size(0), 1, device=attention_mask.device), attention_mask], dim=1)

#         return super().generate(input_ids=input_ids, attention_mask=attention_mask, **kwargs)

In [None]:
# # import os
# # import torch
# from transformers import PreTrainedModel

# class ImageProjector(torch.nn.Module):
#     def __init__(self, input_dim, output_dim):
#         super().__init__()
#         self.linear = torch.nn.Linear(input_dim, output_dim)

#     def forward(self, x):
#         return self.linear(x)

# class Phi3WithProjector(PreTrainedModel):
#     supports_gradient_checkpointing = True

#     def __init__(self, phi3_model, projector):
#         super().__init__(phi3_model.config)
#         self.phi3 = phi3_model
#         self.projector = projector

#     @classmethod
#     def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
#         # Load the base Phi-3 model
#         phi3_model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)

#         # Load the projector weights
#         projector_path = os.path.join(pretrained_model_name_or_path, "image_projector.pth")
#         if os.path.exists(projector_path):
#             projector_state_dict = torch.load(projector_path, map_location=phi3_model.device)

#             # Check if the state dict has the expected structure
#             if 'linear.weight' in projector_state_dict:
#                 input_dim = projector_state_dict['linear.weight'].size(1)
#                 output_dim = projector_state_dict['linear.weight'].size(0)
#             else:
#                 # If not, try to infer dimensions from the first layer's weight
#                 first_key = next(iter(projector_state_dict))
#                 input_dim = projector_state_dict[first_key].size(1)
#                 output_dim = phi3_model.config.hidden_size  # Assuming this is the correct output dimension

#             projector = ImageProjector(input_dim, output_dim)

#             # Try to load the state dict, ignoring mismatched keys
#             projector.load_state_dict(projector_state_dict, strict=False)
#             print(f"Loaded projector with input_dim={input_dim}, output_dim={output_dim}")
#         else:
#             print(f"Projector weights not found at {projector_path}. Initializing with default dimensions.")
#             input_dim = 512  # Default CLIP embedding size
#             output_dim = phi3_model.config.hidden_size
#             projector = ImageProjector(input_dim, output_dim)

#         # Create and return the Phi3WithProjector instance
#         model = cls(phi3_model, projector)
#         return model

#     def save_pretrained(self, save_directory):
#         # Save the base model
#         self.phi3.save_pretrained(save_directory)

#         # Save the projector weights
#         projector_path = os.path.join(save_directory, "image_projector.pth")
#         torch.save(self.projector.state_dict(), projector_path)

#         # Save the config
#         self.config.save_pretrained(save_directory)

#     def forward(self, input_ids=None, attention_mask=None, image_embeddings=None, labels=None, **kwargs):
#         device = next(self.parameters()).device

#         if image_embeddings is not None:
#             image_embeddings = image_embeddings.to(device)
#             projected_images = self.projector(image_embeddings)
#             projected_images = projected_images.unsqueeze(1)

#             if 'inputs_embeds' in kwargs and kwargs['inputs_embeds'] is not None:
#                 inputs_embeds = kwargs['inputs_embeds']
#                 inputs_embeds = torch.cat([projected_images, inputs_embeds], dim=1)
#                 kwargs['inputs_embeds'] = inputs_embeds
#             elif input_ids is not None:
#                 inputs_embeds = self.get_input_embeddings()(input_ids.to(device))
#                 inputs_embeds = torch.cat([projected_images, inputs_embeds], dim=1)
#                 kwargs['inputs_embeds'] = inputs_embeds
#                 input_ids = None  # Set to None to avoid conflict

#             if attention_mask is not None:
#                 attention_mask = torch.cat([torch.ones(image_embeddings.size(0), 1, device=device), attention_mask.to(device)], dim=1)
#             else:
#                 attention_mask = torch.ones(image_embeddings.size(0), inputs_embeds.size(1), device=device)

#             if labels is not None:
#                 # Adjust labels to match the new sequence length
#                 labels = torch.cat([torch.full((labels.size(0), 1), -100, device=device), labels], dim=1)

#         if labels is not None:
#             labels = labels.to(device)

#         # Ensure attention_mask matches the sequence length
#         if 'inputs_embeds' in kwargs:
#             seq_length = kwargs['inputs_embeds'].size(1)
#         elif input_ids is not None:
#             seq_length = input_ids.size(1)
#         else:
#             raise ValueError("Either input_ids or inputs_embeds should be provided")

#         if attention_mask is not None and attention_mask.size(1) != seq_length:
#             attention_mask = attention_mask[:, :seq_length]

#         return self.phi3(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs)

#     def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **kwargs):
#         inputs = self.phi3.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask, **kwargs)
#         if 'image_embeddings' in kwargs:
#             inputs['image_embeddings'] = kwargs['image_embeddings']

#             # Adjust attention_mask if it's present
#             if attention_mask is not None:
#                 inputs['attention_mask'] = torch.cat([torch.ones(attention_mask.size(0), 1, device=attention_mask.device), attention_mask], dim=1)

#             # Remove position_ids as they're not used
#             inputs.pop('position_ids', None)

#         return inputs

#     def get_input_embeddings(self):
#         return self.phi3.get_input_embeddings()

#     def set_input_embeddings(self, value):
#         self.phi3.set_input_embeddings(value)

#     def gradient_checkpointing_enable(self, **kwargs):
#         self.phi3.gradient_checkpointing_enable(**kwargs)

#     def gradient_checkpointing_disable(self):
#         self.phi3.gradient_checkpointing_disable()

#     def __getattr__(self, name):
#         try:
#             return super().__getattr__(name)
#         except AttributeError:
#             return getattr(self.phi3, name)

#     def generate(self, input_ids=None, attention_mask=None, image_embeddings=None, **kwargs):
#         if image_embeddings is not None:
#             kwargs['image_embeddings'] = image_embeddings

#         if attention_mask is not None and image_embeddings is not None:
#             # Add an extra attention mask token for the image embedding
#             attention_mask = torch.cat([torch.ones(attention_mask.size(0), 1, device=attention_mask.device), attention_mask], dim=1)

#         return super().generate(input_ids=input_ids, attention_mask=attention_mask, **kwargs)

In [None]:
class ImageProjector(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=1024):
        super().__init__()
        self.layer1 = nn.Linear(input_dim, hidden_dim)
        self.activation = nn.GELU()  # Using GELU activation, but you can experiment with others
        self.layer2 = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(0.05)  # Adding dropout for regularization

    def forward(self, x):
        x = self.layer1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.layer2(x)
        return x

class Phi3WithProjector(PreTrainedModel):
    supports_gradient_checkpointing = True

    def __init__(self, phi3_model, projector, debug=False):
        super().__init__(phi3_model.config)
        self.phi3 = phi3_model
        self.projector = projector
        self.debug = debug

    def debug_print(self, *args, **kwargs):
        if self.debug:
            print(*args, **kwargs)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, debug=False, **kwargs):
        # Load the base Phi-3 model
        phi3_model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)

        # Determine if it's a local path or a Hugging Face model ID
        is_local = os.path.isdir(pretrained_model_name_or_path)

        if is_local:
            projector_path = os.path.join(pretrained_model_name_or_path, "image_projector.pth")
        else:
            try:
                # Try to download the projector weights from the Hugging Face Hub
                projector_path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename="image_projector.pth")
            except Exception as e:
                print(f"Failed to download projector weights: {e}")
                projector_path = None

        if projector_path and os.path.exists(projector_path):
            projector_state_dict = torch.load(projector_path, map_location=phi3_model.device)
            # Check if the state dict has the expected structure
            if 'linear.weight' in projector_state_dict:
                input_dim = projector_state_dict['linear.weight'].size(1)
                output_dim = projector_state_dict['linear.weight'].size(0)
            else:
                # If not, try to infer dimensions from the first layer's weight
                first_key = next(iter(projector_state_dict))
                input_dim = projector_state_dict[first_key].size(1)
                output_dim = phi3_model.config.hidden_size  # Assuming this is the correct output dimension

            projector = ImageProjector(input_dim, output_dim)

            # Try to load the state dict, ignoring mismatched keys
            projector.load_state_dict(projector_state_dict, strict=False)
            print(f"Loaded projector with input_dim={input_dim}, output_dim={output_dim}")
        else:
            print(f"Projector weights not found. Initializing with default dimensions.")
            input_dim = 512  # Default CLIP embedding size
            output_dim = phi3_model.config.hidden_size
            projector = ImageProjector(input_dim, output_dim)

        # Move the projector to the same device as phi3_model
        projector = projector.to(phi3_model.device)

        # Create and return the Phi3WithProjector instance
        model = cls(phi3_model, projector, debug=debug)
        return model

    def save_pretrained(self, save_directory):
        # Save the base model
        self.phi3.save_pretrained(save_directory)

        # Save the projector weights
        projector_path = os.path.join(save_directory, "image_projector.pth")
        torch.save(self.projector.state_dict(), projector_path)

        # Save the config
        self.config.save_pretrained(save_directory)

    def forward(self, input_ids=None, attention_mask=None, image_embeddings=None, labels=None, past_key_values=None, **kwargs):
        device = next(self.parameters()).device

        if image_embeddings is not None:
            image_embeddings = image_embeddings.to(device)
            projected_images = self.projector(image_embeddings)
            projected_images = projected_images.unsqueeze(1)
            self.debug_print(f"forward projected_images: {projected_images.size()}")

            if past_key_values is None:  # This is the first forward pass
                self.debug_print(f"forward before: {attention_mask.size() if attention_mask is not None else None}")
                if 'inputs_embeds' in kwargs and kwargs['inputs_embeds'] is not None:
                    inputs_embeds = kwargs['inputs_embeds']
                    self.debug_print(f"forward before inputs_embeds: {inputs_embeds.size()}")
                    inputs_embeds = torch.cat([projected_images, inputs_embeds], dim=1)
                    kwargs['inputs_embeds'] = inputs_embeds
                    self.debug_print(f"forward after inputs_embeds: {inputs_embeds.size()}")
                elif input_ids is not None:
                    self.debug_print(f"forward input_ids: {input_ids.size()}")
                    inputs_embeds = self.get_input_embeddings()(input_ids.to(device))
                    self.debug_print(f"forward before inputs_embeds: {inputs_embeds.size()}")
                    inputs_embeds = torch.cat([projected_images, inputs_embeds], dim=1)
                    self.debug_print(f"forward after inputs_embeds: {inputs_embeds.size()}")
                    kwargs['inputs_embeds'] = inputs_embeds
                    input_ids = None  # Set to None to avoid conflict

                if attention_mask is not None:
                    attention_mask = torch.cat([torch.ones(image_embeddings.size(0), 1, device=device), attention_mask.to(device)], dim=1)
                else:
                    attention_mask = torch.ones(image_embeddings.size(0), inputs_embeds.size(1), device=device)

                if labels is not None:
                    # Adjust labels to match the new sequence length
                    labels = torch.cat([torch.full((labels.size(0), 1), -100, device=device), labels], dim=1)

        if labels is not None:
            labels = labels.to(device)

        # Determine sequence length
        if 'inputs_embeds' in kwargs and kwargs['inputs_embeds'] is not None:
            seq_length = kwargs['inputs_embeds'].size(1)
        elif input_ids is not None:
            seq_length = input_ids.size(1)
        else:
            seq_length = attention_mask.size(1) if attention_mask is not None else None

        if seq_length is None:
            raise ValueError("Unable to determine sequence length. Provide either input_ids, inputs_embeds, or attention_mask.")

        # Ensure attention_mask matches the sequence length
        if attention_mask is not None:
            attention_mask = attention_mask[:, :seq_length]

        self.debug_print(f"forward final: input_ids shape: {input_ids.shape if input_ids is not None else None}")
        self.debug_print(f"forward final: attention_mask shape: {attention_mask.shape if attention_mask is not None else None}")
        self.debug_print(f"forward final: inputs_embeds shape: {kwargs.get('inputs_embeds', {}).shape if kwargs.get('inputs_embeds') is not None else None}")

        return self.phi3(input_ids=input_ids, attention_mask=attention_mask, labels=labels, past_key_values=past_key_values, **kwargs)

    def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **kwargs):
        inputs = self.phi3.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask, **kwargs)

        if 'image_embeddings' in kwargs:
            inputs['image_embeddings'] = kwargs['image_embeddings']

            if past is None:  # First forward pass
                # Adjust attention_mask to account for the image token
                if attention_mask is not None:
                    inputs['attention_mask'] = torch.cat([torch.ones((attention_mask.size(0), 1), device=attention_mask.device), attention_mask], dim=1)
            else:  # Subsequent passes
                # Ensure attention_mask matches the current sequence length
                if attention_mask is not None:
                    current_seq_length = past[0][0].size(2) + 1  # past key's sequence length + 1 for the new token
                    inputs['attention_mask'] = attention_mask[:, :current_seq_length]

            inputs.pop('position_ids', None)

        # Safe printing of shapes
        self.debug_print(f"prepare_inputs_for_generation: input_ids shape: {inputs['input_ids'].shape if 'input_ids' in inputs else None}")
        self.debug_print(f"prepare_inputs_for_generation: attention_mask shape: {inputs['attention_mask'].shape if 'attention_mask' in inputs else None}")
        self.debug_print(f"prepare_inputs_for_generation: inputs_embeds shape: {inputs.get('inputs_embeds', {}).shape if inputs.get('inputs_embeds') is not None else None}")

        return inputs

    def get_input_embeddings(self):
        return self.phi3.get_input_embeddings()

    def set_input_embeddings(self, value):
        self.phi3.set_input_embeddings(value)

    def gradient_checkpointing_enable(self, **kwargs):
        self.phi3.gradient_checkpointing_enable(**kwargs)

    def gradient_checkpointing_disable(self):
        self.phi3.gradient_checkpointing_disable()

    def __getattr__(self, name):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.phi3, name)

    def generate(self, input_ids=None, attention_mask=None, image_embeddings=None, **kwargs):
        if image_embeddings is not None:
            kwargs['image_embeddings'] = image_embeddings
            self.debug_print(f"generate input_ids: {input_ids.size()}")
            self.debug_print(f"generate image_embedding: {image_embeddings.size()}")

        if attention_mask is not None and image_embeddings is not None:
            # Add an extra attention mask token for the image embedding
            self.debug_print(f"generate before: {attention_mask.size()}")
            attention_mask = torch.cat([torch.ones(attention_mask.size(0), 1, device=attention_mask.device), attention_mask], dim=1)
            self.debug_print(f"generate after: {attention_mask.size()}")

        return super().generate(input_ids=input_ids, attention_mask=attention_mask, **kwargs)

In [None]:
# Define the path in Google Drive where you want to save the checkpoints
gdrive_checkpoint_dir = "/content/drive/MyDrive/multimodel_llm/phi3_checkpoints"

# Ensure the directory exists
os.makedirs(gdrive_checkpoint_dir, exist_ok=True)
import dataclasses

class SaveLatestCheckpointCallback(TrainerCallback):
    def on_save(self, args, state, control, **kwargs):
        if state.is_world_process_zero:
            checkpoint_dir = os.path.join(gdrive_checkpoint_dir, f"checkpoint-{state.global_step}")

            # Save the model and tokenizer
            kwargs["model"].save_pretrained(checkpoint_dir)
            kwargs["tokenizer"].save_pretrained(checkpoint_dir)

            # Save the projector separately
            projector_path = os.path.join(checkpoint_dir, "image_projector.pth")
            # torch.save(kwargs["model"].projector.state_dict(), projector_path)
            torch.save(kwargs["model"].base_model.model.projector.state_dict(), projector_path)

            # Explicitly save the trainer state
            trainer_state_path = os.path.join(checkpoint_dir, "trainer_state.json")
            state_dict = dataclasses.asdict(state)
            with open(trainer_state_path, "w") as f:
                json.dump(state_dict, f, indent=2)

            # Remove previous checkpoint
            prev_checkpoint = os.path.join(gdrive_checkpoint_dir, f"checkpoint-{state.global_step - args.save_steps}")
            if os.path.exists(prev_checkpoint):
                import shutil
                shutil.rmtree(prev_checkpoint)

# Function to get the latest checkpoint
def get_latest_checkpoint(checkpoint_dir):
    checkpoints = [d for d in os.listdir(checkpoint_dir) if d.startswith('checkpoint-')]
    if not checkpoints:
        return None
    latest_checkpoint = max(checkpoints, key=lambda x: int(x.split('-')[1]))
    return os.path.join(checkpoint_dir, latest_checkpoint)

In [None]:
if torch.cuda.is_bf16_supported():
  compute_dtype = torch.bfloat16
#   attn_implementation = 'flash_attention_2'
else:
  compute_dtype = torch.float16
#   attn_implementation = 'sdpa'

# print(attn_implementation)
print(compute_dtype)

torch.bfloat16


In [None]:
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=use_4bit,
    bnb_4bit_quant_type=bnb_4bit_quant_type,
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=use_nested_quant,
)
# # Load the model again for quantization
# ### Download Phi-3 model
# phi3_model = AutoModelForCausalLM.from_pretrained(
#     model_name,
#     trust_remote_code=True,
#     quantization_config=bnb_config,
#     device_map=device_map,
#     torch_dtype=compute_dtype,
#     # attn_implementation=attn_implementation
# )

# print(phi3_model)

In [None]:
# # Initialize the projector
# image_embedding_dim = len(hf_dataset[0]['image_embedding'])
# projection_dim = phi3_model.config.hidden_size  # Get dimension from the model
# projector = ImageProjector(image_embedding_dim, projection_dim).to(device)

# # Combine Phi-3 with the projector
# model = Phi3WithProjector(phi3_model, projector)

# Get the latest checkpoint
latest_checkpoint = get_latest_checkpoint(gdrive_checkpoint_dir)

if latest_checkpoint:
    print(f"Loading model from checkpoint: {latest_checkpoint}")
    model = Phi3WithProjector.from_pretrained(
        latest_checkpoint,
        trust_remote_code=True,
        quantization_config=bnb_config,
        device_map=device_map,
        torch_dtype=compute_dtype,
    )
else:
    print("No checkpoint found. Starting training from scratch.")
    phi3_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        quantization_config=bnb_config,
        device_map=device_map,
        torch_dtype=compute_dtype,
    )
    image_embedding_dim = len(hf_dataset[0]['image_embedding'])
    projection_dim = phi3_model.config.hidden_size
    projector = ImageProjector(image_embedding_dim, projection_dim).to(device)
    model = Phi3WithProjector(phi3_model, projector)

# Prepare the model for k-bit training
model = prepare_model_for_kbit_training(model)

Loading model from checkpoint: /content/drive/MyDrive/multimodel_llm/phi3_checkpoints/checkpoint-150




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading adapter weights from /content/drive/MyDrive/multimodel_llm/phi3_checkpoints/checkpoint-150 led to unexpected keys not found in the model:  ['phi3.model.layers.0.mlp.down_proj.lora_A.default.weight', 'phi3.model.layers.0.mlp.down_proj.lora_B.default.weight', 'phi3.model.layers.0.self_attn.o_proj.lora_A.default.weight', 'phi3.model.layers.0.self_attn.o_proj.lora_B.default.weight', 'phi3.model.layers.1.mlp.down_proj.lora_A.default.weight', 'phi3.model.layers.1.mlp.down_proj.lora_B.default.weight', 'phi3.model.layers.1.self_attn.o_proj.lora_A.default.weight', 'phi3.model.layers.1.self_attn.o_proj.lora_B.default.weight', 'phi3.model.layers.10.mlp.down_proj.lora_A.default.weight', 'phi3.model.layers.10.mlp.down_proj.lora_B.default.weight', 'phi3.model.layers.10.self_attn.o_proj.lora_A.default.weight', 'phi3.model.layers.10.self_attn.o_proj.lora_B.default.weight', 'phi3.model.layers.11.mlp.down_proj.lora_A.default.weight', 'phi3.model.layers.11.mlp.down_proj.lora_B.default.weight', 'p

Loaded projector with input_dim=512, output_dim=3072


In [None]:
print(compute_dtype) , print(model)

torch.float16
Phi3WithProjector(
  (phi3): Phi3ForCausalLM(
    (model): Phi3Model(
      (embed_tokens): Embedding(32064, 3072, padding_idx=32000)
      (embed_dropout): Dropout(p=0.0, inplace=False)
      (layers): ModuleList(
        (0-31): 32 x Phi3DecoderLayer(
          (self_attn): Phi3Attention(
            (o_proj): lora.Linear4bit(
              (base_layer): Linear4bit(in_features=3072, out_features=3072, bias=False)
              (lora_dropout): ModuleDict(
                (default): Dropout(p=0.05, inplace=False)
              )
              (lora_A): ModuleDict(
                (default): Linear(in_features=3072, out_features=32, bias=False)
              )
              (lora_B): ModuleDict(
                (default): Linear(in_features=32, out_features=3072, bias=False)
              )
              (lora_embedding_A): ParameterDict()
              (lora_embedding_B): ParameterDict()
              (lora_magnitude_vector): ModuleDict()
            )
            (qkv_pr

(None, None)

In [None]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )


In [None]:
# Define LoRA configuration
lora_config = LoraConfig(
    r=lora_r,
    lora_alpha=lora_alpha,
    target_modules=['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"],
    lora_dropout=lora_dropout,
    bias="none",
    task_type="CAUSAL_LM"
)

# Apply LoRA to the model
model = get_peft_model(model, lora_config)
print_trainable_parameters(model)

trainable params: 17825792 || all params: 2030640128 || trainable%: 0.8778410194009522


In [None]:
print(model)

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Phi3WithProjector(
      (phi3): Phi3ForCausalLM(
        (model): Phi3Model(
          (embed_tokens): Embedding(32064, 3072, padding_idx=32000)
          (embed_dropout): Dropout(p=0.0, inplace=False)
          (layers): ModuleList(
            (0-31): 32 x Phi3DecoderLayer(
              (self_attn): Phi3Attention(
                (o_proj): lora.Linear4bit(
                  (base_layer): Linear4bit(in_features=3072, out_features=3072, bias=False)
                  (lora_dropout): ModuleDict(
                    (default): Dropout(p=0.05, inplace=False)
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=3072, out_features=32, bias=False)
                  )
                  (lora_B): ModuleDict(
                    (default): Linear(in_features=32, out_features=3072, bias=False)
                  )
                  (lora_embedding_A): ParameterDict()
           

### Training

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
# Define training arguments
training_args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=per_device_train_batch_size,
    per_device_eval_batch_size  = per_device_eval_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim=optim,
    save_steps=save_steps,
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    fp16=fp16,
    bf16=bf16,
    max_grad_norm=max_grad_norm,
    max_steps=max_steps,
    warmup_ratio=warmup_ratio,
    group_by_length=group_by_length,
    lr_scheduler_type=lr_scheduler_type,
    report_to="all",
    eval_strategy="steps",
    eval_steps=eval_steps, # Evaluate every 25 steps
    # Add these new arguments
    do_eval=True,
    eval_delay=0,  # Start evaluation immediately
    # Enable gradient checkpointing
    gradient_checkpointing=gradient_checkpointing,
    # Disable data parallelism if not needed
    ddp_find_unused_parameters=False,
    save_total_limit=1,  # Keep only the latest checkpoint
)

# Custom data collator to handle pre-tokenized inputs
def custom_data_collator(features):
    batch = {k: [d[k] for d in features] for k in features[0].keys()}

    # Stack image embeddings
    batch['image_embeddings'] = torch.stack(batch['image_embeddings'])

    # Pad the sequences
    batch['input_ids'] = torch.nn.utils.rnn.pad_sequence(batch['input_ids'], batch_first=True, padding_value=tokenizer.pad_token_id)
    batch['attention_mask'] = torch.nn.utils.rnn.pad_sequence(batch['attention_mask'], batch_first=True, padding_value=0)
    batch['labels'] = torch.nn.utils.rnn.pad_sequence(batch['labels'], batch_first=True, padding_value=-100)

    return batch

In [None]:
# Function to select a random subset of the dataset
def select_subset(dataset, fraction=0.05):
    num_samples = int(len(dataset) * fraction)
    indices = random.sample(range(len(dataset)), num_samples)
    return dataset.select(indices)

# Select 5% of the training and test datasets
small_train_dataset = select_subset(dataset_dict['train'], fraction=0.1)
small_test_dataset = select_subset(dataset_dict['test'], fraction=0.1)

# Create a new DatasetDict with the smaller datasets
small_dataset_dict = DatasetDict({
    'train': small_train_dataset,
    'test': small_test_dataset
})

print(f"Small train dataset size: {len(small_dataset_dict['train'])}")
print(f"Small test dataset size: {len(small_dataset_dict['test'])}")

Small train dataset size: 32526
Small test dataset size: 3614


In [None]:
# Initialize the SFTTrainer
trainer = SFTTrainer(
    model=model,
    args=training_args,
    # train_dataset=dataset_dict['train'],
    # eval_dataset=dataset_dict['test'],
    train_dataset=small_dataset_dict['train'],
    eval_dataset=small_dataset_dict['test'],
    tokenizer=tokenizer,
    data_collator=custom_data_collator,
    peft_config=lora_config,
    max_seq_length=max_seq_length,
    packing=packing,
    callbacks=[SaveLatestCheckpointCallback()],  # Add the custom callback
)

# Perform initial evaluation
print("Performing initial evaluation...")
eval_results = trainer.evaluate()
print(f"Initial evaluation results: {eval_results}")

# # Start training
trainer.train()
# Start or resume training
# trainer.train(resume_from_checkpoint=latest_checkpoint)


Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.


Performing initial evaluation...




Initial evaluation results: {'eval_loss': 9.622665405273438, 'eval_model_preparation_time': 0.0042, 'eval_runtime': 6453.6574, 'eval_samples_per_second': 0.56, 'eval_steps_per_second': 0.14}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


Step,Training Loss,Validation Loss,Model Preparation Time
50,0.3685,0.23207,0.0042


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


Step,Training Loss,Validation Loss,Model Preparation Time
50,0.3685,0.23207,0.0042
100,0.3595,0.227313,0.0042
150,0.3657,0.220686,0.0042
200,0.3621,0.216167,0.0042


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


KeyboardInterrupt: 

In [None]:
# Save the fine-tuned model
# trainer.model.save_pretrained(new_model)
final_model_path = os.path.join(gdrive_checkpoint_dir, "final_model")
trainer.model.save_pretrained(final_model_path)
tokenizer.save_pretrained(final_model_path)

('/content/drive/MyDrive/multimodel_llm/phi3_checkpoints/final_model/tokenizer_config.json',
 '/content/drive/MyDrive/multimodel_llm/phi3_checkpoints/final_model/special_tokens_map.json',
 '/content/drive/MyDrive/multimodel_llm/phi3_checkpoints/final_model/tokenizer.model',
 '/content/drive/MyDrive/multimodel_llm/phi3_checkpoints/final_model/added_tokens.json',
 '/content/drive/MyDrive/multimodel_llm/phi3_checkpoints/final_model/tokenizer.json')

In [None]:
torch.save(model.base_model.model.projector.state_dict(), final_model_path + '/image_projector.pth')
print(f"Projector saved to: {final_model_path}/image_projector.pth")

Projector saved to: /content/drive/MyDrive/multimodel_llm/phi3_checkpoints/final_model/image_projector.pth


In [None]:
!cp -r results /content/drive/MyDrive/multimodel_llm/

## sample inference code

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
# Create a custom text generation class
class CustomTextGenerator:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def generate(self, input_text, image_embedding, **generate_kwargs):
        # Tokenize the input text
        inputs = self.tokenizer(input_text, return_tensors="pt")
        input_ids = inputs["input_ids"].to(self.model.device)
        attention_mask = inputs["attention_mask"].to(self.model.device)

        # Ensure image_embedding is a tensor and move it to the correct device
        if not isinstance(image_embedding, torch.Tensor):
            image_embedding = torch.tensor(image_embedding)
        image_embedding = image_embedding.to(self.model.device)

        # Adjust attention_mask to account for the image embedding token
        image_attention = torch.ones((1, 1), dtype=torch.long, device=self.model.device)
        attention_mask = torch.cat([image_attention, attention_mask], dim=1)

        # Generate text
        outputs = self.model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            image_embeddings=image_embedding.unsqueeze(0),  # Add batch dimension
            **generate_kwargs
        )

        # Decode the generated text
        generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return generated_text

# Initialize the custom text generator
generator = CustomTextGenerator(model=model, tokenizer=tokenizer)




<|system|> You are a helpful assistant.<|end|><|user|> Given the following information, provide a detailed and accurate response:
How is the little girl dressed?
[An image is provided for this task.]
<|end|><|assistant|>


In [None]:
# Get a sample from the validation set
sample = dataset_dict['test'][1]
image_embedding = sample['image_embeddings']


def get_first_user_input(decoded_text):
    # Find the position of the first <|assistant|> tag
    assistant_pos = decoded_text.find('<|assistant|>')

    # If <|assistant|> is found, truncate the text
    if assistant_pos != -1:
        return decoded_text[:assistant_pos].strip()
    else:
        return decoded_text.strip()


# Decode the input_ids
full_text = tokenizer.decode(sample['input_ids'], skip_special_tokens=False)

# Extract only the first user input
input_text = get_first_user_input(full_text) + '<|assistant|>'
print(input_text)

<|system|> You are a helpful assistant.<|end|><|user|> Given the following information, provide a detailed and accurate response:
Analyze the image in a comprehensive and detailed manner.
[An image is provided for this task.]
<|end|><|assistant|>


In [None]:
# Generate text
generated_text = generator.generate(
    input_text,
    image_embedding=image_embedding,
    # max_length=200,
    # num_return_sequences=1,
    # do_sample=True,
    # temperature=0.7,
    # top_k=50,
    # top_p=0.95,
    max_new_tokens=150,
    num_return_sequences=1,
    do_sample=True,
    temperature=0.8,
    top_k=40,
    top_p=0.9,
    repetition_penalty=1.2,
    no_repeat_ngram_size=3,
)

print("Input text:")
print(input_text)
print("\nGenerated text:")
print(generated_text)

  return fn(*args, **kwargs)


Input text:
<|system|> You are a helpful assistant.<|end|><|user|> Given the following information, provide a detailed and accurate response:
Analyze the image in a comprehensive and detailed manner.
[An image is provided for this task.]
<|end|><|assistant|>

Generated text:
You are a helpful assistant. Given the following information, provide a detailed and accurate response:
Analyze the image in a comprehensive and detailed manner.
[An image is provided for this task.]
 The scene shows two people standing on opposite sides of an open bedroom window that leads to another room or living space inside their home building complexes near each other's apartments located next door across from one another at 1234 Northwest Boulevard Street (Stanley Apartments). They appear comfortable with being outside together while having conversations without actually interacting directly as they sit facing away towards different directions around them - possibly enjoying some fresh air before returning i

### merge models and save in gdrive

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
# Save the projector
projector_path = '/content/drive/MyDrive/multimodel_llm/image_projector.pth'
os.makedirs(os.path.dirname(projector_path), exist_ok=True)
torch.save(model.projector.state_dict(), projector_path)
print(f"Projector saved to: {projector_path}")

Projector saved to: /content/drive/MyDrive/multimodel_llm/image_projector.pth


In [None]:
# Merge the fine-tuned adapter with the base model
from peft import AutoPeftModelForCausalLM
from peft import PeftModel

# Load the fine-tuned model with the LoRA adapter
# Reload model in FP16 and merge it with LoRA weights
base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float16,
    device_map=device_map,
)
print(base_model)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Phi3ForCausalLM(
  (model): Phi3Model(
    (embed_tokens): Embedding(32064, 3072, padding_idx=32000)
    (embed_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-31): 32 x Phi3DecoderLayer(
        (self_attn): Phi3Attention(
          (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
          (qkv_proj): Linear(in_features=3072, out_features=9216, bias=False)
          (rotary_emb): Phi3RotaryEmbedding()
        )
        (mlp): Phi3MLP(
          (gate_up_proj): Linear(in_features=3072, out_features=16384, bias=False)
          (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
          (activation_fn): SiLU()
        )
        (input_layernorm): Phi3RMSNorm((3072,), eps=1e-05)
        (resid_attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_mlp_dropout): Dropout(p=0.0, inplace=False)
        (post_attention_layernorm): Phi3RMSNorm((3072,), eps=1e-05)
      )
    )
    (norm): Phi3RMSNorm((3072,), eps=1e-05)
  )
 

In [None]:
# gdrive_checkpoint_dir = "/content/drive/MyDrive/multimodel_llm/phi3_checkpoints"
# final_model_path = os.path.join(gdrive_checkpoint_dir, "final_model")

In [None]:
new_model = PeftModel.from_pretrained(base_model, final_model_path)
print(new_model)

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Phi3ForCausalLM(
      (model): Phi3Model(
        (embed_tokens): Embedding(32064, 3072, padding_idx=32000)
        (embed_dropout): Dropout(p=0.0, inplace=False)
        (layers): ModuleList(
          (0-31): 32 x Phi3DecoderLayer(
            (self_attn): Phi3Attention(
              (o_proj): lora.Linear(
                (base_layer): Linear(in_features=3072, out_features=3072, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=3072, out_features=32, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=32, out_features=3072, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vect

In [None]:
print(new_model)

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Phi3ForCausalLM(
      (model): Phi3Model(
        (embed_tokens): Embedding(32064, 3072, padding_idx=32000)
        (embed_dropout): Dropout(p=0.0, inplace=False)
        (layers): ModuleList(
          (0-31): 32 x Phi3DecoderLayer(
            (self_attn): Phi3Attention(
              (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
              (qkv_proj): Linear(in_features=3072, out_features=9216, bias=False)
              (rotary_emb): Phi3RotaryEmbedding()
            )
            (mlp): Phi3MLP(
              (gate_up_proj): Linear(in_features=3072, out_features=16384, bias=False)
              (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
              (activation_fn): SiLU()
            )
            (input_layernorm): Phi3RMSNorm((3072,), eps=1e-05)
            (resid_attn_dropout): Dropout(p=0.0, inplace=False)
            (resid_mlp_dropout): Dropout(p=0.0, inplace=Fals

In [None]:
for name, param in new_model.named_parameters():
    if 'lora' in name:
        print(f"{name}: {param.data.abs().mean()}")

base_model.model.model.layers.0.self_attn.o_proj.lora_A.default.weight: 0.009008973836898804
base_model.model.model.layers.0.self_attn.o_proj.lora_B.default.weight: 0.0
base_model.model.model.layers.0.mlp.down_proj.lora_A.default.weight: 0.005532282404601574
base_model.model.model.layers.0.mlp.down_proj.lora_B.default.weight: 0.0
base_model.model.model.layers.1.self_attn.o_proj.lora_A.default.weight: 0.009038114920258522
base_model.model.model.layers.1.self_attn.o_proj.lora_B.default.weight: 0.0
base_model.model.model.layers.1.mlp.down_proj.lora_A.default.weight: 0.00552835687994957
base_model.model.model.layers.1.mlp.down_proj.lora_B.default.weight: 0.0
base_model.model.model.layers.2.self_attn.o_proj.lora_A.default.weight: 0.009009761735796928
base_model.model.model.layers.2.self_attn.o_proj.lora_B.default.weight: 0.0
base_model.model.model.layers.2.mlp.down_proj.lora_A.default.weight: 0.005520865321159363
base_model.model.model.layers.2.mlp.down_proj.lora_B.default.weight: 0.0
base_

In [None]:
# Merge the LoRA adapter with the base model
merged_model = new_model.merge_and_unload()
print(new_model)

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Phi3ForCausalLM(
      (model): Phi3Model(
        (embed_tokens): Embedding(32064, 3072, padding_idx=32000)
        (embed_dropout): Dropout(p=0.0, inplace=False)
        (layers): ModuleList(
          (0-31): 32 x Phi3DecoderLayer(
            (self_attn): Phi3Attention(
              (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
              (qkv_proj): Linear(in_features=3072, out_features=9216, bias=False)
              (rotary_emb): Phi3RotaryEmbedding()
            )
            (mlp): Phi3MLP(
              (gate_up_proj): Linear(in_features=3072, out_features=16384, bias=False)
              (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
              (activation_fn): SiLU()
            )
            (input_layernorm): Phi3RMSNorm((3072,), eps=1e-05)
            (resid_attn_dropout): Dropout(p=0.0, inplace=False)
            (resid_mlp_dropout): Dropout(p=0.0, inplace=Fals

In [None]:
# Define the path to save the merged model in Google Drive
merged_model_path = '/content/drive/MyDrive/multimodel_llm/merged_phi3_llava_model'

# Save the merged model
# merged_model.save_pretrained(merged_model_path)

# Initialize the projector
image_embedding_dim = len(hf_dataset[0]['image_embedding'])
projection_dim = merged_model.config.hidden_size  # Get dimension from the model
projector = ImageProjector(image_embedding_dim, projection_dim).to(device)
projector.load_state_dict(torch.load(final_model_path + '/image_projector.pth'))
print(projector)

ImageProjector(
  (layer1): Linear(in_features=512, out_features=1024, bias=True)
  (activation): GELU(approximate='none')
  (layer2): Linear(in_features=1024, out_features=3072, bias=True)
  (dropout): Dropout(p=0.05, inplace=False)
)


  projector.load_state_dict(torch.load(final_model_path + '/image_projector.pth'))


In [None]:
# Combine Phi-3 with the projector
phi3_with_projector = Phi3WithProjector(merged_model, projector)
print(phi3_with_projector)

Phi3WithProjector(
  (phi3): Phi3ForCausalLM(
    (model): Phi3Model(
      (embed_tokens): Embedding(32064, 3072, padding_idx=32000)
      (embed_dropout): Dropout(p=0.0, inplace=False)
      (layers): ModuleList(
        (0-31): 32 x Phi3DecoderLayer(
          (self_attn): Phi3Attention(
            (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
            (qkv_proj): Linear(in_features=3072, out_features=9216, bias=False)
            (rotary_emb): Phi3RotaryEmbedding()
          )
          (mlp): Phi3MLP(
            (gate_up_proj): Linear(in_features=3072, out_features=16384, bias=False)
            (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
            (activation_fn): SiLU()
          )
          (input_layernorm): Phi3RMSNorm((3072,), eps=1e-05)
          (resid_attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_mlp_dropout): Dropout(p=0.0, inplace=False)
          (post_attention_layernorm): Phi3RMSNorm((3072,), eps=1e-

In [None]:
def compare_models(base_model, fine_tuned_model):
    differences = []
    for (name1, p1), (name2, p2) in zip(base_model.named_parameters(), fine_tuned_model.named_parameters()):
        if name1 == name2:
            diff = (p1 - p2).abs().max().item()  # Use max absolute difference
            differences.append((name1, diff))
    return differences
# Compare the models
differences = compare_models(base_model, new_model.model)

# Sort differences by magnitude
differences.sort(key=lambda x: x[1], reverse=True)
print("Top 10 layers with the largest differences:")
for name, diff in differences[:10]:
    print(f"{name}: {diff}")

Top 10 layers with the largest differences:
model.embed_tokens.weight: 0.0
model.layers.0.self_attn.o_proj.weight: 0.0
model.layers.0.self_attn.qkv_proj.weight: 0.0
model.layers.0.mlp.gate_up_proj.weight: 0.0
model.layers.0.mlp.down_proj.weight: 0.0
model.layers.0.input_layernorm.weight: 0.0
model.layers.0.post_attention_layernorm.weight: 0.0
model.layers.1.self_attn.o_proj.weight: 0.0
model.layers.1.self_attn.qkv_proj.weight: 0.0
model.layers.1.mlp.gate_up_proj.weight: 0.0


In [None]:
# Save the merged model with the projector
phi3_with_projector.save_pretrained(merged_model_path)

# Save the tokenizer
# Reload tokenizer to save it
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
tokenizer.save_pretrained(merged_model_path)

print(f"Merged model and tokenizer saved to: {merged_model_path}")

Merged model and tokenizer saved to: /content/drive/MyDrive/multimodel_llm/merged_phi3_llava_model


In [None]:
from huggingface_hub import HfApi

api = HfApi()
api.upload_folder(
    folder_path=merged_model_path,
    repo_id="sayanbanerjee32/multimodal-phi3-4k-instruct-llava",
    repo_type="model",
    delete_patterns = "*.safetensors",
)
print("Model uploaded to Hugging Face Hub")

No files have been modified since last commit. Skipping to prevent empty commit.


Model uploaded to Hugging Face Hub


In [None]:
print('end')