In [3]:
from argparse import Namespace
import logging
opt = {
    "batch_size": 4,
    "num_workers": 4,
    "lr": 1e-4,
    "max_epochs": 10,
    "warmup_ratio": 0.2,
    "print_step": 100,
    "save_path": "model_weights",
} 
args = Namespace(**opt)

In [4]:
import json
path = "dacon/"
train_file = path + "train.json"
with open(train_file, "r") as f:
    TRAIN_DATA = json.load(f)
    

In [5]:
import json
path = "dacon/"
test_file = path + "test.json"
with open(test_file, "r") as f:
    TEST_DATA = json.load(f)

In [6]:
#!pip install git+https://github.com/SKT-AI/KoBART#egg=kobart

In [7]:
from kobart import get_kobart_tokenizer
kobart_tokenizer = get_kobart_tokenizer()
kobart_tokenizer.tokenize("안녕하세요. 한국어 BART 입니다.🤣:)l^o")


using cached model


['▁안녕하', '세요.', '▁한국어', '▁B', 'A', 'R', 'T', '▁입', '니다.', '🤣', ':)', 'l^o']

In [8]:
from transformers import BartModel, BartForConditionalGeneration
from kobart import get_pytorch_kobart_model, get_kobart_tokenizer
kobart_tokenizer = get_kobart_tokenizer()
model = BartForConditionalGeneration.from_pretrained(get_pytorch_kobart_model())#BartModel.from_pretrained(get_pytorch_kobart_model())
inputs = kobart_tokenizer(['안녕하세요.'], return_tensors='pt')
model(inputs['input_ids'])

using cached model
using cached model


Seq2SeqLMOutput(loss=None, logits=tensor([[[-3.1152,  5.6895, -8.4704,  ..., -7.6067, -4.6915, -4.2773],
         [-6.7131,  4.5574, -7.6131,  ..., -6.8429, -5.2066, -6.7749]]],
       grad_fn=<AddBackward0>), past_key_values=((tensor([[[[-9.7980e-02, -6.6584e-01, -1.8089e+00,  ...,  9.6023e-01,
           -1.8818e-01, -1.3252e+00],
          [-6.2507e-01,  5.1009e-01, -7.4878e-01,  ...,  8.6230e-01,
            1.5722e-01, -6.0267e-01]],

         [[ 5.4597e-01, -2.3990e-01,  1.5901e+00,  ...,  4.3655e-01,
            7.9514e-01,  8.9880e-02],
          [-1.7327e-01, -6.3167e-01,  4.5152e-02,  ..., -1.4111e-01,
            1.8678e-01, -1.2081e-01]],

         [[ 1.4621e+00,  1.8980e+00, -7.6696e-01,  ...,  1.5695e+00,
            6.7921e-02, -3.9372e-01],
          [-4.1204e-02,  1.7132e+00, -1.1863e+00,  ..., -2.2272e-01,
            9.8310e-02,  8.1729e-01]],

         ...,

         [[ 4.8868e-01,  1.2633e+00, -1.1658e-01,  ..., -3.1989e-01,
            1.2202e+00, -7.9021e-02],
  

In [9]:
inputs

{'input_ids': tensor([[27616, 25161]]), 'token_type_ids': tensor([[0, 0]]), 'attention_mask': tensor([[1, 1]])}

In [10]:
import pandas as pd
train = pd.DataFrame(columns=['uid', 'title', 'region', 'context', 'summary'])
uid = 1000
for data in TRAIN_DATA:
    for agenda in data['context'].keys():
        context = ''
        for line in data['context'][agenda]:
            context += data['context'][agenda][line]
            context += ' '
        train.loc[uid, 'uid'] = uid
        train.loc[uid, 'title'] = data['title']
        train.loc[uid, 'region'] = data['region']
        train.loc[uid, 'context'] = context[:-1]
        train.loc[uid, 'summary'] = data['label'][agenda]['summary']
        uid += 1

test = pd.DataFrame(columns=['uid', 'title', 'region', 'context'])
uid = 2000
for data in TEST_DATA:
    for agenda in data['context'].keys():
        context = ''
        for line in data['context'][agenda]:
            context += data['context'][agenda][line]
            context += ' '
        test.loc[uid, 'uid'] = uid
        test.loc[uid, 'title'] = data['title']
        test.loc[uid, 'region'] = data['region']
        test.loc[uid, 'context'] = context[:-1]
        uid += 1

In [11]:
train['total'] = train.title + ' ' + train.region + ' ' + train.context
test['total'] = test.title + ' ' + test.region + ' ' + test.context

In [12]:
df_train = train.iloc[:-200]
df_val = train.iloc[-200:]

In [13]:
df_train.iloc[0]

uid                                                     1000
title                          제207회 완주군의회(임시회) 제 1 차 본회의회의록
region                                                    완주
context    의석을 정돈하여 주시기 바랍니다. 성원이 되었으므로 제207회 완주군의회 임시회 제...
summary                       제207회 완주군의회 임시회 제1차 본회의 개의 선포.
total      제207회 완주군의회(임시회) 제 1 차 본회의회의록 완주 의석을 정돈하여 주시기 ...
Name: 1000, dtype: object

In [14]:
import os
import glob
import torch
import ast
import numpy as np
import pandas as pd
from tqdm import tqdm, trange
from torch.utils.data import Dataset, DataLoader, IterableDataset

class KoBARTSummaryDataset(Dataset):
    def __init__(self, file, tok, max_len=512, pad_index = 0, ignore_index=-100, train=True):
        super().__init__()
        self.tok = tok
        self.max_len = max_len
        self.docs = file
        self.len = len(self.docs)
        self.pad_index = pad_index
        self.ignore_index = ignore_index
        self.train = train

    def add_padding_data(self, inputs):
        if len(inputs) < self.max_len:
            pad = np.array([self.pad_index] *(self.max_len - len(inputs)))
            inputs = np.concatenate([inputs, pad])
        else:
            inputs = inputs[:self.max_len]

        return inputs

    def add_ignored_data(self, inputs):
        if len(inputs) < self.max_len:
            pad = np.array([self.ignore_index] *(self.max_len - len(inputs)))
            inputs = np.concatenate([inputs, pad])
        else:
            inputs = inputs[:self.max_len]

        return inputs
    
    def __getitem__(self, idx):
        instance = self.docs.iloc[idx]
        context = instance['context']
        if self.train:
            summary = instance['summary']
        input_ids = self.tok.encode(context)
        input_ids = self.add_padding_data(input_ids)

        if self.train:
            label_ids = self.tok.encode(instance['summary'])
            label_ids.append(self.tok.eos_token_id)
            dec_input_ids = [self.pad_index]
            dec_input_ids += label_ids[:-1]
            dec_input_ids = self.add_padding_data(dec_input_ids)
            label_ids = self.add_ignored_data(label_ids)


            return {'input_ids': np.array(input_ids, dtype=np.int_),
                    'decoder_input_ids': np.array(dec_input_ids, dtype=np.int_),
                    'labels': np.array(label_ids, dtype=np.int_)}
        else:
            return {'input_ids': np.array(input_ids, dtype=np.int_),}
    
    def __len__(self):
        return self.len

In [15]:
train_dataset = KoBARTSummaryDataset(file=df_train, tok=kobart_tokenizer, train=True)
valid_dataset = KoBARTSummaryDataset(file=df_val, tok=kobart_tokenizer, train=True)
test_dataset = KoBARTSummaryDataset(file=test, tok=kobart_tokenizer, train=False)

train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False)

In [16]:
#input = train_dataset[0]
#input_ids = torch.as_tensor(input["input_ids"]).unsqueeze(0)#
#decoder_input_ids = torch.as_tensor(input["decoder_input_ids"]).unsqueeze(0)
#labels = torch.as_tensor(input["labels"]).unsqueeze(0)
#return_dict = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids, labels=labels) 
#loss = return_dict["loss"]

In [21]:
import torch
import torch.nn as nn
class KobartSummaryModule(nn.Module):
    def __init__(self, device):
        super().__init__()
        self.device = device
        self.model = BartForConditionalGeneration.from_pretrained(get_pytorch_kobart_model()).to(self.device)
        
    def forward(self, input):
        input_ids = torch.as_tensor(input["input_ids"]).to(self.device)
        decoder_input_ids = torch.as_tensor(input["decoder_input_ids"]).to(self.device)
        labels = torch.as_tensor(input["labels"]).to(self.device)

        if len(input_ids.shape) == 1:
            input_ids, decoder_input_ids, labels = input_ids.unsqueeze(0), decoder_input_ids.unsqueeze(0), labels.unsqueeze(0)
        
        return_dict = self.model(input_ids=input_ids, decoder_input_ids=decoder_input_ids, labels=labels) 
        loss = return_dict["loss"]

        return loss, return_dict


In [22]:
device = "cuda"
input = train_dataset[0]
model = KobartSummaryModule(device=device)
loss, _ = model(input)

using cached model


In [23]:
from transformers.optimization import AdamW, get_cosine_schedule_with_warmup
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(
        nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in param_optimizer if any(
        nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters,
                    lr=args.lr, correct_bias=False)

In [24]:
train_len = len(train_dataloader.dataset)
val_len = len(valid_dataloader.dataset)
logging.info(f'data length {data_len}')
num_train_steps = int(train_len / (args.batch_size * args.num_workers) * args.max_epochs)
logging.info(f'num_train_steps : {num_train_steps}')
num_warmup_steps = int(num_train_steps * args.warmup_ratio)
logging.info(f'num_warmup_steps : {num_warmup_steps}')
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps, num_training_steps=num_train_steps)
lr_scheduler = {'scheduler': scheduler, 
                'monitor': 'loss', 'interval': 'step',
                'frequency': 1}

In [26]:
from tqdm import tqdm

for epoch in range(args.max_epochs):
    total_loss = 0
    for i, d in enumerate(tqdm(train_dataloader)):
        model.train()
        input = d
        optimizer.zero_grad()
        loss, _ = model(input)
        loss.backward()

        optimizer.step()
        scheduler.step()
        
        #if i % args.print_step == 0:
        #    print("step:", i)
        #    print("loss:{:.2f}".format(loss.item()))
    print("EPOCH:", epoch+1)
    print("valid_loss:{:.2f}".format(total_loss/train_len))   

    total_loss = 0
    for i, d in enumerate(tqdm(valid_dataloader)):
        model.eval()
        with torch.no_grad():
            input = d
            loss, _ = model(input) 
            total_loss += loss 
    print("EPOCH:", epoch+1)
    print("valid_loss:{:.2f}".format(total_loss/val_len))       

    torch.save(
            model.state_dict(),
            args.save_path + f"/model_{epoch}.pth"
            )

    """torch.save(
        {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, 
            args.save_path + f"/model_{epoch}.pth"
        )"""
            




  0%|          | 1/699 [00:00<07:48,  1.49it/s]

step: 0
loss:6.76


 14%|█▍        | 101/699 [00:25<02:29,  4.01it/s]

step: 100
loss:3.02


 29%|██▉       | 201/699 [00:51<02:09,  3.86it/s]

step: 200
loss:1.23


 43%|████▎     | 301/699 [01:16<01:40,  3.98it/s]

step: 300
loss:0.46


 57%|█████▋    | 401/699 [01:42<01:17,  3.86it/s]

step: 400
loss:1.32


 72%|███████▏  | 501/699 [02:08<00:50,  3.92it/s]

step: 500
loss:0.49


 86%|████████▌ | 601/699 [02:34<00:25,  3.84it/s]

step: 600
loss:1.34


100%|██████████| 699/699 [02:59<00:00,  4.20it/s]
100%|██████████| 50/50 [00:04<00:00, 12.36it/s]


EPOCH: 1
valid_loss:0.22


  0%|          | 1/699 [00:00<07:50,  1.48it/s]

step: 0
loss:0.38


 14%|█▍        | 101/699 [00:26<02:40,  3.73it/s]

step: 100
loss:0.65


 29%|██▉       | 201/699 [00:53<02:10,  3.82it/s]

step: 200
loss:0.37


 43%|████▎     | 301/699 [01:18<01:44,  3.83it/s]

step: 300
loss:1.22


 57%|█████▋    | 401/699 [01:44<01:15,  3.93it/s]

step: 400
loss:0.09


 72%|███████▏  | 501/699 [02:09<00:51,  3.83it/s]

step: 500
loss:0.17


 86%|████████▌ | 601/699 [02:36<00:25,  3.78it/s]

step: 600
loss:0.30


100%|██████████| 699/699 [03:01<00:00,  4.27it/s]
100%|██████████| 50/50 [00:03<00:00, 12.52it/s]


EPOCH: 2
valid_loss:0.19


  0%|          | 1/699 [00:00<08:16,  1.41it/s]

step: 0
loss:0.17


 14%|█▍        | 101/699 [00:26<02:33,  3.89it/s]

step: 100
loss:0.53


 29%|██▉       | 201/699 [00:52<02:11,  3.79it/s]

step: 200
loss:0.02


 43%|████▎     | 301/699 [01:19<01:45,  3.78it/s]

step: 300
loss:0.37


 57%|█████▋    | 401/699 [01:45<01:19,  3.76it/s]

step: 400
loss:0.15


 72%|███████▏  | 501/699 [02:11<00:51,  3.85it/s]

step: 500
loss:0.15


 86%|████████▌ | 601/699 [02:37<00:25,  3.85it/s]

step: 600
loss:0.13


100%|██████████| 699/699 [03:02<00:00,  4.29it/s]
100%|██████████| 50/50 [00:04<00:00, 12.48it/s]


EPOCH: 3
valid_loss:0.20


  0%|          | 1/699 [00:00<08:02,  1.45it/s]

step: 0
loss:0.24


 14%|█▍        | 101/699 [00:25<02:31,  3.94it/s]

step: 100
loss:0.07


 29%|██▉       | 201/699 [00:51<02:13,  3.72it/s]

step: 200
loss:0.96


 43%|████▎     | 301/699 [01:17<01:42,  3.87it/s]

step: 300
loss:0.12


 57%|█████▋    | 401/699 [01:43<01:16,  3.91it/s]

step: 400
loss:0.26


 72%|███████▏  | 501/699 [02:08<00:50,  3.92it/s]

step: 500
loss:0.32


 86%|████████▌ | 601/699 [02:34<00:24,  3.96it/s]

step: 600
loss:0.13


100%|██████████| 699/699 [02:59<00:00,  4.26it/s]
100%|██████████| 50/50 [00:03<00:00, 12.57it/s]


EPOCH: 4
valid_loss:0.23


  0%|          | 1/699 [00:00<08:38,  1.35it/s]

step: 0
loss:1.09


 14%|█▍        | 101/699 [00:26<02:33,  3.89it/s]

step: 100
loss:0.43


 29%|██▉       | 201/699 [00:52<02:05,  3.96it/s]

step: 200
loss:0.22


 43%|████▎     | 301/699 [01:17<01:40,  3.94it/s]

step: 300
loss:0.25


 57%|█████▋    | 401/699 [01:43<01:17,  3.83it/s]

step: 400
loss:0.15


 72%|███████▏  | 501/699 [02:09<00:52,  3.77it/s]

step: 500
loss:0.07


 86%|████████▌ | 601/699 [02:35<00:25,  3.77it/s]

step: 600
loss:0.03


100%|██████████| 699/699 [03:00<00:00,  4.40it/s]
100%|██████████| 50/50 [00:03<00:00, 12.55it/s]


EPOCH: 5
valid_loss:0.24


  0%|          | 1/699 [00:00<07:52,  1.48it/s]

step: 0
loss:0.24


  5%|▌         | 35/699 [00:09<02:53,  3.83it/s]

KeyboardInterrupt: 