In [None]:
import re
import os
import math
import torch
import random
import pandas as pd
import pytorch_lightning as pl

from tqdm import tqdm
from rouge import Rouge
from datetime import datetime

from torch import nn
from konlpy.tag import Mecab
from tensorboardX import SummaryWriter
from transformers import AutoTokenizer
from torch.nn.utils.rnn import pad_sequence
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import Dataset , DataLoader

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import BartTokenizer, BartForConditionalGeneration, BartTokenizerFast

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

epochs = 1000
batch_size = 8
accumulation_steps = 4
num_workers = 0
patience = 10

init_lr = 0.000001
max_lr = 0.00001
weight_decay = 0
warmup_epochs = 10
T_0 = 100
T_mult = 1
T_gamma = 0.5

dig_max_len = 512
sum_max_len = 256

timestamp = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
save_path = os.path.join("./BART_runs", timestamp)

weights_path = os.path.join(save_path, 'weights')
logs_path = os.path.join(save_path, 'logs')
os.makedirs(weights_path, exist_ok=True)
os.makedirs(logs_path, exist_ok=True)

In [None]:
"noahkim/KoBigBird-KoBart-News-Summarization"
model_id = "alaggung/bart-r3f"
tokenizer = BartTokenizerFast.from_pretrained("../tokenizer/ko_sentencepiece")
model = BartForConditionalGeneration.from_pretrained(model_id).to(device)

In [None]:
special_tokens = [
    '#Person1#',
    '#Person2#',
    '#Person3#',
    '#Person4#',
    '#Person5#',
    '#Person6#',
    '#Person7#',
    '#SSN#',
    '#Email#',
    '#Address#',
    '#Reaction#',
    '#CarNumber#',
    '#Movietitle#',
    '#DateOfBirth#',
    '#CardNumber#',
    '#PhoneNumber#',
    '#PassportNumber#',
]

remove_tokens = ['<usr>', f"{tokenizer.bos_token}", f"{tokenizer.unk_token}", f"{tokenizer.eos_token}", f"{tokenizer.pad_token}", f"{tokenizer.sep_token}", f"{tokenizer.mask_token}"]
tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})

print(tokenizer.special_tokens_map)

In [5]:
class SpanCorruptionDataset(Dataset):
    def __init__(self, df, tokenizer, input_len, mask_prob=0.15, permute_prob=0.3):
        self.tokenizer = tokenizer
        self.df = df.copy()
        self.source_len = input_len
        self.mask_prob = mask_prob
        self.permute_prob = permute_prob
        
        self.input_ids = tokenizer(self.df['total'].tolist(), 
                                   return_tensors="pt", 
                                   padding=True,
                                   add_special_tokens=True, 
                                   truncation=True, 
                                   max_length=self.source_len, 
                                   return_token_type_ids=False).input_ids

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        input_ids = self.input_ids[idx]
        if random.random() < self.permute_prob:
            input_ids = self.apply_sentence_permutation(input_ids)
    
        masked_input_ids = self.apply_span_corruption(input_ids)
        return masked_input_ids, input_ids

    def apply_span_corruption(self, input_ids):
        num_to_mask = int(len(input_ids) * self.mask_prob)
        mask_indices = random.sample(range(len(input_ids)), num_to_mask)
        
        corrupted_ids = input_ids.clone()
        corrupted_ids[mask_indices] = self.tokenizer.mask_token_id
        return corrupted_ids

    def apply_sentence_permutation(self, input_ids):
        text = self.tokenizer.decode(input_ids, skip_special_tokens=True)
        sentences = text.split(". ")
        
        random.shuffle(sentences)
        permuted_text = ". ".join(sentences)
        permuted_input_ids = self.tokenizer(permuted_text, 
                                            return_tensors="pt", 
                                            padding=True, 
                                            truncation=True, 
                                            max_length=self.source_len).input_ids.squeeze(0)
        return permuted_input_ids

def collate_fn(batch):
    masked_input_ids = [item[0] for item in batch]
    input_ids = [item[1] for item in batch]
    
    masked_input_ids_padded = pad_sequence(masked_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    
    return masked_input_ids_padded, input_ids_padded

In [6]:
train_df = pd.read_csv('../dataset/cleaned_train.csv')
val_df = pd.read_csv('../dataset/cleaned_dev.csv')

trans_train = pd.read_csv("../dataset/translated_train.csv")
trans_val = pd.read_csv("../dataset/translated_valid.csv")
trans_train['dialogue'] = trans_train['dialogue'].str.replace('\n', ' ')
trans_val['dialogue'] = trans_val['dialogue'].str.replace('\n', ' ')

train_df['total'] = train_df['dialogue'] + " " + train_df['summary']
trans_train['total'] = trans_train['dialogue'] + " " + trans_train['dialogue']
val_df['total'] = val_df['dialogue'] + " " + val_df['summary']
trans_val['total'] = trans_val['dialogue'] + " " + trans_val['summary']

train_df = pd.concat([train_df, trans_train, trans_val])

pretrain_dataset = SpanCorruptionDataset(train_df[['total']], tokenizer, dig_max_len)
prevalid_dataset = SpanCorruptionDataset(val_df[['total']], tokenizer, dig_max_len)

pretrain_loader = DataLoader(pretrain_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=collate_fn)
preval_loader = DataLoader(prevalid_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=collate_fn)

In [None]:
batch = next(iter(pretrain_loader))
masked_input_ids, original_input_ids = batch[0], batch[1]

masked_example = tokenizer.decode(masked_input_ids[0], skip_special_tokens=False)
original_example = tokenizer.decode(original_input_ids[0], skip_special_tokens=False)

masked_example, original_example

In [8]:
def ids_to_words(tokenizer, preds, labels):
    decoded_preds = tokenizer.batch_decode(preds, clean_up_tokenization_spaces=True)
    labels = tokenizer.batch_decode(labels, clean_up_tokenization_spaces=True)

    replaced_predictions = decoded_preds.copy()
    replaced_labels = labels.copy()

    for token in remove_tokens:
        replaced_predictions = [sentence.replace(token," ") for sentence in replaced_predictions]
        replaced_labels = [sentence.replace(token," ") for sentence in replaced_labels]
        
    return replaced_predictions, replaced_labels

In [9]:
def compute_metrics(replaced_predictions, replaced_labels):
    rouge = Rouge()

    results = rouge.get_scores(replaced_predictions, replaced_labels,avg=True)
    result = {key: value["f"] for key, value in results.items()}
    
    return result

In [10]:
class CosineAnnealingWarmUpRestarts(_LRScheduler):
    def __init__(self, optimizer, T_0, T_mult=1, eta_max=0.1, T_up=0, gamma=1., last_epoch=-1):
        if T_0 <= 0 or not isinstance(T_0, int):
            raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
        if T_mult < 1 or not isinstance(T_mult, int):
            raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
        if T_up < 0 or not isinstance(T_up, int):
            raise ValueError("Expected positive integer T_up, but got {}".format(T_up))
        self.T_0 = T_0
        self.T_mult = T_mult
        self.base_eta_max = eta_max
        self.eta_max = eta_max
        self.T_up = T_up
        self.T_i = T_0
        self.gamma = gamma
        self.cycle = 0
        self.T_cur = last_epoch
        super().__init__(optimizer, last_epoch)
    
    def get_lr(self):
        if self.T_cur == -1:
            return self.base_lrs
        elif self.T_cur < self.T_up:
            return [(self.eta_max - base_lr)*self.T_cur / self.T_up + base_lr for base_lr in self.base_lrs]
        else:
            return [base_lr + (self.eta_max - base_lr) * (1 + math.cos(math.pi * (self.T_cur-self.T_up) / (self.T_i - self.T_up))) / 2
                    for base_lr in self.base_lrs]

    def step(self, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
            self.T_cur = self.T_cur + 1
            if self.T_cur >= self.T_i:
                self.cycle += 1
                self.T_cur = self.T_cur - self.T_i
                self.T_i = (self.T_i - self.T_up) * self.T_mult + self.T_up
        else:
            if epoch >= self.T_0:
                if self.T_mult == 1:
                    self.T_cur = epoch % self.T_0
                    self.cycle = epoch // self.T_0
                else:
                    n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
                    self.cycle = n
                    self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
                    self.T_i = self.T_0 * self.T_mult ** (n)
            else:
                self.T_i = self.T_0
                self.T_cur = epoch
                
        self.eta_max = self.base_eta_max * (self.gamma**self.cycle)
        self.last_epoch = math.floor(epoch)
        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr

In [11]:
def pretrain_epoch(model, device, train_loader, optimizer, epoch, accumulation_steps):
    model.train()
    total_loss = 0.0
    optimizer.zero_grad()

    for idx, batch in tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Pretraining Epoch {epoch+1}", leave=False):
        masked_input_ids = batch[0].to(device)
        labels = batch[1].to(device)

        outputs = model(input_ids=masked_input_ids, labels=labels)
        loss = outputs.loss
        loss = loss / accumulation_steps
        loss.backward()

        if (idx + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        total_loss += loss.item() * accumulation_steps

    avg_loss = total_loss / len(train_loader)
    print(f"Pretrain Loss: {avg_loss:.6f}")

    return avg_loss


def pretrain_validate(tokenizer, model, device, val_loader, epoch):
    model.eval()
    total_loss = 0
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for idx, batch in tqdm(enumerate(val_loader), total=len(val_loader), desc="Pretrain Validating", leave=False):
            input_ids = batch[0].to(device)
            labels = batch[1].to(device)

            pred_ids = model.generate(input_ids=input_ids, max_length=sum_max_len, num_beams=4, repetition_penalty=2.0, length_penalty=1.0, early_stopping=True)
            loss = model(input_ids=input_ids, labels=labels).loss
            total_loss += loss.item()

            replaced_predictions, replaced_labels = ids_to_words(tokenizer, pred_ids, labels)
            all_predictions.extend(replaced_predictions)
            all_labels.extend(replaced_labels)

    avg_loss = total_loss / len(val_loader)

    print(f"Prevalid Loss: {avg_loss:.6f}")
    for i in range(min(3, len(all_predictions))):
        print(f"[예측 문장 {i+1}]: {all_predictions[i]}")
        print(f"[정답 문장 {i+1}]: {all_labels[i]}")
        print("-" * 80)

    return avg_loss

In [None]:
pretrain_epochs = 10
best_pretrain_loss = float('inf')
optimizer = torch.optim.AdamW(model.parameters(), lr=init_lr, weight_decay=0)
scheduler = CosineAnnealingWarmUpRestarts(optimizer, T_0=T_0, T_mult=T_mult, eta_max=max_lr,  T_up=warmup_epochs, gamma=T_gamma)

for epoch in range(pretrain_epochs):
    current_lr = optimizer.param_groups[0]['lr']
    print(f"Epoch {epoch}/{epochs}, Current Learning Rate: {current_lr:.6f}")
    train_loss = pretrain_epoch(model, device, pretrain_loader, optimizer, epoch, accumulation_steps)
    val_loss = pretrain_validate(tokenizer, model, device, preval_loader, epoch)
    
    scheduler.step()
    if val_loss < best_pretrain_loss:
        best_pretrain_loss = val_loss
        early_stopping_counter = 0 
        torch.save(model.state_dict(), os.path.join(weights_path, 'best_pretrain.pth'))
        print(f"Best pretrain model saved at Epoch {epoch+1} with validation loss {best_pretrain_loss:.6f}\n")
    else:
        early_stopping_counter += 1
        print(f"No improvement. Early stopping counter: {early_stopping_counter}/{patience}\n")
    
    if early_stopping_counter >= patience:
        print("Early stopping triggered.")
        break

    torch.save(model.state_dict(), os.path.join(weights_path, f'epoch-{epoch+1}_pretrain.pth'))

In [None]:
train_df = pd.read_csv('../dataset/cleaned_train.csv')
val_df = pd.read_csv('../dataset/cleaned_dev.csv')
test_df = pd.read_csv("../dataset/test.csv")

model = BartForConditionalGeneration.from_pretrained(model_id).to(device)
weights_file = f"{save_path}/weights/best_pretrain.pth"
model.load_state_dict(torch.load(weights_file))

optimizer = torch.optim.AdamW(model.parameters(), lr=init_lr, weight_decay=weight_decay)

print(model.config)

In [21]:
class CustomDataset(Dataset):
    def __init__(self, df, tokenizer, input_len, summ_len, is_train=True):
        self.tokenizer = tokenizer
        self.df = df.copy()
        self.source_len = input_len
        self.summ_len = summ_len
        self.is_train = is_train

        # 화자가 바뀔 때 SEP 토큰을 추가
        self.df.loc[:, 'dialogue'] = self.df['dialogue'].apply(self.add_sep_tokens)

        if self.is_train:           
            self.input_ids = tokenizer(self.df['dialogue'].tolist(), 
                                       return_tensors="pt", 
                                       padding=True,
                                       add_special_tokens=True, 
                                       truncation=True, 
                                       max_length=dig_max_len, 
                                       return_token_type_ids=False).input_ids
            
            self.labels = tokenizer(self.df['summary'].tolist(), 
                                    return_tensors="pt", 
                                    padding=True,
                                    add_special_tokens=True, 
                                    truncation=True, 
                                    max_length=sum_max_len, 
                                    return_token_type_ids=False).input_ids
        else:
            self.input_ids = tokenizer(self.df['dialogue'].tolist(), 
                                       return_tensors="pt", 
                                       padding=True,
                                       add_special_tokens=True, 
                                       truncation=True, 
                                       max_length=dig_max_len, 
                                       return_token_type_ids=False).input_ids

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        if self.is_train:
            return self.input_ids[idx], self.labels[idx]
        else:
            return self.input_ids[idx]


    def add_sep_tokens(self, dialogue):
        # 화자가 바뀔 때 SEP 토큰을 추가
        pattern = r'(#Person\d+#)'  # 화자를 나타내는 패턴
        parts = re.split(pattern, dialogue)  # 화자를 기준으로 대화 분리
        result = []
        prev_speaker = None
        for part in parts:
            if re.match(pattern, part):  # 화자가 바뀌면
                if prev_speaker and prev_speaker != part:
                    result.append('<sep>')  # SEP 토큰 추가
                prev_speaker = part
            result.append(part)
        return ''.join(result)

In [22]:
train_dataset = CustomDataset(train_df[['dialogue', 'summary']], tokenizer, dig_max_len, sum_max_len)
val_dataset = CustomDataset(val_df[['dialogue', 'summary']], tokenizer, dig_max_len, sum_max_len)
test_dataset = CustomDataset(test_df[['dialogue']], tokenizer, dig_max_len, sum_max_len, is_train=False)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [23]:
def train(epoch, model, device, train_loader, optimizer, writer, accumulation_steps):
    model.train()
    total_loss = 0.0
    optimizer.zero_grad()

    for idx, batch in tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Training Epoch {epoch}", leave=False):
        input_ids = batch[0].to(device, dtype=torch.long)
        labels = batch[1].to(device, dtype=torch.long)

        outputs = model(input_ids=input_ids, labels=labels)
        ce_loss = outputs.loss
        ce_loss = ce_loss / accumulation_steps 
        ce_loss.backward()

        if (idx + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        total_loss += ce_loss.item() * accumulation_steps

    avg_loss = total_loss / len(train_loader)
    writer.add_scalar('Loss/train', avg_loss, epoch)
    return avg_loss

In [24]:
def validate(tokenizer, model, device, val_loader, writer, epoch):
    model.eval()
    total_loss = 0
    all_results = []
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for idx, batch in tqdm(enumerate(val_loader), total=len(val_loader), desc="Validating", leave=False):
            input_ids = batch[0].to(device, dtype=torch.long)
            labels = batch[1].to(device, dtype=torch.long)

            pred_ids = model.generate(input_ids=input_ids, max_length=sum_max_len, num_beams=4, repetition_penalty=2.0, 
                                      length_penalty=1.0, early_stopping=True, no_repeat_ngram_size=2)

            loss = model(input_ids=input_ids, labels=labels).loss
            total_loss += loss.item()

            replaced_predictions, replaced_labels = ids_to_words(tokenizer, pred_ids, labels)
            result = compute_metrics(replaced_predictions, replaced_labels)

            all_results.append(result)
            all_predictions.extend(replaced_predictions)
            all_labels.extend(replaced_labels)

    val_loss = total_loss / len(val_loader)
    avg_result = {key: sum(r[key] for r in all_results) / len(all_results) for key in all_results[0]}
    
    writer.add_scalar('Loss/valid', val_loss, epoch)
    writer.add_scalar('ROUGE/rouge-1', avg_result['rouge-1'], epoch)
    writer.add_scalar('ROUGE/rouge-2', avg_result['rouge-2'], epoch)
    writer.add_scalar('ROUGE/rouge-l', avg_result['rouge-l'], epoch)

    return val_loss, avg_result, all_predictions, all_labels

In [None]:
best_rouge = 0
early_stopping_counter = 0
best_val_loss = float('inf')
writer = SummaryWriter(log_dir=logs_path)
scheduler = CosineAnnealingWarmUpRestarts(optimizer, T_0=T_0, T_mult=T_mult, eta_max=max_lr,  T_up=warmup_epochs, gamma=T_gamma)

for epoch in range(1, epochs + 1):
    current_lr = optimizer.param_groups[0]['lr']
    print(f"Epoch {epoch}/{epochs}, Current Learning Rate: {current_lr:.6f}")
    train_loss = train(epoch, model, device, train_loader, optimizer, writer, accumulation_steps)
    val_loss, val_result, val_predictions, val_labels = validate(tokenizer, model, device, val_loader, writer, epoch)

    # avg_rouge = (val_result['rouge-1'] + val_result['rouge-2'] + val_result['rouge-l']) / 3
    print(f"Train Loss: {train_loss:.6f}, Valid Loss: {val_loss:.6f}")
    print(f"Rouge-1: {val_result['rouge-1']:.6f}, Rouge-2: {val_result['rouge-2']:.6f}, Rouge-l: {val_result['rouge-l']:.6f}")

    scheduler.step()
    print('-'*150)
    for i in range(3):
        print(f"PRED: {val_predictions[i].strip()}")
        print(f"GOLD: {val_labels[i]}")
        print('-'*150)

    # if avg_rouge > best_rouge:
    if val_result['rouge-2'] > best_rouge:
        # best_rouge = avg_rouge
        best_rouge = val_result['rouge-2']
        early_stopping_counter = 0 
        torch.save(model.state_dict(), os.path.join(weights_path, 'best_finetune.pth'))
        print(f"New best model saved with average ROUGE: {best_rouge:.6f}")

    else:
        early_stopping_counter += 1
        print(f"Not improve. Early stopping counter: {early_stopping_counter}/{patience}")

    if early_stopping_counter >= patience:
        print("Early stopping triggered.")
        break

    torch.save(model.state_dict(), os.path.join(weights_path, f'epoch-{epoch}_finetune.pth'))
    print()

writer.close()
torch.save(model.state_dict(), os.path.join(weights_path, 'last_finetune.pth'))
print("Training completed. Last model saved.")

In [26]:
def predict(tokenizer, model, device, test_loader, fname):
    model.eval()
    summary = []
    with torch.no_grad():
        for input_ids in tqdm(test_loader):
            input_ids = input_ids.to(device, dtype=torch.long)

            pred_ids = model.generate(
                input_ids=input_ids,
                max_length=sum_max_len, 
                num_beams=4,
                repetition_penalty=2.0, 
                length_penalty=1.0, 
                early_stopping=True,
                no_repeat_ngram_size=2
            )
            for ids in pred_ids:
                result = tokenizer.decode(ids)
                summary.append(result)
                
    # remove_tokens = ['<usr>', f"{tokenizer.unk_token}", f"{tokenizer.eos_token}", f"{tokenizer.pad_token}"]
    preprocessed_summary = summary.copy()
    for token in remove_tokens:
        preprocessed_summary = [sentence.replace(token," ") for sentence in preprocessed_summary]

    output = pd.DataFrame(
        {
            "fname": fname,
            "summary" : preprocessed_summary,
        }
    )
    return output

In [None]:
model.load_state_dict(torch.load(f'{save_path}/weights/best_finetune.pth'))
output = predict(tokenizer, model, device, test_loader, test_df['fname'])
output.to_csv(f"{save_path}/prediction.csv", index=False)