In [1]:
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from transformers import BertTokenizer
from sklearn.utils import shuffle

from typing import *

from MEOW_Models import Pairwise_models, MT_models
from MEOW_Utils import Pairwise_utils, MT_utils

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 12
EPOCH_NUM = 10
INPUT_FILE_PATH = r'C:\Users\Administrator\codeblocks_workspace\MEOW\RTE_train.csv'
PRETRAINED_MODULE_NAME = 'bert-base-uncased'

In [2]:
# get dataframe
tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODULE_NAME)
train_df = MT_utils.get_RTE_df(INPUT_FILE_PATH, tokenizer)
num_labels = len(train_df.value_counts('label').keys())
train_df.head()

Unnamed: 0,context1,context2,label,label_name,SEP_ind
0,He tried to persuade a local banker to loan mo...,Yunus started his own bank to loan money to th...,1,entailment,50
1,She would later return to the Contest in 2005 ...,Ms Siracura joined the Eurovision Song Contest.,0,not_entailment,96
2,Preliminary results from the National Assessme...,The National Assessment of Adult Literacy is a...,1,entailment,148
3,ARCHEOLOGISTS and forensic experts believe the...,The name of Queen Cleopatra's sister was Arsinöe.,1,entailment,140
4,Brad Pitt and Angelina Jolie are to let Getty ...,Brad Pitt and Angelina Jolie have a daughter.,1,entailment,104


In [3]:
#處理好model

bert_model = MT_models.BertWithoutEmbedding.from_pretrained(PRETRAINED_MODULE_NAME)
GETEMBEDDING_helper = MT_utils.get_bert_element(bertmodel=bert_model)
embedding_layer = GETEMBEDDING_helper.get_copy_embeddings_layer()

train_dataset = Pairwise_utils.Pairwise_dataset(train_df, tokenizer, num_labels)
train_loader = Pairwise_utils.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=Pairwise_utils.collate_batch)

Pairwise_model = Pairwise_models.Bert_pairwise(model=bert_model, embedding_layer=embedding_layer, device=DEVICE, num_labels=num_labels)
Pairwise_model.to(DEVICE)

optimizer = torch.optim.SGD(Pairwise_model.parameters(), lr=0.0001, momentum=0.9)
# optimizer = torch.optim.Adam(QA_model.parameters(), lr=0.001, betas=(0.88, 0.95), eps=1e-08)

H = {
    "train_loss": [],
    "train_acc": [],
    "val_loss":[],
    "val_acc": []
    }

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertWithoutEmbedding: ['cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertWithoutEmbedding from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertWithoutEmbedding from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
#開始訓練

for epoch in range(EPOCH_NUM):
    print("the {:d} iter :".format(epoch+1))

    Pairwise_model.train()
    # train 
    training_loss = 0.0
    training_correct = 0.0

    for data in train_loader:
        input_ids, mask, token, label, SEPind = data
        
        # input_ids = input_ids.type(torch.IntTensor)
        # mask = mask.type(torch.IntTensor)
        # labels = labels.type(torch.LongTensor)

        input_ids = input_ids.to(DEVICE)
        mask = mask.to(DEVICE)
        token = token.to(DEVICE)
        label = label.to(DEVICE)
    
        outputs = Pairwise_model(input_ids, token=token, attention_mask=mask, label=label ,SEPind=SEPind)
        loss = outputs[0]
        prob = outputs[1]

        print(loss)
        # print(prob)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        predict = torch.argmax(prob, dim=1)
        label = torch.argmax(label, dim=1)
        correct_num = (predict == label).type(torch.int).sum()
        
        training_loss += loss.item()
        training_correct += correct_num
        
        # print(correct_num)
        

    avg_loss = training_loss / len(train_loader)
    avg_acc = training_correct / len(train_dataset)

    H['train_loss'].append(avg_loss)
    H['train_acc'].append(avg_acc)
    print("Train loss: {:.6f}, Train accuracy {:.4f}".format(avg_loss, avg_acc))


the 1 iter :
tensor(1.1246, device='cuda:0', grad_fn=<DivBackward1>)
tensor(1.1243, device='cuda:0', grad_fn=<DivBackward1>)
tensor(1.1115, device='cuda:0', grad_fn=<DivBackward1>)
tensor(1.0946, device='cuda:0', grad_fn=<DivBackward1>)
tensor(1.1159, device='cuda:0', grad_fn=<DivBackward1>)
tensor(1.0771, device='cuda:0', grad_fn=<DivBackward1>)
tensor(1.0823, device='cuda:0', grad_fn=<DivBackward1>)
tensor(1.0280, device='cuda:0', grad_fn=<DivBackward1>)


KeyboardInterrupt: 

In [40]:
import torch

a = torch.tensor([1,2,3], dtype=torch.float, requires_grad=True)
b = torch.tensor([7,7,7], dtype=torch.float, requires_grad=True)

c = torch.stack([a,b])
loss = c.sum()
print(c)
loss.backward()

print(a.grad)

tensor([[1., 2., 3.],
        [7., 7., 7.]], grad_fn=<StackBackward0>)
tensor([1., 1., 1.])


In [None]:
import matplotlib
import matplotlib.pyplot as plt

# tensor to float
H['train_loss'] = [float(i) for i in H['train_loss']]
# H['train_acc'] = [float(i) for i in H['train_acc']]

# H['val_loss'] = [float(i) for i in H['val_loss']]
# H['val_acc'] = [float(i) for i in H['val_acc']]

# loss
plt.figure()
plt.title("Loss")
plt.xlabel("EPOCH")
plt.ylabel("Loss")
plt.plot(H["train_loss"], label="test_loss")
# plt.plot(H["val_loss"], label="test_loss")
plt.xticks(np.arange(10), range(1,11,1))
plt.show()

# accuracy
# plt.figure()
# plt.title("Test Accuracy")
# plt.xlabel("EPOCH")
# plt.ylabel("Accuracy")
# plt.plot(H["train_acc"], label="test_acc")
# # plt.plot(H["val_acc"], label="test_acc")
# plt.xticks(np.arange(6), range(1,7,1))
# plt.show()

