In [1]:
import numpy as np
import random
from datetime import datetime
import time

import torch
import torchvision
from torch import tensor, nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau

from doc_sim_models import embeding_model, doc_sim_model
from doc_to_dataload import preprocess, set_dataLoader
from model_process import get_loss

from transformers import DistilBertModel, DistilBertTokenizer

if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

  torch.utils._pytree._register_pytree_node(


### data load
* 구조 : 
* [
*  [[문장1], [문장2], ..., [문장n]], # - 1개의 document
*  [[문장1], [문장2], ..., [문장n]],
*         ...
*  [[문장1], [문장2], ..., [문장n]]
* ]

* sample data를 찾지못해 문장유사도 데이터 활용. 각 document에 문장 1개만 있는 구조

In [2]:
## data load
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased', do_lower_case=True) # tokenizer

import requests

def download_sick(f): 

    response = requests.get(f).text

    lines = response.split("\n")[1:]
    lines = [l.split("\t") for l in lines if len(l) > 0]
    lines = [l for l in lines if len(l) == 5]
    
    return [[x[1]] for x in lines], [[x[2]] for x in lines],  np.array([x[3] for x in lines]).astype('float32')
    
train_1, train_2, train_label = download_sick("https://raw.githubusercontent.com/alvations/stasis/master/SICK-data/SICK_train.txt")
valid_1, valid_2, valid_label = download_sick("https://raw.githubusercontent.com/alvations/stasis/master/SICK-data/SICK_test_annotated.txt")

# 길이제한
train_1 = train_1[:32]
train_2 = train_2[:32]
train_label = train_label[:32]
valid_1 = valid_1[:32]
valid_2 = valid_2[:32]
valid_label = valid_label[:32]


train_loader = set_dataLoader(*preprocess(train_1, tokenizer),
                              *preprocess(train_2, tokenizer),
                              tensor(train_label).to(device),
                              type = 'train', batch_size = 16)
valid_loader = set_dataLoader(*preprocess(valid_1, tokenizer),
                              *preprocess(valid_2, tokenizer),
                              tensor(valid_label).to(device),
                              type = 'train', batch_size = 16)

## model load

In [3]:
# load bert embedding model
bert_emb = embeding_model(DistilBertModel.from_pretrained("distilbert-base-uncased"), pooling_type = 'mean')

# 1,2번째 encoding layer 학습 X
for param in bert_emb.encoder.transformer.layer[0].parameters():
    param.requires_grad = False
for param in bert_emb.encoder.transformer.layer[1].parameters():
    param.requires_grad = False


add_layers =  nn.Sequential(
            nn.Tanh(),
            nn.Linear(768,1024),
            nn.BatchNorm1d(1024),
            nn.Tanh(),
            nn.Linear(1024,1024),
            nn.BatchNorm1d(1024),
            nn.Tanh(),
            nn.Linear(1024,1024),
            nn.BatchNorm1d(1024),
            nn.Tanh(),
            nn.Linear(1024,1024),
            nn.BatchNorm1d(1024),
            nn.Tanh(),
            nn.Linear(1024,1024),
            nn.BatchNorm1d(1024),
            nn.Tanh(),
            nn.Linear(1024,1024),
            nn.BatchNorm1d(1024),
            nn.Tanh(),
            nn.Linear(1024,768),
            nn.Tanh()
        )    

initializer = nn.init.xavier_normal_

bert_sim_model = doc_sim_model(bert_emb, add_layers = add_layers, initialize = initializer)

if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model).cuda()

optimizer = AdamW(bert_sim_model.parameters(), lr=1e-5, eps = 1e-16, weight_decay = 0.4)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)

initial_lize :  Linear(in_features=768, out_features=1024, bias=True)
initial_lize :  Linear(in_features=1024, out_features=1024, bias=True)
initial_lize :  Linear(in_features=1024, out_features=1024, bias=True)
initial_lize :  Linear(in_features=1024, out_features=1024, bias=True)
initial_lize :  Linear(in_features=1024, out_features=1024, bias=True)
initial_lize :  Linear(in_features=1024, out_features=1024, bias=True)
initial_lize :  Linear(in_features=1024, out_features=768, bias=True)




## model training

In [4]:
# 보통은 함수로 빼서 사용
best_loss = 2
stopping_count = 0 # for early stopping
epochs = 5
model = bert_sim_model
opt = optimizer
sch = scheduler
path = './doc_sim_model_' + datetime.now().strftime('%Y_%m_%d') + '.pt'
start = time.time() # 시작 시간. 추후 학습시간이 길어지면 멈춤
early_stopping = 3
for epoch in range(epochs):
    ### training
    model.train()
    losses = []
    nums = []
    for b in train_loader:
        # infer
        logits = model(*b[:-1])

        # get loss
        loss, num = get_loss(logits, b[-1], nn.MSELoss(), opt = opt)
        losses.append(loss)
        nums.append(num)

    train_loss = np.sum(losses)/np.sum(nums)
    #torch.cuda.empty_cache()


    ### validation
    model.eval()
    losses = []
    nums = []
    for b in valid_loader:
        # infer
        logits = model(*b[:-1])

        # get loss
        loss, num = get_loss(logits, b[-1], nn.MSELoss())
        losses.append(loss)
        nums.append(num)

    val_loss = np.sum(losses)/np.sum(nums)
    #torch.cuda.empty_cache()
    if sch is not None:
        sch.step(val_loss)
    print("training loss : ", train_loss, "validation loss : ", val_loss, " & time : ", time.time() - start, " & epoch : ", epoch+1)


    ## best model save
    if val_loss < best_loss:
        torch.save(model, path)
        print('best model save at : ' + path)
        best_loss = val_loss
        stopping_count = 0
    else:
        stopping_count += 1

    ## early_stopping
    if stopping_count >= early_stopping:
        break
    if time.time() - start > 7000:
        break

training loss :  0.7414435148239136 validation loss :  0.564044177532196  & time :  10.358126401901245  & epoch :  1
best model save at : ./doc_sim_model_2024_03_23.pt
training loss :  0.7148368656635284 validation loss :  0.5924107134342194  & time :  21.827714681625366  & epoch :  2
training loss :  0.6821464002132416 validation loss :  0.6095966696739197  & time :  31.62321424484253  & epoch :  3
training loss :  0.671469509601593 validation loss :  0.6164752542972565  & time :  41.12640309333801  & epoch :  4
