In [9]:
import pandas as pd
import pickle
import numpy as np
from torch.utils.data import Dataset,DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AdamW
from operator import itemgetter
from sklearn.model_selection import StratifiedKFold

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
torch.cuda.set_device(0)

In [10]:
label2id = pickle.load(open('../temp_results/mini_label2id_dict.pkl','rb'))
id2label = pickle.load(open('../temp_results/mini_id2label_lst.pkl','rb'))

In [11]:
train_data = pd.read_csv('../data/Patent14K/train.csv')
test_data = pd.read_csv('../data/Patent14K/test.csv')

In [12]:
from transformers import AutoModelForMaskedLM,AutoTokenizer,BertConfig
from collections import OrderedDict
save_path = './patent_bert_simcse/simcsepatent_bs64.pth'
loaded_dict = torch.load(save_path)
# 核心问题是：如何去掉权重字典键名中的"module"，以保证模型的统一性。
new_state_dict = OrderedDict()
for k, v in loaded_dict.items():
    name = k[7:] # module字段在最前面，从第7个字符开始就可以去掉module
    new_state_dict[name] = v #新字典的key值对应的value一一对应
    
model_path = 'anferico/bert-for-patents'
tokenizer = AutoTokenizer.from_pretrained(model_path)
Config = BertConfig.from_pretrained(model_path)
Config.attention_probs_dropout_prob = 0.1
Config.hidden_dropout_prob = 0.1
output_way = 'pooler'

class NeuralNetwork(nn.Module):
    def __init__(self,model_path,output_way):
        super(NeuralNetwork, self).__init__()
        self.bert = AutoModelForMaskedLM.from_pretrained(model_path,config=Config)
        self.output_way = output_way
    def forward(self, input_ids, attention_mask, token_type_ids):
        x1 = self.bert(input_ids = input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids,output_hidden_states=True)
        if self.output_way == 'cls':
            output = x1.hidden_states[-1][:,0]
        elif self.output_way == 'pooler':
            output = x1.hidden_states[-1].mean(dim=1)
        return output
    
model = NeuralNetwork(model_path,output_way)
model.load_state_dict(new_state_dict)

Some weights of the model checkpoint at anferico/bert-for-patents were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<All keys matched successfully>

# Data Loader

In [13]:
def str2id_lst(str_label):
    id_lst = []
    for l in str_label.split(','):
        id_lst.append(label2id[l])
    return id_lst

class PatentDataset(Dataset):
    def __init__(self,df,labeled = True):
        self.df = df
        self.labeled = labeled
        
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self,idx):
        text = self.df.iloc[idx]['text'][3:]
        label = str2id_lst(self.df.iloc[idx]['cpc_ids'])
        
        if self.labeled:
            return text,label
        else:
            return text,None
        

In [14]:
test_dataset = PatentDataset(test_data)

In [15]:
def collate_fn(data):
    sents = [i[0] for i in data]
    labels = [i[1] for i in data]
    
    data = tokenizer.batch_encode_plus(batch_text_or_text_pairs=sents,
                                      truncation=True,
                                      padding='max_length',
                                      max_length=128,
                                      return_tensors='pt',
                                      return_length=True)
    input_ids = data['input_ids']
    attention_mask = data['attention_mask']
    token_type_ids = data['token_type_ids']
    
    batch_label = np.zeros((len(labels),len(id2label)))
    for i,_label in enumerate(labels):
        batch_label[i,_label]=1
    
    batch_label = torch.tensor(batch_label,dtype=torch.float32)
    
    return input_ids, attention_mask, token_type_ids, batch_label
    

In [16]:
test_dataloader = DataLoader(dataset = test_dataset,
                            batch_size = 4,
                            collate_fn = collate_fn)

# Define Model

In [17]:
class PatentClsModel(nn.Module):
    def __init__(self,bert_model,backbone_fixed = True):
        super().__init__()
        self.fc = nn.Sequential(nn.BatchNorm1d(1024),
                               nn.Dropout(0.5),
                               nn.Linear(1024,768),
                               nn.ReLU(),
                               nn.BatchNorm1d(768),
                               nn.Dropout(0.5),
                               nn.Linear(768,len(id2label)))
        
        self.bert_model = bert_model
        self.sig = nn.Sigmoid()
        self.backbone_fixed = backbone_fixed
        
        for i, module in enumerate(self.fc):
            if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d)):
                nn.init.constant_(module.weight, 1)
                nn.init.constant_(module.bias, 0)
            elif isinstance(module, nn.Linear):
                if getattr(module, "weight_v", None) is not None:
                    nn.init.uniform_(module.weight_g, 0, 1)
                    nn.init.kaiming_normal_(module.weight_v)
                    assert model[i].weight_g is not None
                else:
                    nn.init.kaiming_normal_(module.weight)
                nn.init.constant_(module.bias, 0)

    def forward(self, input_ids, attention_mask, token_type_ids):
        if self.backbone_fixed:
            with torch.no_grad():
                x = self.bert_model(input_ids = input_ids,
                                    attention_mask = attention_mask,
                                    token_type_ids = token_type_ids)
        else:
            x = self.bert_model(input_ids = input_ids,
                                attention_mask = attention_mask,
                                token_type_ids = token_type_ids)
            
        x = self.fc(x)
        x = self.sig(x)
        
        return x
        

# Training

In [18]:
kfold = StratifiedKFold(n_splits=5)
total_epochs = 30
test_predict_lst = []

In [19]:
from tqdm import tqdm
for train_index, valid_index in kfold.split(train_data,train_data['cpc_ids']):
    
    print('*'*20)
    print(f'Fold{len(test_predict_lst)+1}')
    print('*'*20)
    train_dataset = PatentDataset(train_data.iloc[train_index])
    val_dataset = PatentDataset(train_data.iloc[valid_index])

    train_dataloader = DataLoader(train_dataset,
                                 collate_fn = collate_fn,
                                 batch_size = 4,
                                 shuffle = True,
                                 drop_last = True)
    val_dataloader = DataLoader(val_dataset,
                               collate_fn = collate_fn,
                               batch_size = 4,
                               shuffle = True,
                               drop_last = True)

    patentModel = PatentClsModel(model,backbone_fixed = True).cuda()
    loss_func = nn.BCELoss()
    optimizer = AdamW(patentModel.parameters(), lr=5e-4)
    # reg_lambda = 0.035

    print('Dataloader Success---------------------')

    best_val_loss = 100
    for epoch in range(total_epochs):
        if epoch%5==0:
            print('|',">" * epoch," "*(total_epochs-epoch),'|')

        patentModel.train()
        for iter,(input_ids, attention_mask, token_type_ids, batch_label) in enumerate(tqdm(train_dataloader)):
            input_ids = input_ids.cuda()
            attention_mask = attention_mask.cuda()
            token_type_ids = token_type_ids.cuda()
            batch_label = batch_label.cuda()

            prediction = patentModel(input_ids, attention_mask, token_type_ids)
            
            # l2_reg = None
            # for w in patentModel.fc.parameters():
            #     if not l2_reg:
            #         l2_reg = w.norm(2)
            #     else:
            #         l2_reg = l2_reg + w.norm(2)

            loss = loss_func(prediction,batch_label)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        val_loss = 0
        patentModel.eval()
        with torch.no_grad():
            for iter,(input_ids, attention_mask, token_type_ids, batch_label) in enumerate(tqdm(val_dataloader)):
                input_ids = input_ids.cuda()
                attention_mask = attention_mask.cuda()
                token_type_ids = token_type_ids.cuda()
                batch_label = batch_label.cuda()
                prediction = patentModel(input_ids, attention_mask, token_type_ids)
                loss = loss_func(prediction,batch_label)
                val_loss += loss.detach().item()
            val_loss = val_loss/(iter+1)

        if epoch%10 == 0:
            print('Epoch {}, val_loss {:.4f}'.format(epoch, val_loss))

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(patentModel.state_dict(), 'ckpt/001/best_model_mini_{}.pth'.format(len(test_predict_lst)+1))
            print('Best val loss found: ', best_val_loss)

    print('This fold, the best val loss is: ', best_val_loss)

    test_loss = 0
    test_predict = None
    patentModel = PatentClsModel(model,backbone_fixed = True).cuda()
    patentModel.load_state_dict(torch.load('ckpt/001/best_model_mini_{}.pth'.format(len(test_predict_lst)+1)))

    patentModel.eval()
    with torch.no_grad():
        for iter,(input_ids, attention_mask, token_type_ids, batch_label) in enumerate(tqdm(test_dataloader)):
            input_ids = input_ids.cuda()
            attention_mask = attention_mask.cuda()
            token_type_ids = token_type_ids.cuda()
            batch_label = batch_label.cuda()
            prediction = patentModel(input_ids, attention_mask, token_type_ids)

            if test_predict is None:
                test_predict = prediction
            else:
                test_predict = torch.cat((test_predict,prediction),axis = 0)

            loss = loss_func(prediction,batch_label)
            test_loss += loss.detach().item()

    test_loss /= (iter+1)
    print('This fold, the test loss is: ', test_loss)

    test_predict_lst.append(test_predict)



********************
Fold1
********************




Dataloader Success---------------------
|                                 |


100%|██████████| 2800/2800 [01:36<00:00, 29.01it/s]
100%|██████████| 700/700 [00:20<00:00, 33.91it/s]


Epoch 0, val_loss 0.0147
Best val loss found:  0.014712733714176076


100%|██████████| 2800/2800 [01:35<00:00, 29.19it/s]
100%|██████████| 700/700 [00:20<00:00, 34.46it/s]


Best val loss found:  0.012028435957950672


100%|██████████| 2800/2800 [01:34<00:00, 29.58it/s]
100%|██████████| 700/700 [00:20<00:00, 33.90it/s]


Best val loss found:  0.010937071276961693


100%|██████████| 2800/2800 [01:34<00:00, 29.60it/s]
100%|██████████| 700/700 [00:20<00:00, 34.45it/s]


Best val loss found:  0.010581607057247311


100%|██████████| 2800/2800 [01:35<00:00, 29.38it/s]
100%|██████████| 700/700 [00:20<00:00, 33.95it/s]


Best val loss found:  0.010370161149185151
| >>>>>                           |


100%|██████████| 2800/2800 [01:35<00:00, 29.27it/s]
100%|██████████| 700/700 [00:19<00:00, 35.71it/s]


Best val loss found:  0.010179420153477363


100%|██████████| 2800/2800 [01:34<00:00, 29.55it/s]
100%|██████████| 700/700 [00:20<00:00, 33.67it/s]


Best val loss found:  0.010105474773029397


100%|██████████| 2800/2800 [01:34<00:00, 29.69it/s]
100%|██████████| 700/700 [00:20<00:00, 34.42it/s]


Best val loss found:  0.00997928097627924


100%|██████████| 2800/2800 [01:34<00:00, 29.71it/s]
100%|██████████| 700/700 [00:20<00:00, 34.33it/s]


Best val loss found:  0.009918309169921227


100%|██████████| 2800/2800 [01:34<00:00, 29.60it/s]
100%|██████████| 700/700 [00:20<00:00, 33.89it/s]


Best val loss found:  0.00977915283708301
| >>>>>>>>>>                      |


100%|██████████| 2800/2800 [01:35<00:00, 29.30it/s]
100%|██████████| 700/700 [00:20<00:00, 34.91it/s]


Epoch 10, val_loss 0.0097
Best val loss found:  0.009705506754606696


100%|██████████| 2800/2800 [01:35<00:00, 29.32it/s]
100%|██████████| 700/700 [00:20<00:00, 34.26it/s]
100%|██████████| 2800/2800 [01:34<00:00, 29.52it/s]
100%|██████████| 700/700 [00:19<00:00, 35.49it/s]


Best val loss found:  0.009667234823407074


100%|██████████| 2800/2800 [01:34<00:00, 29.58it/s]
100%|██████████| 700/700 [00:20<00:00, 33.90it/s]
100%|██████████| 2800/2800 [01:34<00:00, 29.65it/s]
100%|██████████| 700/700 [00:20<00:00, 34.12it/s]


Best val loss found:  0.009652071629285015
| >>>>>>>>>>>>>>>                 |


100%|██████████| 2800/2800 [01:34<00:00, 29.63it/s]
100%|██████████| 700/700 [00:20<00:00, 34.19it/s]
100%|██████████| 2800/2800 [01:33<00:00, 29.87it/s]
100%|██████████| 700/700 [00:20<00:00, 34.94it/s]


Best val loss found:  0.009487149527495993


100%|██████████| 2800/2800 [01:35<00:00, 29.35it/s]
100%|██████████| 700/700 [00:20<00:00, 34.24it/s]
100%|██████████| 2800/2800 [01:35<00:00, 29.37it/s]
100%|██████████| 700/700 [00:20<00:00, 34.38it/s]
100%|██████████| 2800/2800 [01:30<00:00, 30.78it/s]
100%|██████████| 700/700 [00:20<00:00, 34.42it/s]


| >>>>>>>>>>>>>>>>>>>>            |


100%|██████████| 2800/2800 [01:34<00:00, 29.63it/s]
100%|██████████| 700/700 [00:20<00:00, 34.54it/s]


Epoch 20, val_loss 0.0094
Best val loss found:  0.009439601032894903


100%|██████████| 2800/2800 [01:35<00:00, 29.38it/s]
100%|██████████| 700/700 [00:20<00:00, 34.51it/s]


Best val loss found:  0.009429808007386912


100%|██████████| 2800/2800 [01:33<00:00, 29.82it/s]
100%|██████████| 700/700 [00:20<00:00, 34.04it/s]
100%|██████████| 2800/2800 [01:34<00:00, 29.57it/s]
100%|██████████| 700/700 [00:20<00:00, 34.15it/s]
100%|██████████| 2800/2800 [01:35<00:00, 29.42it/s]
100%|██████████| 700/700 [00:20<00:00, 34.15it/s]


Best val loss found:  0.009358669734959092
| >>>>>>>>>>>>>>>>>>>>>>>>>       |


100%|██████████| 2800/2800 [01:33<00:00, 29.88it/s]
100%|██████████| 700/700 [00:20<00:00, 34.00it/s]
100%|██████████| 2800/2800 [01:32<00:00, 30.35it/s]
100%|██████████| 700/700 [00:19<00:00, 35.39it/s]
100%|██████████| 2800/2800 [01:28<00:00, 31.81it/s]
100%|██████████| 700/700 [00:20<00:00, 34.11it/s]
100%|██████████| 2800/2800 [01:37<00:00, 28.80it/s]
100%|██████████| 700/700 [00:20<00:00, 33.60it/s]
100%|██████████| 2800/2800 [01:40<00:00, 27.90it/s]
100%|██████████| 700/700 [00:21<00:00, 32.75it/s]


Best val loss found:  0.009314788177476397
This fold, the best val loss is:  0.009314788177476397


100%|██████████| 1000/1000 [00:31<00:00, 32.24it/s]


This fold, the test loss is:  0.010924933600006626
********************
Fold2
********************
Dataloader Success---------------------
|                                 |


100%|██████████| 2800/2800 [01:38<00:00, 28.40it/s]
100%|██████████| 700/700 [00:20<00:00, 33.70it/s]


Epoch 0, val_loss 0.0144
Best val loss found:  0.014409687642141112


100%|██████████| 2800/2800 [01:38<00:00, 28.41it/s]
100%|██████████| 700/700 [00:21<00:00, 33.07it/s]


Best val loss found:  0.011639381488951455


100%|██████████| 2800/2800 [01:37<00:00, 28.77it/s]
100%|██████████| 700/700 [00:20<00:00, 34.04it/s]


Best val loss found:  0.010742284220842911


100%|██████████| 2800/2800 [01:36<00:00, 28.88it/s]
100%|██████████| 700/700 [00:20<00:00, 34.03it/s]


Best val loss found:  0.010339744195641418


100%|██████████| 2800/2800 [01:36<00:00, 28.94it/s]
100%|██████████| 700/700 [00:20<00:00, 33.62it/s]


Best val loss found:  0.010155534508272206
| >>>>>                           |


100%|██████████| 2800/2800 [01:36<00:00, 28.93it/s]
100%|██████████| 700/700 [00:20<00:00, 33.48it/s]


Best val loss found:  0.009986905287286001


100%|██████████| 2800/2800 [01:36<00:00, 28.90it/s]
100%|██████████| 700/700 [00:20<00:00, 33.65it/s]


Best val loss found:  0.009898087932462139


100%|██████████| 2800/2800 [01:39<00:00, 28.21it/s]
100%|██████████| 700/700 [00:21<00:00, 32.67it/s]


Best val loss found:  0.009862047581201685


100%|██████████| 2800/2800 [01:40<00:00, 27.79it/s]
100%|██████████| 700/700 [00:21<00:00, 32.59it/s]


Best val loss found:  0.009683475868244256


100%|██████████| 2800/2800 [01:41<00:00, 27.70it/s]
100%|██████████| 700/700 [00:21<00:00, 32.58it/s]


Best val loss found:  0.009615837971047897
| >>>>>>>>>>                      |


100%|██████████| 2800/2800 [01:40<00:00, 27.90it/s]
100%|██████████| 700/700 [00:21<00:00, 32.42it/s]


Epoch 10, val_loss 0.0096


100%|██████████| 2800/2800 [01:38<00:00, 28.35it/s]
100%|██████████| 700/700 [00:21<00:00, 33.03it/s]


Best val loss found:  0.009547143265039528


100%|██████████| 2800/2800 [01:39<00:00, 28.20it/s]
100%|██████████| 700/700 [00:21<00:00, 32.75it/s]


Best val loss found:  0.009546760203416592


100%|██████████| 2800/2800 [01:40<00:00, 27.96it/s]
100%|██████████| 700/700 [00:21<00:00, 32.79it/s]


Best val loss found:  0.009496605894694638


100%|██████████| 2800/2800 [01:39<00:00, 28.21it/s]
100%|██████████| 700/700 [00:21<00:00, 33.09it/s]


Best val loss found:  0.00944284460740164
| >>>>>>>>>>>>>>>                 |


100%|██████████| 2800/2800 [01:39<00:00, 28.18it/s]
100%|██████████| 700/700 [00:21<00:00, 33.05it/s]
100%|██████████| 2800/2800 [01:38<00:00, 28.35it/s]
100%|██████████| 700/700 [00:21<00:00, 32.78it/s]


Best val loss found:  0.009435201397365225


100%|██████████| 2800/2800 [01:39<00:00, 28.27it/s]
100%|██████████| 700/700 [00:20<00:00, 33.58it/s]


Best val loss found:  0.009373378960715075


100%|██████████| 2800/2800 [01:38<00:00, 28.37it/s]
100%|██████████| 700/700 [00:21<00:00, 32.58it/s]
100%|██████████| 2800/2800 [01:39<00:00, 28.06it/s]
100%|██████████| 700/700 [00:21<00:00, 32.11it/s]


Best val loss found:  0.009360685756130676
| >>>>>>>>>>>>>>>>>>>>            |


100%|██████████| 2800/2800 [01:38<00:00, 28.35it/s]
100%|██████████| 700/700 [00:21<00:00, 33.21it/s]


Epoch 20, val_loss 0.0094


100%|██████████| 2800/2800 [01:39<00:00, 28.13it/s]
100%|██████████| 700/700 [00:21<00:00, 33.05it/s]


Best val loss found:  0.00927844453270414


100%|██████████| 2800/2800 [01:39<00:00, 28.16it/s]
100%|██████████| 700/700 [00:21<00:00, 32.84it/s]
100%|██████████| 2800/2800 [01:38<00:00, 28.32it/s]
100%|██████████| 700/700 [00:21<00:00, 32.37it/s]
100%|██████████| 2800/2800 [01:38<00:00, 28.32it/s]
100%|██████████| 700/700 [00:21<00:00, 32.38it/s]


| >>>>>>>>>>>>>>>>>>>>>>>>>       |


100%|██████████| 2800/2800 [01:39<00:00, 28.15it/s]
100%|██████████| 700/700 [00:21<00:00, 32.88it/s]
100%|██████████| 2800/2800 [01:38<00:00, 28.35it/s]
100%|██████████| 700/700 [00:21<00:00, 33.02it/s]
100%|██████████| 2800/2800 [01:39<00:00, 28.16it/s]
100%|██████████| 700/700 [00:21<00:00, 32.66it/s]
100%|██████████| 2800/2800 [01:39<00:00, 28.24it/s]
100%|██████████| 700/700 [00:21<00:00, 32.71it/s]
100%|██████████| 2800/2800 [01:38<00:00, 28.31it/s]
100%|██████████| 700/700 [00:20<00:00, 33.43it/s]


Best val loss found:  0.009218221652720656
This fold, the best val loss is:  0.009218221652720656


100%|██████████| 1000/1000 [00:30<00:00, 32.73it/s]


This fold, the test loss is:  0.010810707543045282
********************
Fold3
********************
Dataloader Success---------------------
|                                 |


100%|██████████| 2800/2800 [01:39<00:00, 28.27it/s]
100%|██████████| 700/700 [00:21<00:00, 33.05it/s]


Epoch 0, val_loss 0.0151
Best val loss found:  0.015124811388751758


100%|██████████| 2800/2800 [01:39<00:00, 28.28it/s]
100%|██████████| 700/700 [00:20<00:00, 33.38it/s]


Best val loss found:  0.01206105775449292


100%|██████████| 2800/2800 [01:39<00:00, 28.25it/s]
100%|██████████| 700/700 [00:21<00:00, 32.49it/s]


Best val loss found:  0.010847751546784171


100%|██████████| 2800/2800 [01:39<00:00, 28.27it/s]
100%|██████████| 700/700 [00:21<00:00, 32.61it/s]


Best val loss found:  0.010500843457411974


100%|██████████| 2800/2800 [01:38<00:00, 28.34it/s]
100%|██████████| 700/700 [00:21<00:00, 32.86it/s]


Best val loss found:  0.01027253222418949
| >>>>>                           |


100%|██████████| 2800/2800 [01:38<00:00, 28.31it/s]
100%|██████████| 700/700 [00:21<00:00, 32.74it/s]


Best val loss found:  0.010023250557049843


100%|██████████| 2800/2800 [01:38<00:00, 28.29it/s]
100%|██████████| 700/700 [00:21<00:00, 32.61it/s]
100%|██████████| 2800/2800 [01:38<00:00, 28.37it/s]
100%|██████████| 700/700 [00:21<00:00, 32.98it/s]


Best val loss found:  0.009851430663838983


100%|██████████| 2800/2800 [01:38<00:00, 28.48it/s]
100%|██████████| 700/700 [00:21<00:00, 32.56it/s]
100%|██████████| 2800/2800 [01:39<00:00, 28.19it/s]
100%|██████████| 700/700 [00:21<00:00, 33.09it/s]


Best val loss found:  0.009776099210638286
| >>>>>>>>>>                      |


100%|██████████| 2800/2800 [01:39<00:00, 28.19it/s]
100%|██████████| 700/700 [00:21<00:00, 32.57it/s]


Epoch 10, val_loss 0.0097
Best val loss found:  0.009710213396465406


100%|██████████| 2800/2800 [01:38<00:00, 28.44it/s]
100%|██████████| 700/700 [00:21<00:00, 32.50it/s]


Best val loss found:  0.00961252015383382


100%|██████████| 2800/2800 [01:39<00:00, 28.20it/s]
100%|██████████| 700/700 [00:21<00:00, 32.80it/s]
100%|██████████| 2800/2800 [01:38<00:00, 28.35it/s]
100%|██████████| 700/700 [00:21<00:00, 32.93it/s]


Best val loss found:  0.009599930680124089


100%|██████████| 2800/2800 [01:38<00:00, 28.43it/s]
100%|██████████| 700/700 [00:21<00:00, 32.63it/s]


Best val loss found:  0.009494972498754837
| >>>>>>>>>>>>>>>                 |


100%|██████████| 2800/2800 [01:38<00:00, 28.32it/s]
100%|██████████| 700/700 [00:21<00:00, 33.01it/s]
100%|██████████| 2800/2800 [01:38<00:00, 28.41it/s]
100%|██████████| 700/700 [00:21<00:00, 32.86it/s]
100%|██████████| 2800/2800 [01:38<00:00, 28.30it/s]
100%|██████████| 700/700 [00:21<00:00, 32.66it/s]
100%|██████████| 2800/2800 [01:38<00:00, 28.51it/s]
100%|██████████| 700/700 [00:21<00:00, 32.83it/s]


Best val loss found:  0.009492089412940134


100%|██████████| 2800/2800 [01:39<00:00, 28.27it/s]
100%|██████████| 700/700 [00:21<00:00, 32.54it/s]


Best val loss found:  0.009456169375916944
| >>>>>>>>>>>>>>>>>>>>            |


100%|██████████| 2800/2800 [01:38<00:00, 28.29it/s]
100%|██████████| 700/700 [00:21<00:00, 33.05it/s]


Epoch 20, val_loss 0.0095


100%|██████████| 2800/2800 [01:39<00:00, 28.02it/s]
100%|██████████| 700/700 [00:21<00:00, 33.02it/s]


Best val loss found:  0.009451151268856067


100%|██████████| 2800/2800 [01:38<00:00, 28.42it/s]
100%|██████████| 700/700 [00:21<00:00, 32.67it/s]


Best val loss found:  0.009427051362914166


100%|██████████| 2800/2800 [01:39<00:00, 28.21it/s]
100%|██████████| 700/700 [00:18<00:00, 38.31it/s]


Best val loss found:  0.009394213353682842


100%|██████████| 2800/2800 [01:30<00:00, 31.10it/s]
100%|██████████| 700/700 [00:17<00:00, 39.23it/s]


| >>>>>>>>>>>>>>>>>>>>>>>>>       |


100%|██████████| 2800/2800 [01:29<00:00, 31.20it/s]
100%|██████████| 700/700 [00:17<00:00, 39.19it/s]


Best val loss found:  0.00935430069553799


100%|██████████| 2800/2800 [01:30<00:00, 30.99it/s]
100%|██████████| 700/700 [00:18<00:00, 38.61it/s]
100%|██████████| 2800/2800 [01:29<00:00, 31.22it/s]
100%|██████████| 700/700 [00:17<00:00, 38.97it/s]
100%|██████████| 2800/2800 [01:29<00:00, 31.32it/s]
100%|██████████| 700/700 [00:17<00:00, 39.15it/s]
100%|██████████| 2800/2800 [01:30<00:00, 31.03it/s]
100%|██████████| 700/700 [00:18<00:00, 37.80it/s]


Best val loss found:  0.00932327596437452
This fold, the best val loss is:  0.00932327596437452


100%|██████████| 1000/1000 [00:26<00:00, 37.90it/s]


This fold, the test loss is:  0.0107426609917311
********************
Fold4
********************
Dataloader Success---------------------
|                                 |


100%|██████████| 2800/2800 [01:29<00:00, 31.34it/s]
100%|██████████| 700/700 [00:18<00:00, 38.53it/s]


Epoch 0, val_loss 0.0147
Best val loss found:  0.014701925900631718


100%|██████████| 2800/2800 [01:30<00:00, 31.03it/s]
100%|██████████| 700/700 [00:18<00:00, 38.61it/s]


Best val loss found:  0.012221635929974063


100%|██████████| 2800/2800 [01:23<00:00, 33.55it/s]
100%|██████████| 700/700 [00:18<00:00, 38.07it/s]


Best val loss found:  0.011214083906545836


100%|██████████| 2800/2800 [01:25<00:00, 32.90it/s]
100%|██████████| 700/700 [00:17<00:00, 38.98it/s]


Best val loss found:  0.010957768837189567


100%|██████████| 2800/2800 [01:26<00:00, 32.48it/s]
100%|██████████| 700/700 [00:18<00:00, 38.88it/s]


Best val loss found:  0.010527018957405484
| >>>>>                           |


100%|██████████| 2800/2800 [01:26<00:00, 32.22it/s]
100%|██████████| 700/700 [00:18<00:00, 38.29it/s]


Best val loss found:  0.010495365040343521


100%|██████████| 2800/2800 [01:26<00:00, 32.21it/s]
100%|██████████| 700/700 [00:18<00:00, 38.71it/s]


Best val loss found:  0.01027216571682532


100%|██████████| 2800/2800 [01:25<00:00, 32.85it/s]
100%|██████████| 700/700 [00:18<00:00, 38.30it/s]


Best val loss found:  0.010241286465232926


100%|██████████| 2800/2800 [01:26<00:00, 32.54it/s]
100%|██████████| 700/700 [00:18<00:00, 38.15it/s]


Best val loss found:  0.010173782349697182


100%|██████████| 2800/2800 [01:24<00:00, 32.95it/s]
100%|██████████| 700/700 [00:18<00:00, 38.78it/s]


Best val loss found:  0.010098765601869673
| >>>>>>>>>>                      |


100%|██████████| 2800/2800 [01:26<00:00, 32.33it/s]
100%|██████████| 700/700 [00:18<00:00, 37.97it/s]


Epoch 10, val_loss 0.0100
Best val loss found:  0.010037894071844805


100%|██████████| 2800/2800 [01:26<00:00, 32.54it/s]
100%|██████████| 700/700 [00:18<00:00, 37.67it/s]


Best val loss found:  0.010016737241364483


100%|██████████| 2800/2800 [01:26<00:00, 32.43it/s]
100%|██████████| 700/700 [00:18<00:00, 38.44it/s]


Best val loss found:  0.0099825496638992


100%|██████████| 2800/2800 [01:26<00:00, 32.29it/s]
100%|██████████| 700/700 [00:17<00:00, 39.17it/s]


Best val loss found:  0.009899140099795268


100%|██████████| 2800/2800 [01:25<00:00, 32.57it/s]
100%|██████████| 700/700 [00:18<00:00, 38.80it/s]


Best val loss found:  0.009864652467159821
| >>>>>>>>>>>>>>>                 |


100%|██████████| 2800/2800 [01:26<00:00, 32.39it/s]
100%|██████████| 700/700 [00:18<00:00, 38.21it/s]
100%|██████████| 2800/2800 [01:26<00:00, 32.26it/s]
100%|██████████| 700/700 [00:18<00:00, 38.57it/s]
100%|██████████| 2800/2800 [01:26<00:00, 32.36it/s]
100%|██████████| 700/700 [00:18<00:00, 37.97it/s]


Best val loss found:  0.009814881753908204


100%|██████████| 2800/2800 [01:26<00:00, 32.33it/s]
100%|██████████| 700/700 [00:17<00:00, 38.95it/s]
100%|██████████| 2800/2800 [01:26<00:00, 32.24it/s]
100%|██████████| 700/700 [00:18<00:00, 38.54it/s]


| >>>>>>>>>>>>>>>>>>>>            |


100%|██████████| 2800/2800 [01:24<00:00, 33.03it/s]
100%|██████████| 700/700 [00:18<00:00, 37.57it/s]


Epoch 20, val_loss 0.0098
Best val loss found:  0.009804661112804232


100%|██████████| 2800/2800 [01:26<00:00, 32.41it/s]
100%|██████████| 700/700 [00:18<00:00, 38.38it/s]
100%|██████████| 2800/2800 [01:27<00:00, 32.14it/s]
100%|██████████| 700/700 [00:18<00:00, 38.11it/s]


Best val loss found:  0.009768222381660183


100%|██████████| 2800/2800 [01:26<00:00, 32.24it/s]
100%|██████████| 700/700 [00:18<00:00, 38.54it/s]
100%|██████████| 2800/2800 [01:26<00:00, 32.39it/s]
100%|██████████| 700/700 [00:18<00:00, 38.18it/s]


| >>>>>>>>>>>>>>>>>>>>>>>>>       |


100%|██████████| 2800/2800 [01:27<00:00, 32.14it/s]
100%|██████████| 700/700 [00:18<00:00, 38.33it/s]
100%|██████████| 2800/2800 [01:26<00:00, 32.50it/s]
100%|██████████| 700/700 [00:18<00:00, 38.78it/s]


Best val loss found:  0.00975188038012545


100%|██████████| 2800/2800 [01:23<00:00, 33.40it/s]
100%|██████████| 700/700 [00:18<00:00, 38.46it/s]
100%|██████████| 2800/2800 [01:23<00:00, 33.37it/s]
100%|██████████| 700/700 [00:18<00:00, 38.39it/s]


Best val loss found:  0.009682408830849453


100%|██████████| 2800/2800 [01:27<00:00, 32.14it/s]
100%|██████████| 700/700 [00:18<00:00, 37.96it/s]


This fold, the best val loss is:  0.009682408830849453


100%|██████████| 1000/1000 [00:26<00:00, 37.39it/s]


This fold, the test loss is:  0.010882767678354867
********************
Fold5
********************
Dataloader Success---------------------
|                                 |


100%|██████████| 2800/2800 [01:25<00:00, 32.86it/s]
100%|██████████| 700/700 [00:18<00:00, 37.61it/s]


Epoch 0, val_loss 0.0150
Best val loss found:  0.01502072431047314


100%|██████████| 2800/2800 [01:25<00:00, 32.81it/s]
100%|██████████| 700/700 [00:18<00:00, 37.53it/s]


Best val loss found:  0.012074494334602994


100%|██████████| 2800/2800 [01:26<00:00, 32.44it/s]
100%|██████████| 700/700 [00:18<00:00, 37.47it/s]


Best val loss found:  0.011248403774308307


100%|██████████| 2800/2800 [01:27<00:00, 32.08it/s]
100%|██████████| 700/700 [00:19<00:00, 36.34it/s]


Best val loss found:  0.010842517623511542


100%|██████████| 2800/2800 [01:24<00:00, 33.00it/s]
100%|██████████| 700/700 [00:18<00:00, 36.85it/s]


Best val loss found:  0.010535323692518951
| >>>>>                           |


100%|██████████| 2800/2800 [01:24<00:00, 33.03it/s]
100%|██████████| 700/700 [00:18<00:00, 37.18it/s]


Best val loss found:  0.010388164558847036


100%|██████████| 2800/2800 [01:25<00:00, 32.89it/s]
100%|██████████| 700/700 [00:18<00:00, 37.82it/s]


Best val loss found:  0.010282035267945112


100%|██████████| 2800/2800 [01:25<00:00, 32.90it/s]
100%|██████████| 700/700 [00:18<00:00, 37.82it/s]


Best val loss found:  0.010103190551744775


100%|██████████| 2800/2800 [01:24<00:00, 33.03it/s]
100%|██████████| 700/700 [00:18<00:00, 38.13it/s]


Best val loss found:  0.010036145005308623


100%|██████████| 2800/2800 [01:25<00:00, 32.94it/s]
100%|██████████| 700/700 [00:18<00:00, 38.55it/s]


Best val loss found:  0.010033971430135093
| >>>>>>>>>>                      |


100%|██████████| 2800/2800 [01:25<00:00, 32.82it/s]
100%|██████████| 700/700 [00:18<00:00, 38.01it/s]


Epoch 10, val_loss 0.0099
Best val loss found:  0.00992588200656298


100%|██████████| 2800/2800 [01:25<00:00, 32.67it/s]
100%|██████████| 700/700 [00:18<00:00, 38.32it/s]
100%|██████████| 2800/2800 [01:25<00:00, 32.79it/s]
100%|██████████| 700/700 [00:18<00:00, 38.13it/s]


Best val loss found:  0.009822360854928516


100%|██████████| 2800/2800 [01:25<00:00, 32.69it/s]
100%|██████████| 700/700 [00:18<00:00, 37.85it/s]


Best val loss found:  0.009805205006019346


100%|██████████| 2800/2800 [01:25<00:00, 32.87it/s]
100%|██████████| 700/700 [00:18<00:00, 38.02it/s]


Best val loss found:  0.009760077295192916
| >>>>>>>>>>>>>>>                 |


100%|██████████| 2800/2800 [01:25<00:00, 32.84it/s]
100%|██████████| 700/700 [00:18<00:00, 38.28it/s]
100%|██████████| 2800/2800 [01:25<00:00, 32.76it/s]
100%|██████████| 700/700 [00:18<00:00, 38.24it/s]
100%|██████████| 2800/2800 [01:25<00:00, 32.87it/s]
100%|██████████| 700/700 [00:18<00:00, 38.10it/s]


Best val loss found:  0.009708451793828447


100%|██████████| 2800/2800 [01:24<00:00, 33.05it/s]
100%|██████████| 700/700 [00:18<00:00, 38.42it/s]


Best val loss found:  0.009635176060588233


100%|██████████| 2800/2800 [01:25<00:00, 32.72it/s]
100%|██████████| 700/700 [00:18<00:00, 38.10it/s]


| >>>>>>>>>>>>>>>>>>>>            |


100%|██████████| 2800/2800 [01:25<00:00, 32.92it/s]
100%|██████████| 700/700 [00:18<00:00, 38.45it/s]


Epoch 20, val_loss 0.0096
Best val loss found:  0.009601968374003523


100%|██████████| 2800/2800 [01:25<00:00, 32.81it/s]
100%|██████████| 700/700 [00:18<00:00, 38.06it/s]
100%|██████████| 2800/2800 [01:25<00:00, 32.90it/s]
100%|██████████| 700/700 [00:18<00:00, 38.60it/s]
100%|██████████| 2800/2800 [01:25<00:00, 32.88it/s]
100%|██████████| 700/700 [00:18<00:00, 38.36it/s]
100%|██████████| 2800/2800 [01:25<00:00, 32.70it/s]
100%|██████████| 700/700 [00:18<00:00, 37.84it/s]


Best val loss found:  0.009503677094554795
| >>>>>>>>>>>>>>>>>>>>>>>>>       |


100%|██████████| 2800/2800 [01:25<00:00, 32.94it/s]
100%|██████████| 700/700 [00:18<00:00, 37.82it/s]
100%|██████████| 2800/2800 [01:25<00:00, 32.79it/s]
100%|██████████| 700/700 [00:18<00:00, 37.17it/s]
100%|██████████| 2800/2800 [01:24<00:00, 33.17it/s]
100%|██████████| 700/700 [00:18<00:00, 37.22it/s]
100%|██████████| 2800/2800 [01:23<00:00, 33.34it/s]
100%|██████████| 700/700 [00:19<00:00, 36.22it/s]
100%|██████████| 2800/2800 [01:23<00:00, 33.45it/s]
100%|██████████| 700/700 [00:19<00:00, 35.65it/s]


Best val loss found:  0.009502152015588114
This fold, the best val loss is:  0.009502152015588114


100%|██████████| 1000/1000 [00:26<00:00, 38.01it/s]

This fold, the test loss is:  0.010810896016540937





In [20]:
torch.save(test_predict_lst,'test_predict/mini_30e_001.pt')

# Last All Last

In [16]:
# patentModel.load_state_dict(torch.load('ckpt/001/best_model.pth'))

<All keys matched successfully>

In [15]:
# torch.save(patentModel.state_dict(), 'ckpt/001/best_model.pth')