## マルチタスク学習におけるtask_weightsを探索するための試行実験

In [1]:
import torch
import sys
import warnings

warnings.filterwarnings("ignore")
sys.path.append("multi_task_learning_with_t5")

from src.t5_model import T5Model
from src.dataset import SingleTaskDataset, MultiTaskDataset
import src.config as config



In [2]:
# 訓練データと検証データのデータセットを作成し、2つのタスク（MARCとAMCD）のデータセットを組み合わせる
train_dataset_MARC = MultiTaskDataset(task_type="MARC", split="train", data_limit=5600)
train_dataset_AMCD = MultiTaskDataset(task_type="AMCD", split="train")
train_dataset = torch.utils.data.ConcatDataset([train_dataset_MARC, train_dataset_AMCD])

valid_dataset_MARC = MultiTaskDataset(task_type="MARC", split="valid", data_limit=465)
valid_dataset_AMCD = MultiTaskDataset(task_type="AMCD", split="valid")
valid_dataset = torch.utils.data.ConcatDataset([valid_dataset_MARC, valid_dataset_AMCD])

test_dataset_AMCD = MultiTaskDataset(task_type="AMCD", split="test")
test_dataset_MARC = MultiTaskDataset(task_type="MARC", split="test")

In [3]:
# task_weightsの候補
task_weights_candidates = [
    {"MARC": 0.7, "AMCD": 0.3},
    {"MARC": 0.5, "AMCD": 0.5},
    {"MARC": 0.3, "AMCD": 0.7},
    {"MARC": 1.0, "AMCD": 1.0},
]

li=["07_03","05_05","03_07","10_10"]

for ex_name,task_weights in zip(li,task_weights_candidates):
    model = T5Model(multi_task=True, task_weights=task_weights, experiment_name=ex_name)
    model.train(train_dataset, valid_dataset)
    print(f'task_weights:{task_weights}')
    AMCD_results = model.evaluate(test_dataset_AMCD)
    print("Results of AMCD:\n", AMCD_results)
    MARC_results = model.evaluate(test_dataset_MARC)
    print("Results of MARC:\n", MARC_results)

You are using the legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565


Epoch 1:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch [1/4], Loss: 0.3926, val_loss: 0.1633, val_acc: 0.8185


Epoch 2:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch [2/4], Loss: 0.0569, val_loss: 0.0994, val_acc: 0.8539


Epoch 3:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch [3/4], Loss: 0.0490, val_loss: 0.1164, val_acc: 0.8475


Epoch 4:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch [4/4], Loss: 0.0424, val_loss: 0.1276, val_acc: 0.8099
task_weights:{'MARC': 0.7, 'AMCD': 0.3}
Results of AMCD:
 {'loss': 0.073825260168827, 'accuracy': 0.9261241970021413, 'f1_score_macro': 0.8243188354437281, 'f1_score_micro': 0.9261241970021412, 'precision_macro': 0.7913776112558422, 'precision_micro': 0.9261241970021413, 'recall_macro': 0.8712087311058074, 'recall_micro': 0.9261241970021413}
Results of MARC:
 {'loss': 0.18124423635303974, 'accuracy': 0.6992, 'f1_score_macro': 0.30360781587619096, 'f1_score_micro': 0.6992, 'precision_macro': 0.375, 'precision_micro': 0.6992, 'recall_macro': 0.25612500000000005, 'recall_micro': 0.6992}


Epoch 1:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch [1/4], Loss: 0.5976, val_loss: 0.1652, val_acc: 0.8206


Epoch 2:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch [2/4], Loss: 0.0640, val_loss: 0.1035, val_acc: 0.8410


Epoch 3:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch [3/4], Loss: 0.0505, val_loss: 0.1101, val_acc: 0.8260


Epoch 4:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch [4/4], Loss: 0.0400, val_loss: 0.1344, val_acc: 0.8271
task_weights:{'MARC': 0.5, 'AMCD': 0.5}
Results of AMCD:
 {'loss': 0.07550727211945873, 'accuracy': 0.917558886509636, 'f1_score_macro': 0.8203787835269377, 'f1_score_micro': 0.9175588865096361, 'precision_macro': 0.7744326879974288, 'precision_micro': 0.917558886509636, 'recall_macro': 0.9033288583929993, 'recall_micro': 0.917558886509636}
Results of MARC:
 {'loss': 0.17237553269769995, 'accuracy': 0.7396, 'f1_score_macro': 0.22191588663293255, 'f1_score_micro': 0.7396, 'precision_macro': 0.3, 'precision_micro': 0.7396, 'recall_macro': 0.19590000000000002, 'recall_micro': 0.7396}


Epoch 1:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch [1/4], Loss: 0.3614, val_loss: 0.1745, val_acc: 0.7723


Epoch 2:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch [2/4], Loss: 0.0624, val_loss: 0.1070, val_acc: 0.8485


Epoch 3:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch [3/4], Loss: 0.0477, val_loss: 0.1101, val_acc: 0.8453


Epoch 4:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch [4/4], Loss: 0.0407, val_loss: 0.1077, val_acc: 0.8475
task_weights:{'MARC': 0.3, 'AMCD': 0.7}
Results of AMCD:
 {'loss': 0.04811604023696138, 'accuracy': 0.949678800856531, 'f1_score_macro': 0.8642084912226432, 'f1_score_micro': 0.949678800856531, 'precision_macro': 0.8625491137962039, 'precision_micro': 0.949678800856531, 'recall_macro': 0.8658885242641209, 'recall_micro': 0.949678800856531}
Results of MARC:
 {'loss': 0.15860233374163507, 'accuracy': 0.7318, 'f1_score_macro': 0.2362223480916164, 'f1_score_micro': 0.7317999999999999, 'precision_macro': 0.3, 'precision_micro': 0.7318, 'recall_macro': 0.20265, 'recall_micro': 0.7318}


Epoch 1:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch [1/4], Loss: 0.6900, val_loss: 0.1467, val_acc: 0.8034


Epoch 2:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch [2/4], Loss: 0.1346, val_loss: 0.1077, val_acc: 0.8367


Epoch 3:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch [3/4], Loss: 0.1116, val_loss: 0.1076, val_acc: 0.8324


Epoch 4:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch [4/4], Loss: 0.0996, val_loss: 0.1229, val_acc: 0.8153
task_weights:{'MARC': 1.0, 'AMCD': 1.0}
Results of AMCD:
 {'loss': 0.049559377048276045, 'accuracy': 0.9486081370449678, 'f1_score_macro': 0.8631969534225175, 'f1_score_micro': 0.9486081370449678, 'precision_macro': 0.856810551558753, 'precision_micro': 0.9486081370449678, 'recall_macro': 0.8699035401750199, 'recall_micro': 0.9486081370449678}
Results of MARC:
 {'loss': 0.19126679680012165, 'accuracy': 0.6944, 'f1_score_macro': 0.2341568277538667, 'f1_score_micro': 0.6944, 'precision_macro': 0.3333333333333333, 'precision_micro': 0.6944, 'recall_macro': 0.20344444444444443, 'recall_micro': 0.6944}


追加

In [3]:
# task_weightsの候補
task_weights_candidates = [
    {"MARC": 0.4, "AMCD": 0.6},
    {"MARC": 0.2, "AMCD": 0.8},
]

li=["04_06","02_08"]

for ex_name,task_weights in zip(li,task_weights_candidates):
    model = T5Model(multi_task=True, task_weights=task_weights, experiment_name=ex_name)
    model.train(train_dataset, valid_dataset)
    print(f'task_weights:{task_weights}')
    AMCD_results = model.evaluate(test_dataset_AMCD)
    print("Results of AMCD:\n", AMCD_results)
    MARC_results = model.evaluate(test_dataset_MARC)
    print("Results of MARC:\n", MARC_results)

You are using the legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565


Epoch 1:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch [1/4], Loss: 0.5182, val_loss: 0.1145, val_acc: 0.8314


Epoch 2:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch [2/4], Loss: 0.0560, val_loss: 0.1275, val_acc: 0.8163


Epoch 3:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch [3/4], Loss: 0.0463, val_loss: 0.1071, val_acc: 0.8389


Epoch 4:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch [4/4], Loss: 0.0377, val_loss: 0.1170, val_acc: 0.8324
task_weights:{'MARC': 0.4, 'AMCD': 0.6}
Results of AMCD:
 {'loss': 0.06163115676158629, 'accuracy': 0.9464668094218416, 'f1_score_macro': 0.8478093673823212, 'f1_score_micro': 0.9464668094218416, 'precision_macro': 0.8660322509872751, 'precision_micro': 0.9464668094218416, 'recall_macro': 0.8318168257756563, 'recall_micro': 0.9464668094218416}
Results of MARC:
 {'loss': 0.18305685287546367, 'accuracy': 0.7266, 'f1_score_macro': 0.2631668998124179, 'f1_score_micro': 0.7265999999999999, 'precision_macro': 0.3333333333333333, 'precision_micro': 0.7266, 'recall_macro': 0.22494444444444447, 'recall_micro': 0.7266}


Epoch 1:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch [1/4], Loss: 0.8350, val_loss: 0.1894, val_acc: 0.6681


Epoch 2:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch [2/4], Loss: 0.0787, val_loss: 0.1206, val_acc: 0.8238


Epoch 3:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch [3/4], Loss: 0.0562, val_loss: 0.1049, val_acc: 0.8421


Epoch 4:   0%|          | 0/1400 [00:00<?, ?it/s]

Epoch [4/4], Loss: 0.0480, val_loss: 0.1151, val_acc: 0.8593
task_weights:{'MARC': 0.2, 'AMCD': 0.8}
Results of AMCD:
 {'loss': 0.0735309266178813, 'accuracy': 0.9336188436830836, 'f1_score_macro': 0.772517284726587, 'f1_score_micro': 0.9336188436830836, 'precision_macro': 0.8834161869876156, 'precision_micro': 0.9336188436830836, 'recall_macro': 0.7185884049323787, 'recall_micro': 0.9336188436830836}
Results of MARC:
 {'loss': 0.16995917913541198, 'accuracy': 0.7402, 'f1_score_macro': 0.25439875367045495, 'f1_score_micro': 0.7402, 'precision_macro': 0.3333333333333333, 'precision_micro': 0.7402, 'recall_macro': 0.2217222222222222, 'recall_micro': 0.7402}


: 