# Zenith Scientific Test 4: Energy Efficiency Auditor

**Objective:**
Measure the total energy consumption (Joules) of running a sustained AI workload. 
High performance often correlates with high power draw. We want to verify if Zenith improves **Energy Efficiency** (Tokens per Watt).

**Methodology:**
1.  Use `pynvml` (NVIDIA Management Library) to query GPU power usage in real-time.
2.  Run a sustained inference workload (e.g., generating 2000 tokens) for both PyTorch and Zenith.
3.  **Metric:** Total Energy (Joules) = Sum(Power_Watts * Time_Interval).

**Hardware Recom:** NVIDIA GPU is required for `pynvml`.

In [None]:
!pip install -q -U pyzenith torch pynvml matplotlib

import torch
import time
import threading
import pynvml
import matplotlib.pyplot as plt
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
import zenith

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

In [None]:
# Energy Monitor Class
class EnergyMonitor:
    def __init__(self, interval=0.1):
        self.interval = interval
        self.running = False
        self.readings = []
        self.timestamps = []
        try:
            pynvml.nvmlInit()
            self.handle = pynvml.nvmlDeviceGetHandleByIndex(0)
            self.available = True
        except:
            print("WARNING: NVML not available. Energy monitoring disabled.")
            self.available = False

    def start(self):
        if not self.available: return
        self.running = True
        self.readings = []
        self.timestamps = []
        self.start_time = time.time()
        self.thread = threading.Thread(target=self._monitor_loop)
        self.thread.start()

    def stop(self):
        self.running = False
        if self.available:
            self.thread.join()
        return self.get_total_energy()

    def _monitor_loop(self):
        while self.running:
            # Get power in milliWatts, convert to Watts
            power = pynvml.nvmlDeviceGetPowerUsage(self.handle) / 1000.0
            self.readings.append(power)
            self.timestamps.append(time.time() - self.start_time)
            time.sleep(self.interval)

    def get_total_energy(self):
        # Joules = Watts * Seconds
        # Simple Reimann Sum
        if not self.readings: return 0.0
        total_joules = sum(self.readings) * self.interval
        return total_joules

In [None]:
# Workload Setup
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="cuda")

input_text = "Explain the theory of relativity in simple terms."
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")

In [None]:
def run_energy_test(label, use_zenith=False):
    monitor = EnergyMonitor(interval=0.1)
    
    # Compile if needed
    current_model = model
    if use_zenith:
        print(f"Compiling {label} with Zenith...")
        current_model = torch.compile(model, backend="zenith")
        # Warmup
        current_model.generate(**inputs, max_new_tokens=10)
    
    print(f"Starting {label} workload...")
    monitor.start()
    
    # Sustained generation (e.g., 200 tokens)
    start_t = time.time()
    _ = current_model.generate(**inputs, max_new_tokens=200)
    end_t = time.time()
    
    total_joules = monitor.stop()
    duration = end_t - start_t
    
    print(f"Done {label}.")
    print(f"  Duration: {duration:.2f}s")
    print(f"  Energy:   {total_joules:.2f} Joules")
    print(f"  Avg Power: {total_joules/duration:.2f} Watts")
    
    return total_joules, duration, monitor

In [None]:
joules_py, time_py, mon_py = run_energy_test("PyTorch", use_zenith=False)
joules_zen, time_zen, mon_zen = run_energy_test("Zenith", use_zenith=True)

In [None]:
# Plotting Power Curve
plt.figure(figsize=(12, 6))
plt.plot(mon_py.timestamps, mon_py.readings, label=f'PyTorch ({joules_py:.1f} J)', color='gray')
plt.plot(mon_zen.timestamps, mon_zen.readings, label=f'Zenith ({joules_zen:.1f} J)', color='green')

plt.fill_between(mon_py.timestamps, mon_py.readings, color='gray', alpha=0.1)
plt.fill_between(mon_zen.timestamps, mon_zen.readings, color='green', alpha=0.1)

plt.title("GPU Power Consumption: Zenith vs PyTorch")
plt.xlabel("Time (s)")
plt.ylabel("Power (Watts)")
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig("zenith_energy_chart.png")
plt.show()

savings = ((joules_py - joules_zen) / joules_py) * 100
print(f"ENERGY SAVINGS: {savings:+.2f}%")