# Multimodal NER Tutorial
> Tutorial author:乔硕斐（shuofei@zju.edu.cn）

In this tutorial, we use a Transformer based two-stream multimodal model——IFAformer to recognize named entities. We hope this tutorial can help you understand the process of multimodal named entity recognition.

This tutorial uses `Python3`.

## NER
**Named-entity recognition** (also known as named entity identification, entity chunking, and entity extraction) is a subtask of information extraction that seeks to locate and classify named entities mentioned in unstructured text into pre-defined categories such as person names, organizations, locations, medical codes, time expressions, quantities, monetary values, percentages, etc.

## Multimodal NER
**Multimodal named entity recognition**(MNER) in DeepKE applies a correlational visual image for each sentence to enhance the textual named entity recognition.

## Dataset
We use Twitter15 dataset from [UMT](https://github.com/jefferyYu/UMT/) in this tutorial. Each piece of data contains a sentence and an image. The data formats are as follow:

**Text**

```
IMGID:16_05_01_6
5	O
days	O
until	O
JUSTIN	B-PER
BIEBER	I-PER
CONCERT	O
omfg	O
Ima	O
die	O
#	O
puposetour	O
@	O
alminababexox	O
@	O
justinbieber	B-PER
```

**Image**

![image](https://github.com/zjunlp/DeepKE/blob/main/tutorial-notebooks/ner/multimodal/image/image.jpg?raw=1)

The structure of the dataset folder `./data/` is as follow:

```
.
├── twitter15_detect                     # Detected objects using RCNN
├── twitter2015_aux_images                  # Detected objects using visual grounding toolkit
├── twitter2015_images                    # Original images
├── train.txt                        # Train set
└── ...
```

We use RCNN detected objects and visual grounding objects as visual local information, where RCNN via [faster_rcnn](https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py) and visual grounding via [onestage_grounding](https://github.com/zyang-ur/onestage_grounding).

## IFAformer
IFAformer is a novel dual Multimodal Transformer model with implicit feature alignment for the RE task, which utilizes the Transformer structure uniformly in the visual and textual without explicitly designing modal alignment structure. Here **we add a CRF module instead of the relation predictor** after IFAformer to enforce NER task.

![IFAformer](https://github.com/zjunlp/DeepKE/blob/main/tutorial-notebooks/ner/multimodal/image/ner.png?raw=1)

## Prepare environment

In [None]:
! nvidia-smi
! pip install deepke
! git clone https://github.com/zjunlp/DeepKE.git

Since the **torchvision** provided by Colab is the latest version, it is incompatible with the **torch** version required by deepke, so it needs to be unified. **This step can be ignored in non Colab environments.**

In [None]:
! pip install torch==1.10.0+cu102 torchvision==0.11.0+cu102 torchaudio==0.10.0 -f https://download.pytorch.org/whl/torch_stable.html

## Import package

The python environment provided by Colab is 3.7, which is inconsistent with the python=3.8 required by deepke. To solve the problem, it is needed to fix ***import importlib.metadata as importlib\_metadata*** to ***import importlib\_metadata*** in file /usr/local/lib/Python3.7/dist-packages/deepke/relation\_extraction/multimodal/models/clip/file\_utils.py. and /usr/local/lib/Python3.7/dist-packages/deepke/name\_entity\_re/multimodal/models/clip/file\_utils.py. **For non colab environments, please use python3.8, and ignore this step.**

In [None]:
import os
import hydra
import torch
import numpy as np
import random
from PIL import Image
from torch.utils.data import Dataset
from transformers import BertTokenizer
from hydra import utils
from deepke.name_entity_re.multimodal.models.clip.processing_clip import CLIPProcessor
from torch.utils.data import DataLoader
from deepke.name_entity_re.multimodal.models.IFA_model import IFANERCRFModel
from deepke.name_entity_re.multimodal.modules.dataset import MMPNERProcessor, MMPNERDataset
from deepke.name_entity_re.multimodal.modules.train import Trainer

import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)
logger = logging.getLogger(__name__)

# import wandb
# writer = wandb.init(project="DeepKE_NER_MM")
writer=None

## Prepare dataset

In [None]:
! wget 120.27.214.45/Data/ner/multimodal/data.tar.gz
! tar -xzvf data.tar.gz

In [None]:
DATA_PATH = {
    'twitter15': {'train': 'data/twitter2015/train.txt',
                'dev': 'data/twitter2015/valid.txt',
                'test': 'data/twitter2015/test.txt',
                'train_auximgs': 'data/twitter2015/twitter2015_train_dict.pth',
                'dev_auximgs': 'data/twitter2015/twitter2015_val_dict.pth',
                'test_auximgs': 'data/twitter2015/twitter2015_test_dict.pth',
                'rcnn_img_path': 'data/twitter2015',
                'img2crop': 'data/twitter2015/twitter2015_img2crop.pth'},

    'twitter17': {'train': 'data/twitter2017/train.txt',
                'dev': 'data/twitter2017/valid.txt',
                'test': 'data/twitter2017/test.txt',
                'train_auximgs': 'data/twitter2017/twitter2017_train_dict.pth',
                'dev_auximgs': 'data/twitter2017/twitter2017_val_dict.pth',
                'test_auximgs': 'data/twitter2017/twitter2017_test_dict.pth',
                'rcnn_img_path': 'data/twitter2017',
                'img2crop': 'data/twitter2017/twitter17_img2crop.pth'}
    }

IMG_PATH = {
    'twitter15': 'data/twitter2015/twitter2015_images',
    'twitter17': 'data/twitter2017/twitter2017_images'
}

AUX_PATH = {
    'twitter15': {'train': 'data/twitter2015/twitter2015_aux_images/train/crops',
                'dev': 'data/twitter2015/twitter2015_aux_images/val/crops',
                'test': 'data/twitter2015/twitter2015_aux_images/test/crops'},

    'twitter17': {'train': 'data/twitter2017/twitter2017_aux_images/train/crops',
                'dev': 'data/twitter2017/twitter2017_aux_images/val/crops',
                'test': 'data/twitter2017/twitter2017_aux_images/test/crops'}
}

LABEL_LIST = ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "X", "[CLS]", "[SEP]"]

## Configure parameters

In [None]:
class Config(object):
    seed = 1234
    
    bert_name = "bert-base-uncased"
    vit_name = "openai/clip-vit-base-patch32"
    device = "cuda"
    
    num_epochs = 30
    batch_size = 32
    lr = 5e-5
    warmup_ratio = 0.06
    eval_begin_epoch = 1

    max_seq = 40
    aux_size = 128
    rcnn_size = 64
    ignore_idx = 0

    save_path = "checkpoints/twitter15"
    load_path = None

    dataset_name = "twitter15"

cfg = Config()

In [None]:
def set_seed(seed=2021):
    """set random seed"""
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    np.random.seed(seed)
    random.seed(seed)

set_seed(cfg.seed) # set seed, default is 1
if cfg.save_path is not None:  # make save_path dir
    if not os.path.exists(cfg.save_path):
        os.makedirs(cfg.save_path, exist_ok=True)
print(cfg)

## Process dataset

In [None]:
class MMPNERProcessor(object):
    def __init__(self, data_path, args) -> None:
        self.data_path = data_path
        self.tokenizer = BertTokenizer.from_pretrained(args.bert_name, do_lower_case=True)
        self.clip_processor = CLIPProcessor.from_pretrained(args.vit_name)
        self.aux_processor = CLIPProcessor.from_pretrained(args.vit_name)
        self.aux_processor.feature_extractor.size, self.aux_processor.feature_extractor.crop_size = args.aux_size, args.aux_size
        self.rcnn_processor = CLIPProcessor.from_pretrained(args.vit_name)
        self.rcnn_processor.feature_extractor.size, self.rcnn_processor.feature_extractor.crop_size = args.rcnn_size, args.rcnn_size
        self.cwd = args.cwd

    def load_from_file(self, mode="train"):
        load_file = os.path.join(self.cwd,self.data_path[mode])
        logger.info("Loading data from {}".format(load_file))
        with open(load_file, "r", encoding="utf-8") as f:
            lines = f.readlines()
            raw_words, raw_targets = [], []
            raw_word, raw_target = [], []
            imgs = []
            for line in lines:
                if line.startswith("IMGID:"):
                    img_id = line.strip().split('IMGID:')[1] + '.jpg'
                    imgs.append(img_id)
                    continue
                if line != "\n":
                    raw_word.append(line.split('\t')[0])
                    label = line.split('\t')[1][:-1]
                    if 'OTHER' in label:
                        label = label[:2] + 'MISC'
                    raw_target.append(label)
                else:
                    raw_words.append(raw_word)
                    raw_targets.append(raw_target)
                    raw_word, raw_target = [], []

        assert len(raw_words) == len(raw_targets) == len(imgs), "{}, {}, {}".format(len(raw_words), len(raw_targets), len(imgs))
        aux_imgs = None
        # if not self.use_clip_vit:
        aux_path = os.path.join(self.cwd,self.data_path[mode+"_auximgs"])
        aux_imgs = torch.load(aux_path)

        rcnn_imgs = torch.load(os.path.join(self.cwd,self.data_path['train_auximgs']))

        return {"words": raw_words, "targets": raw_targets, "imgs": imgs, "aux_imgs":aux_imgs, "rcnn_imgs":rcnn_imgs}


class MMPNERDataset(Dataset):
    def __init__(self, processor, label_mapping, img_path=None, aux_path=None, rcnn_img_path=None, max_seq=40, ignore_idx=-100, aux_size=128, rcnn_size=64, mode='train',cwd='') -> None:
        self.processor = processor
        self.data_dict = processor.load_from_file(mode)
        self.tokenizer = processor.tokenizer
        self.label_mapping = label_mapping
        self.max_seq = max_seq
        self.ignore_idx = ignore_idx
        self.img_path = img_path
        self.aux_img_path = aux_path[mode]  if aux_path is not None else None
        self.rcnn_img_path = rcnn_img_path
        self.mode = mode
        self.clip_processor = self.processor.clip_processor
        self.aux_processor = self.processor.aux_processor
        self.rcnn_processor = self.processor.rcnn_processor
        self.aux_size = aux_size
        self.rcnn_size = rcnn_size
        self.cwd = cwd
    
    def __len__(self):
        return len(self.data_dict['words'])

    def __getitem__(self, idx):
        word_list, label_list, img = self.data_dict['words'][idx], self.data_dict['targets'][idx], self.data_dict['imgs'][idx]
        tokens, labels = [], []
        for i, word in enumerate(word_list):
            token = self.tokenizer.tokenize(word)
            tokens.extend(token)
            label = label_list[i]
            for m in range(len(token)):
                if m == 0:
                    labels.append(self.label_mapping[label])
                else:
                    labels.append(self.label_mapping["X"])
        if len(tokens) >= self.max_seq - 1:
            tokens = tokens[0:(self.max_seq - 2)]
            labels = labels[0:(self.max_seq - 2)]

        encode_dict = self.tokenizer.encode_plus(tokens, max_length=self.max_seq, truncation=True, padding='max_length')
        input_ids, token_type_ids, attention_mask = encode_dict['input_ids'], encode_dict['token_type_ids'], encode_dict['attention_mask']
        labels = [self.label_mapping["[CLS]"]] + labels + [self.label_mapping["[SEP]"]] + [self.ignore_idx]*(self.max_seq-len(labels)-2)

        
        if self.img_path is not None:
            self.img_path = os.path.join(self.cwd,self.img_path)
            # image process
            try:
                img_path = os.path.join(self.img_path, img)
                image = Image.open(img_path).convert('RGB')
                image = self.clip_processor(images=image, return_tensors='pt')['pixel_values'].squeeze()
            except:
                img_path = os.path.join(self.img_path, 'inf.png')
                image = Image.open(img_path).convert('RGB')
                image = self.clip_processor(images=image, return_tensors='pt')['pixel_values'].squeeze()

            if self.aux_img_path is not None:
                aux_imgs = []
                aux_img_paths = []
                if img in self.data_dict['aux_imgs']:
                    aux_img_paths  = self.data_dict['aux_imgs'][img]
                    aux_img_paths = [os.path.join(self.aux_img_path, path) for path in aux_img_paths]
                for i in range(min(3, len(aux_img_paths))):
                    aux_img = Image.open(os.path.join(self.cwd,aux_img_paths[i])).convert('RGB')
                    aux_img = self.aux_processor(images=aux_img, return_tensors='pt')['pixel_values'].squeeze()
                    aux_imgs.append(aux_img)

                for i in range(3-len(aux_imgs)):
                    aux_imgs.append(torch.zeros((3, self.aux_size, self.aux_size))) 

                aux_imgs = torch.stack(aux_imgs, dim=0)
                assert len(aux_imgs) == 3

                if self.rcnn_img_path is not None:
                    rcnn_imgs = []
                    rcnn_img_paths = []
                    img = img.split('.')[0]
                    if img in self.data_dict['rcnn_imgs']:
                        rcnn_img_paths = self.data_dict['rcnn_imgs'][img]
                        rcnn_img_paths = [os.path.join(self.rcnn_img_path, path) for path in rcnn_img_paths]
                    for i in range(min(3, len(rcnn_img_paths))):
                        rcnn_img = Image.open(rcnn_img_paths[i]).convert('RGB')
                        rcnn_img = self.rcnn_processor(images=rcnn_img, return_tensors='pt')['pixel_values'].squeeze()
                        rcnn_imgs.append(rcnn_img)

                    for i in range(3-len(rcnn_imgs)):
                        rcnn_imgs.append(torch.zeros((3, self.rcnn_size, self.rcnn_size))) 

                    rcnn_imgs = torch.stack(rcnn_imgs, dim=0)
                    assert len(rcnn_imgs) == 3
                    return torch.tensor(input_ids), torch.tensor(token_type_ids), torch.tensor(attention_mask), torch.tensor(labels), image, aux_imgs, rcnn_imgs

                return torch.tensor(input_ids), torch.tensor(token_type_ids), torch.tensor(attention_mask), torch.tensor(labels), image, aux_imgs

        assert len(input_ids) == len(token_type_ids) == len(attention_mask) == len(labels)
        return torch.tensor(input_ids), torch.tensor(token_type_ids), torch.tensor(attention_mask), torch.tensor(labels)

In [None]:
label_mapping = {label:idx for idx, label in enumerate(LABEL_LIST, 1)}
label_mapping["PAD"] = 0
data_path, img_path, aux_path = DATA_PATH[cfg.dataset_name], IMG_PATH[cfg.dataset_name], AUX_PATH[cfg.dataset_name]
rcnn_img_path = DATA_PATH[cfg.dataset_name]['rcnn_img_path']

processor = MMPNERProcessor(data_path, cfg)
train_dataset = MMPNERDataset(processor, label_mapping, img_path, aux_path, rcnn_img_path, max_seq=cfg.max_seq, ignore_idx=cfg.ignore_idx, aux_size=cfg.aux_size, rcnn_size=cfg.rcnn_size, mode='train',cwd=cwd)
train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=4, pin_memory=True)

dev_dataset = MMPNERDataset(processor, label_mapping, img_path, aux_path, rcnn_img_path, max_seq=cfg.max_seq, ignore_idx=cfg.ignore_idx, aux_size=cfg.aux_size, rcnn_size=cfg.rcnn_size, mode='dev',cwd=cwd)
dev_dataloader = DataLoader(dev_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=4, pin_memory=True)

test_dataset = MMPNERDataset(processor, label_mapping, img_path, aux_path, rcnn_img_path, max_seq=cfg.max_seq, ignore_idx=cfg.ignore_idx, aux_size=cfg.aux_size, rcnn_size=cfg.rcnn_size, mode='test',cwd=cwd)
test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=4, pin_memory=True)

## Model construction

In [None]:
import torch
from torch import nn
from torchcrf import CRF
import torch.nn.functional as F
from .modeling_IFA import IFAModel
from .clip.modeling_clip import CLIPModel
from .clip.configuration_clip import CLIPConfig
from transformers import BertConfig, BertModel

class IFANERCRFModel(nn.Module):
    def __init__(self, label_list, args):
        super(IFANERCRFModel, self).__init__()
        self.args = args
        self.vision_config = CLIPConfig.from_pretrained(self.args.vit_name).vision_config
        self.text_config = BertConfig.from_pretrained(self.args.bert_name)

        clip_model_dict = CLIPModel.from_pretrained(self.args.vit_name).vision_model.state_dict()
        bert_model_dict = BertModel.from_pretrained(self.args.bert_name).state_dict()

        print(self.vision_config)
        print(self.text_config)

        self.vision_config.device = args.device
        self.model = IFAModel(self.vision_config, self.text_config)

        self.num_labels  = len(label_list) + 1  # pad
        self.crf = CRF(self.num_labels, batch_first=True)
        self.fc = nn.Linear(self.text_config.hidden_size, self.num_labels)
        self.dropout = nn.Dropout(0.1)

        # load:
        vision_names, text_names = [], []
        model_dict = self.model.state_dict()
        for name in model_dict:
            if 'vision' in name:
                clip_name = name.replace('vision_', '').replace('model.', '')
                if clip_name in clip_model_dict:
                    vision_names.append(clip_name)
                    model_dict[name] = clip_model_dict[clip_name]
            elif 'text' in name:
                text_name = name.replace('text_', '').replace('model.', '')
                if text_name in bert_model_dict:
                    text_names.append(text_name)
                    model_dict[name] = bert_model_dict[text_name]
        assert len(vision_names) == len(clip_model_dict) and len(text_names) == len(bert_model_dict), \
                    (len(vision_names), len(text_names), len(clip_model_dict), len(bert_model_dict))
        self.model.load_state_dict(model_dict)

    def forward(
            self, 
            input_ids=None, 
            attention_mask=None, 
            token_type_ids=None, 
            labels=None, 
            images=None, 
            aux_imgs=None,
            rcnn_imgs=None,
    ):
        bsz = input_ids.size(0)

        output = self.model(input_ids=input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids,

                            pixel_values=images,
                            aux_values=aux_imgs, 
                            rcnn_values=rcnn_imgs,
                            return_dict=True,)

        sequence_output = output.last_hidden_state       # bsz, len, hidden
        sequence_output = self.dropout(sequence_output)  # bsz, len, hidden
        emissions = self.fc(sequence_output)             # bsz, len, labels
        
        logits = self.crf.decode(emissions, attention_mask.byte())
        loss = None
        if labels is not None:
            loss = -1 * self.crf(emissions, labels, mask=attention_mask.byte(), reduction='mean')  # 去掉CLS
            return logits, loss
        return logits, None

In [None]:
model = IFANERCRFModel(LABEL_LIST, cfg)

## Train the model

In [None]:
class Trainer(object):
    def __init__(self, train_data=None, dev_data=None, test_data=None, model=None, process=None, label_map=None, args=None, logger=None,  writer=None) -> None:
        self.train_data = train_data
        self.dev_data = dev_data
        self.test_data = test_data
        self.model = model
        self.process = process
        self.logger = logger
        self.label_map = label_map
        self.writer = writer
        self.refresh_step = 2
        self.best_dev_metric = 0
        self.best_test_metric = 0
        self.best_dev_epoch = None
        self.best_test_epoch = None
        self.optimizer = None
        self.step = 0
        self.args = args
        if self.train_data is not None:
            self.train_num_steps = len(self.train_data) * args.num_epochs
            self.multiModal_before_train()
        
        
    
    def train(self):
        self.step = 0
        self.model.train()
        self.logger.info("***** Running training *****")
        self.logger.info("  Num instance = %d", len(self.train_data)*self.args.batch_size)
        self.logger.info("  Num epoch = %d", self.args.num_epochs)
        self.logger.info("  Batch size = %d", self.args.batch_size)
        self.logger.info("  Learning rate = {}".format(self.args.lr))
        self.logger.info("  Evaluate begin = %d", self.args.eval_begin_epoch)

        if self.args.load_path is not None:  # load model from load_path
            self.logger.info("Loading model from {}".format(self.args.load_path))
            self.model.load_state_dict(torch.load(self.args.load_path))
            self.logger.info("Load model successful!")
            
        with tqdm(total=self.train_num_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True, initial=self.step) as pbar:
            self.pbar = pbar
            avg_loss = 0
            for epoch in range(1, self.args.num_epochs+1):
                y_true, y_pred = [], []
                y_true_idx, y_pred_idx = [], []
                pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.args.num_epochs))
                for batch in self.train_data:
                    self.step += 1
                    batch = (tup.to(self.args.device)  if isinstance(tup, torch.Tensor) else tup for tup in batch)
                    attention_mask, labels, logits, loss = self._step(batch, mode="train")
                    avg_loss += loss.detach().cpu().item()

                    loss.backward()
                    self.optimizer.step()
                    self.scheduler.step()

                    self.optimizer.zero_grad()

                    if isinstance(logits, torch.Tensor): 
                        logits = logits.argmax(-1).detach().cpu().numpy()  # batch, seq, 1
                    label_ids = labels.to('cpu').numpy()
                    input_mask = attention_mask.to('cpu').numpy()
                    label_map = {idx:label for label, idx in self.label_map.items()}
                    for i, mask in enumerate(input_mask):
                        temp_1 = []
                        temp_2 = []
                        temp_1_idx, temp_2_idx = [], []
                        for j, m in enumerate(mask):
                            if j == 0:
                                continue
                            if m:
                                if label_map[label_ids[i][j]] != "X" and label_map[label_ids[i][j]] != "[SEP]":
                                    temp_1.append(label_map[label_ids[i][j]])
                                    temp_2.append(label_map[logits[i][j]])
                                    temp_1_idx.append(label_ids[i][j])
                                    temp_2_idx.append(logits[i][j])
                            else:
                                break
                        y_true.append(temp_1)
                        y_pred.append(temp_2)
                        y_true_idx.append(temp_1_idx)
                        y_pred_idx.append(temp_2_idx)

                    if self.step % self.refresh_step == 0:
                        avg_loss = float(avg_loss) / self.refresh_step
                        print_output = "loss:{:<6.5f}".format(avg_loss)
                        pbar.update(self.refresh_step)
                        pbar.set_postfix_str(print_output)
                        if self.writer:
                            self.writer.log({'avg_loss': avg_loss})
                        avg_loss = 0
               
                if epoch >= self.args.eval_begin_epoch:
                    if self.dev_data:
                        self.evaluate(epoch)   # generator to dev.
                    if self.test_data:
                        self.test(epoch)
            
            torch.cuda.empty_cache()
            
            pbar.close()
            self.pbar = None
            self.logger.info("Get best dev performance at epoch {}, best dev f1 score is {}".format(self.best_dev_epoch, self.best_dev_metric))
            self.logger.info("Get best test performance at epoch {}, best test f1 score is {}".format(self.best_test_epoch, self.best_test_metric))

    def evaluate(self, epoch):
        self.model.eval()
        self.logger.info("***** Running evaluate *****")
        self.logger.info("  Num instance = %d", len(self.dev_data)*self.args.batch_size)
        self.logger.info("  Batch size = %d", self.args.batch_size)

        y_true, y_pred = [], []
        y_true_idx, y_pred_idx = [], []
        step = 0
        with torch.no_grad():
            with tqdm(total=len(self.dev_data), leave=False, dynamic_ncols=True) as pbar:
                pbar.set_description_str(desc="Dev")
                total_loss = 0
                for batch in self.dev_data:
                    step += 1
                    batch = (tup.to(self.args.device)  if isinstance(tup, torch.Tensor) else tup for tup in batch)  # to cpu/cuda device
                    attention_mask, labels, logits, loss = self._step(batch, mode="dev")    # logits: batch, seq, num_labels
                    total_loss += loss.detach().cpu().item()

                    if isinstance(logits, torch.Tensor):    
                        logits = logits.argmax(-1).detach().cpu().numpy()  # batch, seq, 1
                    label_ids = labels.detach().cpu().numpy()
                    input_mask = attention_mask.detach().cpu().numpy()
                    label_map = {idx:label for label, idx in self.label_map.items()}
                    for i, mask in enumerate(input_mask):
                        temp_1 = []
                        temp_2 = []
                        temp_1_idx, temp_2_idx = [], []
                        for j, m in enumerate(mask):
                            if j == 0:
                                continue
                            if m:
                                if label_map[label_ids[i][j]] != "X" and label_map[label_ids[i][j]] != "[SEP]":
                                    temp_1.append(label_map[label_ids[i][j]])
                                    temp_2.append(label_map[logits[i][j]])
                                    temp_1_idx.append(label_ids[i][j])
                                    temp_2_idx.append(logits[i][j])
                            else:
                                break
                        y_true.append(temp_1)
                        y_pred.append(temp_2)
                        y_true_idx.append(temp_1_idx)
                        y_pred_idx.append(temp_2_idx)
                    
                    pbar.update()
                # evaluate done
                pbar.close()

                results = classification_report(y_true, y_pred, digits=4)  
                self.logger.info("***** Dev Eval results *****")
                self.logger.info("\n%s", results)
                f1_score = float(results.split('\n')[-4].split('      ')[-2].split('    ')[-1])
                if self.writer: 
                    self.writer.log({'eva_f1': f1_score})

                self.logger.info("Epoch {}/{}, best dev f1: {}, best epoch: {}, current dev f1 score: {}."\
                            .format(epoch, self.args.num_epochs, self.best_dev_metric, self.best_dev_epoch, f1_score))
                if f1_score >= self.best_dev_metric:  # this epoch get best performance
                    self.logger.info("Get better performance at epoch {}".format(epoch))
                    self.best_dev_epoch = epoch
                    self.best_dev_metric = f1_score # update best metric(f1 score)
                    if self.args.save_path is not None: # save model
                        torch.save(self.model.state_dict(), self.args.save_path+"/best_model.pth")
                        self.logger.info("Save best model at {}".format(self.args.save_path))
               

        self.model.train()

    def test(self, epoch):
        self.model.eval()
        self.logger.info("\n***** Running testing *****")
        self.logger.info("  Num instance = %d", len(self.test_data)*self.args.batch_size)
        self.logger.info("  Batch size = %d", self.args.batch_size)

        if self.args.load_path is not None:  # load model from load_path
            self.logger.info("Loading model from {}".format(self.args.load_path))
            self.model.load_state_dict(torch.load(self.args.load_path))
            self.logger.info("Load model successful!")
        y_true, y_pred = [], []
        y_true_idx, y_pred_idx = [], []
        with torch.no_grad():
            with tqdm(total=len(self.test_data), leave=False, dynamic_ncols=True) as pbar:
                pbar.set_description_str(desc="Testing")
                for batch in self.test_data:
                    batch = (tup.to(self.args.device)  if isinstance(tup, torch.Tensor) else tup for tup in batch)  # to cpu/cuda device
                    attention_mask, labels, logits, loss = self._step(batch, mode="dev")    # logits: batch, seq, num_labels
            
                    if isinstance(logits, torch.Tensor):    #
                        logits = logits.argmax(-1).detach().cpu().tolist()  # batch, seq, 1
                    label_ids = labels.detach().cpu().numpy()
                    input_mask = attention_mask.detach().cpu().numpy()
                    label_map = {idx:label for label, idx in self.label_map.items()}
                    for i, mask in enumerate(input_mask):
                        temp_1 = []
                        temp_2 = []
                        temp_1_idx, temp_2_idx = [], []
                        for j, m in enumerate(mask):
                            if j == 0:
                                continue
                            if m:
                                if label_map[label_ids[i][j]] != "X" and label_map[label_ids[i][j]] != "[SEP]":
                                    temp_1.append(label_map[label_ids[i][j]])
                                    temp_2.append(label_map[logits[i][j]])
                                    temp_1_idx.append(label_ids[i][j])
                                    temp_2_idx.append(logits[i][j])
                            else:
                                break
                        y_true.append(temp_1)
                        y_pred.append(temp_2)
                        y_true_idx.append(temp_1_idx)
                        y_pred_idx.append(temp_2_idx)
                    
                    pbar.update()
                # evaluate done
                pbar.close()

                results = classification_report(y_true, y_pred, digits=4) 
                self.logger.info("***** Test Eval results *****")
                self.logger.info("\n%s", results)
                f1_score = float(results.split('\n')[-4].split('      ')[-2].split('    ')[-1])
                if self.writer:
                    self.writer.log({'test_f1': f1_score})
                total_loss = 0
                
                self.logger.info("Epoch {}/{}, best test f1: {}, best epoch: {}, current test f1 score: {}."\
                            .format(epoch, self.args.num_epochs, self.best_test_metric, self.best_test_epoch, f1_score))
                if f1_score >= self.best_test_metric:  # this epoch get best performance
                    self.best_test_metric = f1_score
                    self.best_test_epoch = epoch
                   
        self.model.train()


    def predict(self):
        self.model.eval()
        self.logger.info("\n***** Running predicting *****")
        self.logger.info("  Num instance = %d", len(self.test_data)*self.args.batch_size)
        self.logger.info("  Batch size = %d", self.args.batch_size)
        if self.args.load_path is not None:  # load model from load_path
            self.logger.info("Loading model from {}".format(self.args.load_path))
            self.model.load_state_dict(torch.load(self.args.load_path))
            self.logger.info("Load model successful!")
            self.model.to(self.args.device)
        y_pred = []

        with torch.no_grad():
            with tqdm(total=len(self.test_data), leave=False, dynamic_ncols=True) as pbar:
                pbar.set_description_str(desc="Predicting")
                for batch in self.test_data:
                    batch = (tup.to(self.args.device)  if isinstance(tup, torch.Tensor) else tup for tup in batch)  # to cpu/cuda device
                    attention_mask, labels, logits, loss = self._step(batch, mode="dev")    # logits: batch, seq, num_labels
            
                    if isinstance(logits, torch.Tensor):    # 
                        logits = logits.argmax(-1).detach().cpu().tolist()  # batch, seq, 1
                    label_ids = labels.detach().cpu().numpy()
                    input_mask = attention_mask.detach().cpu().numpy()
                    label_map = {idx:label for label, idx in self.label_map.items()}
                    for i, mask in enumerate(input_mask):
                        temp_1 = []
                        for j, m in enumerate(mask):
                            if j == 0:
                                continue
                            if m:
                                if label_map[label_ids[i][j]] != "X" and label_map[label_ids[i][j]] != "[SEP]":
                                    temp_1.append(label_map[logits[i][j]])
                            else:
                                break
                        y_pred.append(temp_1)
                    
                    pbar.update()
                # evaluate done
                pbar.close()
        
    def _step(self, batch, mode="train"):
        input_ids, token_type_ids, attention_mask, labels, images, aux_imgs, rcnn_imgs = batch
        logits, loss = self.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=labels, images=images, aux_imgs=aux_imgs, rcnn_imgs=rcnn_imgs)
        return attention_mask, labels, logits, loss


    def multiModal_before_train(self):
        # bert lr
        parameters = []
        params = {'lr':self.args.lr, 'weight_decay':1e-2}
        params['params'] = []
        for name, param in self.model.named_parameters():
            if 'text' in name:
                params['params'].append(param)
        parameters.append(params)

         # vit lr
        params = {'lr':3e-5, 'weight_decay':1e-2}
        params['params'] = []
        for name, param in self.model.named_parameters():
            if 'vision' in name:
                params['params'].append(param)
        parameters.append(params)

        # crf lr
        params = {'lr':5e-2, 'weight_decay':1e-2}
        params['params'] = []
        for name, param in self.model.named_parameters():
            if 'crf' in name or name.startswith('fc'):
                params['params'].append(param)
        parameters.append(params)

        self.optimizer = optim.AdamW(parameters)

        self.model.to(self.args.device)
            
        self.scheduler = get_linear_schedule_with_warmup(optimizer=self.optimizer, 
                                 num_warmup_steps=self.args.warmup_ratio*self.train_num_steps, 
                                 num_training_steps=self.train_num_steps)

In [None]:
trainer = Trainer(train_data=train_dataloader, dev_data=dev_dataloader, test_data=test_dataloader, model=model, label_map=label_mapping, args=cfg, logger=logger, writer=writer)
trainer.train()

## Predict

In [None]:
trainer = Trainer(train_data=None, dev_data=None, test_data=test_dataloader, model=model, label_map=label_mapping, args=cfg, logger=logger, writer=writer)
trainer.predict()