In [1]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import warnings
warnings.filterwarnings("ignore")
import sys
sys.path.append('./utils')
sys.path.append('./utils/APIs')
from utils.common import save_model, write_to_file
import matplotlib.pyplot as plt
import torch
import timm
import argparse
from Config import config
from Trainer import Trainer
from Models.OTEModel import Model,MAMLModel

from torch.utils.data import DataLoader
from dataManagement.DatasetHelper import DatasetHelper
from dataManagement.DatasetLoader import DatasetLoader
from dataManagement.CustomDataset import CustomDataset

2025-06-20 15:48:11.427523: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-06-20 15:48:11.448802: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750405691.474082    2685 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750405691.481850    2685 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-20 15:48:11.508046: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [2]:
# 训练模型
do_train = True
# 预测测试集数据
do_test = True
config.num_words_to_keep = 3000
config.batch_size = 32
config.num_words_x_doc = 100
config.lr = 0.001
config.momentum = 0.9
config.wd = 1e-4
config.load_model_path = './save_models/CMAT/pytorch_model.bin'
config.fuse_model_type = 'CMAT'
config.epoch = 20
config.pre_train_epoch = 30

#仅用文本预测
text_only = False
#仅用图像预测
img_only = False
config.only = 'img' if img_only else None
config.only = 'text' if text_only else None
if img_only and text_only: config.only = None

In [3]:
dataset_file_path = '../dataset/CornDataset/csv/train/train_data.csv'
data_loader = DatasetLoader()
data_loader.load_data(dataset_file_path,False)
train_data = data_loader.get_train_data()
val_data = data_loader.get_val_data()

data_helper = DatasetHelper(config.num_words_to_keep)
train_y, val_y = data_helper.preprocess_labels(train_data, val_data)
train_i, val_i = data_helper.preprocess_images(train_data.get_images(), val_data.get_images())
train_t, val_t = data_helper.preprocess_texts(train_data.get_texts(), val_data.get_texts(), config.num_words_x_doc)

# labels、images、text  set to data_geter
data_loader.set_train_data(train_y, train_i, train_t)
data_loader.set_val_data(val_y, val_i, val_t)

# get CustomDataset (train and val)
train_custom_dataset = CustomDataset(data_loader.get_train_data())
val_custom_dataset = CustomDataset(data_loader.get_val_data())

train_loader = DataLoader(train_custom_dataset, config.batch_size, shuffle=True)
val_loader = DataLoader(val_custom_dataset, config.batch_size, shuffle=False)

Loading data...
----- [Loading]
Train/val split: 2968/743


In [4]:
# Initilaztion
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [5]:
backbone = Model(num_classes=config.num_labels, vocab_size=3000, embedding_size=512)
model = MAMLModel(backbone, config.num_labels)
trainer = Trainer(config, model, device)

In [6]:
# PreTrain
def preTrain():
    # task_lr = 0.01
    task_lr = 0.01
    # inner_steps = 5
    inner_steps = 3
    for e in range(config.pre_train_epoch):
        print('-' * 20 + ' ' + 'PreTrain Epoch ' + str(e+1) + ' ' + '-' * 20)
        trainer.preTrain(train_loader, e, task_lr, inner_steps)
        print()

In [7]:
# Train
def train():
    best_acc = 0
    epoch = config.epoch
    for e in range(epoch):
        print('-' * 20 + ' ' + 'Epoch ' + str(e+1) + ' ' + '-' * 20)
        trainer.train(train_loader,e)
        test_acc = trainer.valid(val_loader,e)
        if test_acc > best_acc:
            best_acc = test_acc
            save_path = './save_models'
            save_model(save_path, config.fuse_model_type, model)
            print('Update best model!')
        print()

In [8]:
# Test
def test():
    test_file_path = '../dataset/CornDataset/csv/test/test_data.csv'
    
    data_loader = DatasetLoader()
    data_loader.load_data(test_file_path,True)
    test_data = data_loader.get_test_data()

    data_helper = DatasetHelper(config.num_words_to_keep)
    test_y = data_helper.preprocess_labels(test_data, None)
    test_i = data_helper.preprocess_images(test_data.get_images(), None)
    test_t = data_helper.preprocess_texts(test_data.get_texts(), None, config.num_words_x_doc)

    data_loader.set_test_data(test_y, test_i, test_t)

    # get CustomDataset
    test_custom_dataset = CustomDataset(data_loader.get_test_data())
    test_loader = DataLoader(test_custom_dataset, config.batch_size, shuffle=True)
    
    if config.load_model_path is not None:
        print("model load successfully")
        model.load_state_dict(torch.load(config.load_model_path))

    trainer.predict(test_loader)

In [None]:
# main
if __name__ == "__main__":
    if do_train:
        # preTrain()
        train()
    
    if do_test:
        if config.load_model_path is None and not do_train:
            print('请输入已训练好模型的路径load_model_path或者选择添加do_train arg')
        else:
            test()

-------------------- Epoch 1 --------------------
epoch:0 - train loss: 1.024 and train acc: 0.628 total sample: 2968
              precision    recall  f1-score   support

           0     0.5218    0.8997    0.6605      1037
           1     0.9913    0.5493    0.7069       832
           2     0.1319    0.1000    0.1137       360
           3     0.9843    0.5940    0.7409       739

    accuracy                         0.6284      2968
   macro avg     0.6573    0.5358    0.5555      2968
weighted avg     0.7213    0.6284    0.6272      2968

epoch:0 - test loss: 0.425 and test acc: 0.855 total sample: 743
              precision    recall  f1-score   support

           0     0.6991    0.9916    0.8201       239
           1     0.9857    0.9718    0.9787       213
           2     0.0000    0.0000    0.0000       100
           3     0.9845    1.0000    0.9922       191

    accuracy                         0.8546       743
   macro avg     0.6673    0.7409    0.6978       743
we