<a href="https://colab.research.google.com/github/tanattiyarungtham/Book_Rating_Project/blob/main/DL_fitness_coach.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Step 1. install necessary packages.

# Clean slate - Uninstall possibly conflicting packages
!pip uninstall -y torch torchvision numpy xformers bitsandbytes

# Install compatible versions
!pip install torch==2.5.1 torchvision==0.18.1 numpy==1.26.4 bitsandbytes==0.45.4 xformers==0.0.35.post1

# Clone Axolotl repo (if not already)
!git clone https://github.com/OpenAccess-AI-Collective/axolotl.git || echo "Already cloned"
%cd axolotl

# Install Axolotl in editable mode
!pip install -e .


Found existing installation: torch 2.6.0+cu124
Uninstalling torch-2.6.0+cu124:
  Successfully uninstalled torch-2.6.0+cu124
Found existing installation: torchvision 0.21.0+cu124
Uninstalling torchvision-0.21.0+cu124:
  Successfully uninstalled torchvision-0.21.0+cu124
Found existing installation: numpy 2.0.2
Uninstalling numpy-2.0.2:
  Successfully uninstalled numpy-2.0.2
[0mCollecting torch==2.5.1
  Downloading torch-2.5.1-cp311-cp311-manylinux1_x86_64.whl.metadata (28 kB)
Collecting torchvision==0.18.1
  Downloading torchvision-0.18.1-cp311-cp311-manylinux1_x86_64.whl.metadata (6.6 kB)
Collecting numpy==1.26.4
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting bitsandbytes==0.45.4
  Downloading bitsandbytes-0.45.4-py3-none-manylinux_2_24_x86_64.whl.metadata (5.0 kB)
[31mERROR: Ignored the followi

In [None]:
# 2. Clone and Install Axolotl (if fine-tuning)
!git clone https://github.com/OpenAccess-AI-Collective/axolotl.git || echo "Already cloned"
%cd axolotl
!pip install -e .

Cloning into 'axolotl'...
remote: Enumerating objects: 31351, done.[K
remote: Counting objects: 100% (657/657), done.[K
remote: Compressing objects: 100% (286/286), done.[K
remote: Total 31351 (delta 481), reused 395 (delta 370), pack-reused 30694 (from 3)[K
Receiving objects: 100% (31351/31351), 13.17 MiB | 16.91 MiB/s, done.
Resolving deltas: 100% (21072/21072), done.
/content/axolotl/axolotl
Obtaining file:///content/axolotl/axolotl
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: axolotl
  Building editable for axolotl (pyproject.toml) ... [?25l[?25hdone
  Created wheel for axolotl: filename=axolotl-0.8.0-0.editable-py3-none-any.whl size=9650 sha256=cf9a312df60fdf0554c66d7fa6e1b19e85c1e95831bfbdf6771ca0b133d68178
  Stored in director

In [None]:
# 3. Load Model & Pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import json, os, re
from datetime import datetime
import gradio as gr

model_id = "Soorya03/Llama-3.2-1B-Instruct-FitnessAssistant"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

# 4. In-memory user profile & log management
user_profiles = {}
required_fields = ["name", "gender", "age", "weight", "height", "goal"]
os.makedirs("logs", exist_ok=True)

def get_or_create_profile(user_id):
    if user_id not in user_profiles:
        user_profiles[user_id] = {field: None for field in required_fields}
        user_profiles[user_id]["__greeted__"] = False
    return user_profiles[user_id]

def reset_profile(user_id):
    user_profiles[user_id] = {field: None for field in required_fields}
    user_profiles[user_id]["__greeted__"] = False

def parse_profile_input(text, profile):
    name_match = re.search(r"(my name is|i am|i'm)\s+([A-Za-z]+)", text, re.IGNORECASE)
    if profile["name"] is None and name_match:
        profile["name"] = name_match.group(2).capitalize()

    gender_match = re.search(r"\b(male|female|non-binary|man|woman|girl|boy)\b", text.lower())
    if profile["gender"] is None and gender_match:
        g = gender_match.group(1)
        profile["gender"] = "male" if g in ["man", "boy"] else "female" if g in ["woman", "girl"] else g

    age_match = re.search(r"\b(\d{1,2})\s*(y/o|yo|years|age)?", text.lower())
    weight_match = re.search(r"\b(\d{2,3})\s*kg", text.lower())
    height_match = re.search(r"\b(\d{2,3})\s*cm", text.lower())

    if profile["age"] is None and age_match:
        profile["age"] = int(age_match.group(1))
    if profile["weight"] is None and weight_match:
        profile["weight"] = int(weight_match.group(1))
    if profile["height"] is None and height_match:
        profile["height"] = int(height_match.group(1))

    goal_match = re.search(r"(?:goal is|i want to|trying to|aim to|my goal is|i'm looking to|i would like to|i wish to|i need to|i hope to|want to|gain|lose|be more|become)\s+([^\.\n]+)", text.lower())
    if profile["goal"] is None and goal_match:
        goal_text = goal_match.group(1).strip().capitalize()
        if not goal_text.lower().startswith("to"):
            goal_text = "To " + goal_text
        profile["goal"] = goal_text

def missing_fields(profile):
    return [k for k in required_fields if profile[k] is None]

def build_prompt(user_input, profile):
    user_context = json.dumps({k: profile[k] for k in required_fields})
    if any(word in user_input.lower() for word in ["plan", "routine", "diet", "exercise"]):
        format_hint = "\nRespond in JSON format with keys: task_type, exercise, meals, notes"
    else:
        format_hint = ""
    return f"User profile: {user_context}\nInstruction: {user_input}{format_hint}"

def save_log(user_id, user_input, bot_reply):
    log_path = f"logs/{user_id}.json"
    entry = {
        "timestamp": datetime.now().isoformat(),
        "user_input": user_input,
        "bot_reply": bot_reply
    }
    if os.path.exists(log_path):
        with open(log_path, "r") as f:
            data = json.load(f)
    else:
        data = []
    data.append(entry)
    with open(log_path, "w") as f:
        json.dump(data, f, indent=2)

def view_logs(user_id):
    log_path = f"logs/{user_id}.json"
    if not os.path.exists(log_path):
        return "No logs found yet."
    with open(log_path, "r") as f:
        logs = json.load(f)
    return "\n\n".join([
        f"📅 {entry['timestamp']}\n🙋 {entry['user_input']}\n🤖 {entry['bot_reply']}"
        for entry in logs
    ])

def chat(user_input, history, user_id):
    profile = get_or_create_profile(user_id)
    if not profile["__greeted__"]:
        profile["__greeted__"] = True
        return history + [["", "Hi there! 👋 Please tell me your name, gender, age, weight, height, and goal."]], ""

    parse_profile_input(user_input, profile)
    missing = missing_fields(profile)
    if missing:
        return history + [[user_input, f"Thanks! I still need your {', '.join(missing)}."]], ""

    prompt = build_prompt(user_input, profile)
    reply = pipe(prompt, max_new_tokens=300)[0]["generated_text"]

    try:
        if any(w in user_input.lower() for w in ["plan", "routine", "diet", "exercise"]):
            json.loads(reply)
    except json.JSONDecodeError:
        reply = "Sorry, I couldn't format that properly. Could you rephrase your request?"

    save_log(user_id, user_input, reply)
    return history + [[user_input, reply]], ""

# 5. RAG support - prepare corpus
os.makedirs("data", exist_ok=True)

rag_entries = [
    {
        "instruction": "",
        "input": "",
        "output": "",
        "content": "High-protein diets preserve lean mass during a cut.",
        "source": "PubMed_1",
        "tags": ["nutrition", "fat loss"]
    },
    {
        "instruction": "",
        "input": "",
        "output": "",
        "content": "Crossfit improves fat oxidation when paired with a moderate caloric deficit.",
        "source": "Study_2023",
        "tags": ["crossfit", "training"]
    }
]

with open("data/rag_corpus.jsonl", "w") as f:
    for entry in rag_entries:
        f.write(json.dumps(entry) + "\n")

# 6. Axolotl training config
os.makedirs("training", exist_ok=True)

with open("training/llama3_config.yaml", "w") as f:
    f.write('''
base_model: Soorya03/Llama-3.2-1B-Instruct-FitnessAssistant
dataset:
  path: data/fine_tune_dataset.jsonl
model_type: llama
load_in_4bit: true
trust_remote_code: true
lora:
  r: 8
  alpha: 16
  dropout: 0.05
tokenizer:
  type: Soorya03/Llama-3.2-1B-Instruct-FitnessAssistant
training:
  micro_batch_size: 2
  gradient_accumulation_steps: 4
  epochs: 3
  learning_rate: 2e-4
  output_dir: ./output
  eval_interval: 100
''')

# 7. Launch Gradio App
with gr.Blocks() as demo:
    gr.Markdown("## 🧘 AI Fitness Assistant Chat")
    user_id = gr.Textbox(label="Your ID", value="bas_123")
    chatbot = gr.Chatbot()
    msg = gr.Textbox(placeholder="Start by telling me your name, gender, age, weight, height, and goal...")
    update_btn = gr.Button("📝 Update Profile")
    view_btn = gr.Button("📖 View My Logs")
    logs_output = gr.Textbox(label="Your Log History", lines=10)

    msg.submit(chat, [msg, chatbot, user_id], [chatbot, msg])
    update_btn.click(lambda uid: (reset_profile(uid), [], "Profile reset! Please re-enter your info."),
                     inputs=[user_id], outputs=[chatbot, chatbot, msg])
    view_btn.click(view_logs, inputs=[user_id], outputs=[logs_output])

    demo.launch()


Device set to use cpu
  chatbot = gr.Chatbot()


Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://582dbd721f014930d4.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
