In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%cd /content/drive/MyDrive/Final_Project/TextSummarization

/content/drive/MyDrive/Final_Project/TextSummarization


In [None]:
%ls

add_data.ipynb  [0m[01;34mKoBART-summarization[0m/  ROUGE.ipynb  tech_test.tsv  [01;34mTraining[0m/  [01;34mValidation[0m/


In [None]:
!pip install rouge
!pip install korouge_score

Collecting korouge_score
  Using cached korouge_score-0.1.4-py3-none-any.whl (28 kB)
Installing collected packages: korouge_score
Successfully installed korouge_score-0.1.4


In [None]:
import pandas as pd
import torch
from transformers import BartForConditionalGeneration, PreTrainedTokenizerFast
from rouge import Rouge
from korouge_score import rouge_scorer

In [None]:
# CUDA 사용 가능 확인 및 장치 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [None]:
# 모델 및 토크나이저 로드
model_original = BartForConditionalGeneration.from_pretrained('digit82/kobart-summarization').to(device)
model_finetuned = BartForConditionalGeneration.from_pretrained('./KoBART-summarization/kobart_summary').to(device)
tokenizer = PreTrainedTokenizerFast.from_pretrained('digit82/kobart-summarization')

You passed along `num_labels=3` with an incompatible id to label map: {'0': 'NEGATIVE', '1': 'POSITIVE'}. The number of labels wil be overwritten to 2.
You passed along `num_labels=3` with an incompatible id to label map: {'0': 'NEGATIVE', '1': 'POSITIVE'}. The number of labels wil be overwritten to 2.


In [None]:
scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL", "rougeLsum"], use_stemmer=True)

def summarize(text, model):
    """주어진 텍스트에 대해 요약을 수행하고 결과를 반환"""
    if not isinstance(text, str):
        text = str(text)  # 숫자 또는 다른 타입을 문자열로 변환
    input_ids = tokenizer.encode(text, return_tensors='pt', max_length=512, truncation=True)
    input_ids = input_ids.to(device)  # 입력 데이터를 GPU로 이동
    summary_ids = model.generate(input_ids, eos_token_id=1, max_length=512, num_beams=5)
    summary_text = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return summary_text

def calculate_rouge_scores(data_path):
    """TSV 파일을 읽어 각 행에 대해 요약을 수행하고 ROUGE 점수를 계산하며, 평균 점수를 반환"""
    data = pd.read_csv(data_path, sep='\t', quoting=3)
    data = data.dropna(subset=['news', 'summary'])
    data['news'] = data['news'].astype(str)
    data['summary'] = data['summary'].astype(str)

    # 점수 집계를 위한 초기화
    total_scores_original = {key: {"precision": 0, "recall": 0, "fmeasure": 0} for key in ['rouge1', 'rouge2', 'rougeL', 'rougeLsum']}
    total_scores_finetuned = {key: {"precision": 0, "recall": 0, "fmeasure": 0} for key in ['rouge1', 'rouge2', 'rougeL', 'rougeLsum']}
    num_entries = 0

    for index, row in data.iterrows():
        print(f"Processing row {index}...")
        text = row['news']
        summary_label = row['summary']

        summary_original = summarize(text, model_original)
        summary_finetuned = summarize(text, model_finetuned)

        scores_original = scorer.score(summary_label, summary_original)
        scores_finetuned = scorer.score(summary_label, summary_finetuned)

        # 점수 집계
        for key in scores_original:
            total_scores_original[key]["precision"] += scores_original[key].precision
            total_scores_original[key]["recall"] += scores_original[key].recall
            total_scores_original[key]["fmeasure"] += scores_original[key].fmeasure
            total_scores_finetuned[key]["precision"] += scores_finetuned[key].precision
            total_scores_finetuned[key]["recall"] += scores_finetuned[key].recall
            total_scores_finetuned[key]["fmeasure"] += scores_finetuned[key].fmeasure

        num_entries += 1

    # 평균 점수 계산
    average_scores_original = {key: {k: v / num_entries for k, v in total_scores_original[key].items()} for key in total_scores_original}
    average_scores_finetuned = {key: {k: v / num_entries for k, v in total_scores_finetuned[key].items()} for key in total_scores_finetuned}

    return average_scores_original, average_scores_finetuned

In [None]:
# TSV 파일 경로
file_path = './tech_test.tsv'

# ROUGE 점수 계산 실행
average_scores_original, average_scores_finetuned = calculate_rouge_scores(file_path)
print("Average ROUGE scores for the original model:", average_scores_original)
print("Average ROUGE scores for the finetuned model:", average_scores_finetuned)

Processing row 0...
Processing row 2...
Processing row 4...
Processing row 6...
Processing row 8...
Processing row 10...
Processing row 12...
Processing row 14...
Processing row 16...
Processing row 18...
Processing row 20...
Processing row 22...
Processing row 24...
Processing row 26...
Processing row 28...
Processing row 30...
Processing row 32...
Processing row 34...
Processing row 36...
Processing row 38...
Processing row 40...
Processing row 42...
Processing row 44...
Processing row 46...
Processing row 48...
Processing row 50...
Processing row 52...
Processing row 54...
Processing row 56...
Processing row 58...
Processing row 60...
Processing row 62...
Processing row 64...
Processing row 66...
Processing row 68...
Processing row 70...
Processing row 72...
Processing row 74...
Processing row 76...
Processing row 78...
Processing row 80...
Processing row 82...
Processing row 84...
Processing row 86...
Processing row 88...
Processing row 90...
Processing row 92...
Processing row 94.

In [36]:
# TSV 파일 경로
file_path2 = './KoBART-summarization/data/test_add.tsv'

# ROUGE 점수 계산 실행
average_scores_original2, average_scores_finetuned2 = calculate_rouge_scores(file_path2)
print("Average ROUGE scores for the original model:", average_scores_original2)
print("Average ROUGE scores for the finetuned model:", average_scores_finetuned2)

Processing row 0...
Processing row 1...
Processing row 2...
Processing row 3...
Processing row 4...
Processing row 5...
Processing row 6...
Processing row 7...
Processing row 8...
Processing row 9...
Processing row 10...
Processing row 11...
Processing row 12...
Processing row 13...
Processing row 14...
Processing row 15...
Processing row 16...
Processing row 17...
Processing row 18...
Processing row 19...
Processing row 20...
Processing row 21...
Processing row 22...
Processing row 23...
Processing row 24...
Processing row 25...
Processing row 26...
Processing row 27...
Processing row 28...
Processing row 29...
Processing row 30...
Processing row 31...
Processing row 32...
Processing row 33...
Processing row 34...
Processing row 35...
Processing row 36...
Processing row 37...
Processing row 38...
Processing row 39...
Processing row 40...
Processing row 41...
Processing row 42...
Processing row 43...
Processing row 45...
Processing row 46...
Processing row 47...
Processing row 48...
Pr

KeyboardInterrupt: 