<a href="https://colab.research.google.com/github/ryanky31-code/model-training-gemma270m/blob/main/site/en/gemma/docs/core/huggingface_text_full_finetune_with_generator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##### Copyright 2025 Google LLC.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Full Model Fine-Tune using Hugging Face Transformers (with embedded dataset generator)

This notebook is based on the working example `huggingface_text_full_finetune.ipynb` but embeds a synthetic dataset generator and adapts the data-loading cells so you can run end-to-end in Colab or locally.

Sections:
- Install dependencies
- Generate synthetic dataset (small default for smoke tests)
- Save dataset and compute SHA256
- Convert CSV → conversational dataset (train/test)
- Load model & tokenizer
- Configure SFT and train
- Test inference

Defaults are set to small values so you can test quickly; increase sample counts and epochs for real training.

In [None]:
# Install dependencies (Colab friendly)
# Uncomment in Colab or run in your environment
# %pip install torch transformers datasets trl accelerate tensorboard sentencepiece protobuf pandas numpy

print('If running in Colab: enable the above pip install cell.')

## 1) Imports and reproducibility

In [None]:
import math
import json
import random
import hashlib
import os
import zipfile
from typing import List

import numpy as np
import pandas as pd

# Reproducibility
RSEED = 42
random.seed(RSEED)
np.random.seed(RSEED)

## 2) Dataset generator (embedded)

This is the user's generator adapted to run inside the notebook. Change `N_SAMPLES` for bigger sets.

In [None]:
# Parameters (adjust for smoke tests or full runs)
N_SAMPLES = 200  # small default for smoke tests; increase to 10000+ for real training
OUT_DIR = '/content' if os.path.exists('/content') else '.'
CSV_PATH = os.path.join(OUT_DIR, 'synthetic_wifi_5ghz_outdoor.csv')
ZIP_PATH = os.path.join(OUT_DIR, 'synthetic_wifi_5ghz_outdoor.zip')

# Constants
FREQS_MHZ = list(range(4900, 6101, 5))
ENVIRONMENTS = ["Urban", "Rural"]
DENSITIES    = ["Low", "Medium", "High"]
WEATHER_COND = ["Clear", "Cloudy", "Rain", "Fog", "Snow", "Storm"]
OBST_TYPES   = ["None", "Tree", "Building", "Vehicle", "Crane", "Billboard"]
TX_ANT_GAIN_DB = 15.0
RX_ANT_GAIN_DB = 15.0
BANDWIDTHS_MHZ = [20, 40, 80, 160]

# Helper functions
def haversine_m(lat1, lon1, lat2, lon2):
    R = 6371000.0
    p = math.pi / 180.0
    dlat = (lat2 - lat1) * p
    dlon = (lon2 - lon1) * p
    a = (math.sin(dlat / 2) ** 2 + math.cos(lat1 * p) * math.cos(lat2 * p) * math.sin(dlon / 2) ** 2)
    return 2 * R * math.asin(math.sqrt(a))

def snr_to_efficiency_bps_per_hz(snr_db: float) -> float:
    return 10.0 / (1.0 + math.exp(-(snr_db - 25.0) / 4.0))

def fspl_db(distance_m: float, freq_mhz: float) -> float:
    if distance_m < 1:
        distance_m = 1.0
    d_km = distance_m / 1000.0
    return 32.45 + 20 * math.log10(d_km) + 20 * math.log10(freq_mhz)

def weather_extra_loss_db(weather: str, distance_m: float) -> float:
    d_km = distance_m / 1000.0
    base = {"Clear": 0.0, "Cloudy": 0.2, "Fog": 0.6, "Rain": 0.8, "Snow": 0.9, "Storm": 1.5}[weather]
    return base * d_km

def density_obstruction_factor(env: str, density: str) -> float:
    return {"Urban": {"Low": 0.5, "Medium": 1.5, "High": 3.0},
            "Rural": {"Low": 0.1, "Medium": 0.4, "High": 0.8}}[env][density]

def fresnel_penalty_db(f_clear: float) -> float:
    if f_clear >= 90: return 0.0
    if f_clear >= 70: return 1.0
    if f_clear >= 50: return 3.0
    if f_clear >= 30: return 6.0
    return 10.0

def obstruction_penalty_db(obstructed: bool, obst_type: str) -> float:
    if not obstructed or obst_type == "None": return 0.0
    return {"Tree": 3.0, "Vehicle": 2.0, "Billboard": 4.0, "Building": 8.0, "Crane": 5.0}.get(obst_type, 2.0)


def generate_synthetic_row(scenario_id: int):
    environment = random.choice(ENVIRONMENTS)
    density = random.choice(DENSITIES)
    lat_a = random.uniform(33.0, 36.0)
    lon_a = random.uniform(35.0, 37.0)
    el_a = random.uniform(5, 900)
    lat_b = lat_a + random.uniform(-0.15, 0.15)
    lon_b = lon_a + random.uniform(-0.15, 0.15)
    el_b = random.uniform(5, 900)
    distance_m = haversine_m(lat_a, lon_a, lat_b, lon_b)
    weather = random.choice(WEATHER_COND)
    humidity = random.uniform(20, 95)
    temp_c = random.uniform(-5, 42)

    f_range = {"Urban": {"Low": (60, 100), "Medium": (40, 90), "High": (15, 80)},
               "Rural": {"Low": (85, 100), "Medium": (65, 100), "High": (45, 100)}}
    fresnel_clear = random.uniform(*f_range[environment][density])

    obstructed = random.random() < (0.65 if (environment == "Urban" and density == "High") else 0.35 if environment == "Urban" else 0.2 if density != "Low" else 0.1)
    obst_type = random.choice(OBST_TYPES if obstructed else ["None"])

    nf_range = {"Urban": {"Low": (-105, -92), "Medium": (-100, -88), "High": (-95, -82)},
                "Rural": {"Low": (-115, -102), "Medium": (-110, -98), "High": (-108, -96)}}
    noise_floor_dbm = random.uniform(*nf_range[environment][density])
    noise_dbm = noise_floor_dbm + random.uniform(0, 8)

    tx_power_dbm = random.uniform(10, 30)
    channel_bw_mhz = random.choice(BANDWIDTHS_MHZ)
    num_avail = random.randint(10, 50)
    available_channels = sorted(random.sample(FREQS_MHZ, k=num_avail))

    util_map = {("Urban", "Low"): (20, 60), ("Urban", "Medium"): (40, 80), ("Urban", "High"): (60, 98),
                ("Rural", "Low"): (0, 20), ("Rural", "Medium"): (10, 40), ("Rural", "High"): (20, 55)}
    util_pct = random.uniform(*util_map[(environment, density)])
    spectral_scan = {}
    for ch in available_channels:
        congestion_bump = np.random.normal(loc=util_pct / 100 * 8.0, scale=1.5)
        spectral_scan[ch] = noise_floor_dbm + 2.0 + max(0.0, congestion_bump)

    best = None
    for ch in available_channels:
        fspl = fspl_db(distance_m, ch)
        loss = (weather_extra_loss_db(weather, distance_m)
                + density_obstruction_factor(environment, density) * (distance_m / 1000.0)
                + fresnel_penalty_db(fresnel_clear)
                + obstruction_penalty_db(obstructed, obst_type))
        rssi = (tx_power_dbm + TX_ANT_GAIN_DB + RX_ANT_GAIN_DB) - fspl - loss
        interference_dbm = spectral_scan[ch]
        snr = rssi - interference_dbm
        score = snr - 0.25 * (interference_dbm - noise_floor_dbm)
        if (best is None) or (score > best[3]):
            best = (ch, rssi, snr, score)

    ch_best, rssi_best, snr_best, _ = best
    eff = snr_to_efficiency_bps_per_hz(max(-10.0, min(60.0, snr_best)))
    expected_throughput_mbps = (eff * channel_bw_mhz * 1e6) / 1e6

    return {
        "scenario_id": scenario_id,
        "device_a_coordinates": json.dumps([lat_a, lon_a, el_a]),
        "device_b_coordinates": json.dumps([lat_b, lon_b, el_b]),
        "link_distance_m": float(distance_m),
        "noise_dbm": float(noise_dbm),
        "noise_floor_dbm": float(noise_floor_dbm),
        "rssi_dbm": float(rssi_best),
        "snr_db": float(snr_best),
        "tx_power_dbm": float(tx_power_dbm),
        "channel_bandwidth_mhz": int(channel_bw_mhz),
        "channel_utilization_pct": float(util_pct),
        "available_channels_mhz": json.dumps(available_channels),
        "spectral_scan_dbm": json.dumps(spectral_scan),
        "fresnel_clear_pct": float(fresnel_clear),
        "weather_temp_c": float(temp_c),
        "weather_humidity_pct": float(humidity),
        "weather_condition": weather,
        "image_obstruction_detected": bool(obstructed),
        "image_obstruction_type": obst_type,
        "environment_type": environment,
        "area_density": density,
        "recommended_channel_mhz": int(ch_best),
        "expected_throughput_mbps": float(expected_throughput_mbps),
    }

In [None]:
# Generate dataset and save
rows = [generate_synthetic_row(i) for i in range(N_SAMPLES)]
df = pd.DataFrame(rows)
df.to_csv(CSV_PATH, index=False)
with zipfile.ZipFile(ZIP_PATH, 'w', compression=zipfile.ZIP_DEFLATED) as z:
    z.write(CSV_PATH, arcname=os.path.basename(CSV_PATH))

def sha256_of(path, chunk=1024*1024):
    h = hashlib.sha256()
    with open(path, 'rb') as f:
        for b in iter(lambda: f.read(chunk), b''):
            h.update(b)
    return h.hexdigest()

print(f"Saved {len(df):,} rows -> {CSV_PATH}")
print('CSV SHA256:', sha256_of(CSV_PATH))
print('ZIP SHA256:', sha256_of(ZIP_PATH))

# Show head
from IPython.display import display
print(df.head(2).to_dict(orient='records'))

## 4) Convert CSV -> conversational dataset (train/test)

We adapt the original notebook's dataset loading to read the CSV we just generated and convert each row to a message pair `{role: user, content: prompt}, {role: assistant, content: target}`. The default target is `recommended_channel_mhz`.

In [None]:
from datasets import Dataset

TARGET_FIELD = 'recommended_channel_mhz'  # change to expected_throughput_mbps if desired

# Build prompt and target

def build_prompt_from_row(r):
    return (f"Scenario id: {r['scenario_id']}\n"
            f"Distance (m): {r['link_distance_m']:.1f}\n"
            f"Env: {r['environment_type']} (density={r['area_density']})\n"
            f"Fresnel clear %: {r['fresnel_clear_pct']:.1f}\n"
            f"Weather: {r['weather_condition']}, temp C: {r['weather_temp_c']:.1f}\n"
            f"Noise floor (dBm): {r['noise_floor_dbm']:.1f}, RSSI (dBm): {r['rssi_dbm']:.1f}\n"
            f"Channel BW (MHz): {int(r['channel_bandwidth_mhz'])}\n\n"
            "Question: Based on the scenario above, provide the best channel in MHz as a single integer (no explanation).")

def build_target_from_row(r):
    return str(int(r[TARGET_FIELD])) if TARGET_FIELD == 'recommended_channel_mhz' else f"{float(r[TARGET_FIELD]):.2f}"

# Convert to messages
samples = []
for _, row in df.iterrows():
    prompt = build_prompt_from_row(row)
    target = build_target_from_row(row)
    messages = [{"role": "user", "content": prompt}, {"role": "assistant", "content": target}]
    samples.append({"messages": messages})

dataset = Dataset.from_pandas(pd.DataFrame(samples))
# train/test split
if len(dataset) > 20:
    dataset = dataset.train_test_split(test_size=0.1)
else:
    dataset = {'train': dataset}

print('Dataset prepared. Train size:', len(dataset['train']))
if 'test' in dataset:
    print('Test size:', len(dataset['test']))

## 5) (Optional) Mount Drive / Hugging Face login

If you're running in Colab, mount your Google Drive to save checkpoints and/or store your HF token in Colab userdata. Otherwise set `hf_token` as an environment variable.

In [None]:
# Colab-specific helpers (uncomment when using Colab)
# from google.colab import drive, userdata
# drive.mount('/content/drive')
# hf_token = userdata.get('HF_TOKEN')
# from huggingface_hub import login
# login(hf_token)

print('If running in Colab, mount drive and login to Hugging Face as needed.')

## 6) Load model & tokenizer (Gemma base)

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Base model (change as needed)
base_model = 'google/gemma-3-270m-it'
checkpoint_dir = os.path.join(OUT_DIR, 'gemma_finetune_checkpoint')

print('Loading model (this may require accepting the model license on HF)')
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    torch_dtype='auto',
    device_map='auto',
    attn_implementation='eager',
)
tokenizer = AutoTokenizer.from_pretrained(base_model)

print('Model device:', model.device)
print('Model dtype:', model.dtype)

## 7) Configure SFT and trainer (TRL)

In [None]:
from trl import SFTConfig, SFTTrainer

torch_dtype = model.dtype

args = SFTConfig(
    output_dir=checkpoint_dir,
    max_length=512,
    packing=False,
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_checkpointing=False,
    optim='adamw_torch_fused',
    logging_steps=1,
    save_strategy='epoch',
    eval_strategy='epoch',
    learning_rate=5e-5,
    fp16=True if torch_dtype == torch.float16 else False,
    bf16=True if torch_dtype == torch.bfloat16 else False,
    lr_scheduler_type='constant',
    push_to_hub=False,
    report_to='tensorboard',
    dataset_kwargs={
        'add_special_tokens': False,
        'append_concat_token': True,
    }
)

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['test'] if 'test' in dataset else None,
    processing_class=tokenizer,
)

print('Trainer created')

## 8) Train the model

In [None]:
# Start training (warning: this will use GPU and can be long)
trainer.train()
trainer.save_model()
print('Training finished; model saved to', checkpoint_dir)

## 9) Test model inference on test samples

In [None]:
from transformers import pipeline

pipe = pipeline('text-generation', model=model, tokenizer=tokenizer)

if 'test' in dataset:
    for i in range(min(5, len(dataset['test']))):
        sample = dataset['test'][i]
        prompt = pipe.tokenizer.apply_chat_template(sample['messages'][:1], tokenize=False, add_generation_prompt=True)
        outputs = pipe(prompt, max_new_tokens=64, disable_compile=True)
        print('Question:')
        print(sample['messages'][0]['content'])
        print('Original Answer:')
        print(sample['messages'][1]['content'])
        print('Generated Answer:')
        print(outputs[0]['generated_text'][len(prompt):].strip())
        print('-'*60)
else:
    print('No test split available to run inference on.')

## 10) Save artifacts and simple loader example

Demonstrates saving the generated CSV, the zipped archive, and how you could reload them for offline use. The trainer already saved model checkpoints to `checkpoint_dir`.

In [None]:
# Example: load CSV back
reloaded = pd.read_csv(CSV_PATH)
print('Reloaded rows:', len(reloaded))
print(reloaded.columns[:10])

# Example: load a saved model (path: checkpoint_dir)
# model = AutoModelForCausalLM.from_pretrained(checkpoint_dir)
# tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir)

print('Notebook completed. Increase N_SAMPLES and epochs for full training runs.')