In [1]:
pip install torch torchvision tqdm scipy

Note: you may need to restart the kernel to use updated packages.


In [2]:
### BesselTorch_rbf

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.init as init

class RBFBANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_centers, alpha=1.0):
        super(RBFBANLayer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_centers = num_centers
        self.alpha = alpha

        self.centers = nn.Parameter(torch.empty(num_centers, input_dim))
        init.xavier_uniform_(self.centers)

        self.weights = nn.Parameter(torch.empty(num_centers, output_dim))
        init.xavier_uniform_(self.weights)

    def besselTorch_rbf(self, distances):
        return torch.special.bessel_j0(self.alpha * distances)

    def forward(self, x):
        distances = torch.cdist(x, self.centers)
        basis_values = self.besselTorch_rbf(distances)
        output = torch.sum(basis_values.unsqueeze(2) * self.weights.unsqueeze(0), dim=1)
        return output
class RBFBAN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_centers):
        super(RBFBAN, self).__init__()
        self.rbf_ban_layer = RBFBANLayer(input_dim, hidden_dim, num_centers)
        self.output_weights = nn.Parameter(torch.empty(hidden_dim, output_dim))
        init.xavier_uniform_(self.output_weights)

    def forward(self, x):
        x = self.rbf_ban_layer(x)
        x = torch.relu(x)
        x = torch.matmul(x, self.output_weights)
        return x

# Load MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
valset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

# Define model
model = RBFBAN(28 * 28, 64, 10, num_centers=100)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=5e-3)

# Define loss
criterion = nn.CrossEntropyLoss()

# Define ReduceLROnPlateau scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True)

for epoch in range(10):
    # Train
    model.train()
    total_loss = 0
    total_accuracy = 0
    with tqdm(trainloader) as pbar:
        for images, labels in pbar:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels).float().mean()
            total_loss += loss.item()
            total_accuracy += accuracy.item()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
    total_loss /= len(trainloader)
    total_accuracy /= len(trainloader)
    
    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            output = model(images)
            val_loss += criterion(output, labels).item()
            val_accuracy += (output.argmax(dim=1) == labels).float().mean().item()
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Step the scheduler based on validation loss
    scheduler.step(val_loss)

    print(f"Epoch {epoch + 1}, Train Loss: {total_loss}, Train Accuracy: {total_accuracy}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}")

100%|█████████████| 938/938 [00:04<00:00, 234.44it/s, accuracy=0.406, loss=1.65]


Epoch 1, Train Loss: 1.87021499834081, Train Accuracy: 0.2996568496801706, Val Loss: 1.716408386351956, Val Accuracy: 0.3206608280254777


100%|█████████████| 938/938 [00:03<00:00, 244.35it/s, accuracy=0.562, loss=1.32]


Epoch 2, Train Loss: 1.499628776044988, Train Accuracy: 0.476529184434968, Val Loss: 1.328058517662583, Val Accuracy: 0.5308519108280255


100%|███████████████| 938/938 [00:03<00:00, 239.72it/s, accuracy=0.5, loss=1.23]


Epoch 3, Train Loss: 1.2260652352879042, Train Accuracy: 0.5944329690831557, Val Loss: 1.0983463013248078, Val Accuracy: 0.6385350318471338


100%|██████████████| 938/938 [00:03<00:00, 261.73it/s, accuracy=0.75, loss=1.07]


Epoch 4, Train Loss: 1.0376472146526328, Train Accuracy: 0.6727245469083155, Val Loss: 0.9103060645662295, Val Accuracy: 0.7292993630573248


100%|████████████| 938/938 [00:03<00:00, 256.07it/s, accuracy=0.719, loss=0.872]


Epoch 5, Train Loss: 0.9184473901669353, Train Accuracy: 0.7120035980810234, Val Loss: 0.8273504148622987, Val Accuracy: 0.7497014331210191


100%|████████████| 938/938 [00:03<00:00, 255.13it/s, accuracy=0.688, loss=0.864]


Epoch 6, Train Loss: 0.8208027053743537, Train Accuracy: 0.7485840884861408, Val Loss: 0.7452503957186535, Val Accuracy: 0.7772691082802548


100%|████████████| 938/938 [00:03<00:00, 240.82it/s, accuracy=0.656, loss=0.812]


Epoch 7, Train Loss: 0.7574538363576698, Train Accuracy: 0.7685567697228145, Val Loss: 0.6940304759391553, Val Accuracy: 0.8030453821656051


100%|████████████| 938/938 [00:03<00:00, 244.04it/s, accuracy=0.625, loss=0.898]


Epoch 8, Train Loss: 0.721475255959578, Train Accuracy: 0.7782016257995735, Val Loss: 0.7495489970893617, Val Accuracy: 0.7504976114649682


100%|█████████████| 938/938 [00:03<00:00, 251.27it/s, accuracy=0.844, loss=0.75]


Epoch 9, Train Loss: 0.687601710973518, Train Accuracy: 0.7893123667377399, Val Loss: 0.6800394685594899, Val Accuracy: 0.7861265923566879


100%|████████████| 938/938 [00:03<00:00, 248.54it/s, accuracy=0.812, loss=0.691]


Epoch 10, Train Loss: 0.6703090688376538, Train Accuracy: 0.7936433901918977, Val Loss: 0.6135912938102795, Val Accuracy: 0.8184713375796179


In [4]:
### BesselScipy_rbf

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import scipy.special as sc

import torch
import torch.nn as nn
import torch.nn.init as init

class RBFBANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_centers, alpha=1.0, n=0):
        super(RBFBANLayer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_centers = num_centers
        self.alpha = alpha
        self.n = n # order of bessel function

        self.centers = nn.Parameter(torch.empty(num_centers, input_dim))
        init.xavier_uniform_(self.centers)

        self.weights = nn.Parameter(torch.empty(num_centers, output_dim))
        init.xavier_uniform_(self.weights)

    def besselScipy_rbf(self, distances):
        # Detach the tensor from the computation graph and convert to NumPy array
        distances_np = distances.detach().numpy()
        # Compute the Bessel function
        bessel_values = sc.jn(self.n, self.alpha * distances_np)
        # Convert back to PyTorch tensor
        return torch.from_numpy(bessel_values)

    def forward(self, x):
        distances = torch.cdist(x, self.centers)
        basis_values = self.besselScipy_rbf(distances)
        output = torch.sum(basis_values.unsqueeze(2) * self.weights.unsqueeze(0), dim=1)
        return output
class RBFBAN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_centers):
        super(RBFBAN, self).__init__()
        self.rbf_ban_layer = RBFBANLayer(input_dim, hidden_dim, num_centers)
        self.output_weights = nn.Parameter(torch.empty(hidden_dim, output_dim))
        init.xavier_uniform_(self.output_weights)

    def forward(self, x):
        x = self.rbf_ban_layer(x)
        x = torch.relu(x)
        x = torch.matmul(x, self.output_weights)
        return x

# Load MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
valset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

# Define model
model = RBFBAN(28 * 28, 64, 10, num_centers=100)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=5e-3)

# Define loss
criterion = nn.CrossEntropyLoss()

# Define ReduceLROnPlateau scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True)

for epoch in range(10):
    # Train
    model.train()
    total_loss = 0
    total_accuracy = 0
    with tqdm(trainloader) as pbar:
        for images, labels in pbar:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels).float().mean()
            total_loss += loss.item()
            total_accuracy += accuracy.item()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
    total_loss /= len(trainloader)
    total_accuracy /= len(trainloader)
    
    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            output = model(images)
            val_loss += criterion(output, labels).item()
            val_accuracy += (output.argmax(dim=1) == labels).float().mean().item()
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Step the scheduler based on validation loss
    scheduler.step(val_loss)

    print(f"Epoch {epoch + 1}, Train Loss: {total_loss}, Train Accuracy: {total_accuracy}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}")

100%|█████████████| 938/938 [00:05<00:00, 162.57it/s, accuracy=0.406, loss=1.53]


Epoch 1, Train Loss: 1.8568752814711793, Train Accuracy: 0.29907382729211085, Val Loss: 1.68206569267686, Val Accuracy: 0.3922173566878981


100%|█████████████| 938/938 [00:06<00:00, 155.55it/s, accuracy=0.531, loss=1.29]


Epoch 2, Train Loss: 1.5072906341379895, Train Accuracy: 0.47333089019189767, Val Loss: 1.3003853119103013, Val Accuracy: 0.5588176751592356


100%|█████████████| 938/938 [00:06<00:00, 155.09it/s, accuracy=0.656, loss=1.03]


Epoch 3, Train Loss: 1.180981670043616, Train Accuracy: 0.6121735074626866, Val Loss: 1.0423499957011764, Val Accuracy: 0.6480891719745223


100%|████████████| 938/938 [00:05<00:00, 164.61it/s, accuracy=0.781, loss=0.767]


Epoch 4, Train Loss: 0.9940195738760902, Train Accuracy: 0.6775053304904051, Val Loss: 0.8962098659983107, Val Accuracy: 0.7236265923566879


100%|████████████| 938/938 [00:06<00:00, 154.22it/s, accuracy=0.688, loss=0.958]


Epoch 5, Train Loss: 0.8852014968644327, Train Accuracy: 0.7174173773987207, Val Loss: 0.8240200844919605, Val Accuracy: 0.7423367834394905


100%|████████████████| 938/938 [00:05<00:00, 160.47it/s, accuracy=0.781, loss=1]


Epoch 6, Train Loss: 0.8188450021911532, Train Accuracy: 0.7373400852878464, Val Loss: 0.7476661201495274, Val Accuracy: 0.7688097133757962


100%|█████████████| 938/938 [00:05<00:00, 165.31it/s, accuracy=0.75, loss=0.692]


Epoch 7, Train Loss: 0.7679368628304142, Train Accuracy: 0.7552305437100213, Val Loss: 0.7876296681203659, Val Accuracy: 0.7270103503184714


100%|████████████| 938/938 [00:05<00:00, 160.05it/s, accuracy=0.812, loss=0.678]


Epoch 8, Train Loss: 0.7364979100697584, Train Accuracy: 0.7658582089552238, Val Loss: 0.6932593658091916, Val Accuracy: 0.7631369426751592


100%|████████████| 938/938 [00:05<00:00, 159.48it/s, accuracy=0.906, loss=0.377]


Epoch 9, Train Loss: 0.7100324122064403, Train Accuracy: 0.7737040245202559, Val Loss: 0.645116623419865, Val Accuracy: 0.7990644904458599


100%|████████████| 938/938 [00:06<00:00, 145.65it/s, accuracy=0.719, loss=0.912]


Epoch 10, Train Loss: 0.6788152822934742, Train Accuracy: 0.7879297707889126, Val Loss: 0.6776232677660171, Val Accuracy: 0.7826433121019108


In [6]:
### Yukawa function

In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.init as init

class RBFBANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_centers, alpha=0.5, beta=1.0):
        super(RBFBANLayer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_centers = num_centers
        self.alpha = alpha
        self.beta = beta

        self.centers = nn.Parameter(torch.empty(num_centers, input_dim))
        init.xavier_uniform_(self.centers)

        self.weights = nn.Parameter(torch.empty(num_centers, output_dim))
        init.xavier_uniform_(self.weights)

    def yukawa_rbf(self, distances):
        return (self.beta / distances) * torch.exp(-self.alpha * distances)

    def forward(self, x):
        distances = torch.cdist(x, self.centers)
        basis_values = self.yukawa_rbf(distances)
        output = torch.sum(basis_values.unsqueeze(2) * self.weights.unsqueeze(0), dim=1)
        return output

class RBFBAN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_centers):
        super(RBFBAN, self).__init__()
        self.rbf_ban_layer = RBFBANLayer(input_dim, hidden_dim, num_centers)
        self.output_weights = nn.Parameter(torch.empty(hidden_dim, output_dim))
        init.xavier_uniform_(self.output_weights)
        
    def forward(self, x):
        x = self.rbf_ban_layer(x)
        x = torch.relu(x)
        x = torch.matmul(x, self.output_weights)
        return x

# Load MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
valset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

# Define model
model = RBFBAN(28 * 28, 64, 10, num_centers=100)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-3)

# Define loss
criterion = nn.CrossEntropyLoss()

# Define ReduceLROnPlateau scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True)

for epoch in range(10):
    # Train
    model.train()
    total_loss = 0
    total_accuracy = 0
    with tqdm(trainloader) as pbar:
        for images, labels in pbar:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels).float().mean()
            total_loss += loss.item()
            total_accuracy += accuracy.item()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
    total_loss /= len(trainloader)
    total_accuracy /= len(trainloader)
    
    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            output = model(images)
            val_loss += criterion(output, labels).item()
            val_accuracy += (output.argmax(dim=1) == labels).float().mean().item()
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Step the scheduler based on validation loss
    scheduler.step(val_loss)

    print(f"Epoch {epoch + 1}, Train Loss: {total_loss}, Train Accuracy: {total_accuracy}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}")

100%|██████████████| 938/938 [00:04<00:00, 194.14it/s, accuracy=0.156, loss=2.3]


Epoch 1, Train Loss: 2.3025851272570805, Train Accuracy: 0.09683168976545842, Val Loss: 2.3025851143393545, Val Accuracy: 0.09723328025477707


100%|█████████████| 938/938 [00:04<00:00, 207.91it/s, accuracy=0.0625, loss=2.3]


Epoch 2, Train Loss: 2.3025850232984464, Train Accuracy: 0.09929704157782517, Val Loss: 2.3025849032553896, Val Accuracy: 0.09783041401273886


100%|█████████████| 938/938 [00:04<00:00, 199.32it/s, accuracy=0.0625, loss=2.3]


Epoch 3, Train Loss: 2.3025847869132883, Train Accuracy: 0.09869736140724947, Val Loss: 2.3025846147233513, Val Accuracy: 0.09783041401273886


100%|█████████████| 938/938 [00:04<00:00, 200.85it/s, accuracy=0.0625, loss=2.3]


Epoch 4, Train Loss: 2.302583869586367, Train Accuracy: 0.09869736140724947, Val Loss: 2.302582263946533, Val Accuracy: 0.09783041401273886


100%|█████████████| 938/938 [00:04<00:00, 202.27it/s, accuracy=0.156, loss=2.18]


Epoch 5, Train Loss: 2.2849717795975937, Train Accuracy: 0.09874733475479744, Val Loss: 2.1963904830300884, Val Accuracy: 0.09783041401273886


100%|████████████| 938/938 [00:04<00:00, 199.35it/s, accuracy=0.0625, loss=2.23]


Epoch 6, Train Loss: 2.162391730335984, Train Accuracy: 0.09869736140724947, Val Loss: 2.144921908712691, Val Accuracy: 0.09783041401273886


100%|████████████| 938/938 [00:04<00:00, 196.72it/s, accuracy=0.0625, loss=2.15]


Epoch 7, Train Loss: 2.1342803308450335, Train Accuracy: 0.09869736140724947, Val Loss: 2.127809716637727, Val Accuracy: 0.09783041401273886


100%|████████████| 938/938 [00:04<00:00, 198.81it/s, accuracy=0.0312, loss=2.27]


Epoch 8, Train Loss: 2.120784513985933, Train Accuracy: 0.09868070362473348, Val Loss: 2.1170110861966562, Val Accuracy: 0.09783041401273886


100%|████████████████| 938/938 [00:04<00:00, 197.25it/s, accuracy=0.125, loss=2]


Epoch 9, Train Loss: 2.1115353395943957, Train Accuracy: 0.09873067697228145, Val Loss: 2.109226473577463, Val Accuracy: 0.09783041401273886


100%|████████████| 938/938 [00:04<00:00, 198.87it/s, accuracy=0.0312, loss=2.24]


Epoch 10, Train Loss: 2.1038026725813777, Train Accuracy: 0.09868070362473348, Val Loss: 2.1012363130119955, Val Accuracy: 0.09783041401273886


In [8]:
### yukawa_rbf when beta value is greater than alpha by a huge difference

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.init as init

class RBFBANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_centers, alpha=0.2, beta=10):
        super(RBFBANLayer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_centers = num_centers
        self.alpha = alpha
        self.beta = beta

        self.centers = nn.Parameter(torch.empty(num_centers, input_dim))
        init.xavier_uniform_(self.centers)

        self.weights = nn.Parameter(torch.empty(num_centers, output_dim))
        init.xavier_uniform_(self.weights)

    def yukawa_rbf(self, distances):
        return (self.beta / distances) * torch.exp(-self.alpha * distances)

    def forward(self, x):
        distances = torch.cdist(x, self.centers)
        basis_values = self.yukawa_rbf(distances)
        output = torch.sum(basis_values.unsqueeze(2) * self.weights.unsqueeze(0), dim=1)
        return output

class RBFBAN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_centers):
        super(RBFBAN, self).__init__()
        self.rbf_ban_layer = RBFBANLayer(input_dim, hidden_dim, num_centers)
        self.output_weights = nn.Parameter(torch.empty(hidden_dim, output_dim))
        init.xavier_uniform_(self.output_weights)
        
    def forward(self, x):
        x = self.rbf_ban_layer(x)
        x = torch.relu(x)
        x = torch.matmul(x, self.output_weights)
        return x

# Load MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
valset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

# Define model
model = RBFBAN(28 * 28, 64, 10, num_centers=100)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-3)

# Define loss
criterion = nn.CrossEntropyLoss()

# Define ReduceLROnPlateau scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True)

for epoch in range(10):
    # Train
    model.train()
    total_loss = 0
    total_accuracy = 0
    with tqdm(trainloader) as pbar:
        for images, labels in pbar:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels).float().mean()
            total_loss += loss.item()
            total_accuracy += accuracy.item()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
    total_loss /= len(trainloader)
    total_accuracy /= len(trainloader)
    
    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            output = model(images)
            val_loss += criterion(output, labels).item()
            val_accuracy += (output.argmax(dim=1) == labels).float().mean().item()
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Step the scheduler based on validation loss
    scheduler.step(val_loss)

    print(f"Epoch {epoch + 1}, Train Loss: {total_loss}, Train Accuracy: {total_accuracy}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}")

100%|████████████| 938/938 [00:04<00:00, 191.59it/s, accuracy=0.812, loss=0.595]


Epoch 1, Train Loss: 1.596172207771842, Train Accuracy: 0.46095415778251597, Val Loss: 0.6992951026008387, Val Accuracy: 0.7898089171974523


100%|████████████| 938/938 [00:05<00:00, 185.18it/s, accuracy=0.875, loss=0.432]


Epoch 2, Train Loss: 0.561266076828498, Train Accuracy: 0.8400852878464818, Val Loss: 0.4426335073105849, Val Accuracy: 0.8748009554140127


100%|████████████| 938/938 [00:04<00:00, 193.44it/s, accuracy=0.812, loss=0.581]


Epoch 3, Train Loss: 0.40958247661018676, Train Accuracy: 0.8862440031982942, Val Loss: 0.3550494434252666, Val Accuracy: 0.8974920382165605


100%|████████████| 938/938 [00:04<00:00, 189.86it/s, accuracy=0.969, loss=0.183]


Epoch 4, Train Loss: 0.34617711991262334, Train Accuracy: 0.902318763326226, Val Loss: 0.3097242308650047, Val Accuracy: 0.9105294585987261


100%|████████████| 938/938 [00:04<00:00, 189.85it/s, accuracy=0.938, loss=0.371]


Epoch 5, Train Loss: 0.3075347076902893, Train Accuracy: 0.9129964019189766, Val Loss: 0.27957596540878155, Val Accuracy: 0.9186902866242038


100%|████████████| 938/938 [00:04<00:00, 188.49it/s, accuracy=0.969, loss=0.343]


Epoch 6, Train Loss: 0.27913960023348267, Train Accuracy: 0.9212253464818764, Val Loss: 0.2553291678499834, Val Accuracy: 0.9259554140127388


100%|█████████████| 938/938 [00:04<00:00, 197.32it/s, accuracy=0.938, loss=0.34]


Epoch 7, Train Loss: 0.256848146539253, Train Accuracy: 0.9270389125799574, Val Loss: 0.23541942690232187, Val Accuracy: 0.9326234076433121


100%|████████████| 938/938 [00:04<00:00, 192.17it/s, accuracy=0.844, loss=0.444]


Epoch 8, Train Loss: 0.2388445749513503, Train Accuracy: 0.9326359275053305, Val Loss: 0.2221835786654691, Val Accuracy: 0.934812898089172


100%|███████████████| 938/938 [00:04<00:00, 198.79it/s, accuracy=1, loss=0.0635]


Epoch 9, Train Loss: 0.2236059221393391, Train Accuracy: 0.9369003198294243, Val Loss: 0.20893652722903877, Val Accuracy: 0.9408837579617835


100%|████████████| 938/938 [00:05<00:00, 185.39it/s, accuracy=0.906, loss=0.241]


Epoch 10, Train Loss: 0.21059394469345682, Train Accuracy: 0.9409814765458422, Val Loss: 0.20094871865644767, Val Accuracy: 0.9410828025477707


In [10]:
### BesselScipy_rbf when n=1

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import scipy.special as sc

import torch
import torch.nn as nn
import torch.nn.init as init

class RBFBANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_centers, alpha=1.0, n=1):
        super(RBFBANLayer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_centers = num_centers
        self.alpha = alpha
        self.n = n # order of bessel function

        self.centers = nn.Parameter(torch.empty(num_centers, input_dim))
        init.xavier_uniform_(self.centers)

        self.weights = nn.Parameter(torch.empty(num_centers, output_dim))
        init.xavier_uniform_(self.weights)

    def besselScipy_rbf(self, distances):
        # Detach the tensor from the computation graph and convert to NumPy array
        distances_np = distances.detach().numpy()
        # Compute the Bessel function
        bessel_values = sc.jn(self.n, self.alpha * distances_np)
        # Convert back to PyTorch tensor
        return torch.from_numpy(bessel_values)

    def forward(self, x):
        distances = torch.cdist(x, self.centers)
        basis_values = self.besselScipy_rbf(distances)
        output = torch.sum(basis_values.unsqueeze(2) * self.weights.unsqueeze(0), dim=1)
        return output
class RBFBAN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_centers):
        super(RBFBAN, self).__init__()
        self.rbf_ban_layer = RBFBANLayer(input_dim, hidden_dim, num_centers)
        self.output_weights = nn.Parameter(torch.empty(hidden_dim, output_dim))
        init.xavier_uniform_(self.output_weights)

    def forward(self, x):
        x = self.rbf_ban_layer(x)
        x = torch.relu(x)
        x = torch.matmul(x, self.output_weights)
        return x

# Load MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
valset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

# Define model
model = RBFBAN(28 * 28, 64, 10, num_centers=100)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=5e-3)

# Define loss
criterion = nn.CrossEntropyLoss()

# Define ReduceLROnPlateau scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True)

for epoch in range(10):
    # Train
    model.train()
    total_loss = 0
    total_accuracy = 0
    with tqdm(trainloader) as pbar:
        for images, labels in pbar:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels).float().mean()
            total_loss += loss.item()
            total_accuracy += accuracy.item()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
    total_loss /= len(trainloader)
    total_accuracy /= len(trainloader)
    
    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            output = model(images)
            val_loss += criterion(output, labels).item()
            val_accuracy += (output.argmax(dim=1) == labels).float().mean().item()
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Step the scheduler based on validation loss
    scheduler.step(val_loss)

    print(f"Epoch {epoch + 1}, Train Loss: {total_loss}, Train Accuracy: {total_accuracy}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}")

100%|█████████████| 938/938 [00:06<00:00, 143.96it/s, accuracy=0.188, loss=2.13]


Epoch 1, Train Loss: 2.2611518719557253, Train Accuracy: 0.13919243070362472, Val Loss: 2.179722252924731, Val Accuracy: 0.14639729299363058


100%|█████████████████| 938/938 [00:06<00:00, 151.57it/s, accuracy=0.25, loss=2]


Epoch 2, Train Loss: 2.0673088154304766, Train Accuracy: 0.23051039445628999, Val Loss: 1.9820359755473531, Val Accuracy: 0.2770700636942675


100%|█████████████| 938/938 [00:06<00:00, 147.12it/s, accuracy=0.406, loss=1.89]


Epoch 3, Train Loss: 1.963187652228992, Train Accuracy: 0.26724080490405117, Val Loss: 1.9250115353590365, Val Accuracy: 0.3092157643312102


100%|█████████████| 938/938 [00:06<00:00, 152.62it/s, accuracy=0.438, loss=1.59]


Epoch 4, Train Loss: 1.9252259572431731, Train Accuracy: 0.28519789445628996, Val Loss: 1.8935781785636951, Val Accuracy: 0.2880175159235669


100%|█████████████| 938/938 [00:06<00:00, 151.45it/s, accuracy=0.375, loss=1.91]


Epoch 5, Train Loss: 1.900750492300306, Train Accuracy: 0.2957256130063966, Val Loss: 1.9287402007230527, Val Accuracy: 0.3209593949044586


100%|█████████████| 938/938 [00:07<00:00, 127.16it/s, accuracy=0.375, loss=1.83]


Epoch 6, Train Loss: 1.886056650422021, Train Accuracy: 0.3101512526652452, Val Loss: 1.8659079105231413, Val Accuracy: 0.30851910828025475


100%|█████████████| 938/938 [00:06<00:00, 140.00it/s, accuracy=0.219, loss=2.28]


Epoch 7, Train Loss: 1.8666329957020562, Train Accuracy: 0.32387726545842216, Val Loss: 1.8522085908112254, Val Accuracy: 0.33479299363057324


100%|██████████████| 938/938 [00:06<00:00, 143.68it/s, accuracy=0.25, loss=1.85]


Epoch 8, Train Loss: 1.8476072388417177, Train Accuracy: 0.3414678837953092, Val Loss: 1.8129395740047383, Val Accuracy: 0.381468949044586


100%|█████████████| 938/938 [00:06<00:00, 148.79it/s, accuracy=0.344, loss=1.86]


Epoch 9, Train Loss: 1.825779037816184, Train Accuracy: 0.3519789445628998, Val Loss: 1.7967256808736523, Val Accuracy: 0.35201035031847133


100%|█████████████| 938/938 [00:06<00:00, 148.02it/s, accuracy=0.438, loss=1.62]


Epoch 10, Train Loss: 1.8008560168463539, Train Accuracy: 0.36468883262260127, Val Loss: 1.8136734468921734, Val Accuracy: 0.3414609872611465


In [12]:
### BesselScipy_rbf when n=2

In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import scipy.special as sc

import torch
import torch.nn as nn
import torch.nn.init as init

class RBFBANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_centers, alpha=1.0, n=2):
        super(RBFBANLayer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_centers = num_centers
        self.alpha = alpha
        self.n = n # order of bessel function

        self.centers = nn.Parameter(torch.empty(num_centers, input_dim))
        init.xavier_uniform_(self.centers)

        self.weights = nn.Parameter(torch.empty(num_centers, output_dim))
        init.xavier_uniform_(self.weights)

    def besselScipy_rbf(self, distances):
        # Detach the tensor from the computation graph and convert to NumPy array
        distances_np = distances.detach().numpy()
        # Compute the Bessel function
        bessel_values = sc.jn(self.n, self.alpha * distances_np)
        # Convert back to PyTorch tensor
        return torch.from_numpy(bessel_values)

    def forward(self, x):
        distances = torch.cdist(x, self.centers)
        basis_values = self.besselScipy_rbf(distances)
        output = torch.sum(basis_values.unsqueeze(2) * self.weights.unsqueeze(0), dim=1)
        return output
class RBFBAN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_centers):
        super(RBFBAN, self).__init__()
        self.rbf_ban_layer = RBFBANLayer(input_dim, hidden_dim, num_centers)
        self.output_weights = nn.Parameter(torch.empty(hidden_dim, output_dim))
        init.xavier_uniform_(self.output_weights)

    def forward(self, x):
        x = self.rbf_ban_layer(x)
        x = torch.relu(x)
        x = torch.matmul(x, self.output_weights)
        return x

# Load MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
valset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

# Define model
model = RBFBAN(28 * 28, 64, 10, num_centers=100)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=5e-3)

# Define loss
criterion = nn.CrossEntropyLoss()

# Define ReduceLROnPlateau scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True)

for epoch in range(10):
    # Train
    model.train()
    total_loss = 0
    total_accuracy = 0
    with tqdm(trainloader) as pbar:
        for images, labels in pbar:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels).float().mean()
            total_loss += loss.item()
            total_accuracy += accuracy.item()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
    total_loss /= len(trainloader)
    total_accuracy /= len(trainloader)
    
    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            output = model(images)
            val_loss += criterion(output, labels).item()
            val_accuracy += (output.argmax(dim=1) == labels).float().mean().item()
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Step the scheduler based on validation loss
    scheduler.step(val_loss)

    print(f"Epoch {epoch + 1}, Train Loss: {total_loss}, Train Accuracy: {total_accuracy}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}")

100%|█████████████| 938/938 [00:06<00:00, 151.18it/s, accuracy=0.344, loss=1.59]


Epoch 1, Train Loss: 1.8216468393167198, Train Accuracy: 0.31378264925373134, Val Loss: 1.5644996879966395, Val Accuracy: 0.4394904458598726


100%|█████████████| 938/938 [00:06<00:00, 152.44it/s, accuracy=0.719, loss=1.06]


Epoch 2, Train Loss: 1.385536540291711, Train Accuracy: 0.5268190298507462, Val Loss: 1.1438822700719165, Val Accuracy: 0.6559514331210191


100%|█████████████| 938/938 [00:06<00:00, 154.17it/s, accuracy=0.875, loss=0.83]


Epoch 3, Train Loss: 1.0703157580483442, Train Accuracy: 0.6546008795309168, Val Loss: 0.9227082915366835, Val Accuracy: 0.7136743630573248


100%|████████████| 938/938 [00:06<00:00, 151.31it/s, accuracy=0.656, loss=0.771]


Epoch 4, Train Loss: 0.904788317837949, Train Accuracy: 0.712853144989339, Val Loss: 0.8441161953719558, Val Accuracy: 0.7079020700636943


100%|████████████| 938/938 [00:07<00:00, 128.71it/s, accuracy=0.719, loss=0.834]


Epoch 5, Train Loss: 0.8073424843074416, Train Accuracy: 0.746068763326226, Val Loss: 0.7917468697781775, Val Accuracy: 0.7528861464968153


100%|████████████| 938/938 [00:06<00:00, 145.66it/s, accuracy=0.812, loss=0.639]


Epoch 6, Train Loss: 0.7381506396699816, Train Accuracy: 0.7729211087420043, Val Loss: 0.6883984261257633, Val Accuracy: 0.7840366242038217


100%|████████████| 938/938 [00:07<00:00, 120.00it/s, accuracy=0.906, loss=0.609]


Epoch 7, Train Loss: 0.6956801519655724, Train Accuracy: 0.783315565031983, Val Loss: 0.645060041717663, Val Accuracy: 0.8021496815286624


100%|████████████| 938/938 [00:06<00:00, 146.83it/s, accuracy=0.906, loss=0.439]


Epoch 8, Train Loss: 0.6651310545485666, Train Accuracy: 0.7943097014925373, Val Loss: 0.689323677948326, Val Accuracy: 0.7765724522292994


100%|█████████████| 938/938 [00:06<00:00, 150.01it/s, accuracy=0.781, loss=0.87]


Epoch 9, Train Loss: 0.6381083912432575, Train Accuracy: 0.804670842217484, Val Loss: 0.5864686185766936, Val Accuracy: 0.8277269108280255


100%|█████████████| 938/938 [00:06<00:00, 134.04it/s, accuracy=0.75, loss=0.676]


Epoch 10, Train Loss: 0.6185251129652137, Train Accuracy: 0.8094849413646056, Val Loss: 0.5913325276724093, Val Accuracy: 0.8154856687898089


In [14]:
### BesselScipy_rbf when n=3

In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import scipy.special as sc

import torch
import torch.nn as nn
import torch.nn.init as init

class RBFBANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_centers, alpha=1.0, n=3):
        super(RBFBANLayer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_centers = num_centers
        self.alpha = alpha
        self.n = n # order of bessel function

        self.centers = nn.Parameter(torch.empty(num_centers, input_dim))
        init.xavier_uniform_(self.centers)

        self.weights = nn.Parameter(torch.empty(num_centers, output_dim))
        init.xavier_uniform_(self.weights)

    def besselScipy_rbf(self, distances):
        # Detach the tensor from the computation graph and convert to NumPy array
        distances_np = distances.detach().numpy()
        # Compute the Bessel function
        bessel_values = sc.jn(self.n, self.alpha * distances_np)
        # Convert back to PyTorch tensor
        return torch.from_numpy(bessel_values)

    def forward(self, x):
        distances = torch.cdist(x, self.centers)
        basis_values = self.besselScipy_rbf(distances)
        output = torch.sum(basis_values.unsqueeze(2) * self.weights.unsqueeze(0), dim=1)
        return output
class RBFBAN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_centers):
        super(RBFBAN, self).__init__()
        self.rbf_ban_layer = RBFBANLayer(input_dim, hidden_dim, num_centers)
        self.output_weights = nn.Parameter(torch.empty(hidden_dim, output_dim))
        init.xavier_uniform_(self.output_weights)

    def forward(self, x):
        x = self.rbf_ban_layer(x)
        x = torch.relu(x)
        x = torch.matmul(x, self.output_weights)
        return x

# Load MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
valset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

# Define model
model = RBFBAN(28 * 28, 64, 10, num_centers=100)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=5e-3)

# Define loss
criterion = nn.CrossEntropyLoss()

# Define ReduceLROnPlateau scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True)

for epoch in range(10):
    # Train
    model.train()
    total_loss = 0
    total_accuracy = 0
    with tqdm(trainloader) as pbar:
        for images, labels in pbar:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels).float().mean()
            total_loss += loss.item()
            total_accuracy += accuracy.item()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
    total_loss /= len(trainloader)
    total_accuracy /= len(trainloader)
    
    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            output = model(images)
            val_loss += criterion(output, labels).item()
            val_accuracy += (output.argmax(dim=1) == labels).float().mean().item()
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Step the scheduler based on validation loss
    scheduler.step(val_loss)

    print(f"Epoch {epoch + 1}, Train Loss: {total_loss}, Train Accuracy: {total_accuracy}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}")

100%|█████████████| 938/938 [00:06<00:00, 140.73it/s, accuracy=0.188, loss=2.24]


Epoch 1, Train Loss: 2.293046501653789, Train Accuracy: 0.11343949893390191, Val Loss: 2.2744884642825762, Val Accuracy: 0.11375398089171974


100%|███████████████| 938/938 [00:06<00:00, 138.44it/s, accuracy=0.25, loss=2.1]


Epoch 2, Train Loss: 2.2274325641233528, Train Accuracy: 0.1470882196162047, Val Loss: 2.1791342048887996, Val Accuracy: 0.16192277070063693


100%|█████████████| 938/938 [00:06<00:00, 135.25it/s, accuracy=0.281, loss=2.11]


Epoch 3, Train Loss: 2.158265300396917, Train Accuracy: 0.18700026652452026, Val Loss: 2.1349693901219946, Val Accuracy: 0.18740047770700638


100%|█████████████| 938/938 [00:06<00:00, 142.85it/s, accuracy=0.281, loss=2.08]


Epoch 4, Train Loss: 2.1289260505613234, Train Accuracy: 0.19871068763326227, Val Loss: 2.113569215604454, Val Accuracy: 0.20003980891719744


100%|█████████████| 938/938 [00:06<00:00, 137.81it/s, accuracy=0.312, loss=2.02]


Epoch 5, Train Loss: 2.1165375844247816, Train Accuracy: 0.2023420842217484, Val Loss: 2.104306097243242, Val Accuracy: 0.20431926751592358


100%|██████████████| 938/938 [00:07<00:00, 121.16it/s, accuracy=0.312, loss=1.9]


Epoch 6, Train Loss: 2.109882681227442, Train Accuracy: 0.2039079157782516, Val Loss: 2.0974174859417474, Val Accuracy: 0.2041202229299363


100%|█████████████| 938/938 [00:07<00:00, 126.90it/s, accuracy=0.219, loss=1.98]


Epoch 7, Train Loss: 2.106141815942996, Train Accuracy: 0.20452425373134328, Val Loss: 2.0936358430583004, Val Accuracy: 0.2044187898089172


100%|██████████████| 938/938 [00:06<00:00, 147.77it/s, accuracy=0.25, loss=2.05]


Epoch 8, Train Loss: 2.103309626518282, Train Accuracy: 0.20492404051172708, Val Loss: 2.091101906861469, Val Accuracy: 0.2047173566878981


100%|█████████████| 938/938 [00:06<00:00, 145.20it/s, accuracy=0.219, loss=2.07]


Epoch 9, Train Loss: 2.1015117198927826, Train Accuracy: 0.2052738539445629, Val Loss: 2.09147314964586, Val Accuracy: 0.20591162420382167


100%|█████████████| 938/938 [00:06<00:00, 145.91it/s, accuracy=0.281, loss=2.01]


Epoch 10, Train Loss: 2.0987612014132013, Train Accuracy: 0.20545708955223882, Val Loss: 2.0891462374644676, Val Accuracy: 0.20451831210191082


In [16]:
### BesselScipy_rbf when n=2 and alpha = 0.1

In [17]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import scipy.special as sc

import torch
import torch.nn as nn
import torch.nn.init as init

class RBFBANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_centers, alpha=0.1, n=2):
        super(RBFBANLayer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_centers = num_centers
        self.alpha = alpha
        self.n = n # order of bessel function

        self.centers = nn.Parameter(torch.empty(num_centers, input_dim))
        init.xavier_uniform_(self.centers)

        self.weights = nn.Parameter(torch.empty(num_centers, output_dim))
        init.xavier_uniform_(self.weights)

    def besselScipy_rbf(self, distances):
        # Detach the tensor from the computation graph and convert to NumPy array
        distances_np = distances.detach().numpy()
        # Compute the Bessel function
        bessel_values = sc.jn(self.n, self.alpha * distances_np)
        # Convert back to PyTorch tensor
        return torch.from_numpy(bessel_values)

    def forward(self, x):
        distances = torch.cdist(x, self.centers)
        basis_values = self.besselScipy_rbf(distances)
        output = torch.sum(basis_values.unsqueeze(2) * self.weights.unsqueeze(0), dim=1)
        return output
class RBFBAN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_centers):
        super(RBFBAN, self).__init__()
        self.rbf_ban_layer = RBFBANLayer(input_dim, hidden_dim, num_centers)
        self.output_weights = nn.Parameter(torch.empty(hidden_dim, output_dim))
        init.xavier_uniform_(self.output_weights)

    def forward(self, x):
        x = self.rbf_ban_layer(x)
        x = torch.relu(x)
        x = torch.matmul(x, self.output_weights)
        return x

# Load MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
valset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

# Define model
model = RBFBAN(28 * 28, 64, 10, num_centers=100)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=5e-3)

# Define loss
criterion = nn.CrossEntropyLoss()

# Define ReduceLROnPlateau scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True)

for epoch in range(10):
    # Train
    model.train()
    total_loss = 0
    total_accuracy = 0
    with tqdm(trainloader) as pbar:
        for images, labels in pbar:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels).float().mean()
            total_loss += loss.item()
            total_accuracy += accuracy.item()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
    total_loss /= len(trainloader)
    total_accuracy /= len(trainloader)
    
    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            output = model(images)
            val_loss += criterion(output, labels).item()
            val_accuracy += (output.argmax(dim=1) == labels).float().mean().item()
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Step the scheduler based on validation loss
    scheduler.step(val_loss)

    print(f"Epoch {epoch + 1}, Train Loss: {total_loss}, Train Accuracy: {total_accuracy}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}")

100%|██████████████| 938/938 [00:05<00:00, 170.15it/s, accuracy=0.125, loss=2.3]


Epoch 1, Train Loss: 2.3037120317345234, Train Accuracy: 0.09918043710021322, Val Loss: 2.302585126488072, Val Accuracy: 0.09783041401273886


100%|█████████████| 938/938 [00:05<00:00, 177.02it/s, accuracy=0.0625, loss=2.3]


Epoch 2, Train Loss: 2.3025851249694824, Train Accuracy: 0.09869736140724947, Val Loss: 2.302585126488072, Val Accuracy: 0.09783041401273886


100%|█████████████| 938/938 [00:05<00:00, 180.84it/s, accuracy=0.0938, loss=2.3]


Epoch 3, Train Loss: 2.3025851249694824, Train Accuracy: 0.09871401918976545, Val Loss: 2.302585126488072, Val Accuracy: 0.09783041401273886


100%|█████████████| 938/938 [00:05<00:00, 177.81it/s, accuracy=0.0625, loss=2.3]


Epoch 4, Train Loss: 2.3025851249694824, Train Accuracy: 0.09869736140724947, Val Loss: 2.302585126488072, Val Accuracy: 0.09783041401273886


100%|██████████████| 938/938 [00:05<00:00, 175.48it/s, accuracy=0.188, loss=2.3]


Epoch 5, Train Loss: 2.3025851249694824, Train Accuracy: 0.09876399253731344, Val Loss: 2.302585126488072, Val Accuracy: 0.09783041401273886


100%|██████████████████| 938/938 [00:05<00:00, 176.93it/s, accuracy=0, loss=2.3]


Epoch 6, Train Loss: 2.3025851249694824, Train Accuracy: 0.09866404584221748, Val Loss: 2.302585126488072, Val Accuracy: 0.09783041401273886


100%|█████████████| 938/938 [00:05<00:00, 174.31it/s, accuracy=0.0938, loss=2.3]


Epoch 7, Train Loss: 2.3025851249694824, Train Accuracy: 0.09871401918976545, Val Loss: 2.302585126488072, Val Accuracy: 0.09783041401273886


100%|█████████████| 938/938 [00:05<00:00, 179.13it/s, accuracy=0.0625, loss=2.3]


Epoch 8, Train Loss: 2.3025851249694824, Train Accuracy: 0.09869736140724947, Val Loss: 2.302585126488072, Val Accuracy: 0.09783041401273886


100%|█████████████| 938/938 [00:05<00:00, 179.06it/s, accuracy=0.0625, loss=2.3]


Epoch 9, Train Loss: 2.3025851249694824, Train Accuracy: 0.09869736140724947, Val Loss: 2.302585126488072, Val Accuracy: 0.09783041401273886


100%|██████████████| 938/938 [00:05<00:00, 166.69it/s, accuracy=0.156, loss=2.3]


Epoch 10, Train Loss: 2.3025851249694824, Train Accuracy: 0.09874733475479744, Val Loss: 2.302585126488072, Val Accuracy: 0.09783041401273886


In [18]:
### BesselScipy_rbf when n=2 and alpha = 10

In [19]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import scipy.special as sc

import torch
import torch.nn as nn
import torch.nn.init as init

class RBFBANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_centers, alpha=10, n=2):
        super(RBFBANLayer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_centers = num_centers
        self.alpha = alpha
        self.n = n # order of bessel function

        self.centers = nn.Parameter(torch.empty(num_centers, input_dim))
        init.xavier_uniform_(self.centers)

        self.weights = nn.Parameter(torch.empty(num_centers, output_dim))
        init.xavier_uniform_(self.weights)

    def besselScipy_rbf(self, distances):
        # Detach the tensor from the computation graph and convert to NumPy array
        distances_np = distances.detach().numpy()
        # Compute the Bessel function
        bessel_values = sc.jn(self.n, self.alpha * distances_np)
        # Convert back to PyTorch tensor
        return torch.from_numpy(bessel_values)

    def forward(self, x):
        distances = torch.cdist(x, self.centers)
        basis_values = self.besselScipy_rbf(distances)
        output = torch.sum(basis_values.unsqueeze(2) * self.weights.unsqueeze(0), dim=1)
        return output
class RBFBAN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_centers):
        super(RBFBAN, self).__init__()
        self.rbf_ban_layer = RBFBANLayer(input_dim, hidden_dim, num_centers)
        self.output_weights = nn.Parameter(torch.empty(hidden_dim, output_dim))
        init.xavier_uniform_(self.output_weights)

    def forward(self, x):
        x = self.rbf_ban_layer(x)
        x = torch.relu(x)
        x = torch.matmul(x, self.output_weights)
        return x

# Load MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
valset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

# Define model
model = RBFBAN(28 * 28, 64, 10, num_centers=100)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=5e-3)

# Define loss
criterion = nn.CrossEntropyLoss()

# Define ReduceLROnPlateau scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True)

for epoch in range(10):
    # Train
    model.train()
    total_loss = 0
    total_accuracy = 0
    with tqdm(trainloader) as pbar:
        for images, labels in pbar:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels).float().mean()
            total_loss += loss.item()
            total_accuracy += accuracy.item()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
    total_loss /= len(trainloader)
    total_accuracy /= len(trainloader)
    
    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            output = model(images)
            val_loss += criterion(output, labels).item()
            val_accuracy += (output.argmax(dim=1) == labels).float().mean().item()
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Step the scheduler based on validation loss
    scheduler.step(val_loss)

    print(f"Epoch {epoch + 1}, Train Loss: {total_loss}, Train Accuracy: {total_accuracy}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}")

100%|█████████████| 938/938 [00:05<00:00, 162.25it/s, accuracy=0.656, loss=1.16]


Epoch 1, Train Loss: 1.64636694673282, Train Accuracy: 0.45362473347547977, Val Loss: 1.1378585684830975, Val Accuracy: 0.6636146496815286


100%|██████████████| 938/938 [00:05<00:00, 161.05it/s, accuracy=0.75, loss=1.01]


Epoch 2, Train Loss: 0.9869466598099991, Train Accuracy: 0.7000099946695096, Val Loss: 0.848231544919834, Val Accuracy: 0.7459195859872612


100%|████████████| 938/938 [00:05<00:00, 160.67it/s, accuracy=0.875, loss=0.475]


Epoch 3, Train Loss: 0.8121030008170143, Train Accuracy: 0.7560800906183369, Val Loss: 0.751132014830401, Val Accuracy: 0.7727906050955414


100%|████████████| 938/938 [00:05<00:00, 159.09it/s, accuracy=0.656, loss=0.925]


Epoch 4, Train Loss: 0.7310588115186833, Train Accuracy: 0.7810167910447762, Val Loss: 0.6974633165225861, Val Accuracy: 0.790406050955414


100%|████████████| 938/938 [00:05<00:00, 166.33it/s, accuracy=0.844, loss=0.543]


Epoch 5, Train Loss: 0.68068932930925, Train Accuracy: 0.7965751599147122, Val Loss: 0.6627962190634126, Val Accuracy: 0.7965764331210191


100%|████████████| 938/938 [00:05<00:00, 161.17it/s, accuracy=0.844, loss=0.647]


Epoch 6, Train Loss: 0.6477712572637652, Train Accuracy: 0.8052538646055437, Val Loss: 0.6249354643047236, Val Accuracy: 0.8105095541401274


100%|████████████| 938/938 [00:06<00:00, 153.49it/s, accuracy=0.719, loss=0.766]


Epoch 7, Train Loss: 0.622089531630087, Train Accuracy: 0.8151319296375267, Val Loss: 0.612149286801648, Val Accuracy: 0.810609076433121


100%|████████████| 938/938 [00:05<00:00, 168.01it/s, accuracy=0.719, loss=0.709]


Epoch 8, Train Loss: 0.6003978619062061, Train Accuracy: 0.8217783848614072, Val Loss: 0.582731483563496, Val Accuracy: 0.8192675159235668


100%|█████████████| 938/938 [00:05<00:00, 169.06it/s, accuracy=0.781, loss=0.69]


Epoch 9, Train Loss: 0.5822623589717503, Train Accuracy: 0.825992803837953, Val Loss: 0.569099146849031, Val Accuracy: 0.8298168789808917


100%|████████████| 938/938 [00:05<00:00, 167.52it/s, accuracy=0.844, loss=0.702]


Epoch 10, Train Loss: 0.5674075498255585, Train Accuracy: 0.8325559701492538, Val Loss: 0.5668013189818449, Val Accuracy: 0.8303144904458599
