<a href="https://colab.research.google.com/github/zheien/FYP-2024-2025S2/blob/main/combine_collab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install jiwer editdistance accelerate

In [None]:
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from datasets import load_dataset
from jiwer import cer,wer
import editdistance
from transformers import DataCollatorForSeq2Seq
import torch
from datasets import Dataset
import json
import os
import sys
import transformers

# print("[DEBUG] transformers version:", transformers.__version__)
# print("[DEBUG] TrainingArguments location:", transformers.TrainingArguments.__module__)
# print("[DEBUG] TrainingArguments class:", transformers.TrainingArguments)
# print("[DEBUG] dir(transformers.TrainingArguments):", dir(transformers.TrainingArguments))

output_dir = os.path.expanduser("~/scratch/combine-qwen")
os.makedirs(output_dir, exist_ok=True)  # Ensure the directory exists
def format_chat(example, tokenizer, inference=False):
    try:
        input_text = f"### Input\n{example['normalized']}"
    except KeyError:
        print(f"Missing 'normalized' key in example: {example}")
        raise

    if not inference:
        #
        target_text = f"### Output\n{example['unnormalized']}<|endoftext|>"
        full_text = input_text + "\n" + target_text

        # Tokenize full sequence (prompt + response)
        full_ids = tokenizer.encode(full_text, truncation=True, max_length=tokenizer.model_max_length)

        # Tokenize only the prompt to figure out how many tokens to ignore
        input_ids_prompt_only = tokenizer.encode(input_text, truncation=True, max_length=tokenizer.model_max_length)

        # Create labels: ignore prompt tokens using -100
        labels = [-100] * len(input_ids_prompt_only) + full_ids[len(input_ids_prompt_only):]

        return {
            "input_ids": full_ids,
            "labels": labels
        }

    else:
        # In inference mode, just return the input prompt
        input_ids = tokenizer.encode(
            input_text,
            truncation=True,
            max_length=tokenizer.model_max_length
        )
        return {
            "input_ids": input_ids
        }

def prepare_training_data(data_list, tokenizer, device):

    dataset = Dataset.from_list(data_list)
    tokenized_dataset = dataset.map(
        lambda example: format_chat(example, tokenizer),
        batched=False
    )

    print("Data preparation complete.")
    print(tokenized_dataset)
    return tokenized_dataset

def prepare_validation_data(data_list, tokenizer, device):
    print("Starting validation data preparation...")
    dataset = Dataset.from_list(data_list)

    tokenized_dataset = dataset.map(
        lambda example: format_chat(example, tokenizer),
        batched=False
    )

    print("Validation data preparation complete.")
    print(tokenized_dataset)
    return tokenized_dataset

def main():
    import os

    file_name = os.environ.get("FILE_NAME")
    print("Loading the model and tokenizer...")
    model_name = "Qwen/Qwen3-0.6B-Base"

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        device_map="auto"
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    print(f"Model is running on: {device}")

    # Load training data
    training_data_path = "datasets/combined_train.json"
    with open(training_data_path, "r") as f:
        train_data = json.load(f)

    # training_data = Dataset.from_list(raw_data)
    # train_dataset = load_dataset("json", data_files=training_data_path, split="train", streaming=True)

    # Load validation data
    validation_data_path = "datasets/combined_val.json"
    with open(validation_data_path, "r") as f:
        validation_data  = json.load(f)

    # Load test data
    # test_data_path = os.path.join(os.path.dirname(__file__), "datasets", "combined_test.json")
    # with open(test_data_path, "r") as f:
    #     test_data  = json.load(f)

    # Prepare datasets
#     train_shard = training_data.shard(num_shards=num_shards, index=shard_index)
#     print(f"📊 Number of samples in train_shard: {len(train_shard)}")
#     train_set = train_shard.map(
#     lambda example: format_chat(example, tokenizer),
#     remove_columns=[],
#     batched=False,
# )
#     print(train_set)

    # train_set = train_shard.map(prepare_training_data)

    train_dataset = prepare_training_data(train_data, tokenizer, device)



    validation_dataset = prepare_validation_data(validation_data , tokenizer, device)


    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=3,
        # max_steps=1000,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=8,
        learning_rate=1e-4,
        warmup_steps=10000,
        # logging_steps=1000,
        save_steps=50000,
        save_strategy="steps",
        bf16=True,
        eval_strategy="steps",
        eval_steps=50000,
        dataloader_num_workers=4,

    )

    # data_collator = DataCollatorForLanguageModeling(
    # tokenizer=tokenizer,
    # mlm=False,
    # pad_to_multiple_of=8,
    # return_tensors="pt",
    # # padding=True
    # )
    data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer)

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=validation_dataset,
        data_collator=data_collator,
    )

    try:
        trainer.train()
    except Exception as e:
        print(f"Training interrupted: {e}")

    print("Saving the finetuned model...")
    model.save_pretrained("./combine-qwen")
    tokenizer.save_pretrained("./combine-qwen")

    print("Evaluating the model on multiple test datasets...")
    test_files = [
        ("ami", "datasets/ami_test.json"),
        ("swbd", "datasets/swbd_test.json"),
        ("earnings", "datasets/earnings_test.json"),
        ("chime", "datasets/chime_test.json"),
        ("gtn", "datasets/gtn_test.json"),
        ("spgi", "datasets/spgi_test.json")
    ]

    for test_name, test_file_path in test_files:
        print(f"\n🚀 Evaluating on {test_name}...")

        try:
            with open(test_file_path, "r") as f:
                test_data = json.load(f)
        except FileNotFoundError:
            print(f"[WARN] Test file {test_file_path} not found. Skipping.")
            continue

        total_edits = 0
        total_chars = 0

        output_path = f"eval_{test_name}.txt"
        with open(output_path, "w", encoding="utf-8") as f:
            f.write(f"--- Evaluation results for {test_name} ---\n\n")

            for i, raw_example in enumerate(test_data):
                try:
                    eval_input = format_chat(raw_example, tokenizer, inference=True)
                    input_ids = torch.tensor(eval_input['input_ids']).unsqueeze(0).to(device)
                    attention_mask = (input_ids != tokenizer.eos_token_id).long()


                    with torch.no_grad():
                        output_ids = model.generate(
                            input_ids=input_ids,
                            attention_mask=attention_mask,
                            max_new_tokens=256
                        )
                        decoded = tokenizer.decode(output_ids[0], skip_special_tokens=True)
                        generated_text = decoded.split("### Response")[-1].strip() if "### Response" in decoded else decoded.strip()
                        # strip <|endoftext|> from generated output before evaluation
                        generated_text = generated_text.replace("<|endoftext|>", "").strip()

                    expected_text = raw_example['unnormalized']
                    edits = editdistance.eval(generated_text, expected_text)
                    cer = edits / len(expected_text) if len(expected_text) > 0 else 0
                    total_edits += edits
                    total_chars += len(expected_text)

                    # Print to terminal
                    # print(f"\n[Test Set: {test_name} | Sample {i}]")
                    # print(f"Expected: {expected_text}")
                    # print(f"Generated: {generated_text}")
                    # print(f"CER: {cer:.4f}")

                    # Write to file
                    # f.write(f"[Sample {i}]\n")
                    f.write(f"Expected : {expected_text}\n")
                    f.write(f"Generated: {generated_text}\n")
                    f.write(f"CER      : {cer:.4f}\n\n")

                except Exception as e:
                    print(f"[ERROR] {test_name} test sample {i} failed: {e}")
                    f.write(f"[Sample {i}] ERROR: {e}\n\n")

            global_cer = total_edits / total_chars if total_chars > 0 else 0
            print(f"✅ {test_name} CER: {global_cer:.4f}")
            f.write(f"\n✅ Global CER on {test_name}: {global_cer:.4f}\n")


if __name__ == "__main__":
    print("Script started...")
    main()

In [None]:
import os
import json

# Create the datasets directory
os.makedirs("datasets", exist_ok=True)

# Create dummy JSON files for training and validation
train_data = [
    {"normalized": "hello world", "unnormalized": "Hello, World!"},
    {"normalized": "this is a test", "unnormalized": "This is a test."}
]

validation_data = [
    {"normalized": "another test", "unnormalized": "Another test."},
    {"normalized": "one more", "unnormalized": "One more."}
]

with open("datasets/combined_train.json", "w") as f:
    json.dump(train_data, f)

with open("datasets/combined_val.json", "w") as f:
    json.dump(validation_data, f)