## Packages

In [1]:
import torch.nn as nn
import torch

In [2]:
import torch

if torch.cuda.is_available():
    num_gpus = torch.cuda.device_count()
    print(f"Number of GPUs available: {num_gpus}")
    
    for gpu_id in range(num_gpus):
        gpu_name = torch.cuda.get_device_name(gpu_id)
        print(f"GPU {gpu_id}: {gpu_name}")
else:
    print("No GPUs available on this system.")


Number of GPUs available: 1
GPU 0: NVIDIA GeForce RTX 4080


## Internal Nodes and Tree Nodes

In [3]:
class Internal_Node():
    
    def __init__(self,depth):
        
        
        self.filter=nn.Linear(28*28,1, bias=True)
        
        #self.beta=nn.Parameter(torch.randn(1))
        
        self.leaf=False
        
        self.depth=depth
        self.alpha=None
        
        
    def proceed(self,x):
        
        #let batch (64,784) go through nn.Linear and get (64,1)
        x=self.filter(x)
        
        #inverse temperature sharpen the direction prob
        #x=x*self.beta
        
        #apply sigmoid to get going right probability (64,1) to (64,1)
        right_prob=torch.sigmoid(x)
        
        
        return right_prob
        
    

In [4]:
class Leaf_Node():
    def __init__(self,argument):
        
        
        self.num_classes=argument.number_classes
        
        self.batch_size=argument.batch_size
        self.leaf=True
        
        #initialize the parameter for outputing softmax probability
        self.param=torch.randn(1,self.num_classes)
        self.param=nn.Parameter(self.param)
        
    def partial_proceed(self):
        
        #for testing purpose
        distribution=torch.softmax(self.param, dim=1)
        
        return distribution
        
    def proceed(self):        
        
        distribution=torch.softmax(self.param, dim=1)
        
        #duplicate along the first dimention (1,10) to (64,10)
        
        distribution=torch.cat([distribution]*self.batch_size,dim=0)
        
        
        return distribution
        

## Construction of Path_dict and Node_dict
Path_prob_dict and Node dict

In [5]:
def path_dict_creation(tree_depth):
    path_dict={}
    
    for i in range(2**(tree_depth-1)):
    
        #useful for iteratively create Path_Names
        path_name=f"path_{i+1}"
    
        path_list=[]
    
        #initial path
        for k in range(tree_depth):
            path_list.append((k+1,1))
    
        #put all node_coordination into a list for counting purpose
        all_tuples = [item for sublist in path_dict.values() for item in sublist]
    
    
        #For iterative change the node coordination based on number of occurrence in the path_dict
        for i in range(tree_depth):
        
            #if different depth's node appear certain times, change it, use while to do it iteratively
            while all_tuples.count(path_list[-(i+1)])==2**i:
            
                #continue addition if there is repetitive occurrence
                path_list[-(i+1)]=(path_list[-(i+1)][0],path_list[-(i+1)][1]+1)
    
        path_dict[path_name]=path_list
    
    
    return path_dict

In [6]:
path_dict_4=path_dict_creation(tree_depth=4)
path_dict_5=path_dict_creation(tree_depth=5)
path_dict_6=path_dict_creation(tree_depth=6)

path_dict_4

{'path_1': [(1, 1), (2, 1), (3, 1), (4, 1)],
 'path_2': [(1, 1), (2, 1), (3, 1), (4, 2)],
 'path_3': [(1, 1), (2, 1), (3, 2), (4, 3)],
 'path_4': [(1, 1), (2, 1), (3, 2), (4, 4)],
 'path_5': [(1, 1), (2, 2), (3, 3), (4, 5)],
 'path_6': [(1, 1), (2, 2), (3, 3), (4, 6)],
 'path_7': [(1, 1), (2, 2), (3, 4), (4, 7)],
 'path_8': [(1, 1), (2, 2), (3, 4), (4, 8)]}

In [7]:
def node_dict_creation(path_dict):
    node_set=set()

    node_dict={}
    for value in path_dict.values():
    
        value=tuple(value)
    
        node_dict[value]=[]
    
    #iterate through key of node dict(value pf path_dict) (path nodes are going to take)
    for key in node_dict.keys():
        
        #add nodes according to path
        for i in range(len(key)):
            
            #avoid repetitive creation of internal nodes
            if key[i] not in node_set:
                
                #for adding internal nodes
                if i <len(key)-1:
                    
                    #attach a depth for regularization, key[i] is the coordination of the node
                    #k[i][0] is the depth (3,1) - depth=3
                    
                    node=Internal_Node(depth=key[i][0])
                    
                    node_dict[key].append(node)
                    
                    #add node's coordination into set to keep a record
                    node_set.add(key[i])
            
            #if already exists
            else:
                
                node_dict[key].append(node_dict[old_key][i])
            
            if i ==len(key)-1:
                
                node=Leaf_Node(argu)
                node_dict[key].append(node)
        
        # for sharing node purpose
        old_key=key
                    
                    
                    
        
    return node_dict


            

In [8]:
def get_dict(tree_depth):
    
    path_dict=path_dict_creation(tree_depth)
    
    node_dict=node_dict_creation(path_dict)
    return node_dict

## Tree models

In [9]:
class Arguments():
    def __init__(self):
        
        
        self.number_classes=10
        self.epochs=30
        self.cuda=True
        
        
        
        
        self.batch_size=128
        self.lr=0.012
        self.momentum=0.9
        #you considered the leaves as tree depth as well, remember the difference (+1)
        self.tree_depth=9
        self.lmbda=0.02
        
argu=Arguments() 

In [10]:
node_dict=get_dict(argu.tree_depth)

In [11]:
import torchvision
import torchvision.transforms as transforms

# Define data transformations to preprocess the images
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Download and load the training dataset
trainset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=argu.batch_size, shuffle=True, drop_last=True)

# Download and load the testing dataset
testset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=argu.batch_size, shuffle=False, drop_last=True)


In [12]:
import torch.optim as optim

class Nt(nn.Module):
    def __init__(self,arguments,node_dict):
        super(Nt,self).__init__()

        self.args=arguments
        self.lmbda=self.args.lmbda
        
        
        self.batch_size=self.args.batch_size
        
        self.epochs=self.args.epochs
        self.module_list=nn.ModuleList()
        self.param_list=nn.ParameterList()
        
        
        #rewrite once automated
        self.node_dict=node_dict
        
        
        self.collect_parameters()
        self.loss=nn.CrossEntropyLoss()
        
        
        self.optimizer = optim.SGD(self.parameters(), lr=self.args.lr, momentum=self.args.momentum)
        
        
        
    def collect_parameters(self):
        
        node_set=set()
        for value in self.node_dict.values():
            
            for i in range(len(value)):
                if value[i] not in node_set:
                    if value[i].leaf==False:
                        self.module_list.append(value[i].filter)
                        #self.param_list.append(value[i].beta)
                        
                        node_set.add(value[i])
                    else:
                        self.param_list.append(value[i].param)

    def predict(self,x):
        
        output_=0
        for key,value in self.node_dict.items():
    
   
            leaf_prob=value[-1].proceed()
    
    
            path_prob=torch.ones(self.batch_size,1)
    
            #iterate through nodes in the path
            for i in range(len(key)-1):
        
                going_right=value[i].proceed(x)
                going_left=1-going_right
                
                
                #For computing alpha
                
                if self.args.cuda==True:
                    path_prob=path_prob.cuda()
                    going_right=going_right.cuda()
                
                
                value[i].alpha=torch.sum(path_prob*going_right)/torch.sum(path_prob)
                
        
                #condition for going right
                if key[i+1][1]/key[i][1]==2:
            
                    path_prob=path_prob*going_right
            
                else:
                    
                    
                    path_prob=path_prob*going_left
    
            #aggregate prediction
            output_+=path_prob*leaf_prob

        return output_ 
    
    def train_(self,train_loader,test_loader):
        self.train()
        for i in range(self.epochs):
            
            print("Epoch:{}".format(i+1))
            
            
            total_loss=0
            
            batch=0
            total_samples=0
            correct_predictions=0
            
            #Tracking how loss is progressing
            avg_loss=0
            avg_regularization=0
            
            for batch_idx, (data, label) in enumerate(train_loader):
                
                batch+=1
            
                if self.args.cuda==True:
                    data, label=data.cuda(),label.cuda()
                    self.cuda()
                    
                device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
                
                
               
                #change (batch,1,28,28) to (batch, 28*28)
                data=data.view(self.batch_size,-1)
                
                
                
                #forward_pass
                output=self.predict(data)
               
            
                #For accuracy purpose
                
                ##max_value, index of predicted label,     two position output
                _, predicted_label = torch.max(output, dim=1, keepdim=True)
                
                correct_predictions += (predicted_label == label.view_as(predicted_label)).sum().item()
               
                total_samples += label.size(0)
            
            
            
            
            
                #updating process
                self.optimizer.zero_grad()
                
                loss=self.loss(output,label)
                
                
                # Regularizers
                regularization=0
                
                #iterate through all nodes
                for value in node_dict.values():
                    
                    for k in range(len(value)):
                        #only for internal nodes
                        if value[k].leaf==False:
                            
                            regularization-=self.lmbda*2**(-value[k].depth)*0.5*(torch.log(value[k].alpha)+torch.log(1-value[k].alpha))
                        
                 #Tracking how loss is progressing
                avg_loss+=loss
                avg_regularization+=regularization
                
                loss=loss+regularization
                
                
                
                loss.backward(retain_graph=True)
                
                self.optimizer.step()
                
                total_loss+=loss.item()
                
                
                
                # verbose
                if batch%200==0:
                    
                    print("Batch: {}   Loss: {}   Accumulated Accuracy in this epoch:{:.2f}%"
                          .format(batch, total_loss/batch,correct_predictions/total_samples*100))
                    
                    print("\nLoss_normal: {}  Regularization: {}".format( avg_loss/batch, avg_regularization/batch ))
        
            print("\nOne Epoch done, check for overfitting")
            self.test_(test_loader)
        
                    
        return self.node_dict
    
    
    #evaluating on test dataset
    def test_(self,test_loader):
        self.eval()
        
        total_samples=0
        correct_predictions=0
        
        
        for batch_idx, (data, label) in enumerate(test_loader):
            
            if self.args.cuda==True:
                    data, label=data.cuda(),label.cuda()
                    self.cuda() 
            
            data=data.view(self.batch_size,-1)
            
            #predict
            output=self.predict(data)
               
            
            #For accuracy purpose
            ##max_value, index of predicted label
            _, predicted_label = torch.max(output, dim=1, keepdim=True)
                
            correct_predictions += (predicted_label == label.view_as(predicted_label)).sum().item()
               
            total_samples += label.size(0)
            
        print("Testset Accuracy: {:.2f}%\n".format(correct_predictions/total_samples*100))
        
        
        

## Testing Code

In [13]:
tre=Nt(argu,node_dict)

In [394]:
dd_ict=tre.train_(train_loader=trainloader,test_loader=testloader)

Epoch:1
Batch: 200   Loss: 5.294553232192993   Accumulated Accuracy in this epoch:94.40%

Loss_normal: 1.6316702365875244  Regularization: 3.6628828048706055
Batch: 400   Loss: 5.293789571523666   Accumulated Accuracy in this epoch:94.50%

Loss_normal: 1.6295607089996338  Regularization: 3.6642282009124756

One Epoch done, check for overfitting
Testset Accuracy: 94.20%

Epoch:2
Batch: 200   Loss: 5.295295097827911   Accumulated Accuracy in this epoch:94.56%

Loss_normal: 1.6276841163635254  Regularization: 3.6676108837127686
Batch: 400   Loss: 5.2934327316284175   Accumulated Accuracy in this epoch:94.56%

Loss_normal: 1.6282000541687012  Regularization: 3.6652331352233887

One Epoch done, check for overfitting
Testset Accuracy: 94.30%

Epoch:3
Batch: 200   Loss: 5.293015840053559   Accumulated Accuracy in this epoch:94.64%

Loss_normal: 1.6271252632141113  Regularization: 3.6658923625946045
Batch: 400   Loss: 5.291593989133835   Accumulated Accuracy in this epoch:94.64%

Loss_normal: 

Batch: 200   Loss: 5.260742292404175   Accumulated Accuracy in this epoch:95.50%

Loss_normal: 1.600421667098999  Regularization: 3.660320997238159
Batch: 400   Loss: 5.262895109653473   Accumulated Accuracy in this epoch:95.43%

Loss_normal: 1.601212978363037  Regularization: 3.6616814136505127

One Epoch done, check for overfitting
Testset Accuracy: 94.93%

Epoch:24
Batch: 200   Loss: 5.264065220355987   Accumulated Accuracy in this epoch:95.42%

Loss_normal: 1.6015303134918213  Regularization: 3.6625335216522217
Batch: 400   Loss: 5.263441658020019   Accumulated Accuracy in this epoch:95.44%

Loss_normal: 1.6009736061096191  Regularization: 3.6624667644500732

One Epoch done, check for overfitting
Testset Accuracy: 94.84%

Epoch:25
Batch: 200   Loss: 5.262025058269501   Accumulated Accuracy in this epoch:95.47%

Loss_normal: 1.5994606018066406  Regularization: 3.662564277648926
Batch: 400   Loss: 5.262692667245865   Accumulated Accuracy in this epoch:95.46%

Loss_normal: 1.600260734

In [14]:
tre.test_(testloader)

Testset Accuracy: 10.15%



## Check number of parameters in the tree

In [16]:
from prettytable import PrettyTable
# Initialize table
table = PrettyTable(["Layer Name", "Number of Parameters"])

total_params = 0

for name, parameter in tre.named_parameters():
    if not parameter.requires_grad: continue
    param = parameter.numel()
    table.add_row([name, param])
    total_params += param

table.add_row(["Total", total_params])

print(table)

+------------------------+----------------------+
|       Layer Name       | Number of Parameters |
+------------------------+----------------------+
|  module_list.0.weight  |         784          |
|   module_list.0.bias   |          1           |
|  module_list.1.weight  |         784          |
|   module_list.1.bias   |          1           |
|  module_list.2.weight  |         784          |
|   module_list.2.bias   |          1           |
|  module_list.3.weight  |         784          |
|   module_list.3.bias   |          1           |
|  module_list.4.weight  |         784          |
|   module_list.4.bias   |          1           |
|  module_list.5.weight  |         784          |
|   module_list.5.bias   |          1           |
|  module_list.6.weight  |         784          |
|   module_list.6.bias   |          1           |
|  module_list.7.weight  |         784          |
|   module_list.7.bias   |          1           |
|  module_list.8.weight  |         784          |
