In [1]:
from data_loaders import data_utils
act_data = data_utils.load_jsonl_file("/data/users/sgarg6/capstone/multimodalqa/MMQA_train.jsonl")

In [2]:
from data_loaders.dataset_mmqa import MMQAKnowledgeBase
mmqa_kb = MMQAKnowledgeBase(
        "/data/users/sgarg6/capstone/multimodalqa/MMQA_texts.jsonl",
        "/data/users/sgarg6/capstone/multimodalqa/MMQA_images.jsonl",
        "/data/users/sgarg6/capstone/multimodalqa/final_dataset_images"
    )
mmqa_text = [text for text in mmqa_kb.get_all_texts()]
mmqa_img = [img for img in mmqa_kb.get_all_images()]

mmqa_map = {text["id"]: text for text in mmqa_text}
mmqa_map_img = {img["id"]: img for img in mmqa_img}

Loaded 218285 text passages
Loaded 57058 image sources


In [3]:
mmqa_map['a7d9e6350bafc46b700e4d0739a39594']

{'title': 'Hillaryland',
 'url': 'https://en.wikipedia.org/wiki/Hillaryland',
 'id': 'a7d9e6350bafc46b700e4d0739a39594',
 'text': 'Hillaryland was the self-designated name of a group of core advisors to Hillary Clinton, when she was First Lady of the United States and again when, as United States Senator, she was one of the Democratic Party candidates for President in the 2008 U.S. election.'}

In [4]:
from PIL import Image
blank_image = Image.open("resources/1x1_#00000000.png")

In [5]:
train_data = []
for ques in act_data:
    ques_type = ques["answers"][0]["modality"]
    if "table" in ques_type:
        continue
    question = ques["question"]
    #try:
    
    if len(ques["answers"][0]["text_instances"]) > 0:
        passage = mmqa_map[ques["answers"][0]["text_instances"][0]["doc_id"]]["text"]
    else:
        passage = ""
    if len(ques["answers"][0]["image_instances"]) > 0:
        image = mmqa_map_img[ques["answers"][0]["image_instances"][0]["doc_id"]]["path"]
    else:
        image = blank_image
    # except Exception as e:
    #     print(e)
    #     print(ques)
    #     break
    ans = ques["answers"][0]["answer"]
    prompt = f"You are a helpful Question Answering assistant. You are being provided with images and passages, a question about the image or the passage and an answer. Answer the question using either the image or the passage. <image> Passage: {passage} Question: {question}. Answer: {ans}<|endofchunk|>"
    train_data.append((prompt, ques_type, image))

In [6]:
len(train_data)

15135

In [7]:
len(act_data)

23817

In [8]:
train_data[1]

("You are a helpful Question Answering assistant. You are being provided with images and passages, a question about the image or the passage and an answer. Answer the question using either the image or the passage. <image> Passage: The Game Boy Advance (Japanese: ゲームボーイアドバンス, Hepburn: Gēmu Bōi Adobansu) (GBA) is a 32-bit handheld video game console developed, manufactured and marketed by Nintendo as the successor to the Game Boy Color. It was released in Japan on March 21, 2001, in North America on June 11, 2001, in Australia and Europe on June 22, 2001, and in mainland China on June 8, 2004 (iQue Player). Nintendo's competitors in the handheld market at the time were the Neo Geo Pocket Color, WonderSwan, GP32, Tapwave Zodiac, and the N-Gage. Despite the competitors' best efforts, Nintendo maintained a majority market share with the Game Boy Advance. Question: When did the virtual console system when Japan had 102 games come out?. Answer: March 21, 2001<|endofchunk|>",
 'text',
 <PIL.P

In [9]:
from torch.utils.data import Dataset

class MMQADataset(Dataset):
    def __init__(self, train_data):
        self.train_data = train_data

    def __len__(self,):
        return len(self.train_data)

    def __getitem__(self, idx):
        return self.train_data[idx]

In [10]:
import flamingo_model
model = flamingo_model.FlamingoModel("anas-awadalla/mpt-1b-redpajama-200b-dolly",
                                     "anas-awadalla/mpt-1b-redpajama-200b-dolly",
                                     1, 
                                    "openflamingo/OpenFlamingo-3B-vitl-mpt1b-langinstruct")

Using pad_token, but it is not set yet.


You are using config.init_device='cpu', but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.
Flamingo model initialized with 1046992944 trainable parameters
cuda


In [11]:

def collate_fn(batch):
    prompt = [item[0] for item in batch]
    ques_type = [item[1] for item in batch]
    images = [item[2] for item in batch]
    model.tokenizer.padding_side = "right"
    prompt = [f"{s.strip()}{model.tokenizer.eos_token}" for s in prompt]
    prmpt_tokens = model.tokenizer(
        prompt,
        return_tensors="pt",
        padding=True,
        truncation=True
    )
    
    images = [mmqa_kb.get_image(image) if ques_type[idx] == "image" else image for idx, image in enumerate(images)]
    images = model.process_imgs(images)
    images = images.unsqueeze(1)
    return prmpt_tokens, ques_type, images

In [None]:
import torch
train_data = MMQADataset(train_data)
data_loader = torch.utils.data.DataLoader(train_data, batch_size=1, shuffle=True, collate_fn=collate_fn)
import torch.nn as nn
from tqdm import tqdm
optimizer = torch.optim.Adam(model.model.parameters(), 0.001)
running_loss = 0
loss_log = []
model.model.train()

for i, batch in tqdm(enumerate(data_loader)):
    
    optimizer.zero_grad()
    input_ids = batch[0]["input_ids"]
    labels = input_ids.clone()
    labels[labels == model.tokenizer.pad_token_id] = -100
    labels[labels == model.tokenizer.eos_token] = -100
    labels = labels.to(model.device)
    # Forward + backward + optimize
    loss = model.model(
            vision_x=batch[2].to(model.device),
            lang_x=batch[0]["input_ids"].to(model.device),
            attention_mask=batch[0]["attention_mask"].to(model.device),
            labels=labels,
        )[0]
    loss.backward()
    optimizer.step()
    loss_log.append(loss.item())
    # Print statistics
    running_loss += loss.item()
    if i % 500 == 0:
        print(running_loss/(i+1))

2it [00:01,  1.81it/s]

3.920361042022705


502it [01:53,  4.93it/s]

1.9892937496988596


1002it [03:49,  4.40it/s]

1.9173681898550554


1501it [05:42,  3.12it/s]

1.8653798380309465


2001it [07:36,  4.77it/s]

1.8457386346384026


2501it [09:31,  4.04it/s]

1.8298833826061536


3001it [11:25,  5.02it/s]

1.819169101472578


3502it [13:19,  4.61it/s]

1.8107281055187574


4001it [15:13,  4.37it/s]

1.8053132040564401


4502it [17:08,  4.95it/s]

1.7970049100977874


5001it [19:02,  5.19it/s]

1.7945326890105893


5502it [20:56,  4.67it/s]

1.7938888501121963


6002it [22:49,  4.39it/s]

1.7892400637812822


6322it [24:00,  5.10it/s]

In [None]:
torch.save(model.model.state_dict(), "/data/users/sgarg6/capstone/models/3b_finetune/model_eoc_no_eol.pt")


In [None]:
a = 1