In [20]:
import argparse
import copy
import os
import sys
import numpy as np
from tqdm import tqdm
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import os
from tensorboardX import SummaryWriter
import torchvision.utils as vutils
import torch.nn.init as init
import pickle

sys.path.append(os.path.abspath("/ocean/projects/asc170022p/shg121/PhD/Project_Pruning"))
import numpy as np
from torch.utils.data import DataLoader

from model_factory.model_meta import Model_Meta
from model_factory.models import Classifier
from run_manager import RunManager


import utils
import yaml
import pickle
import torch
from dataset.dataset_mnist import Dataset_mnist
from dataset.dataset_utils import get_dataset, get_transforms

In [85]:
seed = 0
device = utils.get_device()
print(f"Device: {device}")
data_root = "/ocean/projects/asc170022p/shg121/PhD/Project_Pruning/data/MNIST_EVEN_ODD"
json_root = "/ocean/projects/asc170022p/shg121/PhD/Project_Pruning/scripts_data"
model_arch = "Resnet_18"
dataset_name = "mnist"
pretrained = True
transfer_learning = False
chk_pt_path = "seq_epoch_20.pth.tar"
num_classes = 1
logs = "/ocean/projects/asc170022p/shg121/PhD/Project_Pruning/output"
bb_layer = "layer3"  # layer3
concept_names = ["Zero", "One", "Two", "Three", "Four", "Five", "Six", "Seven", "Eight", "Nine"]
img_size = 224
batch_size = 3
epochs = 50
num_workers = 4
class_list = [0, 1]
num_labels = len(class_list)
cav_vector_file = "max_pooled_train_cavs.pkl"
kernel_size={
    "layer3": 14,
    "layer4": 7
}

prune_type = "lt"
lr = 1e-3
ITERATION  = 35
prune_percent = 10
start_iter = 0
end_iter = 100
resample = False
reinit = True if prune_type=="reinit" else False

Device: cuda


In [23]:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

train_set = get_dataset(
    data_root=data_root,
    json_root=json_root,
    dataset_name=dataset_name,
    mode="train"
)

val_set = get_dataset(
    data_root=data_root,
    json_root=json_root,
    dataset_name=dataset_name,
    mode="val"
)

transform = get_transforms(size=img_size)
train_dataset = Dataset_mnist(train_set, transform)
train_loader = DataLoader(
    train_dataset,
    num_workers=4,
    batch_size=batch_size,
    shuffle=True
)

val_dataset = Dataset_mnist(val_set, transform)
val_loader = DataLoader(
    val_dataset,
    num_workers=4,
    batch_size=batch_size,
    shuffle=False
)

Length of the [train] dataset: 48000
Length of the [val] dataset: 12000


In [122]:
def weight_init(m):
    '''
    Usage:
        model = Model()
        model.apply(weight_init)
    '''
    if isinstance(m, nn.Conv1d):
        init.normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.Conv2d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.Conv3d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.ConvTranspose1d):
        init.normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.ConvTranspose2d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.ConvTranspose3d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.BatchNorm1d):
        init.normal_(m.weight.data, mean=1, std=0.02)
        init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm2d):
        init.normal_(m.weight.data, mean=1, std=0.02)
        init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm3d):
        init.normal_(m.weight.data, mean=1, std=0.02)
        init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        init.xavier_normal_(m.weight.data)
        init.normal_(m.bias.data)
    elif isinstance(m, nn.LSTM):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)
    elif isinstance(m, nn.LSTMCell):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)
    elif isinstance(m, nn.GRU):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)
    elif isinstance(m, nn.GRUCell):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)

In [123]:
global model
model = Classifier(model_arch, num_classes, pretrained, transfer_learning)
model.to(device)

Classifier(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_ru

In [124]:
for name, param in model.named_parameters():
    print(name, param.size())

model.conv1.weight torch.Size([64, 3, 7, 7])
model.bn1.weight torch.Size([64])
model.bn1.bias torch.Size([64])
model.layer1.0.conv1.weight torch.Size([64, 64, 3, 3])
model.layer1.0.bn1.weight torch.Size([64])
model.layer1.0.bn1.bias torch.Size([64])
model.layer1.0.conv2.weight torch.Size([64, 64, 3, 3])
model.layer1.0.bn2.weight torch.Size([64])
model.layer1.0.bn2.bias torch.Size([64])
model.layer1.1.conv1.weight torch.Size([64, 64, 3, 3])
model.layer1.1.bn1.weight torch.Size([64])
model.layer1.1.bn1.bias torch.Size([64])
model.layer1.1.conv2.weight torch.Size([64, 64, 3, 3])
model.layer1.1.bn2.weight torch.Size([64])
model.layer1.1.bn2.bias torch.Size([64])
model.layer2.0.conv1.weight torch.Size([128, 64, 3, 3])
model.layer2.0.bn1.weight torch.Size([128])
model.layer2.0.bn1.bias torch.Size([128])
model.layer2.0.conv2.weight torch.Size([128, 128, 3, 3])
model.layer2.0.bn2.weight torch.Size([128])
model.layer2.0.bn2.bias torch.Size([128])
model.layer2.0.downsample.0.weight torch.Size([1

In [125]:
for name, param in model.named_parameters():
    print(name, param[0][0])
    break

model.conv1.weight tensor([[-0.0104, -0.0061, -0.0018,  0.0748,  0.0566,  0.0171, -0.0127],
        [ 0.0111,  0.0095, -0.1099, -0.2805, -0.2712, -0.1291,  0.0037],
        [-0.0069,  0.0591,  0.2955,  0.5872,  0.5197,  0.2563,  0.0636],
        [ 0.0305, -0.0670, -0.2984, -0.4387, -0.2709, -0.0006,  0.0576],
        [-0.0275,  0.0160,  0.0726, -0.0541, -0.3328, -0.4206, -0.2578],
        [ 0.0306,  0.0410,  0.0628,  0.2390,  0.4138,  0.3936,  0.1661],
        [-0.0137, -0.0037, -0.0241, -0.0659, -0.1507, -0.0822, -0.0058]],
       device='cuda:0', grad_fn=<SelectBackward>)


In [126]:
model.apply(weight_init)
for name, param in model.named_parameters():
    print(name, param[0][0])
    break

model.conv1.weight tensor([[ 0.0299, -0.0026, -0.0051,  0.0486,  0.0018, -0.0319,  0.0104],
        [ 0.0123,  0.0202, -0.0153,  0.0333,  0.0261,  0.0053, -0.0280],
        [ 0.0233, -0.0267, -0.0044, -0.0001, -0.0066, -0.0065, -0.0295],
        [-0.0304,  0.0186, -0.0403,  0.0035, -0.0094, -0.0013,  0.0125],
        [-0.0188,  0.0122, -0.0129, -0.0221, -0.0008, -0.0063,  0.0327],
        [-0.0093,  0.0137,  0.0181,  0.0178, -0.0103,  0.0218, -0.0203],
        [-0.0383,  0.0038,  0.0386, -0.0114,  0.0342,  0.0154,  0.0481]],
       device='cuda:0', grad_fn=<SelectBackward>)


In [127]:
def make_mask(model):
    global step
    global mask
    step = 0
    for name, param in model.named_parameters(): 
        if 'weight' in name:
            step = step + 1
    mask = [None]* step 
    print(mask)
    step = 0
    for name, param in model.named_parameters(): 
        if 'weight' in name:
            tensor = param.data.cpu().numpy()
            mask[step] = np.ones_like(tensor)
            step = step + 1
    print(step)
    step = 0
    print(type(mask))
    print(len(mask))
    print(mask[0][0])
    
    print("Final list: ")
    idx = 0
    for name, param in model.named_parameters():
        if 'weight' in name:
            print(f"name: {name}, param_size: {param.size()}, mask_size: {np.array(mask[idx]).shape}")
            idx += 1

In [128]:
chk_pt = "/ocean/projects/asc170022p/shg121/PhD/Project_Pruning/output/chk_pt/Pruning"
chk_pt_file = os.path.join(chk_pt, model_arch, dataset_name)
try:
    os.makedirs(chk_pt_file, exist_ok=True)
    print("Checkpoint directory is created successfully at:")
    print(chk_pt_file)
except OSError as error:
    print(f"Checkpoint directory {chk_pt_file} can not be created")
    
initial_state_dict = copy.deepcopy(model.state_dict())
torch.save(model, os.path.join(chk_pt_file, f"initial_state_dict_prune_type_{prune_type}.pth.tar"))

make_mask(model)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

    # Loss Function
criterion = nn.BCELoss()

Checkpoint directory is created successfully at:
/ocean/projects/asc170022p/shg121/PhD/Project_Pruning/output/chk_pt/Pruning/Resnet_18/mnist
[None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]
41
<class 'list'>
41
[[[1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1.]]

 [[1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1.]]

 [[1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1.]]]
Final list: 
name: model.conv1.weight, param_size: torch.Siz

In [129]:
bestacc = 0.0
best_accuracy = 0

comp = np.zeros(ITERATION,float)
bestacc = np.zeros(ITERATION,float)
step = 0
all_loss = np.zeros(end_iter,float)
all_accuracy = np.zeros(end_iter,float)

print(comp.shape)
print(bestacc.shape)
print(all_loss.shape)
print(all_accuracy.shape)

(35,)
(35,)
(100,)
(100,)


In [130]:
def prune_by_percentile(percent, resample=False, reinit=False,**kwargs):
        global step
        global mask
        global model

        # Calculate percentile value
        step = 0
        for name, param in model.named_parameters():

            # We do not prune bias term
            if 'weight' in name:
                tensor = param.data.cpu().numpy()
                alive = tensor[np.nonzero(tensor)] # flattened array of nonzero values
                percentile_value = np.percentile(abs(alive), percent)

                # Convert Tensors to numpy and calculate
                weight_dev = param.device
                new_mask = np.where(abs(tensor) < percentile_value, 0, mask[step])
                
                # Apply new weight and mask
                param.data = torch.from_numpy(tensor * new_mask).to(weight_dev)
                mask[step] = new_mask
                step += 1
        step = 0

In [131]:
def original_initialization(mask_temp, initial_state_dict):
    global step
    global mask
    step = 0
    for name, param in self.model.named_parameters(): 
        if "weight" in name: 
            weight_dev = param.device
            param.data = torch.from_numpy(mask_temp[step] * initial_state_dict[name].cpu().numpy()).to(weight_dev)
            step = step + 1
        if "bias" in name:
            param.data = initial_state_dict[name]

In [132]:
def print_nonzeros(model):
    nonzero = total = 0
    for name, p in model.named_parameters():
        tensor = p.data.cpu().numpy()
        nz_count = np.count_nonzero(tensor)
        total_params = np.prod(tensor.shape)
        nonzero += nz_count
        total += total_params
        print(f'{name:20} | nonzeros = {nz_count:7} / {total_params:7} ({100 * nz_count / total_params:6.2f}%) | total_pruned = {total_params - nz_count :7} | shape = {tensor.shape}')
    print(f'alive: {nonzero}, pruned : {total - nonzero}, total: {total}, Compression rate : {total/nonzero:10.2f}x  ({100 * (total-nonzero) / total:6.2f}% pruned)')
    return (round((nonzero/total)*100,1))


In [135]:
ITE = 1
for _ite in range(start_iter, ITERATION):
    if _ite > 0:
        print(_ite)
        prune_by_percentile(prune_percent, resample=resample, reinit=reinit)
        print("------")
        
        if reinit:
            model.apply(weight_init)
            step = 0
            for name, param in model.named_parameters():
                if 'weight' in name:
                    weight_dev = param.device
                    param.data = torch.from_numpy(param.data.cpu().numpy() * mask[step]).to(weight_dev)
                    step = step + 1
            step = 0
        else:
            original_initialization(mask, initial_state_dict)
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4)
    
    print(f"\n--- Pruning Level [{ITE}:{_ite}/{ITERATION}]: ---")
    comp1 = print_nonzeros(model)
    comp[_ite] = comp1
    print(comp)
    pbar = tqdm(range(end_iter))
    print(pbar)
    break
    


--- Pruning Level [1:0/35]: ---
model.conv1.weight   | nonzeros =    9408 /    9408 (100.00%) | total_pruned =       0 | shape = (64, 3, 7, 7)
model.bn1.weight     | nonzeros =      64 /      64 (100.00%) | total_pruned =       0 | shape = (64,)
model.bn1.bias       | nonzeros =       0 /      64 (  0.00%) | total_pruned =      64 | shape = (64,)
model.layer1.0.conv1.weight | nonzeros =   36864 /   36864 (100.00%) | total_pruned =       0 | shape = (64, 64, 3, 3)
model.layer1.0.bn1.weight | nonzeros =      64 /      64 (100.00%) | total_pruned =       0 | shape = (64,)
model.layer1.0.bn1.bias | nonzeros =       0 /      64 (  0.00%) | total_pruned =      64 | shape = (64,)
model.layer1.0.conv2.weight | nonzeros =   36864 /   36864 (100.00%) | total_pruned =       0 | shape = (64, 64, 3, 3)
model.layer1.0.bn2.weight | nonzeros =      64 /      64 (100.00%) | total_pruned =       0 | shape = (64,)
model.layer1.0.bn2.bias | nonzeros =       0 /      64 (  0.00%) | total_pruned =      64 


  0%|          | 0/100 [02:35<?, ?it/s][A

  0%|          | 0/100 [00:00<?, ?it/s]





In [137]:
end_iter_ = 2
for iter_ in range(end_iter_):
    print(iter_)

0
1


In [150]:
def flatten_cnn_activations(activations, kernel_size, stride=1):
    max_pool = torch.nn.MaxPool2d(kernel_size, stride=1)
    torch_activation = torch.from_numpy(activations)
    max_pool_activation = max_pool(torch_activation)
    flatten_activations = max_pool_activation.view(
        max_pool_activation.size()[0], -1
    ).numpy()
    
    print(flatten_activations)
    print(flatten_activations.shape)
    
activations = np.array([[[1., 2., 44., 2.], [11., 14., 1., 2.], [11., 14., 1., 2.], [11., 14., 1., 2.]], 
                       [[1., 2., 44., 2.], [11., 14., 1., 2.], [101., 14., 1., 2.], [11., 14., 1., 2.]]])
print(activations)
print(activations.shape)
flatten_cnn_activations(activations, 4)

[[[  1.   2.  44.   2.]
  [ 11.  14.   1.   2.]
  [ 11.  14.   1.   2.]
  [ 11.  14.   1.   2.]]

 [[  1.   2.  44.   2.]
  [ 11.  14.   1.   2.]
  [101.  14.   1.   2.]
  [ 11.  14.   1.   2.]]]
(2, 4, 4)
[[ 44.]
 [101.]]
(2, 1)
