In [0]:
# This code mainly refers to https://github.com/antspy/quantized_distillation

import time
from datetime import datetime
import math
import torch
import torch.nn as nn
import torch.nn.functional as Functional
import torch.optim as optim
import torch.utils.data
from torch.nn.init import xavier_uniform, calculate_gain
from torchvision import datasets
from torchvision import transforms
from torchvision.models import resnet34
from torchvision.models import resnet18
from google.colab import drive
import functools
import numpy as np
from collections import defaultdict
from heapq import heappush, heappop, heapify


'''
Implements the model and training techniques detailed in the paper:
"Do deep convolutional neural network really need to be deep and convolutional?"
'''

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')


  
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1):
        super(BasicBlock, self).__init__()

        self.conv1 = nn.Conv2d(inplanes, planes, stride=stride, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(planes, planes, stride=1, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or inplanes != planes * self.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(inplanes, planes * self.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * self.expansion),
            )

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += self.shortcut(x)
        out = self.relu(out)

        return out


class ResNet(nn.Module):
    cfg = {
        18: [2, 2, 2, 2],
        34: [3, 4, 6, 3],
    }
    cfg_pm = {
        18: resnet18(pretrained=True),
        34: resnet34(pretrained=True),
    }

    def __init__(self, key, pretrained=False, k=1):
        super(ResNet, self).__init__()

        self.pretrained = pretrained
        self.grams = None
        if pretrained:
            self._pm = self.cfg_pm[key]
            self._pm.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
            self._pm.fc = nn.Linear(self._pm.fc.in_features, 10)
            for _para in list(self._pm.parameters()):
                _para.requires_grad = False
        else:
            self.inplanes = 64*k
            self.conv1 = nn.Conv2d(3, 64*k, kernel_size=3, stride=1, padding=1, bias=False)
            self.bn1 = nn.BatchNorm2d(64*k)
            self.relu = nn.ReLU(inplace=True)
            self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
            self.layer1 = self._make_layer(64*k, self.cfg[key][0], stride=1)
            self.layer2 = self._make_layer(128*k, self.cfg[key][1], stride=2)
            self.layer3 = self._make_layer(256*k, self.cfg[key][2], stride=2)
            self.layer4 = self._make_layer(512*k, self.cfg[key][3], stride=2)
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
            self.fc = nn.Linear(512 * k * BasicBlock.expansion, 10)

    def _make_layer(self, planes, blocks, stride=1):
        layers = list()
        layers.append(BasicBlock(self.inplanes, planes, stride))
        self.inplanes = planes * BasicBlock.expansion
        for _ in range(1, blocks):
            layers.append(BasicBlock(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        self.grams = list()
        if self.pretrained:
            x = self._pm.conv1(x)
            x = self._pm.bn1(x)
            x = self._pm.relu(x)
            # x = self._pm.maxpool(x)

            x = self._pm.layer1(x)
            self.grams.append(x)
            x = self._pm.layer2(x)
            self.grams.append(x)
            x = self._pm.layer3(x)
            self.grams.append(x)
            x = self._pm.layer4(x)
            self.grams.append(x)

            x = self._pm.avgpool(x)
            x = x.view(x.size(0), -1)
            x = self._pm.fc(x)
        else:
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.relu(x)
            #             x = self.maxpool(x)

            x = self.layer1(x)
            self.grams.append(x)
            x = self.layer2(x)
            self.grams.append(x)
            x = self.layer3(x)
            self.grams.append(x)
            x = self.layer4(x)
            self.grams.append(x)

            x = self.avgpool(x)
            x = x.view(x.size(0), -1)
            x = self.fc(x)
        return x



class ConvolForwardNet(nn.Module):
    ''' Teacher model as described in the paper :
    "Do deep convolutional neural network really need to be deep and convolutional?"'''

    def __init__(self, width, height, spec_conv_layers, spec_max_pooling, spec_linear, spec_dropout_rates):

        '''
        The structure of the network is: a number of convolutional layers, intermittend max-pooling and dropout layers,
        and a number of linear layers. The max-pooling layers are inserted in the positions specified, as do the dropout
        layers.

        :param spec_conv_layers: list of tuples with (numFilters, width, height) (one tuple for each layer);
        :param spec_max_pooling: list of tuples with (posToInsert, width, height) of max-pooling layers
        :param spec_dropout_rates list of tuples with (posToInsert, rate of dropout) (applied after max-pooling)
        :param spec_linear: list with numNeurons for each layer (i.e. [100, 200, 300] creates 3 layers)
        '''

        super(ConvolForwardNet, self).__init__()

        self.width = width
        self.height = height
        self.conv_layers = []
        self.max_pooling_layers = []
        self.dropout_layers = []
        self.linear_layers = []
        self.max_pooling_positions = []
        self.dropout_positions = []
        self.batchNormalizationLayers = []

        # creating the convolutional layers
        oldNumChannels = 3
        for idx in range(len(spec_conv_layers)):
            currSpecLayer = spec_conv_layers[idx]
            numFilters = currSpecLayer[0]
            kernel_size = (currSpecLayer[1], currSpecLayer[2])
            # The padding needs to be such that width and height of the image are unchanges after each conv layer
            padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)
            newConvLayer = nn.Conv2d(in_channels=oldNumChannels, out_channels=numFilters,
                                     kernel_size=kernel_size, padding=padding)
            xavier_uniform(newConvLayer.weight, calculate_gain('conv2d'))  # glorot weight initialization
            self.conv_layers.append(newConvLayer)
            self.batchNormalizationLayers.append(nn.BatchNorm2d(numFilters, affine=True))
            oldNumChannels = numFilters

        # creating the max pooling layers
        for idx in range(len(spec_max_pooling)):
            currSpecLayer = spec_max_pooling[idx]
            kernel_size = (currSpecLayer[1], currSpecLayer[2])
            self.max_pooling_layers.append(nn.MaxPool2d(kernel_size))
            self.max_pooling_positions.append(currSpecLayer[0])

        # creating the dropout layers
        for idx in range(len(spec_dropout_rates)):
            currSpecLayer = spec_dropout_rates[idx]
            rate = currSpecLayer[1]
            currPosition = currSpecLayer[0]
            if currPosition < len(self.conv_layers):
                # we use dropout2d only for the conv_layers, otherwise we use the usual dropout
                self.dropout_layers.append(nn.Dropout2d(rate))
            else:
                self.dropout_layers.append(nn.Dropout(rate))
            self.dropout_positions.append(currPosition)

        # creating the linear layers
        oldInputFeatures = oldNumChannels * width * height // 2 ** (2 * len(self.max_pooling_layers))
        for idx in range(len(spec_linear)):
            currNumFeatures = spec_linear[idx]
            newLinearLayer = nn.Linear(in_features=oldInputFeatures, out_features=currNumFeatures)
            xavier_uniform(newLinearLayer.weight, calculate_gain('linear'))  # glorot weight initialization
            self.linear_layers.append(newLinearLayer)
            self.batchNormalizationLayers.append(nn.BatchNorm1d(currNumFeatures, affine=True))
            oldInputFeatures = currNumFeatures

        # final output layer
        self.out_layer = nn.Linear(in_features=oldInputFeatures, out_features=10)
        xavier_uniform(self.out_layer.weight, calculate_gain('linear'))

        self.conv_layers = nn.ModuleList(self.conv_layers)
        self.max_pooling_layers = nn.ModuleList(self.max_pooling_layers)
        self.dropout_layers = nn.ModuleList(self.dropout_layers)
        self.linear_layers = nn.ModuleList(self.linear_layers)
        self.batchNormalizationLayers = nn.ModuleList(self.batchNormalizationLayers)
        self.num_conv_layers = len(self.conv_layers)
        self.total_num_layers = self.num_conv_layers + len(self.linear_layers)

    def forward(self, input):

        for idx in range(self.total_num_layers):
            if idx < self.num_conv_layers:
                input = Functional.relu(self.conv_layers[idx](input))
            else:
                if idx == self.num_conv_layers:
                    # if it is the first layer after the convolutional layers, make it as a vector
                    input = input.view(input.size()[0], -1)
                input = Functional.relu(self.linear_layers[idx - self.num_conv_layers](input))

            input = self.batchNormalizationLayers[idx](input)

            try:
                posMaxLayer = self.max_pooling_positions.index(idx)
                input = self.max_pooling_layers[posMaxLayer](input)
            except ValueError:
                pass

            try:
                posDropoutLayer = self.dropout_positions.index(idx)
                input = self.dropout_layers[posDropoutLayer](input)
            except ValueError:
                pass

        input = Functional.relu(self.out_layer(input))

        # No need to take softmax if the loss function is cross entropy
        return input


def train_model(model, train_loader, test_loader, initial_learning_rate=0.001, initial_momentum=0.9,
                weight_decayL2=0.00022, epochs_to_train=300, use_distillation_loss=False,
                teacher_model=None, quantizeWeights=False, numBits=8, bucket_size=None, backprop_quantization_style='none', estimate_quant_grad_every=1):
    # backprop_quantization_style determines how to modify the gradients to take into account the
    # quantization function. Specifically, one can use 'none', where gradients are not modified,
    # 'truncated', where gradient values outside -1 and 1 are truncated to 0 (as per the paper
    # specified in the comments) and 'complicated', which is the temp name for my idea which is slow and complicated
    # to compute

    if teacher_model is not None:
        teacher_model.eval()

    lr_scheduler = LearningRateScheduler(initial_learning_rate)
    optimizer = optim.SGD(model.parameters(), lr=initial_learning_rate, nesterov=True, momentum=initial_momentum,
                          weight_decay=weight_decayL2)
    startTime = time.time()

    pred_accuracy_epochs = []
    percentages_asked_teacher = []
    losses_epochs = []
    last_loss_saved = float('inf')
    step_since_last_grad_quant_estimation = 1

    if quantizeWeights:
        # uniformLinearScaling
        s = 2 ** numBits

        quantizeFunctions = lambda x: uniformQuantization(x, s, bucket_size=bucket_size)

        def quantize_weights_model(model):
            for idx, p in enumerate(model.parameters()):
                p.data,_ = quantizeFunctions(p.data)
                
        def backward_quant_weights_model(model):
            if backprop_quantization_style == 'none':
                return

    for epoch in range(epochs_to_train):
        model.train()
        print_loss_total = 0
        count_asked_teacher = 0
        count_asked_total = 0
        for idx_minibatch, data in enumerate(train_loader, start=1):

            if quantizeWeights:
                if step_since_last_grad_quant_estimation >= 1:
                    # we save them because we only want to quantize weights to compute gradients,
                    # but keep using non-quantized weights during the algorithm
                    model_state_dict = model.state_dict()
                    quantize_weights_model(model)

            model.zero_grad()
            print_loss, curr_c_teach, curr_c_total = forward_and_backward(model, data,
                                                                          use_distillation_loss=use_distillation_loss,
                                                                          teacher_model=teacher_model)
            count_asked_teacher += curr_c_teach
            count_asked_total += curr_c_total

            # load the non-quantize weights and use them for the update. The quantized
            # weights are used only to get the quantized gradient
            if quantizeWeights:
                if step_since_last_grad_quant_estimation >= 1:
                    model.load_state_dict(model_state_dict)
                    del model_state_dict  # free memory                    
                if step_since_last_grad_quant_estimation >= estimate_quant_grad_every:
                    backward_quant_weights_model(model)
                    
            optimizer.step()

            if step_since_last_grad_quant_estimation >= 1:
                step_since_last_grad_quant_estimation = 0

            step_since_last_grad_quant_estimation += 1

            # print statistics
            print_loss_total += print_loss
            if idx_minibatch % 500 == 0:
                last_loss_saved = print_loss_total / 500

                str_to_print = 'Time Elapsed: {}, [Start Epoch: {}, Epoch: {}, Minibatch: {}], loss: {:3f}'.format(
                    datetime.fromtimestamp(time.time() - startTime).strftime('%H:%M:%S'), 1, epoch + 1, idx_minibatch,
                    last_loss_saved)
                if pred_accuracy_epochs:
                    str_to_print += ' Last prediction accuracy: {:2f}%'.format(pred_accuracy_epochs[-1] * 100)
                print(str_to_print)
                print_loss_total = 0

        curr_percentages_asked_teacher = count_asked_teacher / count_asked_total if count_asked_total != 0 else 0
        percentages_asked_teacher.append(curr_percentages_asked_teacher)
        losses_epochs.append(last_loss_saved)
        curr_pred_accuracy = evaluateModel(model, test_loader)
        pred_accuracy_epochs.append(curr_pred_accuracy)
        print(' === Epoch: {} - prediction accuracy {:2f}% === '.format(epoch + 1, curr_pred_accuracy * 100))

        # updating the learning rate
        new_learning_rate, stop_training = lr_scheduler.update_learning_rate(1 - curr_pred_accuracy)
        if stop_training is True:
            break
        for p in optimizer.param_groups:
            p['lr'] = new_learning_rate

    if quantizeWeights:
        quantize_weights_model(model)

    print(f'percentages_asked_teacher {percentages_asked_teacher}')
    print(f'predictionAccuracy {pred_accuracy_epochs}')
#     print(f'lossSaved {losses_epochs}')
#     model_param_save_name = f'resnet18_quantized_param1.pt'
#     path = F"/content/gdrive/My Drive/{model_param_save_name}" 
#     torch.save(model.state_dict(), path)
    return model


class ScalingFunction(object):
    '''
    This class is there to hold two functions: the scaling function for a tensor, and its inverse.
    They are budled together in a class because to be able to invert the scaling, we need to remember
    several parameters, and it is a little uncomfortable to do it manually. The class of course remembers
    correctly.
    '''

    # TODO: Make static version of scale and inv_scale that take as arguments all that is necessary,
    # and then the class can just be a small wrapper about calling scale, saving the arguments,
    # and calling inv. So we would have both ways to call the scaling function, directly and through
    # the class.

    def __init__(self, bucket_size):
        self.bucket_size = bucket_size
        self.tol_diff_zero = 1e-10

        # Things we need to invert the tensor. Set to None, will be populated by scale
        self.mean_tensor = None
        self.original_tensor_size = None
        self.original_tensor_length = None
        self.expected_tensor_size = None

        self.alpha = None
        self.beta = None

    def scale_down(self, tensor):
        '''
        Scales the tensor using one of the methods. Note that if bucket_size is not None,
        the shape of the tensor will be changed. This change will be inverted by inv_scale
        '''
        self.mean_tensor = 0

        self.original_tensor_size = tensor.size()
        self.original_tensor_length = tensor.numel()
        tensor = create_bucket_tensor(tensor, self.bucket_size)
        if self.bucket_size is None:
            tensor = tensor.view(-1)
        self.expected_tensor_size = tensor.size()

        # if tensor is bucketed, it has 2 dimension, otherwise it has 1.
        if self.bucket_size is None:
            min_rows, _ = tensor.min(dim=0, keepdim=True)
            max_rows, _ = tensor.max(dim=0, keepdim=True)
        else:
            min_rows, _ = tensor.min(dim=1, keepdim=True)
            max_rows, _ = tensor.max(dim=1, keepdim=True)
        alpha = max_rows - min_rows
        beta = min_rows
        # If alpha is zero for one row, it means the whole row is 0.
        # So we set alpha = 1 there, to avoid nan and inf, and result won't change
        if self.bucket_size is None:
            if alpha[0] < self.tol_diff_zero:
                alpha[0] = 1
        else:
            alpha[alpha < self.tol_diff_zero] = 1

        self.alpha = alpha
        self.beta = beta

        tensor.sub_(self.beta.expand_as(tensor))
        tensor.div_(self.alpha.expand_as(tensor))

        return tensor

    def inv_scale_down(self, tensor):

        "inverts the scaling done before. Note that the max_element truncation won't be inverted"
        if tensor.size() != self.expected_tensor_size:
            raise ValueError('The tensor passed has not the expected size.')

        tensor.mul_(self.alpha.expand_as(tensor))
        tensor.add_(self.beta.expand_as(tensor))

        tensor.add_(self.mean_tensor)
        tensor = tensor.view(-1)[0:self.original_tensor_length]  # remove the filler values
        tensor = tensor.view(self.original_tensor_size)

        return tensor


def uniformQuantization(tensor, s, bucket_size=None):
    '''
    Quantizes using the random uniform quantization algorithm the tensor passed using s levels.
    '''

    tensor = tensor.clone()
    # we always pass True to modify_in_place because we have already cloned it by this point
    scaling_function = ScalingFunction(bucket_size)

    tensor = scaling_function.scale_down(tensor)

    # decrease s by one so as to have exactly s quantization points
    s = s - 1

    tensor.mul_(s)
    tensor.round_()
    tensor.div_(s)

    tensor = scaling_function.inv_scale_down(tensor)
    return tensor, scaling_function


def create_bucket_tensor(tensor, bucket_size):
    if bucket_size is None:
        return tensor

    tensor = tensor.view(-1)

    total_length = tensor.numel()
    multiple, rest = divmod(total_length, bucket_size)
    if multiple != 0 and rest != 0:
        # if multiple is 0, the num of elements is smaller than bucket size so we operate directly
        # on the tensor passed
        values_to_add = torch.ones(bucket_size - rest) * tensor[-1]
        values_to_add = values_to_add.to(device)
        # add the fill_values to make the tensor a multiple of the bucket size.
        tensor = torch.cat([tensor, values_to_add])
    if multiple == 0:
        # in this case the tensor is smaller than the bucket size. For consistency we still return it in the same
        # format (i.e. a row) but the number of elements is smaller (and equal to the lenght of the tensor)
        tensor = tensor.view(1, total_length)
    else:
        # this is the bucket tensor. A view of the original tensor suffice
        tensor = tensor.view(-1, bucket_size)
    return tensor


def evaluateModel(model, testLoader, k=1):
    model.eval()
    correctClass = 0
    totalNumExamples = 0

    for idx_minibatch, data in enumerate(testLoader):

        # get the inputs
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        with torch.no_grad():
            outputs = model(inputs)
        _, topk_predictions = outputs.topk(k, dim=1, largest=True, sorted=True)
        topk_predictions = topk_predictions.t()
        correct = topk_predictions.eq(labels.view(1, -1).expand_as(topk_predictions))
        correctClass += correct.view(-1).float().sum(0, keepdim=True).item()
        totalNumExamples += len(labels)

    return correctClass / totalNumExamples


def forward_and_backward(model, batch, use_distillation_loss=False, teacher_model=None):
    # TODO: return_more_info is just there for backward compatibility. A big refactoring is due here, and there one should
    # remove the return_more_info flag

    criterion = nn.CrossEntropyLoss().to(device)

    if use_distillation_loss is True and teacher_model is None:
        raise ValueError('To compute distillation loss you need to pass the teacher model')

    inputs, labels = batch
    inputs, labels = inputs.to(device), labels.to(device)

    # forward + backward + optimize
    outputs = model(inputs)

    count_asked_teacher = 0

    if use_distillation_loss:
        # if cutoff_entropy_value_distillation is not None, we use the distillation loss only on the examples
        # whose entropy is higher than the cutoff.

        mask_distillation_loss = torch.ByteTensor([True] * outputs.size(0))
        index_distillation_loss = torch.arange(0, outputs.size(0))[mask_distillation_loss.view(-1, 1).squeeze()].long()
        inverse_idx_distill_loss = torch.arange(0, outputs.size(0))[((1 - mask_distillation_loss.view(-1, 1)).squeeze()).long()].long()
#         inverse_idx_distill_loss = torch.arange(0, outputs.size(0))[inverse].long()
#         print(inverse_idx_distill_loss)
        index_distillation_loss = index_distillation_loss.to(device)
        inverse_idx_distill_loss = inverse_idx_distill_loss.to(device)

        # this criterion is the distillation criterion according to Hinton's paper:
        # "Distilling the Knowledge in a Neural Network", Hinton et al.

        softmaxFunction = nn.Softmax(dim=1).to(device)
        logSoftmaxFunction = nn.LogSoftmax(dim=1).to(device)
        KLDivLossFunction = nn.KLDivLoss().to(device)

        weight_teacher_loss = 0.7
        temperature_distillation = 10
        if index_distillation_loss.size() != torch.Size():
            count_asked_teacher = index_distillation_loss.numel()
            # if index_distillation_loss is not empty
            volatile_inputs = inputs.data[index_distillation_loss, :].to(device)
            outputsTeacher = teacher_model(volatile_inputs).detach()
            loss_masked = weight_teacher_loss * temperature_distillation ** 2 * KLDivLossFunction(
                logSoftmaxFunction(outputs[index_distillation_loss, :] / temperature_distillation),
                softmaxFunction(outputsTeacher / temperature_distillation))
            loss_masked += (1 - weight_teacher_loss) * criterion(outputs[index_distillation_loss, :],
                                                                 labels[index_distillation_loss])
        else:
            loss_masked = 0

        if inverse_idx_distill_loss.size() != torch.Size():
#             if inverse_idx_distill is not empty
#             if len(inverse_idx_distill_loss) !=0 :
#                 print("aaaaa")
#                 loss_normal = criterion(outputs[inverse_idx_distill_loss, :], labels[inverse_idx_distill_loss])
#             else:
#                 loss_normal = 0
             loss_normal = criterion(outputs[inverse_idx_distill_loss, :], labels[inverse_idx_distill_loss])
        else:
            loss_normal = 0
        loss = loss_masked + loss_normal
    else:
        loss = criterion(outputs, labels)

    loss.backward()

    count_total = inputs.size(0)
    return loss.item(), count_asked_teacher, count_total


class LearningRateScheduler:
    def __init__(self, initial_learning_rate):
        self.initial_learning_rate = initial_learning_rate
        self.current_learning_rate = initial_learning_rate

        self.old_validation_error = float('inf')
        self.epochs_since_validation_error_dropped = 0
        self.total_number_of_learning_rate_halves = 0
        self.epochs_to_wait_for_halving = 0

    def update_learning_rate(self, validation_error):
        epoch_to_wait_before_reducing_rate = 10
        epochs_to_wait_after_halving = 8
        epoch_to_wait_before_stopping = 30
        total_halves_before_stopping = 11

        new_learning_rate = self.current_learning_rate
        stop_training = False

        # we have a 0.1% error band
        if validation_error + 0.001 < self.old_validation_error:
            self.old_validation_error = validation_error
            self.epochs_since_validation_error_dropped = 0
        else:
            self.epochs_since_validation_error_dropped += 1

        self.epochs_to_wait_for_halving = max(self.epochs_to_wait_for_halving - 1, 0)
        if self.epochs_since_validation_error_dropped >= epoch_to_wait_before_reducing_rate and \
                self.epochs_to_wait_for_halving == 0:
            # if validation error does not drop for 10 epochs in a row, halve the learning rate
            # but don't halve it for at least 8 epochs after halving.
            self.epochs_to_wait_for_halving = epochs_to_wait_after_halving
            self.total_number_of_learning_rate_halves += 1
            new_learning_rate = self.current_learning_rate / 2
            self.current_learning_rate = new_learning_rate

        if self.epochs_since_validation_error_dropped > epoch_to_wait_before_stopping or \
                self.total_number_of_learning_rate_halves > total_halves_before_stopping:
            # stop training if validation rate hasn't dropped in 30 epochs or if learning rates was halved 11 times already
            # i.e. it was reduced by 2048 times.
            stop_training = True

        return new_learning_rate, stop_training

      
  
def huffman_encode(symb2freq):

    """Huffman encode the given dict mapping symbols to weights"""
    #code taken from https://rosettacode.org/wiki/Huffman_coding#Python

    heap = [[wt, [sym, ""]] for sym, wt in symb2freq.items()]
    heapify(heap)
    while len(heap) > 1:
        lo = heappop(heap)
        hi = heappop(heap)
        for pair in lo[1:]:
            pair[1] = '0' + pair[1]
        for pair in hi[1:]:
            pair[1] = '1' + pair[1]
        heappush(heap, [lo[0] + hi[0]] + lo[1:] + hi[1:])
    return sorted(heappop(heap)[1:], key=lambda p: (len(p[-1]), p))


  
def get_huffman_encoding_mean_bit_length(model_param_iter, quantization_functions, type_quantization='uniform',
                                         s=None):

    '''
    'returns the mean size of the bit requires to encode everything using huffman encoding'
    :param model_param_iter: the iterator returning model parameters
    :param quantization_functions: the quantization function to use. Either a single one or a list with as many functions
                                   as there are tensors in the model
    :param type_quantization:      Uniform or nonUniform. If nonUniform, the model_param_iter must the the original weights,
                                   not the quantized ones! If uniform, it doesn't matter.
    :return: the mean bit size of encoding the model tensors using huffman encoding
    '''

    type_quantization = type_quantization.lower()
    if type_quantization not in ('uniform', 'nonuniform'):
        raise ValueError('type_quantization not recognized')

    if s is None and type_quantization == 'uniform':
        raise ValueError('If type of quantization is uniform, you must provide s')

    if not isinstance(quantization_functions, list):
        quantization_functions = [quantization_functions]

    single_quant_fun = len(quantization_functions) == 1
    total_length = 0
    frequency = defaultdict(int)
    tol = 1e-5
    for idx, param in enumerate(model_param_iter):
        param = param.clone()
        if hasattr(param, 'data'):
            param = param.data

        total_length += param.numel()
        if single_quant_fun:
            quant_fun = quantization_functions[0]
        else:
            quant_fun = quantization_functions[idx]

        if type_quantization == 'uniform':
            quant_points = [x / (s-1) for x in range(s)]
            q_tensor, scal = quant_fun(param)
            numpy_array = scal.scale_down(q_tensor).view(-1)[0:scal.original_tensor_length].cpu().numpy()
            bin_around_points = [x - tol for x in quant_points]
            bin_indices = np.digitize(numpy_array, bin_around_points).flatten() - 1


        unique, counts = np.unique(bin_indices, return_counts=True)
        for val, count in zip(unique, counts):
            frequency[val] += count

    assert total_length == sum(frequency.values())
    frequency = {x: y/total_length for x, y in frequency.items()}
    huffman_code = huffman_encode(frequency)
    mean_bit_length = sum(frequency[x[0]]*len(x[1]) for x in huffman_code)

    return mean_bit_length  
  
def get_size_quantized_model(model, numBits, quantization_functions, bucket_size=256,
                             type_quantization='uniform', quantizeFirstLastLayer=True):

#     'Returns size in MB'

    if numBits is None:
        return sum(p.numel() for p in model.parameters()) * 4 / 1000000


    numTensors = sum(1 for _ in model.parameters())
    if quantizeFirstLastLayer is True:
        def get_quantized_params():
            return model.parameters()
        def get_unquantized_params():
            return iter(())
    else:
        def get_quantized_params():
            return  (p for idx, p in enumerate(model.parameters()) if idx not in (0, numTensors - 1))
        def get_unquantized_params():
            return (p for idx, p in enumerate(model.parameters()) if idx in (0, numTensors - 1))

    count_quantized_parameters = sum(p.numel() for p in get_quantized_params())
    count_unquantized_parameters = sum(p.numel() for p in get_unquantized_params())

    #Now get the best huffmann bit length for the quantized parameters
    actual_bit_huffmman = get_huffman_encoding_mean_bit_length(get_quantized_params(), quantization_functions,
                                                                   type_quantization, s=2**numBits)

    #Now we can compute the size.
    size_mb = 0
    size_mb += count_unquantized_parameters*4 #32 bits / 8 = 4 byte per parameter
    size_mb += actual_bit_huffmman*count_quantized_parameters/8 #For the quantized parameters we use the mean huffman length
    if bucket_size is not None:
        size_mb += count_quantized_parameters/bucket_size*8  #for every bucket size, we have to save 2 parameters.
                                                             #so we multiply the number of buckets by 2*32/8 = 8
    size_mb = size_mb / 1000000 #to bring it in MB
    return size_mb
  

class QD:
    def __init__(self):
        self.teacher = None
        self.model = None
        self.epochs = 300
        self.train_batch_size = 128  # 25
        self.test_batch_size = 128  # 25
        self.train_loader = None
        self.test_loader = None

    def load_data(self):
        train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.ToTensor()])
        test_transform = transforms.Compose([transforms.ToTensor()])
        train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
        self.train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=self.train_batch_size,
                                                        shuffle=True)
        test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
        self.test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=self.test_batch_size, shuffle=False)

    def load_model(self):
        self.teacher = ResNet(34, False).to(device)
        model_param_save_name = f'resnet34_param.pt'
        path = F"/content/gdrive/My Drive/{model_param_save_name}" 
        self.teacher.load_state_dict(torch.load(path))
        for param in list(self.teacher.parameters()):
                param.requires_grad = False

    def main(self, quantized=False):
        self.load_data()
        self.load_model()

#         for numBit in [2, 4, 8]:
        for numBit in [4]:
            model = ResNet(18,False,k=1).to(device)

            if quantized:  # distillation with quantized bits
                train_model(
                    model,
                    train_loader=self.train_loader,
                    test_loader=self.test_loader,
                    **{
                        'teacher_model': self.teacher,
                        'epochs_to_train' : self.epochs,
                        'use_distillation_loss': True,
                        'quantizeWeights': True,
                        'numBits': numBit,
                        'bucket_size': 256,
                    }
                )
                quant_fun = functools.partial(uniformQuantization, s=2**numBit, bucket_size=256)
                size_model_MB = get_size_quantized_model(model, numBit, quant_fun, bucket_size=256, quantizeFirstLastLayer=True)
#                 print(size_model_MB)
                print(F"The size of the quantized model using Huffman coding is {size_model_MB} MB")
#                 model_param_save_name = 'resnet18_quantized_param.pt'
#                 path = F"/content/gdrive/My Drive/{model_param_save_name}" 
#                 torch.save(model.state_dict(), path)
            else:  # distillation
                train_model(
                    model,
                    train_loader=self.train_loader,
                    test_loader=self.test_loader,
                    **{
                        'teacher_model': self.teacher,
                        'use_distillation_loss': True,
                    }
                )


if __name__ == '__main__':
    drive.mount('/content/gdrive')
#     QD().main()
    QD().main(quantized=True)
