In [None]:
import re
import os
import math
import torch
import pandas as pd
import pytorch_lightning as pl

from tqdm import tqdm
from rouge import Rouge
from datetime import datetime

from torch import nn
from konlpy.tag import Mecab
from tensorboardX import SummaryWriter
from torch.nn import functional as F
from transformers import AutoTokenizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import Dataset , DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
model_id = "psyche/KoT5-summarization"
train_df = pd.read_csv('../dataset/cleaned_train.csv')
val_df = pd.read_csv('../dataset/cleaned_dev.csv')
test_df = pd.read_csv("../dataset/test.csv")

epochs = 100
batch_size = 8
num_workers = 0
patience = 5

init_lr = 0.0001
max_lr = 0.001
warmup_epochs = 10
T_0 = 100
T_mult = 1
T_gamma = 0.5

dig_max_len = 1024
sum_max_len = 512

tokenizer = AutoTokenizer.from_pretrained(model_id)
special_tokens_dict={'additional_special_tokens': ['#Person1#', '#Person2#','#Person3#', '#Person4#', '#Person5#', '#Person6#', '#Person7#', '#PhoneNumber#', 
                                                   '#Address#', '#PassportNumber#', '#CardNumber#', '#Email#', '#DateOfBirth#',
                                                   '<sep>']}

tokenizer.add_special_tokens(special_tokens_dict)
print(tokenizer.special_tokens_map)

remove_tokens = [
    '<usr>',
    f"{tokenizer.unk_token}", 
    f"{tokenizer.eos_token}", 
    f"{tokenizer.pad_token}"
]

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)

In [None]:
class CustomDataset(Dataset):
    def __init__(self, df, tokenizer, input_len, summ_len, is_train=True):
        self.tokenizer = tokenizer
        self.df = df.copy()
        self.source_len = input_len
        self.summ_len = summ_len
        self.is_train = is_train

        self.df.loc[:, 'dialogue'] = self.df['dialogue'].apply(self.add_sep_tokens)

        if self.is_train:
            self.input_ids = tokenizer(self.df['dialogue'].tolist(), return_tensors="pt", padding=True,
                                add_special_tokens=True, truncation=True, max_length=512, return_token_type_ids=False).input_ids
            self.labels = tokenizer(self.df['summary'].tolist(), return_tensors="pt", padding=True,
                                add_special_tokens=True, truncation=True, max_length=100, return_token_type_ids=False).input_ids
        else:
            self.input_ids = tokenizer(self.df['dialogue'].tolist(), return_tensors="pt", padding=True,
                                add_special_tokens=True, truncation=True, max_length=512, return_token_type_ids=False).input_ids

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        if self.is_train:
            return self.input_ids[idx], self.labels[idx]
        else:
            return self.input_ids[idx]

    def add_sep_tokens(self, dialogue):
        pattern = r'(#Person\d+#)'
        parts = re.split(pattern, dialogue)
        result = []
        prev_speaker = None
        for part in parts:
            if re.match(pattern, part):
                if prev_speaker and prev_speaker != part:
                    result.append('<sep>')
                prev_speaker = part
            result.append(part)
        return ''.join(result)


In [None]:
train_dataset = CustomDataset(train_df[['dialogue', 'summary']], tokenizer, dig_max_len, sum_max_len)
val_dataset = CustomDataset(val_df[['dialogue', 'summary']], tokenizer, dig_max_len, sum_max_len)
test_dataset = CustomDataset(test_df[['dialogue']], tokenizer, dig_max_len, sum_max_len, is_train=False)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [None]:
def ids_to_words(tokenizer, preds, labels):
    decoded_preds = tokenizer.batch_decode(preds, clean_up_tokenization_spaces=True)
    labels = tokenizer.batch_decode(labels, clean_up_tokenization_spaces=True)

    replaced_predictions = decoded_preds.copy()
    replaced_labels = labels.copy()

    for token in remove_tokens:
        replaced_predictions = [sentence.replace(token," ") for sentence in replaced_predictions]
        replaced_labels = [sentence.replace(token," ") for sentence in replaced_labels]
        
    return replaced_predictions, replaced_labels

In [None]:
def compute_metrics(replaced_predictions, replaced_labels):
    rouge = Rouge()

    results = rouge.get_scores(replaced_predictions, replaced_labels,avg=True)
    result = {key: value["f"] for key, value in results.items()}
    
    return result

In [None]:
class CosineAnnealingWarmUpRestarts(_LRScheduler):
    def __init__(self, optimizer, T_0, T_mult=1, eta_max=0.1, T_up=0, gamma=1., last_epoch=-1):
        if T_0 <= 0 or not isinstance(T_0, int):
            raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
        if T_mult < 1 or not isinstance(T_mult, int):
            raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
        if T_up < 0 or not isinstance(T_up, int):
            raise ValueError("Expected positive integer T_up, but got {}".format(T_up))
        self.T_0 = T_0
        self.T_mult = T_mult
        self.base_eta_max = eta_max
        self.eta_max = eta_max
        self.T_up = T_up
        self.T_i = T_0
        self.gamma = gamma
        self.cycle = 0
        self.T_cur = last_epoch
        super(CosineAnnealingWarmUpRestarts, self).__init__(optimizer, last_epoch)
    
    def get_lr(self):
        if self.T_cur == -1:
            return self.base_lrs
        elif self.T_cur < self.T_up:
            return [(self.eta_max - base_lr)*self.T_cur / self.T_up + base_lr for base_lr in self.base_lrs]
        else:
            return [base_lr + (self.eta_max - base_lr) * (1 + math.cos(math.pi * (self.T_cur-self.T_up) / (self.T_i - self.T_up))) / 2
                    for base_lr in self.base_lrs]

    def step(self, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
            self.T_cur = self.T_cur + 1
            if self.T_cur >= self.T_i:
                self.cycle += 1
                self.T_cur = self.T_cur - self.T_i
                self.T_i = (self.T_i - self.T_up) * self.T_mult + self.T_up
        else:
            if epoch >= self.T_0:
                if self.T_mult == 1:
                    self.T_cur = epoch % self.T_0
                    self.cycle = epoch // self.T_0
                else:
                    n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
                    self.cycle = n
                    self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
                    self.T_i = self.T_0 * self.T_mult ** (n)
            else:
                self.T_i = self.T_0
                self.T_cur = epoch
                
        self.eta_max = self.base_eta_max * (self.gamma**self.cycle)
        self.last_epoch = math.floor(epoch)
        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr

In [None]:
def sample_sequence(model, input_ids, max_length):
    # 실제 샘플링 구현
    with torch.no_grad():
        output = model.generate(
            input_ids=input_ids,
            max_length=max_length,
            do_sample=True,
            top_k=0,
            temperature=0.7,
            no_repeat_ngram_size=2,
            return_dict_in_generate=True,
            output_scores=True
        )
    return output.sequences, output.scores

def greedy_sequence(model, input_ids, max_length):
    # 그리디 디코딩 구현
    output = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        num_beams=1,
        do_sample=False
    )
    return output

def compute_rouge(predictions, references):
    rouge = Rouge()
    scores = rouge.get_scores(predictions, references, avg=True)
    return {key: value['f'] for key, value in scores.items()}


def compute_log_probs(model, input_ids, sampled_ids):
    # 로그 확률 계산을 위해 모델을 다시 실행
    outputs = model(input_ids=input_ids, labels=sampled_ids)
    logits = outputs.logits
    
    log_probs = F.log_softmax(logits[:, :-1, :], dim=-1)
    target_log_probs = torch.gather(log_probs, 2, sampled_ids[:, 1:].unsqueeze(-1)).squeeze(-1)
    
    # 패딩 마스크 생성
    mask = (sampled_ids[:, 1:] != model.config.pad_token_id).float()
    
    # 마스크를 적용하여 패딩 토큰을 무시
    masked_log_probs = target_log_probs * mask
    sequence_log_probs = masked_log_probs.sum(dim=1)
    
    return sequence_log_probs

In [None]:
def scst_loss(model, input_ids, pred_ids, baseline_ids, labels, tokenizer):
    # 예측 시퀀스와 baseline 시퀀스를 디코딩
    replaced_preds, replaced_labels = ids_to_words(tokenizer, pred_ids, labels)
    replaced_baseline, _ = ids_to_words(tokenizer, baseline_ids, labels)

    # ROUGE 점수 계산
    pred_scores_rouge = compute_rouge(replaced_preds, replaced_labels)
    baseline_scores = compute_rouge(replaced_baseline, replaced_labels)

    # 보상 계산 (여기서는 ROUGE-L 점수를 사용)
    rewards = torch.tensor([pred_scores_rouge['rouge-l'] for _ in range(len(pred_ids))]).to(pred_ids.device)
    baseline = torch.tensor([baseline_scores['rouge-l'] for _ in range(len(baseline_ids))]).to(baseline_ids.device)

    # 로그 확률 계산
    log_probs = compute_log_probs(model, input_ids, pred_ids)

    # SCST 손실 계산
    loss = -((rewards - baseline.detach()) * log_probs).mean()
    return loss

In [None]:
def train_scst(epoch, model, device, train_loader, optimizer, tokenizer, writer):
    model.train()
    total_loss = 0.0
    
    for idx, batch in tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Training Epoch {epoch}", leave=False):
        input_ids = batch[0].to(device, dtype=torch.long)
        labels = batch[1].to(device, dtype=torch.long)

        optimizer.zero_grad()

        # Baseline 예측 (greedy decoding)
        with torch.no_grad():
            baseline_ids = greedy_sequence(model, input_ids, max_length=256)

        # 샘플링을 통한 예측
        pred_ids, _ = sample_sequence(model, input_ids, max_length=256)

        # SCST로 loss 계산
        loss = scst_loss(model, input_ids, pred_ids, baseline_ids, labels, tokenizer)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    writer.add_scalar('Loss/train', avg_loss, epoch)
    return avg_loss

In [None]:
def validate_scst(epoch, model, device, val_loader, tokenizer, writer):
    model.eval()
    total_rouge_1 = 0.0
    total_rouge_2 = 0.0
    total_rouge_l = 0.0
    all_predictions = []
    all_references = []

    with torch.no_grad():
        for idx, batch in tqdm(enumerate(val_loader), total=len(val_loader), desc="Validating", leave=False):
            input_ids = batch[0].to(device, dtype=torch.long)
            labels = batch[1].to(device, dtype=torch.long)

            # 샘플링을 통한 예측
            pred_ids, _ = sample_sequence(model, input_ids, max_length=256)

            # 디코딩
            predictions, references = ids_to_words(tokenizer, pred_ids, labels)
            
            # ROUGE 점수 계산
            rouge_scores = compute_rouge(predictions, references)
            
            total_rouge_1 += rouge_scores['rouge-1']
            total_rouge_2 += rouge_scores['rouge-2']
            total_rouge_l += rouge_scores['rouge-l']

            all_predictions.extend(predictions)
            all_references.extend(references)

    avg_rouge_1 = total_rouge_1 / len(val_loader)
    avg_rouge_2 = total_rouge_2 / len(val_loader)
    avg_rouge_l = total_rouge_l / len(val_loader)

    writer.add_scalar('ROUGE/rouge-1', avg_rouge_1, epoch)
    writer.add_scalar('ROUGE/rouge-2', avg_rouge_2, epoch)
    writer.add_scalar('ROUGE/rouge-l', avg_rouge_l, epoch)

    return {
        'rouge-1': avg_rouge_1,
        'rouge-2': avg_rouge_2,
        'rouge-l': avg_rouge_l
    }, all_predictions, all_references

In [None]:
train_step = 0

timestamp = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
save_path = os.path.join("./T5_runs", timestamp)

weights_path = os.path.join(save_path, 'weights')
logs_path = os.path.join(save_path, 'logs')
os.makedirs(weights_path, exist_ok=True)
os.makedirs(logs_path, exist_ok=True)

best_rouge_l = 0
best_avg_rouge = 0
early_stopping_counter = 0
writer = SummaryWriter(log_dir=logs_path)
scheduler = CosineAnnealingWarmUpRestarts(optimizer, T_0=T_0, T_mult=T_mult, eta_max=max_lr,  T_up=warmup_epochs, gamma=T_gamma)

for epoch in range(1, epochs + 1):
    train_loss = train_scst(epoch, model, device, train_loader, optimizer, tokenizer, writer)
    val_result, val_predictions, val_references = validate_scst(epoch, model, device, val_loader, tokenizer, writer)

    print(f"Epoch {epoch}/{epochs}")
    print(f"Train Loss: {train_loss:.6f}")
    print(f"Validation ROUGE-1: {val_result['rouge-1']:.6f}")
    print(f"Validation ROUGE-2: {val_result['rouge-2']:.6f}")
    print(f"Validation ROUGE-L: {val_result['rouge-l']:.6f}")

    # 예시 출력
    print('-' * 100)
    for i in range(min(3, len(val_predictions))):
        print(f"예측: {val_predictions[i]}")
        print(f"참조: {val_references[i]}")
        print('-' * 100)

    # 모델 저장 및 조기 종료 확인
    if val_result['rouge-l'] > best_rouge_l:
        best_rouge_l = val_result['rouge-l']
        early_stopping_counter = 0
        torch.save(model.state_dict(), os.path.join(weights_path, 'best_model.pth'))
        print(f"새로운 최고 모델이 저장되었습니다. ROUGE-L: {best_rouge_l:.6f}")
    else:
        early_stopping_counter += 1
        print(f"성능 향상 없음. 조기 종료 카운터: {early_stopping_counter}/{patience}")

    if early_stopping_counter >= patience:
        print("조기 종료 실행.")
        break

writer.close()
torch.save(model.state_dict(), os.path.join(weights_path, 'final_model.pth'))
print("훈련이 완료되었습니다. 최종 모델이 저장되었습니다.")

In [None]:
def predict(tokenizer, model, device, test_loader, fname):
    model.eval()
    summary = []
    with torch.no_grad():
        for input_ids in tqdm(test_loader):
            input_ids = input_ids.to(device, dtype=torch.long)

            pred_ids = model.generate(
                input_ids=input_ids,
                max_length=256, 
                num_beams=4,
                repetition_penalty=2.0, 
                length_penalty=1.0, 
                early_stopping=True,
                no_repeat_ngram_size=2
            )
            for ids in pred_ids:
                result = tokenizer.decode(ids)
                summary.append(result)
                
    remove_tokens = ['<usr>', f"{tokenizer.unk_token}", f"{tokenizer.eos_token}", f"{tokenizer.pad_token}"]
    preprocessed_summary = summary.copy()
    for token in remove_tokens:
        preprocessed_summary = [sentence.replace(token," ") for sentence in preprocessed_summary]

    output = pd.DataFrame(
        {
            "fname": fname,
            "summary" : preprocessed_summary,
        }
    )
    return output

In [None]:
best_model = torch.load(f'{save_path}/weights/best.pth')
output = predict(tokenizer, model, device, test_loader, test_df['fname'])

In [None]:
output.to_csv(f"{save_path}/prediction.csv", index=False)