In [1]:
import pandas as pd

raw_data = pd.read_csv('data/pre_book_total_128.csv')

device = 'cpu'

In [2]:
len(raw_data)

193889

In [3]:
from transformers import ElectraForPreTraining, ElectraTokenizer, ElectraForMaskedLM

tokenizer = ElectraTokenizer.from_pretrained("monologg/koelectra-base-v3-generator")

generator = ElectraForMaskedLM.from_pretrained('monologg/koelectra-base-v3-generator')

discriminator = ElectraForPreTraining.from_pretrained("monologg/koelectra-base-v3-discriminator")


### Domain Adaptation을 위한 Electra Model

- Electra-pytorch 라이브러리를 koElectra에 맞게 일부 수정했습니다. 

- Generator 모델은 `ElectraForMaskedLM`로, Discriminator 모델은 `ElectraForPreTraining`로 불러와야 합니다.

- 모델 학습 과정을 이해할 수 있도록 generator의 fake sentence 생성 및 Discriminator의 예측을 출력합니다.

- Electra-pytorh 원본 github 주소 : https://github.com/lucidrains/electra-pytorch



In [4]:
### Electra-pytorch 라이브러리를 KoElectra에 활용할 수 있도록 일부 변형했습니다.

### Electra로 Domain Adaptation을 수행하기 위해 개발했습니다.

### Generator 모델은 ElectraForMaskedLM로, Discriminator 모델은 ElectraForPreTraining로 불러와야 합니다.

### 더 많은 내용을 알고 싶으신 경우 Domain Adaptation Tutorial을 참고해주세요.

### Electra-pytorh 원본 github 주소 : https://github.com/lucidrains/electra-pytorch


### Electra-pytorch 라이브러리를 KoElectra에 활용할 수 있도록 일부 변형했습니다.

### Electra로 Domain Adaptation을 수행하기 위해 개발했습니다.

### Generator 모델은 ElectraForMaskedLM로, Discriminator 모델은 ElectraForPreTraining로 불러와야 합니다.

### 더 많은 내용을 알고 싶으신 경우 Domain Adaptation Tutorial을 참고해주세요.

### Electra-pytorh 원본 github 주소 : https://github.com/lucidrains/electra-pytorch


import math
from functools import reduce
from collections import namedtuple

import torch
from torch import nn
import torch.nn.functional as F

# constants

Results = namedtuple(
    "Results",
    [
        "loss",
        "mlm_loss",
        "disc_loss",
        "gen_acc",
        "disc_acc",
        "disc_labels",
        "disc_predictions",
    ],
)

# helpers


def log(t, eps=1e-9):
    return torch.log(t + eps)


def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))


def gumbel_sample(t, temperature=1.0):
    return ((t / temperature) + gumbel_noise(t)).argmax(dim=-1)


def prob_mask_like(t, prob):
    return torch.zeros_like(t).float().uniform_(0, 1) < prob


def mask_with_tokens(t, token_ids):
    init_no_mask = torch.full_like(t, False, dtype=torch.bool)
    mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask)
    return mask


def get_mask_subset_with_prob(mask, prob):
    batch, seq_len, device = *mask.shape, mask.device
    max_masked = math.ceil(prob * seq_len)

    num_tokens = mask.sum(dim=-1, keepdim=True)
    mask_excess = mask.cumsum(dim=-1) > (num_tokens * prob).ceil()
    mask_excess = mask_excess[:, :max_masked]

    rand = torch.rand((batch, seq_len), device=device).masked_fill(~mask, -1e9)
    _, sampled_indices = rand.topk(max_masked, dim=-1)
    sampled_indices = (sampled_indices + 1).masked_fill_(mask_excess, 0)

    new_mask = torch.zeros((batch, seq_len + 1), device=device)
    new_mask.scatter_(-1, sampled_indices, 1)
    return new_mask[:, 1:].bool()


# hidden layer extractor class, for magically adding adapter to language model to be pretrained


class HiddenLayerExtractor(nn.Module):
    def __init__(self, net, layer=-2):
        super().__init__()
        self.net = net
        self.layer = layer

        self.hidden = None
        self.hook_registered = False

    def _find_layer(self):
        if type(self.layer) == str:
            modules = dict([*self.net.named_modules()])
            return modules.get(self.layer, None)
        elif type(self.layer) == int:
            children = [*self.net.children()]
            return children[self.layer]
        return None

    def _hook(self, _, __, output):
        self.hidden = output

    def _register_hook(self):
        layer = self._find_layer()
        assert layer is not None, f"hidden layer ({self.layer}) not found"
        handle = layer.register_forward_hook(self._hook)
        self.hook_registered = True

    def forward(self, x):
        if self.layer == -1:
            return self.net(x)

        if not self.hook_registered:
            self._register_hook()

        _ = self.net(x)
        hidden = self.hidden
        self.hidden = None
        assert hidden is not None, f"hidden layer {self.layer} never emitted an output"
        return hidden


# main electra class


class Electra(nn.Module):
    def __init__(
        self,
        generator,
        discriminator,
        tokenizer,
        *,
        num_tokens=None,
        discr_dim=-1,
        discr_layer=-1,
        mask_prob=0.15,
        replace_prob=0.85,
        random_token_prob=0.0,
        mask_token_id=4,
        pad_token_id=0,
        mask_ignore_token_ids=[2, 3],
        disc_weight=50.0,
        gen_weight=1.0,
        temperature=1.0,
    ):
        super().__init__()

        self.generator = generator
        self.discriminator = discriminator
        self.tokenizer = tokenizer

        if discr_dim > 0:
            self.discriminator = nn.Sequential(
                HiddenLayerExtractor(discriminator, layer=discr_layer),
                nn.Linear(discr_dim, 1),
            )

        # mlm related probabilities
        self.mask_prob = mask_prob
        self.replace_prob = replace_prob

        self.num_tokens = num_tokens
        self.random_token_prob = random_token_prob

        # token ids
        self.pad_token_id = pad_token_id
        self.mask_token_id = mask_token_id
        self.mask_ignore_token_ids = set([*mask_ignore_token_ids, pad_token_id])

        # sampling temperature
        self.temperature = temperature

        # loss weights
        self.disc_weight = disc_weight
        self.gen_weight = gen_weight

    def forward(self, input, **kwargs):

        replace_prob = prob_mask_like(input, self.replace_prob)

        # do not mask [pad] tokens, or any other tokens in the tokens designated to be excluded ([cls], [sep])
        # also do not include these special tokens in the tokens chosen at random
        no_mask = mask_with_tokens(input, self.mask_ignore_token_ids)
        mask = get_mask_subset_with_prob(~no_mask, self.mask_prob)

        # get mask indices
        mask_indices = torch.nonzero(mask, as_tuple=True)

        # mask input with mask tokens with probability of `replace_prob` (keep tokens the same with probability 1 - replace_prob)
        masked_input = input.clone().detach()

        # set inverse of mask to padding tokens for labels
        gen_labels = input.masked_fill(~mask, self.pad_token_id)

        # clone the mask, for potential modification if random tokens are involved
        # not to be mistakened for the mask above, which is for all tokens, whether not replaced nor replaced with random tokens
        masking_mask = mask.clone()

        # if random token probability > 0 for mlm
        if self.random_token_prob > 0:
            assert (
                self.num_tokens is not None
            ), "Number of tokens (num_tokens) must be passed to Electra for randomizing tokens during masked language modeling"

            random_token_prob = prob_mask_like(input, self.random_token_prob)
            random_tokens = torch.randint(
                0, self.num_tokens, input.shape, device=input.device
            )
            random_no_mask = mask_with_tokens(random_tokens, self.mask_ignore_token_ids)
            random_token_prob &= ~random_no_mask

            masked_input = torch.where(random_token_prob, random_tokens, masked_input)

            # remove random token prob mask from masking mask
            masking_mask = masking_mask & ~random_token_prob

        # [mask] input
        masked_input = masked_input.masked_fill(
            masking_mask * replace_prob, self.mask_token_id
        )

        # get generator output and get mlm loss(수정)
        logits = self.generator(masked_input, **kwargs).logits

        mlm_loss = F.cross_entropy(
            logits.reshape(-1, logits.shape[-1]),
            gen_labels.reshape(-1),
            ignore_index=self.pad_token_id,
        )

        # use mask from before to select logits that need sampling
        sample_logits = logits[mask_indices]

        # sample
        sampled = gumbel_sample(sample_logits, temperature=self.temperature)
        # print("sample token :", self.tokenizer.convert_ids_to_tokens(sampled))

        # scatter the sampled values back to the input
        disc_input = input.clone()
        disc_input[mask_indices] = sampled.detach()

        # generate discriminator labels, with replaced as True and original as False
        disc_labels = (input != disc_input).float().detach()

        # get discriminator predictions of replaced / original
        non_padded_indices = torch.nonzero(input != self.pad_token_id, as_tuple=True)

        # get discriminator output and binary cross entropy loss
        disc_logits = self.discriminator(disc_input, **kwargs).logits
        disc_logits_reshape = disc_logits.reshape_as(disc_labels)

        disc_loss = F.binary_cross_entropy_with_logits(
            disc_logits_reshape[non_padded_indices], disc_labels[non_padded_indices]
        )

        # gather metrics
        with torch.no_grad():
            gen_predictions = torch.argmax(logits, dim=-1)

            disc_predictions = torch.round(
                (torch.sign(disc_logits_reshape) + 1.0) * 0.5
            )
            gen_acc = (gen_labels[mask] == gen_predictions[mask]).float().mean()
            disc_acc = (
                0.5 * (disc_labels[mask] == disc_predictions[mask]).float().mean()
                + 0.5 * (disc_labels[~mask] == disc_predictions[~mask]).float().mean()
            )

            #####

            # [PAD] 제거
            try:
                n = masked_input.data[0].tolist().index(3) + 1
            except:
                n = None

            # masked_sen : Masking한 문장 생성
            masked_sen = self.tokenizer.convert_ids_to_tokens(
                masked_input.data[0].tolist()[:n]
            )

            # gen_prd : Generator 문장 예측
            # [:n] : padding 제거
            gen_prd = self.tokenizer.convert_ids_to_tokens(
                logits.unsqueeze(0).max(dim=-1)[1][0][0].tolist()[:n]
            )

            # gen_chg_sen : Generator 신규 문장 생성
            gen_chg_sen = self.tokenizer.convert_ids_to_tokens(
                disc_input[0].tolist()[:n]
            )

            for i in range(len(disc_input[0][:n])):
                if masked_sen[i] == "[MASK]":
                    gen_chg_sen[i] = "<< " + gen_chg_sen[i] + " >>"

        # return weighted sum of losses
        return (masked_sen, gen_prd, gen_chg_sen), Results(
            self.gen_weight * mlm_loss + self.disc_weight * disc_loss,
            mlm_loss,
            disc_loss,
            gen_acc,
            disc_acc,
            disc_labels,
            disc_predictions,
        )


In [None]:
optimizer = torch.optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.9, 0.98), eps=1e-9)
model = Electra(generator=generator,discriminator=discriminator,tokenizer=tokenizer)

model.to(device)
model.train()

In [6]:
from data import book_corpus
from torch.utils.data import DataLoader

def train_epoch(model) :
    losses = 0
    dataset = book_corpus('data/pre_book_total_128.csv')

    batch_size = 64
    max_length = 128


    train_dataloader = DataLoader(dataset,batch_size)
    torch.cuda.empty_cache()

    for i, sen in enumerate(train_dataloader) :
        
        torch.cuda.empty_cache()
        # weight tie
        generator.electra.embeddings.token_type_embeddings.weight = discriminator.electra.embeddings.token_type_embeddings.weight
        generator.electra.embeddings.position_embeddings.weight = discriminator.electra.embeddings.position_embeddings.weight

        loss = 0

        input_data = tokenizer(sen,return_tensors='pt',truncation=True,padding=True,max_length=max_length).to(device)

        prd,result = model(input_data['input_ids'])

        optimizer.zero_grad()

        result.loss.backward()

        optimizer.step()
        
        if i == 30 :
            break

        if i % 1000 == 0 :
            print(f'{i*batch_size} 개 학습 중')
            masked_sen = prd[0]
            gen_prd = prd[1]
            gen_chg_sen = prd[2]

            print('*--- 원본 ---*')
            print(sen[0])
            print('')
            print('*--- Input 단어 Masking ---*')
            print(tokenizer.convert_tokens_to_string(masked_sen[1:-1]))
            print('')
            print('Generator 학습---*')
            print(tokenizer.convert_tokens_to_string(gen_prd[1:-1]))
            print('')
            print('Generator 문장 생성---*')
            print(tokenizer.convert_tokens_to_string(gen_chg_sen[1:-1]))
            print('')
            print('Discriminator 예측 정확도---*')
            print('disc_acc : ',round(result.disc_acc.item(),3),'0일 경우 학습 에러 : ', torch.sum(result.disc_predictions).item() )


        losses += result.loss.item()

    return losses / len(train_dataloader)
            
        

from timeit import default_timer as timer

NUM_EPOCHS = 1

for epoch in range(1, NUM_EPOCHS+1):
    print('-'*30)
    print(f'{epoch}번째 epoch 실행')
    start_time = timer()
    loss = train_epoch(model)
    end_time = timer()

    print('----*'*20)
    print((f"Epoch: {epoch},loss: {loss:.3f},  "f"Epoch time = {(end_time - start_time):.3f}s"))
    print('----*'*20)


----*----*----*----*----*----*----*----*----*----*----*----*----*----*----*----*----*----*----*----*
Epoch: 1,loss: 0.030,  Epoch time = 121.372s
----*----*----*----*----*----*----*----*----*----*----*----*----*----*----*----*----*----*----*----*


In [7]:
old_model = ElectraForMaskedLM.from_pretrained('monologg/koelectra-base-v3-generator')

In [16]:
new_param = [i for i in generator.electra.parameters()]

len(new_param)

199

In [14]:
old_param = [i for i in old_model.parameters()]
new_param = [i for i in generator.electra.parameters()]

# Params 비교
idx_list = []
for i in range(len(old_param)) :
    x = old_param[i] - new_param[i] # 차이가 없다면 x = 0
    if torch.sum(x.reshape(-1)).item() != 0 :
        print(torch.sum(x.reshape(-1)).item())
        idx_list.append(i)

idx_list 



-0.07394019514322281
0.007522447034716606


IndexError: list index out of range