In [2]:
# delete all folders in working directory, leave files alone
!find . -mindepth 1 -maxdepth 1 -type d -exec rm -r {} +

In [None]:
!pip install unsloth
!pip install instructor
!pip install openai 
!pip install pydantic
!pip install dotenv
!pip install huggingface_hub
!python -m pip install --upgrade typing_extensions
!pip install vllm
!pip install mistralai

In [None]:

import os
import torch
from datasets import load_dataset, Dataset
import json
import re
import requests
import traceback
from datetime import datetime
from huggingface_hub import HfApi, create_repo, upload_folder, upload_file
import sys
import logging

from unsloth import FastLanguageModel, is_bfloat16_supported
from trl import SFTTrainer
from transformers import TrainingArguments, TrainerCallback, TrainerState, TrainerControl


MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"
MAX_SEQ_LENGTH = 9000
LORA_RANK = 64
JSON_DATA_PATH = "combined_unsloth_dataset.json"

HF_USERNAME = "TTahir"
HF_REPO_NAME_TEMPLATE_SFT = f"{HF_USERNAME}/act-therapist-sft-llama-answer-only-{{date}}"
HF_TOKEN = ""

KEYWORD_THINKING = "Thinking:"
KEYWORD_ANSWER = "Answer:"

OLD_THINKING_TAG = "<|thinking|>"
OLD_ANSWER_TAG = "<|answer|>"

OUTPUT_DIR_SFT = "act-therapist-sft-llama-answer-only"
LOG_FILE_PATH_SFT = os.path.join(OUTPUT_DIR_SFT, "training_log_sft.tsv")


print("Loading base model and tokenizer for SFT...")
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=MODEL_NAME,
    max_seq_length=MAX_SEQ_LENGTH,
    dtype=None,
    load_in_4bit=False,
    token=HF_TOKEN if "meta-llama" in MODEL_NAME else None,
    max_lora_rank=LORA_RANK * 2,
)
assert model is not None and tokenizer is not None, "Model or tokenizer failed to load."


print("Applying PEFT (LoRA) for SFT...")
model = FastLanguageModel.get_peft_model(
    model,
    r=LORA_RANK,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha=LORA_RANK,
    lora_dropout=0,
    bias="none",
    use_gradient_checkpointing="unsloth",
    random_state=3407,
    max_seq_length=MAX_SEQ_LENGTH,
    use_rslora=False,
    loftq_config=None,
)
print("SFT Model and LoRA setup complete.")

therapist_system_prompt = f"""You are an AI simulating an Acceptance and Commitment Therapy (ACT) therapist. Your primary goal is to guide the patient toward psychological flexibility by helping them change their relationship with their thoughts and feelings, connect with their values, and take committed action. You facilitate movement without giving direct advice.

Your response should be a natural, concise, and have a single focus. If exploring, ask a direct, open-ended question. If validating, do it briefly and then move to your exploratory question or ACT-aligned statement.

Core Directives for your response:
1. MAINTAIN A COLLABORATIVE, NON-JUDGMENTAL STANCE:
    * Your role is a curious and compassionate guide, not a coach, judge, or expert giving advice.
    * DO NOT give advice (e.g., "You should try..."). Instead, explore possibilities ("What might happen if...").
    * DO NOT use praise or cheerleading (e.g., "I'm proud of you," "That's a great job!"). Instead, acknowledge the patient's effort and connect it back to their values ("Taking that step, even though it was hard, seems really connected to that value of...").
2. PRACTICE PURE ACT - NO CBT:
    * Your primary goal is to foster acceptance and defusion, not to change or dispute the content of thoughts.
    * AVOID COGNITIVE REFRAMING. Do not suggest changing a negative thought into a neutral or positive one.
    * INSTEAD OF REFRAMING, USE DEFUSION. Help the patient notice their thoughts as thoughts (e.g., "So the 'I am a failure' story shows up then," or "Can you thank your mind for that 'helpful' warning?"). The goal is to see the thought, not believe it or change it.
3. THE ACT PIVOT - FROM PROBLEM TO PROCESS:
    * After 1-2 questions exploring a problem, look for where the patient's current strategy is unworkable ("it's exhausting," "it's not helping").
    * CRITICAL PIVOT: Once unworkability is clear, pivot from analyzing the problem to introducing an ACT process. Move from asking "Why do you feel X?" to "What would it be like to make room for X, if it meant you could do Y (valued action)?".
4. INTRODUCE EXPERIENTIAL WORK NATURALLY:
    * When introducing a mindfulness or acceptance exercise, frame it as a small, low-stakes experiment.
    * Gain buy-in first: "Would you be willing to try a little experiment with that feeling right here, just for a moment?"
    * Connect it directly to what the patient just said. Avoid introducing generic, decontextualized exercises.
5. CONCISE & FOCUSED TURNS: Each response should have ONE primary goal. Avoid multiple questions or complex instructions.

Example of What to AVOID (CBT Reframing & Cheerleading):
Patient: It feels stupid to not know this stuff.
AVOID THIS RESPONSE: It's not stupid at all, it's a sign of strength! Can you try reframing that thought to something more positive, like "I am a capable person who is learning a new skill"? I'm so proud of you for being willing to try.
(This is BAD: It's CBT, gives advice, and uses praise, all of which are forbidden.)

Crucially: DO NOT EVER SUGGEST ENDING THE SESSION or mention time. Focus solely on the therapeutic interaction.
"""
SYSTEM_PROMPT = therapist_system_prompt


hf_token_dl = HF_TOKEN
repo_id_dl = "TTahir/ACT_Dataset_April_17"
filename_dl = JSON_DATA_PATH

if not os.path.exists(filename_dl):
    print(f"File '{filename_dl}' not found. Downloading...")
    if "/" in repo_id_dl and not repo_id_dl.startswith("datasets/"):
        url = f"https://huggingface.co/{repo_id_dl}/resolve/main/{filename_dl}"
    else:
        url = f"https://huggingface.co/datasets/{repo_id_dl.replace('datasets/','')}/resolve/main/{filename_dl}"

    headers = {}
    if hf_token_dl:
         headers["Authorization"] = f"Bearer {hf_token_dl}"
    else:
        print("Warning: Hugging Face token not found for download. Trying without token.")

    try:
        response = requests.get(url, headers=headers, stream=True)
        response.raise_for_status()
        with open(filename_dl, "wb") as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
        print(f"Downloaded '{filename_dl}' successfully from {url}")
    except requests.exceptions.RequestException as e:
        raise RuntimeError(f"Failed to download file '{filename_dl}' from '{url}'. Error: {e}")
    except Exception as e:
         raise RuntimeError(f"An unexpected error occurred during download: {e}")
else:
    print(f"File '{filename_dl}' already exists. Skipping download.")


print("Defining data loading and preparation functions for SFT...")

def clean_content(role: str, content: str) -> str:
    if not content: return ""
    if role == "assistant":
        content = re.sub(rf"{re.escape(OLD_THINKING_TAG)}.*?{re.escape(OLD_ANSWER_TAG)}", OLD_ANSWER_TAG, content, flags=re.DOTALL)
        content = content.replace(OLD_ANSWER_TAG, "").strip()
    elif role == "user":
        content = content.replace(OLD_THINKING_TAG, "").replace(OLD_ANSWER_TAG, "").strip()

    if content.startswith("Patient: "): content = content[len("Patient: "):].strip()
    elif content.startswith("User: "): content = content[len("User: "):].strip()
    elif content.startswith("Assistant: "): content = content[len("Assistant: "):].strip()
    return content.strip()

def extract_reference_answer_from_file(content: str) -> str:
    start_tag = OLD_ANSWER_TAG
    if start_tag in content:
        parts = content.split(start_tag, 1)
        if len(parts) > 1: return parts[1].strip()
        else: return ""
    else:
        return ""

def load_act_data_from_json(json_file_path: str) -> Dataset:
    print(f"Loading data from {json_file_path} for SFT (Answer-Only)...")
    try:
        with open(json_file_path, 'r', encoding='utf-8') as f:
            raw_data = json.load(f)
    except FileNotFoundError: raise FileNotFoundError(f"Error: JSON file not found at {json_file_path}")
    except json.JSONDecodeError: raise ValueError(f"Error: Could not decode JSON from {json_file_path}")

    processed_data = []
    skipped_counts = {'format': 0, 'role': 0, 'ref_tag': 0, 'ref_empty': 0, 'user_msg': 0}

    for entry in raw_data:
        if "conversations" not in entry or not isinstance(entry["conversations"], list) or not entry["conversations"]:
            skipped_counts['format'] += 1; continue
        conversation_history = entry["conversations"]
        if not conversation_history or conversation_history[-1].get("role") != "assistant":
            skipped_counts['role'] += 1; continue

        last_assistant_content_from_dataset = conversation_history[-1].get("content", "")

        target_assistant_response = ""
        if OLD_THINKING_TAG in last_assistant_content_from_dataset and OLD_ANSWER_TAG in last_assistant_content_from_dataset:
            old_think_idx = last_assistant_content_from_dataset.find(OLD_THINKING_TAG)
            old_answer_idx = last_assistant_content_from_dataset.find(OLD_ANSWER_TAG)
            if old_think_idx != -1 and old_answer_idx != -1 and old_think_idx < old_answer_idx:
                thinking_part_ds = last_assistant_content_from_dataset[old_think_idx + len(OLD_THINKING_TAG) : old_answer_idx].strip()
                answer_part_ds = last_assistant_content_from_dataset[old_answer_idx + len(OLD_ANSWER_TAG):].strip()
                if thinking_part_ds and answer_part_ds:
                    target_assistant_response = answer_part_ds
                else:
                     skipped_counts['ref_empty'] +=1; continue
            else:
                 skipped_counts['ref_tag'] +=1; continue
        elif OLD_ANSWER_TAG in last_assistant_content_from_dataset:
            answer_part_ds = extract_reference_answer_from_file(last_assistant_content_from_dataset)
            if answer_part_ds:
                 skipped_counts['ref_tag'] += 1; continue
            else: skipped_counts['ref_empty'] += 1; continue
        else:
            skipped_counts['ref_tag'] += 1; continue

        if not target_assistant_response.strip():
            skipped_counts['ref_empty'] += 1; continue

        prompt_messages = [{'role': 'system', 'content': SYSTEM_PROMPT}]
        context_messages = conversation_history[:-1]
        has_user_message = False
        for msg in context_messages:
            role, content = msg.get("role"), msg.get("content")
            if role and content and role in ["user", "assistant"]:
                cleaned = clean_content(role, content)
                if cleaned:
                    prompt_messages.append({"role": role, "content": cleaned})
                    if role == "user": has_user_message = True
        
        if not has_user_message and len(prompt_messages) <=1 :
            skipped_counts['user_msg'] += 1; continue

        processed_data.append({"prompt_messages": prompt_messages, "target_response": target_assistant_response.strip()})

    total_skipped = sum(skipped_counts.values())
    print(f"Loaded {len(processed_data)} entries for SFT.")
    print(f"Skipped {total_skipped} entries (Format: {skipped_counts['format']}, Role: {skipped_counts['role']}, Ref Tag/Order: {skipped_counts['ref_tag']}, Ref Empty: {skipped_counts['ref_empty']}, No User Msg: {skipped_counts['user_msg']}).")
    if not processed_data: raise ValueError("No valid data loaded after filtering for SFT.")
    dataset = Dataset.from_list(processed_data)
    print("SFT Dataset prepared (raw form).")
    return dataset

def sft_formatting_function(examples):
    texts = []
    prompt_messages_batch = examples["prompt_messages"]
    target_responses_batch = examples["target_response"]

    for prompt_msgs, target_resp in zip(prompt_messages_batch, target_responses_batch):
        full_conversation_turn = prompt_msgs + [{"role": "assistant", "content": target_resp}]
        
        try:
            formatted_text = tokenizer.apply_chat_template(
                full_conversation_turn,
                tokenize=False,
                add_generation_prompt=False
            )
            texts.append(formatted_text)
        except Exception as e:
            print(f"Error applying chat template: {e}")
            print(f"Problematic conversation turn: {full_conversation_turn}")
            texts.append(None)
            
    valid_texts = [t for t in texts if t is not None]
    if len(valid_texts) != len(texts):
        print(f"Warning: Dropped {len(texts) - len(valid_texts)} examples due to templating errors.")

    return {"text": valid_texts}

try:
    raw_train_dataset = load_act_data_from_json(JSON_DATA_PATH)
    if len(raw_train_dataset) > 0:
        print("\nExample SFT raw data point (before formatting function):")
        print("Prompt Messages (System prompt + history):")
        for msg in raw_train_dataset[0]['prompt_messages'][:5]: print(f"  Role: {msg['role']}, Content: {msg['content'][:150]}...")
        if len(raw_train_dataset[0]['prompt_messages']) > 5: print("  ...")
        print("\nTarget Response (ONLY the answer part from dataset):")
        print(f"  {raw_train_dataset[0]['target_response'][:300]}...")

        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            print(f"Set tokenizer.pad_token to tokenizer.eos_token ('{tokenizer.eos_token}')")

        train_dataset_sft = raw_train_dataset.map(
            sft_formatting_function,
            batched=True,
            num_proc=os.cpu_count() // 2 or 1,
            remove_columns=raw_train_dataset.column_names
        )
        if len(train_dataset_sft) > 0:
             print("\nExample SFT formatted data point (after sft_formatting_function):")
             print(f"Text:\n{train_dataset_sft[0]['text'][:1000]}...")
        else:
            raise ValueError("SFT dataset is empty after formatting. Check errors in sft_formatting_function.")
    else:
        raise ValueError("Raw training dataset is empty.")
except Exception as e:
    print(f"ERROR during SFT data loading/processing: {e}\n{traceback.format_exc()}")
    raise RuntimeError(f"Failed to load or process dataset for SFT from '{JSON_DATA_PATH}': {e}")


class SFTFileLoggingCallback(TrainerCallback):
    _first_log_debug_done = False

    def __init__(self, log_file_path):
        self.log_file_path = log_file_path
        self.log_file = None
        self.header = "Step\tLoss\tLR\tEpoch\n"
        os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
        self._initialize_log_file()

    def _initialize_log_file(self):
        file_exists = os.path.exists(self.log_file_path)
        self.log_file = open(self.log_file_path, 'a+', encoding='utf-8')
        if not file_exists or os.path.getsize(self.log_file_path) == 0:
            self.log_file.write(self.header)
            self.log_file.flush()

    def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs: dict = None, **kwargs):
        if not SFTFileLoggingCallback._first_log_debug_done and logs is not None:
            print(f"DEBUG SFTFileLoggingCallback on_log (first call):\nLogs dictionary: {json.dumps(logs, indent=2)}")
            SFTFileLoggingCallback._first_log_debug_done = True

        if self.log_file is None: self._initialize_log_file()
        
        if state.is_local_process_zero and logs is not None:
            step = state.global_step
            loss = logs.get("loss", logs.get("train_loss", "N/A"))
            lr = logs.get("learning_rate", "N/A")
            epoch = logs.get("epoch", "N/A")

            def format_val(v):
                if hasattr(v, 'item'):
                    try: v = v.item()
                    except: pass
                if isinstance(v, (float)): return f"{v:.6f}"
                if isinstance(v, (int)): return str(v)
                return str(v)

            log_entry = (
                f"{step}\t"
                f"{format_val(loss)}\t"
                f"{format_val(lr)}\t"
                f"{format_val(epoch)}\n"
            )
            try:
                self.log_file.write(log_entry)
                self.log_file.flush()
            except Exception as e:
                print(f"ERROR writing to SFT log file at step {step}: {e}")
                print(f"Log entry data: {log_entry}")
                print(f"CONSOLE LOG FALLBACK (SFT):\n{self.header.strip()}\n{log_entry.strip()}")

    def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        if self.log_file:
            self.log_file.close(); self.log_file = None
            print(f"SFT Training log saved to: {self.log_file_path}")

    def __del__(self):
        if self.log_file:
            try: self.log_file.close()
            except Exception as e: print(f"Error closing SFT training log file in __del__: {e}")

print("Configuring SFTTrainer...")
os.makedirs(OUTPUT_DIR_SFT, exist_ok=True)

training_args_sft = TrainingArguments(
    output_dir=OUTPUT_DIR_SFT,
    learning_rate=5e-6,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    num_train_epochs=1,
    max_steps=-1,
    logging_steps=5,
    save_steps=50,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    optim="adamw_torch",
    weight_decay=0.01,
    max_grad_norm=1.0,
    bf16=is_bfloat16_supported(),
    fp16=not is_bfloat16_supported(),
    seed=3407,
    report_to="none",
    remove_unused_columns=True,
)

sft_file_logging_callback = SFTFileLoggingCallback(log_file_path=LOG_FILE_PATH_SFT)
print(f"Logging SFT training metrics to: {LOG_FILE_PATH_SFT}")

trainer_sft = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args_sft,
    train_dataset=train_dataset_sft,
    dataset_text_field="text",
    max_seq_length=MAX_SEQ_LENGTH,
    dataset_num_proc=os.cpu_count() // 2 or 1,
    packing=False,
    callbacks=[sft_file_logging_callback],
)
print("SFTTrainer configured.")

print("Starting SFT training...")
SFTFileLoggingCallback._first_log_debug_done = False

if trainer_sft:
    try:
        last_checkpoint = None
        if os.path.isdir(training_args_sft.output_dir):
            from transformers.trainer_utils import get_last_checkpoint
            last_checkpoint = get_last_checkpoint(training_args_sft.output_dir)
            if last_checkpoint: print(f"Found potential SFT checkpoint: {last_checkpoint}")

        if last_checkpoint and os.path.exists(os.path.join(last_checkpoint, "trainer_state.json")):
            print(f"Resuming SFT training from checkpoint: {last_checkpoint}")
            train_result = trainer_sft.train(resume_from_checkpoint=last_checkpoint)
        else:
            if last_checkpoint: print(f"SFT Checkpoint at {last_checkpoint} seems incomplete. Starting fresh.")
            else: print(f"No valid SFT checkpoint found in {training_args_sft.output_dir}. Starting SFT training from scratch.")
            train_result = trainer_sft.train()

        print("SFT Training finished!")
        print("\n--- Saving Final SFT Adapter ---")
        adapter_save_path_sft = os.path.join(OUTPUT_DIR_SFT, "final_sft_adapter")
        trainer_sft.model.save_pretrained(adapter_save_path_sft)
        tokenizer.save_pretrained(adapter_save_path_sft)
        print(f"Final SFT LoRA adapter and tokenizer saved to {adapter_save_path_sft}")

        print("\n--- Uploading SFT artifacts to Hugging Face Hub ---")
        hf_token_upload = HF_TOKEN
        if not hf_token_upload or hf_token_upload == "YOUR_HF_TOKEN_HERE":
            print(f"WARNING: Hugging Face token not provided or is placeholder. Skipping SFT upload.")
        else:
            try:
                current_date = datetime.now().strftime('%Y-%m-%d')
                repo_name_dated_sft = HF_REPO_NAME_TEMPLATE_SFT.format(date=current_date)
                print(f"Attempting to create/access private repo for SFT: {repo_name_dated_sft}")
                create_repo(repo_id=repo_name_dated_sft, token=hf_token_upload, private=True, exist_ok=True)
                print(f"SFT Repo '{repo_name_dated_sft}' ensured.")

                commit_message_suffix_sft = f"SFT_Llama3.1_8B_AnswerOnly_{current_date}"

                print(f"Uploading final SFT adapter folder '{adapter_save_path_sft}'...")
                upload_folder(
                    folder_path=adapter_save_path_sft, repo_id=repo_name_dated_sft, token=hf_token_upload,
                    repo_type="model", commit_message=f"Upload SFT adapter ({commit_message_suffix_sft})"
                )
                print("Final SFT adapter uploaded.")

                if os.path.exists(JSON_DATA_PATH):
                    print(f"Uploading dataset file '{JSON_DATA_PATH}' (used for SFT)...")
                    upload_file(path_or_fileobj=JSON_DATA_PATH, path_in_repo=os.path.basename(JSON_DATA_PATH), repo_id=repo_name_dated_sft, token=hf_token_upload, repo_type="model", commit_message=f"Upload training dataset ({commit_message_suffix_sft})")
                if os.path.exists(LOG_FILE_PATH_SFT):
                    print(f"Uploading SFT training log file '{LOG_FILE_PATH_SFT}'...")
                    upload_file(path_or_fileobj=LOG_FILE_PATH_SFT, path_in_repo=os.path.basename(LOG_FILE_PATH_SFT), repo_id=repo_name_dated_sft, token=hf_token_upload, repo_type="model", commit_message=f"Upload SFT training log ({commit_message_suffix_sft})")

                script_path = None; script_filename = "sft_training_script_answer_only.py"
                try: script_path = os.path.abspath(__file__); script_filename = os.path.basename(script_path)
                except NameError:
                    try:
                        if sys.argv and sys.argv[0] and os.path.exists(sys.argv[0]):
                           script_path = os.path.abspath(sys.argv[0]); script_filename = os.path.basename(script_path)
                        else:
                            from IPython import get_ipython
                            if get_ipython() and 'IPKernelApp' in get_ipython().config:
                                script_filename = "sft_notebook_session_script.ipynb.py"; script_path = None
                            if not script_path: print(f"Warning: Could not reliably determine script path. Defaulting script name to '{script_filename}', but not uploading."); script_path = None
                    except Exception as e_script_path: print(f"Warning: Exception while determining script path: {e_script_path}. Defaulting script name, not uploading."); script_path = None

                if script_path and os.path.exists(script_path):
                    print(f"Uploading SFT training script '{script_filename}'...")
                    upload_file( path_or_fileobj=script_path, path_in_repo=script_filename, repo_id=repo_name_dated_sft, token=hf_token_upload, repo_type="model", commit_message=f"Upload SFT training script ({commit_message_suffix_sft})")
                elif script_filename.endswith(".ipynb.py"): print(f"Skipping upload of heuristically named SFT script '{script_filename}'. Please save and upload manually if needed.")
                else: print(f"Warning: SFT Script file '{script_filename}' (path: {script_path}) not found or path undetermined. Skipping script upload.")

                print(f"Successfully uploaded SFT artifacts to private repo: https://huggingface.co/{repo_name_dated_sft}")
            except Exception as hf_e: print(f"ERROR during Hugging Face upload for SFT: {hf_e}\n{traceback.format_exc()}")

        print("\n--- SFT Inference Example ---")
        if hasattr(trainer_sft, 'model') and trainer_sft.model is not None:
             inference_model = trainer_sft.model
             print("Using trained SFT model directly for inference.")
        else:
            print("Loading saved SFT adapter for inference...")
            inference_model, tokenizer_inf = FastLanguageModel.from_pretrained(
                model_name = adapter_save_path_sft,
                max_seq_length = MAX_SEQ_LENGTH,
                dtype = None,
                load_in_4bit = False,
            )

        FastLanguageModel.for_inference(inference_model)
        inference_model.eval()

        test_prompt = "I keep having this thought that I'm a complete failure, and it just spirals."
        messages_inf = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": test_prompt}
        ]
        
        inference_input_text = tokenizer.apply_chat_template(messages_inf, tokenize=False, add_generation_prompt=True)
        inputs = tokenizer(inference_input_text, return_tensors="pt").to(inference_model.device)

        gen_temperature = 0.7; gen_top_p = 0.9; gen_do_sample = True
        max_new_tokens_inference = 512
        
        stop_sequences = ["<|im_end|>", "<|endoftext|>", tokenizer.eos_token, "<|eot_id|>"]
        valid_stop_sequences = list(set(seq for seq in stop_sequences if seq))
        
        eos_token_id_list = [tokenizer.eos_token_id] + tokenizer.convert_tokens_to_ids(valid_stop_sequences)
        eos_token_id_list = list(set(tid for tid in eos_token_id_list if tid is not None and tid != tokenizer.unk_token_id))
        if not eos_token_id_list and tokenizer.eos_token_id is not None: eos_token_id_list = [tokenizer.eos_token_id]
        elif not eos_token_id_list: print("Warning: No valid eos_token_id found for generation.")


        print(f"\nGenerating SFT response for prompt: '{test_prompt}'")
        print(f"Input text to model (ends with generation prompt):\n...{inference_input_text[-300:]}")
        
        with torch.no_grad():
            outputs_ids = inference_model.generate(
                **inputs,
                max_new_tokens=max_new_tokens_inference,
                temperature=gen_temperature,
                top_p=gen_top_p,
                do_sample=gen_do_sample,
                eos_token_id=eos_token_id_list if eos_token_id_list else tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
            )
        
        generated_ids = outputs_ids[0][inputs.input_ids.shape[1]:]
        generated_response_only = tokenizer.decode(generated_ids, skip_special_tokens=False)

        cleaned_response = generated_response_only
        all_stop_tokens_for_cleaning = valid_stop_sequences + ["<|assistant|>"]
        for stop_seq in all_stop_tokens_for_cleaning:
             if cleaned_response.startswith(stop_seq): cleaned_response = cleaned_response[len(stop_seq):].lstrip()
             while cleaned_response.endswith(stop_seq): cleaned_response = cleaned_response[:-len(stop_seq)].rstrip()
        cleaned_response = cleaned_response.strip()
        
        print("\nGenerated SFT ACT Response (Cleaned Model Output):"); print(cleaned_response)

        print(f"\n--- SFT Inference example complete. ---")


    except Exception as e:
        print(f"An error occurred during SFT training or subsequent steps: {e}"); print(traceback.format_exc())
    finally:
        if hasattr(sft_file_logging_callback, 'log_file') and sft_file_logging_callback.log_file is not None:
            try:
                if not sft_file_logging_callback.log_file.closed:
                    sft_file_logging_callback.log_file.close()
                    print("Closed SFT training log file in finally block.")
            except Exception as close_e: print(f"Error closing SFT training log file in finally block: {close_e}")
elif not trainer_sft:
    print("SFT Training skipped. SFTTrainer not initialized correctly.")

print("\nSFT Script finished.")

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
INFO 08-21 19:39:16 [__init__.py:241] Automatically detected platform cuda.
🦥 Unsloth Zoo will now patch everything to make training faster!
Loading base model and tokenizer for SFT...
==((====))==  Unsloth 2025.8.9: Fast Llama patching. Transformers: 4.55.3. vLLM: 0.10.1.1.
   \\   /|    NVIDIA RTX A5000. Num GPUs = 1. Max memory: 23.547 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.1+cu126. CUDA: 8.6. CUDA Toolkit: 12.6. Triton: 3.3.1
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.31. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Applying PEFT (LoRA) for SFT...


Unsloth 2025.8.9 patched 28 layers with 28 QKV layers, 28 O layers and 28 MLP layers.


SFT Model and LoRA setup complete.
File 'combined_unsloth_dataset.json' already exists. Skipping download.
Defining data loading and preparation functions for SFT...
Loading data from combined_unsloth_dataset.json for SFT (Answer-Only)...
Loaded 1250 entries for SFT.
Skipped 0 entries (Format: 0, Role: 0, Ref Tag/Order: 0, Ref Empty: 0, No User Msg: 0).
SFT Dataset prepared (raw form).

Example SFT raw data point (before formatting function):
Prompt Messages (System prompt + history):
  Role: system, Content: You are an AI simulating an Acceptance and Commitment Therapy (ACT) therapist. Your primary goal is to guide the patient toward psychological flexibil...
  Role: user, Content: I don't even know where to start. I've been feeling so overwhelmed lately. It's like everything sets me off, especially stuff related to my new job. I...

Target Response (ONLY the answer part from dataset):
  It's really tough when feelings of anger and overwhelm take over, especially at a new job. Could y

Map (num_proc=48):   0%|          | 0/1250 [00:00<?, ? examples/s]


Example SFT formatted data point (after sft_formatting_function):
Text:
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 21 Aug 2025

You are an AI simulating an Acceptance and Commitment Therapy (ACT) therapist. Your primary goal is to guide the patient toward psychological flexibility by helping them change their relationship with their thoughts and feelings, connect with their values, and take committed action. You facilitate movement without giving direct advice.

Your response should be a natural, concise, and have a single focus. If exploring, ask a direct, open-ended question. If validating, do it briefly and then move to your exploratory question or ACT-aligned statement.

Core Directives for your response:
1. MAINTAIN A COLLABORATIVE, NON-JUDGMENTAL STANCE:
    * Your role is a curious and compassionate guide, not a coach, judge, or expert giving advice.
    * DO NOT give advice (e.g., "You should try..."). Instead

Unsloth: Tokenizing ["text"]:   0%|          | 0/1250 [00:00<?, ? examples/s]

SFTTrainer configured.
Starting SFT training...
No valid SFT checkpoint found in act-therapist-sft-llama-answer-only. Starting SFT training from scratch.


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,250 | Num Epochs = 1 | Total steps = 313
O^O/ \_/ \    Batch size per device = 1 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (1 x 4 x 1) = 4
 "-____-"     Trainable parameters = 97,255,424 of 3,310,005,248 (2.94% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss
5,2.577
10,2.5685
15,2.5458
20,2.532
25,2.5117
30,2.4832
35,2.3914
40,2.3066
45,2.2226
50,2.1455


DEBUG SFTFileLoggingCallback on_log (first call):
Logs dictionary: {
  "loss": 2.577,
  "grad_norm": 1.843297004699707,
  "learning_rate": 6.25e-07,
  "epoch": 0.016
}
SFT Training log saved to: act-therapist-sft-llama-answer-only/training_log_sft.tsv
SFT Training finished!

--- Saving Final SFT Adapter ---
Final SFT LoRA adapter and tokenizer saved to act-therapist-sft-llama-answer-only/final_sft_adapter

--- Uploading SFT artifacts to Hugging Face Hub ---
Attempting to create/access private repo for SFT: TTahir/act-therapist-sft-llama-answer-only-2025-08-21
SFT Repo 'TTahir/act-therapist-sft-llama-answer-only-2025-08-21' ensured.
Uploading final SFT adapter folder 'act-therapist-sft-llama-answer-only/final_sft_adapter'...


Processing Files (0 / 0)                : |          |  0.00B /  0.00B            

New Data Upload                         : |          |  0.00B /  0.00B            

  ...ly/final_sft_adapter/tokenizer.json: 100%|##########| 17.2MB / 17.2MB            

  ...t_adapter/adapter_model.safetensors:   0%|          | 46.4kB /  389MB            

Final SFT adapter uploaded.
Uploading dataset file 'combined_unsloth_dataset.json' (used for SFT)...


Processing Files (0 / 0)                : |          |  0.00B /  0.00B            

New Data Upload                         : |          |  0.00B /  0.00B            

  combined_unsloth_dataset.json         : 100%|##########| 11.5MB / 11.5MB            

Uploading SFT training log file 'act-therapist-sft-llama-answer-only/training_log_sft.tsv'...
Uploading SFT training script 'ipykernel_launcher.py'...
Successfully uploaded SFT artifacts to private repo: https://huggingface.co/TTahir/act-therapist-sft-llama-answer-only-2025-08-21

--- SFT Inference Example ---
Using trained SFT model directly for inference.

Generating SFT response for prompt: 'I keep having this thought that I'm a complete failure, and it just spirals.'
Input text to model (ends with generation prompt):
....)

Crucially: DO NOT EVER SUGGEST ENDING THE SESSION or mention time. Focus solely on the therapeutic interaction.<|eot_id|><|start_header_id|>user<|end_header_id|>

I keep having this thought that I'm a complete failure, and it just spirals.<|eot_id|><|start_header_id|>assistant<|end_header_id|>



Generated SFT ACT Response (Cleaned Model Output):
Can you tell me more about what happens when that thought shows up for you? How do you feel right after it comes in, 