In [1]:
import os
import torch
import warnings
warnings.filterwarnings("ignore", message="Empty candidate sentence detected; setting raw BERTscores to 0.")

import pandas as pd
import pytorch_lightning as pl


from tqdm import tqdm
from rouge import Rouge
from datetime import datetime
from bert_score import BERTScorer

from torch import nn
from transformers import AutoTokenizer
from torch.utils.data import Dataset , DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
model_id = "lcw99/t5-large-korean-text-summary"
train_df = pd.read_csv('../dataset/cleaned_train.csv')
val_df = pd.read_csv('../dataset/cleaned_dev.csv')
test_df = pd.read_csv("../dataset/test.csv")

epochs = 10
batch_size = 2
num_workers = 0
log_interval = 300
dig_max_len = 1000
sum_max_len = 200

tokenizer = AutoTokenizer.from_pretrained(model_id)
special_tokens_dict={'additional_special_tokens': ['#Person1#', '#Person2#','#Person3#', '#Person4#', '#Person5#', '#Person6#', '#Person7#', '#PhoneNumber#', 
                                                   '#Address#', '#PassportNumber#', '#CardNumber#', '#Email#', '#DateOfBirth#',]}

tokenizer.add_special_tokens(special_tokens_dict)
print(tokenizer.special_tokens_map)

remove_tokens = [
    '<usr>',
    f"{tokenizer.unk_token}", 
    f"{tokenizer.eos_token}", 
    f"{tokenizer.pad_token}"
]

model = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)



{'eos_token': '</s>', 'unk_token': '<pad>', 'pad_token': '<pad>', 'additional_special_tokens': ['#Person1#', '#Person2#', '#Person3#', '#Person4#', '#Person5#', '#Person6#', '#Person7#', '#PhoneNumber#', '#Address#', '#PassportNumber#', '#CardNumber#', '#Email#', '#DateOfBirth#']}


In [4]:
class CustomDataset(Dataset):
    def __init__(self, df, tokenizer, input_len, summ_len, is_train=True):
        self.tokenizer = tokenizer
        self.df = df
        self.source_len = input_len
        self.summ_len = summ_len
        self.is_train = is_train
        if self.is_train:
            self.input_ids = tokenizer(self.df['dialogue'].tolist(), return_tensors="pt", padding=True,
                                add_special_tokens=True, truncation=True, max_length=512, return_token_type_ids=False).input_ids
            self.labels = tokenizer(self.df['summary'].tolist(), return_tensors="pt", padding=True,
                                add_special_tokens=True, truncation=True, max_length=100, return_token_type_ids=False).input_ids
        else:
            self.input_ids = tokenizer(self.df['dialogue'].tolist(), return_tensors="pt", padding=True,
                                add_special_tokens=True, truncation=True, max_length=512, return_token_type_ids=False).input_ids
    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        if self.is_train:
            return self.input_ids[idx], self.labels[idx]
        else:
            return self.input_ids[idx]

In [5]:
train_dataset = CustomDataset(train_df[['dialogue', 'summary']], tokenizer, dig_max_len, sum_max_len)
val_dataset = CustomDataset(val_df[['dialogue', 'summary']], tokenizer, dig_max_len, sum_max_len)
test_dataset = CustomDataset(test_df[['dialogue']], tokenizer, dig_max_len, sum_max_len, is_train=False)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [6]:
def compute_bert_score(predictions, references):
    scorer = BERTScorer(lang="ko", rescale_with_baseline=False)
    P, R, F1 = scorer.score(predictions, references)
    return F1.mean().item()

In [7]:
def ids_to_words(tokenizer, preds, labels):
    decoded_preds = tokenizer.batch_decode(preds, clean_up_tokenization_spaces=True)
    labels = tokenizer.batch_decode(labels, clean_up_tokenization_spaces=True)

    replaced_predictions = decoded_preds.copy()
    replaced_labels = labels.copy()
    # remove_tokens = ['<usr>', f"{tokenizer.unk_token}", f"{tokenizer.eos_token}", f"{tokenizer.pad_token}"]

    for token in remove_tokens:
        replaced_predictions = [sentence.replace(token," ") for sentence in replaced_predictions]
        replaced_labels = [sentence.replace(token," ") for sentence in replaced_labels]
    return replaced_predictions, replaced_labels

In [8]:
def compute_metrics(replaced_predictions, replaced_labels):
    rouge = Rouge()

    results = rouge.get_scores(replaced_predictions, replaced_labels,avg=True)
    result = {key: value["f"] for key, value in results.items()}
    
    return result

In [9]:
def train_rl(epoch, model, device, train_loader, optimizer, log_interval, train_step):
    model.train()
    total_loss = 0.0
    
    for idx, batch in tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Training Epoch {epoch}"):
        input_ids = batch[0].to(device, dtype=torch.long)
        labels = batch[1].to(device, dtype=torch.long)

        # Baseline (greedy decoding)
        with torch.no_grad():
            baseline_output = model.generate(
                input_ids=input_ids,
                max_length=256,
                num_beams=1,
                do_sample=False
            )
        
        # Sample (with sampling)
        sample_output = model.generate(
            input_ids=input_ids,
            max_length=256,
            do_sample=True,
            top_k=0,
            temperature=0.7
        )

        # Compute rewards
        baseline_preds, _ = ids_to_words(tokenizer, baseline_output, labels)
        sample_preds, ref_sums = ids_to_words(tokenizer, sample_output, labels)
        
        baseline_reward = compute_bert_score(baseline_preds, ref_sums)
        sample_reward = compute_bert_score(sample_preds, ref_sums)

        # Compute RL loss
        rl_loss = -(sample_reward - baseline_reward) * model(input_ids=input_ids, labels=sample_output).loss

        optimizer.zero_grad()
        rl_loss.backward()
        optimizer.step()

        total_loss += rl_loss.item()
        train_step += 1

    avg_loss = total_loss / len(train_loader)
    return train_step, avg_loss

In [10]:
def validate(tokenizer, model, device, val_loader):
    model.eval()
    total_loss = 0
    all_results = []
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for idx, batch in tqdm(enumerate(val_loader), total=len(val_loader), desc="Validating"):
            input_ids = batch[0].to(device, dtype=torch.long)
            labels = batch[1].to(device, dtype=torch.long)

            pred_ids = model.generate(
                input_ids=input_ids,
                max_length=256, 
                num_beams=4,
                repetition_penalty=2.0, 
                length_penalty=1.0, 
                early_stopping=True,
                no_repeat_ngram_size=2
            )

            loss = model(input_ids=input_ids, labels=labels).loss
            total_loss += loss.item()

            replaced_predictions, replaced_labels = ids_to_words(tokenizer, pred_ids, labels)
            result = compute_metrics(replaced_predictions, replaced_labels)
            
            all_results.append(result)
            all_predictions.extend(replaced_predictions)
            all_labels.extend(replaced_labels)

    val_loss = total_loss / len(val_loader)
    avg_result = {key: sum(r[key] for r in all_results) / len(all_results) for key in all_results[0]}
    
    return val_loss, avg_result, all_predictions, all_labels

In [11]:
train_step = 0
timestamp = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
save_path = os.path.join("./T5_RL_runs", timestamp)
os.makedirs(save_path, exist_ok=True)

best_bert_score = 0
for epoch in range(1, epochs + 1):
    train_step, train_loss = train_rl(epoch, model, device, train_loader, optimizer, log_interval, train_step)
    val_loss, val_result, val_predictions, val_labels = validate(tokenizer, model, device, val_loader)
    
    bert_score = compute_bert_score(val_predictions, val_labels)
    
    print(f"Epoch {epoch}/{epochs}")
    print(f"Train Loss: {train_loss:.6f}")
    print(f"Validation Loss: {val_loss:.6f}, BERT Score: {bert_score:.6f}")
    
    print('-'*150)
    for i in range(3):
        print(f"PRED: {val_predictions[i]}")
        print(f"GOLD: {val_labels[i]}")
        print('-'*150)
    
    if bert_score > best_bert_score:
        best_bert_score = bert_score
        torch.save(model.state_dict(), os.path.join(save_path, 'best.pth'))
        print(f"New best model saved with BERT Score: {best_bert_score:.6f}")
    
    print()
    torch.save(model.state_dict(), os.path.join(save_path, f'epoch-{epoch}.pth'))

torch.save(model.state_dict(), os.path.join(save_path, 'last.pth'))
print("Training completed. Last model saved.")

Training Epoch 1:   0%|          | 0/6229 [00:00<?, ?it/s]/opt/conda/conda-bld/pytorch_1712608885084/work/aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [856,0,0], thread: [32,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/conda/conda-bld/pytorch_1712608885084/work/aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [856,0,0], thread: [33,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/conda/conda-bld/pytorch_1712608885084/work/aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [856,0,0], thread: [34,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/conda/conda-bld/pytorch_1712608885084/work/aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [856,0,0], thread: [35,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/conda/conda-bld/pytorch_1712608885084/work/aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [856,0,0], thread: [36,0,0] 

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
def predict(tokenizer, model, device, test_loader, fname):
    model.eval()
    summary = []
    with torch.no_grad():
        for input_ids in tqdm(test_loader):
            input_ids = input_ids.to(device, dtype=torch.long)

            pred_ids = model.generate(
                input_ids=input_ids,
                max_length=256, 
                num_beams=4,
                repetition_penalty=2.0, 
                length_penalty=1.0, 
                early_stopping=True,
                no_repeat_ngram_size=2
            )
            for ids in pred_ids:
                result = tokenizer.decode(ids)
                summary.append(result)
                
    remove_tokens = ['<usr>', f"{tokenizer.unk_token}", f"{tokenizer.eos_token}", f"{tokenizer.pad_token}"]
    preprocessed_summary = summary.copy()
    for token in remove_tokens:
        preprocessed_summary = [sentence.replace(token," ") for sentence in preprocessed_summary]

    output = pd.DataFrame(
        {
            "fname": fname,
            "summary" : preprocessed_summary,
        }
    )
    return output

In [None]:
# ckpt_path = "/home/pervinco/Upstage_Ai_Lab/project/notebooks/T5_runs/2024-09-05-15-51-10"
# best_model = torch.load(f'{ckpt_path}/best.pth')
# output = predict(tokenizer, model, device, test_loader, test_df['fname'])
# output.to_csv(f"{ckpt_path}/prediction.csv", index=False)