<a href="https://colab.research.google.com/github/rcbusinesstechlab/realtime-face-recognition/blob/main/kan_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy.spatial.distance import directed_hausdorff
from torch.utils.data import DataLoader, TensorDataset


In [2]:
# ─── Spline Basis ─────────────────────────────────────────────
class SplineBasis(nn.Module):
    def __init__(self, num_basis=16, domain=(-1.0, 1.0)):
        super().__init__()
        self.num_basis = num_basis
        self.domain = domain
        self.register_buffer('knots', torch.linspace(domain[0], domain[1], num_basis))

    def forward(self, x):
        x = x.unsqueeze(-1)
        distances = torch.abs(x - self.knots)
        basis = torch.clamp(1 - distances * self.num_basis, min=0)
        return basis


In [9]:
# ─── KAN Layer 1D ─────────────────────────────────────────────
class KANLayer1D(nn.Module):
    def __init__(self, in_channels, out_channels, num_basis=16):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_basis = num_basis
        self.spline = SplineBasis(num_basis)
        self.coeffs = nn.Parameter(torch.randn(out_channels, in_channels, num_basis))

    def forward(self, x):
        # x shape: (batch_size, in_channels)
        basis_vals = self.spline(x)  # shape: (batch_size, in_channels, num_basis)
        out = torch.einsum('bic,oic->bo', basis_vals, self.coeffs)
        return out


In [10]:
# ─── KAN-Enhanced SE Block ───────────────────────────────────
class KANSEBlock(nn.Module):
    def __init__(self, in_channels, reduction=16, num_basis=16):
        super().__init__()
        mid_channels = max(1, in_channels // reduction)
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.kan1 = KANLayer1D(in_channels, mid_channels, num_basis)
        self.kan2 = KANLayer1D(mid_channels, in_channels, num_basis)
        self.activation = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        B, C, H, W = x.shape
        z = self.global_pool(x).view(B, C)
        h = self.activation(self.kan1(z))
        s = self.sigmoid(self.kan2(h)).view(B, C, 1, 1)
        return x * s


In [11]:
# ─── Evaluation Metrics ──────────────────────────────────────
def dice_coefficient(pred, target, epsilon=1e-6):
    pred = (pred > 0.5).float()
    target = (target > 0.5).float()
    intersection = (pred * target).sum(dim=(2, 3))
    union = pred.sum(dim=(2, 3)) + target.sum(dim=(2, 3))
    dice = (2. * intersection + epsilon) / (union + epsilon)
    return dice.mean().item()

def hd95(pred, target):
    pred_np = pred.squeeze().detach().cpu().numpy()
    target_np = target.squeeze().detach().cpu().numpy()
    pred_coords = np.argwhere(pred_np > 0)
    target_coords = np.argwhere(target_np > 0)
    if len(pred_coords) == 0 or len(target_coords) == 0:
        return np.inf
    hd_forward = directed_hausdorff(pred_coords, target_coords)[0]
    hd_backward = directed_hausdorff(target_coords, pred_coords)[0]
    return np.percentile([hd_forward, hd_backward], 95)


In [12]:
# ─── Training Loop ────────────────────────────────────────────
def train_model(model, dataloader, optimizer, criterion, num_epochs=5, device='cpu'):
    model.to(device)
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        for images, masks in dataloader:
            images, masks = images.to(device), masks.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss:.4f}")
    return model


In [13]:
# ─── Example CNN Using KANSEBlock ─────────────────────────────
class SampleKANUNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            KANSEBlock(out_channels)
        )

    def forward(self, x):
        return self.conv(x)


In [16]:
# ─── Synthetic Test ───────────────────────────────────────────
if __name__ == "__main__":
    # Synthetic data (B, C, H, W)
    images = torch.rand(10, 1, 64, 64)
    masks = (images > 0.5).float()

    dataset = TensorDataset(images, masks)
    dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

    # Model setup
    model = SampleKANUNetBlock(in_channels=1, out_channels=1)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.BCEWithLogitsLoss()

    # Train
    trained_model = train_model(model, dataloader, optimizer, criterion, num_epochs=100)

    # Evaluate
    model.eval()
    with torch.no_grad():
        for images, masks in dataloader:
            outputs = torch.sigmoid(model(images))
            dice = dice_coefficient(outputs, masks)
            hd = hd95(outputs, masks)
            print(f"Dice: {dice:.4f} | HD95: {hd:.2f}")
            break  # evaluate one batch


Epoch [1/100], Loss: 3.5367
Epoch [2/100], Loss: 3.5304
Epoch [3/100], Loss: 3.5248
Epoch [4/100], Loss: 3.5181
Epoch [5/100], Loss: 3.5111
Epoch [6/100], Loss: 3.5038
Epoch [7/100], Loss: 3.4964
Epoch [8/100], Loss: 3.4883
Epoch [9/100], Loss: 3.4800
Epoch [10/100], Loss: 3.4712
Epoch [11/100], Loss: 3.4617
Epoch [12/100], Loss: 3.4514
Epoch [13/100], Loss: 3.4402
Epoch [14/100], Loss: 3.4281
Epoch [15/100], Loss: 3.4151
Epoch [16/100], Loss: 3.4008
Epoch [17/100], Loss: 3.3857
Epoch [18/100], Loss: 3.3687
Epoch [19/100], Loss: 3.3510
Epoch [20/100], Loss: 3.3311
Epoch [21/100], Loss: 3.3095
Epoch [22/100], Loss: 3.2874
Epoch [23/100], Loss: 3.2628
Epoch [24/100], Loss: 3.2369
Epoch [25/100], Loss: 3.2154
Epoch [26/100], Loss: 3.1900
Epoch [27/100], Loss: 3.1665
Epoch [28/100], Loss: 3.1440
Epoch [29/100], Loss: 3.1218
Epoch [30/100], Loss: 3.1007
Epoch [31/100], Loss: 3.0804
Epoch [32/100], Loss: 3.0620
Epoch [33/100], Loss: 3.0451
Epoch [34/100], Loss: 3.0301
Epoch [35/100], Loss: 3