In [8]:
import os, math, time, random, warnings
warnings.filterwarnings("ignore")

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, random_split
from sklearn.datasets import load_iris, load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.utils import shuffle
from scipy.stats import ttest_rel, t
from typing import Tuple, Dict, List

# ===== MNIST için =====
try:
    from torchvision import datasets, transforms
except ImportError as e:
    raise SystemExit("torchvision gerekli. Lütfen: pip install torchvision") from e


# =========================================
# Yardımcılar
# =========================================
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def get_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def accuracy_from_logits(logits, y_true):
    preds = logits.argmax(dim=1)
    return (preds == y_true).float().mean().item()

def paired_stats(a: List[float], b: List[float]) -> Dict[str, float]:
    """a: ANN skorları, b: ONN skorları; her ikisi de aynı koşuların sonuçları"""
    a = np.asarray(a, dtype=float)
    b = np.asarray(b, dtype=float)
    d = b - a
    n = len(d)
    mean_a = a.mean(); std_a = a.std(ddof=1)
    mean_b = b.mean(); std_b = b.std(ddof=1)
    mean_diff = d.mean()
    sd_diff = d.std(ddof=1) if n > 1 else 0.0

    # paired t-test
    if n > 1 and sd_diff > 0:
        t_stat, p_val = ttest_rel(b, a)
        se = sd_diff / math.sqrt(n)
        tcrit = t.ppf(0.975, df=n-1)
        ci_low = mean_diff - tcrit * se
        ci_high = mean_diff + tcrit * se
        # Cohen's dz (eşleşik etki büyüklüğü)
        cohend = mean_diff / sd_diff
    else:
        t_stat, p_val, cohend, ci_low, ci_high = float("nan"), float("nan"), float("nan"), float("nan"), float("nan")

    return {
        "mean_a": mean_a, "std_a": std_a,
        "mean_b": mean_b, "std_b": std_b,
        "t": t_stat, "p": p_val, "dz": cohend,
        "ci_low": ci_low, "ci_high": ci_high
    }


# =========================================
# Veri setleri: Iris, Breast Cancer, MNIST
# =========================================
def load_iris_torch():
    X, y = load_iris(return_X_y=True)
    X, y = shuffle(X, y, random_state=42)
    scaler = StandardScaler()
    X = scaler.fit_transform(X)
    X = torch.tensor(X, dtype=torch.float32)
    y = torch.tensor(y, dtype=torch.long)
    # train/val/test = 60/20/20
    n = len(X)
    i1 = int(0.6 * n); i2 = int(0.8 * n)
    return (TensorDataset(X[:i1], y[:i1]),
            TensorDataset(X[i1:i2], y[i1:i2]),
            TensorDataset(X[i2:], y[i2:])), 4, 3

def load_breast_cancer_torch():
    X, y = load_breast_cancer(return_X_y=True)
    X, y = shuffle(X, y, random_state=42)
    scaler = StandardScaler()
    X = scaler.fit_transform(X)
    X = torch.tensor(X, dtype=torch.float32)
    y = torch.tensor(y, dtype=torch.long)
    n = len(X)
    i1 = int(0.6 * n); i2 = int(0.8 * n)
    return (TensorDataset(X[:i1], y[:i1]),
            TensorDataset(X[i1:i2], y[i1:i2]),
            TensorDataset(X[i2:], y[i2:])), X.shape[1], 2

def load_mnist_torch(download_root: str = "./data", limit_train:int=None, limit_test:int=None, val_ratio:float=0.1):
    """
    MNIST: 60k train, 10k test.
    limit_train/test ile hızlı denemeler için örnek sayısını kısıtlayabilirsin (None = tümü).
    """
    transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])  # 28x28 -> 784 flatten
    train_ds_full = datasets.MNIST(download_root, train=True, transform=transform, download=True)
    test_ds = datasets.MNIST(download_root, train=False, transform=transform, download=True)

    if limit_train is not None:
        train_ds_full = torch.utils.data.Subset(train_ds_full, list(range(min(limit_train, len(train_ds_full)))))
    if limit_test is not None:
        test_ds = torch.utils.data.Subset(test_ds, list(range(min(limit_test, len(test_ds)))))

    # Train/Val böl
    n_train = len(train_ds_full)
    n_val = int(val_ratio * n_train)
    n_tr = n_train - n_val
    train_ds, val_ds = random_split(train_ds_full, [n_tr, n_val], generator=torch.Generator().manual_seed(42))

    # Paket boyutu: tensör dataset'e dönüştürelim (Subset zaten tensörleri içeriyor)
    # Giriş boyutu 784, çıkış 10
    input_size = 28*28
    output_size = 10
    return (train_ds, val_ds, test_ds), input_size, output_size


# =========================================
# Modeller
# =========================================
class ClassicANN(nn.Module):
    def __init__(self, input_size, hidden=[128, 64], output_size=10, pdrop=0.1):
        super().__init__()
        layers = []
        last = input_size
        for h in hidden:
            layers += [nn.Linear(last, h), nn.ReLU(), nn.LayerNorm(h), nn.Dropout(pdrop)]
            last = h
        layers += [nn.Linear(last, output_size)]
        self.net = nn.Sequential(*layers)
        self.apply(self._init)

    @staticmethod
    def _init(m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight); nn.init.zeros_(m.bias)

    def forward(self, x):
        return self.net(x)

class ONN_Bio(nn.Module):
    """
    Basit MLP + biyolojik modüller:
    - STDP-benzeri local plasticity (etkinlik korelasyonu ile weight gate güncelleme)
    - Astrocyte modulation (aktivite düzeyine göre kazanç)
    - Homeostatic scaling (katman aktivitesi hedef aralığa çekilir)
    - Metabolic efficiency (L2 benzeri enerji reg.)
    - Hafif osilasyon (sin modülasyonu)
    Biyoloji kapatılabilir: bio_on=False => klasik MLP davranışı (aynı mimari)
    """
    def __init__(self, input_size, hidden=[128, 64], output_size=10,
                 pdrop=0.1,
                 bio_on=True,
                 plasticity=0.3,
                 astrocyte_modulation=0.5,
                 glia_neuron_interaction=0.5,
                 oscillatory_pattern=0.3,
                 homeostatic_scaling=1.0,
                 metabolic_efficiency=0.0):
        super().__init__()
        self.bio_on = bio_on
        self.plasticity = plasticity
        self.astro = astrocyte_modulation
        self.glia = glia_neuron_interaction
        self.osc = oscillatory_pattern
        self.homeo = homeostatic_scaling
        self.metabolic = metabolic_efficiency

        self.layers = nn.ModuleList()
        self.norms = nn.ModuleList()
        self.drops = nn.ModuleList()

        sizes = [input_size] + hidden + [output_size]
        for i in range(len(sizes) - 1):
            self.layers.append(nn.Linear(sizes[i], sizes[i+1]))
            if i < len(sizes)-1:  # hidden katmanlar için
                if i < len(hidden):  # son linear (output) hariç
                    self.norms.append(nn.LayerNorm(sizes[i+1]))
                    self.drops.append(nn.Dropout(pdrop))

        self.apply(self._init)
        self.t = 0  # osilasyon zamanı
        # Kapılayıcılar (gate) – ağırlık üzerine çarpımsal etki
        self.register_buffer("weight_gate_0", torch.ones(1))
        self.weight_gates = nn.ParameterList()
        for i in range(len(sizes)-1):
            # her katman için ayrı gate vektörü (out_features boyutunda)
            self.weight_gates.append(nn.Parameter(torch.ones(sizes[i+1]), requires_grad=False))

    @staticmethod
    def _init(m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight); nn.init.zeros_(m.bias)

    def _bio_modulate(self, x, h_idx):
        """
        Astrocyte + osilasyon + homeostatic gain modülasyonu.
        h_idx: 0..len(hidden)-1
        """
        if not self.bio_on or h_idx is None:
            return x
        # Astrocyte: aktivite tabanlı kazanç
        gain = 1.0 + self.astro * torch.tanh(x.mean(dim=0, keepdim=True))  # (1, features)
        # Osilasyon: zaman bazlı küçük sin modülasyonu
        if self.osc != 0:
            osc = 1.0 + 0.05 * math.sin(self.t * self.osc)
            gain = gain * osc
        # Homeostatic scaling: ortalama aktiviteyi hedefe çek
        if self.homeo != 0:
            target = 0.0
            mean_act = x.mean().item()
            gain = gain * (1.0 + 0.01 * self.homeo * (target - mean_act))
        return x * gain

    def _plasticity_update(self, pre, post, layer_idx):
        """
        Basit STDP-benzeri: pre-post korelasyonuna göre kapı (gate) güncelle.
        Gate değerleri ağırlığa çarpımsal etki yapar (effective W = W * gate).
        """
        if (not self.bio_on) or self.plasticity == 0:
            return
        with torch.no_grad():
            # pre:(B, in), post:(B, out) -> out-düzeyinde gate
            corr = torch.einsum("bi,bo->o", pre, post) / (pre.size(0) + 1e-6)  # (out,)
            corr = torch.tanh(corr)  # sınırlı etki
            g = self.weight_gates[layer_idx]
            g.data = torch.clamp(g.data + self.plasticity * 0.01 * corr, 0.5, 1.5)

    def forward(self, x):
        self.t += 1
        h = x
        pre = None
        norm_i = 0
        for i, layer in enumerate(self.layers):
            W = layer.weight
            b = layer.bias
            # Kapı uygulama (glia etkileşimi ile güçlendirilebilir)
            gate = self.weight_gates[i]
            eff_W = W * gate.unsqueeze(1) * (1.0 + 0.1 * self.glia if self.bio_on else 1.0)
            h = F.linear(h, eff_W, b)

            is_hidden = i < (len(self.layers)-1)
            if is_hidden:
                h = F.gelu(h)
                if self.bio_on:
                    # astro & homeo & osc
                    h = self._bio_modulate(h, h_idx=i)
                # LayerNorm + Dropout
                h = self.norms[norm_i](h)
                h = self.drops[norm_i](h)
                # plasticity update (pre, post)
                if pre is not None:
                    self._plasticity_update(pre, h, layer_idx=i)
                pre = h.detach()
                norm_i += 1

        # Metabolic regularizer: forward’da ceza değeri döndürmek için hook yerine dışarıdan erişeceğiz
        return h

    def metabolic_penalty(self):
        if not self.bio_on or self.metabolic == 0:
            return 0.0
        penalty = 0.0
        for p in self.parameters():
            if p.requires_grad and p.dim() >= 2:
                penalty = penalty + p.pow(2).mean()
        return self.metabolic * 1e-4 * penalty


# =========================================
# Eğitim & Değerlendirme
# =========================================
def train_one(model, train_loader, val_loader, device, epochs=10, lr=1e-3, weight_decay=0.0, patience=5):
    model.to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    best_val = -1.0
    best_state = None
    wait = 0

    for ep in range(epochs):
        model.train()
        for xb, yb in train_loader:
            xb = xb.to(device)
            yb = yb.to(device).long()
            opt.zero_grad()
            logits = model(xb)
            loss = F.cross_entropy(logits, yb)
            # ONN metabolik ceza
            if hasattr(model, "metabolic_penalty"):
                pen = model.metabolic_penalty()
                if isinstance(pen, torch.Tensor):
                    loss = loss + pen
                else:
                    loss = loss + torch.tensor(pen, device=xb.device)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()

        # val
        model.eval()
        val_accs = []
        with torch.no_grad():
            for xb, yb in val_loader:
                xb = xb.to(device)
                yb = yb.to(device).long()
                logits = model(xb)
                val_accs.append(accuracy_from_logits(logits, yb))
        cur_val = float(np.mean(val_accs))
        if cur_val > best_val:
            best_val = cur_val
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            wait = 0
        else:
            wait += 1
            if wait >= patience:
                break

    if best_state is not None:
        model.load_state_dict(best_state)
    return model

def test_acc(model, test_loader, device):
    model.to(device)
    model.eval()
    accs = []
    with torch.no_grad():
        for xb, yb in test_loader:
            xb = xb.to(device)
            yb = yb.to(device).long()
            logits = model(xb)
            accs.append(accuracy_from_logits(logits, yb))
    return float(np.mean(accs))


def run_dataset(name:str,
                loader_fn,
                input_size:int=None, output_size:int=None,
                hidden=[128,64],
                runs:int=5,
                epochs_small:int=50,
                epochs_mnist:int=5,
                batch_small:int=32,
                batch_mnist:int=128):
    """
    Iris / Breast Cancer gibi küçük setler -> epochs_small
    MNIST -> epochs_mnist (hız için)
    """
    device = get_device()
    ann_scores = []
    onn_scores = []

    # Veri
    (train_ds, val_ds, test_ds), in_sz, out_sz = loader_fn()

    # Loader seçimi
    if name.lower().startswith("mnist"):
        train_loader = DataLoader(train_ds, batch_size=batch_mnist, shuffle=True, drop_last=False)
        val_loader   = DataLoader(val_ds,   batch_size=batch_mnist, shuffle=False, drop_last=False)
        test_loader  = DataLoader(test_ds,  batch_size=batch_mnist, shuffle=False, drop_last=False)
        EPOCHS = epochs_mnist
    else:
        train_loader = DataLoader(train_ds, batch_size=batch_small, shuffle=True, drop_last=False)
        val_loader   = DataLoader(val_ds,   batch_size=batch_small, shuffle=False, drop_last=False)
        test_loader  = DataLoader(test_ds,  batch_size=batch_small, shuffle=False, drop_last=False)
        EPOCHS = epochs_small

    for r in range(runs):
        set_seed(42 + r)

        # ANN
        ann = ClassicANN(in_sz, hidden=hidden, output_size=out_sz, pdrop=0.1)
        ann = train_one(ann, train_loader, val_loader, device, epochs=EPOCHS, lr=1e-3, weight_decay=1e-4, patience=8)
        acc_ann = test_acc(ann, test_loader, device)
        ann_scores.append(100.0 * acc_ann)

        # ONN (bio açık)
        onn = ONN_Bio(in_sz, hidden=hidden, output_size=out_sz,
                      pdrop=0.1, bio_on=True,
                      plasticity=0.3,                # stable varsayılanlar
                      astrocyte_modulation=0.3,
                      glia_neuron_interaction=0.3,
                      oscillatory_pattern=0.1,
                      homeostatic_scaling=0.5,
                      metabolic_efficiency=0.0)
        onn = train_one(onn, train_loader, val_loader, device, epochs=EPOCHS, lr=1e-3, weight_decay=1e-4, patience=8)
        acc_onn = test_acc(onn, test_loader, device)
        onn_scores.append(100.0 * acc_onn)

    # İstatistik
    stats = paired_stats(ann_scores, onn_scores)
    print(f"\n=== {name} ===")
    print(f"Classic ANN: {np.mean(ann_scores):.2f}% ± {np.std(ann_scores, ddof=1):.2f}")
    print(f"ONN_Bio:    {np.mean(onn_scores):.2f}% ± {np.std(onn_scores, ddof=1):.2f}")
    print(f"P-value: {stats['p']:.4f}, Cohen's d(z): {stats['dz']:.3f}, 95% CI (ONN-ANN): [{stats['ci_low']:.2f}, {stats['ci_high']:.2f}]")
    return ann_scores, onn_scores, stats


# =========================================
# Çalıştır
# =========================================
if __name__ == "__main__":
    # MNIST indirme yolu
    os.makedirs("./data", exist_ok=True)

    # Veri yükleyiciler
    def iris_loader():
        return load_iris_torch()  # (train,val,test), in, out

    def bc_loader():
        return load_breast_cancer_torch()

    def mnist_loader():
        # Hızlı deneme için limitleri None bırak (tam 60k/10k). Hız istersen limit_train=20000, limit_test=5000 yap.
        return load_mnist_torch(download_root="./data", limit_train=None, limit_test=None, val_ratio=0.1)

    # Koşular
    iris_ann, iris_onn, iris_stats = run_dataset("Iris", iris_loader, hidden=[64, 32], runs=5, epochs_small=60)
    bc_ann, bc_onn, bc_stats = run_dataset("Breast Cancer", bc_loader, hidden=[128, 64], runs=5, epochs_small=60)
    mnist_ann, mnist_onn, mnist_stats = run_dataset("MNIST", mnist_loader, hidden=[256, 128], runs=3, epochs_mnist=5)



=== Iris ===
Classic ANN: 86.00% ± 4.35
ONN_Bio:    88.67% ± 5.58
P-value: 0.3739, Cohen's d(z): 0.447, 95% CI (ONN-ANN): [-4.74, 10.07]

=== Breast Cancer ===
Classic ANN: 95.97% ± 1.36
ONN_Bio:    96.28% ± 0.81
P-value: 0.6440, Cohen's d(z): 0.223, 95% CI (ONN-ANN): [-1.43, 2.05]
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:05<00:00, 1903852.37it/s]


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 145603.49it/s]


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:01<00:00, 1587903.87it/s]


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 821568.43it/s]


Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw


=== MNIST ===
Classic ANN: 97.79% ± 0.07
ONN_Bio:    97.91% ± 0.15
P-value: 0.3619, Cohen's d(z): 0.677, 95% CI (ONN-ANN): [-0.33, 0.59]
