In [1]:
import random
import torch
import pandas as pd
import numpy as np
import torch.optim as optim
import torch.nn.functional as F

from einops import reduce
from tqdm import tqdm
from PIL import Image
from os.path import join as pjoin

from transformers import InstructBlipProcessor
from pytorch_metric_learning import distances, losses, miners, reducers, testers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator

from data_utils import mkdir_p, del_folder, read_json, write_json, read_text
from dataset import CUB_200_2011
from instruct_blip import InstructBlipForImageRetrieval

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [2]:

device_id = 3
torch.cuda.set_device(torch.device(f'cuda:{device_id}' if torch.cuda.is_available() else 'cpu'))
print ('Cuda device %s | %s | %s/%sGB' % (torch.cuda.current_device(), 
                                          torch.cuda.get_device_name(device_id),
                                          round(torch.cuda.memory_allocated(device_id)/1024**3,1),
                                          round(torch.cuda.memory_reserved(device_id)/1024**3,1)))

device = "cuda:"+str(device_id) if torch.cuda.is_available() else "cpu"

random.seed(12345)
np.random.seed(12345)
torch.manual_seed(12345)
torch.cuda.manual_seed_all(12345)

Cuda device 3 | NVIDIA GeForce RTX 3090 | 0.0/0.0GB


In [3]:
blip_processor_file = "/home/jhkim980112/workspace/code/CV_project/processors"
blip_model_file = "/home/jhkim980112/workspace/code/CV_project/models"

processor = InstructBlipProcessor.from_pretrained(blip_processor_file)
instruct_blip = InstructBlipForImageRetrieval.from_pretrained(blip_model_file)
instruct_blip_image_encoder = instruct_blip.vision_model


In [4]:
cnt=0
for network_name, parameter in instruct_blip.named_parameters():
    if "vision_model" in network_name:
        cnt += 1
        parameter.requires_grad = False
    else:
        parameter.requires_grad = True
print(f"# of the vision encoder's networks : {cnt}")

# of the vision encoder's networks : 474


In [5]:
#for network_name, parameter in instruct_blip.named_parameters():
    #print(network_name, parameter)

In [6]:
image_path = "/home/jhkim980112/workspace/dataset/CUB_200_2011/CUB_200_2011/images"
train_data_path = "/home/jhkim980112/workspace/dataset/CUB_200_2011/annotations/train.json"
valid_data_path = "/home/jhkim980112/workspace/dataset/CUB_200_2011/annotations/val.json"
test_data_path = "/home/jhkim980112/workspace/dataset/CUB_200_2011/annotations/train.json"
attribute_path = "/home/jhkim980112/workspace/code/CV_project/fire/playground/attribute_sample.txt"

cub_train_data = CUB_200_2011(processor=processor, 
                              vision_dir=image_path, 
                              data_path=train_data_path,
                              attributes_path=attribute_path)

cub_valid_data = CUB_200_2011(processor=processor, 
                              vision_dir=image_path, 
                              data_path=valid_data_path,
                              attributes_path=attribute_path)

cub_test_data = CUB_200_2011(processor=processor, 
                              vision_dir=image_path, 
                              data_path=test_data_path,
                              attributes_path=attribute_path)

In [7]:
cub_train_data.attributes

["What is the shape, color, and length of a bird's beak?",
 "What are the patterns and colors of the bird's belly?",
 "What is the color of the upper parts, the under parts and the primary color of the bird's body?"]

In [12]:
num_epochs = 10
batch_size = 32 # Model& Data Load 시, 약 23,032MiB 소모 (3090 기준 배치사이즈 32가 적당해보임)
num_attributes = len(cub_train_data.attributes)

optimizer = optim.Adam(instruct_blip.parameters(), lr=0.001)
distance = distances.CosineSimilarity()
reducer = reducers.ThresholdReducer(low=0)
loss_func = losses.TripletMarginLoss(margin=0.3, distance=distance, reducer=reducer)
mining_func = miners.TripletMarginMiner(margin=0.2, distance=distance)
#mining_func = miners.TripletMarginMiner(margin=0.9, type_of_triplets="all")

#accuracy_calculator = AccuracyCalculator(include=("precision_at_1",), k=1)

In [10]:
train_loader = torch.utils.data.DataLoader(
    cub_train_data, batch_size=batch_size, shuffle=True
)

In [13]:
def train_qformer(instruct_blip, loss_func, mining_func, device, train_loader, optimizer, epoch, writer):
    instruct_blip.train()
    num_examples_seen = 0
    
    for batch_idx, batch in enumerate(train_loader):
        optimizer.zero_grad()
        inbatch_size = batch.label.shape[0]
        
        #labels = batch.label.view(inbatch_size*num_attributes, )
        labels = batch.label[:,0]
        pixel_values = batch.pixel_values[:, 0, :, :].squeeze().to(device)
        image_encoder_output = instruct_blip.vision_model(pixel_values=pixel_values, 
                                            output_attentions=True, 
                                            output_hidden_states=True, 
                                            return_dict=True)
        
        instructed_embeddings= []
        for attr in range(num_attributes):
            qformer_input_ids = batch.qformer_input_ids[:, attr, :].squeeze().to(device)
            qformer_attention_mask = batch.qformer_attention_mask[:, attr, :].squeeze().to(device)
            
            q_former_output = instruct_blip(image_embeds=image_encoder_output[0].to(device),
                                            qformer_input_ids=qformer_input_ids, 
                                            qformer_attention_mask=qformer_attention_mask).qformer_outputs  
            instructed_embeddings.append(q_former_output.pooler_output)

        instructed_embeddings = reduce(torch.stack(instructed_embeddings), 'n b c -> b c', 'mean')
        #print(instructed_embeddings.size())
            
            # (batch_size, 768)*3 -> (batch_size, 768) average pooling
            #instructed_embeddings = torch.stack([q_former_output.pooler_output], dim=0)
            #instructed_embeddings = torch.mean(instructed_embeddings, dim=0)
            # print(instructed_embeddings.shape)
                              
        #print(instructed_embeddings.shape)
        hard_pairs = mining_func(instructed_embeddings, labels)
        #print("hard_pairs", hard_pairs)

        loss = loss_func(instructed_embeddings, labels, hard_pairs)
        loss.backward()
        optimizer.step()
        
        if (batch_idx+1) % 100 == 0:
            print("Epoch {} Step {}: Loss = {}, Number of mined triplets = {}".format(
                epoch, batch_idx+1, loss, mining_func.num_triplets))
            
        num_examples_seen += inbatch_size
        writer.add_scalar("train/batch_loss", loss, batch_idx+1)
        writer.add_scalar("train/avg_loss", loss/(batch_idx+1), batch_idx)


In [14]:
instruct_blip.to(device) # Model Load 시, GPU VRAM 약 4,916MiB 소모

for epoch in range(1, num_epochs + 1):
    train_qformer(instruct_blip, loss_func, mining_func, device, train_loader, optimizer, epoch, writer)
    checkpoint = "/home/jhkim980112/workspace/code/CV_project/models/instruct_blip_" + str(epoch) + ".pt"
    torch.save(instruct_blip, checkpoint)

    #test(dataset1, dataset2, model, accuracy_calculator)

Epoch 1 Step 100: Loss = 0.2988765835762024, Number of mined triplets = 180


TypeError: can only concatenate str (not "int") to str