In [2]:
import pandas as pd

raw_data = pd.read_csv('data/book_raw_corpus.txt')

In [16]:
len(raw_data)

673269

In [None]:
from transformers import ElectraForPreTraining, ElectraTokenizer, ElectraForMaskedLM

tokenizer = ElectraTokenizer.from_pretrained("monologg/koelectra-base-v3-generator")

generator = ElectraForMaskedLM.from_pretrained('monologg/koelectra-base-v3-generator')

discriminator = ElectraForPreTraining.from_pretrained("monologg/koelectra-base-v3-discriminator")


### Domain Adaptation을 위한 Electra Model

- Electra-pytorch 라이브러리를 koElectra에 맞게 일부 수정했습니다. 

- Generator 모델은 `ElectraForMaskedLM`로, Discriminator 모델은 `ElectraForPreTraining`로 불러와야 합니다.

- 모델 학습 과정을 이해할 수 있도록 generator의 fake sentence 생성 및 Discriminator의 예측을 출력합니다.

- Electra-pytorh 원본 github 주소 : https://github.com/lucidrains/electra-pytorch



In [7]:
import math
from functools import reduce
from collections import namedtuple

import torch
from torch import nn
import torch.nn.functional as F

# constants

Results = namedtuple('Results', [
    'loss',
    'mlm_loss',
    'disc_loss',
    'gen_acc',
    'disc_acc',
    'disc_labels',
    'disc_predictions'
])

# helpers

def log(t, eps=1e-9):
    return torch.log(t + eps)

def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

def gumbel_sample(t, temperature = 1.):
    return ((t / temperature) + gumbel_noise(t)).argmax(dim=-1)

def prob_mask_like(t, prob):
    return torch.zeros_like(t).float().uniform_(0, 1) < prob

def mask_with_tokens(t, token_ids):
    init_no_mask = torch.full_like(t, False, dtype=torch.bool)
    mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask)
    return mask

def get_mask_subset_with_prob(mask, prob):
    batch, seq_len, device = *mask.shape, mask.device
    max_masked = math.ceil(prob * seq_len)

    num_tokens = mask.sum(dim=-1, keepdim=True)
    mask_excess = (mask.cumsum(dim=-1) > (num_tokens * prob).ceil())
    mask_excess = mask_excess[:, :max_masked]

    rand = torch.rand((batch, seq_len), device=device).masked_fill(~mask, -1e9)
    _, sampled_indices = rand.topk(max_masked, dim=-1)
    sampled_indices = (sampled_indices + 1).masked_fill_(mask_excess, 0)

    new_mask = torch.zeros((batch, seq_len + 1), device=device)
    new_mask.scatter_(-1, sampled_indices, 1)
    return new_mask[:, 1:].bool()

# hidden layer extractor class, for magically adding adapter to language model to be pretrained

class HiddenLayerExtractor(nn.Module):
    def __init__(self, net, layer = -2):
        super().__init__()
        self.net = net
        self.layer = layer

        self.hidden = None
        self.hook_registered = False

    def _find_layer(self):
        if type(self.layer) == str:
            modules = dict([*self.net.named_modules()])
            return modules.get(self.layer, None)
        elif type(self.layer) == int:
            children = [*self.net.children()]
            return children[self.layer]
        return None

    def _hook(self, _, __, output):
        self.hidden = output

    def _register_hook(self):
        layer = self._find_layer()
        assert layer is not None, f'hidden layer ({self.layer}) not found'
        handle = layer.register_forward_hook(self._hook)
        self.hook_registered = True

    def forward(self, x):
        if self.layer == -1:
            return self.net(x)

        if not self.hook_registered:
            self._register_hook()

        _ = self.net(x)
        hidden = self.hidden
        self.hidden = None
        assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
        return hidden

# main electra class

class Electra(nn.Module):
    def __init__(
        self,
        generator,
        discriminator,
        *,
        num_tokens = None,
        discr_dim = -1,
        discr_layer = -1,
        mask_prob = 0.15,
        replace_prob = 0.85,
        random_token_prob = 0.,
        mask_token_id = 4,
        pad_token_id = 0,
        mask_ignore_token_ids = [2,3],
        disc_weight = 50.,
        gen_weight = 1.,
        temperature = 1.):
        super().__init__()

        self.generator = generator
        self.discriminator = discriminator

        if discr_dim > 0:
            self.discriminator = nn.Sequential(
                HiddenLayerExtractor(discriminator, layer = discr_layer),
                nn.Linear(discr_dim, 1)
            )

        # mlm related probabilities
        self.mask_prob = mask_prob
        self.replace_prob = replace_prob

        self.num_tokens = num_tokens
        self.random_token_prob = random_token_prob

        # token ids
        self.pad_token_id = pad_token_id
        self.mask_token_id = mask_token_id
        self.mask_ignore_token_ids = set([*mask_ignore_token_ids, pad_token_id])

        # sampling temperature
        self.temperature = temperature

        # loss weights
        self.disc_weight = disc_weight
        self.gen_weight = gen_weight


    def forward(self, input, **kwargs):

        b, t = input.shape

        replace_prob = prob_mask_like(input, self.replace_prob)

        # do not mask [pad] tokens, or any other tokens in the tokens designated to be excluded ([cls], [sep])
        # also do not include these special tokens in the tokens chosen at random
        no_mask = mask_with_tokens(input, self.mask_ignore_token_ids)
        mask = get_mask_subset_with_prob(~no_mask, self.mask_prob)

        # get mask indices
        mask_indices = torch.nonzero(mask, as_tuple=True)

        # mask input with mask tokens with probability of `replace_prob` (keep tokens the same with probability 1 - replace_prob)
        masked_input = input.clone().detach()

        # set inverse of mask to padding tokens for labels
        gen_labels = input.masked_fill(~mask, self.pad_token_id)

        # clone the mask, for potential modification if random tokens are involved
        # not to be mistakened for the mask above, which is for all tokens, whether not replaced nor replaced with random tokens
        masking_mask = mask.clone()

        # if random token probability > 0 for mlm
        if self.random_token_prob > 0:
            assert self.num_tokens is not None, 'Number of tokens (num_tokens) must be passed to Electra for randomizing tokens during masked language modeling'

            random_token_prob = prob_mask_like(input, self.random_token_prob)
            random_tokens = torch.randint(0, self.num_tokens, input.shape, device=input.device)
            random_no_mask = mask_with_tokens(random_tokens, self.mask_ignore_token_ids)
            random_token_prob &= ~random_no_mask

            masked_input = torch.where(random_token_prob, random_tokens, masked_input)

            # remove random token prob mask from masking mask
            masking_mask = masking_mask & ~random_token_prob

        # [mask] input
        masked_input = masked_input.masked_fill(masking_mask * replace_prob, self.mask_token_id)
        
        show_mask =tokenizer.convert_ids_to_tokens(masked_input.data[0].tolist())
        print('Masked token 생성 결과 : ',show_mask)

        # get generator output and get mlm loss(수정)
        logits = self.generator(masked_input, **kwargs).logits

        print('Generator 예측 결과 : ',tokenizer.convert_ids_to_tokens(logits.unsqueeze(0).max(dim=-1)[1][0][0].tolist()))

        mlm_loss = F.cross_entropy(
            logits.squeeze(0),
            gen_labels.squeeze(0),
            ignore_index = self.pad_token_id
        )

        # use mask from before to select logits that need sampling
        sample_logits = logits[mask_indices]


        # sample
        sampled = gumbel_sample(sample_logits, temperature = self.temperature)
        print('sample token :', tokenizer.convert_ids_to_tokens(sampled))

        # scatter the sampled values back to the input
        disc_input = input.clone()
        disc_input[mask_indices] = sampled.detach()

        # generate discriminator labels, with replaced as True and original as False
        disc_labels = (input != disc_input).float().detach()


        # 결과 보여주기
        print_input = tokenizer.convert_ids_to_tokens(disc_input[0].tolist())
        
        # token 단어 표현하기
        for i in range(len(disc_input[0])) :
            if show_mask[i] == '[MASK]' :
                print_input[i] = '<< '+ print_input[i] +' >>'

        print('변경된 문장 : ',print_input)

        # get discriminator predictions of replaced / original
        non_padded_indices = torch.nonzero(input != self.pad_token_id, as_tuple=True)

        # get discriminator output and binary cross entropy loss
        disc_logits = self.discriminator(disc_input, **kwargs).logits
        disc_logits = disc_logits.reshape_as(disc_labels)


        disc_loss = F.binary_cross_entropy_with_logits(
            disc_logits[non_padded_indices],
            disc_labels[non_padded_indices]
        )

        # gather metrics
        with torch.no_grad():
            gen_predictions = torch.argmax(logits, dim=-1)

            # 결과 보여주기
            print_input = tokenizer.convert_ids_to_tokens(gen_predictions[0].tolist())

            # token 단어 표현하기
            for i in range(len(disc_input[0])) :
                if show_mask[i] == '[MASK]' :
                    print_input[i] = '<< '+ print_input[i] +' >>'

            print('예측한 문장 : ',print_input)        


            disc_predictions = torch.round((torch.sign(disc_logits) + 1.0) * 0.5)
            gen_acc = (gen_labels[mask] == gen_predictions[mask]).float().mean()
            disc_acc = 0.5 * (disc_labels[mask] == disc_predictions[mask]).float().mean() + 0.5 * (disc_labels[~mask] == disc_predictions[~mask]).float().mean()

        # return weighted sum of losses
        return Results((self.gen_weight * mlm_loss + self.disc_weight * disc_loss).item(), mlm_loss.item(), disc_loss.item(), gen_acc.item(), disc_acc.item(), disc_labels, disc_predictions)



model = Electra(generator=generator,discriminator=discriminator)


raw_str = '이 책은 이제부터 파이썬을 시작하고자 하는 프로그래밍 경험자들에게 가장 적합한 파이썬 입문서입니다'

# random sample 추출
# raw_data = raw_data[raw_data['list'].str.contains('파이썬') == True]
# raw_str = raw_data.sample(1).values.item()


print('원본 : ', raw_str)
input_data = tokenizer(raw_str,return_tensors='pt')

model(input_data['input_ids'])


원본 :  이 책은 이제부터 파이썬을 시작하고자 하는 프로그래밍 경험자들에게 가장 적합한 파이썬 입문서입니다
Masked token 생성 결과 :  ['[CLS]', '이', '책', '##은', '이제', '##부터', '파이', '##썬', '[MASK]', '시작', '##하', '##고', '##자', '하', '[MASK]', '프로그래밍', '경험', '##자', '[MASK]', '##에', '##게', '가장', '적합', '##한', '파이', '##썬', '입문', '##서', '##입니다', '[SEP]']
Generator 예측 결과 :  ['/', '이', '책', '##은', '이제', '##부터', '파이', '##썬', '##을', '시작', '##하', '##고', '##자', '하', '##는', '프로그래밍', '경험', '##자', '##들', '##에', '##게', '가장', '적합', '##한', '파이', '##썬', '입문', '##서', '##입니다', '/']
sample token : ['##을', '시작', '##자', '##는', '##들']
변경된 문장 :  ['[CLS]', '이', '책', '##은', '이제', '##부터', '파이', '##썬', '<< ##을 >>', '시작', '##하', '##고', '##자', '하', '<< ##는 >>', '프로그래밍', '경험', '##자', '<< ##들 >>', '##에', '##게', '가장', '적합', '##한', '파이', '##썬', '입문', '##서', '##입니다', '[SEP]']
예측한 문장 :  ['/', '이', '책', '##은', '이제', '##부터', '파이', '##썬', '<< ##을 >>', '시작', '##하', '##고', '##자', '하', '<< ##는 >>', '프로그래밍', '경험', '##자', '<< ##들 >>', '##에', '##게', '가장', '적합', '##한', '파이', '##썬', '입

Results(loss=1.8209502696990967, mlm_loss=0.013269344344735146, disc_loss=0.03615361824631691, gen_acc=1.0, disc_acc=1.0, disc_labels=tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0.]]), disc_predictions=tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0.]]))