In [1]:
import torch
import torch.nn.utils.prune as prune
import os
from model.networks import Generator
import copy
import numpy as np




In [2]:
def load_model(checkpoint_path):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    generator = Generator(checkpoint=checkpoint_path).to(device)
    return generator

def count_zero_weights(model):
    zero_count = 0
    total_count = 0
    for name, param in model.named_parameters():
        zero_count += torch.sum(param == 0).item()
        total_count += param.numel()
    return zero_count, total_count

def prune_model(model, amount):
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
            prune.l1_unstructured(module, name='weight', amount=amount)
            prune.remove(module, 'weight')
    return model

def save_model(model, path):
    state_dict = {'G': model.state_dict()}  # Ensure the 'G' key is included
    torch.save(state_dict, path)

def print_model_size(model):
    param_size = sum(param.numel() for param in model.parameters())
    buffer_size = sum(buffer.numel() for buffer in model.buffers())
    size_all_mb = (param_size + buffer_size) * 4 / 1024**2
    print(f'Model size in memory: {size_all_mb:.3f} MB')

def test_model(model):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.eval()
    with torch.no_grad():
        dummy_input = torch.randn(1, 5, 256, 256).to(device)  # assuming input shape
        dummy_mask = torch.ones(1, 1, 256, 256).to(device)
        try:
            output = model(dummy_input, dummy_mask)
            print("Model forward pass successful.")
        except Exception as e:
            print(f"Model forward pass failed: {e}")

def get_file_size(path):
    size = os.path.getsize(path) / 1024**2  # in MB
    return size

def main(prune_amounts, initial_checkpoint_path, save_dir):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    model = load_model(initial_checkpoint_path)

    for i, amount in enumerate(prune_amounts):
        zero_count_before, total_count = count_zero_weights(model)

        pruned_model = prune_model(copy.deepcopy(model), amount)

        save_path = os.path.join(save_dir, f"pruned_model_{int(amount*100)}.pth")
        save_model(pruned_model, save_path)
        print(f"Saved pruned model with {amount*100}% pruning at: {save_path}")

        test_model(pruned_model)



In [3]:
if __name__ == "__main__":
    initial_checkpoint_path = "C:/Users/tuant/Downloads/AI Thesis/FINAL_APP_AND_ANALYSIS/optimized_models/states_pt_places2.pth"
    save_dir = "C:/Users/tuant/Downloads/AI Thesis/FINAL_APP_AND_ANALYSIS/optimized_models"
    prune_amounts = [0.00]
    # for i in np.arange(0.01, 1.00, 0.01):
    #     prune_amounts.append(i)

    main(prune_amounts, initial_checkpoint_path, save_dir)

Saved pruned model with 0.0% pruning at: C:/Users/tuant/Downloads/AI Thesis/FINAL_APP_AND_ANALYSIS/optimized_models\pruned_model_0.pth
Model forward pass successful.
