# **0. UTILS**

## 0.1 Utils

In [None]:
import contextlib
import numpy as np
import random
import shutil
import os

import torch

In [None]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
def truncate_seq_pair(tokens_a, tokens_b, max_length):
    """Truncates a sequence pair in place to the maximum length.
    Copied from https://github.com/huggingface/pytorch-pretrained-BERT
    """
    while True:
        total_length = len(tokens_a) + len(tokens_b)
        if total_length <= max_length:
            break
        if len(tokens_a) > len(tokens_b):
            tokens_a.pop()
        else:
            tokens_b.pop()

In [None]:
@contextlib.contextmanager
def numpy_seed(seed, *addl_seeds):
    """Context manager which seeds the NumPy PRNG with the specified seed and
    restores the state afterward"""
    if seed is None:
        yield
        return
    if len(addl_seeds) > 0:
        seed = int(hash((seed, *addl_seeds)) % 1e6)
    state = np.random.get_state()
    np.random.seed(seed)
    try:
        yield
    finally:
        np.random.set_state(state)

## 0.2 Initialize arguments

> Khởi tạo đối số cố định thay vì nhập

In [None]:
class Args:
    def __init__(self):
        self.seed = 123
        self.batch_sz = 8
        self.max_epochs = 20
        self.task_type = "multilabel"
        self.n_workers = 4
        self.patience = 20
        
        output_path = 'output'
        self.savedir = "/kaggle/working"
        self.save_name = 'mimic_par'
        
        self.loaddir = '/kaggle/input/medvill-weight'
        self.name = "scenario_name"
        
        self.openi = False
        self.data_path = '/kaggle/input/thesis-mimic-cxr-2-0-0'
        self.Train_dset_name = 'train.jsonl'
        self.Valid_dset_name = 'valid.jsonl'

        self.embed_sz = 768
        self.hidden_sz = 768
        self.bert_model = "bert-base-uncased"
        self.init_model = "bert-base-uncased"
        
        self.drop_img_percent = 0.0
        self.dropout = 0.1
        
        self.freeze_img = 0
        self.freeze_txt = 0
        
        self.freeze_img_all = False
        self.freeze_txt_all = False
        
        self.glove_path = "/path/to/glove_embeds/glove.840B.300d.txt"
        self.gradient_accumulation_steps = 2
        self.hidden = []
        
        self.img_embed_pool_type = "avg"
        self.img_hidden_sz = 2048
        self.include_bn = True
        
        self.lr = 1e-3
        self.lr_factor = 0.75
        self.lr_patience = 5
        
        self.max_seq_len = 512
        self.num_image_embeds = 256
        
        self.warmup = 0.1
        self.weight_classes = 1
    
args = Args()

# **1. DATA**

## 1.1 Vocab

In [None]:

class Vocab(object):
    """
    A vocabulary class that maps words to indices and vice versa.

    - If `emptyInit=True`, initializes an empty vocabulary.
    - If `emptyInit=False`, initializes with special tokens: [PAD], [UNK], [CLS], [SEP], [MASK].

    Attributes:
        stoi (dict): Maps words to indices.
        itos (list): Maps indices to words.
        vocab_sz (int): Size of the vocabulary.

    Methods:
        add(words): Adds new words to the vocabulary if they don't exist.
    """
    def __init__(self, emptyInit=False):
        if emptyInit:
            self.stoi, self.itos, self.vocab_sz = {}, [], 0
        else:
            self.stoi = {
                w: i
                for i, w in enumerate(["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"])
            }
            self.itos = [w for w in self.stoi]
            self.vocab_sz = len(self.itos)

    def add(self, words):
        cnt = len(self.itos)
        for w in words:
            if w in self.stoi:
                continue
            self.stoi[w] = cnt
            self.itos.append(w)
            cnt += 1
        self.vocab_sz = len(self.itos)

## 1.2 Dataset

> Set ImageFile.LOAD_TRUNCATED_IMAGES = True để bỏ qua những ảnh hư

In [None]:
import json
import numpy as np
import os
from PIL import Image

import torch
from torch.utils.data import Dataset

# from utils.utils import truncate_seq_pair, numpy_seed

from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

> Convert RGB để 3 chiều

In [None]:
class JsonlDataset(Dataset):
    """Dataset for processing .jsonl files containing text and image data.

    - Tokenizes text, truncates it to max length, and converts tokens to indices.
    - Handles optional image data, applying transformations if available.
    - Supports multilabel classification by converting labels into tensors.

    Attributes:
        data (list): List of parsed JSON objects from the dataset file.
        tokenizer (callable): Tokenizer function to process text.
        vocab (Vocab): Vocabulary object for token-to-index conversion.
        transforms (callable): Transformations applied to images.
        n_classes (int): Number of output classes.
        max_seq_len (int): Maximum sequence length after accounting for images.

    Methods:
        __len__(): Returns the number of samples.
        __getitem__(index): Retrieves a sample, processing text, images, and labels.
    """
    def __init__(self, data_path, tokenizer, transforms, vocab, args):
        self.data = [json.loads(l) for l in open(data_path)]
        self.data_dir = os.path.dirname(data_path)
        self.tokenizer = tokenizer
        self.args = args
        self.vocab = vocab
        self.n_classes = len(args.labels)
        self.text_start_token =  ["[SEP]"]

        # Drop images randomly for generalization
        with numpy_seed(0):
            for row in self.data:
                if np.random.random() < args.drop_img_percent:
                    row["img"] = None

        self.max_seq_len = args.max_seq_len
        self.max_seq_len -= args.num_image_embeds

        self.transforms = transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sentence = (
            self.text_start_token
            + self.tokenizer(self.data[index]["text"])[: (self.max_seq_len - 1)] 
            + self.text_start_token
        )
        segment = torch.zeros(len(sentence))
        sentence = torch.LongTensor(
            [
                self.vocab.stoi[w] if w in self.vocab.stoi else self.vocab.stoi["[UNK]"]
                for w in sentence
            ]
        )
        if self.args.task_type == "multilabel":
            label = torch.zeros(self.n_classes)
            if self.data[index]["label"] == '':
                self.data[index]["label"] = "'Others'"
            else:
                pass
            label[
                [self.args.labels.index(tgt) for tgt in self.data[index]["label"]]
            ] = 1
        else:
            pass

        image = None
        if self.data[index]["img"]:
            image = Image.open(
                os.path.join(self.data_dir, self.data[index]["img"])).convert("RGB")
        else:
            image = Image.fromarray(128 * np.ones((512, 512, 3), dtype=np.uint8))
        image = self.transforms(image)

        # The first SEP is part of Image Token.
        segment = segment[1:]
        sentence = sentence[1:]
        # The first segment (0) is of images.
        segment += 1

        return sentence, segment, image, label

## 1.3 Helpers

In [None]:
!pip install pytorch-pretrained-bert

In [None]:
import functools
import json
import os
from collections import Counter

import torch
import torchvision.transforms as transforms
from pytorch_pretrained_bert import BertTokenizer
from torch.utils.data import DataLoader

# from data.dataset import JsonlDataset
# from data.vocab import Vocab

In [None]:
def get_transforms(args):
    """Returns a set of image transformations based on the dataset source.

    - If args.openi is True, converts images to grayscale with 3 channels.
    - Converts images to tensors and normalizes using ImageNet statistics.

    Args:
        args: An object containing configuration settings.

    Returns:
        torchvision.transforms.Compose: A sequence of image transformations.
    """
    # if args.openi:
    #     return transforms.Compose(
    #         [
    #             transforms.Grayscale(num_output_channels=3),
    #             transforms.ToTensor(),
    #             transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    #         ])
    # else:
    return transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]
    )

In [None]:
def get_labels_and_frequencies(path):
    """Extracts unique labels and their frequencies from a JSONL file.

    - Reads labels from the dataset.
    - Splits multi-label entries and replaces empty labels with 'Others'.
    - Counts occurrences of each label.

    Args:
        path (str): Path to the JSONL file.

    Returns:
        tuple: (List of unique labels, Counter object with label frequencies).
    """
    label_freqs = Counter()
    data_labels = [json.loads(line)["label"] for line in open(path)]
    if type(data_labels) == list:
        # Change a little here to read our csv
        for label_row in data_labels:
            label_freqs.update(label_row)
    else:
        pass
    return list(label_freqs.keys()), label_freqs

In [None]:
def get_vocab(args):
    vocab = Vocab()
    bert_tokenizer = BertTokenizer.from_pretrained(
        args.bert_model, do_lower_case=True
    )
    vocab.stoi = bert_tokenizer.vocab
    vocab.itos = bert_tokenizer.ids_to_tokens
    vocab.vocab_sz = len(vocab.itos)

    return vocab

In [None]:
def collate_fn(batch, args):
    lens = [len(row[0]) for row in batch]
    bsz, max_seq_len = len(batch), max(lens)

    mask_tensor = torch.zeros(bsz, max_seq_len).long()
    text_tensor = torch.zeros(bsz, max_seq_len).long()
    segment_tensor = torch.zeros(bsz, max_seq_len).long()

    img_tensor = None
    img_tensor = torch.stack([row[2] for row in batch])

    if args.task_type == "multilabel":
        # Multilabel case
        tgt_tensor = torch.stack([row[3] for row in batch])
    else:
        # Single Label case
        tgt_tensor = torch.cat([row[3] for row in batch]).long()

    for i_batch, (input_row, length) in enumerate(zip(batch, lens)):
        tokens, segment = input_row[:2]
        text_tensor[i_batch, :length] = tokens
        segment_tensor[i_batch, :length] = segment
        mask_tensor[i_batch, :length] = 1

    return text_tensor, segment_tensor, mask_tensor, img_tensor, tgt_tensor

In [None]:
def get_data_loaders(args):
    tokenizer = (
        BertTokenizer.from_pretrained(args.bert_model, do_lower_case=True).tokenize)

    transforms = get_transforms(args)

    args.labels, args.label_freqs = get_labels_and_frequencies(
        os.path.join(args.data_path, args.Train_dset_name)
    )

    vocab = get_vocab(args)
    args.vocab = vocab
    args.vocab_sz = vocab.vocab_sz
    args.n_classes = len(args.labels)

    train = JsonlDataset(
        os.path.join(args.data_path, args.Train_dset_name),
        tokenizer,
        transforms,
        vocab,
        args,
    )

    args.train_data_len = len(train)

    dev = JsonlDataset(
        os.path.join(args.data_path, args.Valid_dset_name),
        tokenizer,
        transforms,
        vocab,
        args,
    )

    collate = functools.partial(collate_fn, args=args)

    train_loader = DataLoader(
        train,
        batch_size=args.batch_sz,
        shuffle=True,
        num_workers=args.n_workers,
        collate_fn=collate,
    )

    val_loader = DataLoader(
        dev,
        batch_size=args.batch_sz,
        shuffle=False,
        num_workers=args.n_workers,
        collate_fn=collate,
    )

    return train_loader, val_loader  # , test

In [None]:
# train_loader, val_loader = get_data_loaders(args)
# next(iter(train_loader))

# **2. MODELS**

## 2.1 Image

In [None]:
import torch
import torch.nn as nn
import torchvision
from torchvision.transforms import ToTensor
from PIL import Image
import torch.nn.functional as F
from einops import rearrange
from glob import glob

- Lấy đặc trưng của ảnh qua mô hình ResNet50.
- Flattening các đặc trưng để chuẩn bị cho các bước xử lý sau.

In [None]:
class ImageEncoder(nn.Module):
    def __init__(self, args):
        super(ImageEncoder, self).__init__()
        self.args = args
        model = torchvision.models.resnet50(pretrained=True)
        modules = list(model.children())[:-2]
        self.model = nn.Sequential(*modules)
    '''
        pool_func = (
            nn.AdaptiveAvgPool2d
            if args.img_embed_pool_type == "avg"
            else nn.AdaptiveMaxPool2d
        )

        if args.num_image_embeds in [1, 2, 3, 5, 7]:
            self.pool = pool_func((args.num_image_embeds, 1))
        elif args.num_image_embeds == 4:
            self.pool = pool_func((2, 2))
        elif args.num_image_embeds == 6:
            self.pool = pool_func((3, 2))
        elif args.num_image_embeds == 8:
            self.pool = pool_func((4, 2))
        elif args.num_image_embeds == 9:
            self.pool = pool_func((3, 3))
    '''
    def forward(self, x):
        # Bx3x224x224 -> Bx2048x7x7 -> Bx2048xN -> BxNx2048

        # out = self.pool(self.model(x))
        # out = torch.flatten(out, start_dim=2)
        # out = out.transpose(1, 2).contiguous()

        out = self.model(x)
        out = torch.flatten(out, start_dim=2) #out torch.Size([100, 2048, 3])
        out = out.transpose(1, 2).contiguous() #out torch.Size([100, 3, 2048])

        # print("out.size()",out.size())
        # input("STOP!!!")


        return out  # BxNx2048

## 2.2 Model

In [None]:
import torch
import torch.nn as nn
from transformers import BertConfig, BertModel
# from models.image import ImageEncoder

- Trước tiên, lớp chuyển đổi hình ảnh thành embeddings.
- Tiếp theo, mã hóa các token `[CLS]` và `[SEP]` để bao gồm chúng vào chuỗi embeddings.
- Cuối cùng, các embeddings của văn bản và hình ảnh được cộng với embeddings vị trí và loại token, sau đó chuẩn hóa và thực hiện dropout.

In [None]:
class ImageBertEmbeddings(nn.Module):
    def __init__(self, args, embeddings):
        super(ImageBertEmbeddings, self).__init__()
        self.args = args
        self.img_embeddings = nn.Linear(args.img_hidden_sz, args.hidden_sz)
        self.position_embeddings = embeddings.position_embeddings
        self.token_type_embeddings = embeddings.token_type_embeddings
        self.word_embeddings = embeddings.word_embeddings
        self.LayerNorm = embeddings.LayerNorm
        self.dropout = nn.Dropout(p=args.dropout)

    def forward(self, input_imgs, token_type_ids):
        bsz = input_imgs.size(0)
        seq_length = self.args.num_image_embeds + 2  # +2 for CLS and SEP Token

        # CLS Token
        cls_id = torch.LongTensor([self.args.vocab.stoi["[CLS]"]]).cuda()
        cls_id = cls_id.unsqueeze(0).expand(bsz, 1)
        cls_token_embeds = self.word_embeddings(cls_id)

        # SEP Token
        sep_id = torch.LongTensor([self.args.vocab.stoi["[SEP]"]]).cuda()
        sep_id = sep_id.unsqueeze(0).expand(bsz, 1)
        sep_token_embeds = self.word_embeddings(sep_id)

        # Image embeddings
        imgs_embeddings = self.img_embeddings(input_imgs)
        token_embeddings = torch.cat(
            [cls_token_embeds, imgs_embeddings, sep_token_embeds], dim=1)

        # Position and token type embeddings
        position_ids = torch.arange(seq_length, dtype=torch.long).cuda()
        position_ids = position_ids.unsqueeze(0).expand(bsz, seq_length)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        # Final embeddings
        # print('token_sz:', token_embeddings.size())
        # print('position_sz:', position_embeddings.size())
        # print('token_type:', token_type_embeddings.size())
        embeddings = token_embeddings + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)

        return embeddings

- Tạo attention mask để xử lý cả văn bản và hình ảnh.
- Trích xuất đặc trưng hình ảnh bằng ResNet50 (ImageEncoder).
- Tạo embeddings cho hình ảnh và văn bản.
- Ghép embeddings hình ảnh và văn bản thành một chuỗi đầu vào cho BERT.
- Truyền dữ liệu qua BERT encoder để học các mối quan hệ giữa hình ảnh và văn bản.
- Trả về embedding cuối cùng để sử dụng cho các tác vụ tiếp theo.

> Sửa lại float 16 thành float 32 trong hàm forward

In [None]:
class MultimodalBertEncoder(nn.Module):
    def __init__(self, args):
        super(MultimodalBertEncoder, self).__init__()
        self.args = args

        if args.init_model == "bert-base-scratch":
            config = BertConfig.from_pretrained("bert-base-uncased")
            bert = BertModel(config)
        else:
            bert = BertModel.from_pretrained(args.init_model)
        self.txt_embeddings = bert.embeddings
        
        self.img_embeddings = ImageBertEmbeddings(args, self.txt_embeddings)
        self.img_encoder = ImageEncoder(args)
        self.encoder = bert.encoder
        self.pooler = bert.pooler
        #self.clf = nn.Linear(args.hidden_sz, args.n_classes)

    def forward(self, input_txt, attention_mask, segment, input_img):
        bsz = input_txt.size(0)
        
        # Attention mask for both text and image
        attention_mask = torch.cat(
            [
                torch.ones(bsz, self.args.num_image_embeds + 2).long().cuda(),
                attention_mask,
            ],
            dim=1)
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        try:
            extended_attention_mask = extended_attention_mask.to(
                dtype=next(self.parameters()).dtype)  # fp16 compatibility
        except StopIteration:
            extended_attention_mask = extended_attention_mask.to(dtype=torch.float32)

        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        # Image embedding
        img_tok = (
            torch.LongTensor(input_txt.size(0), self.args.num_image_embeds + 2)
            .fill_(0)
            .cuda())
        img = self.img_encoder(input_img)  # BxNx3x224x224 -> BxNx2048

        img_embed_out = self.img_embeddings(img, img_tok)
        txt_embed_out = self.txt_embeddings(input_txt, segment)
        encoder_input = torch.cat([img_embed_out, txt_embed_out], 1)  # Bx(TEXT+IMG)xHID
        encoded_layers = self.encoder(encoder_input, extended_attention_mask)
        
        return self.pooler(encoded_layers[-1])

- Lấy đầu ra từ MultimodalBertEncoder.
- Đưa đầu ra qua lớp phân loại (clf).

In [None]:
class MultimodalBertClf(nn.Module):
    def __init__(self, args):
        super(MultimodalBertClf, self).__init__()
        self.args = args
        self.enc = MultimodalBertEncoder(args)
        self.clf = nn.Linear(args.hidden_sz, args.n_classes)

    def forward(self, txt, mask, segment, img):
        x = self.enc(txt, mask, segment, img)
        return self.clf(x)

In [None]:
MODELS = {
    "model": MultimodalBertClf,
}

def get_model(args):
    return MODELS['model'](args)

# **3. Main**

In [None]:
import os
import csv
import argparse
import torch.nn as nn
from tqdm import tqdm
from datetime import datetime
import torch.optim as optim
from pytorch_pretrained_bert import BertAdam
from sklearn.metrics import f1_score, accuracy_score, roc_auc_score
# from data.helpers import get_data_loaders
# from utils.utils import *

In [None]:
def get_criterion(args, device):
    if args.task_type == "multilabel":
        if args.weight_classes:
            freqs = [args.label_freqs[l] for l in args.labels]
            negative = [args.train_data_len - l for l in freqs]
            label_weights = (torch.FloatTensor(freqs) / torch.FloatTensor(negative)) ** -1
            criterion = nn.BCEWithLogitsLoss(pos_weight=label_weights.to(device))
        else:
            criterion = nn.BCEWithLogitsLoss()
    else:
        criterion = nn.CrossEntropyLoss()
    return criterion

In [None]:
def get_optimizer(model, args):
    total_steps = (
            args.train_data_len
            / args.batch_sz
            / args.gradient_accumulation_steps
            * args.max_epochs)
    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01},
        {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0, }]
    optimizer = BertAdam(
        optimizer_grouped_parameters,
        lr=args.lr,
        warmup=args.warmup,
        t_total=total_steps)
    return optimizer

In [None]:
def get_scheduler(optimizer, args):
    return optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, "max", patience=args.lr_patience, verbose=True, factor=args.lr_factor
    )

In [None]:
def model_eval(data, model, args, criterion, device, store_preds=False):
    with torch.no_grad():
        losses, preds, preds_bool, tgts, outAUROC = [], [], [], [], []
        for batch in data:
            loss, out, tgt = model_forward(model, args, criterion, batch, device)
            losses.append(loss.item())
            if args.task_type == "multilabel":
                pred_bool = torch.sigmoid(out).cpu().detach().numpy() > 0.5
                pred = torch.sigmoid(out).cpu().detach().numpy()
            else:pred = torch.nn.functional.softmax(out, dim=1).argmax(dim=1).cpu().detach().numpy()
            preds.append(pred)
            preds_bool.append(pred_bool)
            tgt = tgt.cpu().detach().numpy()
            tgts.append(tgt)

    metrics = {"loss": np.mean(losses)}
    classACC = dict()
    if args.task_type == "multilabel":
        tgts = np.vstack(tgts)
        preds = np.vstack(preds)
        preds_bool = np.vstack(preds_bool)

        for i in range(args.n_classes):
            try:
                outAUROC.append(roc_auc_score(tgts[:, i], preds[:, i]))
            except ValueError:
                outAUROC.append(0)
                pass
        for i in range(0, len(outAUROC)):
            assert args.n_classes == len(outAUROC)
            classACC[args.labels[i]] = outAUROC[i]

        metrics["micro_roc_auc"] = roc_auc_score(tgts, preds, average="micro")
        metrics["macro_roc_auc"] = roc_auc_score(tgts, preds, average="macro")
        metrics["macro_f1"] = f1_score(tgts, preds_bool, average="macro")
        metrics["micro_f1"] = f1_score(tgts, preds_bool, average="micro")
        print('micro_auc:', metrics["micro_roc_auc"])
        print('micro_f1:', metrics["micro_f1"])
        print('-----------------------------------------------------')
    else:
        tgts = [l for sl in tgts for l in sl]
        preds = [l for sl in preds for l in sl]
        metrics["acc"] = accuracy_score(tgts, preds)

    if store_preds:
        store_preds_to_disk(tgts, preds, args)

    return metrics, classACC, tgts, preds

In [None]:
def model_forward(model, args, criterion, batch, device):
    txt, segment, mask, img, tgt = batch
    model.to(device)
    if args.num_image_embeds > 0:
        for param in model.module.enc.img_encoder.parameters():
            param.requires_grad = args.freeze_img_all
    for param in model.module.enc.encoder.parameters():
        param.requires_grad = args.freeze_txt_all

    txt, img = txt.to(device), img.to(device)
    mask, segment = mask.to(device), segment.to(device)
    out = model(txt, mask, segment, img)

    tgt = tgt.to(device)
    loss = criterion(out, tgt)
    return loss, out, tgt

In [None]:
def train(args):
    print("Training start!!")
    print(" # PID :", os.getpid())
    
    set_seed(args.seed)
    args.savedir = os.path.join(args.savedir, args.save_name)
    os.makedirs(args.savedir, exist_ok=True)
    
    train_loader, val_loader = get_data_loaders(args)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = get_model(args)
    
    criterion = get_criterion(args, device)
    optimizer = get_optimizer(model, args)
    scheduler = get_scheduler(optimizer, args)
    
    # logger = create_logger("%s/logfile.log" % args.savedir, args)
    torch.save(args, os.path.join(args.savedir, "args.bin"))
    
    start_epoch, global_step, n_no_improve, best_metric = 0, 0, 0, -np.inf
    
    if os.path.exists(os.path.join(args.loaddir, "pytorch_model.bin")):
        model.load_state_dict(torch.load(args.loaddir + "/pytorch_model.bin"), strict=False)
    
        print("This would load the trained model, then fine-tune the model.")
    
    else:
        print("")
        print("")
        print("this option initilize the model with random value. train from scratch.")
        print("Loaded model : ")
    
    
    
    print("freeze image?", args.freeze_img_all)
    print("freeze txt?", args.freeze_txt_all)
    model.to(device)
    # logger.info("Training..")
    
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)
    
    for i_epoch in range(start_epoch, args.max_epochs):
        train_losses = []
        # model.module.train()
        model.train()
        optimizer.zero_grad()
    
        for batch in tqdm(train_loader, total=len(train_loader)):
            loss, out, target = model_forward(model, args, criterion, batch, device)
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
    
            train_losses.append(loss.item())
            loss.backward()
            global_step += 1
            if global_step % args.gradient_accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()
    
        model.eval()
        metrics, classACC, tgts, preds = model_eval(val_loader, model, args, criterion, device)
        # logger.info("Train Loss: {:.4f}".format(np.mean(train_losses)))
        # log_metrics("Val", metrics, args, logger)
    
        tuning_metric = (
            metrics["micro_f1"] if args.task_type == "multilabel" else metrics["acc"]
        )
        scheduler.step(tuning_metric)
        is_improvement = tuning_metric > best_metric
        if is_improvement:
            best_metric = tuning_metric
            n_no_improve = 0
            torch.save(model.state_dict(), os.path.join(args.savedir, "best_model.pth"))
        else:
            n_no_improve += 1
    
        csv_save_name = args.save_name
        save_path = args.savedir + '/' + csv_save_name + '.csv'
        f = open(save_path, 'w', encoding='utf-8')
        wr = csv.writer(f)
        key = list(classACC.keys())
        val = list(classACC.values())
        title = ['micro_auc', 'macro_auc', 'micro_f1', 'macro_f1'] + key
        result = [metrics["micro_roc_auc"], metrics["macro_roc_auc"], metrics["micro_f1"], metrics["macro_f1"]] + val
        wr.writerow(title)
        wr.writerow(result)
        f.close()
    
        # save_checkpoint(
        #     {
        #         "epoch": i_epoch + 1,
        #         "state_dict": model.state_dict(),
        #         "optimizer": optimizer.state_dict(),
        #         "scheduler": scheduler.state_dict(),
        #         "n_no_improve": n_no_improve,
        #         "best_metric": best_metric,
        #     },
        #     is_improvement,
        #     args.savedir,
        # )
    
        if n_no_improve >= args.patience:
            logger.info("No improvement. Breaking out of loop.")
            break

In [None]:
def test(args):

    print("Model Test")
    print(" # PID :", os.getpid())
    print('log:', args.Valid_dset_name)
    set_seed(args.seed)
    args.savedir = os.path.join(args.savedir, os.name)
    os.makedirs(args.savedir, exist_ok=True)

    train_loader, val_loader = get_data_loaders(args)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = get_model(args)

    criterion = get_criterion(args, device)

    torch.save(args, os.path.join(args.savedir, "args.bin"))


    if os.path.exists(os.path.join(args.loaddir, "model_best.pt")):
        model.load_state_dict(torch.load(args.loaddir + "/model_best.pt"), strict=False)

    else:
        print("")
        print("")
        print("this option initilize the model with random value. train from scratch.")
        print("Loaded model : ")

    print("freeze image?", args.freeze_img_all)
    print("freeze txt?", args.freeze_txt_all)
    model.to(device)

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)

    load_checkpoint(model, os.path.join(args.loaddir, "model_best.pt"))

    model.eval()
    metrics, classACC, tgts, preds  = model_eval(val_loader, model, args, criterion, device, store_preds=True)

    print('micro_roc_auc:', round(metrics["micro_roc_auc"], 3))
    print('macro_roc_auc:', round(metrics["macro_roc_auc"], 3))
    print('macro_f1 f1 scroe:', round(metrics["macro_f1"], 3))
    print('micro f1 score:', round(metrics["micro_f1"], 3))
    for i in classACC:
        print(i, round(classACC[i], 3))

# Classification

In [None]:
train(args)