In [None]:
import os
import torch
import pandas as pd
import pytorch_lightning as pl


from tqdm import tqdm
from rouge import Rouge
from datetime import datetime

from torch import nn
from transformers import AutoTokenizer
from torch.utils.data import Dataset , DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
model_id = "psyche/KoT5-summarization"
train_df = pd.read_csv('../dataset/cleaned_train.csv')
val_df = pd.read_csv('../dataset/cleaned_dev.csv')
test_df = pd.read_csv("../dataset/test.csv")

epochs = 100
batch_size = 2
num_workers = 0
log_interval = 300
dig_max_len = 1024
sum_max_len = 512

tokenizer = AutoTokenizer.from_pretrained(model_id)
special_tokens_dict={'additional_special_tokens': ['#Person1#', '#Person2#','#Person3#', '#Person4#', '#Person5#', '#Person6#', '#Person7#', '#PhoneNumber#', 
                                                   '#Address#', '#PassportNumber#', '#CardNumber#', '#Email#', '#DateOfBirth#',]}

tokenizer.add_special_tokens(special_tokens_dict)
print(tokenizer.special_tokens_map)

remove_tokens = [
    '<usr>',
    f"{tokenizer.unk_token}", 
    f"{tokenizer.eos_token}", 
    f"{tokenizer.pad_token}"
]

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)

In [None]:
class CustomDataset(Dataset):
    def __init__(self, df, tokenizer, input_len, summ_len, is_train=True):
        self.tokenizer = tokenizer
        self.df = df
        self.source_len = input_len
        self.summ_len = summ_len
        self.is_train = is_train
        if self.is_train:
            self.input_ids = tokenizer(self.df['dialogue'].tolist(), return_tensors="pt", padding=True,
                                add_special_tokens=True, truncation=True, max_length=512, return_token_type_ids=False).input_ids
            self.labels = tokenizer(self.df['summary'].tolist(), return_tensors="pt", padding=True,
                                add_special_tokens=True, truncation=True, max_length=100, return_token_type_ids=False).input_ids
        else:
            self.input_ids = tokenizer(self.df['dialogue'].tolist(), return_tensors="pt", padding=True,
                                add_special_tokens=True, truncation=True, max_length=512, return_token_type_ids=False).input_ids
    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        if self.is_train:
            return self.input_ids[idx], self.labels[idx]
        else:
            return self.input_ids[idx]

In [None]:
train_dataset = CustomDataset(train_df[['dialogue', 'summary']], tokenizer, dig_max_len, sum_max_len)
val_dataset = CustomDataset(val_df[['dialogue', 'summary']], tokenizer, dig_max_len, sum_max_len)
test_dataset = CustomDataset(test_df[['dialogue']], tokenizer, dig_max_len, sum_max_len, is_train=False)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [None]:
def ids_to_words(tokenizer, preds, labels):
    decoded_preds = tokenizer.batch_decode(preds, clean_up_tokenization_spaces=True)
    labels = tokenizer.batch_decode(labels, clean_up_tokenization_spaces=True)

    replaced_predictions = decoded_preds.copy()
    replaced_labels = labels.copy()
    # remove_tokens = ['<usr>', f"{tokenizer.unk_token}", f"{tokenizer.eos_token}", f"{tokenizer.pad_token}"]

    for token in remove_tokens:
        replaced_predictions = [sentence.replace(token," ") for sentence in replaced_predictions]
        replaced_labels = [sentence.replace(token," ") for sentence in replaced_labels]
    return replaced_predictions, replaced_labels

In [None]:
def compute_metrics(replaced_predictions, replaced_labels):
    rouge = Rouge()

    results = rouge.get_scores(replaced_predictions, replaced_labels,avg=True)
    result = {key: value["f"] for key, value in results.items()}
    
    return result

In [None]:
def train(epoch, model, device, train_loader, optimizer, log_interval, train_step):
    model.train()
    total_loss = 0.0
    ce_losses = []

    for idx, batch in tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Training Epoch {epoch}"):
        input_ids = batch[0].to(device, dtype=torch.long)
        labels = batch[1].to(device, dtype=torch.long)

        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, labels=labels)
        ce_loss = outputs.loss  # Cross-Entropy Loss만 사용

        # 역전파 및 최적화
        ce_loss.backward()
        optimizer.step()

        total_loss += ce_loss.item()
        ce_losses.append(ce_loss.item())
        train_step += 1

    avg_loss = total_loss / len(train_loader)
    avg_ce_loss = sum(ce_losses) / len(ce_losses)

    return train_step, avg_loss, avg_ce_loss


In [None]:
def validate(tokenizer, model, device, val_loader):
    model.eval()
    total_loss = 0
    all_results = []
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for idx, batch in tqdm(enumerate(val_loader), total=len(val_loader), desc="Validating"):
            input_ids = batch[0].to(device, dtype=torch.long)
            labels = batch[1].to(device, dtype=torch.long)

            pred_ids = model.generate(
                input_ids=input_ids,
                max_length=256, 
                num_beams=4,
                repetition_penalty=2.0, 
                length_penalty=1.0, 
                early_stopping=True,
                no_repeat_ngram_size=2
            )

            loss = model(input_ids=input_ids, labels=labels).loss
            total_loss += loss.item()

            replaced_predictions, replaced_labels = ids_to_words(tokenizer, pred_ids, labels)
            result = compute_metrics(replaced_predictions, replaced_labels)
            
            all_results.append(result)
            all_predictions.extend(replaced_predictions)
            all_labels.extend(replaced_labels)

    val_loss = total_loss / len(val_loader)
    avg_result = {key: sum(r[key] for r in all_results) / len(all_results) for key in all_results[0]}
    
    return val_loss, avg_result, all_predictions, all_labels

In [None]:
train_step = 0

timestamp = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
save_path = os.path.join("./T5_runs", timestamp)
os.makedirs(save_path, exist_ok=True)

best_avg_rouge = 0
for epoch in range(1, epochs + 1):
    train_step, train_loss, train_ce_loss = train(epoch, model, device, train_loader, optimizer, log_interval, train_step)
    val_loss, val_result, val_predictions, val_labels = validate(tokenizer, model, device, val_loader)
    
    avg_rouge = (val_result['rouge-1'] + val_result['rouge-2'] + val_result['rouge-l']) / 3
    print(f"Epoch {epoch}/{epochs}")
    print(f"Train Loss: {train_loss:.6f}, CE Loss: {train_ce_loss:.6f}")
    print(f"Validation Loss: {val_loss:.6f}, Rouge-1: {val_result['rouge-1']:.6f}, Rouge-2: {val_result['rouge-2']:.6f}, Rouge-l: {val_result['rouge-l']:.6f}")
    print(f"Average ROUGE: {avg_rouge:.6f}")
    
    print('-'*150)
    for i in range(3):
        print(f"PRED: {val_predictions[i]}")
        print(f"GOLD: {val_labels[i]}")
        print('-'*150)
    
    # 최고 성능 모델 저장
    if avg_rouge > best_avg_rouge:
        best_avg_rouge = avg_rouge
        torch.save(model.state_dict(), os.path.join(save_path, 'best.pth'))
        print(f"New best model saved with average ROUGE: {best_avg_rouge:.6f}")

    print()    
    torch.save(model.state_dict(), os.path.join(save_path, f'epoch-{epoch}.pth'))

torch.save(model.state_dict(), os.path.join(save_path, 'last.pth'))
print("Training completed. Last model saved.")

print(f"Best average ROUGE: {best_avg_rouge:.6f}")
print(f"Best model saved at: {os.path.join(save_path, 'best.pth')}")
print(f"Last model saved at: {os.path.join(save_path, 'last.pth')}")

In [None]:
def predict(tokenizer, model, device, test_loader, fname):
    model.eval()
    summary = []
    with torch.no_grad():
        for input_ids in tqdm(test_loader):
            input_ids = input_ids.to(device, dtype=torch.long)

            pred_ids = model.generate(
                input_ids=input_ids,
                max_length=256, 
                num_beams=4,
                repetition_penalty=2.0, 
                length_penalty=1.0, 
                early_stopping=True,
                no_repeat_ngram_size=2
            )
            for ids in pred_ids:
                result = tokenizer.decode(ids)
                summary.append(result)
                
    remove_tokens = ['<usr>', f"{tokenizer.unk_token}", f"{tokenizer.eos_token}", f"{tokenizer.pad_token}"]
    preprocessed_summary = summary.copy()
    for token in remove_tokens:
        preprocessed_summary = [sentence.replace(token," ") for sentence in preprocessed_summary]

    output = pd.DataFrame(
        {
            "fname": fname,
            "summary" : preprocessed_summary,
        }
    )
    return output

In [None]:
ckpt_path = "/home/pervinco/Upstage_Ai_Lab/project/notebooks/T5_runs/2024-09-05-15-51-10"
best_model = torch.load(f'{ckpt_path}/best.pth')
output = predict(tokenizer, model, device, test_loader, test_df['fname'])

In [None]:
output.to_csv(f"{ckpt_path}/prediction.csv", index=False)