In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from speech_command_dataset import SpeechCommandDataset
import numpy as np
import matplotlib.pyplot as plt
from model import M5
import argparse
import torch.nn.utils.prune as prune
import copy
from tqdm import tqdm
import time 

In [2]:
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True 
torch.backends.cudnn.benchmark = False

device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
training_params = {"batch_size": 128,
                       "shuffle": True,
                       "drop_last": False,
                       "num_workers": 1}

testing_params = {"batch_size": 128,
                       "shuffle": False,
                       "drop_last": False,
                       "num_workers": 1}

train_set = SpeechCommandDataset()
train_loader = DataLoader(train_set, **training_params)

test_set = SpeechCommandDataset(is_training=False)
test_loader = DataLoader(test_set, **testing_params)

In [4]:
def train(model, epoch):
    model.train()
    total_loss = 0
    correct = 0
    criterion = nn.CrossEntropyLoss() 
    
    for batch_idx, (data, target) in enumerate(train_loader):

        data = data.to(device)
        target = target.to(device)

        #forward
        output = model(data)
        target = target.to(torch.int64) 
        loss = criterion(output, target)
        
        total_loss += loss.item()
        pred = output.argmax(dim=-1)
        correct += pred.squeeze().eq(target).sum().item()        
        
        #backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # print training stats
    train_loss = float(total_loss) / len(train_loader)
    train_acc = 100.0 * float(correct) / len(train_set)
    print('Epoch: %3d' % epoch, '|train loss: %.4f' % train_loss, '|train accuracy: %.2f' % train_acc)
    return train_acc

In [5]:
def test(model, epoch):
    model.eval()
    correct = 0
    for data, target in test_loader:

        data = data.to(device)
        target = target.to(device)

        #forward
        output = model(data)

        pred = output.argmax(dim=-1)
        correct += pred.squeeze().eq(target).sum().item()
        
    # print testing stats
    test_acc = 100.0 * float(correct) / len(test_set)
    print('Epoch: %3d' % epoch, '|test accuracy: %.2f' % test_acc)
    return test_acc

## load model

In [6]:
model_path = './Checkpoint/best_model_org.pth.tar'
print("=> loading checkpoint '{}'".format(model_path))
checkpoint = torch.load(model_path, map_location = device)

model = M5(cfg = checkpoint['cfg']).to(device)
model.load_state_dict(checkpoint['state_dict'])

=> loading checkpoint './Checkpoint/best_model_org.pth.tar'


<All keys matched successfully>

In [7]:
total_param = sum([param.nelement() for param in model.parameters()])
print("Number of parameter before pruning: %.2fk" % (total_param/1e3))

print('\nAccuracy before pruning')
test_acc = test(model, 0)

Number of parameter before pruning: 553.99k

Accuracy before pruning
Epoch:   0 |test accuracy: 93.41


## Fine-tune

In [8]:
model_fg = copy.deepcopy(model)
length = len(list(model_fg.parameters()))
for i, param in enumerate(model_fg.parameters()):
    if len(param.size())!=1 and i<length-2:
        weight = param.detach().cpu().numpy()
        w_mask=np.abs(weight)<np.percentile(np.abs(weight),40)
        weight[w_mask] = 0       
        weight = torch.from_numpy(weight).to(device)
        param.data = weight
test_acc = test(model_fg, 0)
print('\nAccuracy after pruning')
print(test_acc)

Epoch:   0 |test accuracy: 68.96

Accuracy after pruning
68.95580589254766


In [9]:
EPOCH = 100
LR = 0.001

# declare optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=LR  )
print('start finetuning')

best_accuracy = 0
prune_model=model
for epoch in range(1, EPOCH + 1):
    train_acc = train(prune_model, epoch)
    test_acc = test(prune_model, epoch)
    
    if test_acc > best_accuracy:
        print('Saving..')
        torch.save({'cfg': prune_model.cfg, 'state_dict': prune_model.state_dict()}, './Checkpoint/finetuned_model.pth.tar')
        best_accuracy = test_acc
        
print('Best accuracy: %.2f' % best_accuracy)

start finetuning
Epoch:   1 |train loss: 0.4657 |train accuracy: 86.60
Epoch:   1 |test accuracy: 56.46
Saving..
Epoch:   2 |train loss: 0.3494 |train accuracy: 89.35
Epoch:   2 |test accuracy: 81.74
Saving..
Epoch:   3 |train loss: 0.3282 |train accuracy: 90.18
Epoch:   3 |test accuracy: 67.33
Epoch:   4 |train loss: 0.2969 |train accuracy: 91.18
Epoch:   4 |test accuracy: 60.25
Epoch:   5 |train loss: 0.2928 |train accuracy: 90.99
Epoch:   5 |test accuracy: 77.92
Epoch:   6 |train loss: 0.2512 |train accuracy: 92.22
Epoch:   6 |test accuracy: 70.95


KeyboardInterrupt: 

In [None]:
checkpoint = torch.load('./Checkpoint/finetuned_model.pth.tar')
finetuned_model = M5(cfg=checkpoint['cfg']).to(device)
finetuned_model.load_state_dict(checkpoint['state_dict'])

In [None]:
total_param = sum([param.nelement() for param in finetuned_model.parameters()])
print("Number of parameter after pruning: %.2fk" % (total_param/1e3))

print('Accuracy after fine-tuning')
test_acc = test(finetuned_model, 0)