In [None]:
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --no-deps xformers trl peft accelerate bitsandbytes

In [None]:
from tqdm import tqdm
from huggingface_hub import notebook_login 
from datasets import load_from_disk, Dataset, concatenate_datasets

from unsloth import FastLanguageModel
import torch
import re
import pandas as pd

## Hyperparameters and Config

In [None]:
# Models
HF_MODEL_ID = "RodrigoSalazar-U/ang-base"

# Dataset
INPUT_DATASET_PATH = "./hf-repo/unseen"
OUTPUT_DATASET_PATH = "./hf-repo/augmented"

## Accounts login

In [None]:
# Login
notebook_login()

## Load model

Download base model and initialize using unslothed for inference

In [None]:
print(f"Loading model {HF_MODEL_ID}")
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = HF_MODEL_ID,
    max_seq_length = 2048,
    dtype = None,
    load_in_4bit = True,
)

FastLanguageModel.for_inference(model)

## Load Dataset

Expects HF datasets file.
Format of the dataset is as follows:
- `prompt`: the prompt for the task

In [None]:
# Load
input_dataset = load_from_disk(INPUT_DATASET_PATH)

## Generation

In [None]:
END_TAG = "[/"

## Function to generate translations of a list of input prompts
def generate_translation(inp_prompts):
    # Tokenize input prompts and move to GPU
    inputs = tokenizer(
        inp_prompts,
        return_tensors="pt",
        padding=True,  # Ensure batch processing works
        truncation=True  # In case any prompt exceeds max length
    ).to("cuda")

    # Generate output for each prompt in the list
    outputs = model.generate(
        **inputs,
        max_new_tokens=1024,
        min_new_tokens=3, # Force generation
        stop_strings=[END_TAG],  # Use list for stop strings to handle multiple inputs
        tokenizer=tokenizer,
        use_cache=True,
        do_sample=False,     # Deterministic
        #num_beam = 5,       # Number of beams
    )

    # Decode batch of outputs
    generated_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    # Post-process each output
    final_outputs = [g for g in generated_outputs]

    # Move tensors back to CPU and clear GPU memory
    inputs = inputs.to("cpu")
    outputs = outputs.to("cpu")
    torch.cuda.empty_cache()

    # Return list of generated outputs
    return final_outputs

## Batch generation of translations for a dataset
def generate_translation_dataset(dataset: Dataset, batch_size: int = 32) -> Dataset:
    translation_pairs = []

    # Iterate over the dataset in batches
    for i in tqdm(range(0, len(dataset), batch_size)):
        # Collect the input prompts for the current batch
        inp_batch = dataset["prompt"][i:i + batch_size]

        # Generate translations for the batch of inputs
        generated_batch = generate_translation(inp_batch)

        # Append each translation pair (expected, generated) to the list
        for prompt, generated in zip(inp_batch, generated_batch):
            translation_pairs.append({"prompt": prompt, "generated": generated})

    # Convert the list of translation pairs back into a Dataset
    return Dataset.from_list(translation_pairs)


In [None]:
synthetic_dataset = generate_translation_dataset(input_dataset)

## Filter
Remove any low quality outputs

In [None]:
COL_ANG_TEXT = "ANG_text"
COL_EN_TEXT = "EN_text"
LANG_CODE_ANG = "ANG"
LANG_CODE_EN = "EN"

language_codes = {
    "ANG": "Anglo-Saxon",
    "EN": "English"
}

def get_translation_prompt(src, tgt, text, translation):
    target_language = language_codes[tgt]
    source_language = language_codes[src]
    prompt = f"[INST]Translate the following {source_language} fragment to {target_language}[/INST]\n[{src}]{text}[/{src}]\n[{tgt}]"
    answer = f"{translation}[/{tgt}]"
    text = f"{prompt}{answer}"
    return {"prompt": prompt, "answer": answer, "text": text}

def synthetic_filter(
    synth: pd.DataFrame,
    min_generation_length: int = 20,
    min_word_count: int = 5,
    max_generation_length: int = 1000,
    # 1024 in reality. slightly less to be safe
) -> pd.DataFrame:
  """
  Removes low quality rows from the synthetic parallel corpus
  """
  # Measure initial size
  initial_size = synth.shape[0]
  print(f"Initial size: {initial_size}")

  # Remove NaN
  synth = synth.dropna()

  # Validate the length of the generated text
  synth = synth[synth[COL_EN_TEXT].str.len() > min_generation_length]
  synth = synth[synth[COL_EN_TEXT].str.len() < max_generation_length]

  # Minimum word count
  synth = synth[synth[COL_EN_TEXT].str.split().str.len() > min_word_count]

  # Detect if EN column has ANG exclusive characters
  def has_ang_chars(row):
    for char in row[COL_EN_TEXT]:
      if char in "ÆæǷƿÞþÐð":
        return True
    return False
  synth['has_ang_chars'] = synth.apply(has_ang_chars, axis=1)

  # Remove rows with ANG exclusive characters
  synth = synth[synth['has_ang_chars'] == False]

  # Drop the has_ang_chars column
  synth = synth.drop(columns=['has_ang_chars'])

  # Get final size
  final_size = synth.shape[0]
  print(f"Filtered {initial_size - final_size} rows from {initial_size} to {final_size} (- {(initial_size - final_size) / initial_size * 100:.2f}%)")

  return synth

def build_train_from_synth(input_dataset,  random_seed=751):
    print(f"Raw dataset: {len(input_dataset)} rows")
    input_len = len(input_dataset)

    # Extract EN and ANG sections from the dataset
    data = []
    for row in input_dataset:
        text = row['text']
        # If the text does not end with "[/" then it is not a valid entry
        if not text.endswith("[/"):
            continue

        # Capture [EN]text[/EN] and [ANG]text[/
        en_groups = re.findall(r'\[EN\](.*?)\[/', text)
        ang_groups = re.findall(r'\[ANG\](.*?)\[/ANG\]', text)
        if len(en_groups) == 0 or len(ang_groups) == 0:
            # Skip if no match
            continue
        en_text = en_groups[0].strip()
        ang_text = ang_groups[0].strip()

        # Append to the dataset
        data.append({COL_ANG_TEXT: ang_text, COL_EN_TEXT: en_text})
    
    extracted_len = len(data)
    # Display stats
    print(f"Miss generated: {input_len - extracted_len} rows ({(input_len - extracted_len) / input_len * 100:.2f}%)")


    # Create a dataframe
    synth_df = pd.DataFrame(data)
    # Apply quality filter
    synth_df = synthetic_filter(synth_df)

    # Create a dataset
    ds = Dataset.from_pandas(synth_df)
    # Create the prompt
    ds_forward = ds.map(
        lambda x: get_translation_prompt(LANG_CODE_ANG, LANG_CODE_EN, x[COL_ANG_TEXT], x[COL_EN_TEXT]),
        remove_columns=ds.column_names
    )
    ds_backward = ds.map(
        lambda x: get_translation_prompt(LANG_CODE_EN, LANG_CODE_ANG, x[COL_EN_TEXT], x[COL_ANG_TEXT]),
        remove_columns=ds.column_names
    )
    # Return combined and shuffled dataset
    return concatenate_datasets([ds_forward, ds_backward]).shuffle(seed=random_seed)

In [None]:
# Save the augmented dataset
output_dataset = build_train_from_synth(synthetic_dataset)

In [None]:
# Shutdown the Colab runtime
from google.colab import runtime
runtime.unassign()