In [None]:
import pandas as pd
from datasets import Dataset

# Paths
train_path = "/content/drive/MyDrive/datasets/train_100k.csv"
test_path = "/content/drive/MyDrive/datasets/test_1k.csv"

# Load full datasets
train_df = pd.read_csv(train_path)
test_df = pd.read_csv(test_path)

train_df = train_df[["abstract", "title"]]
test_df = test_df[["abstract", "title"]]

train_sample = train_df.sample(n=1000, random_state=42)
test_sample = test_df.sample(n=100, random_state=42)


In [None]:
from transformers import MarianMTModel, MarianTokenizer
import pandas as pd
from tqdm.auto import tqdm

# 1. Load translation models
en_to_fr_model_name = "Helsinki-NLP/opus-mt-en-fr"
fr_to_en_model_name = "Helsinki-NLP/opus-mt-fr-en"

en_to_fr_tokenizer = MarianTokenizer.from_pretrained(en_to_fr_model_name)
en_to_fr_model = MarianMTModel.from_pretrained(en_to_fr_model_name).to("cuda")

fr_to_en_tokenizer = MarianTokenizer.from_pretrained(fr_to_en_model_name)
fr_to_en_model = MarianMTModel.from_pretrained(fr_to_en_model_name).to("cuda")

def back_translate(text, batch_size=1):
    # English → French
    inputs = en_to_fr_tokenizer([text], return_tensors="pt", padding=True, truncation=True).to("cuda")
    translated = en_to_fr_model.generate(**inputs, max_new_tokens=256)
    fr_text = en_to_fr_tokenizer.batch_decode(translated, skip_special_tokens=True)[0]

    # French → English
    inputs = fr_to_en_tokenizer([fr_text], return_tensors="pt", padding=True, truncation=True).to("cuda")
    back_translated = fr_to_en_model.generate(**inputs, max_new_tokens=256)
    en_text = fr_to_en_tokenizer.batch_decode(back_translated, skip_special_tokens=True)[0]

    return en_text

# 2. Back-translate abstracts (start with e.g. 1000 samples)
augmented_rows = []

for idx, row in tqdm(train_df.iterrows(), total=len(train_df), desc="Back-translating"):
    try:
        new_abstract = back_translate(row["abstract"])
        augmented_rows.append({
            "abstract": new_abstract,
            "title": row["title"]
        })
    except Exception as e:
        print(f"Skipping index {idx}: {e}")
        continue

# 3. Convert the augmented data into a DataFrame
augmented_df = pd.DataFrame(augmented_rows)

# 4. Combine the original and augmented data
final_train_df = pd.concat([train_df, augmented_df], ignore_index=True)

# 5. Save the new dataframe (optional)
final_train_df.to_csv("/content/augmented_train.csv", index=False)

print("Augmented dataset has", len(final_train_df), "rows.")
