In [1]:
import numpy as np
from sklearn.preprocessing import MinMaxScaler

import matplotlib.pyplot as plt
import seaborn as sns
from joblib import load

In [2]:
# Fungsi untuk memuat bobot
def load_model(file_path):
    data = np.load(file_path)
    W1 = data['W1']
    b1 = data['b1']
    W2 = data['W2']
    b2 = data['b2']
    W3 = data['W3']
    b3 = data['b3']
    input_dim = data['input_dim']
    hidden1_dim = data['hidden1_dim']
    hidden2_dim = data['hidden2_dim']
    output_dim = data['output_dim']
    return W1, b1, W2, b2, W3, b3, input_dim, hidden1_dim, hidden2_dim, output_dim

# Fungsi forward pass (sama seperti yang Anda buat)
def forward_pass(X, W1, b1, W2, b2, W3, b3):
    Z1 = np.dot(X, W1) + b1
    A1 = np.tanh(Z1)
    Z2 = np.dot(A1, W2) + b2
    A2 = np.tanh(Z2)
    Z3 = np.dot(A2, W3) + b3
    A3 = np.exp(Z3) / np.sum(np.exp(Z3), axis=1, keepdims=True)  # Softmax
    return A3  # Hanya return output softmax untuk prediksi


In [3]:
# Fungsi prediksi
def predict(X, W1, b1, W2, b2, W3, b3, classes, scaler):
    X_normalized = scaler.transform(X)  # Normalisasi input
    A3 = forward_pass(X_normalized, W1, b1, W2, b2, W3, b3)
    predictions = np.argmax(A3, axis=1)
    return classes[predictions]

In [12]:

# Muat model dan konfigurasi
model_file = 'model/stunted_model.npz'
W1, b1, W2, b2, W3, b3, input_dim, hidden1_dim, hidden2_dim, output_dim = load_model(model_file)

# Definisikan kelas (sesuaikan dengan yang ada di pelatihan)
status_gizi_classes = np.array(['Normal', 'Stunted', 'Wasted', 'Overweight'])  # Ganti dengan kelas aktual Anda

# Inisialisasi scaler (harus sama dengan yang digunakan saat pelatihan)
# Muat scaler yang disimpan
scaler = load('model/scaler.joblib')

# Contoh penggunaan dengan input baru
# Input: [Umur (bulan), Jenis Kelamin (0/1), Tinggi Badan (cm)]
new_data = np.array([[14, 1, 45],  # Contoh data baru
                     [12, 1, 40]])  # Contoh data baru

# Lakukan prediksi
predictions = predict(new_data, W1, b1, W2, b2, W3, b3, status_gizi_classes, scaler)
print("Prediksi status gizi untuk data baru:")
for i, pred in enumerate(predictions):
    print(f"Data {i+1}: {pred}")

Prediksi status gizi untuk data baru:
Data 1: Normal
Data 2: Normal
