In [1]:
import torch
import torch.nn as nn
import torchvision
import transformers
import os
import requests
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import wandb
import random
import timm
import gc
import requests

from PIL import Image
from pathlib import Path

from PIL import Image
from tqdm import tqdm

from torch.functional import F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mwjnwjn59[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
config = {
    'seed': 59,
    'learning_rate': 1e-5,
    'epochs': 50,
    'train_batch_size': 32,
    'val_batch_size': 64,
    'hidden_dim': 2048,
    'projection_dim': 2048,
    'weight_decay': 1e-5,
    'patience': 10,
    'text_max_len': 50,
    'fusion_strategy': 'concat+smalllen',
    'text_encoder_id': 'vinai/bartpho-word',
    'img_encoder_id': 'timm/beitv2_base_patch16_224.in1k_ft_in22k', # id from timm
    'dataset': 'ViVQA'
}
PROJECT_NAME = 'vivqa_paraphrase_augmentation'
EXP_NAME = 'vivqa_baseline_bartphoword_beitv2'
wandb.init(
    project=PROJECT_NAME,
    name=EXP_NAME,
    config=config
)

In [4]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

RANDOM_SEED = config['seed']
set_seed(RANDOM_SEED)

In [5]:
DATASET_DIR = Path('../../datasets')
VIVQA_GT_TRAIN_PATH = DATASET_DIR / 'ViVQA' / 'train.csv'
VIVQA_GT_TEST_PATH = DATASET_DIR / 'ViVQA' / 'test.csv'
VIVQA_IMG_TRAIN_DIR = DATASET_DIR / 'MS_COCO2014' / 'merge'


def visualize_sample(question, answer, img_path):
    img_pil = Image.open(img_path).convert('RGB')

    plt.imshow(img_pil)
    plt.axis('off')
    plt.title(f'Question: {question}?. Answer: {answer}')
    plt.show()

img_lst = os.listdir(VIVQA_IMG_TRAIN_DIR)

def get_data(df_path):
    df = pd.read_csv(df_path, index_col=0)
    questions = [] 
    answers = []
    img_paths = []
    for idx, row in df.iterrows():
        question = row['question']
        answer = row['answer']
        img_id = row['img_id']
        #question_type = row['type'] # 0: object, 1: color, 2: how many, 3: where
        img_path = VIVQA_IMG_TRAIN_DIR / f'{img_id:012}.jpg'

        questions.append(question)
        answers.append(answer)
        img_paths.append(img_path)

    return questions, img_paths, answers 


train_questions, train_img_paths, train_answers = get_data(VIVQA_GT_TRAIN_PATH)    
test_questions, test_img_paths, test_answers = get_data(VIVQA_GT_TEST_PATH)    

train_set_size = len(train_questions)
test_set_size = len(test_questions)

print(f'Number of train sample: {train_set_size}')
print(f'Number of test sample: {test_set_size}')


Number of train sample: 11999
Number of test sample: 3001


In [6]:
max([len(text) for text in train_questions + test_questions])

110

In [7]:
answer_space = set(list(train_answers + test_answers))
idx2label = {idx: label for idx, label in enumerate(answer_space)}
label2idx = {label: idx for idx, label in enumerate(answer_space)}
answer_space_len = len(answer_space)

In [8]:
import py_vncorenlp

from transformers import AutoModel, AutoTokenizer
from contextlib import contextmanager

dict_map = {
    "òa": "oà",
    "Òa": "Oà",
    "ÒA": "OÀ",
    "óa": "oá",
    "Óa": "Oá",
    "ÓA": "OÁ",
    "ỏa": "oả",
    "Ỏa": "Oả",
    "ỎA": "OẢ",
    "õa": "oã",
    "Õa": "Oã",
    "ÕA": "OÃ",
    "ọa": "oạ",
    "Ọa": "Oạ",
    "ỌA": "OẠ",
    "òe": "oè",
    "Òe": "Oè",
    "ÒE": "OÈ",
    "óe": "oé",
    "Óe": "Oé",
    "ÓE": "OÉ",
    "ỏe": "oẻ",
    "Ỏe": "Oẻ",
    "ỎE": "OẺ",
    "õe": "oẽ",
    "Õe": "Oẽ",
    "ÕE": "OẼ",
    "ọe": "oẹ",
    "Ọe": "Oẹ",
    "ỌE": "OẸ",
    "ùy": "uỳ",
    "Ùy": "Uỳ",
    "ÙY": "UỲ",
    "úy": "uý",
    "Úy": "Uý",
    "ÚY": "UÝ",
    "ủy": "uỷ",
    "Ủy": "Uỷ",
    "ỦY": "UỶ",
    "ũy": "uỹ",
    "Ũy": "Uỹ",
    "ŨY": "UỸ",
    "ụy": "uỵ",
    "Ụy": "Uỵ",
    "ỤY": "UỴ",
    }

def text_tone_normalize(text, dict_map):
    for i, j in dict_map.items():
        text = text.replace(i, j)
    return text

@contextmanager
def temporary_directory_change(directory):
    original_directory = os.getcwd()
    os.chdir(directory)
    try:
        yield
    finally:
        os.chdir(original_directory)

TEXT_MODEL_ID = config['text_encoder_id']
VNCORENLP_PATH = Path('../models/VnCoreNLP')
ABS_VNCORENLP_PATH = VNCORENLP_PATH.resolve()
os.makedirs(VNCORENLP_PATH, exist_ok=True)

if not (ABS_VNCORENLP_PATH / 'models').exists():
    py_vncorenlp.download_model(save_dir=str(ABS_VNCORENLP_PATH))

with temporary_directory_change(ABS_VNCORENLP_PATH):
    rdrsegmenter = py_vncorenlp.VnCoreNLP(annotators=["wseg"], 
                                        save_dir=str(ABS_VNCORENLP_PATH))

device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_ID)
text_model = AutoModel.from_pretrained(TEXT_MODEL_ID,
                                       device_map=device)

def text_processor(text):
    text = text_tone_normalize(text, dict_map)
    segmented_text = rdrsegmenter.word_segment(text)
    segmented_text = ' '.join(segmented_text)

    input_ids = torch.tensor(
        [tokenizer.encode(segmented_text,
                          max_length=config['text_max_len'],
                          padding='max_length', 
                          truncation=True)]).to(device)
    attention_mask = torch.where(input_ids == 1, 0, 1)

    return { 
        'input_ids': input_ids,
        'attention_mask': attention_mask
    }

# sentence = 'Có bao nhiêu người trong bức ảnh ?' 
# phobert_outputs = text_processor(sentence)

# with torch.no_grad():
#     features = text_model(**phobert_outputs)
#     print(features['last_hidden_state'].shape)

2024-06-26 13:28:27 INFO  WordSegmenter:24 - Loading Word Segmentation model




In [9]:
# url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
# image = Image.open(requests.get(url, stream=True).raw)

IMG_MODEL_ID = config['img_encoder_id']

# get model specific transforms (normalization, resize)
img_model = timm.create_model(
    IMG_MODEL_ID,
    pretrained=True,
    num_classes=0 # remove classifier nn.Linear
).to(device)
data_config = timm.data.resolve_model_data_config(img_model)
img_processor = timm.data.create_transform(**data_config, 
                                           is_training=False)

# output = img_model(img_processor(image).unsqueeze(0))  # output is (batch_size, num_features) shaped tensor
# or equivalently (without needing to set num_classes=0)


# img_model = img_model.eval()
# with torch.no_grad():
#     output = img_model.forward_features(img_processor(image).to(device).unsqueeze(0))
#     print(output.shape)

In [10]:
class ViVQADataset(Dataset):
    def __init__(self, data_dir, data_mode, text_processor, img_processor, label_encoder=None, device='cpu'):
        self.data_dir = data_dir
        if data_mode == 'train':
            self.data_path = data_dir / 'ViVQA' / 'train.csv'
        else:
            self.data_path = data_dir / 'ViVQA' / 'test.csv'
        self.text_processor = text_processor
        self.img_processor = img_processor
        self.label_encoder = label_encoder
        self.device = device

        self.questions, self.img_paths, self.answers = self.get_data()

    def get_data(self):
        df = pd.read_csv(self.data_path, index_col=0)
        questions = [] 
        answers = []
        img_paths = []
        for idx, row in df.iterrows():
            question = row['question']
            answer = row['answer']
            img_id = row['img_id']
            #question_type = row['type'] # 0: object, 1: color, 2: how many, 3: where

            img_path = self.data_dir / 'MS_COCO2014' / 'merge' / f'{img_id:012}.jpg'

            questions.append(question)
            answers.append(answer)
            img_paths.append(img_path)


        return questions, img_paths, answers 

    def __getitem__(self, idx):
        questions = self.questions[idx]
        answers = self.answers[idx]
        img_paths = self.img_paths[idx]

        img_pil = Image.open(img_paths).convert('RGB')
        text_inputs = self.text_processor(questions)
        
        img_inputs = self.img_processor(img_pil).to(device)
        label = self.label_encoder[answers]
        
        text_inputs = {k: v.squeeze().to(self.device) for k, v in text_inputs.items()}
        #img_inputs = {k: v.squeeze().to(self.device) for k, v in img_inputs.items()}
        labels = torch.tensor(label, dtype=torch.long).to(self.device)
        
        return {
            'text_inputs': text_inputs,
            'img_inputs': img_inputs,
            'labels': labels
        }

    def __len__(self):
        return len(self.questions)
    
TRAIN_BATCH_SIZE = config['train_batch_size']
VAL_BATCH_SIZE = config['val_batch_size']
    
train_dataset = ViVQADataset(DATASET_DIR.resolve(), 'train', 
                             text_processor=text_processor,
                             img_processor=img_processor, 
                             label_encoder=label2idx,
                             device=device)
val_dataset = ViVQADataset(DATASET_DIR.resolve(), 'val', 
                           text_processor=text_processor,
                           img_processor=img_processor, 
                           label_encoder=label2idx,
                           device=device)
train_loader = DataLoader(train_dataset,
                          batch_size=TRAIN_BATCH_SIZE,
                          shuffle=True)
val_loader = DataLoader(val_dataset,
                          batch_size=VAL_BATCH_SIZE,
                          shuffle=False)

In [11]:
batch = next(iter(val_loader))
batch['text_inputs']['input_ids'].shape

torch.Size([64, 50])

In [12]:
class TextEncoder(nn.Module):
    def __init__(self, text_model, projection_dim):
        super().__init__()
        self.model = text_model
        self.linear = nn.Linear(self.model.config.hidden_size, projection_dim)

    def forward(self, inputs):
        x = self.model(**inputs)
        x = x['last_hidden_state'][:, 0, :]
        x = self.linear(x)
        x = F.gelu(x)

        return x 

class ImageEncoder(nn.Module):
    def __init__(self, img_model, projection_dim):
        super().__init__()
        for param in img_model.parameters():
            param.requires_grad = True
        self.model = img_model
        self.linear = nn.Linear(self.model.embed_dim, projection_dim)

    def forward(self, inputs):
        x = self.model.forward_features(inputs)
        x = x[:, 0, :]
        x = self.linear(x)
        x = F.gelu(x)

        return x 

class Classifier(nn.Module):
    def __init__(self, projection_dim, hidden_dim, answer_space):
        super().__init__()
        self.fc1 = nn.Linear(projection_dim * 2, hidden_dim)
        self.dropout1 = nn.Dropout(0.4)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.dropout2 = nn.Dropout(0.3)
        self.classifier = nn.Linear(hidden_dim // 2, answer_space)

    def forward(self, text_f, img_f):
        x = torch.cat((img_f, text_f), 1)
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.dropout1(x)
        x = self.fc2(x)
        x = F.gelu(x)
        x = self.dropout2(x)
        x = self.classifier(x)

        return x 


class ViVQAModel(nn.Module):
    def __init__(self, text_encoder, img_encoder, classifier):
        super().__init__()
        self.text_encoder = text_encoder
        self.img_encoder = img_encoder
        self.classifier = classifier

    def forward(self, text_inputs, img_inputs):
        text_f = self.text_encoder(text_inputs)
        img_f = self.img_encoder(img_inputs)

        logits = self.classifier(text_f, img_f)

        return logits

PROJECTION_DIM = config['projection_dim']
HIDDEN_DIM = config['hidden_dim']
text_encoder = TextEncoder(text_model=text_model,
                           projection_dim=PROJECTION_DIM)
img_encoder = ImageEncoder(img_model=img_model,
                         projection_dim=PROJECTION_DIM)
classifier = Classifier(projection_dim=PROJECTION_DIM,
                        hidden_dim=HIDDEN_DIM,
                        answer_space=answer_space_len)

model = ViVQAModel(text_encoder=text_encoder,
                   img_encoder=img_encoder,
                   classifier=classifier).to(device)

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw).convert('RGB')
question = 'Có bao nhiêu con mèo trong bức ảnh?'
with torch.no_grad():
    text_inputs = text_processor(question)
    img_inputs = img_processor(image).to(device).unsqueeze(0)
    
    logits = model(text_inputs, img_inputs)

    print(logits.shape)

torch.Size([1, 353])


In [13]:
LR = config['learning_rate']
EPOCHS = config['epochs']
PATIENCE = config['patience']
WEIGHT_DECAY = config['weight_decay']
optimizer = torch.optim.AdamW(model.parameters(),
                             lr=LR,
                             weight_decay=WEIGHT_DECAY)
# step_size = EPOCHS * 0.4
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 
#                                             step_size=step_size, 
#                                             gamma=0.1)
criterion = nn.CrossEntropyLoss()

def compute_accuracy(logits, labels):
    _, preds = torch.max(logits, 1)
    correct = (preds == labels).sum().item()
    accuracy = correct / logits.size(0)

    return accuracy

def evaluate(model, val_loader, criterion):
    model.eval()
    eval_losses = []
    eval_accs = []
    with torch.no_grad():
        for idx, batch in enumerate(val_loader):
            text_inputs = batch.pop('text_inputs')
            img_inputs = batch.pop('img_inputs')
            labels = batch.pop('labels')

            logits = model(text_inputs, img_inputs)

            loss = criterion(logits, labels)
            acc = compute_accuracy(logits, labels)

            eval_losses.append(loss.item())
            eval_accs.append(acc)

    eval_loss = sum(eval_losses) / len(eval_losses)
    eval_acc = sum(eval_accs) / len(eval_accs)

    return eval_loss, eval_acc


def train(model, 
          train_loader, 
          val_loader, 
          epochs, 
          criterion, 
          optimizer, 
          #scheduler,
          patience=5):
    
    best_val_loss = np.inf
    epochs_no_improve = 0
    
    train_loss_lst = []
    train_acc_lst = []
    val_loss_lst = []
    val_acc_lst = []
    for epoch in range(epochs):
        train_batch_loss_lst = []
        train_batch_acc_lst = []

        epoch_iterator = tqdm(train_loader, 
                              desc=f'Epoch {epoch + 1}/{epochs}', 
                              unit='batch')
        model.train()
        for batch in epoch_iterator:
            text_inputs = batch.pop('text_inputs')
            img_inputs = batch.pop('img_inputs')
            labels = batch.pop('labels')

            logits = model(text_inputs, img_inputs)

            loss = criterion(logits, labels)
            acc = compute_accuracy(logits, labels)

            train_batch_loss_lst.append(loss.item())
            train_batch_acc_lst.append(acc)

            loss.backward()

            optimizer.step()
            optimizer.zero_grad()

            epoch_iterator.set_postfix({'Batch Loss': loss.item()})

        # scheduler.step()

        val_loss, val_acc = evaluate(model,
                                     val_loader,
                                     criterion)

        train_loss = sum(train_batch_loss_lst) / len(train_batch_loss_lst)
        train_acc = sum(train_batch_acc_lst) / len(train_batch_acc_lst)

        train_loss_lst.append(train_loss)
        train_acc_lst.append(train_acc)
        val_loss_lst.append(val_loss)
        val_acc_lst.append(val_acc)

        wandb.log({
            'epoch': epoch + 1,
            'train_loss': train_loss,
            'train_acc': train_acc,
            'val_loss': val_loss,
            'val_acc': val_acc
        })

        print(f'EPOCH {epoch + 1}: Train loss: {train_loss:.4f}\tTrain acc: {train_acc:.4f}\tVal loss: {val_loss:.4f}\tVal acc: {val_acc:.4f}')

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), 'best_model.pth')
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f'Early stopping triggered after {epochs_no_improve} epochs without improvement.')
                break

    return train_loss_lst, train_acc_lst, val_loss_lst, val_acc_lst

train_loss_lst, train_acc_lst, val_loss_lst, val_acc_lst = train(model, 
                                                                 train_loader, 
                                                                 val_loader, 
                                                                 epochs=EPOCHS, 
                                                                 criterion=criterion, 
                                                                 optimizer=optimizer, 
                                                                 #scheduler=scheduler,
                                                                 patience=PATIENCE)

Epoch 1/50: 100%|██████████| 375/375 [04:32<00:00,  1.38batch/s, Batch Loss=3.33]


EPOCH 1: Train loss: 4.2868	Train acc: 0.1694	Val loss: 2.7692	Val acc: 0.3893


Epoch 2/50: 100%|██████████| 375/375 [04:19<00:00,  1.45batch/s, Batch Loss=1.84]


EPOCH 2: Train loss: 2.6299	Train acc: 0.4165	Val loss: 2.2255	Val acc: 0.4792


Epoch 3/50: 100%|██████████| 375/375 [04:18<00:00,  1.45batch/s, Batch Loss=2.48] 


EPOCH 3: Train loss: 2.1360	Train acc: 0.4895	Val loss: 1.9968	Val acc: 0.5243


Epoch 4/50: 100%|██████████| 375/375 [04:18<00:00,  1.45batch/s, Batch Loss=1.79] 


EPOCH 4: Train loss: 1.8139	Train acc: 0.5578	Val loss: 1.8641	Val acc: 0.5582


Epoch 5/50: 100%|██████████| 375/375 [04:19<00:00,  1.45batch/s, Batch Loss=0.951]


EPOCH 5: Train loss: 1.5279	Train acc: 0.6145	Val loss: 1.7664	Val acc: 0.5737


Epoch 6/50: 100%|██████████| 375/375 [04:34<00:00,  1.37batch/s, Batch Loss=1.08] 


EPOCH 6: Train loss: 1.2910	Train acc: 0.6726	Val loss: 1.7584	Val acc: 0.5800


Epoch 7/50: 100%|██████████| 375/375 [04:18<00:00,  1.45batch/s, Batch Loss=0.901]


EPOCH 7: Train loss: 1.0801	Train acc: 0.7189	Val loss: 1.7495	Val acc: 0.5970


Epoch 8/50: 100%|██████████| 375/375 [04:18<00:00,  1.45batch/s, Batch Loss=1.23] 


EPOCH 8: Train loss: 0.9183	Train acc: 0.7618	Val loss: 1.8139	Val acc: 0.6020


Epoch 9/50: 100%|██████████| 375/375 [04:18<00:00,  1.45batch/s, Batch Loss=0.8]  


EPOCH 9: Train loss: 0.7828	Train acc: 0.7945	Val loss: 1.7886	Val acc: 0.6082


Epoch 10/50: 100%|██████████| 375/375 [04:18<00:00,  1.45batch/s, Batch Loss=0.692]


EPOCH 10: Train loss: 0.6530	Train acc: 0.8275	Val loss: 1.8694	Val acc: 0.6109


Epoch 11/50: 100%|██████████| 375/375 [04:18<00:00,  1.45batch/s, Batch Loss=0.717]


EPOCH 11: Train loss: 0.5607	Train acc: 0.8527	Val loss: 1.9117	Val acc: 0.6026


Epoch 12/50: 100%|██████████| 375/375 [04:18<00:00,  1.45batch/s, Batch Loss=0.303] 


EPOCH 12: Train loss: 0.4740	Train acc: 0.8728	Val loss: 1.9723	Val acc: 0.6043


Epoch 13/50: 100%|██████████| 375/375 [04:18<00:00,  1.45batch/s, Batch Loss=0.674] 


EPOCH 13: Train loss: 0.4181	Train acc: 0.8863	Val loss: 2.0308	Val acc: 0.6059


Epoch 14/50: 100%|██████████| 375/375 [04:18<00:00,  1.45batch/s, Batch Loss=0.386] 


EPOCH 14: Train loss: 0.3749	Train acc: 0.8994	Val loss: 2.0962	Val acc: 0.6056


Epoch 15/50: 100%|██████████| 375/375 [04:18<00:00,  1.45batch/s, Batch Loss=0.231] 


EPOCH 15: Train loss: 0.3179	Train acc: 0.9136	Val loss: 2.1426	Val acc: 0.6103


Epoch 16/50: 100%|██████████| 375/375 [04:18<00:00,  1.45batch/s, Batch Loss=0.143] 


EPOCH 16: Train loss: 0.2860	Train acc: 0.9183	Val loss: 2.1321	Val acc: 0.6062


Epoch 17/50: 100%|██████████| 375/375 [04:18<00:00,  1.45batch/s, Batch Loss=0.266] 


EPOCH 17: Train loss: 0.2633	Train acc: 0.9272	Val loss: 2.2027	Val acc: 0.6097
Early stopping triggered after 10 epochs without improvement.
