In [1]:
!nvidia-smi

Fri Jun 10 07:03:12 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   41C    P0    27W / 300W |      0MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
!pip install -q datasets
!pip install -q sentence-transformers

[K     |████████████████████████████████| 346 kB 15.9 MB/s 
[K     |████████████████████████████████| 212 kB 65.0 MB/s 
[K     |████████████████████████████████| 140 kB 73.1 MB/s 
[K     |████████████████████████████████| 1.1 MB 60.0 MB/s 
[K     |████████████████████████████████| 86 kB 5.5 MB/s 
[K     |████████████████████████████████| 86 kB 6.4 MB/s 
[K     |████████████████████████████████| 596 kB 68.0 MB/s 
[K     |████████████████████████████████| 127 kB 56.5 MB/s 
[K     |████████████████████████████████| 94 kB 4.0 MB/s 
[K     |████████████████████████████████| 271 kB 53.8 MB/s 
[K     |████████████████████████████████| 144 kB 61.4 MB/s 
[K     |████████████████████████████████| 112 kB 63.3 MB/s 
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.[0

In [3]:
import os, sys
import random
import pickle
from pathlib import Path
from tqdm.notebook import tqdm
from typing import Dict

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from datasets import load_metric

from transformers import RobertaForMaskedLM, RobertaTokenizerFast, RobertaTokenizer
from transformers import LineByLineTextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
from transformers.modeling_outputs import SequenceClassifierOutput

from sentence_transformers import InputExample
from sentence_transformers import models, SentenceTransformer
from sentence_transformers import losses
from sentence_transformers.util import cos_sim
from sentence_transformers.evaluation import LabelAccuracyEvaluator
from sentence_transformers.losses.TripletLoss import TripletDistanceMetric

In [4]:
clean_legal_corpus_df = pd.read_pickle('/content/drive/MyDrive/NLP/Information_Retrieval/Legal_Text_Retrieval_Zalo2021/datasets/clean_legal_corpus.pkl')
clean_legal_corpus_df

Unnamed: 0_level_0,Unnamed: 1_level_0,text
law_id,article_id,Unnamed: 2_level_1
01/2009/tt-bnn,1,phạm_vi áp_dụng thông_tư hướng_dẫn tuần_tra ca...
01/2009/tt-bnn,2,tổ_chức lực_lượng hàng trước mùa mưa_lũ uỷ_ban...
01/2009/tt-bnn,3,tiêu_chuẩn thành_viên lực_lượng tuần_tra canh_...
01/2009/tt-bnn,4,nhiệm_vụ lực_lượng tuần_tra canh_gác đê chấp_h...
01/2009/tt-bnn,5,phù_hiệu lực_lượng tuần_tra canh_gác đê phù_hi...
...,...,...
99/2020/nđ-cp,60,thẩm_quyền xử_phạt hải_quan chi_cục trưởng chi...
99/2020/nđ-cp,61,thẩm_quyền xử_phạt quản_lý thị_trường kiểm_soá...
99/2020/nđ-cp,62,thẩm_quyền xử_phạt thanh_tra thanh_tra_viên gi...
99/2020/nđ-cp,63,phân_định thẩm_quyền xử_phạt chủ_tịch uỷ_ban_n...


In [5]:
clean_question_answer_df = pd.read_pickle('/content/drive/MyDrive/NLP/Information_Retrieval/Legal_Text_Retrieval_Zalo2021/datasets/clean_question_answer_df.pkl')
clean_question_answer_df

Unnamed: 0,question_id,question,relevant_articles
0,0637bf82c8b290c7875c5bfddbf91df5,công_an xã xử_phạt lỗi mang bằng lái_xe đúng,"[{'law_id': '47/2011/tt-bca', 'article_id': '7'}]"
1,ade2b2ee4f5b869f75f0d183902382af,thực_hiện thao_tác nạp mẫu bình chứa xử_lý mẫu...,"[{'law_id': '41/2020/tt-bca', 'article_id': '1..."
2,8fa164edc7a8419cd7dc9ce66d8e695a,trình_tự đánh_giá chất_lượng hàng đối_với kiểm...,"[{'law_id': '159/2020/nđ-cp', 'article_id': '1..."
3,3aed91309b72097b34427dd28bebd98a,tử_tù chết tiêm thuốc_độc thân_nhân nhận xác h...,"[{'law_id': '53/2010/qh12', 'article_id': '60'..."
4,fe9131a8862ce1bfa7de7e2b30eeb39e,trách_nhiệm khoa_học công_nghệ quản_lý phát_tr...,"[{'law_id': '63/2020/nđ-cp', 'article_id': '20'}]"
...,...,...,...
3191,bd2f979dabdf0033a00ff488b2893e12,nội_dung kiểm_tra công_tác nghiệm_thu công_trì...,"[{'law_id': '06/2021/nđ-cp', 'article_id': '24'}]"
3192,b7347919f2bd65a2da6f9d0b1728a51c,mức phạt đối_với hành_vi chuyển mục_đích sử_dụ...,"[{'law_id': '35/2019/nđ-cp', 'article_id': '12'}]"
3193,bda416f28b8c88ce25604a26b01081cb,trích xuất thi_hành án hình_sự hiểu thế_nào,"[{'law_id': '41/2019/qh14', 'article_id': '3'}]"
3194,6088765b2e502c7d900a3f9a2f82f2b4,chạy ô_tô dầu_nhờn rơi xuống_đường phạt bao_nhiêu,"[{'law_id': '100/2019/nđ-cp', 'article_id': '2..."


In [6]:
with open('/content/drive/MyDrive/NLP/Information_Retrieval/Legal_Text_Retrieval_Zalo2021/datasets/neg_pairs_top_20_2.pkl', 'rb') as f:
    neg_pairs_top_20 = pickle.load(f)

In [7]:
neg_pairs_top_20[:5]

[{'negative': [{'article_id': '1', 'law_id': '26/2009/tt-bnn'},
   {'article_id': '2', 'law_id': '54/2013/tt-bnnptnt'},
   {'article_id': '1', 'law_id': '54/2013/tt-bnnptnt'},
   {'article_id': '1', 'law_id': '13/2012/ttlt-bca-btp-vksndtc-tandtc'},
   {'article_id': '1', 'law_id': '01/2011/tt-bca'},
   {'article_id': '1', 'law_id': '04/2019/tt-bgtvt'},
   {'article_id': '1', 'law_id': '13/2013/ttlt-bca-bqp-vksndtc-tandtc'},
   {'article_id': '1', 'law_id': '123/2014/tt-btc'},
   {'article_id': '1', 'law_id': '19/2018/tt-bnnptnt'},
   {'article_id': '2', 'law_id': '19/2018/tt-bnnptnt'},
   {'article_id': '1', 'law_id': '33/2015/tt-bca'},
   {'article_id': '1', 'law_id': '04/2014/tt-blđtbxh'},
   {'article_id': '1', 'law_id': '58/2015/tt-bca'},
   {'article_id': '1', 'law_id': '206/2016/tt-btc'},
   {'article_id': '1', 'law_id': '31/2018/tt-bnnptnt'},
   {'article_id': '1', 'law_id': '115/2013/tt-btc'},
   {'article_id': '1', 'law_id': '05/2019/tt-bnnptnt'},
   {'article_id': '1', 'law_i

In [8]:
train_df, val_df = train_test_split(clean_question_answer_df, test_size=0.1, random_state=42)

In [9]:
val_df

Unnamed: 0,question_id,question,relevant_articles
1951,8e2cfe626cebf209f94e0db8f147960c,mức xử_phạt đối_với hành_vi tổ_chức khám sức_k...,"[{'law_id': '28/2020/nđ-cp', 'article_id': '21'}]"
1204,0a32724630653580cc90c77bcf552baf,tranh_chấp giữa cơ_quan ký_kết hợp_đồng dự_án ...,"[{'law_id': '64/2020/qh14', 'article_id': '97'}]"
2661,e66ae5eecc1672bac2c5799673666bda,trình_tự cấp cấp giấy chứng_nhận đủ điều_kiện ...,"[{'law_id': '26/2019/nđ-cp', 'article_id': '28'}]"
705,6b815a5b10736472101034aa8bd58146,thành_viên hợp danh chủ doanh_nghiệp,"[{'law_id': '59/2020/qh14', 'article_id': '180'}]"
1036,4d4ed6ef19708f3e18fef48e4786a06e,luật_sư miễn đào_tạo nghề đấu_giá hay,"[{'law_id': '01/2016/qh14', 'article_id': '12'}]"
...,...,...,...
449,d3e3db3c2ac0799e4afa85e44766003b,trình_tự thủ_tục cấp giấy chứng_nhận thành_viê...,"[{'law_id': '155/2020/nđ-cp', 'article_id': '1..."
1563,383d04dfac51dda6b2b7e9ebd6336052,hồ_sơ đề_nghị hỗ_trợ chi_phí mai_táng gồm nhữn...,"[{'law_id': '20/2021/nđ-cp', 'article_id': '11'}]"
605,6a3175d244fa8cd461b7b9aa227f124a,quy_định giấy chứng_nhận huấn_luyện nghiệp_vụ ...,"[{'law_id': '03/2020/tt-bgtvt', 'article_id': ..."
2630,0aa5a051629e2a3580bbb55f2fab40cb,mức phạt phối_hợp cơ_quan nhà_nước thẩm_quyền ...,"[{'law_id': '15/2020/nđ-cp', 'article_id': '10'}]"


In [10]:
train_samples = []

for anc_sent, pos_ids in train_df[['question', 'relevant_articles']].values:
    for pos_id in pos_ids:
        pos_sent = clean_legal_corpus_df.loc[pos_id['law_id'], pos_id['article_id']].text
        train_samples.append(InputExample(texts=[anc_sent, pos_sent], label=1))

        for pair in neg_pairs_top_20:
            if pos_id == pair['positive']:
                neg_ids = pair['negative']
                for neg_id in neg_ids:
                    neg_sent = clean_legal_corpus_df.loc[neg_id['law_id'], neg_id['article_id']].text
                    train_samples.append(InputExample(texts=[anc_sent, neg_sent], label=0))
                break

In [11]:
val_samples = []

for anc_sent, pos_ids in val_df[['question', 'relevant_articles']].values:
    for pos_id in pos_ids:
        pos_sent = clean_legal_corpus_df.loc[pos_id['law_id'], pos_id['article_id']].text
        val_samples.append(InputExample(texts=[anc_sent, pos_sent], label=1))

        for pair in neg_pairs_top_20:
            if pos_id == pair['positive']:
                neg_ids = pair['negative']
                for neg_id in neg_ids:
                    neg_sent = clean_legal_corpus_df.loc[neg_id['law_id'], neg_id['article_id']].text
                    val_samples.append(InputExample(texts=[anc_sent, neg_sent], label=0))
                break

In [12]:
BATCH_SIZE = 16

train_dataloader = DataLoader(train_samples, batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=0, drop_last=False)

val_dataloader = DataLoader(val_samples, batch_size=BATCH_SIZE, shuffle=False,
                              num_workers=0, drop_last=False)

In [13]:
model = SentenceTransformer('/content/drive/MyDrive/NLP/Information_Retrieval/Legal_Text_Retrieval_Zalo2021/checkpoints/task_training_1/')
model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: RobertaModel 
  (1): Pooling({'word_embedding_dimension': 1024, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
)

In [14]:
train_loss = losses.ContrastiveLoss(model=model)

In [15]:
from sentence_transformers.evaluation import BinaryClassificationEvaluator

evaluator = BinaryClassificationEvaluator.from_input_examples(val_samples)

In [16]:
epochs = 5
warmup_steps = int(len(train_dataloader) * epochs * 0.1)

model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=epochs,
    evaluator=evaluator,
    evaluation_steps=1400,
    warmup_steps=warmup_steps,
    output_path='/content/drive/MyDrive/NLP/Information_Retrieval/Legal_Text_Retrieval_Zalo2021/checkpoints/task_training_2/',
    checkpoint_path='/content/drive/MyDrive/NLP/Information_Retrieval/Legal_Text_Retrieval_Zalo2021/checkpoints/task_training_2/',
    optimizer_params={'lr': 2e-6},
    checkpoint_save_steps=3800,
    checkpoint_save_total_limit=2,
    save_best_model=True,
    use_amp=True,
    show_progress_bar=True,
)



Epoch:   0%|          | 0/5 [00:00<?, ?it/s]

Iteration:   0%|          | 0/3895 [00:00<?, ?it/s]

Iteration:   0%|          | 0/3895 [00:00<?, ?it/s]

Iteration:   0%|          | 0/3895 [00:00<?, ?it/s]

Iteration:   0%|          | 0/3895 [00:00<?, ?it/s]

KeyboardInterrupt: ignored