In [1]:
pip install yacs timm sentencepiece regex

Looking in indexes: https://mirrors.aliyun.com/pypi/simple
Collecting yacs
  Downloading https://mirrors.aliyun.com/pypi/packages/38/4f/fe9a4d472aa867878ce3bb7efb16654c5d63672b86dc0e6e953a67018433/yacs-0.1.8-py3-none-any.whl (14 kB)
Collecting timm
  Downloading https://mirrors.aliyun.com/pypi/packages/f0/1e/05287cb8984229d101874433df472b1fa3dcd6f746ccb6e8a26c7deeb1c7/timm-0.9.10-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting sentencepiece
  Downloading https://mirrors.aliyun.com/pypi/packages/c9/58/4fbd3f33a38c9809fedf57bbef7e086b9909d6807148f35d68c0c90896d3/sentencepiece-0.1.99-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting regex
  Downloading https://mirrors.aliyun.com/pypi/packages/7

In [2]:
import os
import argparse
import torch
import numpy as np
from torch.optim import AdamW
from yacs.config import CfgNode as CN
from tqdm import tqdm,trange
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from toolcls import AbmsaProcessor,seed_everything,convert_mm_examples_to_features,BertConfig
from models.swin.swintransformer import SwinTransformer,get_config
from sklearn.metrics import precision_recall_fscore_support
from models.deberta.spm_tokenizer import SPMTokenizer
from models.deberta.deberta import SwinBERTa
from models.logs import logger
from torch.optim.lr_scheduler import LambdaLR
import math

try:
    import safetensors.torch
    _has_safetensors = True
except ImportError:
    _has_safetensors = False

  from .autonotebook import tqdm as notebook_tqdm


# 参数配置

In [3]:

# img_encoder config
_C = CN()
config = _C.clone()
config.LOCAL_RANK = -1

class WarmupCosineSchedule(LambdaLR):
    def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1):
        self.warmup_steps = warmup_steps
        self.t_total = t_total
        self.cycles = cycles
        super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)

    def lr_lambda(self, step):
        if step < self.warmup_steps:
            return float(step) / float(max(1.0, self.warmup_steps))
        # progress after warmup
        progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
        return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))

def accuracy(out, labels):
    outputs = np.argmax(out, axis=1)
    return np.sum(outputs == labels)


def warmup_linear(x, warmup=0.002):
    if x < warmup:
        return x/warmup
    return (1-x)/(1-warmup)


def macro_f1(y_true, y_pred):
    preds = np.argmax(y_pred, axis=-1)
    true = y_true
    p_macro, r_macro, f_macro, support_macro \
      = precision_recall_fscore_support(true, preds, average='macro')
    #f_macro = 2*p_macro*r_macro/(p_macro+r_macro)
    return p_macro, r_macro, f_macro


parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--data_dir",default='./twitterdataset/absa_data/twitter',type=str,#文本数据位置
                    help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
parser.add_argument("--img_ckpt", default='./pretrains/swin_base_patch4_window7_224_1k.pth', type=str,#swin预训练模型的准确位置
                    help="Bert pre-trained model selected in the list: S24-224,S24-336 ")
parser.add_argument("--spm_model_file",default='./pretrains/30k-clean.model',type=str)#albert预训练模型的准确位置
parser.add_argument("--model_name_or_path", default='./pretrains', type=str,#albert预训练模型存放的文件夹
                    help="Path to pre-trained model or shortcut name selected in the list")
parser.add_argument("--task_name",default='twitter17',type=str,#要加载哪个数据集 Twitter是17 Twitter15是15
                    help="The name of the task to train.")
parser.add_argument("--output_dir",default='./output',type=str,
                    help="The output directory where the model predictions and checkpoints will be written.")
parser.add_argument('--path_image', default='./twitterdataset/img_data/twitter2017_images',#图像的位置
                    help='path to images')
parser.add_argument('--init_model',
                    type=str,
                    default='./pretrains/pytorch_model.bin',
                    help="The model state file used to initialize the model weights.")
parser.add_argument('--model_config',
                    type=str,
                    default='./pretrains/config.json',
                    help="The config file of bert model.")
parser.add_argument('--pre_trained',
                    default=None,
                    type=str,
                    help="The path of pre-trained RoBERTa model")
parser.add_argument('--vocab_path',
                    default='./pretrains/spm.model',
                    type=str,
                    help="The path of the vocabulary")
## Other parameters
parser.add_argument('--crop_size', type=int, default=224, help='crop size of image')
parser.add_argument("--max_seq_length",default=64,type=int,
                    help="The maximum total input sequence length after WordPiece tokenization. \n"
                         "Sequences longer than this will be truncated, and sequences shorter \n"
                         "than this will be padded.")
parser.add_argument("--max_entity_length",default=16,type=int,
                    help="The maximum entity input sequence length after WordPiece tokenization. \n"
                         "Sequences longer than this will be truncated, and sequences shorter \n"
                         "than this will be padded.")
parser.add_argument("--do_train",action='store_true',default=True,
                    help="Whether to run training.")
parser.add_argument("--do_lower_case",action='store_true',default=True,
                    help="Set this flag if you are using an uncased model.")
parser.add_argument("--train_batch_size",default=24,type=int,
                    help="Total batch size for training.")
parser.add_argument("--eval_batch_size",default=24,type=int,
                    help="Total batch size for eval.")
parser.add_argument("--learning_rate",default=3e-5,type=float,#可能会暴毙
                    help="The initial learning rate for Adam.")
parser.add_argument("--num_train_epochs",default=8.0,type=float,
                    help="Total number of training epochs to perform.")
parser.add_argument("--warmup_proportion",default=0.1,type=float,
                    help="Proportion of training to perform linear learning rate warmup for. "
                         "E.g., 0.1 = 10%% of training.")
parser.add_argument("--no_cuda",action='store_true',default=False,
                    help="Whether not to use CUDA when available")
parser.add_argument("--local_rank",type=int,default=-1,
                    help="local_rank for distributed training on gpus")
parser.add_argument('--seed',type=int,default=6,
                    help="random seed for initialization")
parser.add_argument('--gradient_accumulation_steps',type=int,default=1,
                    help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument('--fp16',default=False,action='store_true',#kaggle没搞懂怎么导入apex 索性就32位吧
                    help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--loss_scale',type=float, default=0,
                    help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                         "0 (default value): dynamic loss scaling.\n"
                         "Positive power of 2: static loss scaling value.\n")
parser.add_argument('--overwrite_output_dir', action='store_true',default=True,#是否覆盖原本的输出文件
                    help="Overwrite the content of the output directory")
parser.add_argument("--config_name", default="", type=str,
                    help="Pretrained config name or path if not the same as model_name")
parser.add_argument('--cfg', type=str, default="./pretrains/swin_base_patch4_window7_224.yaml", metavar="FILE",
                    help='path to config file', )
args = parser.parse_args(args=[])

# 基础随机数 cuda 路径之类的东西配置

In [4]:
if args.task_name == "twitter17":
    args.path_image = "./twitterdataset/img_data/twitter2017_images"
elif args.task_name == "twitter15":
    args.path_image = "./twitterdataset/img_data/twitter2015_images"
else:
    print("The task name is not right!")
processors = {
        "twitter15": AbmsaProcessor,    # our twitter-2015 dataset
        "twitter17": AbmsaProcessor         # our twitter-2017 dataset
}
num_labels_task = {
    "twitter15": 3,                # our twitter-2015 dataset
    "twitter17": 3                     # our twitter-2017 dataset
}
seed_everything(args.seed) #固定随机数种子
task_name = args.task_name.lower()
#初始化输出的文件夹
if not os.path.exists(args.output_dir):
    os.mkdir(args.output_dir)
args.output_dir = args.output_dir
if os.path.exists(args.output_dir) and os.listdir(
        args.output_dir) and args.do_train and not args.overwrite_output_dir:
    raise ValueError(
        "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
            args.output_dir))
#设置cuda
if config.LOCAL_RANK == -1 or args.no_cuda:
    device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    args.n_gpu = torch.cuda.device_count()
else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    torch.cuda.set_device(config.LOCAL_RANK)
    device = torch.device("cuda", config.LOCAL_RANK)
    torch.distributed.init_process_group(backend='nccl')
    args.n_gpu = 1
args.device = device
processor = processors[task_name]()#获得读取tsv文件方法
num_labels = num_labels_task[task_name]#判定几分类
label_list = processor.get_labels()#获得分类标签

# 定义模型 载入预训练模型

In [5]:
config_file = os.path.join(args.model_name_or_path, 'bert_config.json')
bert_config = BertConfig.from_json_file(config_file)
# 读取词表 方便之后把文字转数字id 返回的是一个30000词字典
tokenizer = SPMTokenizer(args.vocab_path)
#创建并导入预训练模型 这个模型用于文本encoder 融合图文feature
model=SwinBERTa(args, bert_config,args.init_model)
model.to(device)
#设定图像encoder 
config = get_config(args)
encoder = SwinTransformer(img_size=224,
                          patch_size=config.MODEL.SWIN.PATCH_SIZE,
                          in_chans=config.MODEL.SWIN.IN_CHANS,
                          num_classes=config.MODEL.NUM_CLASSES,
                          embed_dim=config.MODEL.SWIN.EMBED_DIM,
                          depths=config.MODEL.SWIN.DEPTHS,
                          num_heads=config.MODEL.SWIN.NUM_HEADS,
                          window_size=config.MODEL.SWIN.WINDOW_SIZE,
                          mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
                          qkv_bias=config.MODEL.SWIN.QKV_BIAS,
                          qk_scale=config.MODEL.SWIN.QK_SCALE,
                          drop_rate=config.MODEL.DROP_RATE,
                          drop_path_rate=config.MODEL.DROP_PATH_RATE,
                          patch_norm=config.MODEL.SWIN.PATCH_NORM,
                          use_checkpoint=False)
pretrained_dict = torch.load(args.img_ckpt, map_location='cpu')
pretrained_dict = pretrained_dict['model']
unexpected_keys = {"head.weight", "head.bias"}
# 删除不匹配的键值
for key in unexpected_keys:
    del pretrained_dict[key]
missing_keys, unexpected_keys = encoder.load_state_dict(pretrained_dict, strict=False)
encoder.to(device)

2023-11-19 06:34:56,142 - root - INFO - Loaded pretrained model file ./pretrains/pytorch_model.bin
2023-11-19 06:34:58,680 - root - INFO - Loaded pretrained model file ./pretrains/pytorch_model.bin


=> merge config from ./pretrains/swin_base_patch4_window7_224.yaml


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


SwinTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
    (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (layers): ModuleList(
    (0): BasicLayer(
      (blocks): ModuleList(
        (0): SwinTransformerBlock(
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=128, out_features=384, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=128, out_features=128, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=128, out_features=512, bias=True)
            (act): GELU(approximate='none')
           

# 加载优化器 pth保存路径配置

In [6]:
train_examples = processor.get_train_examples(args.data_dir)#获取训练集文本内容
eval_examples = processor.get_dev_examples(args.data_dir)#获取验证集文本内容
num_train_steps = int(len(train_examples) / args.train_batch_size * args.num_train_epochs)
t_total = num_train_steps
#文本和融合部分参数的优化策略
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters1 = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
#图像部分参数优化策略
def check_keywords_in_name(name, keywords=()):
    isin = False
    for keyword in keywords:
        if keyword in name:
            isin = True
    return isin
def set_weight_decay(model, skip_list=(), skip_keywords=()):
    has_decay = []
    no_decay = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \
                check_keywords_in_name(name, skip_keywords):
            no_decay.append(param)
        else:
            has_decay.append(param)
    return [{'params': has_decay,'weight_decay': 0.01},
            {'params': no_decay, 'weight_decay': 0.}]
skip = {'absolute_pos_embed'}
skip_keywords = {'relative_position_bias_table'}
optimizer_grouped_parameters2 = set_weight_decay(encoder, skip, skip_keywords)
#合并两组参数统一传入adamw优化器调参
optimizer_grouped_parameters = optimizer_grouped_parameters1 + optimizer_grouped_parameters2
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate,eps=1e-6)

num_train_steps = int(
    len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
t_total = num_train_steps
scheduler = WarmupCosineSchedule(optimizer, warmup_steps=int(t_total * args.warmup_proportion), t_total=t_total)
output_model_file = os.path.join(args.output_dir, "pytorch_model.pth")
output_encoder_file = os.path.join(args.output_dir, "pytorch_encoder.pth")

2023-11-19 06:35:07,089 - root - INFO - LOOKING AT ./twitterdataset/absa_data/twitter/train.tsv


# 载入训练集 验证集

In [7]:
train_features = convert_mm_examples_to_features(
            train_examples, label_list, args.max_seq_length, args.max_entity_length, tokenizer, args.crop_size,
            args.path_image)
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_added_input_mask = torch.tensor([f.added_input_mask for f in train_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
all_s2_input_ids = torch.tensor([f.s2_input_ids for f in train_features], dtype=torch.long)
all_s2_input_mask = torch.tensor([f.s2_input_mask for f in train_features], dtype=torch.long)
all_s2_segment_ids = torch.tensor([f.s2_segment_ids for f in train_features], dtype=torch.long)
all_img_feats = torch.stack([f.img_feat for f in train_features])
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
train_data = TensorDataset(all_input_ids, all_input_mask, all_added_input_mask, all_segment_ids, \
                           all_s2_input_ids, all_s2_input_mask, all_s2_segment_ids,
                           all_img_feats, all_label_ids)
if args.local_rank == -1:
    train_sampler = RandomSampler(train_data)
else:
    train_sampler = DistributedSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size,
                              drop_last=True)
# 获取验证集
eval_features = convert_mm_examples_to_features(
    eval_examples, label_list, args.max_seq_length, args.max_entity_length, tokenizer, args.crop_size,
    args.path_image)
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_added_input_mask = torch.tensor([f.added_input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_s2_input_ids = torch.tensor([f.s2_input_ids for f in eval_features], dtype=torch.long)
all_s2_input_mask = torch.tensor([f.s2_input_mask for f in eval_features], dtype=torch.long)
all_s2_segment_ids = torch.tensor([f.s2_segment_ids for f in eval_features], dtype=torch.long)
all_img_feats = torch.stack([f.img_feat for f in eval_features])
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_added_input_mask, all_segment_ids, \
                          all_s2_input_ids, all_s2_input_mask, all_s2_segment_ids, \
                          all_img_feats, all_label_ids)
# Run prediction for full data
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size, drop_last=True)
global_step = 0
nb_tr_steps = 0
tr_loss = 0
max_acc = 0.0

2023-11-19 06:35:07,198 - root - INFO - *** Example ***
2023-11-19 06:35:07,199 - root - INFO - guid: train-1
2023-11-19 06:35:07,200 - root - INFO - tokens: [CLS] ▁how ▁$ t $ ▁is ▁changing ▁the ▁influencer ▁game ▁: [SEP] ▁jake ▁paul [SEP]
2023-11-19 06:35:07,202 - root - INFO - input_ids: 1 361 419 297 1814 269 2198 262 29655 522 877 2 109757 38723 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
2023-11-19 06:35:07,203 - root - INFO - input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
2023-11-19 06:35:07,204 - root - INFO - segment_ids: 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
2023-11-19 06:35:07,205 - root - INFO - label: 2 (id = 2)


the number of problematic samples: 134
the max length of sentence a: 49 entity b: 10 total length: 57


2023-11-19 06:36:26,381 - root - INFO - *** Example ***
2023-11-19 06:36:26,385 - root - INFO - guid: dev-1
2023-11-19 06:36:26,386 - root - INFO - tokens: [CLS] ▁looking ▁forward ▁to ▁the ▁$ t $ ▁from ▁4 ▁- ▁8 ▁july ▁! ▁more ▁info ▁here ▁# ▁heritage ▁# ▁music [SEP] ▁f other ing hay ▁festival [SEP]
2023-11-19 06:36:26,387 - root - INFO - input_ids: 1 562 939 264 262 419 297 1814 292 453 341 578 52434 1084 310 2470 422 953 5456 953 755 2 2994 10705 510 28577 3694 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
2023-11-19 06:36:26,389 - root - INFO - input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
2023-11-19 06:36:26,390 - root - INFO - segment_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
2023-11-19 06:36:26,390 - root - INFO - label: 2 (id = 2)


the number of problematic samples: 29
the max length of sentence a: 48 entity b: 10 total length: 55


# 训练

In [8]:
logger.info("*************** Running training ***************")
for train_idx in trange(int(args.num_train_epochs), desc="Epoch"):
    logger.info("********** Epoch: " + str(train_idx) + " **********")
    logger.info("  Num examples = %d", len(train_examples))
    logger.info("  Batch size = %d", args.train_batch_size)
    logger.info("  Num steps = %d", num_train_steps)
    model.train()
    encoder.train()
    tr_loss = 0
    nb_tr_examples, nb_tr_steps = 0, 0
    progress_bar = tqdm(enumerate(train_dataloader), desc="Iteration", total=len(train_dataloader), position=0)
    for step, batch in progress_bar:
        batch = tuple(t.to(device) for t in batch)
        input_ids, input_mask, added_input_mask, segment_ids, s2_input_ids, s2_input_mask, s2_segment_ids, \
            img_feats, label_ids = batch
        img_att = encoder(img_feats)
        loss = model(input_ids, s2_input_ids, img_att, segment_ids, s2_segment_ids, input_mask,
                         s2_input_mask, \
                         added_input_mask, label_ids)
        loss.backward()
        tr_loss += loss.item()
        nb_tr_examples += input_ids.size(0)
        nb_tr_steps += 1
        scheduler.step()  # 使用学习率调整算法
        for param_group in optimizer.param_groups:
            progress_bar.set_description(f"Iteration (loss: {loss.item():.4f},lr:{param_group['lr']:.10f})")
        optimizer.step()
        optimizer.zero_grad()
        global_step += 1
    logger.info("***** Running evaluation on Dev Set*****")
    logger.info("  Num examples = %d", len(eval_examples))
    logger.info("  Batch size = %d", args.eval_batch_size)
    model.eval()
    encoder.eval()
    eval_loss, eval_accuracy = 0, 0
    nb_eval_steps, nb_eval_examples = 0, 0
    true_label_list = []
    pred_label_list = []
    #验证
    progress_bar = tqdm(eval_dataloader, desc="Evaluating", position=0)
    for input_ids, input_mask, added_input_mask, segment_ids, s2_input_ids, s2_input_mask, s2_segment_ids, \
            img_feats, label_ids in progress_bar:
        input_ids = input_ids.to(device)
        input_mask = input_mask.to(device)
        added_input_mask = added_input_mask.to(device)
        segment_ids = segment_ids.to(device)
        s2_input_ids = s2_input_ids.to(device)
        s2_input_mask = s2_input_mask.to(device)
        s2_segment_ids = s2_segment_ids.to(device)
        img_feats = img_feats.to(device)
        label_ids = label_ids.to(device)
        with torch.no_grad():
            img_att = encoder(img_feats)
            tmp_eval_loss = model(input_ids, s2_input_ids, img_att, segment_ids, s2_segment_ids,
                                  input_mask, s2_input_mask, added_input_mask, label_ids)
            logits = model(input_ids, s2_input_ids, img_att, segment_ids, s2_segment_ids, input_mask,
                           s2_input_mask, added_input_mask)
        logits = logits.detach().cpu().numpy()
        label_ids = label_ids.to('cpu').numpy()
        true_label_list.append(label_ids)
        pred_label_list.append(logits)
        tmp_eval_accuracy = accuracy(logits, label_ids)
        progress_bar.set_description(f"Evaluating (tmp_eval_loss: {tmp_eval_loss.item():.4f})")
        eval_loss += tmp_eval_loss.mean().item()
        eval_accuracy += tmp_eval_accuracy
        nb_eval_examples += input_ids.size(0)
        nb_eval_steps += 1
    eval_loss = eval_loss / nb_eval_steps
    eval_accuracy = eval_accuracy / nb_eval_examples
    loss = tr_loss / nb_tr_steps if args.do_train else None
    true_label = np.concatenate(true_label_list)
    pred_outputs = np.concatenate(pred_label_list)
    precision, recall, F_score = macro_f1(true_label, pred_outputs)
    result = {'eval_loss': eval_loss,
              'eval_accuracy': eval_accuracy,
              'f_score': F_score,
              'global_step': global_step,
              'loss': loss}
    logger.info("***** Dev Eval results *****")
    for key in sorted(result.keys()):
        logger.info("  %s = %s", key, str(result[key]))
    if eval_accuracy >= max_acc:
        # Save a trained model
        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
        encoder_to_save = encoder.module if hasattr(encoder,
                                                    'module') else encoder  # Only save the model it-self
        if args.do_train:
            torch.save(model_to_save.state_dict(), output_model_file)
            torch.save(encoder_to_save.state_dict(), output_encoder_file)
        max_acc = eval_accuracy

2023-11-19 06:37:01,113 - root - INFO - *************** Running training ***************
Epoch:   0%|          | 0/8 [00:00<?, ?it/s]2023-11-19 06:37:01,120 - root - INFO - ********** Epoch: 0 **********
2023-11-19 06:37:01,121 - root - INFO -   Num examples = 3562
2023-11-19 06:37:01,122 - root - INFO -   Batch size = 24
2023-11-19 06:37:01,123 - root - INFO -   Num steps = 1187
Iteration (loss: 0.8746,lr:0.0000299417): 100%|██████████| 148/148 [01:12<00:00,  2.03it/s]
2023-11-19 06:38:14,028 - root - INFO - ***** Running evaluation on Dev Set*****
2023-11-19 06:38:14,029 - root - INFO -   Num examples = 1176
2023-11-19 06:38:14,030 - root - INFO -   Batch size = 24
Evaluating (tmp_eval_loss: 0.5127): 100%|██████████| 49/49 [00:09<00:00,  5.25it/s]
2023-11-19 06:38:23,392 - root - INFO - ***** Dev Eval results *****
2023-11-19 06:38:23,392 - root - INFO -   eval_accuracy = 0.608843537414966
2023-11-19 06:38:23,393 - root - INFO -   eval_loss = 0.8121368301158048
2023-11-19 06:38:23,39

# 取最好的结果在测试集上验证

In [9]:
torch.cuda.empty_cache()  # 先把cuda清空了
model.load_state_dict(torch.load(output_model_file))#载入eval最好的结果
encoder.load_state_dict(torch.load(output_encoder_file))#载入eval最好的结果
eval_examples = processor.get_test_examples(args.data_dir)#获得测试集
eval_features = convert_mm_examples_to_features(
    eval_examples, label_list, args.max_seq_length, args.max_entity_length, tokenizer, args.crop_size,
    args.path_image)
logger.info("***** Running evaluation on Test Set*****")
logger.info("  Num examples = %d", len(eval_examples))
logger.info("  Batch size = %d", args.eval_batch_size)
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_added_input_mask = torch.tensor([f.added_input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_s2_input_ids = torch.tensor([f.s2_input_ids for f in eval_features], dtype=torch.long)
all_s2_input_mask = torch.tensor([f.s2_input_mask for f in eval_features], dtype=torch.long)
all_s2_segment_ids = torch.tensor([f.s2_segment_ids for f in eval_features], dtype=torch.long)
all_img_feats = torch.stack([f.img_feat for f in eval_features])
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_added_input_mask, all_segment_ids, \
                          all_s2_input_ids, all_s2_input_mask, all_s2_segment_ids,
                          all_img_feats, all_label_ids)
# Run prediction for full data
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
model.eval()
encoder.eval()
eval_loss, eval_accuracy = 0, 0
nb_eval_steps, nb_eval_examples = 0, 0
true_label_list = []
pred_label_list = []
for input_ids, input_mask, added_input_mask, segment_ids, s2_input_ids, s2_input_mask, s2_segment_ids, \
        img_feats, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
    input_ids = input_ids.to(device)
    input_mask = input_mask.to(device)
    added_input_mask = added_input_mask.to(device)
    segment_ids = segment_ids.to(device)
    s2_input_ids = s2_input_ids.to(device)
    s2_input_mask = s2_input_mask.to(device)
    s2_segment_ids = s2_segment_ids.to(device)
    img_feats = img_feats.to(device)
    label_ids = label_ids.to(device)
    with torch.no_grad():
        img_att = encoder(img_feats)
        tmp_eval_loss = model(input_ids, s2_input_ids, img_att, segment_ids, s2_segment_ids,
                              input_mask, s2_input_mask, added_input_mask, label_ids)
        logits = model(input_ids, s2_input_ids, img_att, segment_ids, s2_segment_ids, input_mask,
                       s2_input_mask, added_input_mask)
    logits = logits.detach().cpu().numpy()
    label_ids = label_ids.to('cpu').numpy()
    true_label_list.append(label_ids)
    pred_label_list.append(logits)
    tmp_eval_accuracy = accuracy(logits, label_ids)
    eval_loss += tmp_eval_loss.mean().item()
    eval_accuracy += tmp_eval_accuracy
    nb_eval_examples += input_ids.size(0)
    nb_eval_steps += 1
eval_loss = eval_loss / nb_eval_steps
eval_accuracy = eval_accuracy / nb_eval_examples
loss = tr_loss / nb_tr_steps if args.do_train else None
true_label = np.concatenate(true_label_list)
pred_outputs = np.concatenate(pred_label_list)
precision, recall, F_score = macro_f1(true_label, pred_outputs)
result = {'eval_loss': eval_loss,
          'eval_accuracy': eval_accuracy,
          'precision': precision,
          'recall': recall,
          'f_score': F_score,
          'global_step': global_step,
          'loss': loss}
pred_label = np.argmax(pred_outputs, axis=-1)
fout_p = open(os.path.join(args.output_dir, "pred.txt"), 'w')
fout_t = open(os.path.join(args.output_dir, "true.txt"), 'w')
for i in range(len(pred_label)):
    attstr = str(pred_label[i])
    fout_p.write(attstr + '\n')
for i in range(len(true_label)):
    attstr = str(true_label[i])
    fout_t.write(attstr + '\n')
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer:
    logger.info("***** Test Eval results *****")
    for key in sorted(result.keys()):
        logger.info("  %s = %s", key, str(result[key]))
        writer.write("%s = %s\n" % (key, str(result[key])))
fout_p.close()
fout_t.close()
print(result)

2023-11-19 06:47:46,054 - root - INFO - *** Example ***
2023-11-19 06:47:46,059 - root - INFO - guid: test-1
2023-11-19 06:47:46,060 - root - INFO - tokens: [CLS] ▁# ▁$ t $ ▁performs ▁at ▁stagecoach ▁# ▁music festival ▁2016 [SEP] ▁sam hunt [SEP]
2023-11-19 06:47:46,061 - root - INFO - input_ids: 1 953 419 297 1814 8993 288 109755 953 755 47550 892 2 20782 39396 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
2023-11-19 06:47:46,062 - root - INFO - input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
2023-11-19 06:47:46,063 - root - INFO - segment_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
2023-11-19 06:47:46,064 - root - INFO - label: 2 (id = 2)
2023-11-19 06:48:13,213 - root - INFO - ***** Running evaluation on Test Set*****
2023-11-19 06:48:13,215 - root - INFO

the number of problematic samples: 62
the max length of sentence a: 71 entity b: 10 total length: 67


Evaluating: 100%|██████████| 52/52 [00:09<00:00,  5.22it/s]
2023-11-19 06:48:25,551 - root - INFO - ***** Test Eval results *****
2023-11-19 06:48:25,552 - root - INFO -   eval_accuracy = 0.7293354943273906
2023-11-19 06:48:25,553 - root - INFO -   eval_loss = 1.2056335107638285
2023-11-19 06:48:25,554 - root - INFO -   f_score = 0.7168582627870886
2023-11-19 06:48:25,555 - root - INFO -   global_step = 1184
2023-11-19 06:48:25,557 - root - INFO -   loss = 0.03584640223935649
2023-11-19 06:48:25,558 - root - INFO -   precision = 0.7189185119520857
2023-11-19 06:48:25,559 - root - INFO -   recall = 0.7148950932602686


{'eval_loss': 1.2056335107638285, 'eval_accuracy': 0.7293354943273906, 'precision': 0.7189185119520857, 'recall': 0.7148950932602686, 'f_score': 0.7168582627870886, 'global_step': 1184, 'loss': 0.03584640223935649}
