In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from mpl_toolkits.mplot3d import Axes3D

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data

import random, os, pathlib, time
from tqdm import tqdm
from sklearn import datasets

In [2]:
import nflib
from nflib.flows import SequentialFlow, NormalizingFlow, ActNorm, AffineConstantFlow
import nflib.coupling_flows as icf
import nflib.inn_flow as inn
import nflib.res_flow as irf

In [3]:
from torch import distributions
from torch.distributions import MultivariateNormal

In [4]:
device = torch.device("cuda:0")
# device = torch.device("cpu")

## MNIST dataset

In [5]:
import mylibrary.datasets as datasets
import mylibrary.nnlib as tnn

In [6]:
mnist = datasets.FashionMNIST()
# mnist.download_mnist()
# mnist.save_mnist()
train_data, train_label_, test_data, test_label_ = mnist.load()

train_data = train_data / 255.
test_data = test_data / 255.

# train_label = tnn.Logits.index_to_logit(train_label_)
train_size = len(train_label_)

In [7]:
## converting data to pytorch format
train_data = torch.Tensor(train_data)
test_data = torch.Tensor(test_data)
train_label = torch.LongTensor(train_label_)
test_label = torch.LongTensor(test_label_)

In [8]:
input_size = 784
output_size = 10

In [9]:
class MNIST_Dataset(data.Dataset):
    
    def __init__(self, data, label):
        self.data = data
        self.label = label
        
#         self.label = mask.type(torch.float32).reshape(-1,1)
        self._shuffle_data_()
        
    def __len__(self):
        return len(self.data)
    
    def _shuffle_data_(self):
        randidx = random.sample(range(len(self.data)), k=len(self.data))
        self.data = self.data[randidx]
        self.label = self.label[randidx]
    
    def __getitem__(self, idx):
        img, lbl = self.data[idx], self.label[idx]
        return img, lbl

In [10]:
class Subset_Dataset(data.Dataset):
    
    def __init__(self, dataset, index):
        self.dataset = dataset
        self.index = index
        
    def __len__(self):
        return len(self.index)
    
    def __getitem__(self, idx):
        idx = self.index[idx]
        img, lbl = self.dataset[idx]
        return img, lbl

In [11]:
## The classifiers store all the data in INDEX FORM

In [12]:
train_dataset = MNIST_Dataset(train_data, train_label)
test_dataset = MNIST_Dataset(test_data, test_label)

In [13]:
train_dataset[[0,1,2]]

(tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]),
 tensor([8, 0, 6]))

In [14]:
class ConnectedClassifier_SoftKMeans(nn.Module):
    
    def __init__(self,input_dim, num_sets, output_dim, inv_temp=1):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_sets = num_sets
        self.inv_temp = nn.Parameter(torch.ones(1)*inv_temp)
        
        self.centers = nn.Parameter(torch.rand(num_sets, input_dim)*2-1)
        
#         self.cls_weight = nn.Parameter(torch.ones(num_sets, output_dim)/output_dim)
        init_val = torch.randn(num_sets, output_dim)*0.01
        for ns in range(num_sets):
            init_val[ns, ns%output_dim] = 2.
#             init_val[ns, 0] = 2. ### initialize same class in all clusters

        self.cls_weight = nn.Parameter(init_val)

        self.cls_confidence = None
        
        
    def forward(self, x, hard=False):
        self.cls_weight.data = torch.abs(self.cls_weight.data)/self.cls_weight.data.sum(dim=1, keepdim=True)
        
        x = x[:, :self.input_dim]
        dists = torch.cdist(x, self.centers)
        dists = dists/np.sqrt(self.input_dim) ### correction to make diagonal of unit square 1 in nD space
        
        if hard:
            x = torch.softmax(-dists*1e5, dim=1)
        else:
            x = torch.softmax(-dists*self.inv_temp, dim=1)
        self.cls_confidence = x
#         c = torch.softmax(self.cls_weight, dim=1)
        c = self.cls_weight
        return x@c ## since both are normalized, it is also normalized

In [15]:
# class ClassifierTree:
    
#     def __init__(self, train_data, test_data, device):
#         self.depthwise_classifiers = []
        
#         self.train_data = train_data
#         self.test_data = test_data
#         self.device = device
#         pass
    
#     def create_new_depth(self, x, hidden_dims):
#         output_dim = 10
#         num_classifier = 10
#         if len(self.depthwise_classifiers) == 0:
#             classifier = LocalClassifier(784, hidden_dims, output_dim, num_classifier, self.device)

#             ### to initialize to all classes
#             init_val = torch.randn(num_sets, output_dim)*0.01
#             for ns in range(num_sets):
#                 init_val[ns, ns%output_dim] = 2.

#             classifier.classifier.cls_weight.data = init_val
            
#         else:
#             for 
    
    
#     def inference_forward(self, x): ## inference for one data
#         x = torch.unsqueeze(0)
#         classifier_indx = 0 ## in first depth, there is only single classifier.
#         for i in range(len(self.depthwise_classifiers)):
#             classifier = self.depthwise_classifiers[i][classifier_indx]
#             classifier_indx = classifier.inference_forward(x)[0]
#             if i < len(self.depthwise_classifiers):
#                 if self.depthwise_classifiers[i+1][classifier_indx] is not None:
#                     continue
#             return classifier.prediction_stat[classifier_indx]

In [114]:
class ClassifierTree:
    
    def __init__(self, train_data, test_data, device):
        self.root = LocalClassifier(device)
        self.root.create_network_0(784, [784], 10, 10)
        
        self.train_data = train_data
        self.test_data = test_data
        self.device = device
        pass
    
    def display_stats(self):
        indexing = "0"
        self.root.display_stats(indexing)
        acc, tot = self.root.get_correct_train()
        train_acc = acc/tot
        acc, tot = self.root.get_correct_test()
        test_acc = acc/tot
        print(f"Final Accuracy is Train: {train_acc :.5f} Test: {test_acc :.5f}")
            
    def get_parent_node(self, index_list:list):
        parent = self.root
        index_list = index_list[1:]
        for idx in index_list[:-1]:
            parent = parent.children[idx]
        return parent
    
    def get_node(self, index_list:list):
        parent = self.root
        index_list = index_list[1:]
        for idx in index_list[:-1]:
            parent = parent.children[idx]
        child = parent.children[index_list[-1]]
        return child
    
    
    def get_all_child_index(self):
        child_list = []
        self.root.get_all_index([0], child_list)
        return child_list

In [115]:
class LeafNode:
    def __init__(self):
        self.pred = None
        self.classes = None
        self.num_correct = None
        self.train_indices = None
        self.test_indices = None
        self.test_correct = None
        
    def display_stats(self, indexing):
        print(f"[{indexing}] : Train -> {self.num_correct/len(self.train_indices) :.4f}", end=" ")
        
        if len(self.test_indices)>0:
            test_acc = self.test_correct/len(self.test_indices)
        else:
            test_acc = -1
        print(f"Test -> {test_acc :.4f}, NUM: {len(self.train_indices)}, classes: {self.pred}:{self.classes}")

    def get_correct_train(self):
        return self.num_correct, len(self.train_indices)
    
    def get_correct_test(self):
        return self.test_correct, len(self.test_indices)
    
    def get_all_index(self, indexing, indx_lst):
        indx_lst.append(indexing)
                
        
    

class LocalClassifier:
    
    def __init__(self, device):
        self.model = None
        self.classifier = None
        self.device = device
        
        ### for training purposes
        self.train_loader = None
        self.test_loader = None
        self.optimizer = None
        self.frozen = False
        self.criterion = None
        
        ### after freazing the model, record stats
        self.children = []
    
    def create_network_0(self, input_dim, hidden_dims:list, output_dims, num_classifiers):
        actf = irf.Swish
        flows = []
        flows.append(ActNorm(input_dim))
        for i in range(len(hidden_dims)):
            if isinstance(hidden_dims[i], list):
                hdi = hidden_dims[i]
            else:
                hdi = [hidden_dims[i]]
            flows.append(irf.ResidualFlow(input_dim, hdi, activation=actf))
            flows.append(ActNorm(input_dim))
        
        invertible = SequentialFlow(flows)
        self.model = invertible.to(device)
        
        classifier = ConnectedClassifier_SoftKMeans(784, num_classifiers, output_dims)
        self.classifier = classifier.to(device)
        
    def create_train_loader_1(self, train_dataset, index, batch_size):
        dataset = Subset_Dataset(train_dataset, index)
        print(f"Train Dataset Num: {len(index)}")
        self.train_loader = data.DataLoader(dataset=dataset,
                                            num_workers=4, 
                                            batch_size=batch_size, 
                                            shuffle=True)
    
    def create_test_loader_2(self, test_dataset, index, batch_size):
        dataset = Subset_Dataset(test_dataset, index)
        print(f"Test Dataset Num: {len(index)}")
        self.test_loader = data.DataLoader(dataset=dataset,
                                            num_workers=4, 
                                            batch_size=batch_size, 
                                            shuffle=False)
        
    def create_optimizer_3(self, lr):
        self.optimizer = optim.Adam(list(self.model.parameters())+list(self.classifier.parameters()), 
                                    lr=lr, weight_decay=1e-15)
        self.criterion = nn.CrossEntropyLoss()
        
        
    def train_classifier_4(self, epochs, ):
        if self.frozen:
            raise ValueError("This classifier is frozen. Training it might cause errors in childern classifiers")
            
    ############# TRAINING FUNCTIONALITY BELOW ####################    

        index = 0
        train_accs, test_accs = [], []
        for epoch in range(epochs):
            train_acc = 0
            train_count = 0
            for xx, yy in tqdm(self.train_loader):
                xx, yy = xx.to(device), yy.to(device)
#                 print(xx)
                yout = self.model(xx)
#                 print(yout, torch.count_nonzero(torch.isnan(yout)))
                yout = self.classifier(yout)    
#                 print(yout, torch.count_nonzero(torch.isnan(yout)))
                loss = self.criterion(yout, yy)
#                 print(loss)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
#                 losses.append(float(loss))

                outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
                correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
                train_acc += correct
                train_count += len(outputs)

            train_accs.append(float(train_acc)/train_count*100)
            train_acc = 0
            train_count = 0

            print(f'Epoch: {epoch}:{index},  Loss:{float(loss)}')
            test_count = 0
            test_acc = 0
            for xx, yy in tqdm(self.test_loader):
                xx, yy = xx.to(device), yy.to(device)
                with torch.no_grad():
                    yout = self.classifier(self.model(xx))    
                outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
                correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
                test_acc += correct
                test_count += len(xx)
            test_accs.append(float(test_acc)/test_count*100)
            print(f'Train Acc:{train_accs[-1]:.2f}%, Test Acc:{test_accs[-1]:.2f}%')
            print()

        ### after each class index is finished training
        print(f'\t-> MAX Train Acc {max(train_accs)} ; Test Acc {max(test_accs)}')
        
        
    def freeze_and_compute_stats_5(self, MIN_POINTS):
        assert MIN_POINTS > 0
        if self.frozen:
            raise ValueError("This classifier is frozen. The stat has already been calculated")
            
        self.frozen = True
        
        ### delete optimizer, frees memory
        del self.optimizer
        ### take classifier to eval mode
        self.model.eval()
        self.classifier.eval()
        
        with torch.no_grad():
            #### remove classifier with no data or few data.

            set_count = torch.zeros(self.classifier.num_sets).to(device)
            for xx, yy in tqdm(self.train_loader):
                xx, yy = xx.to(device), yy.to(device)
                with torch.no_grad():
                    yout = self.classifier(self.model(xx), hard=True)

                cls_indx = torch.argmax(self.classifier.cls_confidence, dim=1)
                set_indx, count = torch.unique(cls_indx, return_counts=True) 
                set_count[set_indx] += count

            #### find only the classifier having some data
            classifier_index = []
            classifier_count = []

            for i, cnt in enumerate(set_count.type(torch.long).tolist()):
    #             if cnt == 0: continue
                if cnt < MIN_POINTS: continue

                classifier_index.append(i)
                classifier_count.append(int(cnt))

            #### remove the classifier representing no data
            #### OR representing data less than given N
            print(f"Keeping only N={len(classifier_index)}/{len(self.classifier.centers)} classifiers.")
            self.classifier.centers.data = self.classifier.centers.data[classifier_index]
            self.classifier.cls_weight.data = self.classifier.cls_weight.data[classifier_index]
            ### removed

            ###### compute stats now, from pruned tree.
            def get_Cs_Os_Ts(data_loader):
                Cs = [] ## winning classifier
                Os = [] ## output of winning classifier
                Ts = [] ## target class

                for xx, yy in tqdm(data_loader):
                    Ts.append(yy)
                    xx, yy = xx.to(device), yy.to(device)
                    with torch.no_grad():
                        zz = self.model(xx)
                        yout = self.classifier(zz, hard=True)
                        Os.append(torch.argmax(yout, dim=1).data.cpu())

                    cls_indx = torch.argmax(self.classifier.cls_confidence, dim=1)
                    Cs.append(cls_indx)

                Cs = torch.cat(Cs, dim=0)
                Ts = torch.cat(Ts, dim=0)
                Os = torch.cat(Os, dim=0)
                return Cs, Ts, Os


            unshuffled_data = data.DataLoader(dataset=self.train_loader.dataset,
                                                num_workers=4, 
                                                batch_size=self.train_loader.batch_size, 
                                                shuffle=False)
            Cs, Ts, Os = get_Cs_Os_Ts(unshuffled_data)
            _Cs, _Ts, _ = get_Cs_Os_Ts(self.test_loader)

            print("Hard inference on the data !")
            self.children = []
            acc = 0
            for cls_idx in range(len(self.classifier.centers)):
                data_idx = torch.nonzero(Cs == cls_idx)
                Ti = Ts[data_idx]

                ### get prediction according to data
                cls, count = torch.unique(Ti, return_counts=True, sorted=True)
                pred = cls[torch.argmax(count)]
                p = (Ti==pred).type(torch.float32).sum()
                acc += p

                child = LeafNode()
                child.pred = int(pred)
                child.classes = cls.tolist()
                child.num_correct = int(p)
                child.train_indices = data_idx.cpu().reshape(-1)
                
                test_idx = torch.nonzero(_Cs == cls_idx)
                test_p = (_Ts[test_idx]==pred).type(torch.float32).sum()
                child.test_indices = test_idx.cpu().reshape(-1)
                child.test_correct = int(test_p)

                self.children.append(child)

                print(f"idx: {cls_idx}\tout: {int(pred)} \t acc: {p/len(Ti)*100 :.3f} \tclasses:{cls.tolist()}")

            print(f"Accuracy: {float(acc)/len(Ts)}")
        
        
    def display_stats(self, indexing):
        for i, c in enumerate(self.children):
            c.display_stats(indexing+f", {i}")
            
    def get_all_index(self, indexing:list, indx_lst):
        for i, c in enumerate(self.children):
            c.get_all_index(indexing+[i], indx_lst)
        pass

    def get_correct_train(self):
        a, b = 0, 0
        for i, c in enumerate(self.children):
            _a, _b = c.get_correct_train()
            a+= _a
            b+= _b
        return a, b
    
    def get_correct_test(self):
        a, b = 0, 0
        for i, c in enumerate(self.children):
            _a, _b = c.get_correct_test()
            a+= _a
            b+= _b
        return a, b
        
    def inference_forward(self, x):
        with torch.no_grad():
            zz = self.model(x)
            yout = self.classifier(zz, hard=True)
            return torch.argmax(classifier.cls_confidence, dim=1)

In [116]:
tree = ClassifierTree(train_dataset, test_dataset, device)

In [117]:
tree.root.create_network_0(784, [784], 10, 10)

In [118]:
tree.root.create_train_loader_1(train_dataset, 
                                torch.arange(0, len(train_dataset), dtype=torch.long), 
                                50)

Train Dataset Num: 60000


In [119]:
tree.root.create_test_loader_2(test_dataset, 
                               torch.arange(0, len(test_dataset), dtype=torch.long), 
                               50)

Test Dataset Num: 10000


In [120]:
tree.root.create_optimizer_3(lr=0.0003)

In [121]:
# list(tree.root.classifier.parameters())

In [122]:
# tree.root.model(torch.randn(10, 784).to(device))

In [123]:
tree.root.train_classifier_4(1)

100%|██████████| 1200/1200 [00:04<00:00, 267.31it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 0:0,  Loss:2.227100133895874


100%|██████████| 200/200 [00:00<00:00, 431.08it/s]

Train Acc:67.65%, Test Acc:68.96%

	-> MAX Train Acc 67.65 ; Test Acc 68.96





In [124]:
# asdfsdfasdf

In [125]:
# torch.isinf(tree.root.model.flows[0].s)

In [126]:
tree.root.freeze_and_compute_stats_5(MIN_POINTS=5)

100%|██████████| 1200/1200 [00:01<00:00, 693.72it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Keeping only N=10/10 classifiers.


100%|██████████| 1200/1200 [00:01<00:00, 719.26it/s]
100%|██████████| 200/200 [00:00<00:00, 527.08it/s]

Hard inference on the data !
idx: 0	out: 0 	 acc: 72.764 	classes:[0, 1, 2, 3, 4, 5, 6, 8, 9]
idx: 1	out: 1 	 acc: 90.466 	classes:[0, 1, 2, 3, 4, 5, 6, 8, 9]
idx: 2	out: 2 	 acc: 54.914 	classes:[0, 1, 2, 3, 4, 5, 6, 8, 9]
idx: 3	out: 3 	 acc: 70.162 	classes:[0, 1, 2, 3, 4, 5, 6, 8, 9]
idx: 4	out: 4 	 acc: 52.369 	classes:[0, 1, 2, 3, 4, 5, 6, 8, 9]
idx: 5	out: 5 	 acc: 66.358 	classes:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
idx: 6	out: 6 	 acc: 69.266 	classes:[0, 1, 2, 3, 4, 5, 6, 8, 9]
idx: 7	out: 7 	 acc: 70.514 	classes:[0, 2, 3, 4, 5, 6, 7, 8, 9]
idx: 8	out: 8 	 acc: 87.261 	classes:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
idx: 9	out: 9 	 acc: 73.111 	classes:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Accuracy: 0.6995833333333333





In [127]:
tree.display_stats()

[0, 0] : Train -> 0.7276 Test -> 0.7121, NUM: 6385, classes: 0:[0, 1, 2, 3, 4, 5, 6, 8, 9]
[0, 1] : Train -> 0.9047 Test -> 0.9049, NUM: 6251, classes: 1:[0, 1, 2, 3, 4, 5, 6, 8, 9]
[0, 2] : Train -> 0.5491 Test -> 0.5439, NUM: 7204, classes: 2:[0, 1, 2, 3, 4, 5, 6, 8, 9]
[0, 3] : Train -> 0.7016 Test -> 0.6900, NUM: 7142, classes: 3:[0, 1, 2, 3, 4, 5, 6, 8, 9]
[0, 4] : Train -> 0.5237 Test -> 0.4956, NUM: 7367, classes: 4:[0, 1, 2, 3, 4, 5, 6, 8, 9]
[0, 5] : Train -> 0.6636 Test -> 0.6917, NUM: 5291, classes: 5:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
[0, 6] : Train -> 0.6927 Test -> 0.6486, NUM: 872, classes: 6:[0, 1, 2, 3, 4, 5, 6, 8, 9]
[0, 7] : Train -> 0.7051 Test -> 0.7077, NUM: 7244, classes: 7:[0, 2, 3, 4, 5, 6, 7, 8, 9]
[0, 8] : Train -> 0.8726 Test -> 0.8446, NUM: 4765, classes: 8:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
[0, 9] : Train -> 0.7311 Test -> 0.7213, NUM: 7479, classes: 9:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Final Accuracy is Train: 0.69958 Test: 0.69010


In [128]:
node = tree.get_node([0, 0])

In [129]:
len(node.train_indices)

6385

In [141]:
tree.get_all_child_index()

[[0, 0, 0],
 [0, 0, 1],
 [0, 0, 2],
 [0, 0, 3],
 [0, 0, 4],
 [0, 0, 5],
 [0, 1],
 [0, 2],
 [0, 3],
 [0, 4],
 [0, 5],
 [0, 6],
 [0, 7],
 [0, 8],
 [0, 9]]

In [73]:
asdfsadf

NameError: name 'asdfsadf' is not defined

### Select a node and train a new classifier

In [142]:
### make selecting index automatic
###### for all child index analyse how much incorrect training examples are present

### 1. train the cild with maximum incorrect ones.
### 2. train the network for at least 30,000 steps:
###     for a 5K dataset with 50 batch size, 100 steps in one epoch, so train for 300 epochs

MIN_POINTS = 10
batch_size = 50

max_incorrect = 0
max_inc_node = None
train_epoch = None
for ci in tree.get_all_child_index():
    node = tree.get_node(ci)
    num_data = len(node.train_indices)
    if num_data < MIN_POINTS: continue
        
    incorrect = num_data - node.num_correct
    if incorrect > max_incorrect:
        max_incorrect = incorrect
        max_inc_node = ci
        steps_in_epoch = max(num_data/batch_size, 1)
        train_epoch = int(30000/steps_in_epoch)
        
print(f"Max incorrect: {max_incorrect}, {max_inc_node}, train for: {train_epoch}")

Max incorrect: 3509, [0, 4], train for: 203


In [145]:
# indx = [0, 0]
indx = max_inc_node
parent = tree.get_parent_node(indx)
node = tree.get_node(indx)

In [146]:
# node.train_indices, node.test_indices
node.pred, node.classes

(4, [0, 1, 2, 3, 4, 5, 6, 8, 9])

In [133]:
alt_node = LocalClassifier(device)

In [134]:
### make classifier with only available classes
avl_cls = node.classes
num_cls = len(node.classes)
output_dim = 10
num_sets = num_cls*2
init_val = torch.randn(num_sets, output_dim)*0.01
for ns in range(num_sets):
    init_val[ns, avl_cls[ns%num_cls]] = 2.

alt_node.create_network_0(784, [784], output_dim, num_sets)
alt_node.classifier.cls_weight.data = init_val.to(device)

alt_node.create_train_loader_1(train_dataset, node.train_indices, batch_size=50)
alt_node.create_test_loader_2(test_dataset, node.test_indices, batch_size=50)
alt_node.create_optimizer_3(lr=0.0003)

Train Dataset Num: 6385
Test Dataset Num: 1063


In [135]:
node.display_stats("")

[] : Train -> 0.7276 Test -> 0.7121, NUM: 6385, classes: 0:[0, 1, 2, 3, 4, 5, 6, 8, 9]


In [136]:
alt_node.train_classifier_4(3)

100%|██████████| 128/128 [00:00<00:00, 206.30it/s]
  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 0:0,  Loss:2.2397239208221436


100%|██████████| 22/22 [00:00<00:00, 144.51it/s]
  0%|          | 0/128 [00:00<?, ?it/s]

Train Acc:69.49%, Test Acc:71.50%



100%|██████████| 128/128 [00:00<00:00, 221.56it/s]
  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 1:0,  Loss:2.206883192062378


100%|██████████| 22/22 [00:00<00:00, 144.99it/s]
  0%|          | 0/128 [00:00<?, ?it/s]

Train Acc:73.08%, Test Acc:71.21%



100%|██████████| 128/128 [00:00<00:00, 213.14it/s]
  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 2:0,  Loss:2.1750831604003906


100%|██████████| 22/22 [00:00<00:00, 135.46it/s]

Train Acc:72.76%, Test Acc:71.21%

	-> MAX Train Acc 73.07752545027408 ; Test Acc 71.49576669802445





In [137]:
alt_node.freeze_and_compute_stats_5(MIN_POINTS)

100%|██████████| 128/128 [00:00<00:00, 438.08it/s]
  0%|          | 0/128 [00:00<?, ?it/s]

Keeping only N=6/18 classifiers.


100%|██████████| 128/128 [00:00<00:00, 452.79it/s]
100%|██████████| 22/22 [00:00<00:00, 152.27it/s]

Hard inference on the data !
idx: 0	out: 0 	 acc: 73.527 	classes:[0, 1, 2, 3, 4, 5, 6, 8, 9]
idx: 1	out: 2 	 acc: 57.143 	classes:[0, 2, 6]
idx: 2	out: 6 	 acc: 62.963 	classes:[0, 2, 6, 8]
idx: 3	out: 0 	 acc: 73.400 	classes:[0, 1, 2, 3, 4, 5, 6, 8, 9]
idx: 4	out: 2 	 acc: 66.667 	classes:[0, 2, 4, 5, 8]
idx: 5	out: 6 	 acc: 77.778 	classes:[0, 6]
Accuracy: 0.7332811276429131





In [138]:
alt_node.display_stats("")

[, 0] : Train -> 0.7353 Test -> 0.6894, NUM: 2410, classes: 0:[0, 1, 2, 3, 4, 5, 6, 8, 9]
[, 1] : Train -> 0.5714 Test -> 0.3333, NUM: 21, classes: 2:[0, 2, 6]
[, 2] : Train -> 0.6296 Test -> 0.5000, NUM: 27, classes: 6:[0, 2, 6, 8]
[, 3] : Train -> 0.7340 Test -> 0.7381, NUM: 3891, classes: 0:[0, 1, 2, 3, 4, 5, 6, 8, 9]
[, 4] : Train -> 0.6667 Test -> 0.6667, NUM: 27, classes: 2:[0, 2, 4, 5, 8]
[, 5] : Train -> 0.7778 Test -> -1.0000, NUM: 9, classes: 6:[0, 6]


In [139]:
### replace the leaf node with Local Classifier Node
parent.children[indx[-1]] = alt_node

In [140]:
#### After modification status
tree.display_stats()

[0, 0, 0] : Train -> 0.7353 Test -> 0.6894, NUM: 2410, classes: 0:[0, 1, 2, 3, 4, 5, 6, 8, 9]
[0, 0, 1] : Train -> 0.5714 Test -> 0.3333, NUM: 21, classes: 2:[0, 2, 6]
[0, 0, 2] : Train -> 0.6296 Test -> 0.5000, NUM: 27, classes: 6:[0, 2, 6, 8]
[0, 0, 3] : Train -> 0.7340 Test -> 0.7381, NUM: 3891, classes: 0:[0, 1, 2, 3, 4, 5, 6, 8, 9]
[0, 0, 4] : Train -> 0.6667 Test -> 0.6667, NUM: 27, classes: 2:[0, 2, 4, 5, 8]
[0, 0, 5] : Train -> 0.7778 Test -> -1.0000, NUM: 9, classes: 6:[0, 6]
[0, 1] : Train -> 0.9047 Test -> 0.9049, NUM: 6251, classes: 1:[0, 1, 2, 3, 4, 5, 6, 8, 9]
[0, 2] : Train -> 0.5491 Test -> 0.5439, NUM: 7204, classes: 2:[0, 1, 2, 3, 4, 5, 6, 8, 9]
[0, 3] : Train -> 0.7016 Test -> 0.6900, NUM: 7142, classes: 3:[0, 1, 2, 3, 4, 5, 6, 8, 9]
[0, 4] : Train -> 0.5237 Test -> 0.4956, NUM: 7367, classes: 4:[0, 1, 2, 3, 4, 5, 6, 8, 9]
[0, 5] : Train -> 0.6636 Test -> 0.6917, NUM: 5291, classes: 5:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
[0, 6] : Train -> 0.6927 Test -> 0.6486, NUM: 872, c

In [None]:
[0, 0, 0, 0] : pred -> 0.9799, NUM: 546, classes: [0, 2, 4, 6]
[0, 0, 0, 1] : pred -> 0.9983, NUM: 606, classes: [1, 2]
[0, 0, 0, 2] : pred -> 0.9831, NUM: 592, classes: [0, 2, 4, 6]
[0, 0, 0, 3] : pred -> 0.9788, NUM: 565, classes: [0, 1, 2, 3, 4, 6, 8]
[0, 0, 0, 4] : pred -> 0.9744, NUM: 586, classes: [1, 2, 3, 4, 6]
[0, 0, 0, 5] : pred -> 1.0000, NUM: 589, classes: [5]
[0, 0, 0, 6] : pred -> 0.9898, NUM: 588, classes: [0, 1, 2, 4, 6]
[0, 0, 0, 7] : pred -> 0.9877, NUM: 568, classes: [5, 7, 9]
[0, 0, 0, 8] : pred -> 0.9966, NUM: 592, classes: [0, 8]
[0, 0, 0, 9] : pred -> 0.9947, NUM: 565, classes: [5, 7, 9]
[0, 0, 1, 0] : pred -> 0.9571, NUM: 163, classes: [0, 6]
[0, 0, 1, 1] : pred -> 1.0000, NUM: 149, classes: [1]
[0, 0, 1, 2] : pred -> 0.9940, NUM: 168, classes: [2, 4]
[0, 0, 1, 3] : pred -> 0.9939, NUM: 164, classes: [3, 6]
[0, 0, 1, 4] : pred -> 0.9827, NUM: 173, classes: [2, 4, 6]
[0, 0, 1, 5] : pred -> 1.0000, NUM: 167, classes: [5]
[0, 0, 1, 6] : pred -> 1.0000, NUM: 143, classes: [6]
[0, 0, 1, 7] : pred -> 0.9934, NUM: 151, classes: [7, 8]
[0, 0, 1, 8] : pred -> 0.9864, NUM: 147, classes: [0, 6, 8]
[0, 0, 1, 9] : pred -> 0.9831, NUM: 177, classes: [7, 9]
[0, 1] : pred -> 0.9888, NUM: 5898, classes: [0, 1, 2, 3, 4, 6, 8, 9]
[0, 2, 0, 0] : pred -> 0.9571, NUM: 466, classes: [0, 1, 2, 4, 6, 8]
[0, 2, 0, 1] : pred -> 1.0000, NUM: 476, classes: [1]
[0, 2, 0, 2] : pred -> 0.9876, NUM: 483, classes: [2, 3, 4, 6]
[0, 2, 0, 3] : pred -> 0.9912, NUM: 457, classes: [0, 3, 4, 6]
[0, 2, 0, 4] : pred -> 0.9790, NUM: 477, classes: [2, 3, 4, 6]
[0, 2, 0, 5] : pred -> 1.0000, NUM: 451, classes: [5]
[0, 2, 0, 6] : pred -> 0.9956, NUM: 456, classes: [2, 4, 6]
[0, 2, 0, 7] : pred -> 0.9886, NUM: 440, classes: [5, 7, 8, 9]
[0, 2, 0, 8] : pred -> 0.9883, NUM: 426, classes: [0, 6, 8]
[0, 2, 0, 9] : pred -> 0.9918, NUM: 490, classes: [5, 7, 9]
[0, 2, 1] : pred -> 0.8632, NUM: 541, classes: [0, 1, 3, 4, 5, 6]
[0, 2, 2] : pred -> 0.8750, NUM: 8, classes: [0, 4]
[0, 2, 3] : pred -> 0.8986, NUM: 138, classes: [0, 1, 3, 4, 6, 8]
[0, 2, 4] : pred -> 0.8000, NUM: 5, classes: [0, 4]
[0, 3, 0] : pred -> 0.9524, NUM: 210, classes: [0, 1, 2, 6, 8]
[0, 3, 1] : pred -> 1.0000, NUM: 7, classes: [2]
[0, 3, 2] : pred -> 0.9417, NUM: 5988, classes: [0, 1, 2, 3, 4, 5, 6, 8]
[0, 3, 3, 0] : pred -> 0.9672, NUM: 61, classes: [0, 6]
[0, 3, 3, 1] : pred -> 1.0000, NUM: 56, classes: [1]
[0, 3, 3, 2] : pred -> 1.0000, NUM: 43, classes: [2]
[0, 3, 3, 3] : pred -> 0.9583, NUM: 48, classes: [3, 4, 6]
[0, 3, 3, 4] : pred -> 0.9811, NUM: 53, classes: [2, 4]
[0, 3, 3, 5] : pred -> 1.0000, NUM: 44, classes: [5]
[0, 3, 3, 6] : pred -> 0.9796, NUM: 49, classes: [0, 6]
[0, 3, 3, 7] : pred -> 1.0000, NUM: 53, classes: [7]
[0, 3, 3, 8] : pred -> 1.0000, NUM: 48, classes: [8]
[0, 3, 3, 9] : pred -> 1.0000, NUM: 53, classes: [9]
[0, 4, 0, 0] : pred -> 1.0000, NUM: 103, classes: [0]
[0, 4, 0, 1] : pred -> 1.0000, NUM: 117, classes: [1]
[0, 4, 0, 2] : pred -> 0.9910, NUM: 111, classes: [2, 6]
[0, 4, 0, 3] : pred -> 1.0000, NUM: 112, classes: [3]
[0, 4, 0, 4] : pred -> 1.0000, NUM: 117, classes: [4]
[0, 4, 0, 5] : pred -> 1.0000, NUM: 110, classes: [5]
[0, 4, 0, 6] : pred -> 1.0000, NUM: 119, classes: [6]
[0, 4, 0, 7] : pred -> 0.9831, NUM: 118, classes: [5, 7, 9]
[0, 4, 0, 8] : pred -> 1.0000, NUM: 103, classes: [8]
[0, 4, 0, 9] : pred -> 1.0000, NUM: 124, classes: [9]
[0, 4, 1, 0] : pred -> 0.9752, NUM: 483, classes: [0, 2, 6]
[0, 4, 1, 1] : pred -> 0.9981, NUM: 520, classes: [1, 3]
[0, 4, 1, 2] : pred -> 0.9693, NUM: 521, classes: [0, 2, 4, 6, 8]
[0, 4, 1, 3] : pred -> 0.9838, NUM: 493, classes: [1, 2, 3, 4, 6]
[0, 4, 1, 4] : pred -> 0.9745, NUM: 510, classes: [0, 2, 3, 4, 6]
[0, 4, 1, 5] : pred -> 1.0000, NUM: 497, classes: [5]
[0, 4, 1, 6] : pred -> 0.9917, NUM: 482, classes: [0, 2, 3, 4, 6]
[0, 4, 1, 7] : pred -> 0.9874, NUM: 476, classes: [5, 7, 9]
[0, 4, 1, 8] : pred -> 0.9939, NUM: 493, classes: [0, 6, 8]
[0, 4, 1, 9] : pred -> 0.9941, NUM: 511, classes: [7, 9]
[0, 5] : pred -> 0.9624, NUM: 5902, classes: [0, 1, 2, 3, 4, 5, 7, 8, 9]
[0, 6, 0] : pred -> 0.9314, NUM: 452, classes: [0, 1, 2, 3, 4, 8]
[0, 6, 1, 0] : pred -> 0.9726, NUM: 401, classes: [0, 2, 6]
[0, 6, 1, 1] : pred -> 1.0000, NUM: 403, classes: [1]
[0, 6, 1, 2] : pred -> 0.9851, NUM: 404, classes: [1, 2, 3, 4, 6]
[0, 6, 1, 3] : pred -> 0.9814, NUM: 376, classes: [1, 2, 3, 4, 6]
[0, 6, 1, 4] : pred -> 0.9754, NUM: 407, classes: [2, 3, 4, 6]
[0, 6, 1, 5] : pred -> 1.0000, NUM: 396, classes: [5]
[0, 6, 1, 6] : pred -> 0.9897, NUM: 388, classes: [2, 4, 6, 8]
[0, 6, 1, 7] : pred -> 0.9923, NUM: 391, classes: [5, 7, 8, 9]
[0, 6, 1, 8] : pred -> 0.9972, NUM: 362, classes: [7, 8]
[0, 6, 1, 9] : pred -> 0.9910, NUM: 444, classes: [5, 7, 9]
[0, 6, 2] : pred -> 1.0000, NUM: 11, classes: [8]
[0, 7, 0] : pred -> 1.0000, NUM: 100, classes: [5]
[0, 7, 1] : pred -> 0.9763, NUM: 5877, classes: [5, 7, 8, 9]
[0, 7, 2] : pred -> 0.7500, NUM: 8, classes: [6, 8]
[0, 7, 3] : pred -> 0.9838, NUM: 308, classes: [5, 9]
[0, 8] : pred -> 0.9482, NUM: 6123, classes: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
[0, 9] : pred -> 0.9605, NUM: 5803, classes: [0, 1, 4, 5, 6, 7, 8, 9]
Final Accuracy is : 0.97130