In [None]:
!pip install peft wandb datasets trl==0.8.5 transformers accelerate -q
!pip install -U bitsandbytes flash_attn -q

In [None]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/Colab_Notebooks/Session30

In [3]:
import tqdm
from llavadataset import llavadataset, collate_fn
import pickle
import peft
from peft import LoraConfig
from transformers import AutoTokenizer,BitsAndBytesConfig, AutoModelForCausalLM, CLIPVisionModel, AutoProcessor
import torch
from torch.utils.data import random_split, DataLoader
import pandas as pd
from torch.nn import functional as F
import csv

import random
from PIL import Image
import requests
import wandb
import os
from peft import PeftModel
import torch.nn as nn

import torch
import torch.nn.functional as F
from tqdm import tqdm



In [None]:
import wandb
from google.colab import userdata
wandb1 = userdata.get('wandb')
os.environ["WANDB_API_KEY"] = wandb1

clip_model_name = "openai/clip-vit-base-patch32"
phi_model_name  = "microsoft/phi-2"
#phi_model_name  = "microsoft/Phi-3-mini-128k-instruct"
tokenizer  = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
processor  = AutoProcessor.from_pretrained(clip_model_name)
tokenizer.add_tokens('[QA]')
tokenizer.add_special_tokens({'pad_token':'[PAD]'})
train_batch_size    = 32
clip_embed = 768
phi_embed  = 2560
device = "cuda" if torch.cuda.is_available() else "cpu"
num_workers = 10
IMAGE_TOKEN_ID = 23893 # token for word comment
max_steps      = 100000
EOS_TOKEN_ID   = 50256
phi_patches    = 49
vocab_size     = 51200
max_generate_length = 100
model_val_step      = 1000
model_log_step      = 100
model_save_step     = 100
wandb.init(project="clip_phi2_project", name="clip_phi3_finetune")
torch.set_float32_matmul_precision('medium')
tokenizer.pad_token_id, tokenizer.eos_token_id, tokenizer('[QA]')

In [5]:
# training data
csv_file = 'train_token.csv'
qa_dataset = pd.read_csv(csv_file)

# data loaders
train_dataloader = DataLoader(llavadataset(qa_dataset, phi_model_name,clip_model_name,processor),
                                           collate_fn=collate_fn, batch_size=train_batch_size, num_workers = num_workers, shuffle=True, pin_memory=True)

In [None]:
file = open('sample_val_data.csv')
csvreader = csv.reader(file)
sample_val_data = []
for row in csvreader:
    sample_val_data.append(row)
print(sample_val_data[1])
file.close()

In [7]:
class SimpleResBlock(nn.Module):
    def __init__(self, phi_embed):
        super().__init__()
        self.pre_norm = nn.LayerNorm(phi_embed)
        self.proj = nn.Sequential(
            nn.Linear(phi_embed, phi_embed),
            nn.GELU(),
            nn.Linear(phi_embed, phi_embed)
        )
    def forward(self, x):
        x = self.pre_norm(x)
        return x + self.proj(x)

In [None]:
clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16)

phi_model = AutoModelForCausalLM.from_pretrained(
    phi_model_name,
    torch_dtype="auto",
    quantization_config=bnb_config,
    trust_remote_code=True
)
phi_model.config.use_cache = False
projection = torch.nn.Linear(clip_embed, phi_embed).to(device)
resblock = SimpleResBlock(phi_embed).to(device)

In [None]:
from peft import LoraConfig

lora_alpha = 16
lora_dropout = 0.1
lora_r = 64

peft_config = LoraConfig(
    lora_alpha=lora_alpha,

    lora_dropout=lora_dropout,
    r=lora_r,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"]
)
peft_model = peft.get_peft_model(phi_model, peft_config).to(device)
peft_model.print_trainable_parameters()

In [10]:
# clip non trainable
for network in [clip_model]:
    for param in network.parameters():
        param.requires_grad_(False)


In [None]:
# check trainable paramaeters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"PEFT MODEL:{count_parameters(peft_model)}")
print(f"PROJECTION MODEL:{count_parameters(projection)}")
print(f"CLIP MODEL:{count_parameters(clip_model)}")
print(f"PHI MODEL:{count_parameters(phi_model)}")
print(f"RESNET MODEL:{count_parameters(resblock)}")


In [12]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TORCH_USE_CUDA_DSA'] = '1'

In [None]:
print(peft_model)

In [None]:
# random validation prediction
def model_run_val(sample_val_data,max_generate_length=10):

    total_val_len = len(sample_val_data)
    random_val_datapoint = random.randrange(1,total_val_len) # 0 is header

    val_image_url = sample_val_data[random_val_datapoint][0]
    val_q = sample_val_data[random_val_datapoint][1]
    val_a = sample_val_data[random_val_datapoint][2]

    with torch.no_grad():
        image_load = Image.open(requests.get(val_image_url,stream=True).raw)
        image_processed = processor(images=image_load, return_tensors="pt").to(device)
        clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:]
        val_image_embeds = projection(clip_val_outputs)
        val_image_embeds = resblock(val_image_embeds).to(torch.float16)


        img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
        img_token_embeds = peft_model.model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)

        val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0)
        val_q_embeds  = peft_model.model.model.embed_tokens(val_q_tokenised).unsqueeze(0)

        val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 1, 69, 2560

        predicted_caption = peft_model.generate(inputs_embeds=val_combined_embeds,
                                                  max_new_tokens=max_generate_length,
                                                  return_dict_in_generate = True)

        predicted_captions_decoded = tokenizer.batch_decode(predicted_caption.sequences[:, 1:])[0]
        predicted_captions_decoded = predicted_captions_decoded.replace("<|endoftext|>", "")

    print(f"Image: {val_image_url}")
    print(f"Question: {val_q}")
    print(f"Answer:   {val_a}")
    print(f"Model Predicted Ans: {predicted_captions_decoded}")

model_run_val(sample_val_data,max_generate_length=100)


In [None]:
phi_optimizer        = torch.optim.Adam(peft_model.parameters(), lr=1e-6)
projection_optimizer = torch.optim.Adam(projection.parameters(), lr=1e-5)
resnet_optimizer     = torch.optim.Adam(resblock.parameters(),   lr=1e-5)

step = 0
running_loss = 0.
projection.train()
peft_model.train()
resblock.train()



In [None]:

for epoch in tqdm(range(1), desc="Epochs"):
    for batch_idx, (images, questions, answers) in tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=f"Epoch {epoch+1}", leave=False):
        # Move data to device
        images = {'pixel_values': images.to(device)}
        questions = questions.to(device)
        answers = answers.to(device)

        # CLIP processing
        clip_outputs = clip_model(**images)
        images_embeds = clip_outputs.last_hidden_state[:, 1:, :]  # remove cls token

        # Projection and processing
        image_embeds = projection(images_embeds)
        image_embeds = resblock(image_embeds).to(torch.float16)

        # Token embedding
        img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).repeat(questions.size(0), 1).to(device)
        img_token_embeds = peft_model.model.model.embed_tokens(img_token_tensor)
        questions_embed = peft_model.model.model.embed_tokens(questions)

        # Ensure dimensions match
        if image_embeds.shape[2] != img_token_embeds.shape[2]:
            padding_size = img_token_embeds.shape[2] - image_embeds.shape[2]
            image_embeds = F.pad(image_embeds, (0, padding_size))

        # Combine embeddings
        combined_embeds = torch.cat([image_embeds, img_token_embeds, questions_embed], dim=1)

        # Handle sequence length
        max_position_embeddings = peft_model.config.max_position_embeddings
        seq_length = combined_embeds.size(1)

        if seq_length > max_position_embeddings:
            print(f"Warning: Sequence length ({seq_length}) exceeds maximum position embeddings ({max_position_embeddings}). Truncating.")
            combined_embeds = combined_embeds[:, :max_position_embeddings, :]
            seq_length = max_position_embeddings

        seq_length = combined_embeds.size(1)
        batch_size = combined_embeds.size(0)

        # Ensure seq_length is within a reasonable range
        max_allowed_length = 2048  # Adjust this based on your model's capabilities
        if seq_length > max_allowed_length:
            print(f"Warning: Sequence length ({seq_length}) exceeds maximum allowed length ({max_allowed_length}). Truncating.")
            combined_embeds = combined_embeds[:, :max_allowed_length, :]
            seq_length = max_allowed_length

        # Create position IDs with error handling
        try:
            position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
            position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
        except RuntimeError as e:
            position_ids = torch.arange(seq_length, dtype=torch.long)
            position_ids = position_ids.unsqueeze(0).expand(batch_size, -1).to(device)

        # Forward pass
        phi_output = peft_model(inputs_embeds=combined_embeds, position_ids=position_ids)
        phi_output_logits = phi_output.logits

        # Process output logits
        phi_output_logits = phi_output_logits[:, images_embeds.shape[1] + 1:, :]
        phi_output_logits = phi_output_logits.reshape(-1, peft_model.config.vocab_size)

        phi_optimizer.zero_grad()
        projection_optimizer.zero_grad()
        resnet_optimizer.zero_grad()

        loss = F.cross_entropy(phi_output_logits, answers.contiguous().view(-1), ignore_index=50296, label_smoothing=0.1)

        # loss backprop
        loss.backward()
        phi_optimizer.step()
        projection_optimizer.step()
        resnet_optimizer.step()

        if step % model_log_step == 0:
            tqdm.write(f"Iteration {step}/{max_steps}, Loss: {loss.item()}")

        if step % model_val_step == 0:
            projection.eval()
            peft_model.eval()
            resblock.eval()
            model_run_val(sample_val_data, max_generate_length)
            projection.train()
            peft_model.train()
            resblock.train()

        if step % model_save_step == 0:
            tqdm.write("Saving Checkpoint")
            torch.save(projection.state_dict(), os.path.join(os.getcwd(), "model_chkpt/finetunned_projection.pth"))
            torch.save(resblock.state_dict(), os.path.join(os.getcwd(), "model_chkpt/finetuned_resblock.pth"))
            peft_model.save_pretrained(os.path.join(os.getcwd(), "model_chkpt/lora_adaptor/"), save_adapter=True, save_config=True)

        if step >= max_steps:
            tqdm.write("Training finished.")
            break

        wandb.log({"step": step, "train_loss": loss.item()})
        step += 1

    tqdm.write(f"Epoch {epoch+1} completed.")

In [None]:
import torch

aa= torch.arange(71, dtype=torch.long)
print("aa",aa.shape)
bb = aa.unsqueeze(0)
print("bb",bb.shape)
cc=bb.repeat(8, 1)
print("cc",cc.shape)

position_ids = cc.to(dtype=torch.long,device='cuda')
print("position_ids",position_ids.shape)

