In [1]:
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

import torch
import torch.nn as nn
import numpy as np
import os
import sys
import torch.nn.functional as F
import torch.optim as optim
import subprocess
import glob
import scipy.cluster
import scipy.spatial
from sklearn.preprocessing import MinMaxScaler
from sklearn import metrics
import argparse
import pandas as pd
from sklearn.model_selection import train_test_split    


In [2]:
def multiClassHingeLoss(logits, labels):
    '''
    MultiClassHingeLoss to match C++ Version - No pytorch internal version
    '''
    flatLogits = torch.reshape(logits, [-1, ])
    labels_ = labels.argmax(dim=1)

    correctId = torch.arange(labels.shape[0]).to(
        logits.device) * labels.shape[1] + labels_
    correctLogit = torch.gather(flatLogits, 0, correctId)

    maxLabel = logits.argmax(dim=1)
    top2, _ = torch.topk(logits, k=2, sorted=True)

    wrongMaxLogit = torch.where((maxLabel == labels_), top2[:, 1], top2[:, 0])

    return torch.mean(F.relu(1. + wrongMaxLogit - correctLogit))


def crossEntropyLoss(logits, labels):
    '''
    Cross Entropy loss for MultiClass case in joint training for
    faster convergence
    '''
    return F.cross_entropy(logits, labels.argmax(dim=1))


def binaryHingeLoss(logits, labels):
    '''
    BinaryHingeLoss to match C++ Version - No pytorch internal version
    '''
    return torch.mean(F.relu(1.0 - (2 * labels - 1) * logits))


def hardThreshold(A: torch.Tensor, s):
    '''
    Hard thresholds and modifies in-palce nn.Parameter A with sparsity s 
    '''
    #PyTorch disallows numpy access/copy to tensors in graph.
    #.detach() creates a new tensor not attached to the graph.
    A_ = A.data.cpu().detach().numpy().ravel()    
    if len(A_) > 0:
        th = np.percentile(np.abs(A_), (1 - s) * 100.0, interpolation='higher')
        A_[np.abs(A_) < th] = 0.0
    A_ = A_.reshape(A.shape)
    return torch.tensor(A_, requires_grad=True)

def supportBasedThreshold(dst: torch.Tensor, src: torch.Tensor):
    '''
    zero out entries in dst.data that are zeros in src tensor
    '''
    return copySupport(src, dst.data)

def copySupport(src, dst):
    '''
    zero out entries in dst.data that are zeros in src tensor
    '''
    zeroSupport = (src.view(-1) == 0.0).nonzero()
    dst = dst.reshape(-1)
    dst[zeroSupport] = 0
    dst = dst.reshape(src.shape)
    del zeroSupport
    return dst


def estimateNNZ(A, s, bytesPerVar=4):
    '''
    Returns # of non-zeros and representative size of the tensor
    Uses dense for s >= 0.5 - 4 byte
    Else uses sparse - 8 byte
    '''
    params = 1
    hasSparse = False
    for i in range(0, len(A.shape)):
        params *= int(A.shape[i])
    if s < 0.5:
        nnZ = np.ceil(params * s)
        hasSparse = True
        return nnZ, nnZ * 2 * bytesPerVar, hasSparse
    else:
        nnZ = params
        return nnZ, nnZ * bytesPerVar, hasSparse


def countNNZ(A: torch.nn.Parameter, isSparse):
    '''
    Returns # of non-zeros 
    '''
    A_ = A.detach().numpy()
    if isSparse:
        return np.count_nonzero(A_)
    else:
        nnzs = 1
        for i in range(0, len(A.shape)):
            nnzs *= int(A.shape[i])
        return nnzs

def restructreMatrixBonsaiSeeDot(A, nClasses, nNodes):
    '''
    Restructures a matrix from [nNodes*nClasses, Proj] to
    [nClasses*nNodes, Proj] for SeeDot
    '''
    tempMatrix = np.zeros(A.shape)
    rowIndex = 0

    for i in range(0, nClasses):
        for j in range(0, nNodes):
            tempMatrix[rowIndex] = A[j * nClasses + i]
            rowIndex += 1

    return tempMatrix

class TriangularLR(optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, stepsize, lr_min, lr_max, gamma):
        self.stepsize = stepsize
        self.lr_min = lr_min
        self.lr_max = lr_max
        self.gamma = gamma
        super(TriangularLR, self).__init__(optimizer)

    def get_lr(self):
        it = self.last_epoch
        cycle = math.floor(1 + it / (2 * self.stepsize))
        x = abs(it / self.stepsize - 2 * cycle + 1)
        decayed_range = (self.lr_max - self.lr_min) * self.gamma ** (it / 3)
        lr = self.lr_min + decayed_range * x
        return [lr]

class ExponentialResettingLR(optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, gamma, reset_epoch):
        self.gamma = gamma
        self.reset_epoch = int(reset_epoch)
        super(ExponentialResettingLR, self).__init__(optimizer)

    def get_lr(self):
        epoch = self.last_epoch
        if epoch > self.reset_epoch:
            epoch -= self.reset_epoch
        return [base_lr * self.gamma ** epoch
                for base_lr in self.base_lrs]

In [3]:
class ProtoNN(nn.Module):
    def __init__(self, inputDimension, projectionDimension, numPrototypes,
                 numOutputLabels, gamma, W=None, B=None, Z=None):
        '''
        Forward computation graph for ProtoNN.
        inputDimension: Input data dimension or feature dimension.
        projectionDimension: hyperparameter
        numPrototypes: hyperparameter
        numOutputLabels: The number of output labels or classes
        W, B, Z: Numpy matrices that can be used to initialize
            projection matrix(W), prototype matrix (B) and prototype labels
            matrix (B).
            Expected Dimensions:
                W   inputDimension (d) x projectionDimension (d_cap)
                B   projectionDimension (d_cap) x numPrototypes (m)
                Z   numOutputLabels (L) x numPrototypes (m)
        '''
        super(ProtoNN, self).__init__()
        self.__d = inputDimension
        self.__d_cap = projectionDimension
        self.__m = numPrototypes
        self.__L = numOutputLabels

        self.W, self.B, self.Z = None, None, None
        self.gamma = gamma

        self.__validInit = False
        self.__initWBZ(W, B, Z)
        self.__validateInit()

    def __validateInit(self):
        self.__validinit = False
        errmsg = "Dimensions mismatch! Should be W[d, d_cap]"
        errmsg+= ", B[d_cap, m] and Z[L, m]"
        d, d_cap, m, L, _ = self.getHyperParams()
        assert self.W.shape[0] == d, errmsg
        assert self.W.shape[1] == d_cap, errmsg
        assert self.B.shape[0] == d_cap, errmsg
        assert self.B.shape[1] == m, errmsg
        assert self.Z.shape[0] == L, errmsg
        assert self.Z.shape[1] == m, errmsg
        self.__validInit = True

    def __initWBZ(self, inW, inB, inZ):
        if inW is None:
            self.W = torch.randn([self.__d, self.__d_cap])
            self.W = nn.Parameter(self.W)
        else:
            self.W = nn.Parameter(torch.from_numpy(inW.astype(np.float32)))

        if inB is None:
            self.B = torch.randn([self.__d_cap, self.__m])
            self.B = nn.Parameter(self.B)
        else:
            self.B = nn.Parameter(torch.from_numpy(inB.astype(np.float32)))

        if inZ is None:
            self.Z = torch.randn([self.__L, self.__m])
            self.Z = nn.Parameter(self.Z)
        else:
            self.Z = nn.Parameter(torch.from_numpy(inZ.astype(np.float32)))

    def getHyperParams(self):
        '''
        Returns the model hyperparameters:
            [inputDimension, projectionDimension, numPrototypes,
            numOutputLabels, gamma]
        '''
        d =  self.__d
        dcap = self.__d_cap
        m = self.__m
        L = self.__L
        return d, dcap, m, L, self.gamma

    def getModelMatrices(self):
        '''
        Returns model matrices, which can then be evaluated to obtain
        corresponding numpy arrays.  These can then be exported as part of
        other implementations of ProtonNN, for instance a C++ implementation or
        pure python implementation.
        Returns
            [ProjectionMatrix (W), prototypeMatrix (B),
             prototypeLabelsMatrix (Z), gamma]
        '''
        return self.W, self.B, self.Z, self.gamma

    def forward(self, X):
        '''
        This method is responsible for construction of the forward computation
        graph. The end point of the computation graph, or in other words the
        output operator for the forward computation is returned.
        X: Input of shape [-1, inputDimension]
        returns: The forward computation outputs, self.protoNNOut
        '''
        assert self.__validInit is True, "Initialization failed!"

        W, B, Z, gamma = self.W, self.B, self.Z, self.gamma
        WX = torch.matmul(X, W)
        dim = [-1, WX.shape[1], 1]
        WX = torch.reshape(WX, dim)
        dim = [1, B.shape[0], -1]
        B_ = torch.reshape(B, dim)
        l2sim = B_ - WX
        l2sim = torch.pow(l2sim, 2)
        l2sim = torch.sum(l2sim, dim=1, keepdim=True)
        self.l2sim = l2sim
        gammal2sim = (-1 * gamma * gamma) * l2sim
        M = torch.exp(gammal2sim)
        dim = [1] + list(Z.shape)
        Z_ = torch.reshape(Z, dim)
        y = Z_ * M
        y = torch.sum(y, dim=2)
        return y


In [33]:
class ProtoNNTrainer:

    def __init__(self, protoNNObj, regW, regB, regZ, sparcityW, sparcityB,
                 sparcityZ, learningRate, lossType='l2', device=None):
        '''
        A wrapper for the various techniques used for training ProtoNN. This
        subsumes both the responsibility of loss graph construction and
        performing training. The original training routine that is part of the
        C++ implementation of EdgeML used iterative hard thresholding (IHT),
        gamma estimation through median heuristic and other tricks for
        training ProtoNN. This module implements the same in pytorch
        and python.
        protoNNObj: An instance of ProtoNN class defining the forward
            computation graph. The loss functions and training routines will be
            attached to this instance.
        regW, regB, regZ: Regularization constants for W, B, and
            Z matrices of protoNN.
        sparcityW, sparcityB, sparcityZ: Sparsity constraints
            for W, B and Z matrices. A value between 0 (exclusive) and 1
            (inclusive) is expected. A value of 1 indicates dense training.
        learningRate: Initial learning rate for ADAM optimizer.
        X, Y : Placeholders for data and labels.
            X [-1, featureDimension]
            Y [-1, num Labels]
        lossType: ['l2', 'xentropy']
        '''
        self.protoNNObj = protoNNObj
        self.__regW = regW
        self.__regB = regB
        self.__regZ = regZ
        self.__sW = sparcityW
        self.__sB = sparcityB
        self.__sZ = sparcityZ
        self.__lR = learningRate
        self.sparseTraining = True
        if (sparcityW == 1.0) and (sparcityB == 1.0) and (sparcityZ == 1.0):
            self.sparseTraining = False
            print("Sparse training disabled.", file=sys.stderr)
        self.W_th = None
        self.B_th = None
        self.Z_th = None
        self.__lossType = lossType
        self.optimizer = self.__optimizer()
        self.lossCriterion = None
        #assert lossType in ['l2', 'xentropy']
        if lossType == 'l2':
            self.lossCriterion = torch.nn.MSELoss()
            print("Using L2 (MSE) loss")
        else :
            self.lossCriterion = torch.nn.CrossEntropyLoss()
            print("Using x-entropy loss")
        self.__validInit = False
        self.__validInit = self.__validateInit()
        if device is None:
            self.device = "cpu"
        else:
            self.device = device

    def __validateInit(self):
        assert self.__validInit == False
        msg = "Sparsity values should be between 0 and 1 (both inclusive)"
        assert 0 <= self.__sW <= 1, msg
        assert 0 <= self.__sB <= 1, msg
        assert 0 <= self.__sZ <= 1, msg
        return True

    def __optimizer(self):
        optimizer = torch.optim.Adam(self.protoNNObj.parameters(),
                                     lr=self.__lR)
        return optimizer

    def loss(self, logits, labels_or_target):
        labels = labels_or_target
        assert len(logits) == len(labels)
        assert len(labels.shape) == 2
        assert len(logits.shape) == 2
        regLoss = (self.__regW * (torch.norm(self.protoNNObj.W)**2) +
                   self.__regB * (torch.norm(self.protoNNObj.B)**2) +
                   self.__regZ * (torch.norm(self.protoNNObj.Z)**2))
        if self.__lossType == 'xentropy':
            _, labels = torch.max(labels, dim=1)
            assert len(labels.shape)== 1
        loss = self.lossCriterion(logits, labels) + regLoss
        return loss

    def accuracy(self, predictions, labels):
        '''
        Returns accuracy and number of correct predictions.
        '''
        assert len(predictions.shape) == 1
        assert len(labels.shape) == 1
        assert len(predictions) == len(labels)
        correct = (predictions == labels).double()
        numCorrect = torch.sum(correct)
        acc = torch.mean(correct)
        return acc, numCorrect

    def hardThreshold(self):
        prtn = self.protoNNObj
        W, B, Z = prtn.W.data, prtn.B.data, prtn.Z.data
        newW = hardThreshold(W, self.__sW)
        newB = hardThreshold(B, self.__sB)
        newZ = hardThreshold(Z, self.__sZ)
        prtn.W.data = torch.FloatTensor(newW).to(self.device)
        prtn.B.data = torch.FloatTensor(newB).to(self.device)
        prtn.Z.data = torch.FloatTensor(newZ).to(self.device)
        
    def estimate_confusion_matrix(decision,y_val):
        L=3; D=3
        ConfusionMatrix=np.zeros((D,L))
        for d in range(D):
            for l in range(L):
                idx=((decision==d) & (y_val == l)) 

                if y_val[l]==0:
                    ConfusionMatrix[d][l]=sum(idx)
                elif y_val[l]==1:
                    ConfusionMatrix[d][l]=sum(idx)
                else:
                    ConfusionMatrix[d][l]= sum(idx)# / sum((labels == l))
        return ConfusionMatrix

    def train(self, batchSize, epochs, x_train, x_val, y_train, y_val,
              printStep=10, valStep=1):
        '''
        Performs dense training of ProtoNN followed by iterative hard
        thresholding to enforce sparsity constraints.
        batchSize: Batch size per update
        epochs : The number of epochs to run training for. One epoch is
            defined as one pass over the entire training data.
        x_train, x_val, y_train, y_val: The numpy array containing train and
            validation data. x data is assumed to in of shape [-1,
            featureDimension] while y should have shape [-1, numberLabels].
        printStep: Number of batches between echoing of loss and train accuracy.
        valStep: Number of epochs between evaluations on validation set.
        '''
        d, dcap, m, L, _ = self.protoNNObj.getHyperParams()
        assert batchSize >= 1, 'Batch size should be positive integer'
        assert epochs >= 1, 'Total epochs should be positive integer'
        assert x_train.ndim == 2, 'Expected training data to be of rank 2'
        assert x_train.shape[1] == d, 'Expected x_train to be [-1, %d]' % d
        assert x_val.ndim == 2, 'Expected validation data to be of rank 2'
        assert x_val.shape[1] == d, 'Expected x_val to be [-1, %d]' % d
        assert y_train.ndim == 2, 'Expected training labels to be of rank 2'
        assert y_train.shape[1] == L, 'Expected y_train to be [-1, %d]' % L
        assert y_val.ndim == 2, 'Expected validation labels to be of rank 2'
        assert y_val.shape[1] == L, 'Expected y_val to be [-1, %d]' % L

        trainNumBatches = int(np.ceil(len(x_train) / batchSize))
        valNumBatches = int(np.ceil(len(x_val) / batchSize))
        x_train_batches = np.array_split(x_train, trainNumBatches)
        y_train_batches = np.array_split(y_train, trainNumBatches)
        x_val_batches = np.array_split(x_val, valNumBatches)
        y_val_batches = np.array_split(y_val, valNumBatches)
        vacc=0; pred=0
        for epoch in range(epochs):
            for i in range(len(x_train_batches)):
                x_batch, y_batch = x_train_batches[i], y_train_batches[i]
                x_batch, y_batch = torch.Tensor(x_batch), torch.Tensor(y_batch)
                x_batch, y_batch = x_batch.to(self.device), y_batch.to(self.device)
                self.optimizer.zero_grad()
                logits = self.protoNNObj.forward(x_batch)
                loss = self.loss(logits, y_batch)
                loss.backward()
                self.optimizer.step()
                _, predictions = torch.max(logits, dim=1)
                _, target = torch.max(y_batch, dim=1)
                acc, _ = self.accuracy(predictions, target)
                if i % printStep == 0:
                    print("Epoch %d batch %d loss %f acc %f" % (epoch, i, loss,acc))
            # Perform IHT Here.
            if self.sparseTraining:
                self.hardThreshold()
            # Perform validation set evaluation
            if (epoch + 1) % valStep == 0:
                numCorrect = 0
                for i in range(len(x_val_batches)):
                    x_batch, y_batch = x_val_batches[i], y_val_batches[i]
                    x_batch, y_batch = torch.Tensor(x_batch), torch.Tensor(y_batch)
                    x_batch, y_batch = x_batch.to(self.device), y_batch.to(self.device)
                    logits = self.protoNNObj.forward(x_batch)
                    _, predictions = torch.max(logits, dim=1)
                    _, target = torch.max(y_batch, dim=1)
                    _, count = self.accuracy(predictions, target)
                    numCorrect += count
                    vacc +=  (numCorrect / len(x_val))
#                     pred+=predictions
                print("Validation accuracy: %f" % (numCorrect / len(x_val)))
#         print(len(pred), (y_val).shape)

        print("Accuracy: ", vacc/ len(y_val_batches))
        print("Loss: ", loss/ len(y_val_batches))
#         estimate_confusion_matrix(y_val, )

In [5]:
#Extract the features and the predictors
time_data = pd.read_csv("/Users/vanshika/Downloads/dataset_fog_release/dataset_fog_release/dataset/time.csv")
target = time_data['0'] #action 0, 1, 2
time_data = time_data.drop(['0'], axis = 1)

scaler = MinMaxScaler((-1, 1)) #scaling
X = scaler.fit_transform(time_data)
Y = target

#Split training data 
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size = .25, random_state = 7)

dataDimension = X.shape[1]
numClasses = len(np.unique(Y))

print("Feature Dimension: ", dataDimension)
print("Num classes: ", numClasses)

Feature Dimension:  45
Num classes:  3


In [1]:
# for i in range(len(target)):
#     if target[i] == 0:
#         target= target.drop(i)

In [6]:
# one hot y-train
lab = Y_train.astype('uint8')
lab = np.array(lab) - min(lab)
lab_ = np.zeros((X_train.shape[0], numClasses))
lab_[np.arange(X_train.shape[0]), lab] = 1
y_train = lab_

# one hot y-test
lab = Y_test.astype('uint8')
lab = np.array(lab) - min(lab)
lab_ = np.zeros((X_test.shape[0], numClasses))
lab_[np.arange(X_test.shape[0]), lab] = 1
y_test = lab_

In [35]:
PROJECTION_DIM = 5 #d^
NUM_PROTOTYPES = 40 #m
REG_W = 0.000005
REG_B = 0.0
REG_Z = 0.00005
SPAR_W = 1.0
SPAR_B = 0.8
SPAR_Z = 0.8
LEARNING_RATE = 0.05
NUM_EPOCHS = 200
BATCH_SIZE = 32
GAMMA = 0.024634

In [36]:
#utils

def medianHeuristic(data, projectionDimension, numPrototypes, W_init=None):
    '''
    This method can be used to estimate gamma for ProtoNN. An approximation to
    median heuristic is used here.
    1. First the data is collapsed into the projectionDimension by W_init. If
    W_init is not provided, it is initialized from a random normal(0, 1). Hence
    data normalization is essential.
    2. Prototype are computed by running a  k-means clustering on the projected
    data.
    3. The median distance is then estimated by calculating median distance
    between prototypes and projected data points.

    data needs to be [-1, numFeats]
    If using this method to initialize gamma, please use the W and B as well.

    TODO: Return estimate of Z (prototype labels) based on cluster centroids
    andand labels

    TODO: Clustering fails due to singularity error if projecting upwards

    W [dxd_cap]
    B [d_cap, m]
    returns gamma, W, B
    '''
    assert data.ndim == 2
    X = data
    featDim = data.shape[1]
    if projectionDimension > featDim:
        print("Warning: Projection dimension > feature dimension. Gamma")
        print("\t estimation due to median heuristic could fail.")
        print("\tTo retain the projection dataDimension, provide")
        print("\ta value for gamma.")

    if W_init is None:
        W_init = np.random.normal(size=[featDim, projectionDimension])
    W = W_init
    XW = np.matmul(X, W)
    assert XW.shape[1] == projectionDimension
    assert XW.shape[0] == len(X)
    # Requires [N x d_cap] data matrix of N observations of d_cap-dimension and
    # the number of centroids m. Returns, [n x d_cap] centroids and
    # elementwise center information.
    B, centers = scipy.cluster.vq.kmeans2(XW, numPrototypes)
    # Requires two matrices. Number of observations x dimension of observation
    # space. Distances[i,j] is the distance between XW[i] and B[j]
    distances = scipy.spatial.distance.cdist(XW, B, metric='euclidean')
    distances = np.reshape(distances, [-1])
    gamma = np.median(distances)
    gamma = 1 / (2.5 * gamma)
    return gamma.astype('float32'), W.astype('float32'), B.T.astype('float32')

#helper methods
def getGamma(gammaInit, projectionDim, dataDim, numPrototypes, x_train):
    if gammaInit is None:
        print("Using median heuristic to estimate gamma.")
        gamma, W, B = medianHeuristic(x_train, projectionDim,
                                            numPrototypes)
        print("Gamma estimate is: %f" % gamma)
        return W, B, gamma
    return None, None, gammaInit

In [37]:
W, B, gamma = getGamma(GAMMA, PROJECTION_DIM, dataDimension,
                       NUM_PROTOTYPES, X_train) #x_train for small gamma

In [38]:
# Setup input and train protoNN
protoNN = ProtoNN(dataDimension, PROJECTION_DIM,
                  NUM_PROTOTYPES, numClasses,
                  gamma, W=W, B=B)

trainer = ProtoNNTrainer(protoNN, REG_W, REG_B, REG_Z,
                         SPAR_W, SPAR_B, SPAR_Z,
                         LEARNING_RATE, lossType='l2')

trainer.train(BATCH_SIZE, NUM_EPOCHS, X_train, X_test, y_train, y_test,
              printStep=100, valStep=1)

Using L2 (MSE) loss
Epoch 0 batch 0 loss 13.868752 acc 0.031250
Epoch 0 batch 100 loss 0.133627 acc 0.843750
Epoch 0 batch 200 loss 0.117879 acc 0.812500
Epoch 0 batch 300 loss 0.087056 acc 0.875000
Epoch 0 batch 400 loss 0.123929 acc 0.781250
Epoch 0 batch 500 loss 0.067556 acc 0.906250
Epoch 0 batch 600 loss 0.080463 acc 0.875000
Epoch 0 batch 700 loss 0.095900 acc 0.843750
Epoch 0 batch 800 loss 0.080475 acc 0.875000
Epoch 0 batch 900 loss 0.083544 acc 0.875000
Epoch 0 batch 1000 loss 0.093888 acc 0.843750
Epoch 0 batch 1100 loss 0.118266 acc 0.781250
Validation accuracy: 0.825322
Epoch 1 batch 0 loss 0.053911 acc 0.937500
Epoch 1 batch 100 loss 0.095737 acc 0.843750
Epoch 1 batch 200 loss 0.111083 acc 0.812500
Epoch 1 batch 300 loss 0.081878 acc 0.875000
Epoch 1 batch 400 loss 0.115514 acc 0.781250
Epoch 1 batch 500 loss 0.076038 acc 0.906250
Epoch 1 batch 600 loss 0.087231 acc 0.875000
Epoch 1 batch 700 loss 0.095529 acc 0.843750
Epoch 1 batch 800 loss 0.078349 acc 0.875000
Epoch 

Epoch 14 batch 600 loss 0.097010 acc 0.843750
Epoch 14 batch 700 loss 0.086997 acc 0.843750
Epoch 14 batch 800 loss 0.081775 acc 0.875000
Epoch 14 batch 900 loss 0.072508 acc 0.875000
Epoch 14 batch 1000 loss 0.078148 acc 0.843750
Epoch 14 batch 1100 loss 0.104319 acc 0.781250
Validation accuracy: 0.829663
Epoch 15 batch 0 loss 0.041701 acc 0.937500
Epoch 15 batch 100 loss 0.084422 acc 0.812500
Epoch 15 batch 200 loss 0.096896 acc 0.812500
Epoch 15 batch 300 loss 0.086139 acc 0.875000
Epoch 15 batch 400 loss 0.105654 acc 0.781250
Epoch 15 batch 500 loss 0.076539 acc 0.906250
Epoch 15 batch 600 loss 0.096697 acc 0.843750
Epoch 15 batch 700 loss 0.087036 acc 0.843750
Epoch 15 batch 800 loss 0.081471 acc 0.875000
Epoch 15 batch 900 loss 0.072299 acc 0.875000
Epoch 15 batch 1000 loss 0.078548 acc 0.843750
Epoch 15 batch 1100 loss 0.104005 acc 0.781250
Validation accuracy: 0.829505
Epoch 16 batch 0 loss 0.041221 acc 0.937500
Epoch 16 batch 100 loss 0.084417 acc 0.812500
Epoch 16 batch 200 l

Epoch 28 batch 900 loss 0.072211 acc 0.875000
Epoch 28 batch 1000 loss 0.079279 acc 0.843750
Epoch 28 batch 1100 loss 0.101999 acc 0.781250
Validation accuracy: 0.829426
Epoch 29 batch 0 loss 0.040774 acc 0.937500
Epoch 29 batch 100 loss 0.084220 acc 0.812500
Epoch 29 batch 200 loss 0.095173 acc 0.812500
Epoch 29 batch 300 loss 0.084571 acc 0.875000
Epoch 29 batch 400 loss 0.104428 acc 0.812500
Epoch 29 batch 500 loss 0.074782 acc 0.906250
Epoch 29 batch 600 loss 0.096529 acc 0.843750
Epoch 29 batch 700 loss 0.090175 acc 0.843750
Epoch 29 batch 800 loss 0.082007 acc 0.875000
Epoch 29 batch 900 loss 0.072304 acc 0.875000
Epoch 29 batch 1000 loss 0.079258 acc 0.843750
Epoch 29 batch 1100 loss 0.101833 acc 0.781250
Validation accuracy: 0.829426
Epoch 30 batch 0 loss 0.040511 acc 0.937500
Epoch 30 batch 100 loss 0.084252 acc 0.812500
Epoch 30 batch 200 loss 0.095140 acc 0.812500
Epoch 30 batch 300 loss 0.083770 acc 0.875000
Epoch 30 batch 400 loss 0.104446 acc 0.812500
Epoch 30 batch 500 l

Validation accuracy: 0.829426
Epoch 43 batch 0 loss 0.040702 acc 0.937500
Epoch 43 batch 100 loss 0.083902 acc 0.843750
Epoch 43 batch 200 loss 0.094695 acc 0.812500
Epoch 43 batch 300 loss 0.081130 acc 0.875000
Epoch 43 batch 400 loss 0.103757 acc 0.812500
Epoch 43 batch 500 loss 0.071947 acc 0.906250
Epoch 43 batch 600 loss 0.097631 acc 0.843750
Epoch 43 batch 700 loss 0.087499 acc 0.843750
Epoch 43 batch 800 loss 0.079884 acc 0.875000
Epoch 43 batch 900 loss 0.072486 acc 0.875000
Epoch 43 batch 1000 loss 0.079789 acc 0.843750
Epoch 43 batch 1100 loss 0.100296 acc 0.781250
Validation accuracy: 0.829426
Epoch 44 batch 0 loss 0.040709 acc 0.937500
Epoch 44 batch 100 loss 0.083919 acc 0.843750
Epoch 44 batch 200 loss 0.094650 acc 0.812500
Epoch 44 batch 300 loss 0.081373 acc 0.875000
Epoch 44 batch 400 loss 0.103682 acc 0.812500
Epoch 44 batch 500 loss 0.071807 acc 0.906250
Epoch 44 batch 600 loss 0.097735 acc 0.843750
Epoch 44 batch 700 loss 0.087474 acc 0.843750
Epoch 44 batch 800 los

Epoch 57 batch 300 loss 0.084881 acc 0.875000
Epoch 57 batch 400 loss 0.103392 acc 0.812500
Epoch 57 batch 500 loss 0.072581 acc 0.906250
Epoch 57 batch 600 loss 0.099462 acc 0.843750
Epoch 57 batch 700 loss 0.088525 acc 0.843750
Epoch 57 batch 800 loss 0.079745 acc 0.875000
Epoch 57 batch 900 loss 0.072599 acc 0.875000
Epoch 57 batch 1000 loss 0.079630 acc 0.843750
Epoch 57 batch 1100 loss 0.098335 acc 0.781250
Validation accuracy: 0.829347
Epoch 58 batch 0 loss 0.040335 acc 0.937500
Epoch 58 batch 100 loss 0.083486 acc 0.843750
Epoch 58 batch 200 loss 0.094038 acc 0.812500
Epoch 58 batch 300 loss 0.084542 acc 0.875000
Epoch 58 batch 400 loss 0.103397 acc 0.812500
Epoch 58 batch 500 loss 0.072926 acc 0.906250
Epoch 58 batch 600 loss 0.099581 acc 0.843750
Epoch 58 batch 700 loss 0.088580 acc 0.843750
Epoch 58 batch 800 loss 0.079618 acc 0.875000
Epoch 58 batch 900 loss 0.072614 acc 0.875000
Epoch 58 batch 1000 loss 0.079580 acc 0.843750
Epoch 58 batch 1100 loss 0.099699 acc 0.781250
Va

Epoch 71 batch 600 loss 0.099178 acc 0.843750
Epoch 71 batch 700 loss 0.086844 acc 0.843750
Epoch 71 batch 800 loss 0.079654 acc 0.875000
Epoch 71 batch 900 loss 0.072266 acc 0.875000
Epoch 71 batch 1000 loss 0.080006 acc 0.843750
Epoch 71 batch 1100 loss 0.102016 acc 0.781250
Validation accuracy: 0.830215
Epoch 72 batch 0 loss 0.041463 acc 0.937500
Epoch 72 batch 100 loss 0.084061 acc 0.843750
Epoch 72 batch 200 loss 0.094004 acc 0.812500
Epoch 72 batch 300 loss 0.081708 acc 0.875000
Epoch 72 batch 400 loss 0.103288 acc 0.812500
Epoch 72 batch 500 loss 0.075190 acc 0.906250
Epoch 72 batch 600 loss 0.098694 acc 0.843750
Epoch 72 batch 700 loss 0.086881 acc 0.843750
Epoch 72 batch 800 loss 0.079715 acc 0.875000
Epoch 72 batch 900 loss 0.072132 acc 0.875000
Epoch 72 batch 1000 loss 0.080054 acc 0.843750
Epoch 72 batch 1100 loss 0.100082 acc 0.781250
Validation accuracy: 0.829900
Epoch 73 batch 0 loss 0.042178 acc 0.937500
Epoch 73 batch 100 loss 0.084203 acc 0.843750
Epoch 73 batch 200 l

Epoch 85 batch 900 loss 0.071979 acc 0.875000
Epoch 85 batch 1000 loss 0.080403 acc 0.843750
Epoch 85 batch 1100 loss 0.104470 acc 0.781250
Validation accuracy: 0.830137
Epoch 86 batch 0 loss 0.041771 acc 0.937500
Epoch 86 batch 100 loss 0.084426 acc 0.843750
Epoch 86 batch 200 loss 0.094202 acc 0.812500
Epoch 86 batch 300 loss 0.085865 acc 0.875000
Epoch 86 batch 400 loss 0.102713 acc 0.812500
Epoch 86 batch 500 loss 0.076175 acc 0.906250
Epoch 86 batch 600 loss 0.096601 acc 0.843750
Epoch 86 batch 700 loss 0.087806 acc 0.843750
Epoch 86 batch 800 loss 0.079369 acc 0.875000
Epoch 86 batch 900 loss 0.071939 acc 0.875000
Epoch 86 batch 1000 loss 0.080399 acc 0.843750
Epoch 86 batch 1100 loss 0.102185 acc 0.781250
Validation accuracy: 0.832426
Epoch 87 batch 0 loss 0.045230 acc 0.937500
Epoch 87 batch 100 loss 0.084393 acc 0.843750
Epoch 87 batch 200 loss 0.094265 acc 0.812500
Epoch 87 batch 300 loss 0.083560 acc 0.875000
Epoch 87 batch 400 loss 0.102853 acc 0.812500
Epoch 87 batch 500 l

Validation accuracy: 0.830689
Epoch 100 batch 0 loss 0.042047 acc 0.937500
Epoch 100 batch 100 loss 0.084805 acc 0.843750
Epoch 100 batch 200 loss 0.094387 acc 0.812500
Epoch 100 batch 300 loss 0.082038 acc 0.875000
Epoch 100 batch 400 loss 0.102912 acc 0.812500
Epoch 100 batch 500 loss 0.075694 acc 0.906250
Epoch 100 batch 600 loss 0.096277 acc 0.843750
Epoch 100 batch 700 loss 0.088155 acc 0.843750
Epoch 100 batch 800 loss 0.078920 acc 0.875000
Epoch 100 batch 900 loss 0.071946 acc 0.875000
Epoch 100 batch 1000 loss 0.080469 acc 0.843750
Epoch 100 batch 1100 loss 0.104749 acc 0.781250
Validation accuracy: 0.832110
Epoch 101 batch 0 loss 0.043937 acc 0.937500
Epoch 101 batch 100 loss 0.085013 acc 0.843750
Epoch 101 batch 200 loss 0.094407 acc 0.812500
Epoch 101 batch 300 loss 0.082384 acc 0.875000
Epoch 101 batch 400 loss 0.102862 acc 0.812500
Epoch 101 batch 500 loss 0.076222 acc 0.906250
Epoch 101 batch 600 loss 0.096175 acc 0.843750
Epoch 101 batch 700 loss 0.088137 acc 0.843750
Ep

Validation accuracy: 0.831873
Epoch 114 batch 0 loss 0.043205 acc 0.937500
Epoch 114 batch 100 loss 0.085222 acc 0.843750
Epoch 114 batch 200 loss 0.094540 acc 0.812500
Epoch 114 batch 300 loss 0.083840 acc 0.875000
Epoch 114 batch 400 loss 0.102917 acc 0.812500
Epoch 114 batch 500 loss 0.075200 acc 0.906250
Epoch 114 batch 600 loss 0.096403 acc 0.843750
Epoch 114 batch 700 loss 0.088054 acc 0.843750
Epoch 114 batch 800 loss 0.078611 acc 0.875000
Epoch 114 batch 900 loss 0.071934 acc 0.875000
Epoch 114 batch 1000 loss 0.080510 acc 0.843750
Epoch 114 batch 1100 loss 0.104297 acc 0.781250
Validation accuracy: 0.829821
Epoch 115 batch 0 loss 0.040899 acc 0.937500
Epoch 115 batch 100 loss 0.085062 acc 0.843750
Epoch 115 batch 200 loss 0.094469 acc 0.812500
Epoch 115 batch 300 loss 0.083599 acc 0.875000
Epoch 115 batch 400 loss 0.102881 acc 0.812500
Epoch 115 batch 500 loss 0.075354 acc 0.906250
Epoch 115 batch 600 loss 0.096262 acc 0.843750
Epoch 115 batch 700 loss 0.088059 acc 0.843750
Ep

Validation accuracy: 0.829900
Epoch 128 batch 0 loss 0.041326 acc 0.937500
Epoch 128 batch 100 loss 0.085114 acc 0.843750
Epoch 128 batch 200 loss 0.094556 acc 0.812500
Epoch 128 batch 300 loss 0.082522 acc 0.875000
Epoch 128 batch 400 loss 0.102998 acc 0.812500
Epoch 128 batch 500 loss 0.074875 acc 0.906250
Epoch 128 batch 600 loss 0.096381 acc 0.843750
Epoch 128 batch 700 loss 0.088054 acc 0.843750
Epoch 128 batch 800 loss 0.078482 acc 0.875000
Epoch 128 batch 900 loss 0.071944 acc 0.875000
Epoch 128 batch 1000 loss 0.080499 acc 0.843750
Epoch 128 batch 1100 loss 0.104875 acc 0.781250
Validation accuracy: 0.829900
Epoch 129 batch 0 loss 0.040500 acc 0.937500
Epoch 129 batch 100 loss 0.085374 acc 0.843750
Epoch 129 batch 200 loss 0.094541 acc 0.812500
Epoch 129 batch 300 loss 0.086118 acc 0.875000
Epoch 129 batch 400 loss 0.102867 acc 0.812500
Epoch 129 batch 500 loss 0.074625 acc 0.906250
Epoch 129 batch 600 loss 0.096452 acc 0.843750
Epoch 129 batch 700 loss 0.087980 acc 0.843750
Ep

Validation accuracy: 0.829900
Epoch 142 batch 0 loss 0.040512 acc 0.937500
Epoch 142 batch 100 loss 0.085423 acc 0.843750
Epoch 142 batch 200 loss 0.094607 acc 0.812500
Epoch 142 batch 300 loss 0.081069 acc 0.875000
Epoch 142 batch 400 loss 0.103097 acc 0.812500
Epoch 142 batch 500 loss 0.074970 acc 0.906250
Epoch 142 batch 600 loss 0.096310 acc 0.843750
Epoch 142 batch 700 loss 0.088121 acc 0.843750
Epoch 142 batch 800 loss 0.078335 acc 0.875000
Epoch 142 batch 900 loss 0.071937 acc 0.875000
Epoch 142 batch 1000 loss 0.080399 acc 0.843750
Epoch 142 batch 1100 loss 0.103669 acc 0.781250
Validation accuracy: 0.829900
Epoch 143 batch 0 loss 0.040671 acc 0.937500
Epoch 143 batch 100 loss 0.085246 acc 0.843750
Epoch 143 batch 200 loss 0.094590 acc 0.812500
Epoch 143 batch 300 loss 0.084297 acc 0.875000
Epoch 143 batch 400 loss 0.102956 acc 0.812500
Epoch 143 batch 500 loss 0.074719 acc 0.906250
Epoch 143 batch 600 loss 0.096365 acc 0.843750
Epoch 143 batch 700 loss 0.088049 acc 0.843750
Ep

Validation accuracy: 0.830058
Epoch 156 batch 0 loss 0.041148 acc 0.937500
Epoch 156 batch 100 loss 0.085296 acc 0.843750
Epoch 156 batch 200 loss 0.094655 acc 0.812500
Epoch 156 batch 300 loss 0.082568 acc 0.875000
Epoch 156 batch 400 loss 0.103052 acc 0.812500
Epoch 156 batch 500 loss 0.074688 acc 0.906250
Epoch 156 batch 600 loss 0.096331 acc 0.843750
Epoch 156 batch 700 loss 0.088089 acc 0.843750
Epoch 156 batch 800 loss 0.078233 acc 0.875000
Epoch 156 batch 900 loss 0.071942 acc 0.875000
Epoch 156 batch 1000 loss 0.080465 acc 0.843750
Epoch 156 batch 1100 loss 0.103888 acc 0.781250
Validation accuracy: 0.829900
Epoch 157 batch 0 loss 0.040749 acc 0.937500
Epoch 157 batch 100 loss 0.085368 acc 0.843750
Epoch 157 batch 200 loss 0.094644 acc 0.812500
Epoch 157 batch 300 loss 0.083594 acc 0.875000
Epoch 157 batch 400 loss 0.103066 acc 0.812500
Epoch 157 batch 500 loss 0.074379 acc 0.906250
Epoch 157 batch 600 loss 0.096492 acc 0.843750
Epoch 157 batch 700 loss 0.087979 acc 0.843750
Ep

Validation accuracy: 0.831873
Epoch 170 batch 0 loss 0.044024 acc 0.937500
Epoch 170 batch 100 loss 0.085659 acc 0.812500
Epoch 170 batch 200 loss 0.094798 acc 0.812500
Epoch 170 batch 300 loss 0.084185 acc 0.875000
Epoch 170 batch 400 loss 0.102909 acc 0.812500
Epoch 170 batch 500 loss 0.073454 acc 0.906250
Epoch 170 batch 600 loss 0.096695 acc 0.843750
Epoch 170 batch 700 loss 0.088167 acc 0.843750
Epoch 170 batch 800 loss 0.078032 acc 0.875000
Epoch 170 batch 900 loss 0.072053 acc 0.875000
Epoch 170 batch 1000 loss 0.080319 acc 0.843750
Epoch 170 batch 1100 loss 0.104624 acc 0.781250
Validation accuracy: 0.830610
Epoch 171 batch 0 loss 0.041253 acc 0.937500
Epoch 171 batch 100 loss 0.085421 acc 0.843750
Epoch 171 batch 200 loss 0.094690 acc 0.812500
Epoch 171 batch 300 loss 0.084859 acc 0.875000
Epoch 171 batch 400 loss 0.103017 acc 0.812500
Epoch 171 batch 500 loss 0.073943 acc 0.906250
Epoch 171 batch 600 loss 0.096676 acc 0.843750
Epoch 171 batch 700 loss 0.088010 acc 0.843750
Ep

Validation accuracy: 0.829900
Epoch 184 batch 0 loss 0.040595 acc 0.937500
Epoch 184 batch 100 loss 0.085670 acc 0.812500
Epoch 184 batch 200 loss 0.094718 acc 0.812500
Epoch 184 batch 300 loss 0.085922 acc 0.875000
Epoch 184 batch 400 loss 0.102924 acc 0.812500
Epoch 184 batch 500 loss 0.073921 acc 0.906250
Epoch 184 batch 600 loss 0.096551 acc 0.843750
Epoch 184 batch 700 loss 0.088062 acc 0.843750
Epoch 184 batch 800 loss 0.078056 acc 0.875000
Epoch 184 batch 900 loss 0.071962 acc 0.875000
Epoch 184 batch 1000 loss 0.080448 acc 0.843750
Epoch 184 batch 1100 loss 0.103998 acc 0.781250
Validation accuracy: 0.831794
Epoch 185 batch 0 loss 0.043111 acc 0.937500
Epoch 185 batch 100 loss 0.085720 acc 0.812500
Epoch 185 batch 200 loss 0.094824 acc 0.812500
Epoch 185 batch 300 loss 0.081747 acc 0.875000
Epoch 185 batch 400 loss 0.103132 acc 0.812500
Epoch 185 batch 500 loss 0.074171 acc 0.906250
Epoch 185 batch 600 loss 0.096538 acc 0.843750
Epoch 185 batch 700 loss 0.088135 acc 0.843750
Ep

Validation accuracy: 0.831873
Epoch 198 batch 0 loss 0.042867 acc 0.937500
Epoch 198 batch 100 loss 0.085954 acc 0.812500
Epoch 198 batch 200 loss 0.094810 acc 0.812500
Epoch 198 batch 300 loss 0.085854 acc 0.875000
Epoch 198 batch 400 loss 0.102938 acc 0.812500
Epoch 198 batch 500 loss 0.073476 acc 0.906250
Epoch 198 batch 600 loss 0.096648 acc 0.843750
Epoch 198 batch 700 loss 0.088038 acc 0.843750
Epoch 198 batch 800 loss 0.077934 acc 0.875000
Epoch 198 batch 900 loss 0.071960 acc 0.875000
Epoch 198 batch 1000 loss 0.080352 acc 0.843750
Epoch 198 batch 1100 loss 0.106940 acc 0.781250
Validation accuracy: 0.830847
Epoch 199 batch 0 loss 0.041741 acc 0.937500
Epoch 199 batch 100 loss 0.085617 acc 0.812500
Epoch 199 batch 200 loss 0.094810 acc 0.812500
Epoch 199 batch 300 loss 0.083087 acc 0.875000
Epoch 199 batch 400 loss 0.103224 acc 0.812500
Epoch 199 batch 500 loss 0.073901 acc 0.906250
Epoch 199 batch 600 loss 0.096609 acc 0.843750
Epoch 199 batch 700 loss 0.088030 acc 0.843750
Ep

In [15]:

def countnnZ(A, s, bytesPerVar=4):
    '''
    Returns # of non-zeros and representative size of the tensor
    Uses dense for s >= 0.5 - 4 byte
    Else uses sparse - 8 byte
    '''
    params = 1
    hasSparse = False
    for i in range(0, len(A.shape)):
        params *= int(A.shape[i])
    if s < 0.5:
        nnZ = np.ceil(params * s)
        hasSparse = True
        return nnZ, nnZ * 2 * bytesPerVar, hasSparse
    else:
        nnZ = params
        return nnZ, nnZ * bytesPerVar, hasSparse
    
def getModelSize(matrixList, sparcityList, expected=True, bytesPerVar=4):
    '''
    expected: Expected size according to the parameters set. The number of
        zeros could actually be more than that is required to satisfy the
        sparsity constraint.
    '''
    nnzList, sizeList, isSparseList = [], [], []
    hasSparse = False
    for i in range(len(matrixList)):
        A, s = matrixList[i], sparcityList[i]
        assert A.ndim == 2
        assert s >= 0
        assert s <= 1
        nnz, size, sparse = countnnZ(A, s, bytesPerVar=bytesPerVar)
        nnzList.append(nnz)
        sizeList.append(size)
        hasSparse = (hasSparse or sparse)

    totalnnZ = np.sum(nnzList)
    totalSize = np.sum(sizeList)
    if expected:
        return totalnnZ, totalSize, hasSparse
    numNonZero = 0
    totalSize = 0
    hasSparse = False
    for i in range(len(matrixList)):
        A, s = matrixList[i], sparcityList[i]
        numNonZero_ = np.count_nonzero(A)
        numNonZero += numNonZero_
        hasSparse = (hasSparse or (s < 0.5))
        if s <= 0.5:
            totalSize += numNonZero_ * 2 * bytesPerVar
        else:
            totalSize += A.size * bytesPerVar
    return numNonZero, totalSize, hasSparse

# W, B, Z are tensorflow graph nodes
W, B, Z, _ = protoNN.getModelMatrices()
matrixList = ([W, B, Z])
sparcityList = [SPAR_W, SPAR_B, SPAR_Z]
nnz, size, sparse = getModelSize(matrixList, sparcityList)
# print("Final test accuracy", acc)
print("Model size constraint (Bytes): ", size)
# print("Number of non-zeros: ", nnz)

Model size constraint (Bytes):  2180


In [None]:
# # summarize the fit of the model

# print("Classification summary: \n",metrics.classification_report(Y_test, y_pred))
# print("Confusion matrix: \n",metrics.confusion_matrix(Y_test, y_pred))