In [4]:
import csv
!pip install lxml
!pip install datasets



In [2]:
RROMANI_BIBLE_BASE_URL = "https://biblia-rromani.ro"
HEADERS = {"User-Agent": "Mozilla/5.0"}
BIBLE_BY_MARCU_CONFIG = {
    'n_pages': 16,
    'url': f"{RROMANI_BIBLE_BASE_URL}/marku/41-MRK-"
}

BIBLE_BY_JOHN_CONFIG = {
    'n_chapters': 21,
    'url': f"{RROMANI_BIBLE_BASE_URL}/joan/43-JHN-"
}

In [3]:
def get_page_content(page_url):
    resp = requests.get(page_url, headers=HEADERS)
    resp.encoding = 'utf-8'
    soup = BeautifulSoup(resp.text, 'html.parser')
    return soup.find("div", {"id": "content"})

def clean_text(element):
    # Remove footnotes
    for fn in element.select(".footnote"):
        fn.decompose()
    return element.get_text(" ", strip=True)


In [4]:

from bs4 import BeautifulSoup
import requests
import re

def extract_chapter(book, chapter, url):
    resp = requests.get(url, headers={"User-Agent": "Mozilla/5.0"})
    resp.raise_for_status()
    resp.encoding = "utf-8"

    soup = BeautifulSoup(resp.text, "html.parser")
    content = soup.find("div", id="content")

    anchors = content.find_all("a", id=re.compile(r"^v\d+$"))
    rows = []

    for i, a in enumerate(anchors):
        verse_num = int(a["id"][1:])
        end = anchors[i + 1] if i + 1 < len(anchors) else None

        parts = []
        node = a.next_element

        while node and node != end:
            if getattr(node, "name", None) == "div" and "txs" in (node.get("class") or []):
                # remove footnotes
                for fn in node.select(".footnote"):
                    fn.decompose()

                text = node.get_text(" ", strip=True)
                text = re.sub(r"^\d+\s*", "", text).strip()

                if text:
                    parts.append(text)

            node = node.next_element

        verse_text = " ".join(parts)
        verse_text = re.sub(r"\s+", " ", verse_text).strip()

        rows.append([
            book,
            chapter,
            verse_num,
            verse_text
        ])

    return rows




## Rromani John Bible Scraper

In [5]:
JOHN_OUTPUT_CSV = "john_all_chapters.csv"
JOHN_BASE_URL = "https://biblia-rromani.ro/joan/43-JHN-{chapter:03d}.html"

In [6]:
def run_pipeline():
    all_rows = []
    for chapter in range(1, 22):
        url = JOHN_BASE_URL.format(chapter=chapter)
        print(f"Processing John chapter: {chapter}")
        rows = extract_chapter("John", chapter, url)
        all_rows.extend(rows)

    with open(JOHN_OUTPUT_CSV, "w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerows(["book", "chapter", "verse", "text"])
        writer.writerows(all_rows)


In [31]:
run_pipeline()

Processing John chapter: 1
Processing John chapter: 2
Processing John chapter: 3
Processing John chapter: 4
Processing John chapter: 5
Processing John chapter: 6
Processing John chapter: 7
Processing John chapter: 8
Processing John chapter: 9
Processing John chapter: 10
Processing John chapter: 11
Processing John chapter: 12
Processing John chapter: 13
Processing John chapter: 14
Processing John chapter: 15
Processing John chapter: 16
Processing John chapter: 17
Processing John chapter: 18
Processing John chapter: 19
Processing John chapter: 20
Processing John chapter: 21


## Rromani Marcu Bible Scraper

In [32]:
MARK_OUTPUT_CSV = "mark_all_chapters.csv"
MARK_BASE_URL = "https://biblia-rromani.ro/marku/41-MRK-{chapter:03d}.html"

In [33]:
def run_pipeline():
    all_rows = []
    for chapter in range(1, 17):
        url = MARK_BASE_URL.format(chapter=chapter)
        print(f"Processing John chapter: {chapter}")
        rows = extract_chapter("Mark", chapter, url)
        all_rows.extend(rows)

    with open(MARK_OUTPUT_CSV, "w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerows(["book", "chapter", "verse", "text"])
        writer.writerows(all_rows)

In [43]:
run_pipeline()

Processing John chapter: 1


HTTPError: 404 Client Error: Not Found for url: https://biblia-rromani.ro/apostola/01-GEN-001.html

## Rromani Acts Bible Scraper

In [23]:
ACTS_OUTPUT_CSV = "acts_all_chapters.csv"
ACTS_BASE_URL = "https://biblia-rromani.ro/apostola/44-ACT-{chapter:03d}.html"

In [37]:
def run_pipeline():
    all_rows = []
    for chapter in range(1, 29):
        url = ACTS_BASE_URL.format(chapter=chapter)
        print(f"Processing John chapter: {chapter}")
        rows = extract_chapter("Acts", chapter, url)
        all_rows.extend(rows)

    with open(ACTS_OUTPUT_CSV, "w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerows(["book", "chapter", "verse", "text"])
        writer.writerows(all_rows)

In [38]:
run_pipeline()

Processing John chapter: 1
Processing John chapter: 2
Processing John chapter: 3
Processing John chapter: 4
Processing John chapter: 5
Processing John chapter: 6
Processing John chapter: 7
Processing John chapter: 8
Processing John chapter: 9
Processing John chapter: 10
Processing John chapter: 11
Processing John chapter: 12
Processing John chapter: 13
Processing John chapter: 14
Processing John chapter: 15
Processing John chapter: 16
Processing John chapter: 17
Processing John chapter: 18
Processing John chapter: 19
Processing John chapter: 20
Processing John chapter: 21
Processing John chapter: 22
Processing John chapter: 23
Processing John chapter: 24
Processing John chapter: 25
Processing John chapter: 26
Processing John chapter: 27
Processing John chapter: 28


## Rromani Geneza Bible Scraper

In [50]:
GENEZA_OUTPUT_CSV = "geneza_all_chapters.csv"
GENEZA_BASE_URL = "https://biblia-rromani.ro/geneza/01-GEN-{chapter:03d}.html"

In [51]:
def run_pipeline():
    all_rows = []
    for chapter in range(1, 5):
        url = GENEZA_BASE_URL.format(chapter=chapter)
        print(f"Processing Geneza chapter: {chapter}")
        rows = extract_chapter("Geneza", chapter, url)
        all_rows.extend(rows)

    with open(GENEZA_OUTPUT_CSV, "w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerows(["book", "chapter", "verse", "text"])
        writer.writerows(all_rows)

In [52]:
run_pipeline()

Processing Geneza chapter: 1
Processing Geneza chapter: 2
Processing Geneza chapter: 3
Processing Geneza chapter: 4


## Scrape Romanian

In [20]:
def extract_chapter_romanian(book, chapter, url):
    resp = requests.get(url, headers={"User-Agent": "Mozilla/5.0"})
    resp.raise_for_status()
    resp.encoding = "utf-8"

    soup = BeautifulSoup(resp.text, "html.parser")
    content = soup.find("div", id="biblia-text")
    vers_list = content.find("ul", id="verslist").find_all("li", class_="verset")

    rows = []
    for vers in vers_list:
        vers_content = vers.find("a", class_="verset-content")

        verse_num = vers_content.find("span", class_="numar-verset").find("sup").text
        verse_text = vers_content.find("span", class_="continut-verset-simplu").text

        rows.append([
            book,
            chapter,
            verse_num,
            verse_text
        ])

    return rows

In [53]:
ROMANIAN_MARK_BASE_URL = "https://biblia.resursecrestine.ro/marcu/{chapter}"
ROMANIAN_MARK_OUTPUT = "romanian_mark_all_chapters.csv"

def run_mark_romanian_pipeline():
    all_rows = []
    for chapter in range(1, 17):
        url = ROMANIAN_MARK_BASE_URL.format(chapter=chapter)
        print(f"Processing Mark chapter: {chapter}")

        rows = extract_chapter_romanian("Mark", chapter, url)
        all_rows.extend(rows)

    with open(ROMANIAN_MARK_OUTPUT, "w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerows(["book", "chapter", "verse", "text"])
        writer.writerows(all_rows)

In [54]:
run_mark_romanian_pipeline()

Processing Mark chapter: 1
Processing Mark chapter: 2
Processing Mark chapter: 3
Processing Mark chapter: 4
Processing Mark chapter: 5
Processing Mark chapter: 6
Processing Mark chapter: 7
Processing Mark chapter: 8
Processing Mark chapter: 9
Processing Mark chapter: 10
Processing Mark chapter: 11
Processing Mark chapter: 12
Processing Mark chapter: 13
Processing Mark chapter: 14
Processing Mark chapter: 15
Processing Mark chapter: 16


In [26]:
ROMANIAN_JOHN_BASE_URL = "https://biblia.resursecrestine.ro/ioan/{chapter}"
ROMANIAN_JOHN_OUTPUT = "romanian_john_all_chapters.csv"

def run_john_romanian_pipeline():
    all_rows = []
    for chapter in range(1, 22):
        url = ROMANIAN_JOHN_BASE_URL.format(chapter=chapter)
        print(f"Processing John chapter: {chapter}")

        rows = extract_chapter_romanian("John", chapter, url)
        all_rows.extend(rows)

    with open(ROMANIAN_JOHN_OUTPUT, "w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerows(["book", "chapter", "verse", "text"])
        writer.writerows(all_rows)

In [27]:
run_john_romanian_pipeline()

Processing John chapter: 1
Processing John chapter: 2
Processing John chapter: 3
Processing John chapter: 4
Processing John chapter: 5
Processing John chapter: 6
Processing John chapter: 7
Processing John chapter: 8
Processing John chapter: 9
Processing John chapter: 10
Processing John chapter: 11
Processing John chapter: 12
Processing John chapter: 13
Processing John chapter: 14
Processing John chapter: 15
Processing John chapter: 16
Processing John chapter: 17
Processing John chapter: 18
Processing John chapter: 19
Processing John chapter: 20
Processing John chapter: 21


In [30]:
ROMANIAN_GENEZA_BASE_URL = "https://biblia.resursecrestine.ro/geneza/{chapter}"
ROMANIAN_GENEZA_OUTPUT = "romanian_geneza_all_chapters.csv"

def run_geneza_romanian_pipeline():
    all_rows = []
    for chapter in range(1, 5):
        url = ROMANIAN_GENEZA_BASE_URL.format(chapter=chapter)
        print(f"Processing Geneza chapter: {chapter}")

        rows = extract_chapter_romanian("Geneza", chapter, url)
        all_rows.extend(rows)

    with open(ROMANIAN_GENEZA_OUTPUT, "w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerows(["book", "chapter", "verse", "text"])
        writer.writerows(all_rows)

In [31]:
run_geneza_romanian_pipeline()

Processing Geneza chapter: 1
Processing Geneza chapter: 2
Processing Geneza chapter: 3
Processing Geneza chapter: 4


In [32]:
ROMANIAN_ACTS_BASE_URL = "https://biblia.resursecrestine.ro/faptele-apostolilor/{chapter}"
ROMANIAN_ACTS_OUTPUT = "romanian_acts_all_chapters.csv"

def run_acts_romanian_pipeline():
    all_rows = []
    for chapter in range(1, 29):
        url = ROMANIAN_ACTS_BASE_URL.format(chapter=chapter)
        print(f"Processing Acts chapter: {chapter}")

        rows = extract_chapter_romanian("Acts", chapter, url)
        all_rows.extend(rows)

    with open(ROMANIAN_ACTS_OUTPUT, "w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerows(["book", "chapter", "verse", "text"])
        writer.writerows(all_rows)

In [33]:
run_acts_romanian_pipeline()

Processing Acts chapter: 1
Processing Acts chapter: 2
Processing Acts chapter: 3
Processing Acts chapter: 4
Processing Acts chapter: 5
Processing Acts chapter: 6
Processing Acts chapter: 7
Processing Acts chapter: 8
Processing Acts chapter: 9
Processing Acts chapter: 10
Processing Acts chapter: 11
Processing Acts chapter: 12
Processing Acts chapter: 13
Processing Acts chapter: 14
Processing Acts chapter: 15
Processing Acts chapter: 16
Processing Acts chapter: 17
Processing Acts chapter: 18
Processing Acts chapter: 19
Processing Acts chapter: 20
Processing Acts chapter: 21
Processing Acts chapter: 22
Processing Acts chapter: 23
Processing Acts chapter: 24
Processing Acts chapter: 25
Processing Acts chapter: 26
Processing Acts chapter: 27
Processing Acts chapter: 28


In [55]:
import pandas as pd
import os

def align_plain_text_bibles(ro_path, rom_path, output_csv):
    with open(ro_path, 'r', encoding='utf-8') as f:
        ro_lines = [line.strip() for line in f if line.strip()]

    with open(rom_path, 'r', encoding='utf-8') as f:
        rom_lines = [line.strip() for line in f if line.strip()]

    # Print diagnostics
    print(f"Romanian lines: {len(ro_lines)}")
    print(f"Rromani lines: {len(rom_lines)}")

    # Check for major misalignment
    diff = abs(len(ro_lines) - len(rom_lines))
    if diff > 0:
        print(f" Warning: Files differ by {diff} lines. Data may shift!")

    # Use the shorter length to avoid IndexErrors
    min_len = min(len(ro_lines), len(rom_lines))

    aligned_data = []
    for i in range(min_len):
        aligned_data.append({
            'source_ro': ro_lines[i],
            'target_rom': rom_lines[i]
        })

    df = pd.DataFrame(aligned_data)
    df.to_csv(output_csv, index=False, encoding='utf-8')
    print(f"Created parallel dataset with {min_len} pairs.")

def align_custom_dataset(base_dir):
    # Mapping filenames to ensure we pair the right books
    books = ['acts', 'geneza', 'john', 'mark']
    all_frames = []

    for book in books:
        # Construct paths based on your image structure
        ro_path = os.path.join(base_dir, 'romanian', f'romanian_{book}_all_chapters.csv')
        rom_path = os.path.join(base_dir, 'rromani', f'{book}_all_chapters.csv')

        # Load CSVs - assuming no header based on your example
        cols = ['Book', 'Chapter', 'Verse', 'Text']
        df_ro = pd.read_csv(ro_path, names=cols, header=None)
        df_rom = pd.read_csv(rom_path, names=cols, header=None)

        # Create a unique key for matching: "Acts_1_1"
        df_ro['ID'] = df_ro['Book'] + "_" + df_ro['Chapter'].astype(str) + "_" + df_ro['Verse'].astype(str)
        df_rom['ID'] = df_rom['Book'] + "_" + df_rom['Chapter'].astype(str) + "_" + df_rom['Verse'].astype(str)

        # Perform an 'inner' merge to keep only verses present in both
        merged = pd.merge(
            df_ro[['ID', 'Text']],
            df_rom[['ID', 'Text']],
            on='ID',
            suffixes=('_ro', '_rom')
        )

        print(f"Book: {book} | Matched {len(merged)} verses.")
        all_frames.append(merged)

    # Combine all books into one master training file
    final_df = pd.concat(all_frames, ignore_index=True)
    final_df.to_csv('merged_parallel_corpus.csv', index=False, encoding='utf-8')
    print(f"\n✅ Total Parallel Pairs: {len(final_df)}")
    return final_df

In [42]:
align_plain_text_bibles('../bible-uedin.ro-rom.ro', '../bible-uedin.ro-rom.rom', 'uedin_ro_rom.csv')

Romanian lines: 7931
Rromani lines: 7931
✅ Created parallel dataset with 7931 pairs.


In [57]:
custom_dataset = align_custom_dataset('../data')

Book: acts | Matched 1007 verses.
Book: geneza | Matched 106 verses.
Book: john | Matched 879 verses.
Book: mark | Matched 678 verses.

✅ Total Parallel Pairs: 2670


## Data Merging and Preprocessing

- Deduplication
- Normalizaiton
- Length Filtering

In [1]:
MODEL_NAME = "facebook/nllb-200-distilled-600M"
CSV_PATH = "full_corpus.csv"
OUTPUT_PATH = "./nllb_ro_rromani_lora"


In [2]:
import unicodedata

def normalize(text):
    text = unicodedata.normalize('NFC', text)
    text = " ".join(text.split())
    return text

In [6]:
import pandas as pd

custom_df = pd.read_csv('merged_parallel_corpus.csv')
uedin_df = pd.read_csv('uedin_ro_rom.csv')

custom_df["source"] = "custom"
uedin_df["source"] = "uedin"

custom_df = custom_df.rename(columns={'Text_ro': 'ro', 'Text_rom': 'rmy'})
uedin_df = uedin_df.rename(columns={'source_ro': 'ro', 'target_rom': 'rmy'})

df = pd.concat([custom_df, uedin_df], ignore_index=True)
df = df.drop_duplicates(subset=['ro', 'rmy'])

conflicts = (
    df.groupby("ro")
    .filter(lambda x: len(x) > 1)
)

print("RO sentences with multiple translations:", conflicts["ro"].nunique())

df["ro"] = df["ro"].apply(normalize)
df["rmy"] = df["rmy"].apply(normalize)

df["len_ro"] = df["ro"].str.split().str.len()
df["len_rmy"] = df["rmy"].str.split().str.len()

df["ratio"] = df["len_rmy"] / df["len_ro"]

df = df[(df["ratio"] > 0.5) & (df["ratio"] < 2.0)]

print("Final rows:", len(df))
print(df.sample(5))

df.to_csv(CSV_PATH, index=False, encoding='utf-8')
# train_df, eval_df = train_test_split(df_all, test_size=0.1, random_state=42)

RO sentences with multiple translations: 107
Final rows: 10379
            ID                                                 ro  \
8626       NaN  Dar Scriptura a închis totul supt păcat, pentr...   
7344       NaN  ,Cine eşti, Doamne?` am răspuns eu. Şi Domnul ...   
8788       NaN  ci, credincioşi adevărului, în dragoste, să cr...   
2073  Mark_3_9  Isus a poruncit ucenicilor să-I ţină la îndemâ...   
3857       NaN  dar n'au rădăcină în ei, ci ţin pînă la o vrem...   

                                                    rmy  source  len_ro  \
8626  Numa E Vorba le Devleski mothol ke sa e lumia ...   uedin      20   
7344  Ai me phendem, "Kon san tu Devla," ai o Del ph...   uedin      17   
8788  Numa phenas o chachimos la dragostiasa, sagda ...   uedin      19   
2073  O Isus porunćisardǎs le ućenićenqe te anen Les...  custom      17   
3857  Numa chi gelo divano dur ande lengo ilo ai nas...   uedin      27   

      len_rmy     ratio  
8626       36  1.800000  
7344       11  0.64

In [7]:
print(df.columns)
print("Rows:", len(df))

Index(['ID', 'ro', 'rmy', 'source', 'len_ro', 'len_rmy', 'ratio'], dtype='object')
Rows: 10379


In [None]:
import pandas as pd
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)
from peft import LoraConfig, get_peft_model


OUTPUT_DIR = "./nllb_ro_rromani_lora"
MAX_LEN = 256

df = pd.read_csv(CSV_PATH)

df["ro"] = df["ro"].astype(str)
df["rmy"] = df["rmy"].astype(str)

df = df[
    df["ro"].notna() &
    df["rmy"].notna() &
    (df["ro"].str.strip() != "") &
    (df["rmy"].str.strip() != "")
]

df = df.reset_index(drop=True)

print("Rows after cleaning:", len(df))

ds = Dataset.from_pandas(df[["ro", "rmy"]])
ds = ds.train_test_split(test_size=0.08, seed=42)  # 92/8 split
train_ds, eval_ds = ds["train"], ds["test"]

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
SRC_LANG = "ron_Latn"
TGT_LANG = "rom_Latn"
tokenizer.src_lang = SRC_LANG
forced_bos_token_id = tokenizer.convert_tokens_to_ids(TGT_LANG)

def preprocess(batch):
    ro_list = batch.get("ro")
    rmy_list = batch.get("rmy")

    # Arrow may give empty batches
    if ro_list is None or rmy_list is None:
        return {}

    cleaned_ro = []
    cleaned_rmy = []

    for ro, rmy in zip(ro_list, rmy_list):
        if ro and rmy:
            ro = str(ro).strip()
            rmy = str(rmy).strip()
            if ro != "" and rmy != "":
                cleaned_ro.append(ro)
                cleaned_rmy.append(rmy)

    # IMPORTANT: skip empty batches
    if len(cleaned_ro) == 0:
        return {}

    model_inputs = tokenizer(
        cleaned_ro,
        truncation=True,
        max_length=MAX_LEN,
    )

    labels = tokenizer(
        cleaned_rmy,
        truncation=True,
        max_length=MAX_LEN,
    )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

train_tok = train_ds.map(preprocess, batched=True, remove_columns=train_ds.column_names)
eval_tok = eval_ds.map(preprocess, batched=True, remove_columns=eval_ds.column_names)

In [None]:
# ---- 5) Load model in 4-bit and attach LoRA ----
model = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL_NAME,
    device_map="auto",
)

# LoRA config for seq2seq (target attention projections)
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="SEQ_2_SEQ_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)

model = get_peft_model(model, lora_config)

# Ensure correct target language during generation
model.config.forced_bos_token_id = forced_bos_token_id

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

# ---- 6) Metrics ----
def postprocess_text(preds, labels):
    preds = [p.strip() for p in preds]
    labels = [l.strip() for l in labels]
    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    # Replace -100 with pad token for decoding
    labels = [[(t if t != -100 else tokenizer.pad_token_id) for t in seq] for seq in labels]

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    bleu = sacrebleu.corpus_bleu(decoded_preds, [decoded_labels]).score
    chrf = sacrebleu.corpus_chrf(decoded_preds, [decoded_labels]).score
    return {"bleu": bleu, "chrf": chrf}

# ---- 7) Training args ----
args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=4,   # effective batch 32
    learning_rate=2e-4,              # LoRA can tolerate higher LR
    num_train_epochs=10,
    warmup_ratio=0.05,
    logging_steps=100,
    eval_steps=500,
    save_steps=500,
    evaluation_strategy="steps",
    save_total_limit=2,
    predict_with_generate=True,
    generation_max_length=MAX_LEN,
    fp16=True,
    report_to="none",
)

trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=train_tok,
    eval_dataset=eval_tok,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)

print("Saved to:", OUTPUT_DIR)


In [None]:
import os
import re
import pandas as pd
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)
from peft import LoraConfig, get_peft_model
import sacrebleu

OUTPUT_DIR = "./nllb_ro_rromani_lora"

# ---- 1) Load data ----
df = pd.read_csv(CSV_PATH)
df = df.copy()

# Basic cleaning (minimal; avoid over-normalization)
def clean(s: str) -> str:
    s = "" if pd.isna(s) else str(s)
    s = " ".join(s.split())
    return s

df["ro"] = df["ro"].astype(str)
df["rmy"] = df["rmy"].astype(str)

# drop invalid rows
df = df[
    df["ro"].notna() &
    df["rmy"].notna() &
    (df["ro"].str.strip() != "") &
    (df["rmy"].str.strip() != "")
]

df = df.reset_index(drop=True)
df["ro"] = df["ro"].map(clean)
df["rmy"] = df["rmy"].map(clean)

df = df[(df["ro"] != "") & (df["rmy"] != "")]
df = df.drop_duplicates(subset=["ro", "rmy"]).reset_index(drop=True)

assert df["ro"].apply(lambda x: isinstance(x, str)).all()
assert df["rmy"].apply(lambda x: isinstance(x, str)).all()

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
SRC_LANG = "ron_Latn"
TGT_LANG = "rom_Latn"

tokenizer.src_lang = SRC_LANG
forced_bos_token_id = tokenizer.convert_tokens_to_ids(TGT_LANG)

# ---- 3) Train/val split ----
ds = Dataset.from_pandas(df[["ro", "rmy"]])
ds = ds.train_test_split(test_size=0.08, seed=42)  # 92/8 split
train_ds, eval_ds = ds["train"], ds["test"]

# ---- 4) Preprocess ----
MAX_LEN = 256

def preprocess(batch):
    inputs = batch["ro"]
    targets = batch["rmy"]

    model_inputs = tokenizer(
        inputs,
        max_length=MAX_LEN,
        truncation=True,
    )

    labels = tokenizer(
        text_target=targets,
        max_length=MAX_LEN,
        truncation=True,
    )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

train_tok = train_ds.map(preprocess, batched=True, remove_columns=train_ds.column_names)
eval_tok = eval_ds.map(preprocess, batched=True, remove_columns=eval_ds.column_names)

# ---- 5) Load model in 4-bit and attach LoRA ----
model = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL_NAME,
    load_in_4bit=True,
    device_map="auto",
)

# LoRA config for seq2seq (target attention projections)
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="SEQ_2_SEQ_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)

model = get_peft_model(model, lora_config)

# Ensure correct target language during generation
model.config.forced_bos_token_id = forced_bos_token_id

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

# ---- 6) Metrics ----
def postprocess_text(preds, labels):
    preds = [p.strip() for p in preds]
    labels = [l.strip() for l in labels]
    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    # Replace -100 with pad token for decoding
    labels = [[(t if t != -100 else tokenizer.pad_token_id) for t in seq] for seq in labels]

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    bleu = sacrebleu.corpus_bleu(decoded_preds, [decoded_labels]).score
    chrf = sacrebleu.corpus_chrf(decoded_preds, [decoded_labels]).score
    return {"bleu": bleu, "chrf": chrf}

# ---- 7) Training args ----
args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=4,   # effective batch 32
    learning_rate=2e-4,              # LoRA can tolerate higher LR
    num_train_epochs=10,
    warmup_ratio=0.05,
    logging_steps=100,
    eval_steps=500,
    save_steps=500,
    evaluation_strategy="steps",
    save_total_limit=2,
    predict_with_generate=True,
    generation_max_length=MAX_LEN,
    fp16=True,
    report_to="none",
)

trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=train_tok,
    eval_dataset=eval_tok,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)

print("Saved to:", OUTPUT_DIR)
