#### multiquadratic_rbf

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.init as init

class RBFKANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_centers, alpha=1.0):
        super(RBFKANLayer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_centers = num_centers
        self.alpha = alpha

        self.centers = nn.Parameter(torch.empty(num_centers, input_dim))
        init.xavier_uniform_(self.centers)

        self.weights = nn.Parameter(torch.empty(num_centers, output_dim))
        init.xavier_uniform_(self.weights)

    def multiquadratic_rbf(self, distances):
        return (1 + (self.alpha * distances) ** 2) ** 0.5

    def forward(self, x):
        distances = torch.cdist(x, self.centers)
        basis_values = self.multiquadratic_rbf(distances)
        output = torch.sum(basis_values.unsqueeze(2) * self.weights.unsqueeze(0), dim=1)
        return output

class RBFKAN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_centers):
        super(RBFKAN, self).__init__()
        self.rbf_kan_layer = RBFKANLayer(input_dim, hidden_dim, num_centers)
        self.output_weights = nn.Parameter(torch.empty(hidden_dim, output_dim))
        init.xavier_uniform_(self.output_weights)

    def forward(self, x):
        x = self.rbf_kan_layer(x)
        x = torch.relu(x)
        x = torch.matmul(x, self.output_weights)
        return x

# Load MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
valset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

# Define model
model = RBFKAN(28 * 28, 64, 10, num_centers=100)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=5e-3)

# Define loss
criterion = nn.CrossEntropyLoss()

# Define ReduceLROnPlateau scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True)

for epoch in range(10):
    # Train
    model.train()
    total_loss = 0
    total_accuracy = 0
    with tqdm(trainloader) as pbar:
        for images, labels in pbar:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels).float().mean()
            total_loss += loss.item()
            total_accuracy += accuracy.item()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
    total_loss /= len(trainloader)
    total_accuracy /= len(trainloader)

    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            output = model(images)
            val_loss += criterion(output, labels).item()
            val_accuracy += (output.argmax(dim=1) == labels).float().mean().item()
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Step the scheduler based on validation loss
    scheduler.step(val_loss)

    print(f"Epoch {epoch + 1}, Train Loss: {total_loss}, Train Accuracy: {total_accuracy}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}")

100%|██████████| 938/938 [00:21<00:00, 44.12it/s, accuracy=0.875, loss=0.344]


Epoch 1, Train Loss: 1.3331204660729306, Train Accuracy: 0.6869336353944563, Val Loss: 0.5092402221101104, Val Accuracy: 0.8365843949044586


100%|██████████| 938/938 [00:21<00:00, 44.62it/s, accuracy=0.844, loss=0.754]


Epoch 2, Train Loss: 0.5317786001383877, Train Accuracy: 0.8356876332622601, Val Loss: 0.45442688009541504, Val Accuracy: 0.856687898089172


100%|██████████| 938/938 [00:20<00:00, 44.79it/s, accuracy=0.875, loss=0.596]


Epoch 3, Train Loss: 0.47396965979385985, Train Accuracy: 0.8546775053304904, Val Loss: 0.4623399981932276, Val Accuracy: 0.85828025477707


100%|██████████| 938/938 [00:20<00:00, 44.87it/s, accuracy=0.969, loss=0.129] 


Epoch 4, Train Loss: 0.4533071270041755, Train Accuracy: 0.8618903251599147, Val Loss: 0.4867182159974317, Val Accuracy: 0.8470342356687898


100%|██████████| 938/938 [00:21<00:00, 42.90it/s, accuracy=0.812, loss=0.5]  


Epoch 5, Train Loss: 0.44235368565455685, Train Accuracy: 0.863339552238806, Val Loss: 0.40798378541211416, Val Accuracy: 0.8772890127388535


100%|██████████| 938/938 [00:20<00:00, 44.96it/s, accuracy=0.938, loss=0.162] 


Epoch 6, Train Loss: 0.4230411807134716, Train Accuracy: 0.8704857409381663, Val Loss: 0.3293195556445866, Val Accuracy: 0.8992834394904459


100%|██████████| 938/938 [00:20<00:00, 44.86it/s, accuracy=0.719, loss=0.674] 


Epoch 7, Train Loss: 0.4199728215141083, Train Accuracy: 0.8709021855010661, Val Loss: 0.48304354451644194, Val Accuracy: 0.8533041401273885


100%|██████████| 938/938 [00:20<00:00, 45.21it/s, accuracy=0.875, loss=0.562]


Epoch 8, Train Loss: 0.3942965001487401, Train Accuracy: 0.8792977078891258, Val Loss: 0.3657194176202367, Val Accuracy: 0.8863455414012739


100%|██████████| 938/938 [00:20<00:00, 45.37it/s, accuracy=0.875, loss=0.421]


Epoch 9, Train Loss: 0.3929344529885727, Train Accuracy: 0.8813299573560768, Val Loss: 0.44346745098662227, Val Accuracy: 0.8609673566878981


100%|██████████| 938/938 [00:20<00:00, 44.86it/s, accuracy=0.844, loss=0.55]  


Epoch 10, Train Loss: 0.3687275981566291, Train Accuracy: 0.8889259061833689, Val Loss: 0.32881565809629526, Val Accuracy: 0.8995820063694268


#### thin_plate_spline_rbf

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.init as init

class RBFKANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_centers, alpha=1.0):
        super(RBFKANLayer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_centers = num_centers
        self.alpha = alpha

        self.centers = nn.Parameter(torch.empty(num_centers, input_dim))
        init.xavier_uniform_(self.centers)

        self.weights = nn.Parameter(torch.empty(num_centers, output_dim))
        init.xavier_uniform_(self.weights)

    def thin_plate_spline_rbf(self, distances):
        return distances ** 2 * torch.log(self.alpha * distances)

    def forward(self, x):
        distances = torch.cdist(x, self.centers)
        basis_values = self.thin_plate_spline_rbf(distances)
        output = torch.sum(basis_values.unsqueeze(2) * self.weights.unsqueeze(0), dim=1)
        return output

class RBFKAN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_centers):
        super(RBFKAN, self).__init__()
        self.rbf_kan_layer = RBFKANLayer(input_dim, hidden_dim, num_centers)
        self.output_weights = nn.Parameter(torch.empty(hidden_dim, output_dim))
        init.xavier_uniform_(self.output_weights)

    def forward(self, x):
        x = self.rbf_kan_layer(x)
        x = torch.relu(x)
        x = torch.matmul(x, self.output_weights)
        return x

# Load MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
valset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

# Define model
model = RBFKAN(28 * 28, 64, 10, num_centers=100)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-3)

# Define loss
criterion = nn.CrossEntropyLoss()

# Define ReduceLROnPlateau scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True)

for epoch in range(10):
    # Train
    model.train()
    total_loss = 0
    total_accuracy = 0
    with tqdm(trainloader) as pbar:
        for images, labels in pbar:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels).float().mean()
            total_loss += loss.item()
            total_accuracy += accuracy.item()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
    total_loss /= len(trainloader)
    total_accuracy /= len(trainloader)

    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            output = model(images)
            val_loss += criterion(output, labels).item()
            val_accuracy += (output.argmax(dim=1) == labels).float().mean().item()
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Step the scheduler based on validation loss
    scheduler.step(val_loss)

    print(f"Epoch {epoch + 1}, Train Loss: {total_loss}, Train Accuracy: {total_accuracy}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}")

100%|██████████| 938/938 [00:22<00:00, 40.91it/s, accuracy=0.844, loss=5.46]   


Epoch 1, Train Loss: 52.49743112896297, Train Accuracy: 0.6394422974413646, Val Loss: 10.743994658919656, Val Accuracy: 0.6907842356687898


100%|██████████| 938/938 [00:23<00:00, 39.20it/s, accuracy=0.625, loss=16.2] 


Epoch 2, Train Loss: 12.287046090626259, Train Accuracy: 0.7488672707889126, Val Loss: 21.633233185786352, Val Accuracy: 0.6130573248407644


100%|██████████| 938/938 [00:24<00:00, 38.85it/s, accuracy=0.844, loss=9.14] 


Epoch 3, Train Loss: 10.788128211681268, Train Accuracy: 0.790944829424307, Val Loss: 12.194597823604656, Val Accuracy: 0.7600517515923567


100%|██████████| 938/938 [00:23<00:00, 39.23it/s, accuracy=0.656, loss=24.9]  


Epoch 4, Train Loss: 10.22712245590683, Train Accuracy: 0.7996401918976546, Val Loss: 8.237443765433078, Val Accuracy: 0.8284235668789809


100%|██████████| 938/938 [00:23<00:00, 39.36it/s, accuracy=0.781, loss=5.55] 


Epoch 5, Train Loss: 9.040300528187233, Train Accuracy: 0.8012226812366737, Val Loss: 6.179725476131318, Val Accuracy: 0.8075238853503185


100%|██████████| 938/938 [00:24<00:00, 37.89it/s, accuracy=0.688, loss=8.43] 


Epoch 6, Train Loss: 6.757920517087745, Train Accuracy: 0.8190631663113006, Val Loss: 6.748548487331837, Val Accuracy: 0.8248407643312102


100%|██████████| 938/938 [00:23<00:00, 39.35it/s, accuracy=0.938, loss=2.86] 


Epoch 7, Train Loss: 5.771717983951319, Train Accuracy: 0.8263925906183369, Val Loss: 3.3508190594516756, Val Accuracy: 0.876890923566879


100%|██████████| 938/938 [00:22<00:00, 41.06it/s, accuracy=0.875, loss=2.17] 


Epoch 8, Train Loss: 4.417131943012605, Train Accuracy: 0.824160447761194, Val Loss: 6.2312569271796825, Val Accuracy: 0.7412420382165605


100%|██████████| 938/938 [00:21<00:00, 43.90it/s, accuracy=0.719, loss=2.91] 


Epoch 9, Train Loss: 3.2719493453373025, Train Accuracy: 0.8138159648187633, Val Loss: 1.7306252520911063, Val Accuracy: 0.8607683121019108


100%|██████████| 938/938 [00:20<00:00, 45.91it/s, accuracy=0.781, loss=0.692]


Epoch 10, Train Loss: 1.5846328088707888, Train Accuracy: 0.7624933368869936, Val Loss: 0.8684299250317228, Val Accuracy: 0.7932921974522293


#### gaussian_rbf

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.init as init

class RBFKANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_centers, alpha=1.0):
        super(RBFKANLayer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_centers = num_centers
        self.alpha = alpha

        self.centers = nn.Parameter(torch.empty(num_centers, input_dim))
        init.xavier_uniform_(self.centers)

        self.weights = nn.Parameter(torch.empty(num_centers, output_dim))
        init.xavier_uniform_(self.weights)

    def gaussian_rbf(self, distances):
        return torch.exp(-self.alpha * distances ** 2)

    def forward(self, x):
        distances = torch.cdist(x, self.centers)
        basis_values = self.gaussian_rbf(distances)
        output = torch.sum(basis_values.unsqueeze(2) * self.weights.unsqueeze(0), dim=1)
        return output

class RBFKAN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_centers):
        super(RBFKAN, self).__init__()
        self.rbf_kan_layer = RBFKANLayer(input_dim, hidden_dim, num_centers)
        self.output_weights = nn.Parameter(torch.empty(hidden_dim, output_dim))
        init.xavier_uniform_(self.output_weights)

    def forward(self, x):
        x = self.rbf_kan_layer(x)
        x = torch.relu(x)
        x = torch.matmul(x, self.output_weights)
        return x

# Load MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
valset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

# Define model
model = RBFKAN(28 * 28, 64, 10, num_centers=100)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-3)

# Define loss
criterion = nn.CrossEntropyLoss()

# Define ReduceLROnPlateau scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True)

for epoch in range(10):
    # Train
    model.train()
    total_loss = 0
    total_accuracy = 0
    with tqdm(trainloader) as pbar:
        for images, labels in pbar:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels).float().mean()
            total_loss += loss.item()
            total_accuracy += accuracy.item()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
    total_loss /= len(trainloader)
    total_accuracy /= len(trainloader)

    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            output = model(images)
            val_loss += criterion(output, labels).item()
            val_accuracy += (output.argmax(dim=1) == labels).float().mean().item()
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Step the scheduler based on validation loss
    scheduler.step(val_loss)

    print(f"Epoch {epoch + 1}, Train Loss: {total_loss}, Train Accuracy: {total_accuracy}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}")

#### inverse_quadric

In [18]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.init as init

class RBFKANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_centers, alpha=1.0):
        super(RBFKANLayer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_centers = num_centers
        self.alpha = alpha

        self.centers = nn.Parameter(torch.empty(num_centers, input_dim))
        init.xavier_uniform_(self.centers)

        self.weights = nn.Parameter(torch.empty(num_centers, output_dim))
        init.xavier_uniform_(self.weights)

    def inverse_quadric(self, distances):
        return 1 / (1 + (self.alpha * distances) ** 2)

    def forward(self, x):
        distances = torch.cdist(x, self.centers)
        basis_values = self.inverse_quadric(distances)
        output = torch.sum(basis_values.unsqueeze(2) * self.weights.unsqueeze(0), dim=1)
        return output

class RBFKAN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_centers):
        super(RBFKAN, self).__init__()
        self.rbf_kan_layer = RBFKANLayer(input_dim, hidden_dim, num_centers)
        self.output_weights = nn.Parameter(torch.empty(hidden_dim, output_dim))
        init.xavier_uniform_(self.output_weights)

    def forward(self, x):
        x = self.rbf_kan_layer(x)
        x = torch.relu(x)
        x = torch.matmul(x, self.output_weights)
        return x

# Load MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
valset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

# Define model
model = RBFKAN(28 * 28, 64, 10, num_centers=100)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=5e-3)

# Define loss
criterion = nn.CrossEntropyLoss()

# Define ReduceLROnPlateau scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True)

for epoch in range(10):
    # Train
    model.train()
    total_loss = 0
    total_accuracy = 0
    with tqdm(trainloader) as pbar:
        for images, labels in pbar:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels).float().mean()
            total_loss += loss.item()
            total_accuracy += accuracy.item()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
    total_loss /= len(trainloader)
    total_accuracy /= len(trainloader)

    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            output = model(images)
            val_loss += criterion(output, labels).item()
            val_accuracy += (output.argmax(dim=1) == labels).float().mean().item()
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Step the scheduler based on validation loss
    scheduler.step(val_loss)

    print(f"Epoch {epoch + 1}, Train Loss: {total_loss}, Train Accuracy: {total_accuracy}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}")

100%|██████████| 938/938 [00:20<00:00, 46.50it/s, accuracy=0.688, loss=1.12] 


Epoch 1, Train Loss: 1.5977591815025314, Train Accuracy: 0.4398654051172708, Val Loss: 0.9699718990143696, Val Accuracy: 0.7251194267515924


100%|██████████| 938/938 [00:20<00:00, 45.45it/s, accuracy=0.719, loss=0.707]


Epoch 2, Train Loss: 0.7924936451256148, Train Accuracy: 0.7754697494669509, Val Loss: 0.6597338320723005, Val Accuracy: 0.8110071656050956


100%|██████████| 938/938 [00:20<00:00, 45.70it/s, accuracy=0.875, loss=0.535]


Epoch 3, Train Loss: 0.6203015694168331, Train Accuracy: 0.8271755063965884, Val Loss: 0.5615962621322863, Val Accuracy: 0.8405652866242038


100%|██████████| 938/938 [00:20<00:00, 46.63it/s, accuracy=0.719, loss=0.98] 


Epoch 4, Train Loss: 0.5460911259404632, Train Accuracy: 0.8483308901918977, Val Loss: 0.5064548553934523, Val Accuracy: 0.8536027070063694


100%|██████████| 938/938 [00:20<00:00, 46.30it/s, accuracy=0.875, loss=0.456]


Epoch 5, Train Loss: 0.49157956267979097, Train Accuracy: 0.8660047974413646, Val Loss: 0.4543072434177824, Val Accuracy: 0.873109076433121


100%|██████████| 938/938 [00:20<00:00, 45.71it/s, accuracy=0.906, loss=0.57] 


Epoch 6, Train Loss: 0.4421602461987467, Train Accuracy: 0.8811633795309168, Val Loss: 0.3999748114188006, Val Accuracy: 0.8926154458598726


100%|██████████| 938/938 [00:23<00:00, 40.59it/s, accuracy=0.938, loss=0.197]


Epoch 7, Train Loss: 0.39452289286325737, Train Accuracy: 0.8956556503198294, Val Loss: 0.3606869113767982, Val Accuracy: 0.9018710191082803


100%|██████████| 938/938 [00:20<00:00, 45.22it/s, accuracy=0.906, loss=0.259]


Epoch 8, Train Loss: 0.3566824979841836, Train Accuracy: 0.9053005063965884, Val Loss: 0.3280747068251014, Val Accuracy: 0.9118232484076433


100%|██████████| 938/938 [00:20<00:00, 45.19it/s, accuracy=0.938, loss=0.274]


Epoch 9, Train Loss: 0.3305229269810069, Train Accuracy: 0.9127631929637526, Val Loss: 0.3088569798192401, Val Accuracy: 0.9168988853503185


100%|██████████| 938/938 [00:20<00:00, 44.78it/s, accuracy=0.938, loss=0.285]


Epoch 10, Train Loss: 0.3114428613771762, Train Accuracy: 0.9165445095948828, Val Loss: 0.2923462862137017, Val Accuracy: 0.9194864649681529
