In [1]:
import numpy as np 
import pandas as pd 
import json
from transformers import * 
import os
from tqdm.auto import tqdm 
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
from torch.utils.data import Dataset, TensorDataset, DataLoader, RandomSampler, SequentialSampler 
import pickle 
import time 



In [2]:
with open("qa_json(full).json") as f: 
    d = json.load(f) 

In [3]:
queries, passages, answers = [], [], [] 

for i in tqdm(range(len(d)), position=0, leave=True): 
    query = d[i]["query"] 
    passage = d[i]["passages"]
    answer = d[i]["answer"] 
    queries.append(query) 
    passages.append(passage[0])  
    answers.append(answer) 

  0%|          | 0/6786 [00:00<?, ?it/s]

In [4]:
cnt = 0
for i in range(len(answers)): 
    if answers[i] in ["소극", "적극"]: 
        cnt += 1 
        
exists = cnt / len(answers) * 100.0 

print(exists), cnt

37.50368405540819


(None, 2545)

In [5]:
all_data = pd.DataFrame(list(zip(queries, passages, answers)), columns=["queries", "passages", "answers"])

all_data.head()  

Unnamed: 0,queries,passages,answers
0,설립준비중의 재단법인에 대한 재산의 출연,재단법인의 설립준비중 제3자가 그 설립자에 대하여 장차 설립될 동 법인에 설립을 조...,
1,수임인이 위임사무를 처리함에 있어 받은 물건으로 위임인에게 인도한 목적물은 그것이 ...,집행불능시의 대상청구속에는 예비적으로 이행불능시의 전보배상청구도 포함된 것으로 보고...,
2,무권대리인의 상대방이 갖는 계약의 이행 또는 손해배상청구권의 소멸시효의 기산점,타인의 대리인으로 계약을 한 자가 그 대리권을 증명하지 못하고 또 본인의 추인을 얻...,
3,사실혼관계에 있는 당사자의 일방이 모르는 사이에 혼인 신고가 이루어진후 쌍방 당사자...,본법 제139조는 재산법에 관한 총칙규정이고 신분법에 관하여는 그대로 통 용될 ...,
4,"남편 소유의 부동산 매각과, 아내의 일상 가사 대리권의 한계",부부간의 일상가사대리권은 그 동거생활을 추지하기 위하여 각각 필요한 범위내의 법률행...,


In [6]:
query_index_dict = {} 

for i in range(len(queries)): 
    query_index_dict[queries[i]] = [] 
    
for i in range(len(queries)): 
    query_index_dict[queries[i]].append(i) 

In [7]:
model_name = "monologg/kobigbird-bert-base" 
q_tokenizer = AutoTokenizer.from_pretrained(model_name) 
p_tokenizer = AutoTokenizer.from_pretrained(model_name)  

loading file vocab.txt from cache at /root/.cache/huggingface/hub/models--monologg--kobigbird-bert-base/snapshots/ceacda477e20abef2c929adfa4a07c6f811323be/vocab.txt
loading file tokenizer.json from cache at /root/.cache/huggingface/hub/models--monologg--kobigbird-bert-base/snapshots/ceacda477e20abef2c929adfa4a07c6f811323be/tokenizer.json
loading file added_tokens.json from cache at None
loading file special_tokens_map.json from cache at /root/.cache/huggingface/hub/models--monologg--kobigbird-bert-base/snapshots/ceacda477e20abef2c929adfa4a07c6f811323be/special_tokens_map.json
loading file tokenizer_config.json from cache at /root/.cache/huggingface/hub/models--monologg--kobigbird-bert-base/snapshots/ceacda477e20abef2c929adfa4a07c6f811323be/tokenizer_config.json
loading file vocab.txt from cache at /root/.cache/huggingface/hub/models--monologg--kobigbird-bert-base/snapshots/ceacda477e20abef2c929adfa4a07c6f811323be/vocab.txt
loading file tokenizer.json from cache at /root/.cache/huggingf

In [8]:
train_size = int(0.8 * all_data.shape[0]) 
val_size = int(0.1 * all_data.shape[0])  

train_df = all_data.iloc[:train_size] 
val_df = all_data.iloc[train_size:train_size+val_size] 
test_df = all_data.iloc[train_size+val_size:]  

train_df.shape, val_df.shape, test_df.shape 

((5428, 3), (678, 3), (680, 3))

In [9]:
train_questions, train_candidates = train_df["queries"].values, train_df["passages"].values 

q_input_ids, q_attn_masks = [], [] 
for i in tqdm(range(len(train_questions)), position=0, leave=True): 
    encoded_inputs = q_tokenizer(str(train_questions[i]), max_length=512, truncation=True, padding="max_length") 
    q_input_ids.append(encoded_inputs["input_ids"]) 
    q_attn_masks.append(encoded_inputs["attention_mask"]) 

c_input_ids, c_attn_masks = [], [] 
for i in tqdm(range(len(train_candidates)), position=0, leave=True): 
    encoded_inputs = p_tokenizer(str(train_candidates[i]), max_length=512, truncation=True, padding="max_length") 
    c_input_ids.append(encoded_inputs["input_ids"]) 
    c_attn_masks.append(encoded_inputs["attention_mask"]) 
    
q_input_ids = torch.tensor(q_input_ids, dtype=int) 
q_attn_masks = torch.tensor(q_attn_masks, dtype=int) 
c_input_ids = torch.tensor(c_input_ids, dtype=int) 
c_attn_masks = torch.tensor(c_attn_masks, dtype=int) 

q_input_ids.shape, q_attn_masks.shape, c_input_ids.shape, c_attn_masks.shape 

  0%|          | 0/5428 [00:00<?, ?it/s]

  0%|          | 0/5428 [00:00<?, ?it/s]

(torch.Size([5428, 512]),
 torch.Size([5428, 512]),
 torch.Size([5428, 512]),
 torch.Size([5428, 512]))

In [10]:
# prepare all candidate ids
candidates = all_data["passages"].values 

all_context_input_ids, all_context_attn_masks = [], [] 
for i in tqdm(range(len(candidates)), position=0, leave=True):
    encoded_inputs = p_tokenizer(str(candidates[i]), max_length=512, truncation=True, padding="max_length") 
    all_context_input_ids.append(encoded_inputs["input_ids"]) 
    all_context_attn_masks.append(encoded_inputs["attention_mask"]) 
     
all_context_input_ids = torch.tensor(all_context_input_ids, dtype=int) 
all_context_attn_masks = torch.tensor(all_context_attn_masks, dtype=int)  

all_context_input_ids.shape, all_context_attn_masks.shape 

  0%|          | 0/6786 [00:00<?, ?it/s]

(torch.Size([6786, 512]), torch.Size([6786, 512]))

In [11]:
best_recall = 0 

train_dataset = TensorDataset(q_input_ids, q_attn_masks, c_input_ids, c_attn_masks) 
train_sampler = RandomSampler(train_dataset) 
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=32) 

num_train_epochs = 10 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  

q_encoder = AutoModel.from_pretrained("monologg/kobigbird-bert-base") 
q_encoder.to(device) 
checkpoint = torch.load("KoBigBird_query_encoder__.pt")
q_encoder.load_state_dict(checkpoint) 

p_encoder = AutoModel.from_pretrained("monologg/kobigbird-bert-base") 
p_encoder.to(device) 
checkpoint = torch.load("KoBigBird_passage_encoder__.pt")
p_encoder.load_state_dict(checkpoint) 

loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--monologg--kobigbird-bert-base/snapshots/ceacda477e20abef2c929adfa4a07c6f811323be/config.json
Model config BigBirdConfig {
  "_name_or_path": "monologg/kobigbird-bert-base",
  "architectures": [
    "BigBirdForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "attention_type": "block_sparse",
  "block_size": 64,
  "bos_token_id": 5,
  "classifier_dropout": null,
  "eos_token_id": 6,
  "gradient_checkpointing": false,
  "hidden_act": "gelu_new",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 4096,
  "model_type": "big_bird",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "num_random_blocks": 3,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "rescale_embeddings": false,
  "sep_token_id": 3,
  "tokenizer_class": "BertTokenizer",
  "torch_dtype": "floa

<All keys matched successfully>

In [12]:
params = list(q_encoder.parameters()) + list(p_encoder.parameters())  
optimizer = AdamW(params, lr=2e-5, eps=1e-8)  
t_total = len(train_dataloader) * num_train_epochs 
q_encoder.zero_grad()  
p_encoder.zero_grad()  
torch.cuda.empty_cache() 
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(0.05 * t_total), num_training_steps=t_total) 



In [13]:
for epoch_i in tqdm(range(0, num_train_epochs), desc="Epochs", position=0, leave=True, total=num_train_epochs): 
    train_loss = 0
    q_encoder.train() 
    p_encoder.train() 
    with tqdm(train_dataloader, unit="batch") as tepoch: 
        for step, batch in enumerate(tepoch): 
            batch = tuple(t.to(device) for t in batch) 
            bq_input_ids, bq_attn_masks, bp_input_ids, bp_attn_masks = batch 
            q_outputs = q_encoder(bq_input_ids, bq_attn_masks).pooler_output 
            p_outputs = p_encoder(bp_input_ids, bp_attn_masks).pooler_output 
            # (batch_size * embedding_dims) * (embedding_dims * batch_size) 
            sim_scores = torch.matmul(q_outputs, torch.transpose(p_outputs, 0, 1)) 
            targets = torch.arange(0, q_outputs.shape[0]).long().to(device) 
            sim_scores = F.log_softmax(sim_scores, dim=1) 
            loss = F.nll_loss(sim_scores, targets)  
            train_loss += loss.item() 
            loss.backward() 
            optimizer.step() 
            scheduler.step() 
            q_encoder.zero_grad() 
            p_encoder.zero_grad()  
            tepoch.set_postfix(loss=train_loss / (step+1))  
            time.sleep(0.1) 
    avg_train_loss = train_loss / len(train_dataloader)  
    print(f"avg train loss : {avg_train_loss}") 
    
    val_loss = 0
    print("validating") 
    with torch.no_grad(): 
        p_encoder.eval() 
        p_embs = [] 
        inference_dataset = TensorDataset(all_context_input_ids, all_context_attn_masks)
        inference_sampler = SequentialSampler(inference_dataset) 
        inference_dataloader = DataLoader(inference_dataset, sampler=inference_sampler, batch_size=64)
        for step, batch in tqdm(enumerate(inference_dataloader), position=0, leave=True, total=len(inference_dataloader), desc="calculating candidate embeddings with current model"):
            batch = (t.to(device) for t in batch) 
            b_input_ids, b_attn_masks = batch 
            p_emb = p_encoder(b_input_ids, b_attn_masks).pooler_output 
            for i in range(p_emb.shape[0]): 
                p_embs.append(torch.reshape(p_emb[i], (-1, 768))) 
        # p_embs = torch.Tensor(p_embs).squeeze() 
        p_embs = torch.cat(p_embs, dim=0) 
        print(f"candidate embeddings shape: {p_embs.shape}")  
        top_50 = 0 
        q_encoder.eval()
        val_questions = val_df["queries"].values 
        for sample_idx in tqdm(range(len(val_questions)), position=0, leave=True, desc="calculating Recall@50"):
            query = val_questions[sample_idx] 
            encoded_query = q_tokenizer(str(query), max_length=512, truncation=True, padding="max_length", return_tensors="pt").to(device) 
            q_emb = q_encoder(**encoded_query).pooler_output 
            dot_prod_scores = torch.matmul(q_emb, torch.transpose(p_embs, 0, 1)) 
            rank = torch.argsort(dot_prod_scores, dim=1, descending=True).squeeze()
            correct_idx = query_index_dict[query] # we only have a single correct index in this case  
            cnt = 0
            for idx in correct_idx: 
                if idx in rank[:50]: 
                    cnt += 1 
            top_50 += cnt / len(correct_idx) 
    avg_top50 = top_50 / len(val_questions) 
    print(f"mean recall@50 : {avg_top50}")
    if avg_top50 > best_recall: 
        print("saving best checkpoint so far!") 
        best_recall = avg_top50 
        torch.save(q_encoder.state_dict(), "large_law_KoBigBird_query_encoder.pt") 
        torch.save(p_encoder.state_dict(), "large_law_KoBigBird_passage_encoder.pt") 

print(f"best recall: {best_recall}") 
print("done training!") 

Epochs:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/170 [00:00<?, ?batch/s]

Attention type 'block_sparse' is not possible if sequence_length: 512 <= num global tokens: 2 * config.block_size + min. num sliding tokens: 3 * config.block_size + config.num_random_blocks * config.block_size + additional buffer: config.num_random_blocks * config.block_size = 704 with config.block_size = 64, config.num_random_blocks = 3. Changing attention type to 'original_full'...
Attention type 'block_sparse' is not possible if sequence_length: 512 <= num global tokens: 2 * config.block_size + min. num sliding tokens: 3 * config.block_size + config.num_random_blocks * config.block_size + additional buffer: config.num_random_blocks * config.block_size = 704 with config.block_size = 64, config.num_random_blocks = 3. Changing attention type to 'original_full'...


avg train loss : 0.7241659341927837
validating


calculating candidate embeddings with current model:   0%|          | 0/107 [00:00<?, ?it/s]

candidate embeddings shape: torch.Size([6786, 768])


calculating Recall@50:   0%|          | 0/678 [00:00<?, ?it/s]

mean recall@50 : 0.9800884955752213
saving best checkpoint so far!


  0%|          | 0/170 [00:00<?, ?batch/s]

avg train loss : 0.4195954817292445
validating


calculating candidate embeddings with current model:   0%|          | 0/107 [00:00<?, ?it/s]

candidate embeddings shape: torch.Size([6786, 768])


calculating Recall@50:   0%|          | 0/678 [00:00<?, ?it/s]

mean recall@50 : 0.9786135693215339


  0%|          | 0/170 [00:00<?, ?batch/s]

avg train loss : 0.3875195291112451
validating


calculating candidate embeddings with current model:   0%|          | 0/107 [00:00<?, ?it/s]

candidate embeddings shape: torch.Size([6786, 768])


calculating Recall@50:   0%|          | 0/678 [00:00<?, ?it/s]

mean recall@50 : 0.9800884955752213


  0%|          | 0/170 [00:00<?, ?batch/s]

avg train loss : 0.3633478946744136
validating


calculating candidate embeddings with current model:   0%|          | 0/107 [00:00<?, ?it/s]

candidate embeddings shape: torch.Size([6786, 768])


calculating Recall@50:   0%|          | 0/678 [00:00<?, ?it/s]

mean recall@50 : 0.9771386430678466


  0%|          | 0/170 [00:00<?, ?batch/s]

avg train loss : 0.3353768384152585
validating


calculating candidate embeddings with current model:   0%|          | 0/107 [00:00<?, ?it/s]

candidate embeddings shape: torch.Size([6786, 768])


calculating Recall@50:   0%|          | 0/678 [00:00<?, ?it/s]

mean recall@50 : 0.9771386430678466


  0%|          | 0/170 [00:00<?, ?batch/s]

avg train loss : 0.31391136891701643
validating


calculating candidate embeddings with current model:   0%|          | 0/107 [00:00<?, ?it/s]

candidate embeddings shape: torch.Size([6786, 768])


calculating Recall@50:   0%|          | 0/678 [00:00<?, ?it/s]

mean recall@50 : 0.9756637168141593


  0%|          | 0/170 [00:00<?, ?batch/s]

avg train loss : 0.28880718599128374
validating


calculating candidate embeddings with current model:   0%|          | 0/107 [00:00<?, ?it/s]

candidate embeddings shape: torch.Size([6786, 768])


calculating Recall@50:   0%|          | 0/678 [00:00<?, ?it/s]

mean recall@50 : 0.9771386430678466


  0%|          | 0/170 [00:00<?, ?batch/s]

avg train loss : 0.2767469352703569
validating


calculating candidate embeddings with current model:   0%|          | 0/107 [00:00<?, ?it/s]

candidate embeddings shape: torch.Size([6786, 768])


calculating Recall@50:   0%|          | 0/678 [00:00<?, ?it/s]

mean recall@50 : 0.9786135693215339


  0%|          | 0/170 [00:00<?, ?batch/s]

avg train loss : 0.24528285970651162
validating


calculating candidate embeddings with current model:   0%|          | 0/107 [00:00<?, ?it/s]

candidate embeddings shape: torch.Size([6786, 768])


calculating Recall@50:   0%|          | 0/678 [00:00<?, ?it/s]

mean recall@50 : 0.9771386430678466


  0%|          | 0/170 [00:00<?, ?batch/s]

avg train loss : 0.24846477454240598
validating


calculating candidate embeddings with current model:   0%|          | 0/107 [00:00<?, ?it/s]

candidate embeddings shape: torch.Size([6786, 768])


calculating Recall@50:   0%|          | 0/678 [00:00<?, ?it/s]

mean recall@50 : 0.9771386430678466
best recall: 0.9800884955752213
done training!


In [14]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 

best_q_chkpt = torch.load("large_law_KoBigBird_query_encoder.pt")
best_p_chkpt = torch.load("large_law_KoBigBird_passage_encoder.pt")

q_encoder = AutoModel.from_pretrained("monologg/kobigbird-bert-base") 
q_encoder.to(device) 
print(q_encoder.load_state_dict(best_q_chkpt)) 

p_encoder = AutoModel.from_pretrained("monologg/kobigbird-bert-base") 
p_encoder.to(device) 
print(p_encoder.load_state_dict(best_p_chkpt)) 

print()  

loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--monologg--kobigbird-bert-base/snapshots/ceacda477e20abef2c929adfa4a07c6f811323be/config.json
Model config BigBirdConfig {
  "_name_or_path": "monologg/kobigbird-bert-base",
  "architectures": [
    "BigBirdForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "attention_type": "block_sparse",
  "block_size": 64,
  "bos_token_id": 5,
  "classifier_dropout": null,
  "eos_token_id": 6,
  "gradient_checkpointing": false,
  "hidden_act": "gelu_new",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 4096,
  "model_type": "big_bird",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "num_random_blocks": 3,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "rescale_embeddings": false,
  "sep_token_id": 3,
  "tokenizer_class": "BertTokenizer",
  "torch_dtype": "floa

<All keys matched successfully>


Some weights of the model checkpoint at monologg/kobigbird-bert-base were not used when initializing BigBirdModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BigBirdModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BigBirdModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of BigBirdModel were initialized from the model checkpoint at monologg/kobigbird-bert-b

<All keys matched successfully>



In [17]:
top50_answers = []

with torch.no_grad(): 
    p_encoder.eval() 
    q_encoder.eval() 
    p_embs = [] 

    inference_dataset = TensorDataset(all_context_input_ids, all_context_attn_masks) 
    inference_sampler = SequentialSampler(inference_dataset)
    inference_dataloader = DataLoader(inference_dataset, sampler=inference_sampler, batch_size=32) 

    for step, batch in tqdm(enumerate(inference_dataloader), position=0, leave=True, total=len(inference_dataloader)): 
        batch = (t.to(device) for t in batch)
        b_input_ids, b_attn_masks = batch 
        p_emb = p_encoder(b_input_ids, b_attn_masks).pooler_output 
        for i in range(p_emb.shape[0]): 
            p_embs.append(torch.reshape(p_emb[i], (-1, 768))) 

    p_embs = torch.cat(p_embs, dim=0)
    print(f"candidate embeddings shape : {p_embs.shape}") 
    
    top_1, top_5, top_10, top_20, top_50 = 0, 0, 0, 0, 0 
    test_questions = test_df["queries"].values 
    for sample_idx in tqdm(range(len(test_questions)), position=0, leave=True, desc="Calculating Recall"):
        query = test_questions[sample_idx] 
        encoded_query = q_tokenizer(str(query), max_length=512, truncation=True, padding="max_length", return_tensors="pt").to(device) 
        q_emb = q_encoder(**encoded_query).pooler_output
        dot_prod_scores = torch.matmul(q_emb, torch.transpose(p_embs, 0, 1)) 
        rank = torch.argsort(dot_prod_scores, dim=1, descending=True).squeeze() 
        top50_answers.append(rank[:50]) 
        correct_idx = query_index_dict[query] 
        
        cnt = 0 
        for idx in correct_idx: 
            if idx in rank[:1]: 
                cnt += 1 
        top_1 += cnt / len(correct_idx) 
        
        cnt = 0
        for idx in correct_idx: 
            if idx in rank[:5]: 
                cnt += 1
        top_5 += cnt / len(correct_idx) 
        
        cnt = 0 
        for idx in correct_idx: 
            if idx in rank[:10]:
                cnt += 1
        top_10 += cnt / len(correct_idx) 
        
        cnt = 0 
        for idx in correct_idx:
            if idx in rank[:20]:
                cnt += 1
        top_20 += cnt / len(correct_idx) 
        
        cnt = 0
        for idx in correct_idx: 
            if idx in rank[:50]:
                cnt += 1
        top_50 += cnt / len(correct_idx) 
        
    avg_top1 = top_1 / len(test_questions) 
    avg_top5 = top_5 / len(test_questions) 
    avg_top10 = top_10 / len(test_questions) 
    avg_top20 = top_20 / len(test_questions) 
    avg_top50 = top_50 / len(test_questions) 
    
    print(f"RECALL@1:{avg_top1} | RECALL@5:{avg_top5} | RECALL@10:{avg_top10} | RECALL@20:{avg_top20} | RECALL@50:{avg_top50}")
    

  0%|          | 0/213 [00:00<?, ?it/s]

candidate embeddings shape : torch.Size([6786, 768])


Calculating Recall:   0%|          | 0/680 [00:00<?, ?it/s]

RECALL@1:0.7218137254901962 | RECALL@5:0.9161764705882353 | RECALL@10:0.9367647058823529 | RECALL@20:0.9544117647058824 | RECALL@50:0.9691176470588235
