In [81]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torchvision import datasets, transforms

In [None]:
# still test version

In [82]:
class TensorProjection(nn.Module):
    """ Custom Linear layer but mimics a standard linear layer """
    def __init__(self,p1,p2,p3,q1,q2,q3):
        super().__init__()
        self.p1 = p1;
        self.p2 = p2;
        self.p3 = p3;
        self.q1 = q1;
        self.q2 = q2;
        self.q3 = q3;
        if(p1 != q1):
            self.W1 = nn.Parameter(torch.Tensor(p1,q1))
        if(p2 != q2):
            self.W2 = nn.Parameter(torch.Tensor(p2,q2))
        if(p3 != q3):
            self.W3 = nn.Parameter(torch.Tensor(p3,q3))

        # initialize weights and biases
        nn.init.kaiming_uniform_(self.W1, a=math.sqrt(5)) # weight init
        nn.init.kaiming_uniform_(self.W2, a=math.sqrt(5)) # weight init
        nn.init.kaiming_uniform_(self.W3, a=math.sqrt(5)) # weight init
        
    
    def sqrtm(self,A):
        U,D,V = torch.linalg.svd(A)
        return U @ torch.diag(torch.sqrt(D)) @ V
    
    def forward(self, x):
        
        n = x.shape[0]
        z = x
        e = 10**-6
        
        if(self.p1 != self.q1):
            Iq1 = torch.eye(self.q1);
            W1 = self.W1;
            U1 = W1 @ torch.linalg.inv(self.sqrtm(torch.transpose(W1,1,0) @ W1 + e * Iq1));
            U1T = torch.transpose(U1,1,0);
            A = torch.unsqueeze(U1T,0)
            A = torch.tile(A,[n,1,1])
            z = torch.einsum('npqr,nsp->nsqr',z,A)
        
        if(self.p2 != self.q2):
            Iq2 = torch.eye(self.q2);
            W2 = self.W2;
            U2 = W2 @ torch.linalg.inv(self.sqrtm(torch.transpose(W2,1,0) @ W2 + e * Iq2));
            U2T = torch.transpose(U2,1,0);
            B = torch.unsqueeze(U2T,0)
            B = torch.tile(B,[n,1,1])
            z = torch.einsum('npqr,nsq->npsr',z,B)
            
        if(self.p3 != self.q3):
            Iq3 = torch.eye(self.q3);
            W3 = self.W3;
            U3 = W3 @ torch.linalg.inv(self.sqrtm(torch.transpose(W3,1,0) @ W3 + e * Iq3));
            U3T = torch.transpose(U3,1,0);
            C = torch.unsqueeze(U3T,0)
            C = torch.tile(C,[n,1,1])
            z = torch.einsum('npqr,nsr->npqs',z,C)
            
        
        return z

In [83]:
# set hyperparameters
num_epochs = 5 
num_batch = 100 
learning_rate = 0.001 
image_size = 28*28

# if possible to use cuda
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [84]:
# create data
transform = transforms.Compose([
    transforms.ToTensor()
])
# mnist
# https://pytorch.org/vision/stable/generated/torchvision.datasets.MNIST.html#torchvision.datasets.MNIST
# training
train_dataset = datasets.MNIST(
    './data',                
    train = True,           
    download = True,        
    transform = transform   
    )
# valdidation
test_dataset = datasets.MNIST(
    './data', 
    train = False,
    transform = transform
    )

# data loader 
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = num_batch, shuffle = True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size = num_batch, shuffle = True)

In [85]:
# define neural network
class Net(nn.Module):
    def __init__(self, input_size, output_size):
        super(Net, self).__init__()
        # 1st
        self.cnn1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=0) #output_shape=(16,24,24)
        self.relu1 = nn.ReLU() # activation
        self.maxpool1 = nn.MaxPool2d(kernel_size=2) #output_shape=(16,12,12)
        # 2nd
        self.cnn2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=0) #output_shape=(32,8,8)
        self.relu2 = nn.ReLU() # activation
        self.tensorprojection = TensorProjection(32,8,8, 10,4,4) #output_shape=(10,4,4)
        # Fully connected 1 ,#input_shape=(10*4*4)
        self.fc1 = nn.Linear(10 * 4 * 4, 10)
        
    def forward(self, x):
        x = self.cnn1(x)
        x = self.relu1(x)
        x = self.maxpool1(x)
        x = self.cnn2(x)
        x = self.relu2(x)
        x = self.tensorprojection(x)
        x = x.view(x.size(0), -1)
        # Linear function (readout)
        x = self.fc1(x)
        return F.log_softmax(x, dim=1)

# define a neural network
model = Net(image_size, 10).to(device)

# set loss function
criterion = nn.CrossEntropyLoss() 

# optimize
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate) 

# training
model.train()

Net(
  (cnn1): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1))
  (relu1): ReLU()
  (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (cnn2): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1))
  (relu2): ReLU()
  (tensorprojection): TensorProjection()
  (fc1): Linear(in_features=160, out_features=10, bias=True)
)

In [86]:
for epoch in range(num_epochs):
    loss_sum = 0
    
    for inputs, labels in train_dataloader:
        # if possible to use gpu
        inputs = inputs.to(device)
        labels = labels.to(device)

        # initialize optimizer
        optimizer.zero_grad()

        # get outputs
        #inputs = inputs.view(-1, image_size) 
        outputs = model(inputs)

        # compute loss
        loss = criterion(outputs, labels)
        loss_sum += loss

        # compute gradient
        loss.backward()

        # update model parameters
        optimizer.step()

    # show training status
    print(f"Epoch: {epoch+1}/{num_epochs}, Loss: {loss_sum.item() / len(train_dataloader)}")

    # save 
    torch.save(model.state_dict(), 'model_weights.pth')

# evalation
model.eval()

loss_sum = 0
correct = 0

Epoch: 1/5, Loss: 0.3957036844889323
Epoch: 2/5, Loss: 0.11463255564371745
Epoch: 3/5, Loss: 0.07972813924153646
Epoch: 4/5, Loss: 0.06644113540649414
Epoch: 5/5, Loss: 0.05648375193277995


In [87]:
with torch.no_grad():
    for inputs, labels in test_dataloader:

        # if possible to use gpu
        inputs = inputs.to(device)
        labels = labels.to(device)

        # define outout
        #inputs = inputs.view(-1, image_size) 
        outputs = model(inputs)

        # compute loss
        loss_sum += criterion(outputs, labels)

        # prediction
        pred = outputs.argmax(1)
        # acc
        correct += pred.eq(labels.view_as(pred)).sum().item()

print(f"Loss: {loss_sum.item() / len(test_dataloader)}, Accuracy: {100*correct/len(test_dataset)}% ({correct}/{len(test_dataset)})")

Loss: 0.04603756427764893, Accuracy: 98.52% (9852/10000)
