In [14]:
import torch
import numpy as np
from models.convnet import *
torch.manual_seed(777)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
saved_pth = "./resnet50/best.pt"
model = ConvNet().to(device)

model.load_state_dict(torch.load(saved_pth))

<All keys matched successfully>

In [15]:
from trainer import *
from face import *
import torchvision.transforms as transforms
from tqdm import tqdm
def validate(model, dataset=FaceDataset("./data/aflw_val/", transform=transforms.ToTensor()), batch_size=64, criterion=NME()):
        val_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)
        model.eval()
        val_loss = 0
        for imgs, landmarks in tqdm(val_loader):
            imgs, landmarks = imgs.to(device), landmarks.to(device)
            with torch.no_grad():
                pred_landmarks = model(imgs)
            loss = criterion(pred_landmarks, landmarks)
            val_loss += loss.item()
        return val_loss / len(val_loader)

print("Best.pt Validation Loss:")
validate(model=model)


Best.pt Validation Loss:


100%|██████████| 4/4 [00:02<00:00,  1.42it/s]


0.020460421685129404

In [16]:
import torch.nn.utils.prune as prune
def remove_parameters(model):

    for module_name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            try:
                prune.remove(module, "weight")
            except:
                pass
            try:
                prune.remove(module, "bias")
            except:
                pass
        elif isinstance(module, torch.nn.Linear):
            try:
                prune.remove(module, "weight")
            except:
                pass
            try:
                prune.remove(module, "bias")
            except:
                pass

    return model

In [17]:
def compute_final_pruning_rate(pruning_rate, num_iterations):
    '''
    A function to compute the final pruning rate for iterative pruning.
        Note that this cannot be applied for global pruning rate if the pruning rate is heterogeneous among different layers.
    Args:
        pruning_rate (float): Pruning rate.
        num_iterations (int): Number of iterations.
    Returns:
        float: Final pruning rate.
    '''

    final_pruning_rate = 1 - (1 - pruning_rate) ** num_iterations

    return final_pruning_rate

In [18]:
def measure_module_sparsity(module, weight=True, bias=False, use_mask=False):

    num_zeros = 0
    num_elements = 0

    if use_mask == True:
        for buffer_name, buffer in module.named_buffers():
            if "weight_mask" in buffer_name and weight == True:
                num_zeros += torch.sum(buffer == 0).item()
                num_elements += buffer.nelement()
            if "bias_mask" in buffer_name and bias == True:
                num_zeros += torch.sum(buffer == 0).item()
                num_elements += buffer.nelement()
    else:
        for param_name, param in module.named_parameters():
            if "weight" in param_name and weight == True:
                num_zeros += torch.sum(param == 0).item()
                num_elements += param.nelement()
            if "bias" in param_name and bias == True:
                num_zeros += torch.sum(param == 0).item()
                num_elements += param.nelement()

    sparsity = num_zeros / num_elements

    return num_zeros, num_elements, sparsity

In [19]:
def measure_global_sparsity(
    model, weight = True,
    bias = False, conv2d_use_mask = False,
    linear_use_mask = False):

    num_zeros = 0
    num_elements = 0

    for module_name, module in model.named_modules():

        if isinstance(module, torch.nn.Conv2d):

            module_num_zeros, module_num_elements, _ = measure_module_sparsity(
                module, weight=weight, bias=bias, use_mask=conv2d_use_mask)
            num_zeros += module_num_zeros
            num_elements += module_num_elements

        elif isinstance(module, torch.nn.Linear):

            module_num_zeros, module_num_elements, _ = measure_module_sparsity(
                module, weight=weight, bias=bias, use_mask=linear_use_mask)
            num_zeros += module_num_zeros
            num_elements += module_num_elements

    sparsity = num_zeros / num_elements

    return num_zeros, num_elements, sparsity

In [20]:
l1_regularization_strength = 0
l2_regularization_strength = 1e-4
learning_rate = 0.01
learning_rate_decay = 1
dataset=FaceDataset("./data/aflw_val/", transform=transforms.ToTensor())
train_dataset=FaceDataset("./data/synthetics_train/", transforms.Compose([transforms.ColorJitter(0.5, 0.5, 0.5, 0.5), transforms.ToTensor()]))
val_loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=8, pin_memory=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=8, pin_memory=True)


# Finetune the trained model
def finetune_train_model(model, train_loader=train_loader, val_loader=val_loader, device=device, l1_regularization_strength=0,
                        l2_regularization_strength=l2_regularization_strength, learning_rate=1e-1, num_epochs=20):
    criterion = NME()
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=l2_regularization_strength)

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[4, 12], gamma=0.1, last_epoch=-1)
    model.eval()
    eval_loss = validate(model)
    print(f"Pre fine-tuning: val_loss = {eval_loss:.3f}")

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0
        for img, landmark in tqdm(train_loader):
            img, landmark = img.to(device), landmark.to(device)
            optimizer.zero_grad()

            output = model(img)
            loss = criterion(output, landmark)

            l1_reg = torch.tensor(0.).to(device)
            for module in model.modules():
                mask = None
                weight = None
                for name, buffer in module.named_buffers():
                    if name == "weight_mask":
                        mask = buffer
                for name, param in module.named_buffers():
                    if name == "weight_orgi":
                        weight = param
                if mask is not None and weight is not None:
                    l1_reg += torch.norm(mask*weight, 1)
            
            loss += l1_regularization_strength*l1_reg
            loss.backward()
            optimizer.step()
            running_loss += loss.item()*img.size(0)
        
        train_loss = running_loss / len(train_loader.dataset)
        model.eval()
        eval_loss = validate(model)
        scheduler.step()
        print(f"epoch = {epoch + 1} train_loss = {train_loss:.3f}, val_loss = {eval_loss:.4f}")
    
    return model



In [21]:
def iterative_pruning_finetuning(
    model, train_loader, val_loader, device,
    learning_rate, l1_regularization_strength,
    l2_regularization_strength, learning_rate_decay = 0.1,
    conv2d_prune_amount = 0.2, linear_prune_amount = 0.1,
    num_iterations = 10, num_epochs_per_iteration = 16,
    grouped_pruning = False):
    for i in range(num_iterations):
        print("\nPruning and Finetuning {}/{}".format(i + 1, num_iterations))

        print("Pruning...")


        # NOTE: For global pruning, linear/dense layer can also be pruned!
        if grouped_pruning == True:
            # grouped_pruning -> Global pruning
            parameters_to_prune = []
            for module_name, module in model.named_modules():
                if isinstance(module, torch.nn.Conv2d):
                    parameters_to_prune.append((module, "weight"))
                elif isinstance(module, torch.nn.Linear):
                    parameters_to_prune.append((module, "weight"))
        
            # L1Unstructured - prune (currently unpruned) entries in a tensor by zeroing
            # out the ones with the lowest absolute magnitude-
            prune.global_unstructured(
                parameters_to_prune,
                pruning_method = prune.L1Unstructured,
                amount = conv2d_prune_amount,
            )
        
        # layer-wise pruning-
        else:
            for module_name, module in model.named_modules():
                if isinstance(module, torch.nn.Conv2d):
                    prune.l1_unstructured(
                        module, name = "weight",
                        amount = conv2d_prune_amount)
                elif isinstance(module, torch.nn.Linear):
                    prune.l1_unstructured(
                        module, name = "weight",
                        amount = linear_prune_amount)

        # Compute validation accuracy just after pruning-
        eval_loss = validate(model)


        # Compute global sparsity-
        num_zeros, num_elements, sparsity = measure_global_sparsity(
            model, weight = True,
            bias = False, conv2d_use_mask = True,
            linear_use_mask = False)
        
        print(f"Global sparsity = {sparsity * 100:.3f}% & val_loss = {eval_loss * 100:.4f}%")

        print("\nFine-tuning...")

        fine_tuned_model = finetune_train_model(
            model = model, train_loader = train_loader,
            val_loader = val_loader, device = device,
            l1_regularization_strength = l1_regularization_strength,
            l2_regularization_strength = l2_regularization_strength,
            # i -> current pruning round-
            # learning_rate = learning_rate * (learning_rate_decay ** i),
            learning_rate = learning_rate,
            num_epochs = num_epochs_per_iteration)

        eval_loss = validate(
            model)

        num_zeros, num_elements, sparsity = measure_global_sparsity(
            # model,
            fine_tuned_model, weight = True,
            bias = False, conv2d_use_mask = True,
            linear_use_mask = False)
        # if eval_loss <= 0.04:
            # model_filename = "{}_{}_{}.pt".format({"pruned_resnet"}, i + 1, eval_loss)
            # model_filepath = os.path.join("./pruned/", model_filename)
            # print("Model Save")

        print(f"Post fine-tuning: Global sparsity = {sparsity * 100:.3f}% & val_loss = {eval_loss * 100:.3f}%")

        # model_filename = "{}_{}.pt".format({"pruned_resnet"}, i + 1)
        # model_filepath = os.path.join("./pruned/", model_filename)
        # save_model(model=model,
        #           model_dir=model_dir,
        #           model_filename=model_filename)
        # model = load_model(model=model,
        #                   model_filepath=model_filepath,
        #                   device=device)
        
    return model

In [22]:
eval_loss = validate(
    model)


100%|██████████| 4/4 [00:03<00:00,  1.25it/s]


In [23]:
num_zeros, num_elements, sparsity = measure_global_sparsity(model)
print(f"Global sparsity = {sparsity:.3f}% & val_loss = {eval_loss * 100:.3f}%")

Global sparsity = 0.000% & val_loss = 2.046%


In [24]:
import copy
# pruned_model = copy.deepcopy(model)
model = ConvNet()
model.load_state_dict(torch.load('./pruned/ResNet50_trained_sparsity-89.989.pth'))
model.to(device)
pruned_model = iterative_pruning_finetuning(model=model, train_loader=train_loader, val_loader=val_loader, device=device,
                learning_rate=learning_rate, learning_rate_decay=learning_rate_decay, l1_regularization_strength=l1_regularization_strength,
                l2_regularization_strength=l2_regularization_strength, conv2d_prune_amount=0.35, linear_prune_amount=0.1, num_iterations=8,
                num_epochs_per_iteration=16, grouped_pruning=True)

remove_parameters(model=pruned_model)
eval_loss = validate(pruned_model)
num_zeros, num_elements, sparsity = measure_global_sparsity(pruned_model)
print(f"Global sparsity = {sparsity:.3f} & val_loss = {eval_loss:.3f}")


Pruning and Finetuning 1/8
Pruning...


100%|██████████| 4/4 [00:02<00:00,  1.37it/s]


Global sparsity = 39.161% & val_loss = 3.5136%

Fine-tuning...


100%|██████████| 4/4 [00:02<00:00,  1.67it/s]


Pre fine-tuning: val_loss = 0.035


100%|██████████| 1559/1559 [12:04<00:00,  2.15it/s]
100%|██████████| 4/4 [00:02<00:00,  1.78it/s]


epoch = 1 train_loss = 0.040, val_loss = 0.0510


100%|██████████| 1559/1559 [11:57<00:00,  2.17it/s]
100%|██████████| 4/4 [00:02<00:00,  1.69it/s]


epoch = 2 train_loss = 0.037, val_loss = 0.0578


100%|██████████| 1559/1559 [11:57<00:00,  2.17it/s]
100%|██████████| 4/4 [00:02<00:00,  1.42it/s]


epoch = 3 train_loss = 0.037, val_loss = 0.0569


100%|██████████| 1559/1559 [11:58<00:00,  2.17it/s]
100%|██████████| 4/4 [00:02<00:00,  1.48it/s]


epoch = 4 train_loss = 0.035, val_loss = 0.0577


100%|██████████| 1559/1559 [11:56<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.43it/s]


epoch = 5 train_loss = 0.029, val_loss = 0.0361


100%|██████████| 1559/1559 [11:54<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.67it/s]


epoch = 6 train_loss = 0.028, val_loss = 0.0333


100%|██████████| 1559/1559 [11:53<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.47it/s]


epoch = 7 train_loss = 0.027, val_loss = 0.0336


100%|██████████| 1559/1559 [11:54<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.49it/s]


epoch = 8 train_loss = 0.026, val_loss = 0.0365


100%|██████████| 1559/1559 [11:54<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.42it/s]


epoch = 9 train_loss = 0.026, val_loss = 0.0335


100%|██████████| 1559/1559 [11:54<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.80it/s]


epoch = 10 train_loss = 0.026, val_loss = 0.0312


100%|██████████| 1559/1559 [11:54<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.88it/s]


epoch = 11 train_loss = 0.026, val_loss = 0.0350


100%|██████████| 1559/1559 [11:53<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.46it/s]


epoch = 12 train_loss = 0.025, val_loss = 0.0322


100%|██████████| 1559/1559 [11:54<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.44it/s]


epoch = 13 train_loss = 0.024, val_loss = 0.0301


100%|██████████| 1559/1559 [11:52<00:00,  2.19it/s]
100%|██████████| 4/4 [00:02<00:00,  1.47it/s]


epoch = 14 train_loss = 0.024, val_loss = 0.0305


100%|██████████| 1559/1559 [11:52<00:00,  2.19it/s]
100%|██████████| 4/4 [00:02<00:00,  1.46it/s]


epoch = 15 train_loss = 0.023, val_loss = 0.0301


100%|██████████| 1559/1559 [11:54<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.51it/s]


epoch = 16 train_loss = 0.023, val_loss = 0.0301


100%|██████████| 4/4 [00:02<00:00,  1.46it/s]


Post fine-tuning: Global sparsity = 35.052% & val_loss = 3.014%

Pruning and Finetuning 2/8
Pruning...


100%|██████████| 4/4 [00:02<00:00,  1.70it/s]


Global sparsity = 56.165% & val_loss = 3.0141%

Fine-tuning...


100%|██████████| 4/4 [00:02<00:00,  1.92it/s]


Pre fine-tuning: val_loss = 0.030


100%|██████████| 1559/1559 [12:00<00:00,  2.16it/s]
100%|██████████| 4/4 [00:02<00:00,  1.47it/s]


epoch = 1 train_loss = 0.038, val_loss = 0.0535


100%|██████████| 1559/1559 [11:58<00:00,  2.17it/s]
100%|██████████| 4/4 [00:02<00:00,  1.41it/s]


epoch = 2 train_loss = 0.036, val_loss = 0.4971


100%|██████████| 1559/1559 [11:57<00:00,  2.17it/s]
100%|██████████| 4/4 [00:02<00:00,  1.45it/s]


epoch = 3 train_loss = 0.036, val_loss = 0.1100


100%|██████████| 1559/1559 [11:58<00:00,  2.17it/s]
100%|██████████| 4/4 [00:02<00:00,  1.78it/s]


epoch = 4 train_loss = 0.035, val_loss = 0.1026


100%|██████████| 1559/1559 [11:54<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.44it/s]


epoch = 5 train_loss = 0.029, val_loss = 0.0360


100%|██████████| 1559/1559 [11:53<00:00,  2.19it/s]
100%|██████████| 4/4 [00:02<00:00,  1.76it/s]


epoch = 6 train_loss = 0.028, val_loss = 0.0357


100%|██████████| 1559/1559 [11:53<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.45it/s]


epoch = 7 train_loss = 0.028, val_loss = 0.0354


100%|██████████| 1559/1559 [11:53<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.48it/s]


epoch = 8 train_loss = 0.028, val_loss = 0.0335


100%|██████████| 1559/1559 [11:53<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.49it/s]


epoch = 9 train_loss = 0.027, val_loss = 0.0363


100%|██████████| 1559/1559 [11:53<00:00,  2.19it/s]
100%|██████████| 4/4 [00:02<00:00,  1.41it/s]


epoch = 10 train_loss = 0.027, val_loss = 0.0356


100%|██████████| 1559/1559 [11:54<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.41it/s]


epoch = 11 train_loss = 0.027, val_loss = 0.0379


100%|██████████| 1559/1559 [11:53<00:00,  2.19it/s]
100%|██████████| 4/4 [00:02<00:00,  1.66it/s]


epoch = 12 train_loss = 0.027, val_loss = 0.0346


100%|██████████| 1559/1559 [11:53<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.48it/s]


epoch = 13 train_loss = 0.026, val_loss = 0.0325


100%|██████████| 1559/1559 [11:52<00:00,  2.19it/s]
100%|██████████| 4/4 [00:02<00:00,  1.47it/s]


epoch = 14 train_loss = 0.026, val_loss = 0.0324


100%|██████████| 1559/1559 [11:52<00:00,  2.19it/s]
100%|██████████| 4/4 [00:02<00:00,  1.71it/s]


epoch = 15 train_loss = 0.025, val_loss = 0.0324


100%|██████████| 1559/1559 [11:53<00:00,  2.19it/s]
100%|██████████| 4/4 [00:02<00:00,  1.42it/s]


epoch = 16 train_loss = 0.025, val_loss = 0.0325


100%|██████████| 4/4 [00:02<00:00,  1.41it/s]


Post fine-tuning: Global sparsity = 56.165% & val_loss = 3.249%

Pruning and Finetuning 3/8
Pruning...


100%|██████████| 4/4 [00:02<00:00,  1.50it/s]


Global sparsity = 69.895% & val_loss = 3.2495%

Fine-tuning...


100%|██████████| 4/4 [00:02<00:00,  1.43it/s]


Pre fine-tuning: val_loss = 0.032


100%|██████████| 1559/1559 [11:58<00:00,  2.17it/s]
100%|██████████| 4/4 [00:02<00:00,  1.68it/s]


epoch = 1 train_loss = 0.037, val_loss = 0.0458


100%|██████████| 1559/1559 [11:56<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.80it/s]


epoch = 2 train_loss = 0.036, val_loss = 0.1795


100%|██████████| 1559/1559 [11:58<00:00,  2.17it/s]
100%|██████████| 4/4 [00:02<00:00,  1.75it/s]


epoch = 3 train_loss = 0.036, val_loss = 0.0690


100%|██████████| 1559/1559 [11:57<00:00,  2.17it/s]
100%|██████████| 4/4 [00:02<00:00,  1.46it/s]


epoch = 4 train_loss = 0.036, val_loss = 0.0778


100%|██████████| 1559/1559 [11:52<00:00,  2.19it/s]
100%|██████████| 4/4 [00:02<00:00,  1.41it/s]


epoch = 5 train_loss = 0.030, val_loss = 0.0449


100%|██████████| 1559/1559 [11:51<00:00,  2.19it/s]
100%|██████████| 4/4 [00:02<00:00,  1.46it/s]


epoch = 6 train_loss = 0.030, val_loss = 0.0388


100%|██████████| 1559/1559 [11:52<00:00,  2.19it/s]
100%|██████████| 4/4 [00:02<00:00,  1.47it/s]


epoch = 7 train_loss = 0.029, val_loss = 0.0370


100%|██████████| 1559/1559 [11:52<00:00,  2.19it/s]
100%|██████████| 4/4 [00:02<00:00,  1.64it/s]


epoch = 8 train_loss = 0.029, val_loss = 0.0372


100%|██████████| 1559/1559 [11:52<00:00,  2.19it/s]
100%|██████████| 4/4 [00:02<00:00,  1.45it/s]


epoch = 9 train_loss = 0.029, val_loss = 0.0380


100%|██████████| 1559/1559 [11:54<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.73it/s]


epoch = 10 train_loss = 0.029, val_loss = 0.0370


100%|██████████| 1559/1559 [11:54<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.46it/s]


epoch = 11 train_loss = 0.028, val_loss = 0.0407


100%|██████████| 1559/1559 [11:53<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.76it/s]


epoch = 12 train_loss = 0.028, val_loss = 0.0357


100%|██████████| 1559/1559 [11:52<00:00,  2.19it/s]
100%|██████████| 4/4 [00:02<00:00,  1.45it/s]


epoch = 13 train_loss = 0.027, val_loss = 0.0338


100%|██████████| 1559/1559 [11:53<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.55it/s]


epoch = 14 train_loss = 0.027, val_loss = 0.0340


100%|██████████| 1559/1559 [11:52<00:00,  2.19it/s]
100%|██████████| 4/4 [00:02<00:00,  1.61it/s]


epoch = 15 train_loss = 0.027, val_loss = 0.0341


100%|██████████| 1559/1559 [11:50<00:00,  2.19it/s]
100%|██████████| 4/4 [00:02<00:00,  1.46it/s]


epoch = 16 train_loss = 0.026, val_loss = 0.0335


100%|██████████| 4/4 [00:02<00:00,  1.45it/s]


Post fine-tuning: Global sparsity = 69.895% & val_loss = 3.348%

Pruning and Finetuning 4/8
Pruning...


100%|██████████| 4/4 [00:02<00:00,  1.43it/s]


Global sparsity = 78.914% & val_loss = 3.3483%

Fine-tuning...


100%|██████████| 4/4 [00:02<00:00,  1.65it/s]


Pre fine-tuning: val_loss = 0.033


100%|██████████| 1559/1559 [11:57<00:00,  2.17it/s]
100%|██████████| 4/4 [00:02<00:00,  1.49it/s]


epoch = 1 train_loss = 0.039, val_loss = 0.0565


100%|██████████| 1559/1559 [11:57<00:00,  2.17it/s]
100%|██████████| 4/4 [00:02<00:00,  1.48it/s]


epoch = 2 train_loss = 0.036, val_loss = 0.0656


100%|██████████| 1559/1559 [11:56<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.45it/s]


epoch = 3 train_loss = 0.036, val_loss = 0.0625


100%|██████████| 1559/1559 [11:56<00:00,  2.17it/s]
100%|██████████| 4/4 [00:02<00:00,  1.48it/s]


epoch = 4 train_loss = 0.036, val_loss = 0.1065


100%|██████████| 1559/1559 [11:54<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.76it/s]


epoch = 5 train_loss = 0.030, val_loss = 0.0404


100%|██████████| 1559/1559 [11:55<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.65it/s]


epoch = 6 train_loss = 0.030, val_loss = 0.0375


100%|██████████| 1559/1559 [11:54<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.68it/s]


epoch = 7 train_loss = 0.029, val_loss = 0.0373


100%|██████████| 1559/1559 [11:54<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.53it/s]


epoch = 8 train_loss = 0.029, val_loss = 0.0359


100%|██████████| 1559/1559 [11:54<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.46it/s]


epoch = 9 train_loss = 0.029, val_loss = 0.0361


100%|██████████| 1559/1559 [11:54<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.61it/s]


epoch = 10 train_loss = 0.029, val_loss = 0.0364


100%|██████████| 1559/1559 [11:54<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.83it/s]


epoch = 11 train_loss = 0.028, val_loss = 0.0403


100%|██████████| 1559/1559 [11:54<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.56it/s]


epoch = 12 train_loss = 0.028, val_loss = 0.0362


100%|██████████| 1559/1559 [11:53<00:00,  2.19it/s]
100%|██████████| 4/4 [00:02<00:00,  1.46it/s]


epoch = 13 train_loss = 0.027, val_loss = 0.0346


100%|██████████| 1559/1559 [11:52<00:00,  2.19it/s]
100%|██████████| 4/4 [00:02<00:00,  1.51it/s]


epoch = 14 train_loss = 0.027, val_loss = 0.0345


100%|██████████| 1559/1559 [11:52<00:00,  2.19it/s]
100%|██████████| 4/4 [00:02<00:00,  1.75it/s]


epoch = 15 train_loss = 0.027, val_loss = 0.0345


100%|██████████| 1559/1559 [11:55<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.45it/s]


epoch = 16 train_loss = 0.027, val_loss = 0.0346


100%|██████████| 4/4 [00:02<00:00,  1.43it/s]


Post fine-tuning: Global sparsity = 78.914% & val_loss = 3.461%

Pruning and Finetuning 5/8
Pruning...


100%|██████████| 4/4 [00:02<00:00,  1.43it/s]


Global sparsity = 84.836% & val_loss = 3.4610%

Fine-tuning...


100%|██████████| 4/4 [00:02<00:00,  1.57it/s]


Pre fine-tuning: val_loss = 0.035


100%|██████████| 1559/1559 [11:58<00:00,  2.17it/s]
100%|██████████| 4/4 [00:02<00:00,  1.46it/s]


epoch = 1 train_loss = 0.038, val_loss = 0.0561


100%|██████████| 1559/1559 [12:00<00:00,  2.16it/s]
100%|██████████| 4/4 [00:02<00:00,  1.46it/s]


epoch = 2 train_loss = 0.036, val_loss = 0.0501


100%|██████████| 1559/1559 [11:58<00:00,  2.17it/s]
100%|██████████| 4/4 [00:02<00:00,  1.47it/s]


epoch = 3 train_loss = 0.036, val_loss = 0.0672


100%|██████████| 1559/1559 [12:01<00:00,  2.16it/s]
100%|██████████| 4/4 [00:02<00:00,  1.80it/s]


epoch = 4 train_loss = 0.036, val_loss = 0.0794


100%|██████████| 1559/1559 [11:52<00:00,  2.19it/s]
100%|██████████| 4/4 [00:02<00:00,  1.53it/s]


epoch = 5 train_loss = 0.031, val_loss = 0.0388


100%|██████████| 1559/1559 [11:52<00:00,  2.19it/s]
100%|██████████| 4/4 [00:02<00:00,  1.44it/s]


epoch = 6 train_loss = 0.030, val_loss = 0.0397


100%|██████████| 1559/1559 [11:54<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.85it/s]


epoch = 7 train_loss = 0.030, val_loss = 0.0416


100%|██████████| 1559/1559 [11:53<00:00,  2.19it/s]
100%|██████████| 4/4 [00:02<00:00,  1.64it/s]


epoch = 8 train_loss = 0.030, val_loss = 0.0415


100%|██████████| 1559/1559 [11:52<00:00,  2.19it/s]
100%|██████████| 4/4 [00:02<00:00,  1.71it/s]


epoch = 9 train_loss = 0.030, val_loss = 0.0407


100%|██████████| 1559/1559 [11:54<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.49it/s]


epoch = 10 train_loss = 0.029, val_loss = 0.0377


100%|██████████| 1559/1559 [11:55<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.56it/s]


epoch = 11 train_loss = 0.029, val_loss = 0.0375


100%|██████████| 1559/1559 [11:52<00:00,  2.19it/s]
100%|██████████| 4/4 [00:02<00:00,  1.51it/s]


epoch = 12 train_loss = 0.029, val_loss = 0.0369


100%|██████████| 1559/1559 [11:52<00:00,  2.19it/s]
100%|██████████| 4/4 [00:02<00:00,  1.78it/s]


epoch = 13 train_loss = 0.028, val_loss = 0.0356


100%|██████████| 1559/1559 [11:55<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.43it/s]


epoch = 14 train_loss = 0.028, val_loss = 0.0360


100%|██████████| 1559/1559 [11:50<00:00,  2.19it/s]
100%|██████████| 4/4 [00:02<00:00,  1.39it/s]


epoch = 15 train_loss = 0.028, val_loss = 0.0356


100%|██████████| 1559/1559 [11:49<00:00,  2.20it/s]
100%|██████████| 4/4 [00:02<00:00,  1.45it/s]


epoch = 16 train_loss = 0.028, val_loss = 0.0358


100%|██████████| 4/4 [00:02<00:00,  1.41it/s]


Post fine-tuning: Global sparsity = 84.836% & val_loss = 3.581%

Pruning and Finetuning 6/8
Pruning...


100%|██████████| 4/4 [00:02<00:00,  1.47it/s]


Global sparsity = 88.701% & val_loss = 3.5811%

Fine-tuning...


100%|██████████| 4/4 [00:02<00:00,  1.47it/s]


Pre fine-tuning: val_loss = 0.036


100%|██████████| 1559/1559 [11:56<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.45it/s]


epoch = 1 train_loss = 0.038, val_loss = 0.0577


100%|██████████| 1559/1559 [11:54<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.44it/s]


epoch = 2 train_loss = 0.036, val_loss = 0.0547


100%|██████████| 1559/1559 [11:57<00:00,  2.17it/s]
100%|██████████| 4/4 [00:02<00:00,  1.47it/s]


epoch = 3 train_loss = 0.036, val_loss = 0.0548


100%|██████████| 1559/1559 [11:58<00:00,  2.17it/s]
100%|██████████| 4/4 [00:02<00:00,  1.42it/s]


epoch = 4 train_loss = 0.036, val_loss = 0.0590


100%|██████████| 1559/1559 [11:54<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.44it/s]


epoch = 5 train_loss = 0.031, val_loss = 0.0405


100%|██████████| 1559/1559 [11:54<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.46it/s]


epoch = 6 train_loss = 0.030, val_loss = 0.0381


100%|██████████| 1559/1559 [11:51<00:00,  2.19it/s]
100%|██████████| 4/4 [00:02<00:00,  1.66it/s]


epoch = 7 train_loss = 0.030, val_loss = 0.0383


100%|██████████| 1559/1559 [11:54<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.72it/s]


epoch = 8 train_loss = 0.030, val_loss = 0.0370


100%|██████████| 1559/1559 [11:52<00:00,  2.19it/s]
100%|██████████| 4/4 [00:02<00:00,  1.41it/s]


epoch = 9 train_loss = 0.030, val_loss = 0.0399


100%|██████████| 1559/1559 [11:53<00:00,  2.19it/s]
100%|██████████| 4/4 [00:02<00:00,  1.68it/s]


epoch = 10 train_loss = 0.029, val_loss = 0.0385


100%|██████████| 1559/1559 [11:53<00:00,  2.19it/s]
100%|██████████| 4/4 [00:02<00:00,  1.50it/s]


epoch = 11 train_loss = 0.029, val_loss = 0.0395


100%|██████████| 1559/1559 [11:57<00:00,  2.17it/s]
100%|██████████| 4/4 [00:03<00:00,  1.19it/s]


epoch = 12 train_loss = 0.030, val_loss = 0.0380


100%|██████████| 1559/1559 [12:06<00:00,  2.15it/s]
100%|██████████| 4/4 [00:02<00:00,  1.54it/s]


epoch = 13 train_loss = 0.028, val_loss = 0.0354


100%|██████████| 1559/1559 [11:57<00:00,  2.17it/s]
100%|██████████| 4/4 [00:03<00:00,  1.31it/s]


epoch = 14 train_loss = 0.028, val_loss = 0.0356


100%|██████████| 1559/1559 [11:57<00:00,  2.17it/s]
100%|██████████| 4/4 [00:02<00:00,  1.44it/s]


epoch = 15 train_loss = 0.028, val_loss = 0.0354


100%|██████████| 1559/1559 [11:57<00:00,  2.17it/s]
100%|██████████| 4/4 [00:03<00:00,  1.31it/s]


epoch = 16 train_loss = 0.028, val_loss = 0.0357


100%|██████████| 4/4 [00:02<00:00,  1.38it/s]


Post fine-tuning: Global sparsity = 88.701% & val_loss = 3.573%

Pruning and Finetuning 7/8
Pruning...


100%|██████████| 4/4 [00:02<00:00,  1.44it/s]


Global sparsity = 91.225% & val_loss = 3.5730%

Fine-tuning...


100%|██████████| 4/4 [00:03<00:00,  1.19it/s]


Pre fine-tuning: val_loss = 0.036


100%|██████████| 1559/1559 [12:04<00:00,  2.15it/s]
100%|██████████| 4/4 [00:02<00:00,  1.33it/s]


epoch = 1 train_loss = 0.037, val_loss = 0.0525


100%|██████████| 1559/1559 [12:08<00:00,  2.14it/s]
100%|██████████| 4/4 [00:03<00:00,  1.28it/s]


epoch = 2 train_loss = 0.038, val_loss = 0.0643


100%|██████████| 1559/1559 [12:04<00:00,  2.15it/s]
100%|██████████| 4/4 [00:02<00:00,  1.37it/s]


epoch = 3 train_loss = 0.038, val_loss = 0.2434


100%|██████████| 1559/1559 [12:11<00:00,  2.13it/s]
100%|██████████| 4/4 [00:02<00:00,  1.35it/s]


epoch = 4 train_loss = 0.038, val_loss = 0.0966


100%|██████████| 1559/1559 [12:02<00:00,  2.16it/s]
100%|██████████| 4/4 [00:03<00:00,  1.30it/s]


epoch = 5 train_loss = 0.033, val_loss = 0.0416


100%|██████████| 1559/1559 [12:02<00:00,  2.16it/s]
100%|██████████| 4/4 [00:02<00:00,  1.36it/s]


epoch = 6 train_loss = 0.032, val_loss = 0.0393


100%|██████████| 1559/1559 [12:00<00:00,  2.16it/s]
100%|██████████| 4/4 [00:03<00:00,  1.26it/s]


epoch = 7 train_loss = 0.032, val_loss = 0.0400


100%|██████████| 1559/1559 [12:01<00:00,  2.16it/s]
100%|██████████| 4/4 [00:02<00:00,  1.39it/s]


epoch = 8 train_loss = 0.032, val_loss = 0.0398


100%|██████████| 1559/1559 [12:04<00:00,  2.15it/s]
100%|██████████| 4/4 [00:02<00:00,  1.40it/s]


epoch = 9 train_loss = 0.032, val_loss = 0.0429


100%|██████████| 1559/1559 [11:56<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.36it/s]


epoch = 10 train_loss = 0.032, val_loss = 0.0402


100%|██████████| 1559/1559 [12:01<00:00,  2.16it/s]
100%|██████████| 4/4 [00:02<00:00,  1.39it/s]


epoch = 11 train_loss = 0.032, val_loss = 0.0412


100%|██████████| 1559/1559 [11:58<00:00,  2.17it/s]
100%|██████████| 4/4 [00:02<00:00,  1.38it/s]


epoch = 12 train_loss = 0.031, val_loss = 0.0392


100%|██████████| 1559/1559 [11:58<00:00,  2.17it/s]
100%|██████████| 4/4 [00:02<00:00,  1.37it/s]


epoch = 13 train_loss = 0.030, val_loss = 0.0377


100%|██████████| 1559/1559 [11:56<00:00,  2.18it/s]
100%|██████████| 4/4 [00:02<00:00,  1.39it/s]


epoch = 14 train_loss = 0.030, val_loss = 0.0378


100%|██████████| 1559/1559 [12:01<00:00,  2.16it/s]
100%|██████████| 4/4 [00:02<00:00,  1.35it/s]


epoch = 15 train_loss = 0.030, val_loss = 0.0377


100%|██████████| 1559/1559 [12:03<00:00,  2.15it/s]
100%|██████████| 4/4 [00:02<00:00,  1.51it/s]


epoch = 16 train_loss = 0.030, val_loss = 0.0382


100%|██████████| 4/4 [00:03<00:00,  1.30it/s]


Post fine-tuning: Global sparsity = 91.225% & val_loss = 3.819%

Pruning and Finetuning 8/8
Pruning...


100%|██████████| 4/4 [00:02<00:00,  1.38it/s]


Global sparsity = 92.887% & val_loss = 3.8187%

Fine-tuning...


100%|██████████| 4/4 [00:03<00:00,  1.25it/s]


Pre fine-tuning: val_loss = 0.038


100%|██████████| 1559/1559 [12:05<00:00,  2.15it/s]
100%|██████████| 4/4 [00:03<00:00,  1.33it/s]


epoch = 1 train_loss = 0.039, val_loss = 0.0956


100%|██████████| 1559/1559 [12:04<00:00,  2.15it/s]
100%|██████████| 4/4 [00:02<00:00,  1.46it/s]


epoch = 2 train_loss = 0.038, val_loss = 0.2248


 58%|█████▊    | 908/1559 [07:04<05:04,  2.14it/s]


KeyboardInterrupt: 

In [25]:
final_model = remove_parameters(model=model)

In [28]:
num_zeros, num_elements, sparsity = measure_global_sparsity(
    final_model, weight = True,
    bias = False, conv2d_use_mask = False,
    linear_use_mask = False
)
print(sparsity)

0.9681355074162337


In [29]:
torch.save(final_model.state_dict(), f"./pruned/ResNet50_trained_sparsity-{sparsity * 100:.3f}.pth")

In [30]:
eval_loss = validate(model)

100%|██████████| 4/4 [00:02<00:00,  1.41it/s]


In [31]:
print(eval_loss)

0.05962326191365719
