In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from speech_command_dataset import SpeechCommandDataset
import numpy as np
import matplotlib.pyplot as plt
from model import M5
import argparse
import torch.nn.utils.prune as prune
import copy
from tqdm import tqdm
import time 
from prettytable import PrettyTable
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

In [41]:
batchsize   = 256
Epoch       = 300
lr       = 0.0001
pruning       = 0.26
name       = 'coarse_1_'

In [5]:
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True 
torch.backends.cudnn.benchmark = False

device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [8]:
# declare dataloader  ------------------------------------------------------------------------------------
training_params = {"batch_size":batchsize ,
                    "shuffle": True,
                    "drop_last": False,
                    "num_workers": 1}

testing_params = {"batch_size": batchsize,
                    "shuffle": False,
                    "drop_last": False,
                    "num_workers": 1}

train_set = SpeechCommandDataset()
train_loader = DataLoader(train_set, **training_params) 

train_set_aug = SpeechCommandDataset(aug=True)
train_loader_aug = DataLoader(train_set_aug, **training_params) 

test_set = SpeechCommandDataset(is_training=False)
test_loader = DataLoader(test_set, **testing_params)

In [10]:

def train(model, epoch,data_loader,data_set,data_loader_aug,data_set_aug,device,optimizer):
    model.train()
    total_loss = 0
    correct = 0
    criterion = nn.CrossEntropyLoss() 
    print("----------------------------------------------------------------------------------------------------")

    for data, target in tqdm(data_loader):
              
        data = data.to(device)
        target = target.to(device)


        output = model(data)

        target = target.to(torch.int64) 
        loss = criterion(output, target)


        total_loss += loss.item()
        pred = output.argmax(dim=-1)
        correct += pred.squeeze().eq(target).sum().item()

        #backward
        optimizer.zero_grad()

        loss.backward()
        optimizer.step()  



    for data, target in tqdm(data_loader_aug):
              
        data = data.to(device)
        target = target.to(device)


        output = model(data)

        target = target.to(torch.int64) 
        loss = criterion(output, target)


        total_loss += loss.item()
        pred = output.argmax(dim=-1)
        correct += pred.squeeze().eq(target).sum().item()

        #backward
        optimizer.zero_grad()

        loss.backward()
        optimizer.step()  
  


   
    train_loss = float(total_loss) / (len(train_loader)+len(train_loader_aug))
    train_acc = 100.0 * float(correct) / (len(data_set)+len(data_set_aug))
    # print('Epochhhh: %3d' % epoch, '|train loss: %.4f' % train_loss, '|train accuracy: %.2f' % train_acc)
    return train_acc,train_loss 


In [11]:

def test(model):
    model.eval()
    correct = 0

    for data, target in test_loader:

        data = data.to(device)
        target = target.to(device)

        #forward
        output = model(data)

        pred = output.argmax(dim=-1)
        correct += pred.squeeze().eq(target).sum().item()

    # print testing stats
    test_acc = 100.0 * float(correct) / len(test_set)
    # print('test accuracy: %.2f' % test_acc)
    return test_acc


In [12]:

def count_parameters(model,show=True):
    table = PrettyTable(["Modules", "Parameters","zero number","Sparsity"])
    total_params = 0
    org_w=[]
    org_z=[]
    length = len(list(model.parameters()))
    for i, (name, parameter )in enumerate(model.named_parameters()):
   
        if not parameter.requires_grad: continue
        w = parameter.detach().cpu().numpy()    
        z_num=np.sum(np.sum(np.where(w,0,1)))
        if len(parameter.size())!=1 and i<length-2:
            org_w.append(parameter.numel())
            org_z.append(z_num)
        params = parameter.numel()
        table.add_row([name, params,z_num,str(round((z_num/params)*100,1))+" %"])
        total_params+=params
    sparsity=round(np.sum(org_z)/np.sum(org_w),1)
    if show:
        print(table)
        print(f"Total Trainable Params: {total_params}, Sparsity: {sparsity*100} %")

    return total_params,sparsity
    


## load model

In [14]:
# load model  ------------------------------------------------------------------------------------
model_path = './log/best_model_clean.pth.tar'
print("=> loading checkpoint '{}'".format(model_path))
checkpoint = torch.load(model_path, map_location = device)
model_org = M5(cfg = checkpoint['cfg']).to(device)
model_org.load_state_dict(checkpoint['state_dict'])
cfg = checkpoint['cfg']
# print(cfg)
# calulate parameter  ------------------------------------------------------------------------------------
print(model_org)
# count_parameters(model_org)



    

=> loading checkpoint './log/best_model_clean.pth.tar'
M5(
  (features): Sequential(
    (0): Conv1d(1, 128, kernel_size=(40,), stride=(2,), padding=(19,))
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
    (4): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))
    (5): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
    (8): Conv1d(128, 256, kernel_size=(3,), stride=(1,), padding=(1,))
    (9): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
    (12): Conv1d(256, 512, kernel_size=(3,), stride=(1,), padding=(1,))
    (13): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine

In [28]:
total_param = sum([param.nelement() for param in model_org.parameters()])
test_acc = test(model_org)
print("Number of parameter before pruning: %.2fk" % (total_param/1e3))
print('Accuracy before pruning:'+str(round(test_acc,2))+"%")

Number of parameter before pruning: 553.99k
Accuracy before pruning:91.47%


## Coarse-grained Pruning

In [19]:
# you can choose your pruning rate
pruning_rate = 0.26

In [23]:
#calculate pruning threshold
cfgs = []         #example format:  [125, 120, 155, 403]
cfgs_mask = [] 
# for i, param in enumerate(model_org.parameters()):
#     print(param)
BN_torch = torch.tensor([]).to(device)
for m in model_org.modules():
    if isinstance(m, nn.BatchNorm1d):
        BN_torch=torch.cat((BN_torch,m.weight.data),0)

sorted_weight = torch.sort(BN_torch)[0]
thres_index = int(len(sorted_weight) * pruning_rate)
thres = sorted_weight[thres_index]




In [24]:
#get configuration and mask for pruned network
for m in model_org.modules():
    if isinstance(m, nn.BatchNorm1d):
        cfg = (m.weight.data > thres).sum().to(device)
        mask = (m.weight.data > thres).to(device)
        # cfgs.append(len(cfg))
        cfgs.append(cfg.item())
        cfgs_mask.append(mask)
print("Format of model:")
print(cfgs)
print('Pre-processing Successful!')

Format of model:
[20, 81, 144, 512]
Pre-processing Successful!


In [25]:
new_model = M5(cfgs).to(device)

In [33]:
old_modules = list(model_org.modules())
new_modules = list(new_model.modules())

layer_id_in_cfg = 0
start_mask = torch.ones(1, dtype = bool)
end_mask = cfgs_mask[layer_id_in_cfg]

for layer_id in range(len(old_modules)):
    m0 = old_modules[layer_id]
    m1 = new_modules[layer_id]
    if isinstance(m0, nn.BatchNorm1d):
        m1.weight.data = m0.weight.data[end_mask].clone()
        m1.bias.data = m0.bias.data[end_mask].clone()
        m1.running_mean = m0.running_mean[end_mask].clone()
        m1.running_var = m0.running_var[end_mask].clone()
        layer_id_in_cfg += 1
        start_mask = end_mask.clone()
        if layer_id_in_cfg < len(cfgs_mask): #prevent out of range
            end_mask = cfgs_mask[layer_id_in_cfg]
            
    elif isinstance(m0, nn.Conv1d):
        w1 = m0.weight.data[:, start_mask, :].clone()
        w1 = w1[end_mask, :, :].clone()
        m1.weight.data = w1.clone()
        m1.bias.data = m0.bias.data[end_mask]

    elif isinstance(m0, nn.Linear):
        m1.weight.data = m0.weight.data[:, start_mask].clone()
        m1.bias.data = m0.bias.data.clone()
        
        

print('cfg', cfg)
# torch.save({'cfg': cfg, 'state_dict': new_model.state_dict()}, './Checkpoint/coarse_model.pth.tar')

cfg tensor(512, device='cuda:0')


## load pruned model and see result

In [34]:
# checkpoint = torch.load('./Checkpoint/coarse_model.pth.tar')
# prune_model = M5(cfg=checkpoint['cfg']).to(device)
# prune_model.load_state_dict(checkpoint['state_dict'])
prune_model=new_modules
print(prune_model)

[M5(
  (features): Sequential(
    (0): Conv1d(1, 20, kernel_size=(40,), stride=(2,), padding=(19,))
    (1): BatchNorm1d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
    (4): Conv1d(20, 81, kernel_size=(3,), stride=(1,), padding=(1,))
    (5): BatchNorm1d(81, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
    (8): Conv1d(81, 144, kernel_size=(3,), stride=(1,), padding=(1,))
    (9): BatchNorm1d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
    (12): Conv1d(144, 512, kernel_size=(3,), stride=(1,), padding=(1,))
    (13): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (14): ReLU()
    (15): 

In [36]:

total_param = sum([param.nelement() for param in new_model.parameters()])
test_acc = test(new_model)
print("Number of parameter after pruning: %.2fk" % (total_param/1e3))
print('Accuracy after pruning, before fine-tuning:'+str(round(test_acc,2))+"%")

Number of parameter after pruning: 269.24k
Accuracy after pruning, before fine-tuning:9.04%


## Fine-grained

In [44]:
def Pruning_fg(model,p_pruning=60):
    model_fg = copy.deepcopy(model)
    length = len(list(model_fg.parameters()))
    mask_list=[]
    for i, param in enumerate(model_fg.parameters()):
        if len(param.size())==3:
            weight = param.detach().cpu().numpy()
            w_mask=np.abs(weight)<np.percentile(np.abs(weight),p_pruning)
            mask_list.append(w_mask)
            weight[w_mask] = 0       
            weight = torch.from_numpy(weight).to(device)
            param.data = weight
    return model_fg,mask_list

In [46]:
model_fg ,mask_list =Pruning_fg(model_org,p_pruning=pruning)

## Fine-tune

In [38]:
model=model_fg # Fine-grained
model=new_model # Coarse-grained
# declare optimizer and loss function
optimizer = optim.Adam(new_model.parameters(), lr=lr  )
print('start finetuning')

best_accuracy = 0
prune_model=model

localtime = time.asctime( time.localtime(time.time()) )
timecode=localtime[9:10]+"_"+localtime[11:13]+"_"+localtime[14:16]
checkpoint = open('./Checkpoint/'+name+str(timecode)+'_batchsize_'+str(batchsize)+'.txt', 'w')

for epoch in range(1, Epoch + 1):
    train_acc ,train_loss= train(new_model, epoch,train_loader,train_set,train_loader_aug,train_set_aug,device,optimizer)
    test_acc = test(new_model)
    # total_params,sparsity_train= count_parameters(model_fg,show=False)
    
    print('Epoch: %3d' % epoch, '|train loss: %.4f' % train_loss, '|train accuracy: %.2f' % train_acc,'|test_acc:  %.2f' % test_acc)
    
    checkpoint = open('./Checkpoint/'+name+str(timecode)+'_batchsize_'+str(batchsize)+'.txt', 'a')
    checkpoint.write('Epoch: %3d' % epoch + '|train loss: %.4f' % train_loss+ '|train accuracy: %.2f' % train_acc+ '|test accuracy: %.2f' % test_acc+'\n')        
    checkpoint.close()
    if test_acc > best_accuracy:
        
        print('Saving..')
        torch.save({'cfg': new_model.cfg, 'state_dict': new_model.state_dict()}, './Checkpoint/'+name+str(timecode)+'_batchsize_'+str(batchsize)+'.pth.tar')
        best_accuracy = test_acc
    
    # print('Epoch: %3d' % epoch, '|train loss: %.4f' % train_loss, '|train accuracy: %.2f' % train_acc,'|test accuracy: %.2f' % test_acc,'|best accuracy: %.2f' % best_accuracy)  
print('Best accuracy: %.2f' % best_accuracy)

start finetuning
----------------------------------------------------------------------------------------------------


  9%|▉         | 7/75 [00:08<01:26,  1.26s/it]


KeyboardInterrupt: 

In [None]:
checkpoint = torch.load('./Checkpoint/finetuned_model.pth.tar')
finetuned_model = M5(cfg=checkpoint['cfg']).to(device)
finetuned_model.load_state_dict(checkpoint['state_dict'])

In [None]:
total_param = sum([param.nelement() for param in finetuned_model.parameters()])
print("Number of parameter after pruning: %.2fk" % (total_param/1e3))

print('Accuracy after fine-tuning')
test_acc = test(finetuned_model, 0)