# TSAI Capstone - Fine tune projection layer + phi2 

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoConfig

#### ML projection layer taken from: 

In [2]:
def build_patch_mlp_projector(
    input_hidden_size: int, lm_hidden_size: int, num_layers: int
) -> nn.Module:
    modules = [nn.Linear(input_hidden_size, lm_hidden_size, bias=False)]
    for _ in range(1, num_layers):
        modules.append(nn.GELU())
        modules.append(nn.Linear(lm_hidden_size, lm_hidden_size, bias=False))
    return nn.Sequential(*modules)


class _MLPVectorProjector(nn.Module):
    def __init__(
        self, input_hidden_size: int, lm_hidden_size: int, num_layers: int, width: int
    ):
        super(_MLPVectorProjector, self).__init__()
        self.mlps = nn.ModuleList()
        for _ in range(width):
            mlp = [nn.Linear(input_hidden_size, lm_hidden_size, bias=False)]
            for _ in range(1, num_layers):
                mlp.append(nn.GELU())
                mlp.append(nn.Linear(lm_hidden_size, lm_hidden_size, bias=False))
            self.mlps.append(nn.Sequential(*mlp))

    def forward(self, x):
        return torch.cat([mlp(x) for mlp in self.mlps], dim=-2)


def build_mlp_vector_projector(
    input_hidden_size: int, lm_hidden_size: int, num_layers: int, num_tokens: int
):
    return _MLPVectorProjector(
        input_hidden_size, lm_hidden_size, num_layers, num_tokens
    )

#### Load the projection model that we obtained from Step 1


#### Use the proj + phi2 model from step 1. Use Q&A instead of captions. 

In [3]:
from transformers import AutoModelForCausalLM
import copy
import peft
from peft import LoraConfig

In [4]:

model_name = "microsoft/phi-2"
phi2 = AutoModelForCausalLM.from_pretrained(
            model_name,
            trust_remote_code=True,
            # torch_dtype = torch.float16
        ).to("cuda")


lora_alpha = 16
lora_dropout = 0.1
lora_r = 64

peft_config = LoraConfig(
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    r=lora_r,
    bias="none",
    target_modules=[
        "Wqkv",
        "out_proj",
        "fc1",
        "fc2",
    ]
)

peft_phi_model = peft.get_peft_model(phi2, peft_config)
peft_phi_model.print_trainable_parameters()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

trainable params: 52,428,800 || all params: 2,832,112,640 || trainable%: 1.8512258043521885


In [5]:
peft_phi_model

PeftModel(
  (base_model): LoraModel(
    (model): PhiForCausalLM(
      (model): PhiModel(
        (embed_tokens): Embedding(51200, 2560)
        (embed_dropout): Dropout(p=0.0, inplace=False)
        (layers): ModuleList(
          (0-31): 32 x PhiDecoderLayer(
            (self_attn): PhiAttention(
              (q_proj): Linear(in_features=2560, out_features=2560, bias=True)
              (k_proj): Linear(in_features=2560, out_features=2560, bias=True)
              (v_proj): Linear(in_features=2560, out_features=2560, bias=True)
              (dense): Linear(in_features=2560, out_features=2560, bias=True)
              (rotary_emb): PhiRotaryEmbedding()
            )
            (mlp): PhiMLP(
              (activation_fn): NewGELUActivation()
              (fc1): lora.Linear(
                (base_layer): Linear(in_features=2560, out_features=10240, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )


In [6]:
class ImageWithPhiLayer(nn.Module):
    def __init__(self, 
                 clip_emb:int = 512, 
                 token_emb: int = 2560,
                 projection_n_tokens: int = 4,
                 projection_n_layers: int = 1
                ):
        super().__init__()       
        self.projection_n_tokens = projection_n_tokens
        self.ll1 = build_mlp_vector_projector(
            clip_emb, token_emb, projection_n_layers, self.projection_n_tokens).to("cuda")
        self.ll1.load_state_dict(torch.load('stage_2_proj_head.pth'))
        model_name = "microsoft/phi-2"
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        self.vocab_size = len(self.tokenizer)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        
        self.phi2Model = peft_phi_model
        self.token_embedding = self.phi2Model.get_submodule('base_model.model.model.embed_tokens')
        
    def generate(self,  x, Qtokens):
        x = self.ll1(x)
        Qtoken_embeddings = self.token_embedding(Qtokens)
        inputs = torch.concat((x, Qtoken_embeddings), axis=-2)
        
        return self.tokenizer.batch_decode(
            model.phi2Model.generate(
                inputs_embeds=inputs, 
                max_new_tokens=20,
                bos_token_id=model.tokenizer.bos_token_id, 
                eos_token_id=model.tokenizer.eos_token_id,
                pad_token_id=model.tokenizer.pad_token_id
            )
        )

        
    def forward(self, x, QnAtokens, QTokenLength, QnA_length):
        x = self.ll1(x)
        QnAtoken_embeddings = self.token_embedding(QnAtokens)
        inputs = torch.concat((x, QnAtoken_embeddings), axis=-2)
        outputs = self.phi2Model(inputs_embeds=inputs)
        predictions = []
        
        b,t,v = outputs.logits.shape
        
        for i in range(b):
            if (i == 0):
                loss = F.cross_entropy(
                    outputs.logits[
                    i, self.projection_n_tokens + QTokenLength[i].item(): self.projection_n_tokens + QnA_length[i].item(), :],
                    QnAtokens[i][QTokenLength[i].item() + 1: QnA_length[i].item() +1]
                )
            else:
                loss += F.cross_entropy(
                    outputs.logits[
                    i, self.projection_n_tokens + QTokenLength[i].item(): self.projection_n_tokens + QnA_length[i].item(), :],
                    QnAtokens[i][QTokenLength[i].item() + 1: QnA_length[i].item() +1],
                )
        
        return loss / b, predictions

In [7]:
model = ImageWithPhiLayer()
#[(n, type(m)) for n, m in model.phi2Model.named_modules()]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


# stage 2 - Instruction finetuning

In [8]:
import json 

instruct_dataset = f'./llava_instruct_150k.json'
with open(instruct_dataset, 'r') as f:
    instruct_data = json.load(f)

In [9]:
from torch.utils.data import Dataset, DataLoader

class CustomTextDataset(Dataset):
    def __init__(self, json_data, image_embedding_dict,  tokenizer, maxContext=512):
        self.image_embedding_dict = image_embedding_dict
        self.tokenizer = tokenizer
        self.json_data = json_data
        self.maxContext = maxContext
        
        self.entries = []        
        for entry in json_data:
            image = entry['image']
            image_embedding = self.getEmbeddingForImage(image)
            if image_embedding is None:
                continue
            
            conversations = entry['conversations']
            for i in range(len(conversations)):
                if conversations[i]['from'] == 'human':
                    if len(conversations[i]['value'] + conversations[i + 1]['value']) > 512:
                        continue
                    question = 'Question: ' + conversations[i]['value'].lstrip('<image>\n')
                    answer = 'Answer: ' + conversations[i + 1]['value']  
                    # Assuming the next message is from 'gpt' and contains the answer
                    self.entries.append({
                        'image_name': image,
                        'image_embedding': image_embedding,
                        'Question': question,
                        'Answer': answer,
                        'QnAText': question + answer
                        }) 
        print('------------- num entries = -----------------')
        print(len(self.entries))

    def getEmbeddingForImage(self, image):
        if image in self.image_embedding_dict:
            image_embedding = self.image_embedding_dict[image]
            return image_embedding
        else:
            return None      

    def __len__(self):
        return len(self.entries)

    def __getitem__(self, idx):
        entry = self.entries[idx]
        image_name = entry['image_name']
        Q_caption_tokens = tokenizer.encode(entry['Question'], add_special_tokens=True)
        QnA_captions_tokens = tokenizer.encode(entry['QnAText'], add_special_tokens=True)
        QTokensLength = len(Q_caption_tokens)
        QnA_length = len(QnA_captions_tokens)
        QnA_captions_tokens = QnA_captions_tokens + [tokenizer.pad_token_id] * (self.maxContext - len(QnA_captions_tokens))       

        return {'image_name': entry['image_name'], 
                'QText': entry['Question'], 
                'AText': entry['Answer'], 
                'image_embedding':  entry['image_embedding'].to("cuda"), 
                'QnA_tokens': torch.tensor(QnA_captions_tokens),
                'QTokensLength': QTokensLength,
                'QnA_length': QnA_length
               }



In [10]:
# img_emb = torch.load("img_embeddings.pth").unsqueeze(1).to("cpu")
# print(img_emb.shape)

In [11]:
# with open("./image_names.json", 'r') as file:
#     image_names = json.load(file)
# imgEmbDict = dict(zip(image_names, img_emb))

In [12]:
imgEmbDict = torch.load('img_embeddings_dict.pth', map_location=torch.device('cpu'))

In [13]:
model_name = "microsoft/phi-2"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

custom_dataset = CustomTextDataset(instruct_data, imgEmbDict,  tokenizer)
custom_dataloader = DataLoader(custom_dataset, batch_size=8, shuffle=True)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


------------- num entries = -----------------
225484


#### Train - finetune the proj + phi2 peft model with QnA dataset

In [None]:
## Training loop
num_epochs = 200
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

for epoch in range(num_epochs):
    model.train()
    for ix, batch in enumerate(custom_dataloader):
        
        embeddings = batch['image_embedding'].to('cuda')
        QnAtokens = batch['QnA_tokens'].to('cuda')
        QTokenLength = batch['QTokensLength'].to('cuda')
        QnA_length = batch['QnA_length'].to('cuda')
        
        # Backward pass and optimization
        
        if ix % 50 == 0:
            index = 2
            prediction = model.generate(
                embeddings[index].unsqueeze(0), 
                QnAtokens[index].unsqueeze(0)[:, :QTokenLength[index].item() + 2 ]
            )
            print("------------Questionsdf text = -------------------")
            print(''.join(model.tokenizer.batch_decode(QnAtokens[index])).rstrip('<|endoftext|>').rstrip("\n"))
            print("------------Teacher forced predictions text = -------------------")
            print(prediction[0].rstrip('<|endoftext|>').rstrip("\n")[:200])
        optimizer.zero_grad()
        loss, predictions = model(embeddings, QnAtokens, QTokenLength, QnA_length)
        if ix % 10 == 0: print(f"{epoch=} Step={ix + 1}, Loss={loss.item()}")
        loss.backward()
        optimizer.step()       
    model.phi2Model.save_pretrained("stage2_v3")
    torch.save(
        model.ll1.state_dict(), 
        "stage_2_proj_head_v3.pth"
    )

------------Questionsdf text = -------------------
Question: Where is the cat positioned in relation to the person's chest?Answer: The cat is resting between the person's chest and the keyboard.
------------Teacher forced predictions text = -------------------
<|endoftext|> The cat is positioned on the person's chest.Question: What is the cat doing?Answer:
epoch=0 Step=1, Loss=1.286075234413147
epoch=0 Step=11, Loss=1.2231391668319702
epoch=0 Step=21, Loss=1.44710111618042
epoch=0 Step=31, Loss=1.5892454385757446
epoch=0 Step=41, Loss=1.100769281387329
------------Questionsdf text = -------------------
Question: What are the main elements of the bathroom?Answer: The main elements of the bathroom include a clawfoot tub, a white toilet, and the damaged wall with peeling paint. Additionally, there is a large hole in the ceiling that contributes to the overall rundown appearance of the room.
------------Teacher forced predictions text = -------------------
<|endoftext|> The main elements o

In [None]:
import torch, gc
gc.collect()
torch.cuda.empty_cache()