<a href="https://colab.research.google.com/github/walnashgit/S30-Capstone/blob/main/S30ProjectionLayerTrainingipynb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import userdata
import os

os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')

In [None]:
#!/bin/bash
!kaggle datasets download walnash/coco2014embeddingsm2


Dataset URL: https://www.kaggle.com/datasets/walnash/coco2014embeddingsm2
License(s): unknown
Downloading coco2014embeddingsm2.zip to /content
100% 6.64G/6.65G [01:30<00:00, 124MB/s]
100% 6.65G/6.65G [01:30<00:00, 78.5MB/s]


In [None]:
!unzip /content/coco2014embeddingsm2.zip

Archive:  /content/coco2014embeddingsm2.zip
  inflating: captions_train2014.json  
  inflating: coco2014_clip_embeddings_m2.h5  


In [None]:
cfg = dict(
    num_epoch=1,
    last_epoch=2,
    saved_model='', # model name with dir
    resume=True,
    data_dir='../data',
    checkpoint_dir='./checkpoint',
    max_seqlen=80,
    batch_size=2,
    vision_projector_file='',
    validation_phase=False,
    clip_dim=768, #512,
    phi_dim=2560,
    image_token="<image>",
)
IMAGE_TOKEN_INDEX = -200

In [None]:
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
    prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]

    def insert_separator(X, sep):
        return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]

    input_ids = []
    offset = 0
    if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
        offset = 1
        input_ids.append(prompt_chunks[0][0])
    for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
        input_ids.extend(x[offset:])

    if return_tensors is not None:
        if return_tensors == 'pt':
            return torch.tensor(input_ids, dtype=torch.long)
        raise ValueError(f'Unsupported tensor type: {return_tensors}')
    return input_ids

In [None]:
import json
import random

import torch
from torch.utils.data import Dataset, Sampler
import h5py


class ProjectionLayerDataset2(Dataset):
    def __init__(self, embedding_file, caption_file, tokenizer, max_length=80):
        self.embedding_file = h5py.File(embedding_file, 'r')
        self.image_ids = set(self.embedding_file.keys())
        self.caption_file = caption_file
        self.captions = self.get_caption_dict()  # List of text inputs
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.image_token = cfg['image_token']

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        caption = self.captions[idx]
        image_id = str(caption.get('image_id'))
        image_embedding = torch.tensor(self.embedding_file[image_id][()])
        image_embedding = image_embedding[1:, :]  # returning 49 x 768
        text = caption.get('caption')
        prompt = f'{self.image_token} caption: '

        prompt_ids = torch.tensor(tokenizer_image_token(prompt, tokenizer=self.tokenizer), dtype=torch.int32)
        # prompt_ids = tokenized_prompt['input_ids']
        labels = self.tokenizer.encode(text)
        labels = self.get_final_label_ids(labels, prompt_ids, image_embedding)

        return {
            'image_embedding': image_embedding,
            'prompt_ids': prompt_ids, #torch.tensor(prompt_ids, dtype=torch.int32),
            'labels': labels,
            # 'attention_mask': torch.tensor(tokenized['attention_mask']),
            # 'image_token_position': image_token_position
        }

    def get_final_label_ids(self, labels, prompt_ids, image_embedding):
        pad_token_count = self.max_length - (prompt_ids.size(0) + image_embedding.size(0) - 1) - len(labels) - 1
        if pad_token_count < 0:
            pad_token_count = 0
            # truncate_len = self.max_length - (len(token_ids) + prompt_ids.size(0) - 1) - 1
            truncate_len = self.max_length - (prompt_ids.size(0) + image_embedding.size(0) - 1) - 1
            labels = labels[:truncate_len]

        labels = torch.cat(
            [
                torch.tensor(labels, dtype=torch.int32),
                torch.tensor([self.tokenizer.eos_token_id], dtype=torch.int32),
                torch.tensor([self.tokenizer.pad_token_id] * pad_token_count, dtype=torch.int32)
            ],
            dim=0
        )
        return labels


    def get_caption_dict(self):
        with open(self.caption_file, 'r') as f:
            caption_file_json = json.load(f)
            # return [cap for cap in caption_file_json['annotations']]
            return [
                cap for cap in caption_file_json['annotations']
                if str(cap['image_id']) in self.image_ids
            ]

In [None]:
import torch

class ProjectionPreTrainer(torch.nn.Module):
    def __init__(self, projectionModel, phi_model, phi_tokenizer, device='cuda'):
        super().__init__()
        self.device = device
        self.projectionModel = projectionModel
        self.phi_model = phi_model
        self.phi_tokenizer = phi_tokenizer
        self.phi_model.to(device)
        self.phi_embeddings = self.phi_model.get_input_embeddings()
        self.loss = nn.CrossEntropyLoss()

    def forward(self, image_embedding, label_ids, prompt_ids):
        # Get Phi-2 text embeddings
        projected_clip = self.projectionModel(image_embedding)
        # with torch.no_grad():
        new_prompt_embeds = self.prepare_input_embed(projected_clip, prompt_ids)
        ie_size = new_prompt_embeds.size(1) - 1

        phi_text_embedding = self.phi_embeddings(label_ids)
        combined_embeddings = torch.cat(
            [
                new_prompt_embeds,
                phi_text_embedding
            ],
            dim=1
        )
        phi_outputs = self.phi_model(inputs_embeds=combined_embeddings)

        logits = phi_outputs['logits']

        X = logits[:, ie_size:ie_size + label_ids.size(1), :]
        Y = label_ids.contiguous().type(torch.LongTensor).to(self.device)

        X = X.contiguous().view(-1, X.size(-1))
        Y = Y.view(-1)

        loss_val = self.loss(
            X,
            Y
        )

        return logits, loss_val

    def prepare_input_embed(self, image_embeds, prompt_ids):
        new_input_embeds = []
        for batch_idx, cur_prompt_ids in enumerate(prompt_ids):
            image_token_indices = torch.where(cur_prompt_ids == IMAGE_TOKEN_INDEX)[0]
            cur_new_input_embeds = []

            image_token_start = image_token_indices[0]

            cur_new_input_embeds.append(self.phi_embeddings(cur_prompt_ids[:image_token_start]))
            cur_new_input_embeds.append(image_embeds[batch_idx])
            cur_new_input_embeds.append(self.phi_embeddings(cur_prompt_ids[image_token_start + 1:]))
            cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
            cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)

            new_input_embeds.append(cur_new_input_embeds)

        new_input_embeds = torch.stack(new_input_embeds, dim=0)

        return new_input_embeds

In [None]:
class ProjectionLayer(torch.nn.Module):
    def __init__(self, clip_dim, phi_dim):
        super().__init__()
        self.linear = torch.nn.Linear(clip_dim, phi_dim)

    def forward(self, x):
        return self.linear(x)

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
import gc


phi_tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
phi_tokenizer.pad_token = phi_tokenizer.eos_token


def pretrain_projection(model, dataloader, device):
    model = model.to(device)
    # optimizer = torch.optim.Adam(model.projectionModel.parameters(), lr=0.001)
    # optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)
    # criterion = CosineSimilarityLoss()
    # criterion = nn.CrossEntropyLoss()

    last_epoch = 0
    prev_loss = 10000
    if cfg['resume']:
        saved_model = torch.load(cfg['saved_model'])
        model.projectionModel.load_state_dict(saved_model['model_state_dict'])
        optimizer.load_state_dict(saved_model['optimizer_state_dict'])
        last_epoch = saved_model['last_epoch']
        prev_loss = saved_model['loss']
        print('prev_loss: ', prev_loss)
        print('last_epoch: ', last_epoch)

    step_count = -1
    num_epochs = cfg['num_epoch']
    # model.projectionModel.train()

    bestLoss = 0
    bestStep = 0

    for epoch in range(last_epoch, num_epochs):
        total_loss = 0
        step_loss = 0
        # for batch in dataloader:
        batch_iterator = tqdm(dataloader, desc=f"Processing Epoch {epoch:02d}")
        # for batch in itertools.islice(batch_iterator, 20):
        for batch in batch_iterator:
            optimizer.zero_grad()
            clip_embedding = batch['image_embedding']

            label_ids = batch['labels']
            prompt_ids = batch['prompt_ids']

            clip_embedding = clip_embedding.to(device).requires_grad_(True)
            label_ids = label_ids.to(device)
            prompt_ids = prompt_ids.to(device)

            with torch.autocast(device_type=device, dtype=torch.float16):

              logits, loss = model(clip_embedding, label_ids, prompt_ids)

              total_loss += loss.item()
              step_loss += loss.item()

              loss.backward()

            if step_count == -1:
              print(f"\n Epoch {epoch + 1}, step: {step_count}, loss: {loss.item()}, total loss: {total_loss}")

            if loss.item() < prev_loss:
                bestLoss = loss.item()
                bestStep = step_count
                print(f"\n Epoch {epoch + 1}, step: {step_count}, Loss: {loss.item()}, total loss: {total_loss}")
                save_model(epoch, model, loss.item(), optimizer, step_count)
                # print('saving model')
                prev_loss = loss.item()
            elif step_count > 0 and step_count % 100 == 0:
                print(f"\n Epoch {epoch + 1}, step: {step_count}, loss: {loss.item()}, total loss: {total_loss}")

            step_count += 1
            optimizer.step()

            gc.collect()
            torch.cuda.empty_cache()

        if device == 'cuda':
            torch.cuda.empty_cache()
        elif device == 'mps':
            torch.mps.empty_cache()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch + 1}, Average Loss: {avg_loss}")
        print(f"BestLoss: {bestLoss}, BestStep: {bestStep}")


def save_model(epoch, model, loss, optimizer, step_count):
    print("saving model")
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.projectionModel.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
        'step_count': step_count,
    }, '%s/projectionModel_ckpt_%s.pth' % (cfg['checkpoint_dir'], epoch))


def train():
    device = 'cpu'
    if torch.cuda.is_available():
        device = 'cuda'
    elif torch.backends.mps.is_available():
        device = 'mps'

    gc.collect()
    torch.cuda.empty_cache()

    torch.manual_seed(14)


    embed_file = '/content/coco2014_clip_embeddings_m2.h5'
    caption_file = '/content/captions_train2014.json'
    dataset = ProjectionLayerDataset2(embed_file, caption_file, phi_tokenizer)
    dataloader = DataLoader(dataset, batch_size=16, shuffle=True, pin_memory=False)


    model = get_model(device)
    pretrain_projection(model, dataloader, device)


def get_model(device):
    projectionModel = ProjectionLayer(cfg['clip_dim'], cfg['phi_dim'])
    phi_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True, torch_dtype=torch.float16, low_cpu_mem_usage=True)

    projectionModel.train()
    phi_model.eval()  # Set Phi-2 to evaluation mode
    for param in phi_model.parameters():
        param.requires_grad = False

    return ProjectionPreTrainer(projectionModel, phi_model, phi_tokenizer, device)


def find_image_toke_pos(self, input_ids):
    # Find the position of the <image> token
    image_token = cfg['image_token']
    image_token_id = self.phi_tokenizer.convert_tokens_to_ids(image_token)
    image_token_position = (torch.tensor(input_ids) == image_token_id).nonzero(as_tuple=True)[0]
    return image_token_position

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
# PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
train()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Processing Epoch 00:   0%|          | 0/15633 [00:00<?, ?it/s]We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)



 Epoch 1, step: -1, Loss: 6.3405938148498535, total loss: 6.3405938148498535
saving model


Processing Epoch 00:   0%|          | 23/15633 [00:32<6:00:24,  1.39s/it]


 Epoch 1, step: 22, Loss: 6.311634063720703, total loss: 160.87029933929443
saving model


Processing Epoch 00:   0%|          | 27/15633 [00:38<6:14:41,  1.44s/it]


 Epoch 1, step: 26, Loss: 6.302954196929932, total loss: 186.96160554885864
saving model


Processing Epoch 00:   0%|          | 30/15633 [00:42<6:19:45,  1.46s/it]


 Epoch 1, step: 29, Loss: 6.271252155303955, total loss: 206.28550100326538
saving model


Processing Epoch 00:   0%|          | 31/15633 [00:44<6:19:13,  1.46s/it]


 Epoch 1, step: 30, Loss: 6.261171817779541, total loss: 212.54667282104492
saving model


Processing Epoch 00:   0%|          | 32/15633 [00:45<6:18:30,  1.46s/it]


 Epoch 1, step: 31, Loss: 6.006778240203857, total loss: 218.55345106124878
saving model


Processing Epoch 00:   0%|          | 35/15633 [00:50<6:10:56,  1.43s/it]


 Epoch 1, step: 34, Loss: 5.978465557098389, total loss: 236.8836374282837
saving model


Processing Epoch 00:   0%|          | 65/15633 [01:34<6:15:02,  1.45s/it]


 Epoch 1, step: 64, Loss: 5.9445037841796875, total loss: 425.28772592544556
saving model


Processing Epoch 00:   0%|          | 69/15633 [01:40<6:16:24,  1.45s/it]


 Epoch 1, step: 68, Loss: 5.826568603515625, total loss: 449.4175395965576
saving model


Processing Epoch 00:   0%|          | 74/15633 [01:47<6:20:08,  1.47s/it]


 Epoch 1, step: 73, Loss: 5.632808685302734, total loss: 479.26116943359375
saving model


Processing Epoch 00:   0%|          | 77/15633 [01:52<6:12:44,  1.44s/it]


 Epoch 1, step: 76, Loss: 5.629109859466553, total loss: 496.3482346534729
saving model


Processing Epoch 00:   0%|          | 78/15633 [01:53<6:15:28,  1.45s/it]


 Epoch 1, step: 77, Loss: 5.431206703186035, total loss: 501.77944135665894
saving model


Processing Epoch 00:   1%|          | 88/15633 [02:07<6:08:56,  1.42s/it]


 Epoch 1, step: 87, Loss: 5.277965068817139, total loss: 557.8313755989075
saving model


Processing Epoch 00:   1%|          | 95/15633 [02:18<6:21:19,  1.47s/it]


 Epoch 1, step: 94, Loss: 5.2776570320129395, total loss: 595.5243182182312
saving model


Processing Epoch 00:   1%|          | 97/15633 [02:21<6:15:00,  1.45s/it]


 Epoch 1, step: 96, Loss: 5.104087829589844, total loss: 606.0059957504272
saving model


Processing Epoch 00:   1%|          | 101/15633 [02:26<6:09:23,  1.43s/it]


 Epoch 1, step: 100, Loss: 5.080832481384277, total loss: 626.9612984657288
saving model


Processing Epoch 00:   1%|          | 107/15633 [02:35<6:20:35,  1.47s/it]


 Epoch 1, step: 106, Loss: 5.023874759674072, total loss: 657.9963803291321
saving model


Processing Epoch 00:   1%|          | 109/15633 [02:38<6:15:35,  1.45s/it]


 Epoch 1, step: 108, Loss: 4.921847820281982, total loss: 667.951765537262
saving model


Processing Epoch 00:   1%|          | 112/15633 [02:43<6:13:19,  1.44s/it]


 Epoch 1, step: 111, Loss: 4.881966590881348, total loss: 682.9150495529175
saving model


Processing Epoch 00:   1%|          | 120/15633 [02:54<6:14:14,  1.45s/it]


 Epoch 1, step: 119, Loss: 4.844529151916504, total loss: 723.0004224777222
saving model


Processing Epoch 00:   1%|          | 121/15633 [02:56<6:17:15,  1.46s/it]


 Epoch 1, step: 120, Loss: 4.819287300109863, total loss: 727.819709777832
saving model


Processing Epoch 00:   1%|          | 123/15633 [02:59<6:15:53,  1.45s/it]


 Epoch 1, step: 122, Loss: 4.528977394104004, total loss: 737.2345404624939
saving model


Processing Epoch 00:   1%|          | 130/15633 [03:09<6:16:11,  1.46s/it]


 Epoch 1, step: 129, Loss: 4.437427043914795, total loss: 769.9990882873535
saving model


Processing Epoch 00:   1%|          | 132/15633 [03:12<6:14:29,  1.45s/it]


 Epoch 1, step: 131, Loss: 4.425318241119385, total loss: 779.2611742019653
saving model


Processing Epoch 00:   1%|          | 135/15633 [03:17<6:16:30,  1.46s/it]


 Epoch 1, step: 134, Loss: 4.163666248321533, total loss: 792.5036401748657
saving model


Processing Epoch 00:   1%|          | 145/15633 [03:31<6:09:52,  1.43s/it]


 Epoch 1, step: 144, Loss: 4.102085590362549, total loss: 836.3390684127808
saving model


Processing Epoch 00:   1%|          | 147/15633 [03:34<6:30:38,  1.51s/it]


 Epoch 1, step: 146, Loss: 4.033354759216309, total loss: 844.4830446243286
saving model


Processing Epoch 00:   1%|          | 150/15633 [03:39<6:29:49,  1.51s/it]


 Epoch 1, step: 149, Loss: 3.897423028945923, total loss: 856.8262150287628
saving model


Processing Epoch 00:   1%|          | 152/15633 [03:42<6:19:12,  1.47s/it]


 Epoch 1, step: 151, Loss: 3.8064823150634766, total loss: 864.6809175014496
saving model


Processing Epoch 00:   1%|          | 155/15633 [03:46<6:11:19,  1.44s/it]


 Epoch 1, step: 154, Loss: 3.7323596477508545, total loss: 876.5037198066711
saving model


Processing Epoch 00:   1%|          | 159/15633 [03:52<6:23:09,  1.49s/it]


 Epoch 1, step: 158, Loss: 3.51489520072937, total loss: 891.8760316371918
saving model


Processing Epoch 00:   1%|          | 163/15633 [03:58<6:15:43,  1.46s/it]


 Epoch 1, step: 162, Loss: 3.486243963241577, total loss: 906.9988663196564
saving model


Processing Epoch 00:   1%|          | 168/15633 [04:05<6:10:47,  1.44s/it]


 Epoch 1, step: 167, Loss: 3.3165552616119385, total loss: 924.9638092517853
saving model


Processing Epoch 00:   1%|          | 175/15633 [04:15<6:10:06,  1.44s/it]


 Epoch 1, step: 174, Loss: 3.0678493976593018, total loss: 949.0615022182465
saving model


Processing Epoch 00:   1%|          | 192/15633 [04:40<6:19:46,  1.48s/it]


 Epoch 1, step: 191, Loss: 2.691357374191284, total loss: 1004.8392353057861
saving model


Processing Epoch 00:   1%|▏         | 200/15633 [04:52<6:06:31,  1.42s/it]


 Epoch 1, step: 199, Loss: 2.6190662384033203, total loss: 1027.6461079120636
saving model


Processing Epoch 00:   1%|▏         | 201/15633 [04:53<6:14:06,  1.45s/it]


 Epoch 1, step: 200, loss: 2.7375259399414062, total loss: 1030.383633852005


Processing Epoch 00:   1%|▏         | 202/15633 [04:55<6:17:17,  1.47s/it]


 Epoch 1, step: 201, Loss: 2.5711429119110107, total loss: 1032.954776763916
saving model


Processing Epoch 00:   1%|▏         | 214/15633 [05:12<6:18:46,  1.47s/it]


 Epoch 1, step: 213, Loss: 2.452059507369995, total loss: 1065.6575510501862
saving model


Processing Epoch 00:   1%|▏         | 215/15633 [05:14<6:29:28,  1.52s/it]


 Epoch 1, step: 214, Loss: 2.389263391494751, total loss: 1068.046814441681
saving model


Processing Epoch 00:   1%|▏         | 219/15633 [05:20<6:12:20,  1.45s/it]


 Epoch 1, step: 218, Loss: 2.310533285140991, total loss: 1078.0476253032684
saving model


Processing Epoch 00:   1%|▏         | 226/15633 [05:30<6:20:58,  1.48s/it]


 Epoch 1, step: 225, Loss: 2.30560564994812, total loss: 1095.5869524478912
saving model


Processing Epoch 00:   1%|▏         | 227/15633 [05:32<6:29:22,  1.52s/it]


 Epoch 1, step: 226, Loss: 2.148407459259033, total loss: 1097.7353599071503
saving model


Processing Epoch 00:   2%|▏         | 301/15633 [07:18<6:04:14,  1.43s/it]


 Epoch 1, step: 300, loss: 2.5168445110321045, total loss: 1273.4560492038727


Processing Epoch 00:   3%|▎         | 401/15633 [09:43<6:02:29,  1.43s/it]


 Epoch 1, step: 400, loss: 2.642922878265381, total loss: 1526.3495831489563


Processing Epoch 00:   3%|▎         | 501/15633 [12:07<6:03:14,  1.44s/it]


 Epoch 1, step: 500, loss: 2.5447840690612793, total loss: 1786.315841436386


Processing Epoch 00:   3%|▎         | 520/15633 [12:34<6:01:04,  1.43s/it]


 Epoch 1, step: 519, Loss: 2.134439706802368, total loss: 1832.3478112220764
saving model


Processing Epoch 00:   3%|▎         | 543/15633 [13:08<5:56:36,  1.42s/it]


 Epoch 1, step: 542, Loss: 2.11423921585083, total loss: 1886.8277423381805
saving model


Processing Epoch 00:   4%|▍         | 601/15633 [14:31<5:53:27,  1.41s/it]


 Epoch 1, step: 600, loss: 2.805877923965454, total loss: 2027.1102902889252


Processing Epoch 00:   4%|▍         | 638/15633 [15:25<6:14:58,  1.50s/it]


 Epoch 1, step: 637, Loss: 2.095550298690796, total loss: 2114.5133740901947
saving model


Processing Epoch 00:   4%|▍         | 651/15633 [15:43<6:00:40,  1.44s/it]


 Epoch 1, step: 650, Loss: 2.0776944160461426, total loss: 2144.2186374664307
saving model


Processing Epoch 00:   4%|▍         | 660/15633 [15:56<6:08:31,  1.48s/it]


 Epoch 1, step: 659, Loss: 2.069135904312134, total loss: 2164.163196325302
saving model


Processing Epoch 00:   4%|▍         | 686/15633 [16:34<5:53:34,  1.42s/it]


 Epoch 1, step: 685, Loss: 1.9374605417251587, total loss: 2224.3225506544113
saving model


Processing Epoch 00:   4%|▍         | 701/15633 [16:55<5:53:54,  1.42s/it]


 Epoch 1, step: 700, loss: 2.2123279571533203, total loss: 2259.009969830513


Processing Epoch 00:   5%|▌         | 801/15633 [19:20<5:52:27,  1.43s/it]


 Epoch 1, step: 800, loss: 2.430595636367798, total loss: 2516.0721954107285


Processing Epoch 00:   6%|▌         | 901/15633 [21:44<5:46:19,  1.41s/it]


 Epoch 1, step: 900, loss: 2.2651784420013428, total loss: 2761.2363471984863


Processing Epoch 00:   6%|▋         | 1001/15633 [24:08<5:45:30,  1.42s/it]


 Epoch 1, step: 1000, loss: 2.3303236961364746, total loss: 2993.9928278923035


Processing Epoch 00:   7%|▋         | 1101/15633 [26:33<5:44:56,  1.42s/it]


 Epoch 1, step: 1100, loss: 2.759902238845825, total loss: 3254.091414451599


Processing Epoch 00:   8%|▊         | 1201/15633 [28:57<5:41:09,  1.42s/it]


 Epoch 1, step: 1200, loss: 2.477870225906372, total loss: 3526.432317495346


Processing Epoch 00:   8%|▊         | 1279/15633 [30:50<5:39:43,  1.42s/it]


 Epoch 1, step: 1278, Loss: 1.9176008701324463, total loss: 3702.0469992160797
saving model


Processing Epoch 00:   8%|▊         | 1301/15633 [31:21<5:38:29,  1.42s/it]


 Epoch 1, step: 1300, loss: 2.1507627964019775, total loss: 3749.628192424774


Processing Epoch 00:   8%|▊         | 1302/15633 [31:23<5:37:06,  1.41s/it]


 Epoch 1, step: 1301, Loss: 1.9026405811309814, total loss: 3751.530833005905
saving model


Processing Epoch 00:   8%|▊         | 1328/15633 [32:00<5:42:18,  1.44s/it]


 Epoch 1, step: 1327, Loss: 1.8375381231307983, total loss: 3807.038556456566
saving model


Processing Epoch 00:   9%|▉         | 1399/15633 [33:42<5:43:40,  1.45s/it]


 Epoch 1, step: 1398, Loss: 1.7019420862197876, total loss: 3955.4910390377045
saving model


Processing Epoch 00:   9%|▉         | 1401/15633 [33:45<5:40:14,  1.43s/it]


 Epoch 1, step: 1400, loss: 1.9447224140167236, total loss: 3959.2810691595078


Processing Epoch 00:  10%|▉         | 1494/15633 [35:58<5:30:32,  1.40s/it]


 Epoch 1, step: 1493, Loss: 1.5926201343536377, total loss: 4144.490428566933
saving model


Processing Epoch 00:  10%|▉         | 1501/15633 [36:08<5:36:56,  1.43s/it]


 Epoch 1, step: 1500, loss: 2.024434804916382, total loss: 4157.5078048706055


Processing Epoch 00:  10%|█         | 1601/15633 [38:32<5:35:52,  1.44s/it]


 Epoch 1, step: 1600, loss: 2.0108182430267334, total loss: 4351.314712524414


Processing Epoch 00:  11%|█         | 1661/15633 [39:57<5:27:14,  1.41s/it]


 Epoch 1, step: 1660, Loss: 1.5786371231079102, total loss: 4464.410191297531
saving model


Processing Epoch 00:  11%|█         | 1701/15633 [40:55<5:39:10,  1.46s/it]


 Epoch 1, step: 1700, loss: 2.1191763877868652, total loss: 4541.030363082886


Processing Epoch 00:  12%|█▏        | 1801/15633 [43:18<5:37:14,  1.46s/it]


 Epoch 1, step: 1800, loss: 1.7639840841293335, total loss: 4730.480536103249


Processing Epoch 00:  12%|█▏        | 1901/15633 [45:41<5:31:00,  1.45s/it]


 Epoch 1, step: 1900, loss: 2.221797227859497, total loss: 4924.5789052248


Processing Epoch 00:  12%|█▏        | 1948/15633 [46:48<5:36:21,  1.47s/it]


 Epoch 1, step: 1947, Loss: 1.561288833618164, total loss: 5014.545948028564
saving model


Processing Epoch 00:  13%|█▎        | 1961/15633 [47:07<5:27:32,  1.44s/it]


 Epoch 1, step: 1960, Loss: 1.5380991697311401, total loss: 5038.081069231033
saving model


Processing Epoch 00:  13%|█▎        | 2001/15633 [48:04<5:21:30,  1.42s/it]


 Epoch 1, step: 2000, loss: 1.6995036602020264, total loss: 5111.192231059074


Processing Epoch 00:  13%|█▎        | 2088/15633 [50:09<5:19:38,  1.42s/it]


 Epoch 1, step: 2087, Loss: 1.5000646114349365, total loss: 5268.89532494545
saving model


Processing Epoch 00:  13%|█▎        | 2101/15633 [50:27<5:16:28,  1.40s/it]


 Epoch 1, step: 2100, loss: 1.904961347579956, total loss: 5292.7475645542145


Processing Epoch 00:  14%|█▎        | 2116/15633 [50:49<5:26:26,  1.45s/it]


 Epoch 1, step: 2115, Loss: 1.4183470010757446, total loss: 5318.8591248989105
saving model


Processing Epoch 00:  14%|█▍        | 2201/15633 [52:51<5:14:47,  1.41s/it]


 Epoch 1, step: 2200, loss: 1.9929684400558472, total loss: 5479.250538945198


Processing Epoch 00:  15%|█▍        | 2301/15633 [55:14<5:18:50,  1.43s/it]


 Epoch 1, step: 2300, loss: 1.779798150062561, total loss: 5670.819621801376


Processing Epoch 00:  15%|█▌        | 2401/15633 [57:38<5:18:06,  1.44s/it]


 Epoch 1, step: 2400, loss: 1.939523458480835, total loss: 5850.905344724655


Processing Epoch 00:  16%|█▌        | 2501/15633 [1:00:00<5:21:30,  1.47s/it]


 Epoch 1, step: 2500, loss: 1.8195264339447021, total loss: 6031.048939228058


Processing Epoch 00:  17%|█▋        | 2601/15633 [1:02:22<5:09:46,  1.43s/it]


 Epoch 1, step: 2600, loss: 1.7085669040679932, total loss: 6208.368551850319


Processing Epoch 00:  17%|█▋        | 2701/15633 [1:04:44<4:59:15,  1.39s/it]


 Epoch 1, step: 2700, loss: 2.0394225120544434, total loss: 6384.385156869888


Processing Epoch 00:  18%|█▊        | 2801/15633 [1:07:06<4:59:36,  1.40s/it]


 Epoch 1, step: 2800, loss: 1.8898820877075195, total loss: 6584.241739392281


Processing Epoch 00:  19%|█▊        | 2901/15633 [1:09:28<4:59:31,  1.41s/it]


 Epoch 1, step: 2900, loss: 2.0910255908966064, total loss: 6770.758103966713


Processing Epoch 00:  19%|█▉        | 3001/15633 [1:11:50<5:01:33,  1.43s/it]


 Epoch 1, step: 3000, loss: 1.7673197984695435, total loss: 6956.226130723953


Processing Epoch 00:  20%|█▉        | 3101/15633 [1:14:12<5:04:53,  1.46s/it]


 Epoch 1, step: 3100, loss: 1.7282720804214478, total loss: 7147.567577600479


Processing Epoch 00:  20%|██        | 3201/15633 [1:16:34<4:51:35,  1.41s/it]


 Epoch 1, step: 3200, loss: 2.046186685562134, total loss: 7340.838427901268


Processing Epoch 00:  21%|██        | 3301/15633 [1:18:56<4:46:37,  1.39s/it]


 Epoch 1, step: 3300, loss: 1.8081865310668945, total loss: 7528.387335896492


Processing Epoch 00:  22%|██▏       | 3401/15633 [1:21:17<4:46:51,  1.41s/it]


 Epoch 1, step: 3400, loss: 1.571982741355896, total loss: 7703.99308514595


Processing Epoch 00:  22%|██▏       | 3501/15633 [1:23:39<4:51:04,  1.44s/it]


 Epoch 1, step: 3500, loss: 1.6124202013015747, total loss: 7877.132068634033


Processing Epoch 00:  23%|██▎       | 3601/15633 [1:26:00<4:48:25,  1.44s/it]


 Epoch 1, step: 3600, loss: 1.5876076221466064, total loss: 8055.384760141373


Processing Epoch 00:  24%|██▎       | 3701/15633 [1:28:21<4:38:04,  1.40s/it]


 Epoch 1, step: 3700, loss: 1.7072776556015015, total loss: 8237.699759840965


Processing Epoch 00:  24%|██▍       | 3801/15633 [1:30:43<4:33:21,  1.39s/it]


 Epoch 1, step: 3800, loss: 2.0893709659576416, total loss: 8434.72203218937


Processing Epoch 00:  25%|██▍       | 3901/15633 [1:33:04<4:32:57,  1.40s/it]


 Epoch 1, step: 3900, loss: 1.739772081375122, total loss: 8627.132200598717


Processing Epoch 00:  26%|██▌       | 4001/15633 [1:35:25<4:36:45,  1.43s/it]


 Epoch 1, step: 4000, loss: 1.7207248210906982, total loss: 8806.965348124504


Processing Epoch 00:  26%|██▌       | 4101/15633 [1:37:46<4:29:48,  1.40s/it]


 Epoch 1, step: 4100, loss: 1.7700507640838623, total loss: 8986.26793396473


Processing Epoch 00:  27%|██▋       | 4201/15633 [1:40:06<4:23:33,  1.38s/it]


 Epoch 1, step: 4200, loss: 1.7114845514297485, total loss: 9165.797843456268


Processing Epoch 00:  28%|██▊       | 4301/15633 [1:42:28<4:22:25,  1.39s/it]


 Epoch 1, step: 4300, loss: 1.8178526163101196, total loss: 9349.606652259827


Processing Epoch 00:  28%|██▊       | 4401/15633 [1:44:48<4:29:18,  1.44s/it]


 Epoch 1, step: 4400, loss: 1.713474988937378, total loss: 9528.485793590546


Processing Epoch 00:  29%|██▉       | 4501/15633 [1:47:09<4:23:34,  1.42s/it]


 Epoch 1, step: 4500, loss: 1.8185430765151978, total loss: 9709.111825466156


Processing Epoch 00:  29%|██▉       | 4563/15633 [1:48:37<4:17:50,  1.40s/it]


 Epoch 1, step: 4562, Loss: 1.4059160947799683, total loss: 9814.87218272686
saving model


Processing Epoch 00:  29%|██▉       | 4602/15633 [1:49:31<4:15:43,  1.39s/it]


 Epoch 1, step: 4600, loss: 1.7384731769561768, total loss: 9877.717487335205


Processing Epoch 00:  30%|███       | 4701/15633 [1:51:51<4:15:05,  1.40s/it]


 Epoch 1, step: 4700, loss: 1.7126386165618896, total loss: 10048.258151650429


Processing Epoch 00:  31%|███       | 4777/15633 [1:53:38<4:23:39,  1.46s/it]


 Epoch 1, step: 4776, Loss: 1.3383914232254028, total loss: 10178.873227715492
saving model


Processing Epoch 00:  31%|███       | 4801/15633 [1:54:12<4:19:27,  1.44s/it]


 Epoch 1, step: 4800, loss: 1.7428756952285767, total loss: 10219.085227370262


Processing Epoch 00:  31%|███       | 4884/15633 [1:56:09<4:10:17,  1.40s/it]


 Epoch 1, step: 4883, Loss: 1.3373826742172241, total loss: 10357.566935658455
saving model


Processing Epoch 00:  31%|███▏      | 4901/15633 [1:56:33<4:16:20,  1.43s/it]


 Epoch 1, step: 4900, loss: 1.9052549600601196, total loss: 10385.680579423904


Processing Epoch 00:  32%|███▏      | 5001/15633 [1:58:53<4:06:28,  1.39s/it]


 Epoch 1, step: 5000, loss: 1.8309067487716675, total loss: 10549.668045520782


Processing Epoch 00:  32%|███▏      | 5026/15633 [1:59:29<4:11:36,  1.42s/it]


 Epoch 1, step: 5025, Loss: 1.2582365274429321, total loss: 10589.988594651222
saving model


Processing Epoch 00:  32%|███▏      | 5052/15633 [2:00:05<4:11:53,  1.43s/it]


 Epoch 1, step: 5051, Loss: 1.2571699619293213, total loss: 10630.821589589119
saving model


Processing Epoch 00:  33%|███▎      | 5101/15633 [2:01:14<4:03:42,  1.39s/it]


 Epoch 1, step: 5100, loss: 1.5015268325805664, total loss: 10709.328580379486


Processing Epoch 00:  33%|███▎      | 5201/15633 [2:03:35<4:06:49,  1.42s/it]


 Epoch 1, step: 5200, loss: 1.8393810987472534, total loss: 10868.930366754532


Processing Epoch 00:  34%|███▍      | 5301/15633 [2:05:55<4:08:49,  1.44s/it]


 Epoch 1, step: 5300, loss: 1.6800910234451294, total loss: 11028.314090371132


Processing Epoch 00:  35%|███▍      | 5401/15633 [2:08:15<3:54:51,  1.38s/it]


 Epoch 1, step: 5400, loss: 1.5827566385269165, total loss: 11185.88696861267


Processing Epoch 00:  35%|███▌      | 5501/15633 [2:10:36<3:55:31,  1.39s/it]


 Epoch 1, step: 5500, loss: 1.4603593349456787, total loss: 11351.34306883812


Processing Epoch 00:  36%|███▌      | 5601/15633 [2:12:56<3:59:02,  1.43s/it]


 Epoch 1, step: 5600, loss: 1.6338520050048828, total loss: 11517.760630488396


Processing Epoch 00:  36%|███▋      | 5701/15633 [2:15:17<3:55:54,  1.43s/it]


 Epoch 1, step: 5700, loss: 1.6053637266159058, total loss: 11679.8893378973


Processing Epoch 00:  37%|███▋      | 5801/15633 [2:17:37<3:46:30,  1.38s/it]


 Epoch 1, step: 5800, loss: 1.6219069957733154, total loss: 11842.13503909111


Processing Epoch 00:  38%|███▊      | 5901/15633 [2:19:58<3:44:40,  1.39s/it]


 Epoch 1, step: 5900, loss: 1.7953112125396729, total loss: 12006.266588807106


Processing Epoch 00:  38%|███▊      | 6001/15633 [2:22:18<3:49:06,  1.43s/it]


 Epoch 1, step: 6000, loss: 1.8994648456573486, total loss: 12173.119334816933


Processing Epoch 00:  39%|███▉      | 6101/15633 [2:24:38<3:44:11,  1.41s/it]


 Epoch 1, step: 6100, loss: 1.7751065492630005, total loss: 12338.656559467316


Processing Epoch 00:  40%|███▉      | 6201/15633 [2:26:59<3:37:06,  1.38s/it]


 Epoch 1, step: 6200, loss: 1.6857184171676636, total loss: 12503.389309883118


Processing Epoch 00:  40%|████      | 6301/15633 [2:29:19<3:35:41,  1.39s/it]


 Epoch 1, step: 6300, loss: 1.663983702659607, total loss: 12670.272683501244


Processing Epoch 00:  41%|████      | 6401/15633 [2:31:40<3:40:25,  1.43s/it]


 Epoch 1, step: 6400, loss: 1.6569445133209229, total loss: 12832.472758054733


Processing Epoch 00:  41%|████      | 6405/15633 [2:31:46<3:38:39,  1.42s/it]


KeyboardInterrupt: 

# Multimodal GPT - Trained on PHI2 LLM

This project is a multimodal GPT capable of handling images, audio and text as input for context for question and answer from the GPT. This uses PHI2 LLM.

## Training

### Pre training
To handle image as input, a small projection layer was pre-trained which was used to project the input image embedding to the same latent space that PHI2 can handle.


* Image embedding is obtained using CLIP vision model.
* Image embedding is passed through projecting layer.
* Projected image embedding is concatenated with the embedding of the corresponding caption embedding got from the PHI2 text embedding. For this step the parameters of PHI2 is frozen so that it does change during back propagation.
* The concatenated embedding is used as input to PHI2 model and is compared with actual caption for loss.
* Datset used is [COCO2014](https://www.kaggle.com/datasets/nadaibrahim/coco2014).


```
class ProjectionLayer(torch.nn.Module):
    def __init__(self, clip_dim, phi_dim):
        super().__init__()
        self.linear = torch.nn.Linear(clip_dim, phi_dim)

    def forward(self, x):
        return self.linear(x)
```

<Image>

### Pre-Training Log

```
 Epoch 1, step: 4000, loss: 1.7207248210906982, total loss: 8806.965348124504
Processing Epoch 00:  26%|██▌       | 4101/15633 [1:37:46<4:29:48,  1.40s/it]

 Epoch 1, step: 4100, loss: 1.7700507640838623, total loss: 8986.26793396473
Processing Epoch 00:  27%|██▋       | 4201/15633 [1:40:06<4:23:33,  1.38s/it]

 Epoch 1, step: 4200, loss: 1.7114845514297485, total loss: 9165.797843456268
Processing Epoch 00:  28%|██▊       | 4301/15633 [1:42:28<4:22:25,  1.39s/it]

 Epoch 1, step: 4300, loss: 1.8178526163101196, total loss: 9349.606652259827
Processing Epoch 00:  28%|██▊       | 4401/15633 [1:44:48<4:29:18,  1.44s/it]

 Epoch 1, step: 4400, loss: 1.713474988937378, total loss: 9528.485793590546
Processing Epoch 00:  29%|██▉       | 4501/15633 [1:47:09<4:23:34,  1.42s/it]

 Epoch 1, step: 4500, loss: 1.8185430765151978, total loss: 9709.111825466156
Processing Epoch 00:  29%|██▉       | 4563/15633 [1:48:37<4:17:50,  1.40s/it]

 Epoch 1, step: 4562, Loss: 1.4059160947799683, total loss: 9814.87218272686
saving model
Processing Epoch 00:  29%|██▉       | 4602/15633 [1:49:31<4:15:43,  1.39s/it]

 Epoch 1, step: 4600, loss: 1.7384731769561768, total loss: 9877.717487335205
Processing Epoch 00:  30%|███       | 4701/15633 [1:51:51<4:15:05,  1.40s/it]

 Epoch 1, step: 4700, loss: 1.7126386165618896, total loss: 10048.258151650429
Processing Epoch 00:  31%|███       | 4777/15633 [1:53:38<4:23:39,  1.46s/it]

 Epoch 1, step: 4776, Loss: 1.3383914232254028, total loss: 10178.873227715492
saving model
Processing Epoch 00:  31%|███       | 4801/15633 [1:54:12<4:19:27,  1.44s/it]

 Epoch 1, step: 4800, loss: 1.7428756952285767, total loss: 10219.085227370262
Processing Epoch 00:  31%|███       | 4884/15633 [1:56:09<4:10:17,  1.40s/it]

 Epoch 1, step: 4883, Loss: 1.3373826742172241, total loss: 10357.566935658455
saving model
Processing Epoch 00:  31%|███▏      | 4901/15633 [1:56:33<4:16:20,  1.43s/it]

 Epoch 1, step: 4900, loss: 1.9052549600601196, total loss: 10385.680579423904
Processing Epoch 00:  32%|███▏      | 5001/15633 [1:58:53<4:06:28,  1.39s/it]

 Epoch 1, step: 5000, loss: 1.8309067487716675, total loss: 10549.668045520782
Processing Epoch 00:  32%|███▏      | 5026/15633 [1:59:29<4:11:36,  1.42s/it]

 Epoch 1, step: 5025, Loss: 1.2582365274429321, total loss: 10589.988594651222
saving model
Processing Epoch 00:  32%|███▏      | 5052/15633 [2:00:05<4:11:53,  1.43s/it]

 Epoch 1, step: 5051, Loss: 1.2571699619293213, total loss: 10630.821589589119
saving model
```



### Fine Tuning
After the projection layer was pretrained, PHI2 was fine tuned.

* Input image embedding is obtained using CLIP vision model.
* Image embedding is passed through pre-trained projecting layer. For fine tuning, the parameters of projection layer is forzen to prevent changes during back propagation.
* Projected embedding is concatenated with the embeddings of Question-Answer text data from [LLaVA-Instruct-150K](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K)
* Concatenated embedding is used as input to PHI2 model and the output is compared with the answer part for the loss.

<Image>

### Fine Tuning Log
```
 Epoch 1, step: 8580, loss: 2.147397041320801, total loss: 20604.322937846184
clearing cache
Processing Epoch 00:  47%|████▋     | 8591/18148 [7:30:17<8:25:17,  3.17s/it]

 Epoch 1, step: 8590, loss: 2.0401413440704346, total loss: 20627.469371914864
clearing cache
Processing Epoch 00:  47%|████▋     | 8601/18148 [7:30:48<8:24:41,  3.17s/it]

 Epoch 1, step: 8600, loss: 2.2512617111206055, total loss: 20651.634582281113
clearing cache
Processing Epoch 00:  47%|████▋     | 8611/18148 [7:31:19<8:25:37,  3.18s/it]

 Epoch 1, step: 8610, loss: 2.054892063140869, total loss: 20674.192289114
clearing cache
Processing Epoch 00:  48%|████▊     | 8621/18148 [7:31:50<8:22:25,  3.16s/it]

 Epoch 1, step: 8620, loss: 2.16609787940979, total loss: 20694.78110229969
clearing cache
Processing Epoch 00:  48%|████▊     | 8631/18148 [7:32:21<8:22:07,  3.17s/it]

 Epoch 1, step: 8630, loss: 2.1779747009277344, total loss: 20719.170233249664
clearing cache
Processing Epoch 00:  48%|████▊     | 8641/18148 [7:32:52<8:24:29,  3.18s/it]

 Epoch 1, step: 8640, loss: 2.852278709411621, total loss: 20743.28934788704
clearing cache
Processing Epoch 00:  48%|████▊     | 8651/18148 [7:33:23<8:22:59,  3.18s/it]

 Epoch 1, step: 8650, loss: 2.776125431060791, total loss: 20767.941595435143
clearing cache
Processing Epoch 00:  48%|████▊     | 8661/18148 [7:33:54<8:21:47,  3.17s/it]

 Epoch 1, step: 8660, loss: 2.1857335567474365, total loss: 20790.802763819695
clearing cache

```


## Further Improvements
* Fine tuned model is not accurate and can be trained further by improving the data input and loss calculation.
* The context of the queries is limited to the current query and can be improved to include the previous query and responses as context.

