In [1]:
import pandas as pd
import pickle
import numpy as np
from torch.utils.data import Dataset,DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AdamW
from operator import itemgetter
from sklearn.model_selection import StratifiedKFold

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
torch.cuda.set_device(0)

In [2]:
label2id = pickle.load(open('../temp_results/mini_label2id_dict.pkl','rb'))
id2label = pickle.load(open('../temp_results/mini_id2label_lst.pkl','rb'))

In [3]:
train_data = pd.read_csv('../data/mini_train_data.csv')
test_data = pd.read_csv('../data/mini_test_data.csv')

In [4]:
from tflow_utils import TransformerGlow, AdamWeightDecayOptimizer
from transformers import AutoTokenizer

model_name_or_path = 'anferico/bert-for-patents'
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model = TransformerGlow.from_pretrained('output')  # Load model

# Data Loader

In [5]:
def str2id_lst(str_label):
    id_lst = []
    for l in str_label.split(','):
        id_lst.append(label2id[l])
    return id_lst

class PatentDataset(Dataset):
    def __init__(self,df,labeled = True):
        self.df = df
        self.labeled = labeled
        
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self,idx):
        text = self.df.iloc[idx]['text'][3:]
        label = str2id_lst(self.df.iloc[idx]['cpc_ids'])
        
        if self.labeled:
            return text,label
        else:
            return text,None
        

In [6]:
test_dataset = PatentDataset(test_data)

In [7]:
def collate_fn(data):
    sents = [i[0] for i in data]
    labels = [i[1] for i in data]
    
    data = tokenizer.batch_encode_plus(batch_text_or_text_pairs=sents,
                                      truncation=True,
                                      padding='longest',
                                      max_length=512,
                                      return_tensors='pt',
                                      return_length=True)
    input_ids = data['input_ids']
    attention_mask = data['attention_mask']
    token_type_ids = data['token_type_ids']
    
    batch_label = np.zeros((len(labels),len(id2label)))
    for i,_label in enumerate(labels):
        batch_label[i,_label]=1
    
    batch_label = torch.tensor(batch_label,dtype=torch.float32)
    
    return input_ids, attention_mask, token_type_ids, batch_label
    

In [8]:
test_dataloader = DataLoader(dataset = test_dataset,
                            batch_size = 4,
                            collate_fn = collate_fn)

# Define Model

In [9]:
class PatentClsModel(nn.Module):
    def __init__(self,bert_model,backbone_fixed = True):
        super().__init__()
        self.fc = nn.Sequential(nn.BatchNorm1d(1024),
                               nn.Dropout(0.5),
                               nn.Linear(1024,768),
                               nn.ReLU(),
                               nn.BatchNorm1d(768),
                               nn.Dropout(0.5),
                               nn.Linear(768,len(id2label)))
        
        self.bert_model = bert_model
        self.sig = nn.Sigmoid()
        self.backbone_fixed = backbone_fixed
        
        for i, module in enumerate(self.fc):
            if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d)):
                nn.init.constant_(module.weight, 1)
                nn.init.constant_(module.bias, 0)
            elif isinstance(module, nn.Linear):
                if getattr(module, "weight_v", None) is not None:
                    nn.init.uniform_(module.weight_g, 0, 1)
                    nn.init.kaiming_normal_(module.weight_v)
                    assert model[i].weight_g is not None
                else:
                    nn.init.kaiming_normal_(module.weight)
                nn.init.constant_(module.bias, 0)

    def forward(self, input_ids, attention_mask, token_type_ids):
        if self.backbone_fixed:
            with torch.no_grad():
                x,loss = self.bert_model(input_ids = input_ids,
                                         attention_mask = attention_mask,
                                         return_loss=True)
        else:
            x,loss = self.bert_model(input_ids = input_ids,
                                     attention_mask = attention_mask,
                                     return_loss=True)
            
        x = self.fc(x)
        x = self.sig(x)
        
        return x
        

# Training

In [10]:
kfold = StratifiedKFold(n_splits=5)
total_epochs = 30
test_predict_lst = []

In [11]:
from tqdm import tqdm
for train_index, valid_index in kfold.split(train_data,train_data['cpc_ids']):
    
    print('*'*20)
    print(f'Fold{len(test_predict_lst)+1}')
    print('*'*20)
    train_dataset = PatentDataset(train_data.iloc[train_index])
    val_dataset = PatentDataset(train_data.iloc[valid_index])

    train_dataloader = DataLoader(train_dataset,
                                 collate_fn = collate_fn,
                                 batch_size = 4,
                                 shuffle = True,
                                 drop_last = True)
    val_dataloader = DataLoader(val_dataset,
                               collate_fn = collate_fn,
                               batch_size = 4,
                               shuffle = True,
                               drop_last = True)

    patentModel = PatentClsModel(model,backbone_fixed = True).cuda()
    loss_func = nn.BCELoss()
    optimizer = AdamW(patentModel.parameters(), lr=5e-4)
    # reg_lambda = 0.035

    print('Dataloader Success---------------------')

    best_val_loss = 100
    for epoch in range(total_epochs):
        if epoch%5==0:
            print('|',">" * epoch," "*(total_epochs-epoch),'|')

        patentModel.train()
        for iter,(input_ids, attention_mask, token_type_ids, batch_label) in enumerate(tqdm(train_dataloader)):
            input_ids = input_ids.cuda()
            attention_mask = attention_mask.cuda()
            token_type_ids = token_type_ids.cuda()
            batch_label = batch_label.cuda()

            prediction = patentModel(input_ids, attention_mask, token_type_ids)
            
            # l2_reg = None
            # for w in patentModel.fc.parameters():
            #     if not l2_reg:
            #         l2_reg = w.norm(2)
            #     else:
            #         l2_reg = l2_reg + w.norm(2)

            loss = loss_func(prediction,batch_label)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        val_loss = 0
        patentModel.eval()
        with torch.no_grad():
            for iter,(input_ids, attention_mask, token_type_ids, batch_label) in enumerate(tqdm(val_dataloader)):
                input_ids = input_ids.cuda()
                attention_mask = attention_mask.cuda()
                token_type_ids = token_type_ids.cuda()
                batch_label = batch_label.cuda()
                prediction = patentModel(input_ids, attention_mask, token_type_ids)
                loss = loss_func(prediction,batch_label)
                val_loss += loss.detach().item()
            val_loss = val_loss/(iter+1)

        if epoch%10 == 0:
            print('Epoch {}, val_loss {:.4f}'.format(epoch, val_loss))

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(patentModel.state_dict(), 'ckpt/bertflow/best_model_mini_{}.pth'.format(len(test_predict_lst)+1))
            print('Best val loss found: ', best_val_loss)

    print('This fold, the best val loss is: ', best_val_loss)

    test_loss = 0
    test_predict = None
    patentModel = PatentClsModel(model,backbone_fixed = True).cuda()
    patentModel.load_state_dict(torch.load('ckpt/bertflow/best_model_mini_{}.pth'.format(len(test_predict_lst)+1)))

    patentModel.eval()
    with torch.no_grad():
        for iter,(input_ids, attention_mask, token_type_ids, batch_label) in enumerate(tqdm(test_dataloader)):
            input_ids = input_ids.cuda()
            attention_mask = attention_mask.cuda()
            token_type_ids = token_type_ids.cuda()
            batch_label = batch_label.cuda()
            prediction = patentModel(input_ids, attention_mask, token_type_ids)

            if test_predict is None:
                test_predict = prediction
            else:
                test_predict = torch.cat((test_predict,prediction),axis = 0)

            loss = loss_func(prediction,batch_label)
            test_loss += loss.detach().item()

    test_loss /= (iter+1)
    print('This fold, the test loss is: ', test_loss)

    test_predict_lst.append(test_predict)



********************
Fold1
********************




Dataloader Success---------------------
|                                 |


100%|██████████| 2800/2800 [02:47<00:00, 16.72it/s]
100%|██████████| 700/700 [00:37<00:00, 18.58it/s]


Epoch 0, val_loss 0.0176
Best val loss found:  0.017647038169338235


100%|██████████| 2800/2800 [02:45<00:00, 16.89it/s]
100%|██████████| 700/700 [00:38<00:00, 18.13it/s]


Best val loss found:  0.014719449019591723


100%|██████████| 2800/2800 [02:52<00:00, 16.25it/s]
100%|██████████| 700/700 [00:38<00:00, 18.33it/s]


Best val loss found:  0.012807031681295484


100%|██████████| 2800/2800 [02:44<00:00, 17.00it/s]
100%|██████████| 700/700 [00:36<00:00, 18.93it/s]


Best val loss found:  0.012032744581478514


100%|██████████| 2800/2800 [02:44<00:00, 17.06it/s]
100%|██████████| 700/700 [00:37<00:00, 18.70it/s]


Best val loss found:  0.011607506808359177
| >>>>>                           |


100%|██████████| 2800/2800 [02:52<00:00, 16.23it/s]
100%|██████████| 700/700 [00:38<00:00, 18.25it/s]


Best val loss found:  0.011206917427979144


100%|██████████| 2800/2800 [02:46<00:00, 16.81it/s]
100%|██████████| 700/700 [00:37<00:00, 18.75it/s]


Best val loss found:  0.011053023947668926


100%|██████████| 2800/2800 [02:44<00:00, 16.98it/s]
100%|██████████| 700/700 [00:38<00:00, 18.40it/s]


Best val loss found:  0.010807990934367158


100%|██████████| 2800/2800 [02:51<00:00, 16.36it/s]
100%|██████████| 700/700 [00:38<00:00, 18.11it/s]


Best val loss found:  0.010724073451544558


100%|██████████| 2800/2800 [02:46<00:00, 16.86it/s]
100%|██████████| 700/700 [00:36<00:00, 19.21it/s]


Best val loss found:  0.010694813486521265
| >>>>>>>>>>                      |


100%|██████████| 2800/2800 [02:43<00:00, 17.11it/s]
100%|██████████| 700/700 [00:37<00:00, 18.63it/s]


Epoch 10, val_loss 0.0106
Best val loss found:  0.01058538760391197


100%|██████████| 2800/2800 [02:53<00:00, 16.16it/s]
100%|██████████| 700/700 [00:39<00:00, 17.93it/s]


Best val loss found:  0.010503425475742135


100%|██████████| 2800/2800 [02:49<00:00, 16.53it/s]
100%|██████████| 700/700 [00:37<00:00, 18.82it/s]


Best val loss found:  0.010412488061701879


100%|██████████| 2800/2800 [02:44<00:00, 16.98it/s]
100%|██████████| 700/700 [00:37<00:00, 18.70it/s]
100%|██████████| 2800/2800 [02:50<00:00, 16.40it/s]
100%|██████████| 700/700 [00:38<00:00, 18.35it/s]


Best val loss found:  0.01033830400140557
| >>>>>>>>>>>>>>>                 |


100%|██████████| 2800/2800 [02:47<00:00, 16.73it/s]
100%|██████████| 700/700 [00:36<00:00, 18.97it/s]


Best val loss found:  0.010314330757994736


100%|██████████| 2800/2800 [02:44<00:00, 16.98it/s]
100%|██████████| 700/700 [00:38<00:00, 18.37it/s]
100%|██████████| 2800/2800 [02:53<00:00, 16.18it/s]
100%|██████████| 700/700 [00:38<00:00, 17.96it/s]


Best val loss found:  0.010191976431358073


100%|██████████| 2800/2800 [02:49<00:00, 16.48it/s]
100%|██████████| 700/700 [00:37<00:00, 18.82it/s]
100%|██████████| 2800/2800 [02:47<00:00, 16.67it/s]
100%|██████████| 700/700 [00:37<00:00, 18.70it/s]


| >>>>>>>>>>>>>>>>>>>>            |


100%|██████████| 2800/2800 [02:49<00:00, 16.47it/s]
100%|██████████| 700/700 [00:38<00:00, 18.18it/s]


Epoch 20, val_loss 0.0103


100%|██████████| 2800/2800 [02:48<00:00, 16.58it/s]
100%|██████████| 700/700 [00:38<00:00, 18.32it/s]


Best val loss found:  0.010155708078860438


100%|██████████| 2800/2800 [02:49<00:00, 16.52it/s]
100%|██████████| 700/700 [00:38<00:00, 18.14it/s]
100%|██████████| 2800/2800 [02:51<00:00, 16.33it/s]
100%|██████████| 700/700 [00:38<00:00, 18.02it/s]


Best val loss found:  0.010142386014174137


100%|██████████| 2800/2800 [02:49<00:00, 16.51it/s]
100%|██████████| 700/700 [00:38<00:00, 18.40it/s]


| >>>>>>>>>>>>>>>>>>>>>>>>>       |


100%|██████████| 2800/2800 [02:47<00:00, 16.74it/s]
100%|██████████| 700/700 [00:37<00:00, 18.59it/s]


Best val loss found:  0.010099504185574395


100%|██████████| 2800/2800 [02:47<00:00, 16.76it/s]
100%|██████████| 700/700 [00:37<00:00, 18.46it/s]
100%|██████████| 2800/2800 [02:48<00:00, 16.61it/s]
100%|██████████| 700/700 [00:37<00:00, 18.70it/s]
100%|██████████| 2800/2800 [02:51<00:00, 16.33it/s]
100%|██████████| 700/700 [00:38<00:00, 18.10it/s]
100%|██████████| 2800/2800 [02:50<00:00, 16.39it/s]
100%|██████████| 700/700 [00:37<00:00, 18.53it/s]


Best val loss found:  0.010089992808311113
This fold, the best val loss is:  0.010089992808311113


100%|██████████| 1000/1000 [00:56<00:00, 17.60it/s]


This fold, the test loss is:  0.011429158280603587
********************
Fold2
********************
Dataloader Success---------------------
|                                 |


100%|██████████| 2800/2800 [02:48<00:00, 16.66it/s]
100%|██████████| 700/700 [00:37<00:00, 18.73it/s]


Epoch 0, val_loss 0.0173
Best val loss found:  0.017251864092291465


100%|██████████| 2800/2800 [02:47<00:00, 16.69it/s]
100%|██████████| 700/700 [00:37<00:00, 18.65it/s]


Best val loss found:  0.014255635997812663


100%|██████████| 2800/2800 [02:48<00:00, 16.66it/s]
100%|██████████| 700/700 [00:36<00:00, 18.94it/s]


Best val loss found:  0.012648751370475761


100%|██████████| 2800/2800 [02:50<00:00, 16.42it/s]
100%|██████████| 700/700 [00:38<00:00, 17.99it/s]


Best val loss found:  0.01172409265069291


100%|██████████| 2800/2800 [02:53<00:00, 16.15it/s]
100%|██████████| 700/700 [00:37<00:00, 18.53it/s]


Best val loss found:  0.011382961505358773
| >>>>>                           |


100%|██████████| 2800/2800 [02:46<00:00, 16.78it/s]
100%|██████████| 700/700 [00:36<00:00, 19.08it/s]


Best val loss found:  0.011008452095390697


100%|██████████| 2800/2800 [02:46<00:00, 16.85it/s]
100%|██████████| 700/700 [00:37<00:00, 18.62it/s]


Best val loss found:  0.010791532138495572


100%|██████████| 2800/2800 [02:49<00:00, 16.49it/s]
100%|██████████| 700/700 [00:37<00:00, 18.78it/s]


Best val loss found:  0.010668517557870863


100%|██████████| 2800/2800 [02:46<00:00, 16.80it/s]
100%|██████████| 700/700 [00:37<00:00, 18.88it/s]


Best val loss found:  0.010563836581672408


100%|██████████| 2800/2800 [02:48<00:00, 16.64it/s]
100%|██████████| 700/700 [00:38<00:00, 18.12it/s]


Best val loss found:  0.010500704158164028
| >>>>>>>>>>                      |


100%|██████████| 2800/2800 [02:53<00:00, 16.11it/s]
100%|██████████| 700/700 [00:38<00:00, 18.36it/s]


Epoch 10, val_loss 0.0104
Best val loss found:  0.010400096476078034


100%|██████████| 2800/2800 [02:45<00:00, 16.91it/s]
100%|██████████| 700/700 [00:36<00:00, 19.09it/s]


Best val loss found:  0.010378824067302048


100%|██████████| 2800/2800 [02:44<00:00, 16.97it/s]
100%|██████████| 700/700 [00:37<00:00, 18.47it/s]


Best val loss found:  0.01029523354722187


100%|██████████| 2800/2800 [02:51<00:00, 16.31it/s]
100%|██████████| 700/700 [00:38<00:00, 18.15it/s]


Best val loss found:  0.010230975654974047


100%|██████████| 2800/2800 [02:48<00:00, 16.65it/s]
100%|██████████| 700/700 [00:37<00:00, 18.60it/s]


| >>>>>>>>>>>>>>>                 |


100%|██████████| 2800/2800 [02:46<00:00, 16.87it/s]
100%|██████████| 700/700 [00:37<00:00, 18.49it/s]


Best val loss found:  0.01013305697290759


100%|██████████| 2800/2800 [02:52<00:00, 16.22it/s]
100%|██████████| 700/700 [00:38<00:00, 18.06it/s]
100%|██████████| 2800/2800 [02:47<00:00, 16.75it/s]
100%|██████████| 700/700 [00:36<00:00, 18.94it/s]


Best val loss found:  0.010114208569284529


100%|██████████| 2800/2800 [02:44<00:00, 17.06it/s]
100%|██████████| 700/700 [00:36<00:00, 19.20it/s]
100%|██████████| 2800/2800 [02:50<00:00, 16.45it/s]
100%|██████████| 700/700 [00:38<00:00, 18.20it/s]


| >>>>>>>>>>>>>>>>>>>>            |


100%|██████████| 2800/2800 [02:49<00:00, 16.47it/s]
100%|██████████| 700/700 [00:37<00:00, 18.58it/s]


Epoch 20, val_loss 0.0101
Best val loss found:  0.010070010400377215


100%|██████████| 2800/2800 [02:45<00:00, 16.91it/s]
100%|██████████| 700/700 [00:37<00:00, 18.63it/s]


Best val loss found:  0.009973921219352633


100%|██████████| 2800/2800 [02:48<00:00, 16.60it/s]
100%|██████████| 700/700 [00:38<00:00, 18.06it/s]
100%|██████████| 2800/2800 [02:47<00:00, 16.69it/s]
100%|██████████| 700/700 [00:37<00:00, 18.83it/s]
100%|██████████| 2800/2800 [02:43<00:00, 17.10it/s]
100%|██████████| 700/700 [00:37<00:00, 18.84it/s]


| >>>>>>>>>>>>>>>>>>>>>>>>>       |


100%|██████████| 2800/2800 [02:45<00:00, 16.87it/s]
100%|██████████| 700/700 [00:37<00:00, 18.46it/s]


Best val loss found:  0.00996635021542066


100%|██████████| 2800/2800 [02:50<00:00, 16.42it/s]
100%|██████████| 700/700 [00:37<00:00, 18.54it/s]
100%|██████████| 2800/2800 [02:44<00:00, 17.03it/s]
100%|██████████| 700/700 [00:36<00:00, 18.93it/s]
100%|██████████| 2800/2800 [02:44<00:00, 16.98it/s]
100%|██████████| 700/700 [00:37<00:00, 18.68it/s]


Best val loss found:  0.009881131546571852


100%|██████████| 2800/2800 [02:47<00:00, 16.70it/s]
100%|██████████| 700/700 [00:37<00:00, 18.82it/s]


This fold, the best val loss is:  0.009881131546571852


100%|██████████| 1000/1000 [00:56<00:00, 17.70it/s]


This fold, the test loss is:  0.011438390016788617
********************
Fold3
********************
Dataloader Success---------------------
|                                 |


100%|██████████| 2800/2800 [02:41<00:00, 17.33it/s]
100%|██████████| 700/700 [00:36<00:00, 18.94it/s]


Epoch 0, val_loss 0.0179
Best val loss found:  0.017910334820459995


100%|██████████| 2800/2800 [02:43<00:00, 17.17it/s]
100%|██████████| 700/700 [00:38<00:00, 18.37it/s]


Best val loss found:  0.014558496696076223


100%|██████████| 2800/2800 [02:48<00:00, 16.66it/s]
100%|██████████| 700/700 [00:37<00:00, 18.45it/s]


Best val loss found:  0.012725488142043884


100%|██████████| 2800/2800 [02:43<00:00, 17.09it/s]
100%|██████████| 700/700 [00:36<00:00, 19.08it/s]


Best val loss found:  0.011864397409704647


100%|██████████| 2800/2800 [02:43<00:00, 17.15it/s]
100%|██████████| 700/700 [00:37<00:00, 18.48it/s]


Best val loss found:  0.01138522544110726
| >>>>>                           |


100%|██████████| 2800/2800 [02:50<00:00, 16.47it/s]
100%|██████████| 700/700 [00:38<00:00, 18.18it/s]


Best val loss found:  0.011080617164594254


100%|██████████| 2800/2800 [02:44<00:00, 16.99it/s]
100%|██████████| 700/700 [00:37<00:00, 18.80it/s]


Best val loss found:  0.011006987082572387


100%|██████████| 2800/2800 [02:41<00:00, 17.29it/s]
100%|██████████| 700/700 [00:37<00:00, 18.47it/s]


Best val loss found:  0.01089490845732923


100%|██████████| 2800/2800 [02:48<00:00, 16.62it/s]
100%|██████████| 700/700 [00:38<00:00, 18.16it/s]


Best val loss found:  0.01073946749159534


100%|██████████| 2800/2800 [02:46<00:00, 16.82it/s]
100%|██████████| 700/700 [00:36<00:00, 18.96it/s]


Best val loss found:  0.010673277159886701
| >>>>>>>>>>                      |


100%|██████████| 2800/2800 [02:41<00:00, 17.39it/s]
100%|██████████| 700/700 [00:36<00:00, 18.94it/s]


Epoch 10, val_loss 0.0105
Best val loss found:  0.010538283907808364


100%|██████████| 2800/2800 [02:46<00:00, 16.77it/s]
100%|██████████| 700/700 [00:38<00:00, 18.09it/s]
100%|██████████| 2800/2800 [02:47<00:00, 16.68it/s]
100%|██████████| 700/700 [00:37<00:00, 18.88it/s]


Best val loss found:  0.010470256365370006


100%|██████████| 2800/2800 [02:41<00:00, 17.35it/s]
100%|██████████| 700/700 [00:36<00:00, 18.93it/s]
100%|██████████| 2800/2800 [02:45<00:00, 16.96it/s]
100%|██████████| 700/700 [00:38<00:00, 18.18it/s]


| >>>>>>>>>>>>>>>                 |


100%|██████████| 2800/2800 [02:50<00:00, 16.45it/s]
100%|██████████| 700/700 [00:37<00:00, 18.73it/s]


Best val loss found:  0.010390140979018594


100%|██████████| 2800/2800 [02:42<00:00, 17.21it/s]
100%|██████████| 700/700 [00:36<00:00, 18.99it/s]


Best val loss found:  0.010328637545795313


100%|██████████| 2800/2800 [02:43<00:00, 17.11it/s]
100%|██████████| 700/700 [00:38<00:00, 18.38it/s]


Best val loss found:  0.01031131621782801


100%|██████████| 2800/2800 [02:48<00:00, 16.59it/s]
100%|██████████| 700/700 [00:38<00:00, 18.16it/s]
100%|██████████| 2800/2800 [02:43<00:00, 17.13it/s]
100%|██████████| 700/700 [00:37<00:00, 18.83it/s]


Best val loss found:  0.010292403714265675
| >>>>>>>>>>>>>>>>>>>>            |


100%|██████████| 2800/2800 [02:42<00:00, 17.26it/s]
100%|██████████| 700/700 [00:38<00:00, 18.26it/s]


Epoch 20, val_loss 0.0103


100%|██████████| 2800/2800 [02:49<00:00, 16.48it/s]
100%|██████████| 700/700 [00:38<00:00, 18.25it/s]
100%|██████████| 2800/2800 [02:43<00:00, 17.10it/s]
100%|██████████| 700/700 [00:36<00:00, 19.13it/s]
100%|██████████| 2800/2800 [02:42<00:00, 17.20it/s]
100%|██████████| 700/700 [00:36<00:00, 18.95it/s]


Best val loss found:  0.01015834410608347


100%|██████████| 2800/2800 [02:47<00:00, 16.77it/s]
100%|██████████| 700/700 [00:38<00:00, 18.19it/s]


| >>>>>>>>>>>>>>>>>>>>>>>>>       |


100%|██████████| 2800/2800 [02:45<00:00, 16.89it/s]
100%|██████████| 700/700 [00:37<00:00, 18.89it/s]
100%|██████████| 2800/2800 [02:42<00:00, 17.20it/s]
100%|██████████| 700/700 [00:37<00:00, 18.72it/s]
100%|██████████| 2800/2800 [02:46<00:00, 16.81it/s]
100%|██████████| 700/700 [00:38<00:00, 17.97it/s]


Best val loss found:  0.010099207451567054


100%|██████████| 2800/2800 [02:48<00:00, 16.59it/s]
100%|██████████| 700/700 [00:37<00:00, 18.85it/s]
100%|██████████| 2800/2800 [02:41<00:00, 17.31it/s]
100%|██████████| 700/700 [00:37<00:00, 18.83it/s]


This fold, the best val loss is:  0.010099207451567054


100%|██████████| 1000/1000 [00:55<00:00, 18.09it/s]


This fold, the test loss is:  0.011375273110577837
********************
Fold4
********************
Dataloader Success---------------------
|                                 |


100%|██████████| 2800/2800 [02:47<00:00, 16.69it/s]
100%|██████████| 700/700 [00:37<00:00, 18.59it/s]


Epoch 0, val_loss 0.0181
Best val loss found:  0.018064627633430065


100%|██████████| 2800/2800 [02:48<00:00, 16.65it/s]
100%|██████████| 700/700 [00:35<00:00, 19.51it/s]


Best val loss found:  0.014766037577896246


100%|██████████| 2800/2800 [02:43<00:00, 17.08it/s]
100%|██████████| 700/700 [00:35<00:00, 19.46it/s]


Best val loss found:  0.013114529664162546


100%|██████████| 2800/2800 [02:46<00:00, 16.79it/s]
100%|██████████| 700/700 [00:37<00:00, 18.85it/s]


Best val loss found:  0.011980093303136527


100%|██████████| 2800/2800 [02:49<00:00, 16.52it/s]
100%|██████████| 700/700 [00:36<00:00, 19.31it/s]


Best val loss found:  0.011719081517242427
| >>>>>                           |


100%|██████████| 2800/2800 [02:43<00:00, 17.13it/s]
100%|██████████| 700/700 [00:35<00:00, 19.59it/s]


Best val loss found:  0.011409439451859465


100%|██████████| 2800/2800 [02:42<00:00, 17.21it/s]
100%|██████████| 700/700 [00:37<00:00, 18.69it/s]


Best val loss found:  0.011255457704454394


100%|██████████| 2800/2800 [02:51<00:00, 16.35it/s]
100%|██████████| 700/700 [00:37<00:00, 18.47it/s]


Best val loss found:  0.011057352914275335


100%|██████████| 2800/2800 [02:45<00:00, 16.95it/s]
100%|██████████| 700/700 [00:36<00:00, 19.25it/s]


Best val loss found:  0.011000776669077042


100%|██████████| 2800/2800 [02:45<00:00, 16.92it/s]
100%|██████████| 700/700 [00:37<00:00, 18.83it/s]


Best val loss found:  0.01090752500841128
| >>>>>>>>>>                      |


100%|██████████| 2800/2800 [02:50<00:00, 16.42it/s]
100%|██████████| 700/700 [00:37<00:00, 18.65it/s]


Epoch 10, val_loss 0.0107
Best val loss found:  0.010745092984288931


100%|██████████| 2800/2800 [02:45<00:00, 16.93it/s]
100%|██████████| 700/700 [00:35<00:00, 19.46it/s]


Best val loss found:  0.010683152154753251


100%|██████████| 2800/2800 [02:42<00:00, 17.20it/s]
100%|██████████| 700/700 [00:35<00:00, 19.55it/s]


Best val loss found:  0.010615605743535395


100%|██████████| 2800/2800 [02:48<00:00, 16.63it/s]
100%|██████████| 700/700 [00:37<00:00, 18.82it/s]


Best val loss found:  0.010537262173768665


100%|██████████| 2800/2800 [02:47<00:00, 16.70it/s]
100%|██████████| 700/700 [00:36<00:00, 19.16it/s]


| >>>>>>>>>>>>>>>                 |


100%|██████████| 2800/2800 [02:44<00:00, 17.06it/s]
100%|██████████| 700/700 [00:36<00:00, 19.39it/s]


Best val loss found:  0.010458883281264986


100%|██████████| 2800/2800 [02:47<00:00, 16.71it/s]
100%|██████████| 700/700 [00:37<00:00, 18.69it/s]


Best val loss found:  0.010453127881212693


100%|██████████| 2800/2800 [02:48<00:00, 16.64it/s]
100%|██████████| 700/700 [00:36<00:00, 19.29it/s]
100%|██████████| 2800/2800 [02:43<00:00, 17.11it/s]
100%|██████████| 700/700 [00:36<00:00, 19.29it/s]


Best val loss found:  0.01038409523316659


100%|██████████| 2800/2800 [02:46<00:00, 16.81it/s]
100%|██████████| 700/700 [00:37<00:00, 18.58it/s]


| >>>>>>>>>>>>>>>>>>>>            |


100%|██████████| 2800/2800 [02:52<00:00, 16.20it/s]
100%|██████████| 700/700 [00:37<00:00, 18.58it/s]


Epoch 20, val_loss 0.0103
Best val loss found:  0.010342211348137686


100%|██████████| 2800/2800 [02:44<00:00, 17.03it/s]
100%|██████████| 700/700 [00:36<00:00, 19.31it/s]
100%|██████████| 2800/2800 [02:44<00:00, 17.02it/s]
100%|██████████| 700/700 [00:37<00:00, 18.82it/s]


Best val loss found:  0.010295455233925688


100%|██████████| 2800/2800 [02:49<00:00, 16.48it/s]
100%|██████████| 700/700 [00:37<00:00, 18.89it/s]


Best val loss found:  0.01029194886197469


100%|██████████| 2800/2800 [02:42<00:00, 17.26it/s]
100%|██████████| 700/700 [00:35<00:00, 19.58it/s]


| >>>>>>>>>>>>>>>>>>>>>>>>>       |


100%|██████████| 2800/2800 [02:43<00:00, 17.15it/s]
100%|██████████| 700/700 [00:36<00:00, 19.44it/s]
100%|██████████| 2800/2800 [02:51<00:00, 16.35it/s]
100%|██████████| 700/700 [00:37<00:00, 18.48it/s]
100%|██████████| 2800/2800 [02:46<00:00, 16.85it/s]
100%|██████████| 700/700 [00:35<00:00, 19.45it/s]
100%|██████████| 2800/2800 [02:42<00:00, 17.23it/s]
100%|██████████| 700/700 [00:36<00:00, 19.43it/s]
100%|██████████| 2800/2800 [02:48<00:00, 16.59it/s]
100%|██████████| 700/700 [00:37<00:00, 18.86it/s]


This fold, the best val loss is:  0.01029194886197469


100%|██████████| 1000/1000 [00:57<00:00, 17.28it/s]


This fold, the test loss is:  0.011538849223055876
********************
Fold5
********************
Dataloader Success---------------------
|                                 |


100%|██████████| 2800/2800 [02:44<00:00, 17.03it/s]
100%|██████████| 700/700 [00:37<00:00, 18.51it/s]


Epoch 0, val_loss 0.0179
Best val loss found:  0.017934446020850113


100%|██████████| 2800/2800 [02:42<00:00, 17.24it/s]
100%|██████████| 700/700 [00:38<00:00, 18.29it/s]


Best val loss found:  0.014736121089663356


100%|██████████| 2800/2800 [02:49<00:00, 16.53it/s]
100%|██████████| 700/700 [00:39<00:00, 17.72it/s]


Best val loss found:  0.012781569662703468


100%|██████████| 2800/2800 [02:45<00:00, 16.87it/s]
100%|██████████| 700/700 [00:37<00:00, 18.88it/s]


Best val loss found:  0.012008910439908505


100%|██████████| 2800/2800 [02:40<00:00, 17.39it/s]
100%|██████████| 700/700 [00:37<00:00, 18.82it/s]


Best val loss found:  0.011611996884125151
| >>>>>                           |


100%|██████████| 2800/2800 [02:44<00:00, 16.98it/s]
100%|██████████| 700/700 [00:39<00:00, 17.92it/s]


Best val loss found:  0.011300585504421699


100%|██████████| 2800/2800 [02:46<00:00, 16.85it/s]
100%|██████████| 700/700 [00:37<00:00, 18.46it/s]


Best val loss found:  0.011197374700236


100%|██████████| 2800/2800 [02:41<00:00, 17.30it/s]
100%|██████████| 700/700 [00:37<00:00, 18.61it/s]


Best val loss found:  0.01103108163091487


100%|██████████| 2800/2800 [02:46<00:00, 16.78it/s]
100%|██████████| 700/700 [00:38<00:00, 18.00it/s]


Best val loss found:  0.010950793034197496


100%|██████████| 2800/2800 [02:47<00:00, 16.69it/s]
100%|██████████| 700/700 [00:37<00:00, 18.47it/s]


Best val loss found:  0.010893745401075908
| >>>>>>>>>>                      |


100%|██████████| 2800/2800 [02:41<00:00, 17.34it/s]
100%|██████████| 700/700 [00:37<00:00, 18.77it/s]


Epoch 10, val_loss 0.0108
Best val loss found:  0.010811179349797645


100%|██████████| 2800/2800 [02:43<00:00, 17.14it/s]
100%|██████████| 700/700 [00:38<00:00, 18.18it/s]


Best val loss found:  0.010673471729138068


100%|██████████| 2800/2800 [02:47<00:00, 16.68it/s]
100%|██████████| 700/700 [00:39<00:00, 17.90it/s]


Best val loss found:  0.010662879405343639


100%|██████████| 2800/2800 [02:45<00:00, 16.90it/s]
100%|██████████| 700/700 [00:37<00:00, 18.50it/s]


Best val loss found:  0.010602559715043753


100%|██████████| 2800/2800 [02:43<00:00, 17.11it/s]
100%|██████████| 700/700 [00:38<00:00, 18.03it/s]


| >>>>>>>>>>>>>>>                 |


100%|██████████| 2800/2800 [02:47<00:00, 16.73it/s]
100%|██████████| 700/700 [00:38<00:00, 18.19it/s]


Best val loss found:  0.010516178701072932


100%|██████████| 2800/2800 [02:44<00:00, 17.05it/s]
100%|██████████| 700/700 [00:37<00:00, 18.73it/s]
100%|██████████| 2800/2800 [02:40<00:00, 17.41it/s]
100%|██████████| 700/700 [00:37<00:00, 18.67it/s]
100%|██████████| 2800/2800 [02:45<00:00, 16.87it/s]
100%|██████████| 700/700 [00:39<00:00, 17.87it/s]


Best val loss found:  0.01045592778761472


100%|██████████| 2800/2800 [02:46<00:00, 16.78it/s]
100%|██████████| 700/700 [00:38<00:00, 18.23it/s]


Best val loss found:  0.010362875263339707
| >>>>>>>>>>>>>>>>>>>>            |


100%|██████████| 2800/2800 [02:41<00:00, 17.34it/s]
100%|██████████| 700/700 [00:37<00:00, 18.61it/s]


Epoch 20, val_loss 0.0104


100%|██████████| 2800/2800 [02:45<00:00, 16.92it/s]
100%|██████████| 700/700 [00:38<00:00, 17.95it/s]


Best val loss found:  0.010341459120557244


100%|██████████| 2800/2800 [02:46<00:00, 16.84it/s]
100%|██████████| 700/700 [00:38<00:00, 18.30it/s]
100%|██████████| 2800/2800 [02:42<00:00, 17.25it/s]
100%|██████████| 700/700 [00:37<00:00, 18.60it/s]
100%|██████████| 2800/2800 [02:44<00:00, 17.05it/s]
100%|██████████| 700/700 [00:38<00:00, 18.17it/s]


Best val loss found:  0.010308293180340634
| >>>>>>>>>>>>>>>>>>>>>>>>>       |


100%|██████████| 2800/2800 [02:49<00:00, 16.57it/s]
100%|██████████| 700/700 [00:38<00:00, 18.22it/s]
100%|██████████| 2800/2800 [02:44<00:00, 16.98it/s]
100%|██████████| 700/700 [00:37<00:00, 18.72it/s]
100%|██████████| 2800/2800 [02:42<00:00, 17.20it/s]
100%|██████████| 700/700 [00:38<00:00, 18.17it/s]


Best val loss found:  0.010248285785450468


100%|██████████| 2800/2800 [02:48<00:00, 16.59it/s]
100%|██████████| 700/700 [00:39<00:00, 17.84it/s]
100%|██████████| 2800/2800 [02:46<00:00, 16.85it/s]
100%|██████████| 700/700 [00:37<00:00, 18.50it/s]


This fold, the best val loss is:  0.010248285785450468


100%|██████████| 1000/1000 [00:56<00:00, 17.75it/s]

This fold, the test loss is:  0.011433473503333517





In [12]:
torch.save(test_predict_lst,'test_predict/mini_30e_004.pt')

# Last All Last

In [16]:
# patentModel.load_state_dict(torch.load('ckpt/001/best_model.pth'))

<All keys matched successfully>

In [15]:
# torch.save(patentModel.state_dict(), 'ckpt/001/best_model.pth')