## Get param kenrel and bias

In [1]:
import os
import torch
import numpy as np
import pandas as pd
from transformers import BertModel, BertTokenizer
from tqdm import tqdm
import scipy.stats

TRAIN_PATH = '../data/mini_train_data.csv'
TEST_PATH = '../data/mini_test_data.csv' # is not uesd

MODEL_NAME = "anferico/bert-for-patents"

POOLING = 'first_last_avg'
# POOLING = 'last_avg'
# POOLING = 'last2avg'

USE_WHITENING = True
N_COMPONENTS = 384
MAX_LENGTH = 512

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def load_dataset(path):
    sent_batch = []
    test_data =pd.read_csv(path)
    for sent in test_data['text']:
        sent_batch.append(sent[3:])
    return sent_batch

def build_model(name):
    tokenizer = BertTokenizer.from_pretrained(name)
    model = BertModel.from_pretrained(name)
    model = model.to(DEVICE)
    return tokenizer, model


def sents_to_vecs(sents, tokenizer, model):
    vecs = []
    with torch.no_grad():
        for sent in tqdm(sents):
            inputs = tokenizer(sent, return_tensors="pt", padding=True, truncation=True,  max_length=MAX_LENGTH)
            inputs['input_ids'] = inputs['input_ids'].to(DEVICE)
            inputs['token_type_ids'] = inputs['token_type_ids'].to(DEVICE)
            inputs['attention_mask'] = inputs['attention_mask'].to(DEVICE)

            hidden_states = model(**inputs, return_dict=True, output_hidden_states=True).hidden_states

            if POOLING == 'first_last_avg':
                output_hidden_state = (hidden_states[-1] + hidden_states[1]).mean(dim=1)
            elif POOLING == 'last_avg':
                output_hidden_state = (hidden_states[-1]).mean(dim=1)
            elif POOLING == 'last2avg':
                output_hidden_state = (hidden_states[-1] + hidden_states[-2]).mean(dim=1)
            else:
                raise Exception("unknown pooling {}".format(POOLING))

            vec = output_hidden_state.cpu().numpy()[0]
            vecs.append(vec)
    assert len(sents) == len(vecs)
    vecs = np.array(vecs)
    return vecs


def calc_spearmanr_corr(x, y):
    return scipy.stats.spearmanr(x, y).correlation


def compute_kernel_bias(vecs, n_components):
    """计算kernel和bias
    最后的变换：y = (x + bias).dot(kernel)
    """
    vecs = np.concatenate(vecs, axis=0)
    mu = vecs.mean(axis=0, keepdims=True)
    cov = np.cov(vecs.T)
    u, s, vh = np.linalg.svd(cov)
    W = np.dot(u, np.diag(s**0.5))
    W = np.linalg.inv(W.T)
    W = W[:, :n_components]
    return W, -mu


def transform_and_normalize(vecs, kernel, bias):
    """应用变换，然后标准化
    """
    if not (kernel is None or bias is None):
        vecs = (vecs + bias).dot(kernel)
    return vecs / (vecs**2).sum(axis=1, keepdims=True)**0.5


def normalize(vecs):
    """标准化
    """
    return vecs / (vecs**2).sum(axis=1, keepdims=True)**0.5


print(f"Configs: {MODEL_NAME}-{POOLING}-{USE_WHITENING}-{N_COMPONENTS}.")

a_sents_train = load_dataset(TRAIN_PATH)
print("Loading {} training samples from {}".format(len(a_sents_train), TRAIN_PATH))


tokenizer, model = build_model(MODEL_NAME)
print("Building {} tokenizer and model successfuly.".format(MODEL_NAME))

print("Transfer sentences to BERT vectors.")

if USE_WHITENING:
    a_vecs_train = sents_to_vecs(a_sents_train, tokenizer, model)

    print("Compute kernel and bias.")
    kernel, bias = compute_kernel_bias([
        a_vecs_train
    ], n_components=N_COMPONENTS)


Configs: anferico/bert-for-patents-first_last_avg-True-384.
Loading 14000 training samples from ../data/mini_train_data.csv


Some weights of the model checkpoint at anferico/bert-for-patents were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Building anferico/bert-for-patents tokenizer and model successfuly.
Transfer sentences to BERT vectors.


100%|██████████| 14000/14000 [06:01<00:00, 38.72it/s]


Compute kernel and bias.


# Main Model

In [56]:
import pandas as pd
import pickle
import numpy as np
from torch.utils.data import Dataset,DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AdamW
from operator import itemgetter
from sklearn.model_selection import StratifiedKFold

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
torch.cuda.set_device(0)

In [57]:
kernel = torch.tensor(kernel,dtype=torch.float32).cuda()
bias = torch.tensor(bias,dtype=torch.float32).cuda()

  kernel = torch.tensor(kernel,dtype=torch.float32).cuda()
  bias = torch.tensor(bias,dtype=torch.float32).cuda()


In [58]:
label2id = pickle.load(open('../temp_results/mini_label2id_dict.pkl','rb'))
id2label = pickle.load(open('../temp_results/mini_id2label_lst.pkl','rb'))

In [59]:
train_data = pd.read_csv('../data/mini_train_data.csv')
test_data = pd.read_csv('../data/mini_test_data.csv')

In [60]:
from transformers import AutoModelForMaskedLM,AutoTokenizer

model = AutoModelForMaskedLM.from_pretrained("anferico/bert-for-patents")
tokenizer = AutoTokenizer.from_pretrained("anferico/bert-for-patents")

Some weights of the model checkpoint at anferico/bert-for-patents were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


# Data Loader

In [61]:
def str2id_lst(str_label):
    id_lst = []
    for l in str_label.split(','):
        id_lst.append(label2id[l])
    return id_lst

class PatentDataset(Dataset):
    def __init__(self,df,labeled = True):
        self.df = df
        self.labeled = labeled
        
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self,idx):
        text = self.df.iloc[idx]['text'][3:]
        label = str2id_lst(self.df.iloc[idx]['cpc_ids'])
        
        if self.labeled:
            return text,label
        else:
            return text,None
        

In [62]:
test_dataset = PatentDataset(test_data)

In [63]:
def collate_fn(data):
    sents = [i[0] for i in data]
    labels = [i[1] for i in data]
    
    data = tokenizer.batch_encode_plus(batch_text_or_text_pairs=sents,
                                       truncation=True,
                                       padding='max_length',
                                       max_length=500,
                                       return_tensors='pt',
                                       return_length=True)
    input_ids = data['input_ids']
    attention_mask = data['attention_mask']
    token_type_ids = data['token_type_ids']
    
    batch_label = np.zeros((len(labels),len(id2label)))
    for i,_label in enumerate(labels):
        batch_label[i,_label]=1
    
    batch_label = torch.tensor(batch_label,dtype=torch.float32)
    
    return input_ids, attention_mask, token_type_ids, batch_label
    

In [64]:
test_dataloader = DataLoader(dataset = test_dataset,
                             batch_size = 4,
                             collate_fn = collate_fn)

# Define Model

In [65]:
class PatentClsModel(nn.Module):
    def __init__(self,bert_model,backbone_fixed = True):
        super().__init__()
        self.fc = nn.Sequential(nn.BatchNorm1d(384),
                                nn.Dropout(0.5),
                                nn.Linear(384,768),
                                nn.ReLU(),
                                nn.BatchNorm1d(768),
                                nn.Dropout(0.5),
                                nn.Linear(768,len(id2label)))
        
        self.bert_model = bert_model
        self.sig = nn.Sigmoid()
        self.backbone_fixed = backbone_fixed
        
        for i, module in enumerate(self.fc):
            if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d)):
                nn.init.constant_(module.weight, 1)
                nn.init.constant_(module.bias, 0)
            elif isinstance(module, nn.Linear):
                if getattr(module, "weight_v", None) is not None:
                    nn.init.uniform_(module.weight_g, 0, 1)
                    nn.init.kaiming_normal_(module.weight_v)
                    assert model[i].weight_g is not None
                else:
                    nn.init.kaiming_normal_(module.weight)
                nn.init.constant_(module.bias, 0)

    def forward(self, input_ids, attention_mask, token_type_ids):
        if self.backbone_fixed:
            with torch.no_grad():
                hidden = self.bert_model(input_ids = input_ids,
                                         attention_mask = attention_mask,
                                         token_type_ids = token_type_ids,
                                         output_hidden_states=True).hidden_states
                x = (hidden[-1] + hidden[1]).mean(dim=1)
                x = torch.mm(x+bias,kernel)
                
        else:
            hidden = self.bert_model(input_ids = input_ids,
                                     attention_mask = attention_mask,
                                     token_type_ids = token_type_ids,
                                     output_hidden_states=True).hidden_states[-1]
            x = (hidden[-1] + hidden[1]).mean(dim=1)
            x = torch.mm(x+bias,kernel)
            
        x = self.fc(x)
        x = self.sig(x)
        
        return x
        

# Training

In [66]:
kfold = StratifiedKFold(n_splits=5)
total_epochs = 30
test_predict_lst = []

In [None]:
from tqdm import tqdm
for train_index, valid_index in kfold.split(train_data,train_data['cpc_ids']):
    
    print('*'*20)
    print(f'Fold{len(test_predict_lst)+1}')
    print('*'*20)
    train_dataset = PatentDataset(train_data.iloc[train_index])
    val_dataset = PatentDataset(train_data.iloc[valid_index])

    train_dataloader = DataLoader(train_dataset,
                                 collate_fn = collate_fn,
                                 batch_size = 4,
                                 shuffle = True,
                                 drop_last = True)
    val_dataloader = DataLoader(val_dataset,
                               collate_fn = collate_fn,
                               batch_size = 4,
                               shuffle = True,
                               drop_last = True)

    patentModel = PatentClsModel(model,backbone_fixed = True).cuda()
    loss_func = nn.BCELoss()
    optimizer = AdamW(patentModel.parameters(), lr=5e-4)
    # reg_lambda = 0.035

    print('Dataloader Success---------------------')

    best_val_loss = 100
    for epoch in range(total_epochs):
        if epoch%5==0:
            print('|',">" * epoch," "*(total_epochs-epoch),'|')

        patentModel.train()
        for iter,(input_ids, attention_mask, token_type_ids, batch_label) in enumerate(tqdm(train_dataloader)):
            input_ids = input_ids.cuda()
            attention_mask = attention_mask.cuda()
            token_type_ids = token_type_ids.cuda()
            batch_label = batch_label.cuda()

            prediction = patentModel(input_ids, attention_mask, token_type_ids)
            
            # l2_reg = None
            # for w in patentModel.fc.parameters():
            #     if not l2_reg:
            #         l2_reg = w.norm(2)
            #     else:
            #         l2_reg = l2_reg + w.norm(2)

            loss = loss_func(prediction,batch_label)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        val_loss = 0
        patentModel.eval()
        with torch.no_grad():
            for iter,(input_ids, attention_mask, token_type_ids, batch_label) in enumerate(tqdm(val_dataloader)):
                input_ids = input_ids.cuda()
                attention_mask = attention_mask.cuda()
                token_type_ids = token_type_ids.cuda()
                batch_label = batch_label.cuda()
                prediction = patentModel(input_ids, attention_mask, token_type_ids)
                loss = loss_func(prediction,batch_label)
                val_loss += loss.detach().item()
            val_loss = val_loss/(iter+1)

        if epoch%10 == 0:
            print('Epoch {}, val_loss {:.4f}'.format(epoch, val_loss))

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(patentModel.state_dict(), 'ckpt/best_model_mini_{}.pth'.format(len(test_predict_lst)+1))
            print('Best val loss found: ', best_val_loss)

    print('This fold, the best val loss is: ', best_val_loss)

    test_loss = 0
    test_predict = None
    patentModel = PatentClsModel(model,backbone_fixed = True).cuda()
    patentModel.load_state_dict(torch.load('ckpt/best_model_mini_{}.pth'.format(len(test_predict_lst)+1)))

    patentModel.eval()
    with torch.no_grad():
        for iter,(input_ids, attention_mask, token_type_ids, batch_label) in enumerate(tqdm(test_dataloader)):
            input_ids = input_ids.cuda()
            attention_mask = attention_mask.cuda()
            token_type_ids = token_type_ids.cuda()
            batch_label = batch_label.cuda()
            prediction = patentModel(input_ids, attention_mask, token_type_ids)

            if test_predict is None:
                test_predict = prediction
            else:
                test_predict = torch.cat((test_predict,prediction),axis = 0)

            loss = loss_func(prediction,batch_label)
            test_loss += loss.detach().item()

    test_loss /= (iter+1)
    print('This fold, the test loss is: ', test_loss)

    test_predict_lst.append(test_predict)



********************
Fold1
********************




Dataloader Success---------------------
|                                 |


100%|██████████| 2800/2800 [03:34<00:00, 13.02it/s]
100%|██████████| 700/700 [00:53<00:00, 13.09it/s]


Epoch 0, val_loss 0.0160
Best val loss found:  0.01600022033854787


100%|██████████| 2800/2800 [03:35<00:00, 12.98it/s]
100%|██████████| 700/700 [00:53<00:00, 13.13it/s]


Best val loss found:  0.013929396279355776


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:53<00:00, 13.12it/s]


Best val loss found:  0.012648805676892931


100%|██████████| 2800/2800 [03:35<00:00, 12.99it/s]
100%|██████████| 700/700 [00:53<00:00, 13.14it/s]


Best val loss found:  0.01195431825605088


100%|██████████| 2800/2800 [03:35<00:00, 12.99it/s]
100%|██████████| 700/700 [00:53<00:00, 13.14it/s]


Best val loss found:  0.011478564590215684
| >>>>>                           |


100%|██████████| 2800/2800 [03:35<00:00, 12.99it/s]
100%|██████████| 700/700 [00:53<00:00, 13.09it/s]


Best val loss found:  0.011311318645187255


100%|██████████| 2800/2800 [03:35<00:00, 12.99it/s]
100%|██████████| 700/700 [00:53<00:00, 13.14it/s]


Best val loss found:  0.011167860065387296


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:53<00:00, 13.10it/s]


Best val loss found:  0.01093898075067305


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:53<00:00, 13.14it/s]


Best val loss found:  0.010883142806789172


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:53<00:00, 13.18it/s]


Best val loss found:  0.010741989121306687
| >>>>>>>>>>                      |


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:53<00:00, 13.12it/s]


Epoch 10, val_loss 0.0107


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:53<00:00, 13.14it/s]


Best val loss found:  0.010583799209645283


100%|██████████| 2800/2800 [03:35<00:00, 12.99it/s]
100%|██████████| 700/700 [00:53<00:00, 13.11it/s]
100%|██████████| 2800/2800 [03:35<00:00, 12.99it/s]
100%|██████████| 700/700 [00:53<00:00, 13.14it/s]


Best val loss found:  0.01053100017863991


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:53<00:00, 13.14it/s]


| >>>>>>>>>>>>>>>                 |


100%|██████████| 2800/2800 [03:35<00:00, 12.98it/s]
100%|██████████| 700/700 [00:53<00:00, 13.19it/s]
100%|██████████| 2800/2800 [03:34<00:00, 13.02it/s]
100%|██████████| 700/700 [00:52<00:00, 13.25it/s]
100%|██████████| 2800/2800 [03:35<00:00, 13.01it/s]
100%|██████████| 700/700 [00:52<00:00, 13.33it/s]
100%|██████████| 2800/2800 [03:35<00:00, 13.01it/s]
100%|██████████| 700/700 [00:52<00:00, 13.25it/s]


Best val loss found:  0.010492515246649937


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:52<00:00, 13.24it/s]


Best val loss found:  0.010439581741312785
| >>>>>>>>>>>>>>>>>>>>            |


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:52<00:00, 13.22it/s]


Epoch 20, val_loss 0.0104
Best val loss found:  0.010417796136212668


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:52<00:00, 13.31it/s]


Best val loss found:  0.010391077890859118


100%|██████████| 2800/2800 [03:35<00:00, 13.01it/s]
100%|██████████| 700/700 [00:52<00:00, 13.33it/s]
100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:53<00:00, 13.14it/s]
100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:53<00:00, 13.18it/s]


Best val loss found:  0.010338177170404899
| >>>>>>>>>>>>>>>>>>>>>>>>>       |


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:52<00:00, 13.21it/s]
100%|██████████| 2800/2800 [03:35<00:00, 12.98it/s]
100%|██████████| 700/700 [00:52<00:00, 13.22it/s]


Best val loss found:  0.01033579321445099


100%|██████████| 2800/2800 [03:35<00:00, 13.01it/s]
100%|██████████| 700/700 [00:52<00:00, 13.21it/s]


Best val loss found:  0.010292984734488917


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:53<00:00, 13.16it/s]
100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:52<00:00, 13.31it/s]


This fold, the best val loss is:  0.010292984734488917


100%|██████████| 1000/1000 [01:15<00:00, 13.33it/s]


This fold, the test loss is:  0.011685255819931626
********************
Fold2
********************
Dataloader Success---------------------
|                                 |


100%|██████████| 2800/2800 [03:35<00:00, 13.02it/s]
100%|██████████| 700/700 [00:52<00:00, 13.30it/s]


Epoch 0, val_loss 0.0158
Best val loss found:  0.015789745509890575


100%|██████████| 2800/2800 [03:35<00:00, 13.01it/s]
100%|██████████| 700/700 [00:53<00:00, 13.19it/s]


Best val loss found:  0.013707554676636521


100%|██████████| 2800/2800 [03:35<00:00, 12.99it/s]
100%|██████████| 700/700 [00:53<00:00, 13.18it/s]


Best val loss found:  0.012675422957566167


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:53<00:00, 13.17it/s]


Best val loss found:  0.011635177424177528


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:53<00:00, 13.17it/s]


Best val loss found:  0.01152524605094056
| >>>>>                           |


100%|██████████| 2800/2800 [03:35<00:00, 12.99it/s]
100%|██████████| 700/700 [00:53<00:00, 13.18it/s]


Best val loss found:  0.011104350934869476


100%|██████████| 2800/2800 [03:35<00:00, 13.01it/s]
100%|██████████| 700/700 [00:53<00:00, 13.13it/s]


Best val loss found:  0.010955703366281731


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:53<00:00, 13.16it/s]


Best val loss found:  0.010768274823868914


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:53<00:00, 13.17it/s]


Best val loss found:  0.010683475171348879


100%|██████████| 2800/2800 [03:35<00:00, 13.01it/s]
100%|██████████| 700/700 [00:53<00:00, 13.12it/s]


| >>>>>>>>>>                      |


100%|██████████| 2800/2800 [03:35<00:00, 12.99it/s]
100%|██████████| 700/700 [00:53<00:00, 13.15it/s]


Epoch 10, val_loss 0.0105
Best val loss found:  0.010539268490392715


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:53<00:00, 13.18it/s]


Best val loss found:  0.010397026047243603


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:53<00:00, 13.12it/s]


Best val loss found:  0.010392519424536398


100%|██████████| 2800/2800 [03:35<00:00, 12.99it/s]
100%|██████████| 700/700 [00:52<00:00, 13.24it/s]


Best val loss found:  0.01036206041629027


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:52<00:00, 13.26it/s]


Best val loss found:  0.010263539313870881
| >>>>>>>>>>>>>>>                 |


100%|██████████| 2800/2800 [03:35<00:00, 12.99it/s]
100%|██████████| 700/700 [00:52<00:00, 13.23it/s]


Best val loss found:  0.010248394003803176


100%|██████████| 2800/2800 [03:35<00:00, 13.01it/s]
100%|██████████| 700/700 [00:52<00:00, 13.24it/s]


Best val loss found:  0.010216431538907013


100%|██████████| 2800/2800 [03:35<00:00, 12.99it/s]
100%|██████████| 700/700 [00:52<00:00, 13.27it/s]
100%|██████████| 2800/2800 [03:34<00:00, 13.02it/s]
100%|██████████| 700/700 [00:52<00:00, 13.30it/s]
100%|██████████| 2800/2800 [03:35<00:00, 13.02it/s]
100%|██████████| 700/700 [00:52<00:00, 13.30it/s]


| >>>>>>>>>>>>>>>>>>>>            |


100%|██████████| 2800/2800 [03:34<00:00, 13.03it/s]
100%|██████████| 700/700 [00:52<00:00, 13.22it/s]


Epoch 20, val_loss 0.0102
Best val loss found:  0.010209044092667423


100%|██████████| 2800/2800 [03:35<00:00, 13.02it/s]
100%|██████████| 700/700 [00:52<00:00, 13.29it/s]
100%|██████████| 2800/2800 [03:35<00:00, 13.02it/s]
100%|██████████| 700/700 [00:52<00:00, 13.29it/s]
100%|██████████| 2800/2800 [03:34<00:00, 13.03it/s]
100%|██████████| 700/700 [00:52<00:00, 13.31it/s]
100%|██████████| 2800/2800 [03:35<00:00, 13.02it/s]
100%|██████████| 700/700 [00:52<00:00, 13.27it/s]


| >>>>>>>>>>>>>>>>>>>>>>>>>       |


100%|██████████| 2800/2800 [03:34<00:00, 13.02it/s]
100%|██████████| 700/700 [00:52<00:00, 13.23it/s]
100%|██████████| 2800/2800 [03:34<00:00, 13.03it/s]
100%|██████████| 700/700 [00:52<00:00, 13.26it/s]
100%|██████████| 2800/2800 [03:34<00:00, 13.03it/s]
100%|██████████| 700/700 [00:52<00:00, 13.26it/s]
100%|██████████| 2800/2800 [03:34<00:00, 13.04it/s]
100%|██████████| 700/700 [00:52<00:00, 13.29it/s]
100%|██████████| 2800/2800 [03:34<00:00, 13.03it/s]
100%|██████████| 700/700 [00:52<00:00, 13.26it/s]


This fold, the best val loss is:  0.010209044092667423


100%|██████████| 1000/1000 [01:15<00:00, 13.27it/s]


This fold, the test loss is:  0.0117582929818891
********************
Fold3
********************
Dataloader Success---------------------
|                                 |


100%|██████████| 2800/2800 [03:35<00:00, 13.02it/s]
100%|██████████| 700/700 [00:52<00:00, 13.26it/s]


Epoch 0, val_loss 0.0160
Best val loss found:  0.016047761175515395


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:52<00:00, 13.27it/s]


Best val loss found:  0.01413692626935829


100%|██████████| 2800/2800 [03:35<00:00, 12.99it/s]
100%|██████████| 700/700 [00:52<00:00, 13.22it/s]


Best val loss found:  0.012602954499889164


100%|██████████| 2800/2800 [03:35<00:00, 13.02it/s]
100%|██████████| 700/700 [00:52<00:00, 13.29it/s]


Best val loss found:  0.012164422183663452


100%|██████████| 2800/2800 [03:35<00:00, 13.01it/s]
100%|██████████| 700/700 [00:52<00:00, 13.24it/s]


Best val loss found:  0.011760309334910873
| >>>>>                           |


100%|██████████| 2800/2800 [03:34<00:00, 13.03it/s]
100%|██████████| 700/700 [00:52<00:00, 13.28it/s]


Best val loss found:  0.011258646007627248


100%|██████████| 2800/2800 [03:35<00:00, 13.02it/s]
100%|██████████| 700/700 [00:52<00:00, 13.24it/s]


Best val loss found:  0.011244842478938933


100%|██████████| 2800/2800 [03:35<00:00, 13.01it/s]
100%|██████████| 700/700 [00:52<00:00, 13.27it/s]


Best val loss found:  0.011073156288226268


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:52<00:00, 13.29it/s]


Best val loss found:  0.010959717727465821


100%|██████████| 2800/2800 [03:35<00:00, 13.02it/s]
100%|██████████| 700/700 [00:52<00:00, 13.28it/s]


Best val loss found:  0.01077752781599494
| >>>>>>>>>>                      |


100%|██████████| 2800/2800 [03:34<00:00, 13.03it/s]
100%|██████████| 700/700 [00:52<00:00, 13.24it/s]


Epoch 10, val_loss 0.0108
Best val loss found:  0.010768000329013115


100%|██████████| 2800/2800 [03:34<00:00, 13.03it/s]
100%|██████████| 700/700 [00:52<00:00, 13.28it/s]


Best val loss found:  0.010688310592834439


100%|██████████| 2800/2800 [03:34<00:00, 13.03it/s]
100%|██████████| 700/700 [00:52<00:00, 13.29it/s]


Best val loss found:  0.010674779733443367


100%|██████████| 2800/2800 [03:34<00:00, 13.04it/s]
100%|██████████| 700/700 [00:52<00:00, 13.29it/s]


Best val loss found:  0.010627497877326928


100%|██████████| 2800/2800 [03:34<00:00, 13.03it/s]
100%|██████████| 700/700 [00:52<00:00, 13.26it/s]


| >>>>>>>>>>>>>>>                 |


100%|██████████| 2800/2800 [03:34<00:00, 13.03it/s]
100%|██████████| 700/700 [00:52<00:00, 13.24it/s]


Best val loss found:  0.010554223964323423


100%|██████████| 2800/2800 [03:34<00:00, 13.03it/s]
100%|██████████| 700/700 [00:52<00:00, 13.30it/s]


Best val loss found:  0.010552543924589243


100%|██████████| 2800/2800 [03:34<00:00, 13.03it/s]
100%|██████████| 700/700 [00:52<00:00, 13.29it/s]
100%|██████████| 2800/2800 [03:34<00:00, 13.04it/s]
100%|██████████| 700/700 [00:52<00:00, 13.27it/s]


Best val loss found:  0.010524676720877844


100%|██████████| 2800/2800 [03:34<00:00, 13.04it/s]
100%|██████████| 700/700 [00:52<00:00, 13.29it/s]


Best val loss found:  0.010430415808555802
| >>>>>>>>>>>>>>>>>>>>            |


100%|██████████| 2800/2800 [03:34<00:00, 13.02it/s]
100%|██████████| 700/700 [00:52<00:00, 13.29it/s]


Epoch 20, val_loss 0.0103
Best val loss found:  0.010295102148915509


100%|██████████| 2800/2800 [03:34<00:00, 13.04it/s]
100%|██████████| 700/700 [00:52<00:00, 13.23it/s]
100%|██████████| 2800/2800 [03:34<00:00, 13.03it/s]
100%|██████████| 700/700 [00:52<00:00, 13.30it/s]
100%|██████████| 2800/2800 [03:34<00:00, 13.03it/s]
100%|██████████| 700/700 [00:52<00:00, 13.24it/s]
100%|██████████| 2800/2800 [03:34<00:00, 13.04it/s]
100%|██████████| 700/700 [00:52<00:00, 13.28it/s]


| >>>>>>>>>>>>>>>>>>>>>>>>>       |


100%|██████████| 2800/2800 [03:34<00:00, 13.02it/s]
100%|██████████| 700/700 [00:52<00:00, 13.21it/s]
100%|██████████| 2800/2800 [03:34<00:00, 13.03it/s]
100%|██████████| 700/700 [00:52<00:00, 13.29it/s]
100%|██████████| 2800/2800 [03:34<00:00, 13.03it/s]
100%|██████████| 700/700 [00:53<00:00, 13.13it/s]
100%|██████████| 2800/2800 [03:34<00:00, 13.03it/s]
100%|██████████| 700/700 [00:52<00:00, 13.22it/s]
100%|██████████| 2800/2800 [03:34<00:00, 13.04it/s]
100%|██████████| 700/700 [00:53<00:00, 13.14it/s]


This fold, the best val loss is:  0.010295102148915509


100%|██████████| 1000/1000 [01:16<00:00, 13.10it/s]


This fold, the test loss is:  0.01175575838109944
********************
Fold4
********************
Dataloader Success---------------------
|                                 |


100%|██████████| 2800/2800 [03:34<00:00, 13.05it/s]
100%|██████████| 700/700 [00:52<00:00, 13.22it/s]


Epoch 0, val_loss 0.0165
Best val loss found:  0.016519928939440952


100%|██████████| 2800/2800 [03:34<00:00, 13.04it/s]
100%|██████████| 700/700 [00:52<00:00, 13.23it/s]


Best val loss found:  0.01429164373781532


100%|██████████| 2800/2800 [03:35<00:00, 13.02it/s]
100%|██████████| 700/700 [00:52<00:00, 13.24it/s]


Best val loss found:  0.012958872345874884


100%|██████████| 2800/2800 [03:35<00:00, 13.01it/s]
100%|██████████| 700/700 [00:53<00:00, 13.18it/s]


Best val loss found:  0.01232363181986979


100%|██████████| 2800/2800 [03:35<00:00, 13.02it/s]
100%|██████████| 700/700 [00:53<00:00, 13.19it/s]


Best val loss found:  0.011897136773581483
| >>>>>                           |


100%|██████████| 2800/2800 [03:34<00:00, 13.03it/s]
100%|██████████| 700/700 [00:53<00:00, 13.19it/s]


Best val loss found:  0.011684112048741164


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:53<00:00, 13.12it/s]


Best val loss found:  0.011616864977404474


100%|██████████| 2800/2800 [03:34<00:00, 13.03it/s]
100%|██████████| 700/700 [00:52<00:00, 13.23it/s]


Best val loss found:  0.011295678362117282


100%|██████████| 2800/2800 [03:35<00:00, 13.02it/s]
100%|██████████| 700/700 [00:53<00:00, 13.15it/s]
100%|██████████| 2800/2800 [03:34<00:00, 13.04it/s]
100%|██████████| 700/700 [00:52<00:00, 13.23it/s]


Best val loss found:  0.011109176030648605
| >>>>>>>>>>                      |


100%|██████████| 2800/2800 [03:34<00:00, 13.04it/s]
100%|██████████| 700/700 [00:53<00:00, 13.19it/s]


Epoch 10, val_loss 0.0110
Best val loss found:  0.011003736722216542


100%|██████████| 2800/2800 [03:34<00:00, 13.04it/s]
100%|██████████| 700/700 [00:53<00:00, 13.14it/s]


Best val loss found:  0.010972083457745611


100%|██████████| 2800/2800 [03:34<00:00, 13.05it/s]
100%|██████████| 700/700 [00:53<00:00, 13.18it/s]


Best val loss found:  0.010962884435430168


100%|██████████| 2800/2800 [03:34<00:00, 13.03it/s]
100%|██████████| 700/700 [00:53<00:00, 13.16it/s]


Best val loss found:  0.010893025687962238


100%|██████████| 2800/2800 [03:34<00:00, 13.05it/s]
100%|██████████| 700/700 [00:53<00:00, 13.21it/s]


Best val loss found:  0.010789830614812672
| >>>>>>>>>>>>>>>                 |


100%|██████████| 2800/2800 [03:34<00:00, 13.04it/s]
100%|██████████| 700/700 [00:52<00:00, 13.22it/s]


Best val loss found:  0.010725835573061237


100%|██████████| 2800/2800 [03:34<00:00, 13.03it/s]
100%|██████████| 700/700 [00:53<00:00, 13.16it/s]
100%|██████████| 2800/2800 [03:34<00:00, 13.04it/s]
100%|██████████| 700/700 [00:53<00:00, 13.16it/s]
100%|██████████| 2800/2800 [03:34<00:00, 13.03it/s]
100%|██████████| 700/700 [00:52<00:00, 13.21it/s]


Best val loss found:  0.01069061292268868


100%|██████████| 2800/2800 [03:35<00:00, 12.99it/s]
100%|██████████| 700/700 [00:53<00:00, 13.07it/s]


Best val loss found:  0.010639341691109751
| >>>>>>>>>>>>>>>>>>>>            |


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:53<00:00, 13.17it/s]


Epoch 20, val_loss 0.0107


100%|██████████| 2800/2800 [03:35<00:00, 12.98it/s]
100%|██████████| 700/700 [00:53<00:00, 13.13it/s]
100%|██████████| 2800/2800 [03:35<00:00, 12.99it/s]
100%|██████████| 700/700 [00:53<00:00, 13.14it/s]
100%|██████████| 2800/2800 [03:35<00:00, 12.99it/s]
100%|██████████| 700/700 [00:53<00:00, 13.12it/s]
100%|██████████| 2800/2800 [03:35<00:00, 12.99it/s]
100%|██████████| 700/700 [00:53<00:00, 13.12it/s]


Best val loss found:  0.010556194866741342
| >>>>>>>>>>>>>>>>>>>>>>>>>       |


100%|██████████| 2800/2800 [03:35<00:00, 12.99it/s]
100%|██████████| 700/700 [00:53<00:00, 13.11it/s]
100%|██████████| 2800/2800 [03:35<00:00, 12.99it/s]
100%|██████████| 700/700 [00:53<00:00, 13.10it/s]


Best val loss found:  0.010492475037130394


100%|██████████| 2800/2800 [03:35<00:00, 12.98it/s]
100%|██████████| 700/700 [00:53<00:00, 13.14it/s]
100%|██████████| 2800/2800 [03:35<00:00, 12.99it/s]
100%|██████████| 700/700 [00:53<00:00, 13.10it/s]


Best val loss found:  0.010448150972868981


100%|██████████| 2800/2800 [03:35<00:00, 12.99it/s]
100%|██████████| 700/700 [00:53<00:00, 13.15it/s]


This fold, the best val loss is:  0.010448150972868981


100%|██████████| 1000/1000 [01:15<00:00, 13.16it/s]


This fold, the test loss is:  0.011557566983508877
********************
Fold5
********************
Dataloader Success---------------------
|                                 |


100%|██████████| 2800/2800 [03:35<00:00, 12.99it/s]
100%|██████████| 700/700 [00:53<00:00, 13.06it/s]


Epoch 0, val_loss 0.0169
Best val loss found:  0.016878496627655944


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:53<00:00, 13.10it/s]


Best val loss found:  0.01438886555443917


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:53<00:00, 13.12it/s]


Best val loss found:  0.012987098001675414


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:53<00:00, 13.07it/s]


Best val loss found:  0.012506689419304686


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:53<00:00, 13.11it/s]


Best val loss found:  0.0119824070283877
| >>>>>                           |


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:53<00:00, 13.14it/s]


Best val loss found:  0.011665181399356307


100%|██████████| 2800/2800 [03:35<00:00, 12.99it/s]
100%|██████████| 700/700 [00:53<00:00, 13.12it/s]


Best val loss found:  0.011450736456983057


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:53<00:00, 13.19it/s]


Best val loss found:  0.011278337622061372


100%|██████████| 2800/2800 [03:35<00:00, 13.01it/s]
100%|██████████| 700/700 [00:52<00:00, 13.29it/s]


Best val loss found:  0.011159405715464215


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:52<00:00, 13.25it/s]


Best val loss found:  0.011034468421712518
| >>>>>>>>>>                      |


100%|██████████| 2800/2800 [03:35<00:00, 13.01it/s]
100%|██████████| 700/700 [00:53<00:00, 13.14it/s]


Epoch 10, val_loss 0.0110
Best val loss found:  0.010990405014682828


100%|██████████| 2800/2800 [03:35<00:00, 12.99it/s]
100%|██████████| 700/700 [00:53<00:00, 13.08it/s]


Best val loss found:  0.01084767223302541


100%|██████████| 2800/2800 [03:35<00:00, 12.99it/s]
100%|██████████| 700/700 [00:53<00:00, 13.12it/s]
100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:53<00:00, 13.13it/s]
100%|██████████| 2800/2800 [03:35<00:00, 12.98it/s]
100%|██████████| 700/700 [00:53<00:00, 13.07it/s]


Best val loss found:  0.010737578336349023
| >>>>>>>>>>>>>>>                 |


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:52<00:00, 13.31it/s]


Best val loss found:  0.010624523154692724


100%|██████████| 2800/2800 [03:35<00:00, 13.01it/s]
100%|██████████| 700/700 [00:52<00:00, 13.28it/s]
100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:52<00:00, 13.30it/s]
100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:52<00:00, 13.24it/s]


Best val loss found:  0.010616425466391124


100%|██████████| 2800/2800 [03:35<00:00, 13.01it/s]
100%|██████████| 700/700 [00:52<00:00, 13.26it/s]


Best val loss found:  0.010529343752922225
| >>>>>>>>>>>>>>>>>>>>            |


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:52<00:00, 13.27it/s]


Epoch 20, val_loss 0.0106


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:52<00:00, 13.26it/s]


Best val loss found:  0.010501842426601798


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:52<00:00, 13.29it/s]
100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:52<00:00, 13.28it/s]
100%|██████████| 2800/2800 [03:35<00:00, 13.01it/s]
100%|██████████| 700/700 [00:53<00:00, 13.12it/s]


| >>>>>>>>>>>>>>>>>>>>>>>>>       |


100%|██████████| 2800/2800 [03:35<00:00, 12.99it/s]
100%|██████████| 700/700 [00:53<00:00, 13.14it/s]


Best val loss found:  0.010424756059822227


100%|██████████| 2800/2800 [03:35<00:00, 13.01it/s]
100%|██████████| 700/700 [00:53<00:00, 13.09it/s]
100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:53<00:00, 13.12it/s]


Best val loss found:  0.010409943605440536


100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
100%|██████████| 700/700 [00:53<00:00, 13.09it/s]
100%|██████████| 2800/2800 [03:35<00:00, 13.00it/s]
 86%|████████▌ | 603/700 [00:45<00:07, 13.10it/s]

In [None]:
torch.save(test_predict_lst,'test_result/mini_30e_004.pt')

In [None]:
import os 
os.system("shutdown")

# Last All Last

In [16]:
# patentModel.load_state_dict(torch.load('ckpt/001/best_model.pth'))

<All keys matched successfully>

In [15]:
# torch.save(patentModel.state_dict(), 'ckpt/001/best_model.pth')