## Worker

In [1]:
import websocket
import rel
import json
import re
import numpy as np
import torch
import pickle
from keras.models import load_model
from gensim.models import Doc2Vec
from konlpy.tag import Okt
from tensorflow.keras.preprocessing.sequence import pad_sequences
from transformers import AutoTokenizer, AutoModel


TURTLEMQ_URL = "ws://110.165.18.105/turtle/"

# Interval
STOPWORDS_DICT_PATH = "./user_dic/stopwords"
TOKENIZER_PATH = "./model/tokenizer.pickle"

# MLP
IMPRISONMENT_MLP_MODEL_PATH = "./model/imprisonment_231119.keras"
PROBATION_MODEL_PATH = "./model/probation_231120.keras"
FINE_MODEL_PATH = "./model/fine_231121.keras"

# KOBERT
IMPRISONMENT_KOBERT_MODEL_PATH = "./model/kobert_imprisonment_231119_1.h5"

# DOC2VEC
D2V_MODEL_PATH = "./model/d2v_231117.model"


class Worker:
    def __init__(self):
        # 형태소 분석기
        self.okt = Okt()

        # 불용어 사전
        self.stopwords = set()
        with open(STOPWORDS_DICT_PATH, "r") as f:
            for line in f.readlines():
                self.stopwords.add(line.strip())

        # 토크나이저 불러오기
        self.tokenizer = {}
        with open(TOKENIZER_PATH, 'rb') as handle:
            self.tokenizer = pickle.load(handle)

        # MLP 모델 불러오기
        self.imprisonment_mlp_model = load_model(IMPRISONMENT_MLP_MODEL_PATH)
        self.probation_model = load_model(PROBATION_MODEL_PATH)
        self.fine_model = load_model(FINE_MODEL_PATH)

        # KoBERT 모델 불러오기
        self.imprisonment_kobert_model = load_model(IMPRISONMENT_KOBERT_MODEL_PATH)
        self.kobert_tokenizer = AutoTokenizer.from_pretrained("monologg/kobert")
        self.kobert_model = AutoModel.from_pretrained("monologg/kobert")
        self.kobert_model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

        # Doc2Vec 모델
        self.d2v_model = Doc2Vec.load(D2V_MODEL_PATH)


    def on_open(self, ws):
        print("Connection opened")
        self.send('{"type":"REGISTER_WORKER", "data":""}') # 초기 패킷 (워커 등록)


    def on_message(self, ws, message):
        print("TurtleMQ →", message)

        message = json.loads(message)

        if message['type'] == "REQUEST_TASK":
            result = { "imprisonment": 0, "probation": 0, "fine": 0, "judgementDecision": "예상 판결은 아직 지원되지 않습니다.", "similarPrecedents": [] }

            # 텍스트 클렌징 진행
            clean_text = self.get_clean_text(message['data'].strip())

            # 띄어쓰기 수정
            clean_text = self.correct_spacing(clean_text)

            # 명사 추출
            extracted_nouns = self.extract_nouns(clean_text)

            # 토크나이징
            tokens = self.get_tokenized_sequences(extracted_nouns)


            # 징역 예측 (MLP모델과 KoBERT 모델 결과의 평균)
            result['imprisonment'] = int((self.predict_imprisonment_kobert(extracted_nouns) + self.predict_imprisonment_mlp(tokens)) / 2)

            # 집행유예 예측
            result['probation'] = int(self.predict_probation(tokens) * 100)

            # 벌금 예측
            result['fine'] = int(self.predict_fine(tokens)) * 10000
            

            # 유사 판례 계산
            result['similarPrecedents'] = self.get_similar_precedents(extracted_nouns)

            response = { "type": "RESPONSE_TASK", "taskId": message["taskId"], "data": json.dumps(result) }
            self.send(json.dumps(response))



    def on_error(self, ws, error):
        print(error)

    def on_close(self, ws, close_status_code, close_msg):
        self.on_exit()

    def on_exit(self):
        if self.ws.keep_running:
            self.ws.close()

        print("Connection closed")
        raise SystemExit("Socket connection is closed.")

    def run(self):
        self.ws = websocket.WebSocketApp(TURTLEMQ_URL,
                                on_open=self.on_open,
                                on_message=self.on_message,
                                on_error=self.on_error,
                                on_close=self.on_close)
        
        get_ipython().events.register('post_execute', self.on_exit)

        self.ws.run_forever(dispatcher=rel, reconnect=5)  # 연결 실패면 5초뒤 다시 연결 시도
        rel.dispatch()


    # 텍스트 클렌징
    def get_clean_text(self, text) -> str:
        text = text.replace(',','').replace('"','').replace('\'','').replace('.','').replace('(',' ').\
            replace(')','').replace('!','').replace('?','').replace(':','').replace(';','').lower()
        text = text.replace("\n"," ")
        text = re.sub(r'\d+?\.\s\d+\.\s\d+\.', '', text) # 날짜 제거
        text = re.sub(r'\b\w+법원|지원', '', text) # 법원명 제거
        text = re.sub('수사보고|범 죄 사 실|범죄사실', '', text) # 판례의 기본적인 문구 제거
        text = re.sub(r'[「」『』\[\],.:%○]', '', text) # 특수기호 제거
        return text

    # 띄어쓰기 고치기
    def correct_spacing(self, text):
        tagged = self.okt.pos(text)
        corrected = ""
        for i in tagged:
            if i[1] in ('Josa', 'PreEomi', 'Eomi', 'Suffix', 'Punctuation'):
                corrected += i[0]
            else:
                corrected += " "+i[0]
        if corrected[0] == " ":
            corrected = corrected[1:]
        return corrected

    # 명사 추출
    def extract_nouns(self, text):
        nouns = []
        for noun in self.okt.nouns(text):
            if noun.isalpha() and (noun not in self.stopwords): # 단어에 숫자가 아닌 경우에만 저장 and 불용어 사전에 포함되지 않는 경우
                nouns.append(noun)
        return nouns

    # 토크나이징
    def get_tokenized_sequences(self, nouns):
        out = self.tokenizer.texts_to_sequences([nouns])
        return pad_sequences(out, 256)
    
    # 징역 예측 (MLP)
    def predict_imprisonment_mlp(self, tokens):
        return self.imprisonment_mlp_model.predict(tokens, verbose=0)
    
    # 징역 예측 (KoBERT)
    def predict_imprisonment_kobert(self, nouns):
        # 명사 병합
        processed_text = ' '.join(nouns)

        # 토큰화
        tokenized_text = self.kobert_tokenizer.tokenize(processed_text)
        tokenized_text = " ".join([word for word in tokenized_text if word != '[UNK]'])
        inputs = self.kobert_tokenizer(tokenized_text, return_tensors='pt', padding=True, truncation=True, max_length=512)
        inputs = {k: v.to("cuda" if torch.cuda.is_available() else "cpu") for k, v in inputs.items()}

        with torch.no_grad():
            outputs = self.kobert_model(**inputs)

        embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()[0]
        return self.imprisonment_kobert_model.predict(np.array([embedding]), verbose=0)[0][0]

    # 집행유예 예측 (MLP)
    def predict_probation(self, tokens):
        return self.probation_model.predict(tokens, verbose=0)
    
    # 벌금 예측 (MLP)
    def predict_fine(self, tokens):
        return self.fine_model.predict(tokens, verbose=0)

    # 유사 판례 계산
    def get_similar_precedents(self, nouns):
        result = []
        input_data_vector = self.d2v_model.infer_vector(nouns)
        for case in self.d2v_model.dv.most_similar([input_data_vector]):
            # 결과 5개만 추출
            if len(result) >= 5:
                break
            result.append('{} {}'.format(case[0], int(round(case[1], 2) * 100)))
        return result
        

    def send(self, data):
        if self.ws.keep_running:
            self.ws.send(data)
            print("TurtleMQ ←", data)

worker = Worker()
worker.run()

2023-11-21 15:16:03.794645: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-11-21 15:16:04.321898: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-11-21 15:16:04.415422: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  from .autonotebook import tqdm as notebook_tqdm


Connection opened
TurtleMQ ← {"type":"REGISTER_WORKER", "data":""}
TurtleMQ → {"type":"REGISTER_WORKER","messageId":null,"data":""}
TurtleMQ → {"type":"REQUEST_TASK","messageId":null,"data":"전동킥보드 진로변경 금지 위반 중앙선 침범 도주치상 사문서위조 ","taskId":"b5f69a76-1344-4147-a780-a733c2e8d8cc"}
TurtleMQ ← {"type": "RESPONSE_TASK", "taskId": "b5f69a76-1344-4147-a780-a733c2e8d8cc", "data": "{\"imprisonment\": 9, \"probation\": 24, \"fine\": 0, \"judgementDecision\": \"\\uc608\\uc0c1 \\ud310\\uacb0\\uc740 \\uc544\\uc9c1 \\uc9c0\\uc6d0\\ub418\\uc9c0 \\uc54a\\uc2b5\\ub2c8\\ub2e4.\", \"similarPrecedents\": [\"\\uc11c\\uc6b8\\ub0a8\\ubd80\\uc9c0\\ubc29\\ubc95\\uc6d0/2020\\uace0\\ub2e84219 34\", \"\\uc11c\\uc6b8\\ub0a8\\ubd80\\uc9c0\\ubc29\\ubc95\\uc6d0/2021\\uace0\\ub2e83240 34\", \"\\uc6b8\\uc0b0\\uc9c0\\ubc29\\ubc95\\uc6d0/2021\\uace0\\ub2e82299 34\", \"\\uc11c\\uc6b8\\ub0a8\\ubd80\\uc9c0\\ubc29\\ubc95\\uc6d0/2020\\uace0\\ub2e83829 33\", \"\\uc11c\\uc6b8\\ub0a8\\ubd80\\uc9c0\\ubc29\\ubc95\\uc6d0/2020\\uace0\\

Socket connection is closed.
