In [1]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import SplineTransformer
from torch.utils.data import DataLoader, TensorDataset, Subset
from collections import OrderedDict
from torch import nn
import torch
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import torchvision
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
from collections import Counter
import glob
import cv2
import os

In [2]:
def diag_mat_weights(dimp, type = 'first'):
    if type == 'first':
        dg = np.zeros((dimp-1, dimp))
        for i in range(dimp-1):
            dg[i,i] = -1
            dg[i,i+1]= 1
    elif type == 'second':
        dg = np.zeros((dimp-2, dimp))
        for i in range(dimp-2):
            dg[i,i] = -1
            dg[i,i+1]= 2
            dg[i,i+2]= -1
    else:
        pass
    return torch.Tensor(dg)
    

class BSL(nn.Module):
    def __init__(self, degree, num_knots, num_neurons, bias = True):
        super(BSL, self).__init__()
        self.degree = degree
        self.num_knots = num_knots
        self.num_neurons = num_neurons
        self.control_p = nn.Parameter(torch.randn(self.num_knots, self.num_neurons))
        
        if bias:
            self.bias = nn.Parameter(torch.randn(self.num_neurons))
        else:
            self.register_parameter('bias', None)
            
        self.inter = {}
    
    def basis_function(self, x, i, k, t):
    
        # Base case: degree 0 spline
        if k == 0:
            return ((t[i] <= x) & (x < t[i + 1])).float()
    
        # Recursive case
        denom1 = t[i + k] - t[i]
        denom2 = t[i + k + 1] - t[i + 1]
    
        term1 = 0
        if denom1 != 0:
            term1 = (x - t[i]) / denom1 * self.basis_function(x, i, k - 1, t)
    
        term2 = 0
        if denom2 != 0:
            term2 = (t[i + k + 1] - x) / denom2 * self.basis_function(x, i + 1, k - 1, t)
    
        return term1 + term2

    def knots_distribution(self, dg, nk):

        knots = torch.cat([torch.linspace(-0.002, -0.001, steps=dg),            # Add repeated values at the start for clamping
            torch.linspace(0, 1, nk-2*dg-2),  # Uniform knot spacing in the middle
            torch.linspace(1.001, 1.002, steps=dg)           # Add repeated values at the end for clamping
            ]).view(-1,1)
        
        knots = torch.cat([torch.linspace(0, 1, nk-2)          # Add repeated values at the end for clamping
            ]).view(-1,1)

        return knots
    
    def basis_function2(self, x, spl):
        basis_output = spl.fit_transform(x.cpu().numpy())
        return basis_output
            
    def forward(self, x):
        batch_size, num_features = x.size()
        device = x.device
        
        # Create knot vector and apply B-spline basis functions for each feature
    
        basises = []
        knots = self.knots_distribution(self.degree, self.num_knots)
        spl = SplineTransformer(n_knots=self.num_knots, degree=self.degree, knots = knots)

        
        for feature in range(num_features):
            # Calculate B-spline basis functions for this feature
            
            basis = self.basis_function2(x[:, feature].reshape(-1,1), spl)
            basis = torch.Tensor(basis).to(device)
            basises.append(basis)
        
        if num_features == 1:
            tout = basises[0] @ self.control_p
            self.inter['basic'] = basises[0].T
        else:
            self.inter['basic'] = torch.reshape(torch.stack(basises, dim = 1), (batch_size, self.num_knots * self.num_neurons)).T
            basises = torch.stack(basises)
            tout = basises.permute(1,2,0) * self.control_p
            tout = tout.sum(dim =1)
                
        if self.bias is not None:
            tout += self.bias        
            
        return tout


class NormLayer(nn.Module):
    def __init__(self):
        super(NormLayer, self).__init__()

    def forward(self, x):
        min_val = torch.min(x, axis = 1).values.reshape(-1,1)
        max_val = torch.max(x, axis = 1).values.reshape(-1,1)

        x = (x - min_val)/(max_val - min_val)  # Rescale to [0, 1]
        return x.detach()
    
class BSpline_block(nn.Module):
    def __init__(self, degree, num_knots, num_neurons, dropout = 0.0, bias = True):
        super(BSpline_block, self).__init__()

        self.block = nn.Sequential(OrderedDict([
            ('norm', NormLayer()),
            ('BSL', BSL(degree = degree, num_knots = num_knots, num_neurons = num_neurons, bias = bias)),
            ('drop', nn.Dropout(dropout)),
        ]))
        
    def forward(self, x):
        return self.block(x)
        
class StackBS_block(nn.Module):
    def __init__(self, block, degree, num_knots, num_neurons, num_blocks, dropout = 0.0, bias = True):
        super().__init__()
        self.model = nn.ModuleDict({
            f'block_{i}': block(degree = degree, num_knots = num_knots, num_neurons = num_neurons)
            for i in range(num_blocks)
        })

    def forward(self, x):
        for name, block in self.model.items():
            x = block(x)
        return x

## Modeling

In [3]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    criterion = nn.CrossEntropyLoss() 
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 20 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
        scheduler.step()
        
def test(model, device, test_loader):
    model.eval()
    
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            output = torch.log_softmax(output, dim=1)
            _, pred = torch.max(output, dim = 1)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [5]:
device = torch.device("mps:0" if torch.mps.is_available() else "cpu")

# Download the MNIST dataset
train_dataset_full = datasets.MNIST(root='./data', train=True, download=True)
test_dataset_full = datasets.MNIST(root='./data', train=False, download=True)

# Function to extract N samples for a given digit
def extract_n_samples(data, targets, digit, n):
    indices = (targets == digit).nonzero().squeeze()
    selected_indices = indices[:n]
    return data[selected_indices], targets[selected_indices]

data = train_dataset_full.data
targets = train_dataset_full.targets

# Extract 200 samples for each digit
data_0, labels_0 = extract_n_samples(data, targets, digit=0, n=200)
data_1, labels_1 = extract_n_samples(data, targets, digit=1, n=200)

# Combine and apply transform manually
X_train = torch.cat([data_0, data_1]).to(device)
y_train = torch.cat([labels_0, labels_1]).to(device)

X_train = X_train.unsqueeze(1).float() / 255.0
X_train = (X_train - 0.1307) / 0.3081

perm = torch.randperm(X_train.size(0))
X_train = X_train[perm]
y_train = y_train[perm]

data = test_dataset_full.data
targets = test_dataset_full.targets

# Extract 200 samples for each digit
data_0, labels_0 = extract_n_samples(data, targets, digit=0, n=200)
data_1, labels_1 = extract_n_samples(data, targets, digit=1, n=200)

# Combine and apply transform manually
X_test = torch.cat([data_0, data_1]).to(device)
y_test = torch.cat([labels_0, labels_1]).to(device)

X_test = X_test.unsqueeze(1).float() / 255.0
X_test = (X_test - 0.1307) / 0.3081

In [11]:
class MNISTClassifier(nn.Module):
    def __init__(self, dg, nk, nm, nbl, dropout, Fout, bias):
        super(MNISTClassifier, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.gap = nn.Flatten()
        self.classifier = nn.Sequential(
            nn.Linear(1568, nm),
            StackBS_block(BSpline_block, degree = dg, num_knots = nk, num_neurons = nm, num_blocks = nbl, dropout = dropout),
            nn.Linear(nm, Fout))
        self.sm = nn.Softmax(dim = 1)
    def forward(self, x):
        x = self.features(x)
        x = self.gap(x)
        x = self.classifier(x)
        x = self.sm(x)

        return x

    def get_para_ecm(self, x):

        '''
        ecm_para: A dictionary that collects the parameter we need to the following ECM algorithm.
        ecm_para.basic: Store the output of each B-Spline block; Dimension = [n_sample, n_neurons]
        ecm_para.ebasic Store the weight matrix of each B-Spline expansion; Dimension = [n_knots * n_neurons, n_sample]

        '''
        ecm_para = {}
        bs_block_out = {}
        bs_spline_weight = {}
        bs_spline_value = {}
        bs_spline_bias = {}

        _ = self(x)
        
        def get_activation(name):
            def hook(model, input, output):
                bs_block_out[name] = output.detach()
            return hook

        handles = []
        for name, layer in self.named_modules():
            if 'block.drop' in name:
                handles.append(layer.register_forward_hook(get_activation(name)))
            elif 'block.BSL' in name:
                bs_spline_value[name] = layer.inter['basic'].detach()
                bs_spline_weight[name] = layer.control_p.detach()
                bs_spline_bias[name] = layer.bias.detach()
        # Run forward pass (triggers hooks)
        _ = self(x)

        # Clean up hooks
        for h in handles:
            h.remove()
            
        ecm_para['basic'] = torch.stack(list(bs_block_out.values()), dim=0)
        ecm_para['ebasic'] = torch.stack(list(bs_spline_value.values()), dim=0)
        ecm_para['wbasic'] = torch.stack(list(bs_spline_weight.values()), dim=0)
        ecm_para['bbasic'] = torch.stack(list(bs_spline_bias.values()), dim=0)
        del bs_block_out, bs_spline_weight, bs_spline_value, bs_spline_bias
        
        return ecm_para

nm = 50; nk = 15; dg = 3; nl = 1
mnist_DeepBS = MNISTClassifier(dg = dg, nk = nk, nm = nm, nbl = nl, dropout = 0.0, Fout = 2, bias = True).to(device)
learning_r = 1e-2
optimizer = torch.optim.Adam(mnist_DeepBS.parameters(), lr=learning_r)
criterion = nn.CrossEntropyLoss()

In [110]:
for t in range(Iteration):

    # Forward pass: Compute predicted y by passing x to the modelsp
    pyb_af = mnist_DeepBS(X_train)
    loss = criterion(pyb_af, y_train)
    
    prediction = torch.argmax(pyb_af, axis = 1)
    acc = (torch.argmax(pyb_af, axis = 1) == y_train).sum()/len(y_train)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if(t % 10 == 0):
        print('| Epoch: ',t+1,'/',str(Iteration),' | Loss: ', np.round(loss.item(), 4),' | Acc: ', acc.item())
        if(t % 100 == 0):
            with torch.no_grad():
                print((torch.argmax(mnist_DeepBS(X_test).detach(), axis = 1) == y_test).sum()/len(y_test))

torch.save(mnist_DeepBS.state_dict(), './MNIST'+str(X_train.size()[0])+'h'+str(nm)+'k'+str(nk))

| Epoch:  1 / 100  | Loss:  0.7827  | Acc:  0.5
tensor(0.4900, device='mps:0')
| Epoch:  11 / 100  | Loss:  0.5086  | Acc:  0.875
| Epoch:  21 / 100  | Loss:  0.3977  | Acc:  0.9725000262260437
| Epoch:  31 / 100  | Loss:  0.3593  | Acc:  0.9850000143051147
| Epoch:  41 / 100  | Loss:  0.3419  | Acc:  0.9950000047683716
| Epoch:  51 / 100  | Loss:  0.3333  | Acc:  0.9950000047683716
| Epoch:  61 / 100  | Loss:  0.3285  | Acc:  0.9950000047683716
| Epoch:  71 / 100  | Loss:  0.3255  | Acc:  0.9975000023841858
| Epoch:  81 / 100  | Loss:  0.3234  | Acc:  0.9975000023841858
| Epoch:  91 / 100  | Loss:  0.3219  | Acc:  0.9975000023841858


In [12]:
eval_model = MNISTClassifier(dg = dg, nk = nk, nm = nm, nbl = nl, dropout = 0.0, Fout = 2, bias = True).to(device)
eval_model.load_state_dict(torch.load('./MNIST'+str(X_train.size()[0])+'h'+str(nm)+'k'+str(nk), weights_only = True))
eval_model.eval()

with torch.no_grad():
    pred_postecm = eval_model(X_train)
    print('Accuracy: ', (torch.argmax(pred_postecm, axis = 1) == y_train).sum())
    pred_postecm = eval_model(X_test)
    print('Accuracy: ', (torch.argmax(pred_postecm, axis = 1) == y_test).sum())
    
    CLoss = criterion(pred_postecm.detach(), y_train)

Accuracy:  tensor(399, device='mps:0')
Accuracy:  tensor(397, device='mps:0')


In [41]:
def ECM(par, initial_xi = 1, initial_sigma = 1, initial_lambda = 1e-4):
    lambdab = initial_lambda
    sigma = initial_sigma
    xi = initial_xi
    
    n_block, num_knots, num_neurons = par['wbasic'].size()
    ls_lambda = torch.empty(n_block)
    
    for l in range(n_block):
        B = par['ebasic'][l]
        By = par['basic'][l]
        WB = par['wbasic'][l]
        
        DB = diag_mat_weights(WB.size()[0]).to(device)
        size = B.size()[1]
        S = DB.T @ DB
        Cov_a = (xi**2)* torch.linalg.pinv(S)
        Cov_a.to(device)
        Cov_e = (torch.eye(size*num_neurons)* sigma).to(device)
        
        block_y = torch.reshape(By, (-1,1))
        flatB = B.view(num_neurons, num_knots, size)
            
        sqr_xi= 0
        sqr_sig = 0

        for i in range(num_neurons):
            Ncov = (Cov_a -(Cov_a @ flatB[i]) @ (torch.linalg.pinv(flatB[i].T @ Cov_a @ flatB[i] + Cov_e[size*i:size*(i+1),size*i:size*(i+1)]) @ flatB[i].T @ Cov_a))
            Nmu = (Cov_a @ flatB[i]) @ (torch.linalg.pinv(flatB[i].T @ Cov_a @ flatB[i] + Cov_e[size*i:size*(i+1),size*i:size*(i+1)])) @ By[:,i].reshape(-1,1)
            
            first_xi = S @ Ncov
            second_xi = (Nmu.T @ S @ Nmu)
            sqr_xi += torch.trace(first_xi) + second_xi
                
            first_sig = torch.norm(By[:,i])
            second_sig = 2 * (By[:,i] @ flatB[i].T) @ Nmu 
            third_sig = torch.trace((flatB[i] @ flatB[i].T) @ Ncov)
            four_sig = (Nmu.T @ flatB[i] @ flatB[i].T @ Nmu)
            
            sqr_sig += (first_sig + second_sig + third_sig + four_sig)
            
            del first_xi, second_xi, first_sig, second_sig, third_sig, four_sig

        sqr_xi /= num_neurons
        sqr_sig /= (num_neurons*size)

        ls_lambda[l] = (sqr_sig/sqr_xi).item()
        
        del Cov_a, Cov_e, flatB, B, By, WB
        torch.cuda.empty_cache()
    
    return ls_lambda
    
def ECM_layersise_update(model, par, Lambda, x, y):

    model.eval()
    device = x.device
    
    B_out, B_in, B_w, B_b = par['basic'], par['ebasic'], par['wbasic'], par['bbasic']
    n_layer, nk, nm = B_w.size()
    DB = diag_mat_weights(B_w[0].size()[0], 'second').to(device)

    Project_matrix = (torch.linalg.pinv(B_in[-1].T @ B_in[-1]) @ B_in[-1].T @ B_in[-1])
    Size = [b.size()[1] for b in B_in]

    B_in = B_in.view(n_layer, nm, nk, Size[0])

    for l in range(n_layer):    
        NW = torch.empty((nk, nm)).to(device)
        NB = torch.empty((nm)).to(device)
        
        for i in range(nm):
            B1y = B_out[l][:,i] - B_b[l][i]
            BB = B_in[l][i].T
    
            # Update the weights and bias
            NW[:, i] = (torch.inverse(BB.T @ BB + (Lambda[l]/Size[l]) * (DB.T @ DB)) @ BB.T @ B1y)
            NB[i] = torch.mean(B_out[l][:,i] - (NW[:,i] @ BB.T))
                
        # update the weight
        block = getattr(model.classifier[1].model, f'block_{l}')
        getattr(block.block.BSL, 'control_p').data = NW
        getattr(block.block.BSL, 'bias').data = NB

        del NW, NB, B1y, BB, block
        torch.cuda.empty_cache()
        
    with torch.no_grad():
        DPSy = model(x)
        CLoss = criterion(DPSy.detach(), y)
        GCV = CLoss/(Size[-1]-torch.trace(Project_matrix))
    
    return model, GCV

def ECM_update(model, max_iter, x, y):
    BestGCV = prev = 9999
    patient = 10
    pcount = 0
    for i in range(max_iter):
        _ = model(X_train)
        ECM_para = model.get_para_ecm(x)
        ECM_Lambda = ECM(ECM_para, initial_xi = 1, initial_sigma = 1, initial_lambda = 1e-4)

        model, GCV = ECM_layersise_update(model, ECM_para, ECM_Lambda, x, y)
        if np.abs(prev - GCV.cpu().detach().numpy()) < 5e-5:
            print('GCV Converge at ',i+1,' iteration')
            break
            
        if GCV < BestGCV:
            BestLambda = ECM_Lambda
            BestGCV = GCV
            pcount = 0
        else:
            pcount += 1

        if pcount == patient:
            print('GCV Converge at ',i,' iteration')
            break

        prev = GCV.cpu().detach().numpy()

        del ECM_para, ECM_Lambda
        
    del model
    
    return BestLambda

In [None]:
eval_model = MNISTClassifier(dg = dg, nk = nk, nm = nm, nbl = nl, dropout = 0.0, Fout = 2, bias = True).to(device)
eval_model.load_state_dict(torch.load('./MNIST'+str(X_train.size()[0])+'h'+str(nm)+'k'+str(nk), weights_only = True))
eval_model.eval()
with torch.no_grad():
    BestLambda = ECM_update(eval_model, 10, X_train, y_train)

In [71]:
"""

`fast_epoch`: number of epoch to run the fast tuning

"""
fast_epoch = 200
DPS = MNISTClassifier(dg = dg, nk = nk, nm = nm, nbl = nl, dropout = 0.0, Fout = 2, bias = True).to(device)
DPS.load_state_dict(torch.load('./MNIST'+str(X_train.size()[0])+'h'+str(nm)+'k'+str(nk), weights_only = True))

lr_ft = 1e-2
optimizer = torch.optim.Adam(DPS.parameters(), lr=lr_ft)

for t in range(fast_epoch):

    # Forward pass: Compute predicted y by passing x to the modelsp
    pyb_af = DPS(X_train)
    loss = criterion(pyb_af, y_train)
    
    for l in range(nl):
        block = getattr(DPS.classifier[1].model, f'block_{l}')
        W = getattr(block.block.BSL, 'control_p')
        D = diag_mat_weights(W.size()[0]).to(device)
        loss += BestLambda[l].to(device)/X_train.size()[0] * torch.norm(D@W)
    
    prediction = torch.argmax(pyb_af, axis = 1)
    acc = (torch.argmax(pyb_af, axis = 1) == y_train).sum()/len(y_train)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if t % 10 == 0:
        print('| Epoch: ',t+1,'/',str(fast_epoch),' | Loss: ', loss.item(),' | Acc: ', np.round(acc.item(), 5))
        if t % 100 == 0:
            with torch.no_grad():
                print((torch.argmax(DPS(X_test).detach(), axis = 1) == y_test).sum()/len(y_test))

| Epoch:  1 / 200  | Loss:  0.34722980856895447  | Acc:  0.9975
tensor(0.9950, device='mps:0')
| Epoch:  11 / 200  | Loss:  0.342163622379303  | Acc:  0.9975
| Epoch:  21 / 200  | Loss:  0.3391888737678528  | Acc:  0.9975
| Epoch:  31 / 200  | Loss:  0.3366234600543976  | Acc:  0.9975
| Epoch:  41 / 200  | Loss:  0.3333474099636078  | Acc:  1.0
| Epoch:  51 / 200  | Loss:  0.3313204050064087  | Acc:  1.0
| Epoch:  61 / 200  | Loss:  0.329637348651886  | Acc:  1.0
| Epoch:  71 / 200  | Loss:  0.3281386196613312  | Acc:  1.0
| Epoch:  81 / 200  | Loss:  0.3267662525177002  | Acc:  1.0
| Epoch:  91 / 200  | Loss:  0.3255162835121155  | Acc:  1.0
| Epoch:  101 / 200  | Loss:  0.32437586784362793  | Acc:  1.0
tensor(0.9925, device='mps:0')
| Epoch:  111 / 200  | Loss:  0.32334110140800476  | Acc:  1.0
| Epoch:  121 / 200  | Loss:  0.32240742444992065  | Acc:  1.0
| Epoch:  131 / 200  | Loss:  0.32157114148139954  | Acc:  1.0
| Epoch:  141 / 200  | Loss:  0.3208281993865967  | Acc:  1.0
| Ep