# 1. 데이터셋 전처리하기!

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
from transformers import AutoTokenizer
import pickle

In [3]:
class Preprocess:
    def __init__(self, CSV_PATH, version):
        self.version= version
        self.data= self.load_data(CSV_PATH)

    def load_data(self, path):
        
        data= pd.read_csv(path)
        
        sub_entity, sub_type= [], []
        obj_entity, obj_type= [], []
        sub_idx, obj_idx= [], []
        sentence= []

        """preprocess"""
        for i, [x, y, z] in enumerate(zip(data['subject_entity'], data['object_entity'], data['sentence'])):
            sub_typ= x[1:-1].split(':')[-1].split('\'')[-2]
            obj_typ= y[1:-1].split(':')[-1].split('\'')[-2]
            
            for idx_i in range(len(x)):
                if x[idx_i: idx_i+ 9]== 'start_idx':
                    sub_start= int(x[idx_i+12:].split(',')[0].strip())
                if x[idx_i: idx_i+7]== 'end_idx':
                    sub_end= int(x[idx_i+10:].split(',')[0].strip())
                
                if y[idx_i: idx_i+ 9]== 'start_idx':
                    obj_start= int(y[idx_i+12:].split(',')[0].strip())
                if y[idx_i: idx_i+7]== 'end_idx':
                    obj_end= int(y[idx_i+10:].split(',')[0].strip())
            
            sub_i= [sub_start, sub_end]
            obj_i= [obj_start, obj_end]

            sub_entity.append(z[sub_i[0]: sub_i[1]+1])
            obj_entity.append(z[obj_i[0]: obj_i[1]+1])
            sub_type.append(sub_typ); sub_idx.append(sub_i)
            obj_type.append(obj_typ); obj_idx.append(obj_i)
            
            """tokenize version"""
            if self.version== 'SUB':
                # 만약 sub가 먼저 시작한다면?
                if sub_i[0] < obj_i[0]:
                    z= z[:sub_i[0]] + '[SUB]'+ z[sub_i[0]: sub_i[1]+1] + '[/SUB]' + z[sub_i[1]+1:]
                    z= z[:obj_i[0]+11] + '[OBJ]'+ z[obj_i[0]+11: obj_i[1]+12]+ '[/OBJ]'+ z[obj_i[1]+12:]
                # 만약 obj가 먼저 시작한다면?
                else:
                    z= z[:obj_i[0]] + '[OBJ]'+ z[obj_i[0]: obj_i[1]+1]+ '[/OBJ]'+ z[obj_i[1]+1:]
                    z= z[:sub_i[0]+11] + '[SUB]'+ z[sub_i[0]+11: sub_i[1]+12] + '[/SUB]' + z[sub_i[1]+12:]

            elif self.version== 'PUN':
                if sub_i[0] < obj_i[0]:
                    z= z[:sub_i[0]] + '@*'+sub_typ+'*'+ z[sub_i[0]: sub_i[1]+1] + '@' + z[sub_i[1]+1:]
                    z= z[:obj_i[0]+7] + '#^'+ obj_typ +'^'+ z[obj_i[0]+7: obj_i[1]+8]+ '#'+ z[obj_i[1]+8:]
                else:
                    z= z[:obj_i[0]] + '#^'+ obj_typ +'^'+ z[obj_i[0]: obj_i[1]+1]+ '#' + z[obj_i[1]+1:]
                    z= z[:sub_i[0]+7] + '@*'+sub_typ+'*' + z[sub_i[0]+7: sub_i[1]+8] + '@' + z[sub_i[1]+8:]

            sentence.append(z)

        df= pd.DataFrame({'id': data['id'], 'sentence' : sentence, 'subject_entity': sub_entity, 'object_entity': obj_entity,
                                'subject_type': sub_type, 'object_type': obj_type, 'label': data['label'],
                                'subject_idx': sub_idx, 'object_idx': obj_idx})
        
        return df
    
    def tokenized_dataset(self, data, tokenizer):

        """add token list"""
        tokens= ['PER', 'LOC', 'POH', 'DAT', 'NOH', 'ORG']
        tokenizer.add_tokens(tokens)     

        concat_entity = []
        for sub_ent, obj_ent, sub_typ, obj_typ in zip(data['subject_entity'], data['object_entity'], data['subject_type'], data['object_type']):
            temp =  '@*'+ sub_typ + '*' + sub_ent + '@와 #^' + obj_typ + '^' + obj_ent + '#의 관계'
            #temp =  e01 + '와' + e02 + '의 관계'
            concat_entity.append(temp)

        tokenized_sentence= tokenizer(
            concat_entity,
            list(data['sentence']), # list나 string type으로 보내줘야 함 !
            return_tensors= "pt", # pytorch type
            padding= True, # 문장의 길이가 짧다면 padding
            truncation= True, # 문장 자르기
            max_length= 256, # 토큰 최대 길이...
            add_special_tokens= True, # special token 추가
            return_token_type_ids= False # roberta의 경우.. token_type_ids가 안들어감 ! 
        )    

        return tokenized_sentence, len(tokens)
    
    def label_to_num(self, label):
        num_label= [] # 숫자로 된 label 담을 변수

        with open('dict_label_to_num.pkl', 'rb') as f:
            dict_label_to_num= pickle.load(f)

            for val in label:
                num_label.append(dict_label_to_num[val])
        
        return num_label

"""Train, Test Dataset"""
class Dataset:
    def __init__(self, data, labels): # data : dict, label : list느낌..
        self.data= data
        self.labels= labels
    
    def __getitem__(self, idx):
        item = {key: val[idx].clone().detach() for key, val in self.data.items()}
        item['labels'] = torch.tensor(self.labels[idx])

        return item
    
    def __len__(self):
        return len(self.labels)

In [5]:
preprocess= Preprocess('/opt/ml/dataset/train/pseudo_train.csv', 'PUN')
all_dataset= preprocess.data

In [15]:
all_dataset

Unnamed: 0,id,sentence,subject_entity,object_entity,subject_type,object_type,label,subject_idx,object_idx
0,0,〈Something〉는 #^PER^조지 해리슨#이 쓰고 @*ORG*비틀즈@가 196...,비틀즈,조지 해리슨,ORG,PER,no_relation,"[24, 26]","[13, 18]"
1,1,호남이 기반인 바른미래당·#^ORG^대안신당#·@*ORG*민주평화당@이 우여곡절 끝...,민주평화당,대안신당,ORG,ORG,no_relation,"[19, 23]","[14, 17]"
2,2,K리그2에서 성적 1위를 달리고 있는 @*ORG*광주FC@는 지난 26일 #^ORG...,광주FC,한국프로축구연맹,ORG,ORG,org:member_of,"[21, 24]","[34, 41]"
3,3,균일가 생활용품점 (주)@*ORG*아성다이소@(대표 #^PER^박정부#)는 코로나1...,아성다이소,박정부,ORG,PER,org:top_members/employees,"[13, 17]","[22, 24]"
4,4,#^DAT^1967#년 프로 야구 드래프트 1순위로 @*ORG*요미우리 자이언츠@에...,요미우리 자이언츠,1967,ORG,DAT,no_relation,"[22, 30]","[0, 3]"
...,...,...,...,...,...,...,...,...,...
40230,40230,코로나19 방역 조치의 일환으로 국민의 움직임을 통제하려는 @*ORG*정부@의 시도...,정부,이탈리아,ORG,LOC,org:place_of_headquarters,"[33, 34]","[41, 44]"
40231,40231,선 연구원은 “위식도역류질환치료제인 케이캡이 92억원 판매되면서 2019년 연간 3...,종근당,전년,ORG,DAT,no_relation,"[133, 135]","[143, 144]"
40232,40232,"@*ORG*한국전기안전공사@(사장 #^PER^조성완#)는 8월 1일부로, 3급 간부...",한국전기안전공사,조성완,ORG,PER,org:top_members/employees,"[0, 7]","[12, 14]"
40233,40233,#^DAT^1987년# @*PER*B. 슈나이더@(B. Schneider)에 의해 ...,B. 슈나이더,1987년,PER,DAT,no_relation,"[6, 12]","[0, 4]"


In [None]:
trainset= Dataset(tokenized_train, train_label)
valset= Dataset(tokenized_val, val_label)