## original 

In [35]:
import pickle
import pandas as pd
import itertools
import copy as cp
from collections import Counter
import numpy as np
import nltk
from nltk import word_tokenize
from nltk.corpus import stopwords
from gensim.models import word2vec
from sklearn.linear_model import LogisticRegression
import os
import string
nltk.download('stopwords')
from tqdm.auto import tqdm
from datasets import load_dataset
dataset = load_dataset("poem_sentiment")

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\maenz\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
Found cached dataset poem_sentiment (C:/Users/maenz/.cache/huggingface/datasets/poem_sentiment/default/1.0.0/4e44428256d42cdde0be6b3db1baa587195e91847adabf976e4f9454f6a82099)


  0%|                                                                                            | 0/3 [00:00<…

In [36]:
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'verse_text', 'label'],
        num_rows: 892
    })
    validation: Dataset({
        features: ['id', 'verse_text', 'label'],
        num_rows: 105
    })
    test: Dataset({
        features: ['id', 'verse_text', 'label'],
        num_rows: 104
    })
})

In [37]:
df=pd.DataFrame.from_dict(dataset["train"][0:700])

In [38]:
df.iloc[0:5]['verse_text']

0    with pale blue berries. in these peaceful shad...
1                  it flows so long as falls the rain,
2                   and that is why, the lonesome day,
3    when i peruse the conquered fame of heroes, an...
4              of inward strife for truth and liberty.
Name: verse_text, dtype: object

In [39]:
df_train=df

In [40]:
df_a=pd.DataFrame.from_dict(dataset["validation"])
df_b=pd.DataFrame.from_dict(dataset["test"])

In [41]:
df_test=pd.concat([df_a,df_b])

In [42]:
df_test=df_test.reset_index(drop=True)

In [43]:
df_test

Unnamed: 0,id,verse_text,label
0,0,"to water, cloudlike on the bush afar,",2
1,1,"shall yet be glad for him, and he shall bless",1
2,2,on its windy site uplifting gabled roof and pa...,2
3,3,(if haply the dark will of fate,0
4,4,"jehovah, jove, or lord!",2
...,...,...,...
204,99,shall live my highland mary.,2
205,100,now is past since last we met,2
206,101,begins to live,2
207,102,beneath the hazel bough;,2


In [44]:
# df_train.to_csv("poem_train")
# df_test.to_csv("poem_test")

In [45]:
df_train=pd.read_csv("poem_train")
df_text=pd.read_csv("poem_test")

In [46]:
X_train_original=df_train['verse_text']
X_valid_original=df_test['verse_text']
y_train_original=df_train['label']
y_valid_original=df_test['label']

In [47]:
import torch
assert torch.cuda.is_available()
device = torch.device("cuda")

In [48]:
from transformers import AutoTokenizer, AdamW,BertForSequenceClassification
tokenizer= AutoTokenizer.from_pretrained("bert-base-uncased")
X_train=tokenizer(list(X_train_original),truncation=True, padding=True, max_length=12,return_tensors='pt')
y_train=list(y_train_original)
X_valid=tokenizer(list(X_valid_original),truncation=True, padding=True, max_length=12,return_tensors='pt')
y_valid=list(y_valid_original)

In [49]:
class MyData(torch.utils.data.Dataset):
    def __init__(self, token):
        self.token = token
    def __getitem__(self, id):
        return {key: val[id] for key, val in self.token.items()}
    def __len__(self):
        return len(self.token.input_ids)

In [50]:
bat_size=128
from torch.utils.data import DataLoader
train_dataloader = DataLoader(MyData(X_train) , batch_size=bat_size,shuffle = False)
valid_dataloader = DataLoader(MyData(X_valid) , batch_size=bat_size,shuffle = False)

In [17]:
model = BertForSequenceClassification.from_pretrained(
    "bert-base-uncased", 
    num_labels = 4,    
    output_attentions = False, 
    output_hidden_states = False,
)
model.to(device)
optimizer = AdamW(model.parameters(),
                  lr = 1e-5, # args.learning_rate - default is 5e-5
                  eps = 1e-8, # args.adam_epsilon  - default is 1e-8
                  weight_decay=0.01,
                  correct_bias=True
                )
epochs = 50

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [47]:
train_lost_epoch=[]
valid_lost_epoch=[]
valid_accuracy_list=[]
for epoch in tqdm(range(epochs),desc='epoch'):
    model.train()
    loss_list=[]
    i=0
    for batch in tqdm(train_dataloader,desc='trainset',leave=False):
        model.zero_grad() 
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        target=torch.tensor(y_train[i:i+bat_size]).to(device)
        i+=bat_size
        outputs = model(input_ids, 
                            token_type_ids, 
                            attention_mask, 
                            labels=target)
    
        loss = outputs.loss
        logits = outputs.logits
        loss.backward()
        optimizer.step()
        loss_list.append(loss.item())
    train_lost_epoch.append(np.mean(loss_list))
    model.eval()
    loss_list=[]
    i=0
    num_correct=0
    for batch in tqdm(valid_dataloader,desc='validset',leave=False):
        with torch.no_grad():  
            input_ids = batch['input_ids'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)

            target=torch.tensor(y_valid[i:i+bat_size]).to(device)
            i+=bat_size
            outputs = model(input_ids, 
                                token_type_ids, 
                                attention_mask, 
                                labels=target)
            loss = outputs.loss
            logits = outputs.logits
            loss_list.append(loss.item())
            pred_flat = np.argmax(logits.detach().cpu().numpy(), axis=1).flatten()
            labels_flat = target.to('cpu').numpy().flatten()
            num_correct += np.sum(pred_flat == labels_flat)
    valid_accuracy=num_correct/len(y_valid)
    if  epoch==0:
        torch.save(model.state_dict(), "my_model")
    else:
        if  valid_accuracy>max(valid_accuracy_list):
            torch.save(model.state_dict(), "my_model")
    if epoch>=10:
        if valid_accuracy<min(valid_accuracy_list[-10:-1]):
            print("early stop")
            break
    valid_accuracy_list.append(valid_accuracy)
    valid_lost_epoch.append(np.mean(loss_list))
    print("epoch: "+str(epoch))
    print("train loss: "+str(train_lost_epoch[-1]))
    print("valid loss: "+str(valid_lost_epoch[-1]))
    print("valid accuracy: "+str(valid_accuracy))
    print("——————————————————————————————————————————————")

epoch:   0%|                                                                                    | 0/50 [00:00<…

trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 0
train loss: 1.302384873231252
valid loss: 1.1306560635566711
valid accuracy: 0.6602870813397129
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 1
train loss: 1.1197502613067627
valid loss: 0.9944722354412079
valid accuracy: 0.6602870813397129
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 2
train loss: 1.041100839773814
valid loss: 0.9551325142383575
valid accuracy: 0.6602870813397129
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 3
train loss: 1.0168617963790894
valid loss: 0.9494473040103912
valid accuracy: 0.6602870813397129
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 4
train loss: 1.0143223305543263
valid loss: 0.929929256439209
valid accuracy: 0.6602870813397129
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 5
train loss: 0.9830000698566437
valid loss: 0.918822318315506
valid accuracy: 0.6602870813397129
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 6
train loss: 0.9489422539869944
valid loss: 0.8932844996452332
valid accuracy: 0.6602870813397129
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 7
train loss: 0.9001456002394358
valid loss: 0.8685807883739471
valid accuracy: 0.6602870813397129
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 8
train loss: 0.8305393060048422
valid loss: 0.8330943286418915
valid accuracy: 0.6794258373205742
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 9
train loss: 0.7502157290776571
valid loss: 0.8106102049350739
valid accuracy: 0.7033492822966507
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 10
train loss: 0.6824499468008677
valid loss: 0.7934814393520355
valid accuracy: 0.7081339712918661
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 11
train loss: 0.6090277433395386
valid loss: 0.823330283164978
valid accuracy: 0.7177033492822966
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 12
train loss: 0.5399445593357086
valid loss: 0.7983383536338806
valid accuracy: 0.7177033492822966
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 13
train loss: 0.49812538425127667
valid loss: 0.8722084760665894
valid accuracy: 0.7368421052631579
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 14
train loss: 0.4674835006395976
valid loss: 0.937417209148407
valid accuracy: 0.7129186602870813
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 15
train loss: 0.4313296526670456
valid loss: 0.8887096047401428
valid accuracy: 0.7272727272727273
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 16
train loss: 0.37259669105211896
valid loss: 0.9255449771881104
valid accuracy: 0.6842105263157895
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 17
train loss: 0.3255942364533742
valid loss: 0.9492078423500061
valid accuracy: 0.722488038277512
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 18
train loss: 0.2908825675646464
valid loss: 0.9627380669116974
valid accuracy: 0.722488038277512
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 19
train loss: 0.256252184510231
valid loss: 0.9885523617267609
valid accuracy: 0.6889952153110048
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 20
train loss: 0.23876200864712396
valid loss: 1.0239688456058502
valid accuracy: 0.722488038277512
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 21
train loss: 0.21285991867383322
valid loss: 1.0348764061927795
valid accuracy: 0.69377990430622
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 22
train loss: 0.20371799916028976
valid loss: 1.0806025266647339
valid accuracy: 0.7368421052631579
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 23
train loss: 0.18118554105361304
valid loss: 1.0620994567871094
valid accuracy: 0.6985645933014354
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 24
train loss: 0.1762292558948199
valid loss: 1.0613073706626892
valid accuracy: 0.7033492822966507
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 25
train loss: 0.1482169988254706
valid loss: 1.096186339855194
valid accuracy: 0.7368421052631579
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 26
train loss: 0.12271497771143913
valid loss: 1.0896676778793335
valid accuracy: 0.7129186602870813
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 27
train loss: 0.11782162884871165
valid loss: 1.0872318744659424
valid accuracy: 0.7320574162679426
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 28
train loss: 0.1052653081715107
valid loss: 1.1293259263038635
valid accuracy: 0.7320574162679426
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 29
train loss: 0.09343202784657478
valid loss: 1.1425594687461853
valid accuracy: 0.7081339712918661
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 30
train loss: 0.08767187471191089
valid loss: 1.130857765674591
valid accuracy: 0.7320574162679426
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 31
train loss: 0.08408034468690555
valid loss: 1.139640748500824
valid accuracy: 0.7129186602870813
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 32
train loss: 0.07297161718209584
valid loss: 1.2277202606201172
valid accuracy: 0.7129186602870813
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 33
train loss: 0.07061389523247878
valid loss: 1.2322866320610046
valid accuracy: 0.6985645933014354
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 34
train loss: 0.06615132403870423
valid loss: 1.2207797169685364
valid accuracy: 0.7464114832535885
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 35
train loss: 0.06068310079475244
valid loss: 1.2312549948692322
valid accuracy: 0.6985645933014354
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 36
train loss: 0.05416193356116613
valid loss: 1.247413992881775
valid accuracy: 0.7320574162679426
——————————————————————————————————————————————


trainset:   0%|                                                                                  | 0/6 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

early stop


In [48]:
print(max(valid_accuracy_list))

0.7464114832535885


## baseline

In [49]:
import os
import openai
import time
from auto_tqdm import tqdm
def rephrase_by_chatgpt(text):
    file1=open("api_key.txt",'r')
    api_key=file1.read()
    openai.api_key=api_key
    completion = openai.ChatCompletion.create(
    model="gpt-3.5-turbo",
    messages=[{'role': 'user', 'content': "Please rephrase the following sentence: " +text}])
    return completion.choices[0].message.content

In [50]:
rephrase_by_chatgpt(X_train_original[1])

'It continues to flow for as long as the rain falls.'

In [51]:
# x_new=[]
# y_new=[]
# for i in tqdm(range(len(X_train_original))):
#     count=0
#     while(count!=1):
#         try:
#             ans=rephrase_by_chatgpt(X_train_original[i])
#             x_new.append(ans)
#             y_new.append(y_train_original[i])
#             count=1  
#         except:
#             time.sleep(0.5)

In [52]:
# base_dict={"verse_text":x_new,"label":y_new}
# new_df=pd.DataFrame.from_dict(base_dict)
# new_df.to_csv("chatgpt_poem_base")
new_df=pd.read_csv("chatgpt_poem_base")
X_train_base=list(X_train_original)+list(new_df['verse_text'])
y_train_base=list(y_train_original)+list(new_df['label'])
X_train=tokenizer(list(X_train_base),truncation=True, padding=True, max_length=12,return_tensors='pt')
y_train=list(y_train_base)
from torch.utils.data import DataLoader
train_dataloader = DataLoader(MyData(X_train) , batch_size=bat_size,shuffle = False)

In [53]:
model= BertForSequenceClassification.from_pretrained(
    "bert-base-uncased", 
    num_labels = 4,    
    output_attentions = False, 
    output_hidden_states = False,
)
model.to(device)
optimizer = AdamW(model.parameters(),
                  lr = 1e-5, # args.learning_rate - default is 5e-5
                  eps = 1e-8, # args.adam_epsilon  - default is 1e-8
                  weight_decay=0.01,
                  correct_bias=True
                )
epochs = 50

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [54]:
train_lost_epoch=[]
valid_lost_epoch=[]
valid_accuracy_list=[]
for epoch in tqdm(range(epochs),desc='epoch'):
    model.train()
    loss_list=[]
    i=0
    for batch in tqdm(train_dataloader,desc='trainset',leave=False):
        model.zero_grad() 
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        target=torch.tensor(y_train[i:i+bat_size]).to(device)
        i+=bat_size
        outputs = model(input_ids, 
                            token_type_ids, 
                            attention_mask, 
                            labels=target)

        loss = outputs.loss
        logits = outputs.logits
        loss.backward()
        optimizer.step()
        loss_list.append(loss.item())
    train_lost_epoch.append(np.mean(loss_list))
    model.eval()
    loss_list=[]
    i=0
    num_correct=0
    for batch in tqdm(valid_dataloader,desc='validset',leave=False):
        with torch.no_grad():  
            input_ids = batch['input_ids'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)

            target=torch.tensor(y_valid[i:i+bat_size]).to(device)
            i+=bat_size
            outputs = model(input_ids, 
                                token_type_ids, 
                                attention_mask, 
                                labels=target)
            loss = outputs.loss
            logits = outputs.logits
            loss_list.append(loss.item())
            pred_flat = np.argmax(logits.detach().cpu().numpy(), axis=1).flatten()
            labels_flat = target.to('cpu').numpy().flatten()
            num_correct += np.sum(pred_flat == labels_flat)
    valid_accuracy=num_correct/len(y_valid)
    if  epoch==0:
        torch.save(model.state_dict(), "my_model_re")
    else:
        if  valid_accuracy>max(valid_accuracy_list):
            torch.save(model.state_dict(), "my_model_re")
    if epoch>=10:
        if valid_accuracy<min(valid_accuracy_list[-10:-1]):
            print("early stop")
            break
    valid_accuracy_list.append(valid_accuracy)
    valid_lost_epoch.append(np.mean(loss_list))
    print("epoch: "+str(epoch))
    print("train loss: "+str(train_lost_epoch[-1]))
    print("valid loss: "+str(valid_lost_epoch[-1]))
    print("valid accuracy: "+str(valid_accuracy))
    print("——————————————————————————————————————————————")

epoch:   0%|                                                                                    | 0/50 [00:00<…

trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 0
train loss: 1.14312021298842
valid loss: 0.9796668887138367
valid accuracy: 0.6602870813397129
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 1
train loss: 1.0268731930039146
valid loss: 0.929713249206543
valid accuracy: 0.6602870813397129
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 2
train loss: 1.0013412291353398
valid loss: 0.9047386646270752
valid accuracy: 0.6602870813397129
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 3
train loss: 0.9557891596447338
valid loss: 0.8863883018493652
valid accuracy: 0.6602870813397129
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 4
train loss: 0.8926624330607328
valid loss: 0.8593977987766266
valid accuracy: 0.6842105263157895
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 5
train loss: 0.8146767182783647
valid loss: 0.8074006736278534
valid accuracy: 0.7129186602870813
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 6
train loss: 0.721032895825126
valid loss: 0.7877430617809296
valid accuracy: 0.7272727272727273
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 7
train loss: 0.6548753543333574
valid loss: 0.7972751557826996
valid accuracy: 0.7559808612440191
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 8
train loss: 0.5760138197378679
valid loss: 0.8543191850185394
valid accuracy: 0.7320574162679426
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 9
train loss: 0.519777777520093
valid loss: 0.934749960899353
valid accuracy: 0.7081339712918661
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 10
train loss: 0.4896187348799272
valid loss: 0.9325160980224609
valid accuracy: 0.69377990430622
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 11
train loss: 0.4732683138413863
valid loss: 0.9329074919223785
valid accuracy: 0.6602870813397129
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 12
train loss: 0.43721178444949066
valid loss: 0.7880803346633911
valid accuracy: 0.7320574162679426
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 13
train loss: 0.3893626088445837
valid loss: 0.984249472618103
valid accuracy: 0.7177033492822966
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 14
train loss: 0.36256465315818787
valid loss: 0.9371805489063263
valid accuracy: 0.7081339712918661
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 15
train loss: 0.300536264072765
valid loss: 0.8141528069972992
valid accuracy: 0.7607655502392344
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 16
train loss: 0.23065392131155188
valid loss: 0.9128938019275665
valid accuracy: 0.7464114832535885
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 17
train loss: 0.194395743987777
valid loss: 0.9388521015644073
valid accuracy: 0.7559808612440191
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 18
train loss: 0.1562873531471599
valid loss: 1.0304115414619446
valid accuracy: 0.7368421052631579
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 19
train loss: 0.1310361779548905
valid loss: 1.0253433883190155
valid accuracy: 0.722488038277512
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 20
train loss: 0.10582554137164896
valid loss: 1.0431211292743683
valid accuracy: 0.7272727272727273
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 21
train loss: 0.09070613370700316
valid loss: 1.0770219564437866
valid accuracy: 0.7272727272727273
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 22
train loss: 0.07716849581761794
valid loss: 1.1558043956756592
valid accuracy: 0.7320574162679426
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 23
train loss: 0.06735729019750249
valid loss: 1.1525227427482605
valid accuracy: 0.7511961722488039
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

early stop


In [55]:
print(max(valid_accuracy_list))

0.7607655502392344


### trial1

In [17]:
import os
import openai
import time
from auto_tqdm import tqdm
def rephrase_by_chatgpt2(text,label):
    file1=open("api_key.txt",'r')
    api_key=file1.read()
    openai.api_key=api_key
    if(label==0):
        tone="negative"
    elif(label==1):
        tone="positive"
    elif(label==2):
        tone="neutral"
    else:
        tone="mixed"
    completion = openai.ChatCompletion.create(
    model="gpt-3.5-turbo",
    messages=[{'role': 'user', 'content': "Please rephrase this sentence while retaining the"+ tone + "emotion: "+text}])
    return completion.choices[0].message.content

In [18]:
rephrase_by_chatgpt2(X_train_original[1],y_train_original[1])

'It continues to flow as long as the rain falls.'

In [19]:
x_new=[]
y_new=[]
for i in tqdm(range(len(X_train_original))):
    count=0
    while(count!=1):
        try:
            ans=rephrase_by_chatgpt2(X_train_original[i],y_train_original[i])
            x_new.append(ans)
            y_new.append(y_train_original[i])
            count=1  
        except:
            time.sleep(0.5)

  0%|                                                                                          | 0/700 [00:00<…

In [20]:
base_dict={"verse_text":x_new,"label":y_new}
new_df=pd.DataFrame.from_dict(base_dict)
new_df.to_csv("chatgpt_poem_trial1")
new_df=pd.read_csv("chatgpt_poem_trial1")
X_train_base=list(X_train_original)+list(new_df['verse_text'])
y_train_base=list(y_train_original)+list(new_df['label'])
X_train=tokenizer(list(X_train_base),truncation=True, padding=True, max_length=12,return_tensors='pt')
y_train=list(y_train_base)
from torch.utils.data import DataLoader
train_dataloader = DataLoader(MyData(X_train) , batch_size=bat_size,shuffle = False)

In [21]:
model= BertForSequenceClassification.from_pretrained(
    "bert-base-uncased", 
    num_labels = 4,    
    output_attentions = False, 
    output_hidden_states = False,
)
model.to(device)
optimizer = AdamW(model.parameters(),
                  lr = 1e-5, # args.learning_rate - default is 5e-5
                  eps = 1e-8, # args.adam_epsilon  - default is 1e-8
                  weight_decay=0.01,
                  correct_bias=True
                )
epochs = 50

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [22]:
train_lost_epoch=[]
valid_lost_epoch=[]
valid_accuracy_list=[]
for epoch in tqdm(range(epochs),desc='epoch'):
    model.train()
    loss_list=[]
    i=0
    for batch in tqdm(train_dataloader,desc='trainset',leave=False):
        model.zero_grad() 
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        target=torch.tensor(y_train[i:i+bat_size]).to(device)
        i+=bat_size
        outputs = model(input_ids, 
                            token_type_ids, 
                            attention_mask, 
                            labels=target)

        loss = outputs.loss
        logits = outputs.logits
        loss.backward()
        optimizer.step()
        loss_list.append(loss.item())
    train_lost_epoch.append(np.mean(loss_list))
    model.eval()
    loss_list=[]
    i=0
    num_correct=0
    for batch in tqdm(valid_dataloader,desc='validset',leave=False):
        with torch.no_grad():  
            input_ids = batch['input_ids'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)

            target=torch.tensor(y_valid[i:i+bat_size]).to(device)
            i+=bat_size
            outputs = model(input_ids, 
                                token_type_ids, 
                                attention_mask, 
                                labels=target)
            loss = outputs.loss
            logits = outputs.logits
            loss_list.append(loss.item())
            pred_flat = np.argmax(logits.detach().cpu().numpy(), axis=1).flatten()
            labels_flat = target.to('cpu').numpy().flatten()
            num_correct += np.sum(pred_flat == labels_flat)
    valid_accuracy=num_correct/len(y_valid)
    if  epoch==0:
        torch.save(model.state_dict(), "my_model_re")
    else:
        if  valid_accuracy>max(valid_accuracy_list):
            torch.save(model.state_dict(), "my_model_re")
    if epoch>=10:
        if valid_accuracy<min(valid_accuracy_list[-10:-1]):
            print("early stop")
            break
    valid_accuracy_list.append(valid_accuracy)
    valid_lost_epoch.append(np.mean(loss_list))
    print("epoch: "+str(epoch))
    print("train loss: "+str(train_lost_epoch[-1]))
    print("valid loss: "+str(valid_lost_epoch[-1]))
    print("valid accuracy: "+str(valid_accuracy))
    print("——————————————————————————————————————————————")

epoch:   0%|                                                                                    | 0/50 [00:00<…

trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 0
train loss: 1.3171073306690564
valid loss: 1.0668271780014038
valid accuracy: 0.6602870813397129
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 1
train loss: 1.0521593473174355
valid loss: 0.9570034742355347
valid accuracy: 0.6602870813397129
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 2
train loss: 0.9908972490917553
valid loss: 0.9123295843601227
valid accuracy: 0.6602870813397129
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 3
train loss: 0.9242495731873945
valid loss: 0.8706432282924652
valid accuracy: 0.6746411483253588
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 4
train loss: 0.831388229673559
valid loss: 0.8097855746746063
valid accuracy: 0.7033492822966507
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 5
train loss: 0.7547060955654491
valid loss: 0.7879242897033691
valid accuracy: 0.722488038277512
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 6
train loss: 0.6703039407730103
valid loss: 0.7851417660713196
valid accuracy: 0.7416267942583732
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 7
train loss: 0.5863475474444303
valid loss: 0.7871550917625427
valid accuracy: 0.7416267942583732
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 8
train loss: 0.5124301639470187
valid loss: 0.8089621067047119
valid accuracy: 0.7511961722488039
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 9
train loss: 0.43201467936689203
valid loss: 0.8065638542175293
valid accuracy: 0.7464114832535885
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 10
train loss: 0.3759708377447995
valid loss: 0.8246808648109436
valid accuracy: 0.7799043062200957
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 11
train loss: 0.3156012418595227
valid loss: 0.8697594106197357
valid accuracy: 0.7751196172248804
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 12
train loss: 0.2850195521658117
valid loss: 0.8874067068099976
valid accuracy: 0.7368421052631579
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 13
train loss: 0.23346058482473547
valid loss: 0.9206729531288147
valid accuracy: 0.7511961722488039
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 14
train loss: 0.2170107906514948
valid loss: 0.9658450782299042
valid accuracy: 0.7320574162679426
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 15
train loss: 0.19722473350438205
valid loss: 1.0008128583431244
valid accuracy: 0.7320574162679426
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

early stop


In [23]:
print(max(valid_accuracy_list))

0.7799043062200957


### back translation

In [29]:
import os
import openai
import time
from auto_tqdm import tqdm
def chatgpt_to_spain(text):
    file1=open("api_key.txt",'r')
    api_key=file1.read()
    openai.api_key=api_key
    completion = openai.ChatCompletion.create(
    model="gpt-3.5-turbo",
    messages=[{'role': 'user', 'content': "Please translate this sentence to Spanish while retaining the emotion:"+text}])
    return completion.choices[0].message.content
def chatgpt_to_english(text):
    file1=open("api_key.txt",'r')
    api_key=file1.read()
    openai.api_key=api_key
    completion = openai.ChatCompletion.create(
    model="gpt-3.5-turbo",
    messages=[{'role': 'user', 'content': "Please translate this sentence to English:"+text}])
    return completion.choices[0].message.content

In [30]:
# x_new=[]
# y_new=[]
# for i in tqdm(range(len(X_train_original))):
#     count=0
#     while(count!=1):
#         try:
#             temp=chatgpt_to_spain(X_train_original[i])
#             ans=chatgpt_to_english(temp)
#             x_new.append(ans)
#             y_new.append(y_train_original[i])
#             count=1  
#         except:
#             time.sleep(0.5)

In [31]:
# base_dict={"text":x_new,"label":y_new}
# new_df=pd.DataFrame.from_dict(base_dict)
# new_df.to_csv("poem_back_translate")
new_df=pd.read_csv("poem_back_translate")
X_train_base=list(X_train_original)+list(new_df['text'])
y_train_base=list(y_train_original)+list(new_df['label'])
X_train=tokenizer(list(X_train_base),truncation=True, padding=True, max_length=12,return_tensors='pt')
y_train=list(y_train_base)
from torch.utils.data import DataLoader
train_dataloader = DataLoader(MyData(X_train) , batch_size=bat_size,shuffle = False)

In [32]:
model= BertForSequenceClassification.from_pretrained(
    "bert-base-uncased", 
    num_labels = 4,    
    output_attentions = False, 
    output_hidden_states = False,
)
model.to(device)
optimizer = AdamW(model.parameters(),
                  lr = 1e-5, # args.learning_rate - default is 5e-5
                  eps = 1e-8, # args.adam_epsilon  - default is 1e-8
                  weight_decay=0.01,
                  correct_bias=True
                )
epochs = 50

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [33]:
train_lost_epoch=[]
valid_lost_epoch=[]
valid_accuracy_list=[]
for epoch in tqdm(range(epochs),desc='epoch'):
    model.train()
    loss_list=[]
    i=0
    for batch in tqdm(train_dataloader,desc='trainset',leave=False):
        model.zero_grad() 
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        target=torch.tensor(y_train[i:i+bat_size]).to(device)
        i+=bat_size
        outputs = model(input_ids, 
                            token_type_ids, 
                            attention_mask, 
                            labels=target)

        loss = outputs.loss
        logits = outputs.logits
        loss.backward()
        optimizer.step()
        loss_list.append(loss.item())
    train_lost_epoch.append(np.mean(loss_list))
    model.eval()
    loss_list=[]
    i=0
    num_correct=0
    for batch in tqdm(valid_dataloader,desc='validset',leave=False):
        with torch.no_grad():  
            input_ids = batch['input_ids'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)

            target=torch.tensor(y_valid[i:i+bat_size]).to(device)
            i+=bat_size
            outputs = model(input_ids, 
                                token_type_ids, 
                                attention_mask, 
                                labels=target)
            loss = outputs.loss
            logits = outputs.logits
            loss_list.append(loss.item())
            pred_flat = np.argmax(logits.detach().cpu().numpy(), axis=1).flatten()
            labels_flat = target.to('cpu').numpy().flatten()
            num_correct += np.sum(pred_flat == labels_flat)
    valid_accuracy=num_correct/len(y_valid)
    if  epoch==0:
        torch.save(model.state_dict(), "my_model_re")
    else:
        if  valid_accuracy>max(valid_accuracy_list):
            torch.save(model.state_dict(), "my_model_re")
    if epoch>=10:
        if valid_accuracy<min(valid_accuracy_list[-10:-1]):
            print("early stop")
            break
    valid_accuracy_list.append(valid_accuracy)
    valid_lost_epoch.append(np.mean(loss_list))
    print("epoch: "+str(epoch))
    print("train loss: "+str(train_lost_epoch[-1]))
    print("valid loss: "+str(valid_lost_epoch[-1]))
    print("valid accuracy: "+str(valid_accuracy))
    print("——————————————————————————————————————————————")

epoch:   0%|                                                                                    | 0/50 [00:00<…

trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 0
train loss: 1.1549361835826526
valid loss: 1.0173993706703186
valid accuracy: 0.6602870813397129
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 1
train loss: 1.0301598202098499
valid loss: 0.9372580349445343
valid accuracy: 0.6602870813397129
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 2
train loss: 0.9770511497150768
valid loss: 0.903009444475174
valid accuracy: 0.6602870813397129
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 3
train loss: 0.9268957539038225
valid loss: 0.8662103712558746
valid accuracy: 0.69377990430622
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 4
train loss: 0.8472975546663458
valid loss: 0.8317634165287018
valid accuracy: 0.6985645933014354
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 5
train loss: 0.7623242519118569
valid loss: 0.8185634016990662
valid accuracy: 0.7416267942583732
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 6
train loss: 0.6701455333016135
valid loss: 0.8068921566009521
valid accuracy: 0.7464114832535885
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 7
train loss: 0.5945433080196381
valid loss: 0.8652442693710327
valid accuracy: 0.722488038277512
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 8
train loss: 0.5152003060687672
valid loss: 0.8473688662052155
valid accuracy: 0.7320574162679426
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 9
train loss: 0.4492676935412667
valid loss: 0.8089805245399475
valid accuracy: 0.7607655502392344
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 10
train loss: 0.43035573850978504
valid loss: 0.9546825885772705
valid accuracy: 0.7177033492822966
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 11
train loss: 0.41098006746985694
valid loss: 1.0963571667671204
valid accuracy: 0.7033492822966507
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 12
train loss: 0.485912176695737
valid loss: 1.1784954071044922
valid accuracy: 0.6842105263157895
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 13
train loss: 0.5577012002468109
valid loss: 0.899105429649353
valid accuracy: 0.7129186602870813
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 14
train loss: 0.4090526916763999
valid loss: 0.9652886092662811
valid accuracy: 0.7177033492822966
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 15
train loss: 0.3581135733561082
valid loss: 0.8467852771282196
valid accuracy: 0.7511961722488039
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 16
train loss: 0.28082134506919165
valid loss: 0.7694830894470215
valid accuracy: 0.7655502392344498
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 17
train loss: 0.21516680988398465
valid loss: 0.8227570652961731
valid accuracy: 0.7511961722488039
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 18
train loss: 0.18697731603275647
valid loss: 0.9097400009632111
valid accuracy: 0.7511961722488039
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 19
train loss: 0.1584176536310803
valid loss: 0.8790696263313293
valid accuracy: 0.7607655502392344
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 20
train loss: 0.135620308870619
valid loss: 0.9718517065048218
valid accuracy: 0.7559808612440191
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 21
train loss: 0.13828680190173062
valid loss: 0.9618074595928192
valid accuracy: 0.7511961722488039
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 22
train loss: 0.10726982626048001
valid loss: 0.9804539084434509
valid accuracy: 0.7416267942583732
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 23
train loss: 0.11052980273962021
valid loss: 1.0719385147094727
valid accuracy: 0.7464114832535885
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

early stop


In [34]:
print(max(valid_accuracy_list))

0.7655502392344498


### p2

In [18]:
import os
import openai
import time
from auto_tqdm import tqdm
def rephrase_by_chatgptp2(text):
    file1=open("api_key.txt",'r')
    api_key=file1.read()
    openai.api_key=api_key
    completion = openai.ChatCompletion.create(
    model="gpt-3.5-turbo",
    messages=[{'role': 'user', 'content': "Please rephrase the following sentences from poem but keep their emotion " +text}])
    return completion.choices[0].message.content
x_new=[]
y_new=[]
for i in tqdm(range(len(X_train_original))):
    count=0
    while(count!=1):
        try:
            ans=rephrase_by_chatgptp2(X_train_original[i])
            x_new.append(ans)
            y_new.append(y_train_original[i])
            count=1  
        except:
            time.sleep(0.5)
base_dict={"verse_text":x_new,"label":y_new}
new_df=pd.DataFrame.from_dict(base_dict)
new_df.to_csv("chatgpt_poem_p2")
new_df=pd.read_csv("chatgpt_poem_p2")
X_train_base=list(X_train_original)+list(new_df['verse_text'])
y_train_base=list(y_train_original)+list(new_df['label'])
X_train=tokenizer(list(X_train_base),truncation=True, padding=True, max_length=12,return_tensors='pt')
y_train=list(y_train_base)

  0%|                                                                                          | 0/700 [00:00<…

In [19]:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(MyData(X_train) , batch_size=bat_size,shuffle = False)
model= BertForSequenceClassification.from_pretrained(
    "bert-base-uncased", 
    num_labels = 4,    
    output_attentions = False, 
    output_hidden_states = False,
)
model.to(device)
optimizer = AdamW(model.parameters(),
                  lr = 1e-5, # args.learning_rate - default is 5e-5
                  eps = 1e-8, # args.adam_epsilon  - default is 1e-8
                  weight_decay=0.01,
                  correct_bias=True
                )
epochs = 50
train_lost_epoch=[]
valid_lost_epoch=[]
valid_accuracy_list=[]
for epoch in tqdm(range(epochs),desc='epoch'):
    model.train()
    loss_list=[]
    i=0
    for batch in tqdm(train_dataloader,desc='trainset',leave=False):
        model.zero_grad() 
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        target=torch.tensor(y_train[i:i+bat_size]).to(device)
        i+=bat_size
        outputs = model(input_ids, 
                            token_type_ids, 
                            attention_mask, 
                            labels=target)

        loss = outputs.loss
        logits = outputs.logits
        loss.backward()
        optimizer.step()
        loss_list.append(loss.item())
    train_lost_epoch.append(np.mean(loss_list))
    model.eval()
    loss_list=[]
    i=0
    num_correct=0
    for batch in tqdm(valid_dataloader,desc='validset',leave=False):
        with torch.no_grad():  
            input_ids = batch['input_ids'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)

            target=torch.tensor(y_valid[i:i+bat_size]).to(device)
            i+=bat_size
            outputs = model(input_ids, 
                                token_type_ids, 
                                attention_mask, 
                                labels=target)
            loss = outputs.loss
            logits = outputs.logits
            loss_list.append(loss.item())
            pred_flat = np.argmax(logits.detach().cpu().numpy(), axis=1).flatten()
            labels_flat = target.to('cpu').numpy().flatten()
            num_correct += np.sum(pred_flat == labels_flat)
    valid_accuracy=num_correct/len(y_valid)
    if  epoch==0:
        torch.save(model.state_dict(), "my_model_re")
    else:
        if  valid_accuracy>max(valid_accuracy_list):
            torch.save(model.state_dict(), "my_model_re")
    if epoch>=10:
        if valid_accuracy<min(valid_accuracy_list[-10:-1]):
            print("early stop")
            break
    valid_accuracy_list.append(valid_accuracy)
    valid_lost_epoch.append(np.mean(loss_list))
    print("epoch: "+str(epoch))
    print("train loss: "+str(train_lost_epoch[-1]))
    print("valid loss: "+str(valid_lost_epoch[-1]))
    print("valid accuracy: "+str(valid_accuracy))
    print("——————————————————————————————————————————————")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

epoch:   0%|                                                                                    | 0/50 [00:00<…

trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 0
train loss: 1.253892887722362
valid loss: 1.0453683137893677
valid accuracy: 0.6602870813397129
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 1
train loss: 1.0515467741272666
valid loss: 0.950000673532486
valid accuracy: 0.6602870813397129
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 2
train loss: 1.02396652915261
valid loss: 0.9179369807243347
valid accuracy: 0.6602870813397129
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 3
train loss: 0.9966316223144531
valid loss: 0.8925484120845795
valid accuracy: 0.6602870813397129
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 4
train loss: 0.9502051310105757
valid loss: 0.8862158954143524
valid accuracy: 0.6602870813397129
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 5
train loss: 0.8848507675257596
valid loss: 0.8568534851074219
valid accuracy: 0.6602870813397129
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 6
train loss: 0.8265145800330422
valid loss: 0.8962196111679077
valid accuracy: 0.69377990430622
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 7
train loss: 0.8126004121520303
valid loss: 0.8868883848190308
valid accuracy: 0.6889952153110048
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 8
train loss: 0.7462911660021002
valid loss: 0.8780803084373474
valid accuracy: 0.7033492822966507
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 9
train loss: 0.6565853736617349
valid loss: 0.964288592338562
valid accuracy: 0.7033492822966507
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 10
train loss: 0.583088295026259
valid loss: 0.9758772850036621
valid accuracy: 0.7177033492822966
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 11
train loss: 0.5833671499382366
valid loss: 0.98283252120018
valid accuracy: 0.7129186602870813
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 12
train loss: 0.5531266765160994
valid loss: 1.0765891075134277
valid accuracy: 0.6985645933014354
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 13
train loss: 0.6542212096127596
valid loss: 0.8122573792934418
valid accuracy: 0.7177033492822966
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 14
train loss: 0.6046023775230754
valid loss: 0.7780191600322723
valid accuracy: 0.7081339712918661
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 15
train loss: 0.48304173621264374
valid loss: 0.8940348923206329
valid accuracy: 0.722488038277512
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 16
train loss: 0.4355633827773007
valid loss: 0.8328381478786469
valid accuracy: 0.7607655502392344
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

early stop


In [20]:
print(max(valid_accuracy_list))

0.7607655502392344


## p3

In [28]:
import os
import openai
import time
from auto_tqdm import tqdm
def rephrase_by_chatgptp3(text):
    file1=open("api_key.txt",'r')
    api_key=file1.read()
    openai.api_key=api_key
    completion = openai.ChatCompletion.create(
    model="gpt-3.5-turbo",
    messages=[{'role': 'user', 'content': "Please rephrase the following sentences to around 12 length words without changing emotion: " + text}])
    return completion.choices[0].message.content
x_new=[]
y_new=[]
for i in tqdm(range(len(X_train_original))):
    count=0
    while(count!=1):
        try:
            ans=rephrase_by_chatgptp3(X_train_original[i])
            x_new.append(ans)
            y_new.append(y_train_original[i])
            count=1  
        except:
            time.sleep(0.5)
base_dict={"verse_text":x_new,"label":y_new}
new_df=pd.DataFrame.from_dict(base_dict)
new_df.to_csv("chatgpt_poem_p3")
new_df=pd.read_csv("chatgpt_poem_p3")
X_train_base=list(X_train_original)+list(new_df['verse_text'])
y_train_base=list(y_train_original)+list(new_df['label'])
X_train=tokenizer(list(X_train_base),truncation=True, padding=True, max_length=12,return_tensors='pt')
y_train=list(y_train_base)

  0%|                                                                                          | 0/700 [00:00<…

In [33]:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(MyData(X_train) , batch_size=bat_size,shuffle = False)
model= BertForSequenceClassification.from_pretrained(
    "bert-base-uncased", 
    num_labels = 4,    
    output_attentions = False, 
    output_hidden_states = False,
)
model.to(device)
optimizer = AdamW(model.parameters(),
                  lr = 1e-5, # args.learning_rate - default is 5e-5
                  eps = 1e-8, # args.adam_epsilon  - default is 1e-8
                  weight_decay=0.01,
                  correct_bias=True
                )
epochs = 50
train_lost_epoch=[]
valid_lost_epoch=[]
valid_accuracy_list=[]
for epoch in tqdm(range(epochs),desc='epoch'):
    model.train()
    loss_list=[]
    i=0
    for batch in tqdm(train_dataloader,desc='trainset',leave=False):
        model.zero_grad() 
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        target=torch.tensor(y_train[i:i+bat_size]).to(device)
        i+=bat_size
        outputs = model(input_ids, 
                            token_type_ids, 
                            attention_mask, 
                            labels=target)

        loss = outputs.loss
        logits = outputs.logits
        loss.backward()
        optimizer.step()
        loss_list.append(loss.item())
    train_lost_epoch.append(np.mean(loss_list))
    model.eval()
    loss_list=[]
    i=0
    num_correct=0
    for batch in tqdm(valid_dataloader,desc='validset',leave=False):
        with torch.no_grad():  
            input_ids = batch['input_ids'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)

            target=torch.tensor(y_valid[i:i+bat_size]).to(device)
            i+=bat_size
            outputs = model(input_ids, 
                                token_type_ids, 
                                attention_mask, 
                                labels=target)
            loss = outputs.loss
            logits = outputs.logits
            loss_list.append(loss.item())
            pred_flat = np.argmax(logits.detach().cpu().numpy(), axis=1).flatten()
            labels_flat = target.to('cpu').numpy().flatten()
            num_correct += np.sum(pred_flat == labels_flat)
    valid_accuracy=num_correct/len(y_valid)
    if  epoch==0:
        torch.save(model.state_dict(), "my_model_re")
    else:
        if  valid_accuracy>max(valid_accuracy_list):
            torch.save(model.state_dict(), "my_model_re")
    if epoch>=10:
        if valid_accuracy<min(valid_accuracy_list[-10:-1]):
            print("early stop")
            break
    valid_accuracy_list.append(valid_accuracy)
    valid_lost_epoch.append(np.mean(loss_list))
    print("epoch: "+str(epoch))
    print("train loss: "+str(train_lost_epoch[-1]))
    print("valid loss: "+str(valid_lost_epoch[-1]))
    print("valid accuracy: "+str(valid_accuracy))
    print("——————————————————————————————————————————————")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

epoch:   0%|                                                                                    | 0/50 [00:00<…

trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 0
train loss: 1.2254488847472451
valid loss: 1.0295893549919128
valid accuracy: 0.6602870813397129
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 1
train loss: 1.0400849255648525
valid loss: 0.9459940493106842
valid accuracy: 0.6602870813397129
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 2
train loss: 0.9994429078969088
valid loss: 0.9175613820552826
valid accuracy: 0.6602870813397129
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 3
train loss: 0.9606176885691556
valid loss: 0.8888067901134491
valid accuracy: 0.6602870813397129
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 4
train loss: 0.8805646300315857
valid loss: 0.8561734259128571
valid accuracy: 0.6985645933014354
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 5
train loss: 0.7922210422429171
valid loss: 0.8228768408298492
valid accuracy: 0.7272727272727273
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 6
train loss: 0.7076939994638617
valid loss: 0.821491926908493
valid accuracy: 0.7320574162679426
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 7
train loss: 0.6457946300506592
valid loss: 0.8034060597419739
valid accuracy: 0.7416267942583732
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 8
train loss: 0.5720845786007968
valid loss: 0.7612061500549316
valid accuracy: 0.7416267942583732
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 9
train loss: 0.5247821807861328
valid loss: 0.7918435335159302
valid accuracy: 0.7511961722488039
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 10
train loss: 0.4662241041660309
valid loss: 0.8450749814510345
valid accuracy: 0.722488038277512
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 11
train loss: 0.41615560921755707
valid loss: 0.8759326040744781
valid accuracy: 0.7368421052631579
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 12
train loss: 0.3571621667255055
valid loss: 0.9384317696094513
valid accuracy: 0.7368421052631579
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 13
train loss: 0.33147365938533435
valid loss: 0.9410344660282135
valid accuracy: 0.7320574162679426
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 14
train loss: 0.2950957945801995
valid loss: 1.0116201043128967
valid accuracy: 0.7320574162679426
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 15
train loss: 0.2839023606343703
valid loss: 0.9724260568618774
valid accuracy: 0.7368421052631579
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

epoch: 16
train loss: 0.2589221379973672
valid loss: 0.9660221636295319
valid accuracy: 0.7368421052631579
——————————————————————————————————————————————


trainset:   0%|                                                                                 | 0/11 [00:00<…

validset:   0%|                                                                                  | 0/2 [00:00<…

early stop


In [34]:
print(max(valid_accuracy_list))

0.7511961722488039
