In [1]:
import transformers as T
from datasets import load_dataset
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from tqdm import tqdm
from torchmetrics import SpearmanCorrCoef, Accuracy, F1Score

device = "cuda:0" if torch.cuda.is_available() else "cpu"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 有些中文的標點符號在tokenizer編碼以後會變成[UNK]，所以將其換成英文標點
token_replacement = [
    ["：" , ":"],
    ["，" , ","],
    ["“" , "\""],
    ["”" , "\""],
    ["？" , "?"],
    ["……" , "..."],
    ["！" , "!"]
]

In [3]:
# model = MultiLabelModel().to(device)
tokenizer = T.BertTokenizer.from_pretrained("google-bert/bert-base-uncased", cache_dir="./cache/")



In [4]:
class SemevalDataset(Dataset):
    def __init__(self, split="train") -> None:
        super().__init__()
        assert split in ["train", "validation", "test"]
        self.data = load_dataset(
            "sem_eval_2014_task_1", split=split, cache_dir="./cache/"
        ).to_list()

    def __getitem__(self, index):
        d = self.data[index]
        # 把中文標點替換掉
        for k in ["premise", "hypothesis"]:
            for tok in token_replacement:
                d[k] = d[k].replace(tok[0], tok[1])
        return d

    def __len__(self):
        return len(self.data)

data_sample = SemevalDataset(split="train").data[:3]
print(f"Dataset example: \n{data_sample[0]} \n{data_sample[1]} \n{data_sample[2]}")

Dataset example: 
{'sentence_pair_id': 1, 'premise': 'A group of kids is playing in a yard and an old man is standing in the background', 'hypothesis': 'A group of boys in a yard is playing and a man is standing in the background', 'relatedness_score': 4.5, 'entailment_judgment': 0} 
{'sentence_pair_id': 2, 'premise': 'A group of children is playing in the house and there is no man standing in the background', 'hypothesis': 'A group of kids is playing in a yard and an old man is standing in the background', 'relatedness_score': 3.200000047683716, 'entailment_judgment': 0} 
{'sentence_pair_id': 3, 'premise': 'The young boys are playing outdoors and the man is smiling nearby', 'hypothesis': 'The kids are playing outdoors near a man with a smile', 'relatedness_score': 4.699999809265137, 'entailment_judgment': 1}


In [5]:
SemevalDataset(split="train").data[0]

{'sentence_pair_id': 1,
 'premise': 'A group of kids is playing in a yard and an old man is standing in the background',
 'hypothesis': 'A group of boys in a yard is playing and a man is standing in the background',
 'relatedness_score': 4.5,
 'entailment_judgment': 0}

In [6]:
# Define the hyperparameters
lr = 2e-5
epochs = 3
train_batch_size = 8
validation_batch_size = 8

In [7]:
# TODO1: Create batched data for DataLoader
# `collate_fn` is a function that defines how the data batch should be packed.
# This function will be called in the DataLoader to pack the data batch.

import torch.utils
import torch.utils.data
import torch.utils.data.dataloader
import torch.utils.data.dataset


def collate_fn(batch):
    # TODO1-1: Implement the collate_fn function
    # Write your code here
    # The input parameter is a data batch (tuple), and this function packs it into tensors.
    # Use tokenizer to pack tokenize and pack the data and its corresponding labels.
    # Return the data batch and labels for each sub-task.

    #提取出每個batch的資料

    premises = [item['premise'] for item in batch]
    hypotheses = [item['hypothesis'] for item in batch]
    relatedness_scores = [item['relatedness_score'] for item in batch]
    entailment_judgements = [item['entailment_judgment'] for item in batch]

    #將資料轉換成模型可以讀取的格式 {input_ids, token_type_ids, attention_mask}
    # input_ids: 代表每個token的id
    # token_type_ids: 代表每個token屬於第一句或第二句
    # attention_mask: 代表哪些token是padding
    encoding = tokenizer(
        premises,
        hypotheses,
        padding=True,
        truncation=True,
        return_tensors='pt',
        return_token_type_ids=True, # 返回token_type_ids 用來區分兩個句子，第一句全為0，第二句全為1
    ).to(device)

    # 將label轉換成tensor
    relatedness_scores = torch.tensor(relatedness_scores, dtype=torch.float).to(device)
    entailment_judgements = torch.tensor(entailment_judgements, dtype=torch.long).to(device)

    # 將資料打包成輸出的字典
    batch_output = {
        'input_text':{
            'input_ids': encoding['input_ids'],
            'token_type_ids': encoding['token_type_ids'],
            'attention_mask': encoding['attention_mask']
        },
        'label1': relatedness_scores,
        'label2': entailment_judgements
    }

    return batch_output
    
    

# TODO1-2: Define your DataLoader
dl_train = torch.utils.data.DataLoader(
    SemevalDataset(split="train"),
    batch_size=train_batch_size,
    shuffle=True,
    collate_fn=collate_fn
) # Write your code here

dl_validation = torch.utils.data.DataLoader(
    SemevalDataset(split="validation"),
    batch_size=validation_batch_size,
    shuffle=False,
    collate_fn=collate_fn
) # Write your code here

dl_test = torch.utils.data.DataLoader(
    SemevalDataset(split="test"),
    batch_size=validation_batch_size,
    shuffle=False,
    collate_fn=collate_fn
)

In [8]:
# TODO2: Construct your model
class MultiLabelModel(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Write your code here
        # Define what modules you will use in the model
        
        # 先做一個BertModel，然後加上relatedness_score的regression和entailment_judgement的分類層

        self.bert = T.BertModel.from_pretrained("bert-base-uncased", cache_dir="./cache/")
        self.regression = torch.nn.Linear(self.bert.config.hidden_size, 1)
        self.classification = torch.nn.Linear(self.bert.config.hidden_size, 3)

    def forward(self, **kwargs):
        # Write your code here
        # Forward pass

        input_text = kwargs['input_text']
        outputs = self.bert(
            input_ids=input_text['input_ids'],
            token_type_ids=input_text['token_type_ids'],
            attention_mask=input_text['attention_mask']
        )
        cls_output = outputs.last_hidden_state[:, 0, :]
        regression_output = self.regression(cls_output)
        classification_output = self.classification(cls_output)
        
        return regression_output, classification_output

model = MultiLabelModel().to(device)

In [9]:
# TODO3: Define your optimizer and loss function

# TODO3-1: Define your Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr) # Write your code here

# TODO3-2: Define your loss functions (you should have two)
# Write your code here
loss_fn_regression = torch.nn.MSELoss()
loss_fn_classification = torch.nn.CrossEntropyLoss()

# scoring functions
spc = SpearmanCorrCoef()
acc = Accuracy(task="multiclass", num_classes=3)
f1 = F1Score(task="multiclass", num_classes=3, average='macro')



In [10]:
for ep in range(epochs):
    # Training Loop
    pbar = tqdm(dl_train)
    pbar.set_description(f"Training epoch [{ep+1}/{epochs}]")
    model.train()

    for batch_idx, batch in enumerate(pbar):

        # 清空梯度
        optimizer.zero_grad()

        relatedness_scores = batch['label1'].to(device)
        entailment_judgements = batch['label2'].to(device)
        batch['input_text'] = {
            k: v.to(device) for k, v in batch['input_text'].items() if isinstance(v, torch.Tensor)
        }

        # 前向傳播與計算損失
        regression_output, classification_output = model(**batch)
        loss1 = loss_fn_regression(regression_output.squeeze(), relatedness_scores)
        loss2 = loss_fn_classification(classification_output, entailment_judgements)
        loss = loss1 + loss2

        # 反向傳播與更新參數
        loss.backward()
        optimizer.step()


        pbar.set_postfix({
            "Loss": loss.item()
        })

    # Validation Loop
    pbar = tqdm(dl_validation)
    pbar.set_description(f"Validation epoch [{ep+1}/{epochs}]")
    model.eval()

    spc.to(device)
    acc.to(device)
    f1.to(device)

    spc.reset()
    acc.reset()
    f1.reset()

    with torch.no_grad():
        for batch_idx, batch in enumerate(pbar):
            relatedness_scores = batch['label1'].to(device)
            entailment_judgements = batch['label2'].to(device)
            batch['input_text'] = {
                k: v.to(device) for k, v in batch['input_text'].items() if isinstance(v, torch.Tensor)
            }

            regression_output, classification_output = model(**batch)

            # 更新度量
            spc.update(regression_output.squeeze(), relatedness_scores)
            acc.update(classification_output, entailment_judgements)
            f1.update(classification_output, entailment_judgements)

    print(f"SpearmanCorr: {spc.compute().item()}\n"
          f"Accuracy: {acc.compute().item()}\nF1Score: {f1.compute().item()}\n"
          )

    # 儲存模型
    torch.save(model, f'./saved_models/ep{ep}.ckpt')


Training epoch [1/3]: 100%|██████████| 563/563 [00:15<00:00, 36.57it/s, Loss=1.78] 
Validation epoch [1/3]: 100%|██████████| 63/63 [00:00<00:00, 153.16it/s]


SpearmanCorr: 0.7558720111846924
Accuracy: 0.75
F1Score: 0.7661084532737732



Training epoch [2/3]: 100%|██████████| 563/563 [00:14<00:00, 39.70it/s, Loss=0.574] 
Validation epoch [2/3]: 100%|██████████| 63/63 [00:00<00:00, 165.35it/s]


SpearmanCorr: 0.8257865309715271
Accuracy: 0.8259999752044678
F1Score: 0.8317697048187256



Training epoch [3/3]: 100%|██████████| 563/563 [00:15<00:00, 36.61it/s, Loss=0.366] 
Validation epoch [3/3]: 100%|██████████| 63/63 [00:00<00:00, 145.59it/s]


SpearmanCorr: 0.8328427672386169
Accuracy: 0.8600000143051147
F1Score: 0.8594834804534912



For test set predictions, you can write perform evaluation simlar to #TODO5.

In [11]:
# Test
pbar = tqdm(dl_test)
pbar.set_description(f"test")
model.eval()

spc.to(device)
acc.to(device)
f1.to(device)

spc.reset()
acc.reset()
f1.reset()

with torch.no_grad():
    for batch_idx, batch in enumerate(pbar):
        relatedness_scores = batch['label1'].to(device)
        entailment_judgements = batch['label2'].to(device)
        batch['input_text'] = {
            k: v.to(device) for k, v in batch['input_text'].items() if isinstance(v, torch.Tensor)
        }

        regression_output, classification_output = model(**batch)

        # 更新度量
        spc.update(regression_output.squeeze(), relatedness_scores)
        acc.update(classification_output, entailment_judgements)
        f1.update(classification_output, entailment_judgements)

print(f"Testset\nSpearmanCorr: {spc.compute().item()}\n"
        f"Accuracy: {acc.compute().item()}\nF1Score: {f1.compute().item()}\n"
        )

test: 100%|██████████| 616/616 [00:03<00:00, 156.34it/s]


Testset
SpearmanCorr: 0.8333613872528076
Accuracy: 0.8749746084213257
F1Score: 0.8655876517295837



## 各別任務訓練

### Relatedness

In [26]:
class SingleLabelModelRelatedness(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Write your code here
        # Define what modules you will use in the model
        
        # 先做一個BertModel，然後加上relatedness_score的regression和entailment_judgement的分類層

        self.bert = T.BertModel.from_pretrained("bert-base-uncased", cache_dir="./cache/")
        self.regression = torch.nn.Linear(self.bert.config.hidden_size, 1)
        # self.classification = torch.nn.Linear(self.bert.config.hidden_size, 3)

    def forward(self, **kwargs):
        # Write your code here
        # Forward pass

        input_text = kwargs['input_text']
        outputs = self.bert(
            input_ids=input_text['input_ids'],
            token_type_ids=input_text['token_type_ids'],
            attention_mask=input_text['attention_mask']
        )
        cls_output = outputs.last_hidden_state[:, 0, :]
        regression_output = self.regression(cls_output)
        # classification_output = self.classification(cls_output)
        
        return regression_output

model_relatedness = SingleLabelModelRelatedness().to(device)

In [33]:
for ep in range(epochs):
    # Training Loop
    pbar = tqdm(dl_train)
    pbar.set_description(f"Training epoch [{ep+1}/{epochs}]")
    model_relatedness.train()
    for batch_idx, batch in enumerate(pbar):
        # 清空梯度
        optimizer.zero_grad()
        relatedness_scores = batch['label1'].to(device)
        # entailment_judgements = batch['label2'].to(device)
        batch['input_text'] = {
            k: v.to(device) for k, v in batch['input_text'].items() if isinstance(v, torch.Tensor)
        }

        # 前向傳播與計算損失
        regression_output = model_relatedness(**batch)
        loss1 = loss_fn_regression(regression_output.squeeze(), relatedness_scores)
        # loss2 = loss_fn_classification(classification_output, entailment_judgements)
        loss = loss1
        # 反向傳播與更新參數
        loss.backward()
        optimizer.step()
        pbar.set_postfix({
            "Loss": loss.item()
        })

    # Validation Loop
    pbar = tqdm(dl_validation)
    pbar.set_description(f"Validation epoch [{ep+1}/{epochs}]")
    model.eval()

    spc.to(device)
    # acc.to(device)
    # f1.to(device)

    spc.reset()
    # acc.reset()
    # f1.reset()

    with torch.no_grad():
        for batch_idx, batch in enumerate(pbar):
            relatedness_scores = batch['label1'].to(device)
            # entailment_judgements = batch['label2'].to(device)
            batch['input_text'] = {
                k: v.to(device) for k, v in batch['input_text'].items() if isinstance(v, torch.Tensor)
            }

            regression_output = model_relatedness(**batch)

            # 更新度量
            spc.update(regression_output.squeeze(), relatedness_scores)
            # acc.update(classification_output, entailment_judgements)
            # f1.update(classification_output, entailment_judgements)

    print(
        f"SpearmanCorr: {spc.compute().item()}\n"
          )


Training epoch [1/3]: 100%|██████████| 563/563 [00:10<00:00, 54.58it/s, Loss=14.2]
Validation epoch [1/3]: 100%|██████████| 63/63 [00:00<00:00, 165.23it/s]


SpearmanCorr: 0.026445582509040833



Training epoch [2/3]: 100%|██████████| 563/563 [00:10<00:00, 55.85it/s, Loss=18.2]
Validation epoch [2/3]: 100%|██████████| 63/63 [00:00<00:00, 165.95it/s]


SpearmanCorr: -0.0391974151134491



Training epoch [3/3]: 100%|██████████| 563/563 [00:09<00:00, 58.49it/s, Loss=14.7]
Validation epoch [3/3]: 100%|██████████| 63/63 [00:00<00:00, -123.41it/s]

SpearmanCorr: -0.08608773350715637






### Entailment

In [34]:
# TODO2: Construct your model
class SingleLabelModelEntailment(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Write your code here
        # Define what modules you will use in the model
        
        # 先做一個BertModel，然後加上relatedness_score的regression和entailment_judgement的分類層

        self.bert = T.BertModel.from_pretrained("bert-base-uncased", cache_dir="./cache/")
        # self.regression = torch.nn.Linear(self.bert.config.hidden_size, 1)
        self.classification = torch.nn.Linear(self.bert.config.hidden_size, 3)

    def forward(self, **kwargs):
        # Write your code here
        # Forward pass

        input_text = kwargs['input_text']
        outputs = self.bert(
            input_ids=input_text['input_ids'],
            token_type_ids=input_text['token_type_ids'],
            attention_mask=input_text['attention_mask']
        )
        cls_output = outputs.last_hidden_state[:, 0, :]
        # regression_output = self.regression(cls_output)
        classification_output = self.classification(cls_output)
        
        return classification_output

model_entailment = SingleLabelModelEntailment().to(device)

In [35]:
for ep in range(epochs):
    # Training Loop
    pbar = tqdm(dl_train)
    pbar.set_description(f"Training epoch [{ep+1}/{epochs}]")
    model_entailment.train()
    for batch_idx, batch in enumerate(pbar):
        # 清空梯度
        optimizer.zero_grad()
        # relatedness_scores = batch['label1'].to(device)
        entailment_judgements = batch['label2'].to(device)
        batch['input_text'] = {
            k: v.to(device) for k, v in batch['input_text'].items() if isinstance(v, torch.Tensor)
        }
        # 前向傳播與計算損失
        classification_output = model_entailment(**batch)
        # loss1 = loss_fn_regression(regression_output.squeeze(), relatedness_scores)
        loss2 = loss_fn_classification(classification_output, entailment_judgements)
        loss = loss2
        # 反向傳播與更新參數
        loss.backward()
        optimizer.step()
        pbar.set_postfix({
            "Loss": loss.item()
        })

    # Validation Loop
    pbar = tqdm(dl_validation)
    pbar.set_description(f"Validation epoch [{ep+1}/{epochs}]")
    model.eval()

    spc.to(device)
    # acc.to(device)
    # f1.to(device)

    # spc.reset()
    acc.reset()
    f1.reset()

    with torch.no_grad():
        for batch_idx, batch in enumerate(pbar):
            # relatedness_scores = batch['label1'].to(device)
            entailment_judgements = batch['label2'].to(device)
            batch['input_text'] = {
                k: v.to(device) for k, v in batch['input_text'].items() if isinstance(v, torch.Tensor)
            }

            classification_output = model_entailment(**batch)

            # 更新度量
            # spc.update(regression_output.squeeze(), relatedness_scores)
            acc.update(classification_output, entailment_judgements)
            f1.update(classification_output, entailment_judgements)

    print(
        f"Accuracy: {acc.compute().item()}\nF1Score: {f1.compute().item()}\n"
          )


Training epoch [1/3]: 100%|██████████| 563/563 [00:10<00:00, 55.05it/s, Loss=0.92] 
Validation epoch [1/3]: 100%|██████████| 63/63 [00:00<00:00, 131.50it/s]


Accuracy: 0.550000011920929
F1Score: 0.23655913770198822



Training epoch [2/3]: 100%|██████████| 563/563 [00:09<00:00, 56.83it/s, Loss=1]    
Validation epoch [2/3]: 100%|██████████| 63/63 [00:00<00:00, 142.70it/s]


Accuracy: 0.5540000200271606
F1Score: 0.24638409912586212



Training epoch [3/3]: 100%|██████████| 563/563 [00:09<00:00, 56.54it/s, Loss=1.39] 
Validation epoch [3/3]: 100%|██████████| 63/63 [00:00<00:00, -140.48it/s]

Accuracy: 0.5600000023841858
F1Score: 0.24419522285461426






## 用額外資料集個別訓練entailment跟relatedness

In [12]:
entailment_ds = load_dataset("nyu-mll/glue", "mnli", split='train')
relatedness_ds = load_dataset("nyu-mll/glue", "stsb", split='train')

In [13]:
entailment_ds[0]['label']

1

In [14]:
entailment_ds = entailment_ds.select(range(5000))

In [15]:
# 將entailment的label對齊
entailment_ds = [
    {**example, 'label': 0 if example['label'] == 1 else 1 if example['label'] == 0 else example['label']} for example in entailment_ds
]
entailment_ds[0]['label']

0

In [16]:
entailment_ds[0]

{'premise': 'Conceptually cream skimming has two basic dimensions - product and geography.',
 'hypothesis': 'Product and geography are what make cream skimming work. ',
 'label': 0,
 'idx': 0}

In [17]:
relatedness_ds[0]

{'sentence1': 'A plane is taking off.',
 'sentence2': 'An air plane is taking off.',
 'label': 5.0,
 'idx': 0}

In [18]:
class RelatednessDataset(Dataset):
    def __init__(self, dataset) -> None:
        super().__init__()
        self.data = dataset

    def __getitem__(self, index):
        item = self.data[index]
        return {
            'premise': item['sentence1'],
            'hypothesis': item['sentence2'],
            'relatedness_score': item['label']
        }

    def __len__(self):
        return len(self.data)
    
class EntailmentDataset(Dataset):
    def __init__(self, dataset) -> None:
        super().__init__()
        self.data = dataset

    def __getitem__(self, index):
        item = self.data[index]
        return {
            'premise': item['premise'],
            'hypothesis': item['hypothesis'],
            'entailment_judgment': item['label']
        }

    def __len__(self):
        return len(self.data)

def collate_fn_relatedness(batch):
    premises = [item['premise'] for item in batch]
    hypotheses = [item['hypothesis'] for item in batch]
    relatedness_scores = [item['relatedness_score'] for item in batch]

    inputs = tokenizer(premises, hypotheses, padding=True, truncation=True, return_tensors="pt")
    relatedness_scores = torch.tensor(relatedness_scores, dtype=torch.float)

    return {'input_text': inputs, 'relatedness_score': relatedness_scores}

def collate_fn_entailment(batch):
    premises = [item['premise'] for item in batch]
    hypotheses = [item['hypothesis'] for item in batch]
    entailment_judgments = [item['entailment_judgment'] for item in batch]

    inputs = tokenizer(premises, hypotheses, padding=True, truncation=True, return_tensors="pt")
    entailment_judgments = torch.tensor(entailment_judgments, dtype=torch.long)

    return {'input_text': inputs, 'entailment_judgment': entailment_judgments}

relatedness_dl = DataLoader(
    RelatednessDataset(relatedness_ds),
    batch_size=32,
    shuffle=True,
    collate_fn=collate_fn_relatedness
)

entailment_dl = DataLoader(
    EntailmentDataset(entailment_ds),
    batch_size=32,
    shuffle=True,
    collate_fn=collate_fn_entailment
)


In [36]:
model.train()
pbar = tqdm(relatedness_dl)
pbar.set_description(f"relatedness")

# 訓練 Relatedness 任務
for batch_idx, batch in enumerate(pbar):
    optimizer.zero_grad()
    
    batch['input_text'] = {
        k: v.to(device) for k, v in batch['input_text'].items() if isinstance(v, torch.Tensor)
    }
    input_text = batch['input_text']
    relatedness_score = batch['relatedness_score'].to(device)

    # 模型前向傳遞
    regression_output, _ = model(input_text=input_text)

    # 計算回歸損失
    loss_regression = loss_fn_regression(regression_output.squeeze(-1), relatedness_score)
    loss_regression.backward()
    optimizer.step()

pbar = tqdm(entailment_dl)
pbar.set_description(f"entailment")

# 訓練 Entailment 任務
for batch_idx, batch in enumerate(pbar):
    optimizer.zero_grad()
    batch['input_text'] = {
        k: v.to(device) for k, v in batch['input_text'].items() if isinstance(v, torch.Tensor)
                            }
    input_text = batch['input_text']
    entailment_judgment = batch['entailment_judgment'].to(device)

    # 模型前向傳遞
    _, classification_output = model(input_text=input_text)

    # 計算分類損失
    loss_classification = loss_fn_classification(classification_output, entailment_judgment)
    loss_classification.backward()
    optimizer.step()

relatedness: 100%|██████████| 180/180 [00:09<00:00, 18.81it/s]
entailment: 100%|██████████| 313/313 [00:20<00:00, 14.93it/s]


In [None]:
# Test (不使用混合訓練)
pbar = tqdm(dl_test)
pbar.set_description(f"test")
model.eval()

spc.to(device)
acc.to(device)
f1.to(device)

spc.reset()
acc.reset()
f1.reset()

with torch.no_grad():
    for batch_idx, batch in enumerate(pbar):
        relatedness_scores = batch['label1'].to(device)
        entailment_judgements = batch['label2'].to(device)
        batch['input_text'] = {
            k: v.to(device) for k, v in batch['input_text'].items() if isinstance(v, torch.Tensor)
        }

        regression_output, classification_output = model(**batch)

        # 更新度量
        spc.update(regression_output.squeeze(), relatedness_scores)
        acc.update(classification_output, entailment_judgements)
        f1.update(classification_output, entailment_judgements)

print(f"Testset\nSpearmanCorr: {spc.compute().item()}\n"
        f"Accuracy: {acc.compute().item()}\nF1Score: {f1.compute().item()}\n"
        )

test: 100%|██████████| 616/616 [00:03<00:00, 154.05it/s]

Testset
SpearmanCorr: 0.7849445343017578
Accuracy: 0.7513700127601624
F1Score: 0.738945484161377






In [23]:
from itertools import cycle
from tqdm import tqdm
import random

# 合併兩個 DataLoader
def mixed_dataloader(dl1, dl2):
    iter1 = cycle(dl1)  # 確保 relatedness_dl 不會提前結束
    iter2 = cycle(dl2)  # 確保 entailment_dl 不會提前結束
    while True:
        # 隨機選擇一個 DataLoader
        if random.random() < 0.5:
            yield next(iter1), "relatedness"
        else:
            yield next(iter2), "entailment"

# 混合訓練過程
num_epochs = 5
total_steps = min(len(relatedness_dl), len(entailment_dl)) * num_epochs  # 總進度條步數
mixed_dl = mixed_dataloader(relatedness_dl, entailment_dl)

model.train()

# 創建 tqdm 進度條
pbar = tqdm(total=total_steps, desc="Training", unit="batch")

for epoch in range(num_epochs):
    for step in range(len(relatedness_dl) + len(entailment_dl)):  # 每個 epoch 遍歷所有任務
        batch, task_type = next(mixed_dl)
        optimizer.zero_grad()

        # 將 batch 的 input_text 移到裝置上
        batch['input_text'] = {
            k: v.to(device) for k, v in batch['input_text'].items() if isinstance(v, torch.Tensor)
        }
        input_text = batch['input_text']

        if task_type == "relatedness":
            # 訓練 Relatedness 任務
            relatedness_score = batch['relatedness_score'].to(device)
            regression_output, _ = model(input_text=input_text)
            loss_regression = loss_fn_regression(regression_output.squeeze(-1), relatedness_score)
            loss_regression.backward()
            optimizer.step()
        elif task_type == "entailment":
            # 訓練 Entailment 任務
            entailment_judgment = batch['entailment_judgment'].to(device)
            _, classification_output = model(input_text=input_text)
            loss_classification = loss_fn_classification(classification_output, entailment_judgment)
            loss_classification.backward()
            optimizer.step()

        # 更新 tqdm 進度條
        pbar.update(1)
        pbar.set_postfix({"epoch": epoch + 1, "task": task_type})


Training: 1685batch [01:40, 16.08batch/s, epoch=5, task=entailment]                      

In [None]:
# Test (使用混合訓練)
pbar = tqdm(dl_test)
pbar.set_description(f"test")
model.eval()

spc.to(device)
acc.to(device)
f1.to(device)

spc.reset()
acc.reset()
f1.reset()

with torch.no_grad():
    for batch_idx, batch in enumerate(pbar):
        relatedness_scores = batch['label1'].to(device)
        entailment_judgements = batch['label2'].to(device)
        batch['input_text'] = {
            k: v.to(device) for k, v in batch['input_text'].items() if isinstance(v, torch.Tensor)
        }

        regression_output, classification_output = model(**batch)

        # 更新度量
        spc.update(regression_output.squeeze(), relatedness_scores)
        acc.update(classification_output, entailment_judgements)
        f1.update(classification_output, entailment_judgements)

print(f"Testset\nSpearmanCorr: {spc.compute().item()}\n"
        f"Accuracy: {acc.compute().item()}\nF1Score: {f1.compute().item()}\n"
        )

test: 100%|██████████| 616/616 [00:03<00:00, 155.76it/s]

Testset
SpearmanCorr: 0.8087190985679626
Accuracy: 0.7432514429092407
F1Score: 0.7365155220031738




