In [None]:
# !pip install evaluate
# !pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
# !pip install --no-deps "xformers<0.0.27" "trl<0.9.0" peft accelerate bitsandbytes datasets"

In [None]:
DEBUG = False

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import torch

# 모델 체크포인트
checkpoint = "/workspace/LSY/Llama-3.2-3B-Instruct"

print("start load model")
model = AutoModelForCausalLM.from_pretrained(checkpoint)

# 토크나이저 로드
print("start load tokenizer")
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [None]:
from datasets import load_dataset, DatasetDict

# 3만 개의 데이터셋 로드 (경로는 사용자 환경에 맞게 설정)
#raw_datasets = load_dataset("allenai/scitldr", "Abstract")

raw_datasets = load_dataset("/workspace/LSY/4-1/arxiv-summarization dataset")

# if DEBUG :
#     # 데이터 섞고 앞에서 5만 개 선택
#     train_subset = raw_datasets["train"].shuffle(seed=42).select(range(3))
    
#     # 검증, 테스트는 유지
#     validation_subset = raw_datasets["validation"].shuffle(seed=42).select(range(3))
#     test_subset = raw_datasets["test"].shuffle(seed=42).select(range(3))
    
# else :
#     train_subset = raw_datasets["train"].shuffle(seed=42).select(range(1000))
#     validation_subset = raw_datasets["validation"].shuffle(seed=42).select(range(100))
#     test_subset = raw_datasets["test"].shuffle(seed=42).select(range(100))
train_subset = raw_datasets["train"].shuffle(seed=40).select(range(10000))
validation_subset = raw_datasets["validation"].shuffle(seed=40).select(range(250))
test_subset = raw_datasets["test"].shuffle(seed=40).select(range(250))


# 새로운 데이터셋 구성
datasets = DatasetDict({
    "train": train_subset,
    "validation": validation_subset,
    "test": test_subset
})

In [None]:
train_subset

In [None]:
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
Summarize the following article.

### Article:
{}

### Summary:
{}"""

EOS_TOKEN = tokenizer.eos_token

# def preprocess(example):
#     selected_sentences = [
#         sent for sent, label in zip(example["source"], example["source_labels"]) if label == 1
#     ]
#     input_text = " ".join(selected_sentences)
    
#     return {
#         "article": input_text,
#         "abstract": example["target"]
#     }

def formatting_prompts_func(example):
    return {
        "text": alpaca_prompt.format(example["article"], example["abstract"]) + EOS_TOKEN
    }

# 데이터셋 변환
#datasets = datasets.map(preprocess)
datasets = datasets.map(formatting_prompts_func)

print(datasets)

In [None]:
from transformers import DataCollatorForSeq2Seq
def tokenize_function(example):
    model_inputs = tokenizer(
        example["text"],
        max_length=256,
        truncation=True,
        padding="max_length"
    )
    model_inputs["labels"] = model_inputs["input_ids"].copy()
    return model_inputs


In [None]:
model

In [None]:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model.resize_token_embeddings(len(tokenizer))
model

In [None]:
datasets = datasets.map(tokenize_function, remove_columns=datasets["train"].column_names, batched=True)
print(datasets)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True)

In [None]:
import evaluate  # ✅ 새 라이브러리
import numpy as np
# def compute_metrics(pred):
    
#     labels = pred.label_ids
#     preds = pred.predictions.argmax(-1)

#     m1 = evaluate.load('accuracy')
#     m2 = evaluate.load('f1')

#     acc = m1.compute(predictions=preds, references=labels)['accuracy']
#     f1 = m2.compute(predictions=preds, references=labels)['f1']

#     return {'accuracy':acc, 'f1':f1}

metric = evaluate.load('rouge')
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    labels = labels.astype(np.uint16)
    predictions = np.argmax(logits, axis=-1).astype(np.uint16)
    predictions = tokenizer.batch_decode(predictions)
    labels = tokenizer.batch_decode(labels)
    if DEBUG :
        print("\n\n====================")
        print(f"{predictions=}")
        print(f"\n{labels=}")
    return metric.compute(predictions=predictions, references=labels)
    

In [None]:
from peft import LoraConfig, get_peft_model

# peft_config = LoraConfig(
#     task_type="CAUSAL_LM",
#     inference_mode=True,
#     r=16,  # Low-rank 매트릭스 차원 설정
#     target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
#     lora_alpha=32,  # 학습된 LoRA 가중치의 스케일 조절
#     lora_dropout=0.05,  # Dropout 적용
#     bias="none",
#     use_rslora=False,
#     loftq_config=None,
# )
lora_config = LoraConfig(
    r=16,
    lora_alpha=16,
    lora_dropout=0.01,
    bias="none",
    task_type="CAUSAL_LM")
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

In [None]:
from transformers import TrainingArguments, EarlyStoppingCallback, Trainer

    
trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=datasets["train"], 
    eval_dataset=datasets["validation"],
    compute_metrics=compute_metrics,
    # compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience=100)],

    args=TrainingArguments(
        per_device_train_batch_size=8,  # H100 80GB 기준 최적화
        per_device_eval_batch_size = 1,  # 평가 시 배치 크기 감소 (기본 8 → 2)
        gradient_accumulation_steps=4,   # 최소한의 Accumulation
        num_train_epochs=100,  # 데이터가 적어졌으므로 학습 횟수 증가
        eval_strategy = "steps",
        eval_steps = 2 if DEBUG else 100,
        save_steps= 1000,
        load_best_model_at_end = True,
        learning_rate=2e-4,
        bf16=True,
        logging_steps=100,
        weight_decay=0.01,
        seed=40,
        output_dir="outputs",
        save_strategy="steps",
        save_total_limit=3,
        gradient_checkpointing=False,
    )
)

In [None]:
# 학습 실행
trainer.train()

In [None]:
model.save_pretrained("/workspace/LSY/4-1/pretrained/model_early100steps.pt")

In [None]:
def summarize_paper(paper_text):
    prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
Summarize the following section of a research paper in 3 detailed sentences. The summary should include the main arguments, methodology, results, or conclusions of the section, while preserving important context and details. Base your summary on the text provided below.

### Article:
{paper_text}

### Summary:
"""
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    output = model.generate(**inputs, max_new_tokens=250)
    summary = tokenizer.decode(output[0], skip_special_tokens=True)
    return summary

# 예제 논문 요약
sample_paper = ''' This work is subject to limitations in two main
 aspects: (1) Limited Focus on LLM Bias in Me
dia Bias Prediction: The scope of bias analysis is
 constrained by the availability of three-way (left-,
 center-, and right-leaning) labeled data. Our study
 relies on two political bias prediction datasets with
 three-way labels to investigate biases during LLM
 prediction. However, datasets with only biased and
 non-biased labels would not suffice for our analysis
 in this paper. (2) Assumption of Ground Truth: We
 operate under the assumption that human-labeled
 data serves as an unbiased ground truth for assess
ing LLM biases. Nevertheless, human annotations
 are inherently subjective and may be influenced by
 individual biases, potentially impacting the validity
 of our evaluations'''
summary = summarize_paper(sample_paper)
print("Generated Summary:", summary)


In [None]:
import sys
import requests
from bs4 import BeautifulSoup
import re



def sep_no(s):
    if s[0] in ['I', 'V', 'X']:
        tnum, txt = '', ''
        dot_idx = -1
        if '.' in s: # I.A or I.1
            dot_idx = s.find('.')
            tnum += s[:dot_idx+2]
            txt = s[dot_idx+2:]
        elif '-' in s: # I-A or I-1
            dot_idx = s.find('-')
            tnum += s[:dot_idx+2]
            txt = s[dot_idx+2:]
        else:
            # I II III IV V VI VII VIII IX X
            if s.startswith('IIntroduction') or s.startswith('IINTRODUCTION'):
                tnum += 'I'
                txt = 'Introduction'
            elif s.startswith('III'): # 3
                tnum += 'III'
                txt = s[3:].strip()
            elif s.startswith('II') and not s.startswith('IIntro'): # 2
                tnum += 'II'
                txt = s[2:].strip()
            elif s.startswith('IX'): # 9
                tnum += 'IX'
                txt = s[1:].strip()
            elif s.startswith('IV'): # 4
                tnum += 'IV'
                txt = s[2:].strip()
            elif s.startswith('I'): # 1
                tnum += 'I'
                txt = s[1:].strip()
            elif s.startswith('VIII'): # 8
                tnum += 'VIII'
                txt = s[4:].strip()
            elif s.startswith('VII'): # 7
                tnum += 'VII'
                txt = s[3:].strip()
            elif s.startswith('VI'): # 6
                tnum += 'VI'
                txt = s[2:].strip()
            elif s.startswith('V'): # 5
                tnum += 'V'
                txt = s[1:].strip()
            elif s.startswith('X'): # 10
                tnum += 'X'
                txt = s[1:].strip()
            else:
                txt = s

        seperated = (tnum, txt)
        return seperated
    
    elif s[0].isdigit():
        dot_idx = 0
        for i in s:
            if i.isalpha():
                break
            dot_idx += 1
        
        tnum, txt = s[:dot_idx], s[dot_idx:]
        seperated = (tnum, txt)
        return seperated
    
    else: # maybe references section
        tnum, txt = '0', s
        seperated = (tnum, txt)
        return seperated



def extract(url):
    if len(url) < 1:
        print('error: invalid url', file=sys.stderr)
        return

    resp = requests.get(url)
    resp.raise_for_status()

    soup = BeautifulSoup(resp.text, 'html.parser')
    sections = []

    now = {'title': 'default', 'paragraphs': []}

    for element in soup.find_all(['h1', 'h2', 'h3', 'h6', 'p']):
        if element.name == 'h6': # abstract
            title_text = ('0', element.get_text(strip=True))
            now = {
                'title': title_text,
                'paragraphs': []
            }

        elif element.name in ['h1', 'h2', 'h3', 'h6']: # title number + title name
            if now['paragraphs']:
                sections.append(now)
            title_text = element.get_text(strip=True)
            title_text = sep_no(title_text)
            now = {
                'title': title_text,
                'paragraphs': []
            }
            #print(title_text)

        elif element.name == 'p': # paragraph
            paragraph = element.get_text(strip=True)

            # except LaTeX expression
            if paragraph and not re.search(r'[\$]|\\\(|\\\)', paragraph):
                now['paragraphs'].append(paragraph)

    # last section
    if now['paragraphs']:
        sections.append(now)

    return sections

url = "https://arxiv.org/html/2504.07495v1"
# execute
sections = extract(url)

if sections:
    for sec in sections:
        if sec['title'][1] != 'References':
            print(sec['title'])
            summary = summarize_paper('\n'.join(sec['paragraphs']))
            print(summary[len('\n'.join(sec['paragraphs']))+256:])

            sec['paragraphs'] = summary[len('\n'.join(sec['paragraphs']))+271:-2]
        print()
print(sections)

In [None]:
import arxiv
import re

def process_arxiv_url(url):
    arxiv_id = re.search(r'arxiv\.org/abs/([\w\.]+)', url).group(1)
    client = arxiv.Client()
    search = arxiv.Search(id_list=[arxiv_id])
    for paper in client.results(search):
        title = paper.title
        abstract = paper.summary
        with open(f"{arxiv_id}.md", "w", encoding="utf-8") as f:
            f.write(f"# {title}\n\n## Abstract\n{abstract}")
    return f"{arxiv_id}.md"

# 받은 URL 처리
md_file = process_arxiv_url(received_arxiv_url)
print(f"Markdown file created: {md_file}")

In [None]:
def extract_md(md_text, keyword):
    os.makedirs(os.path.dirname(f'/workspace/LSY/4-1/arxiv-summarization/outputs/{keyword}_output.md'), exist_ok=True)
    with open(f'/workspace/LSY/4-1/arxiv-summarization/outputs/{keyword}_output.md', 'w', encoding='utf-8') as file:
        file.write(md_text)

In [None]:
def generate_markdown(related_keywords):
    """
    생성된 키워드와 설명을 마크다운 형식으로 변환합니다.
    """

    for i in related_keywords:
        if i['title'] == 'default':
            keyword = i['paragraphs']
            markdown = f"# {keyword}\n"
        elif len(i['title'][0]) < 2:
            markdown += f" ## {i['title'][0]+ ". " + i['title'][1]}\n"  # 하위 주제
            markdown += f" ### {i['paragraphs']}\n"  # 각 단어를 하위 키워드로 추가
        else:
            markdown += f" #### {i['title'][0]+ " " + i['title'][1]}\n"  # 하위 주제
            markdown += f" ##### {i['paragraphs']}\n"  # 각 단어를 하위 키워드로 추가

    extract_md(markdown, keyword)
    
    return markdown

In [None]:
md = generate_markdown(sections)
print(md)