# Multimodal RE Tutorial
> Tutorial author: 乔硕斐（shuofei@zju.edu.cn）

In this tutorial, we use a Transformer based two-stream multimodal model——`IFAformer` to extract relations. We hope this tutorial can help you understand the process of multimodal relation extraction.

This tutorial uses `Python3`.

## RE

**Relation Extraction**(RE), a key task in information extraction, predicts semantic relations between pairs of entities from unstructured texts.

## Multimodal RE

**Multimodal Relation Extraction**(MRE) in DeepKE applies a correlational visual image for each sentence, witch contains a pair of entities, to enhance the textual relation extraction.

## Dataset

We use [MNRE](https://github.com/thecharm/MNRE)(Multimodal Neural Relation Extraction) dataset in this tutorial. Each piece of data contains a sentence and an image. The data formats are as follow:

**Text**
```
{
  'token': ['RT', '@TheHerd', ':', 'LeBron', 'is', 'in', 'their', 'head', '.', '2', '-', '0', 'Cavs', '.'],
  'h': {
      'name': 'LeBron',
      'pos': [3, 4]
     },
  't': {
      'name': 'Cavs',
      'pos': [12, 13]
     }, 
  'img_id': 'twitter_stream_2018_05_07_16_0_2_180.jpg',
  'relation': '/per/org/member_of'
}
```
**Image**

![image](https://github.com/zjunlp/DeepKE/blob/main/tutorial-notebooks/re/multimodal/image/image.png?raw=1)

The structure of the dataset folder `./data/` is as follow:

```
.
├── img_detect                     # Detected objects using RCNN
├── img_vg                       # Detected objects using visual grounding toolkit
├── img_org                      # Original images
├── txt                        # Text Set
├── vg_data                      # Bounding image and img_vg
└── ours_rel2id.json                  # Relation Set
```

We use RCNN detected objects and visual grounding objects as visual local information, where RCNN via [faster_rcnn](https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py) and visual grounding via [onestage_grounding](https://github.com/zyang-ur/onestage_grounding).

## IFAformer

IFAformer is a novel dual Multimodal Transformer model with implicit feature alignment for the RE task, which utilizes the Transformer structure uniformly in the visual and textual without explicitly designing modal alignment structure.

![IFAformer](https://github.com/zjunlp/DeepKE/blob/main/tutorial-notebooks/re/multimodal/image/IFAformer.png?raw=1)

## Prepare evironment

In [None]:
! nvidia-smi
! pip install deepke
! git clone https://github.com/zjunlp/DeepKE.git

Since the **torchvision** provided by Colab is the latest version, it is incompatible with the **torch** version required by deepke, so it needs to be unified. **This step can be ignored in non Colab environments.**

In [None]:
! pip install torch==1.10.0+cu102 torchvision==0.11.0+cu102 torchaudio==0.10.0 -f https://download.pytorch.org/whl/torch_stable.html

## Import package

The python environment provided by Colab is 3.7, which is inconsistent with the python=3.8 required by deepke. To solve the problem, it is needed to fix ***import importlib.metadata as importlib\_metadata*** to ***import importlib\_metadata*** in file /usr/local/lib/Python3.7/dist-packages/deepke/relation\_extraction/multimodal/models/clip/file\_utils.py. **For non colab environments, please use python3.8, and ignore this step.**

In [None]:
import os
import hydra
import torch
import numpy as np
import random
from torchvision import transforms
from hydra import utils
import json
import ast
from PIL import Image
from torch.utils.data import Dataset
from transformers import BertTokenizer
from torch.utils.data import DataLoader
from deepke.relation_extraction.multimodal.models.clip.processing_clip import CLIPProcessor
from deepke.relation_extraction.multimodal.models.IFA_model import IFAREModel
from deepke.relation_extraction.multimodal.modules.dataset import MMREProcessor, MMREDataset
from deepke.relation_extraction.multimodal.modules.train import Trainer

import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)
logger = logging.getLogger(__name__)

import wandb
writer = wandb.init(project="DeepKE_RE_MM")

## Prepare dataset

In [None]:
! wget 120.27.214.45/Data/re/multimodal/data.tar.gz
! tar -xzvf data.tar.gz

In [None]:
DATA_PATH = {
        'train': 'data/txt/ours_train.txt',
        'dev': 'data/txt/ours_val.txt',
        'test': 'data/txt/ours_test.txt',
        'train_auximgs': 'data/txt/mre_train_dict.pth',
        'dev_auximgs': 'data/txt/mre_dev_dict.pth',
        'test_auximgs': 'data/txt/mre_test_dict.pth',
        'train_img2crop': 'data/img_detect/train/train_img2crop.pth',
        'dev_img2crop': 'data/img_detect/val/val_img2crop.pth',
        'test_img2crop': 'data/img_detect/test/test_img2crop.pth'}

IMG_PATH = {
    'train': 'data/img_org/train/',
    'dev': 'data/img_org/val/',
    'test': 'data/img_org/test'}

AUX_PATH = {
        'train': 'data/img_vg/train/crops',
        'dev': 'data/img_vg/val/crops',
        'test': 'data/img_vg/test/crops'}

re_path = 'data/ours_rel2id.json'

## Configure parameters

In [None]:
class Config(object):
    seed = 1234
    
    bert_name = "bert-base-uncased"
    vit_name = "openai/clip-vit-base-patch32"
    device = "cuda"
    
    num_epochs = 12
    batch_size = 32
    lr = 1e-5
    warmup_ratio = 0.06
    eval_begin_epoch = 1

    max_seq = 80
    aux_size = 128
    rcnn_size = 64

    save_path = "checkpoints"
    load_path = None

cfg = Config()

In [None]:
def set_seed(seed=2021):
    """set random seed"""
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    np.random.seed(seed)
    random.seed(seed)

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
              std=[0.229, 0.224, 0.225])])

set_seed(cfg.seed) # set seed, default is 1
if cfg.save_path is not None:  # make save_path dir
    if not os.path.exists(cfg.save_path):
        os.makedirs(cfg.save_path, exist_ok=True)
print(cfg)

## Process dataset

In [None]:
class MMREProcessor(object):
    def __init__(self, data_path, re_path, args):
        self.args = args
        self.data_path = data_path
        self.re_path = re_path
        self.tokenizer = BertTokenizer.from_pretrained(args.bert_name, do_lower_case=True)
        self.tokenizer.add_special_tokens({'additional_special_tokens':['<s>', '</s>', '<o>', '</o>']})

        self.clip_processor = CLIPProcessor.from_pretrained(args.vit_name)
        self.aux_processor = CLIPProcessor.from_pretrained(args.vit_name)
        self.aux_processor.feature_extractor.size, self.aux_processor.feature_extractor.crop_size = args.aux_size, args.aux_size
        self.rcnn_processor = CLIPProcessor.from_pretrained(args.vit_name)
        self.rcnn_processor.feature_extractor.size, self.rcnn_processor.feature_extractor.crop_size = args.rcnn_size, args.rcnn_size


    def load_from_file(self, mode="train"):
        load_file = os.path.join(self.args.cwd,self.data_path[mode])
        logger.info("Loading data from {}".format(load_file))
        with open(load_file, "r", encoding="utf-8") as f:
            lines = f.readlines()
            words, relations, heads, tails, imgids, dataid = [], [], [], [], [], []
            for i, line in enumerate(lines):
                line = ast.literal_eval(line)   # str to dict
                words.append(line['token'])
                relations.append(line['relation'])
                heads.append(line['h']) # {name, pos}
                tails.append(line['t'])
                imgids.append(line['img_id'])
                dataid.append(i)

        assert len(words) == len(relations) == len(heads) == len(tails) == (len(imgids))

        # 辅助图像
        aux_imgs = None
        # if not self.use_clip_vit:
        aux_path = os.path.join(self.args.cwd,self.data_path[mode+"_auximgs"])
        aux_imgs = torch.load(aux_path)
        rcnn_imgs = torch.load(os.path.join(self.args.cwd,self.data_path[mode+'_img2crop']))
        return {'words':words, 'relations':relations, 'heads':heads, 'tails':tails, 'imgids': imgids, 'dataid': dataid, 'aux_imgs':aux_imgs, "rcnn_imgs":rcnn_imgs}


    def get_relation_dict(self):
        
        with open(os.path.join(self.args.cwd,self.re_path), 'r', encoding="utf-8") as f:
            line = f.readlines()[0]
            re_dict = json.loads(line)
        return re_dict

    def get_rel2id(self, train_path):
        with open(os.path.join(self.args.cwd,self.re_path), 'r', encoding="utf-8") as f:
            line = f.readlines()[0]
            re_dict = json.loads(line)
        re2id = {key:[] for key in re_dict.keys()}
        with open(train_path, "r", encoding="utf-8") as f:
            lines = f.readlines()
            for i, line in enumerate(lines):
                line = ast.literal_eval(line)   # str to dict
                assert line['relation'] in re2id
                re2id[line['relation']].append(i)
        return re2id


class MMREDataset(Dataset):
    def __init__(self, processor, transform, img_path=None, aux_img_path=None, mode="train") -> None:
        self.processor = processor
        self.args = self.processor.args
        self.transform = transform
        self.max_seq = self.args.max_seq
        self.img_path = img_path[mode]  if img_path is not None else img_path
        self.aux_img_path = aux_img_path[mode] if aux_img_path is not None else aux_img_path
        self.rcnn_img_path = 'data'
        self.mode = mode
        self.data_dict = self.processor.load_from_file(mode)
        self.re_dict = self.processor.get_relation_dict()
        self.tokenizer = self.processor.tokenizer
        self.aux_size = self.args.aux_size
        self.rcnn_size = self.args.rcnn_size
    
    def __len__(self):
        return len(self.data_dict['words'])

    def __getitem__(self, idx):
        word_list, relation, head_d, tail_d, imgid = self.data_dict['words'][idx], self.data_dict['relations'][idx], self.data_dict['heads'][idx], self.data_dict['tails'][idx], self.data_dict['imgids'][idx]
        item_id = self.data_dict['dataid'][idx]
        # [CLS] ... <s> head </s> ... <o> tail <o/> .. [SEP]
        head_pos, tail_pos = head_d['pos'], tail_d['pos']
        # insert <s> <s/> <o> <o/>
        extend_word_list = []
        for i in range(len(word_list)):
            if  i == head_pos[0]:
                extend_word_list.append('<s>')
            if i == head_pos[1]:
                extend_word_list.append('</s>')
            if i == tail_pos[0]:
                extend_word_list.append('<o>')
            if i == tail_pos[1]:
                extend_word_list.append('</o>')
            extend_word_list.append(word_list[i])
        extend_word_list = " ".join(extend_word_list)   # list不会进行子词分词
        encode_dict = self.tokenizer.encode_plus(text=extend_word_list, max_length=self.max_seq, truncation=True, padding='max_length')
        input_ids, token_type_ids, attention_mask = encode_dict['input_ids'], encode_dict['token_type_ids'], encode_dict['attention_mask']
        input_ids, token_type_ids, attention_mask = torch.tensor(input_ids), torch.tensor(token_type_ids), torch.tensor(attention_mask)
        
        re_label = self.re_dict[relation]   # label to id

         # image process
        if self.img_path is not None:
            try:
                img_path = os.path.join(os.path.join(self.args.cwd,self.img_path), imgid)
                img_path = img_path.replace('test', 'train')
                image = Image.open(img_path).convert('RGB')
                image = self.processor.clip_processor(images=image, return_tensors='pt')['pixel_values'].squeeze()
            except:
                img_path = os.path.join(os.path.join(self.args.cwd,self.img_path), 'inf.png')
                image = Image.open(img_path).convert('RGB')
                image = self.processor.clip_processor(images=image, return_tensors='pt')['pixel_values'].squeeze()
            if self.aux_img_path is not None:
                # 辅助图像
                aux_imgs = []
                aux_img_paths = []
                imgid = imgid.split(".")[0]
                if item_id in self.data_dict['aux_imgs']:
                    aux_img_paths  = self.data_dict['aux_imgs'][item_id]
                    # print(aux_img_paths)
                    aux_img_paths = [os.path.join(os.path.join(self.args.cwd,self.aux_img_path), path) for path in aux_img_paths]
                # 大于3需要舍弃
                for i in range(min(3, len(aux_img_paths))):
                    aux_img = Image.open(aux_img_paths[i]).convert('RGB')
                    aux_img = self.processor.aux_processor(images=aux_img, return_tensors='pt')['pixel_values'].squeeze()
                    aux_imgs.append(aux_img)

                #小于3需要加padding-0
                for i in range(3-len(aux_imgs)):
                    aux_imgs.append(torch.zeros((3, self.aux_size, self.aux_size))) 

                aux_imgs = torch.stack(aux_imgs, dim=0)
                assert len(aux_imgs) == 3

                if self.rcnn_img_path is not None:
                    rcnn_imgs = []
                    rcnn_img_paths = []
                    if imgid in self.data_dict['rcnn_imgs']:
                        rcnn_img_paths = self.data_dict['rcnn_imgs'][imgid]
                        rcnn_img_paths = [os.path.join(os.path.join(self.args.cwd,self.rcnn_img_path), path) for path in rcnn_img_paths]
                     # 大于3需要舍弃
                    for i in range(min(3, len(rcnn_img_paths))):
                        rcnn_img = Image.open(rcnn_img_paths[i]).convert('RGB')
                        rcnn_img = self.processor.rcnn_processor(images=rcnn_img, return_tensors='pt')['pixel_values'].squeeze()
                        rcnn_imgs.append(rcnn_img)
                    #小于3需要加padding-0
                    for i in range(3-len(rcnn_imgs)):
                        rcnn_imgs.append(torch.zeros((3, self.rcnn_size, self.rcnn_size))) 

                    rcnn_imgs = torch.stack(rcnn_imgs, dim=0)
                    assert len(rcnn_imgs) == 3
                    return input_ids, token_type_ids, attention_mask, torch.tensor(re_label), image, aux_imgs, rcnn_imgs

                return input_ids, token_type_ids, attention_mask, torch.tensor(re_label), image, aux_imgs
            
        return input_ids, token_type_ids, attention_mask, torch.tensor(re_label)

In [None]:
processor = MMREProcessor(DATA_PATH, re_path, cfg)
train_dataset = MMREDataset(processor, transform, IMG_PATH, AUX_PATH, mode='train')
train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=4, pin_memory=True)

dev_dataset = MMREDataset(processor, transform, IMG_PATH, AUX_PATH, mode='dev')
dev_dataloader = DataLoader(dev_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=4, pin_memory=True)

test_dataset = MMREDataset(processor, transform, IMG_PATH, AUX_PATH, mode='test')
test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=4, pin_memory=True)

re_dict = processor.get_relation_dict()
num_labels = len(re_dict)
tokenizer = processor.tokenizer

## Model construction

In [None]:
import torch
from torch import nn

import torch.nn.functional as F
from .modeling_IFA import IFAModel
from .clip.modeling_clip import CLIPModel
from .clip.configuration_clip import CLIPConfig
from transformers import BertConfig, BertModel

class IFAREModel(nn.Module):
    def __init__(self, num_labels, tokenizer, args):
        super(IFAREModel, self).__init__()

        self.args = args
        self.vision_config = CLIPConfig.from_pretrained(self.args.vit_name).vision_config
        self.text_config = BertConfig.from_pretrained(self.args.bert_name)

        clip_model_dict = CLIPModel.from_pretrained(self.args.vit_name).vision_model.state_dict()
        bert_model_dict = BertModel.from_pretrained(self.args.bert_name).state_dict()

        print(self.vision_config)
        print(self.text_config)

        # for re
        self.vision_config.device = args.device
        self.model = IFAModel(self.vision_config, self.text_config)

        # load:
        vision_names, text_names = [], []
        model_dict = self.model.state_dict()
        for name in model_dict:
            if 'vision' in name:
                clip_name = name.replace('vision_', '').replace('model.', '')
                if clip_name in clip_model_dict:
                    vision_names.append(clip_name)
                    model_dict[name] = clip_model_dict[clip_name]
            elif 'text' in name:
                text_name = name.replace('text_', '').replace('model.', '')
                if text_name in bert_model_dict:
                    text_names.append(text_name)
                    model_dict[name] = bert_model_dict[text_name]
        assert len(vision_names) == len(clip_model_dict) and len(text_names) == len(bert_model_dict), \
                    (len(vision_names), len(text_names), len(clip_model_dict), len(bert_model_dict))
        self.model.load_state_dict(model_dict)

        self.model.resize_token_embeddings(len(tokenizer))

        self.dropout = nn.Dropout(0.5)
        self.classifier = nn.Linear(self.text_config.hidden_size*2, num_labels)
        self.head_start = tokenizer.convert_tokens_to_ids("<s>")
        self.tail_start = tokenizer.convert_tokens_to_ids("<o>")
        self.tokenizer = tokenizer

    def forward(
            self, 
            input_ids=None, 
            attention_mask=None, 
            token_type_ids=None, 
            labels=None, 
            images=None, 
            aux_imgs=None,
            rcnn_imgs=None,
    ):
        bsz = input_ids.size(0)
        output = self.model(input_ids=input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids,

                            pixel_values=images,
                            aux_values=aux_imgs, 
                            rcnn_values=rcnn_imgs,
                            return_dict=True,)

        last_hidden_state, pooler_output = output.last_hidden_state, output.pooler_output
        bsz, seq_len, hidden_size = last_hidden_state.shape
        entity_hidden_state = torch.Tensor(bsz, 2*hidden_size) # batch, 2*hidden
        for i in range(bsz):
            head_idx = input_ids[i].eq(self.head_start).nonzero().item()
            tail_idx = input_ids[i].eq(self.tail_start).nonzero().item()
            head_hidden = last_hidden_state[i, head_idx, :].squeeze()
            tail_hidden = last_hidden_state[i, tail_idx, :].squeeze()
            entity_hidden_state[i] = torch.cat([head_hidden, tail_hidden], dim=-1)
        entity_hidden_state = entity_hidden_state.to(self.args.device)
        logits = self.classifier(entity_hidden_state)
        if labels is not None:
            loss_fn = nn.CrossEntropyLoss()
            return loss_fn(logits, labels.view(-1)), logits
        return logits

In [None]:
model = IFAREModel(num_labels, tokenizer, cfg)

## Metric function

In [None]:
def eval_result(true_labels, pred_result, rel2id, logger, use_name=False):
    correct = 0
    total = len(true_labels)
    correct_positive = 0
    pred_positive = 0
    gold_positive = 0

    neg = -1
    for name in ['NA', 'na', 'no_relation', 'Other', 'Others', 'none', 'None']:
        if name in rel2id:
            if use_name:
                neg = name
            else:
                neg = rel2id[name]
            break
    for i in range(total):
        if use_name:
            golden = true_labels[i]
        else:
            golden = true_labels[i]

        if golden == pred_result[i]:
            correct += 1
            if golden != neg:
                correct_positive += 1
        if golden != neg:
            gold_positive += 1
        if pred_result[i] != neg:
            pred_positive += 1
    acc = float(correct) / float(total)
    try:
        micro_p = float(correct_positive) / float(pred_positive)
    except:
        micro_p = 0
    try:
        micro_r = float(correct_positive) / float(gold_positive)
    except:
        micro_r = 0
    try:
        micro_f1 = 2 * micro_p * micro_r / (micro_p + micro_r)
    except:
        micro_f1 = 0

    result = {'acc': acc, 'micro_p': micro_p, 'micro_r': micro_r, 'micro_f1': micro_f1}
    logger.info('Evaluation result: {}.'.format(result))
    return result

## Train the model

In [None]:
class Trainer(object):
    def __init__(self, train_data=None, dev_data=None, test_data=None, re_dict=None, model=None, args=None, logger=None, writer=None) -> None:
        self.train_data = train_data
        self.dev_data = dev_data
        self.test_data = test_data
        self.re_dict = re_dict
        self.model = model
        self.logger = logger
        self.refresh_step = 2
        self.best_dev_metric = 0
        self.best_test_metric = 0
        self.best_dev_epoch = None
        self.best_test_epoch = None
        self.optimizer = None
        self.writer = writer
        
        self.step = 0
        self.args = args

        if self.train_data is not None:
            self.train_num_steps = len(self.train_data) * args.num_epochs
            self.before_multimodal_train()
        self.model.to(self.args.device)

    def train(self):
        self.step = 0
        self.model.train()
        self.logger.info("***** Running training *****")
        self.logger.info("  Num instance = %d", len(self.train_data)*self.args.batch_size)
        self.logger.info("  Num epoch = %d", self.args.num_epochs)
        self.logger.info("  Batch size = %d", self.args.batch_size)
        self.logger.info("  Learning rate = {}".format(self.args.lr))
        self.logger.info("  Evaluate begin = %d", self.args.eval_begin_epoch)

        if self.args.load_path is not None:  # load model from load_path
            self.logger.info("Loading model from {}".format(self.args.load_path))
            self.model.load_state_dict(torch.load(self.args.load_path))
            self.logger.info("Load model successful!")


        with tqdm(total=self.train_num_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True, initial=self.step) as pbar:
            self.pbar = pbar
            avg_loss = 0
            for epoch in range(1, self.args.num_epochs+1):
                pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.args.num_epochs))
                for batch in self.train_data:
                    self.step += 1
                    batch = (tup.to(self.args.device)  if isinstance(tup, torch.Tensor) else tup for tup in batch)
                    (loss, logits), labels = self._step(batch, mode="train")
                    avg_loss += loss.detach().cpu().item()

                    loss.backward()
                    self.optimizer.step()
                    self.scheduler.step()
                    self.optimizer.zero_grad()

                    if self.step % self.refresh_step == 0:
                        avg_loss = float(avg_loss) / self.refresh_step
                        print_output = "loss:{:<6.5f}".format(avg_loss)
                        pbar.update(self.refresh_step)
                        pbar.set_postfix_str(print_output)
                        self.writer.log({'avg_loss': avg_loss})
                        avg_loss = 0

                if epoch >= self.args.eval_begin_epoch:
                    if self.dev_data:
                        self.evaluate(epoch)   # generator to dev.
                    if self.test_data:
                        self.test(epoch)
            
            pbar.close()
            self.pbar = None
            self.logger.info("Get best dev performance at epoch {}, best dev f1 score is {}".format(self.best_dev_epoch, self.best_dev_metric))
            self.logger.info("Get best test performance at epoch {}, best test f1 score is {}".format(self.best_test_epoch, self.best_test_metric))

    def evaluate(self, epoch):
        self.model.eval()
        self.logger.info("***** Running evaluate *****")
        self.logger.info("  Num instance = %d", len(self.dev_data)*self.args.batch_size)
        self.logger.info("  Batch size = %d", self.args.batch_size)
        step = 0
        true_labels, pred_labels = [], []
        with torch.no_grad():
            with tqdm(total=len(self.dev_data), leave=False, dynamic_ncols=True) as pbar:
                pbar.set_description_str(desc="Dev")
                total_loss = 0
                for batch in self.dev_data:
                    step += 1
                    batch = (tup.to(self.args.device)  if isinstance(tup, torch.Tensor) else tup for tup in batch)  # to cpu/cuda device
                    (loss, logits), labels = self._step(batch, mode="dev")    # logits: batch, 3
                    total_loss += loss.detach().cpu().item()
                    
                    preds = logits.argmax(-1)
                    true_labels.extend(labels.view(-1).detach().cpu().tolist())
                    pred_labels.extend(preds.view(-1).detach().cpu().tolist())
                    pbar.update()
                # evaluate done
                pbar.close()
                sk_result = classification_report(y_true=true_labels, y_pred=pred_labels, labels=list(self.re_dict.values())[1:], target_names=list(self.re_dict.keys())[1:], digits=4)
                self.logger.info("%s\n", sk_result)
                result = eval_result(true_labels, pred_labels, self.re_dict, self.logger)
                acc, micro_f1 = round(result['acc']*100, 4), round(result['micro_f1']*100, 4)

                self.logger.info("Epoch {}/{}, best dev f1: {}, best epoch: {}, current dev f1 score: {}, acc: {}."\
                            .format(epoch, self.args.num_epochs, self.best_dev_metric, self.best_dev_epoch, micro_f1, acc))
                self.writer.log({'eva_f1': micro_f1, 'eva_accuracy': acc})
                if micro_f1 >= self.best_dev_metric:  # this epoch get best performance
                    self.logger.info("Get better performance at epoch {}".format(epoch))
                    self.best_dev_epoch = epoch
                    self.best_dev_metric = micro_f1 # update best metric(f1 score)
                    if self.args.save_path is not None:
                        torch.save(self.model.state_dict(), self.args.save_path+"/best_model.pth")
                        self.logger.info("Save best model at {}".format(self.args.save_path))

        self.model.train()

    def test(self, epoch):
        self.model.eval()
        self.logger.info("\n***** Running testing *****")
        self.logger.info("  Num instance = %d", len(self.test_data)*self.args.batch_size)
        self.logger.info("  Batch size = %d", self.args.batch_size)
        
        if self.args.load_path is not None:  # load model from load_path
            self.logger.info("Loading model from {}".format(self.args.load_path))
            self.model.load_state_dict(torch.load(self.args.load_path))
            self.logger.info("Load model successful!")
        true_labels, pred_labels = [], []
        with torch.no_grad():
            with tqdm(total=len(self.test_data), leave=False, dynamic_ncols=True) as pbar:
                pbar.set_description_str(desc="Testing")
                total_loss = 0
                for batch in self.test_data:
                    batch = (tup.to(self.args.device)  if isinstance(tup, torch.Tensor) else tup for tup in batch)  # to cpu/cuda device  
                    (loss, logits), labels = self._step(batch, mode="dev")    # logits: batch, 3
                    total_loss += loss.detach().cpu().item()
                    
                    preds = logits.argmax(-1)
                    true_labels.extend(labels.view(-1).detach().cpu().tolist())
                    pred_labels.extend(preds.view(-1).detach().cpu().tolist())
                    
                    pbar.update()
                # evaluate done
                pbar.close()
                sk_result = classification_report(y_true=true_labels, y_pred=pred_labels, labels=list(self.re_dict.values())[1:], target_names=list(self.re_dict.keys())[1:], digits=4)
                self.logger.info("%s\n", sk_result)
                result = eval_result(true_labels, pred_labels, self.re_dict, self.logger)
                acc, micro_f1 = round(result['acc']*100, 4), round(result['micro_f1']*100, 4)
                total_loss = 0
                self.logger.info("Epoch {}/{}, best test f1: {}, best epoch: {}, current test f1 score: {}, acc: {}"\
                            .format(epoch, self.args.num_epochs, self.best_test_metric, self.best_test_epoch, micro_f1, acc))
                self.writer.log({'test_f1': micro_f1, 'test_accuracy': acc})
                if micro_f1 >= self.best_test_metric:  # this epoch get best performance
                    self.best_test_metric = micro_f1
                    self.best_test_epoch = epoch
        
        self.model.train()
    
    def predict(self):
        self.model.eval()
        self.logger.info("\n***** Running predicting *****")
        self.logger.info("  Num instance = %d", len(self.test_data)*self.args.batch_size)
        self.logger.info("  Batch size = %d", self.args.batch_size)
        
        if self.args.load_path is not None:  # load model from load_path
            self.logger.info("Loading model from {}".format(self.args.load_path))
            self.model.load_state_dict(torch.load(self.args.load_path))
            self.logger.info("Load model successful!")
        true_labels, pred_labels = [], []
        with torch.no_grad():
            with tqdm(total=len(self.test_data), leave=False, dynamic_ncols=True) as pbar:
                pbar.set_description_str(desc="Predicting")
                total_loss = 0
                for batch in self.test_data:
                    batch = (tup.to(self.args.device)  if isinstance(tup, torch.Tensor) else tup for tup in batch)  # to cpu/cuda device  
                    (loss, logits), labels = self._step(batch, mode="dev")    # logits: batch, 3
                    total_loss += loss.detach().cpu().item()
                    
                    preds = logits.argmax(-1)
                    true_labels.extend(labels.view(-1).detach().cpu().tolist())
                    pred_labels.extend(preds.view(-1).detach().cpu().tolist())
                    
                    pbar.update()
                # evaluate done
                pbar.close()
                sk_result = classification_report(y_true=true_labels, y_pred=pred_labels, labels=list(self.re_dict.values())[1:], target_names=list(self.re_dict.keys())[1:], digits=4)
                self.logger.info("%s\n", sk_result)
                # save predict results
                import os
                with open(os.path.join(self.args.cwd,'data/txt/result.txt'), 'w', encoding="utf-8") as wf:
                    wf.write(sk_result)
                    print('Successful write!!')

                
        
        self.model.train()

    def _step(self, batch, mode="train"):
        input_ids, token_type_ids, attention_mask, labels, images, aux_imgs, rcnn_imgs = batch
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=labels, images=images, aux_imgs=aux_imgs, rcnn_imgs=rcnn_imgs)
        return outputs, labels
    
    def before_multimodal_train(self):
        optimizer_grouped_parameters = []
        params = {'lr':self.args.lr, 'weight_decay':1e-2}
        params['params'] = []
        for name, param in self.model.named_parameters():
            if 'model' in name:
                params['params'].append(param)
        optimizer_grouped_parameters.append(params)

        self.optimizer = optim.AdamW(optimizer_grouped_parameters, lr=self.args.lr)
        self.scheduler = get_linear_schedule_with_warmup(optimizer=self.optimizer, 
                                 num_warmup_steps=self.args.warmup_ratio*self.train_num_steps, 
                                 num_training_steps=self.train_num_steps)

In [None]:
trainer = Trainer(train_data=train_dataloader, dev_data=dev_dataloader, test_data=test_dataloader, re_dict=re_dict, model=model, args=cfg, logger=logger, writer=writer)
trainer.train()

## Predict

In [None]:
trainer = Trainer(train_data=None, dev_data=None, test_data=test_dataloader, re_dict=re_dict, model=model, args=cfg, logger=logger, writer=writer)
trainer.predict()