In [1]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.optim import AdamW, lr_scheduler
import pytorch_lightning as pl
from torchvision import models
import torch_pruning as tp
import wandb
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import networkx as nx
import matplotlib.pyplot as plt
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, Subset

# Set MPS as the default device
torch.set_default_device("mps")

In [2]:
class AlexNetFineTuner(pl.LightningModule):
    def __init__(self, learning_rate=1e-4, num_classes=10):
        super(AlexNetFineTuner, self).__init__()
        self.save_hyperparameters()

        # Load pre-trained AlexNet
        self.model = models.alexnet(pretrained=False)
        self.model.classifier[6] = torch.nn.Linear(4096, num_classes)  # Update the classifier layer

        # Initialize metrics dictionary
        self.metrics = {
            "pruning_percentage": [],
            "test_accuracy": [],
            "test_loss": [],
            "model_size": []
        }
        
        self.test_outputs = []
    def forward(self, x):
        return self.model(x)

    def print_dependency_graph(self, DG):
        print("\nDependency Graph Details:")
        for module, node in DG.module2node.items():
            print(f"Module: {module}")
            for dep in node.dependencies:
                print(f"    * Target Module: {dep.target.module}")

    def visualize_dependency_graph(self, DG):
        """Visualize the dependency graph using networkx."""
        G = nx.DiGraph()
        for module, node in DG.module2node.items():
            for dep in node.dependencies:
                G.add_edge(str(module), str(dep.target.module))
        plt.figure(figsize=(12, 8))
        nx.draw(G, with_labels=True, node_size=1000, font_size=8, node_color="skyblue", edge_color="gray")
        plt.title("Dependency Graph")
        plt.show()

    # def prune_model(self, pruning_percentage=0.2):
    #     # example_inputs = torch.randn(1, 3, 224, 224)  # Example input for dependency graph
    #     example_inputs = torch.randn(1, 3, 224, 224, dtype=torch.float32).to("mps")
    #     imp = tp.importance.TaylorImportance()

    #     ignored_layers = []
    #     for m in model.modules():
    #         if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
    #             ignored_layers.append(m) # DO NOT prune the final classifier!

    #     iterative_steps = 5 # progressive pruning
    #     pruner = tp.pruner.MagnitudePruner(
    #         model,
    #         example_inputs,
    #         importance=imp,
    #         iterative_steps=iterative_steps,
    #         ch_sparsity=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    #         ignored_layers=ignored_layers,
    #     )
        
    #     self.model = self.model.to("mps")
    #     self.model = self.model.to(torch.float32)

    #     print("MODEL BEFORE PRUNING --------------------->", self.model)

    #     DG = tp.DependencyGraph().build_dependency(self.model, example_inputs)
    #     self.print_dependency_graph(DG)
    #     self.visualize_dependency_graph(DG)

    #     # Choose a layer to prune
    #     layer_to_prune = self.model.classifier[1]

    #     # Prune layer
    #     num_features = layer_to_prune.out_features
    #     pruning_idxs = torch.arange(0, int(num_features * pruning_percentage)).tolist()
    #     group = DG.get_pruning_group(layer_to_prune, tp.prune_linear_out_channels, idxs=pruning_idxs)
    #     if DG.check_pruning_group(group):
    #         group.prune()

    #     print("MODEL AFTER PRUNING --------------------->", self.model)
    #     print("PRUNING GROUP ---------------------->", group)

    #     all_groups = list(DG.get_all_groups())
    #     print("Number of Groups: --------------------->%d"%len(all_groups))
    #     print("The last Group: --------------------->", all_groups[-1])
    #     # Update metrics after pruning
    #     self.metrics["model_size"].append(sum(p.numel() for p in self.model.parameters() if p.requires_grad))
    #     self.metrics["pruning_percentage"].append(pruning_percentage * 100)

    def high_level_prune_model(self, ch_sparsity=0.5, iterative_steps=5, example_inputs=None):
        if example_inputs is None:
            example_inputs = torch.randn(1, 3, 224, 224, dtype=torch.float32).to("mps")
        
        print(f"Initial model state -> Device: {next(self.model.parameters()).device}, Dtype: {next(self.model.parameters()).dtype}")
        self.model = self.model.to("mps").to(torch.float32)

        # Define importance criteria
        imp = tp.importance.TaylorImportance()
        print("IMPORTANCE CRITERIA------------------->", imp)

        # Specify ignored layers (e.g., the final classification layer)
        ignored_layers = []
        for m in self.model.modules():
            if isinstance(m, torch.nn.Linear) and m.out_features == 10:  # Adjust for CIFAR-10
                ignored_layers.append(m)

        # Initialize high-level pruner
        pruner = tp.pruner.MagnitudePruner(
            self.model,
            example_inputs,
            importance=imp,
            iterative_steps=iterative_steps,
            ch_sparsity=ch_sparsity,
            ignored_layers=ignored_layers,
        )

        # Count initial MACs and parameters
        base_macs, base_nparams = tp.utils.count_ops_and_params(self.model, example_inputs)
        print(f"Initial MACs: {base_macs}, Parameters: {base_nparams}")

        # Iteratively prune and update metrics
        for i in range(iterative_steps):
            print(f"Pruning step {i + 1}/{iterative_steps}...")

            # If using TaylorImportance, calculate gradients
            if isinstance(imp, tp.importance.TaylorImportance):
                self.model.zero_grad()  # Clear gradients
                loss = self.model(example_inputs).sum()  # Dummy loss
                loss.backward()  # Backpropagate to calculate gradients

            # Prune the model
            pruner.step()

            # Recalculate MACs and parameters
            macs, nparams = tp.utils.count_ops_and_params(self.model, example_inputs)
            print(f"After step {i + 1}: MACs={macs}, Params={nparams}")

            # Update metrics
            self.metrics["model_size"].append(nparams)
            self.metrics["pruning_percentage"].append((ch_sparsity * 100 * (i + 1) / iterative_steps))

        print("Pruning complete.")
        # Final model stats
        final_macs, final_nparams = tp.utils.count_ops_and_params(self.model, example_inputs)
        print(f"Final MACs: {final_macs}, Parameters: {final_nparams}")



    def training_step(self, batch, batch_idx):
        images, labels = batch
        images = images.to(torch.float32)  # Convert inputs to float32
        labels = labels.to("mps")  # Ensure labels are on MPS

        print("Images device and dtype:---------------------", images.device, images.dtype)
        print("Labels device and dtype:", labels.device, labels.dtype)

        outputs = self(images)
        loss = F.cross_entropy(outputs, labels)
        preds = torch.argmax(outputs, dim=1)
        acc = (preds == labels).float().mean()
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log("train_acc", acc, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        images = images.to(torch.float32)  # Convert inputs to float32
        labels = labels.to("mps")  # Ensure labels are on MPS
        outputs = self(images)
        val_loss = F.cross_entropy(outputs, labels)
        preds = torch.argmax(outputs, dim=1)
        val_acc = (preds == labels).float().mean()
        self.log("val_loss", val_loss, prog_bar=True)
        self.log("val_acc", val_acc, prog_bar=True)
        return val_loss

    def test_step(self, batch, batch_idx):
        images, labels = batch
        images = images.to(torch.float32)
        labels = labels.to("mps")
        outputs = self(images)
        test_loss = F.cross_entropy(outputs, labels)
        preds = torch.argmax(outputs, dim=1)
        test_acc = (preds == labels).float().mean()

        # Store batch results for aggregation later
        self.test_outputs.append({"test_loss": test_loss.item(), "test_acc": test_acc.item()})

        # Log per-batch results if needed
        self.log("test_loss_batch", test_loss, prog_bar=True)
        self.log("test_acc_batch", test_acc, prog_bar=True)

        return test_loss

    def on_test_epoch_end(self):
        # Aggregate metrics across batches
        avg_loss = sum(o["test_loss"] for o in self.test_outputs) / len(self.test_outputs)
        avg_acc = sum(o["test_acc"] for o in self.test_outputs) / len(self.test_outputs)

        # Append aggregated metrics for the epoch
        self.metrics["test_accuracy"].append(avg_acc)
        self.metrics["test_loss"].append(avg_loss)

        # Log aggregated metrics
        self.log("test_loss_epoch", avg_loss, prog_bar=True)
        self.log("test_acc_epoch", avg_acc, prog_bar=True)

        # Clear outputs for the next test epoch
        self.test_outputs = []


    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.hparams.learning_rate)
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)
        return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": "val_loss"}}



In [3]:
def plot_metrics(metrics):
    metrics["pruning_percentage"] = [0, 0.1,0.2,0.3,0.4,0.5]
    print("----------------------->1", metrics["pruning_percentage"], metrics["test_accuracy[1,7]"])
    print("----------------------->2", metrics["pruning_percentage"], metrics["test_loss"])
    print("----------------------->2", metrics["pruning_percentage"], metrics["model_size"])

    plt.figure()
    plt.plot(metrics["pruning_percentage"], metrics["test_accuracy"], marker='o', label="Accuracy")
    plt.title("Test Accuracy vs. Pruning Percentage")
    plt.xlabel("Pruning Percentage (%)")
    plt.ylabel("Accuracy")
    plt.grid(True)
    plt.legend()
    plt.show()

    plt.figure()
    plt.plot(metrics["pruning_percentage"], metrics["test_loss"], marker='o', color="orange", label="Loss")
    plt.title("Test Loss vs. Pruning Percentage")
    plt.xlabel("Pruning Percentage (%)")
    plt.ylabel("Loss")
    plt.grid(True)
    plt.legend()
    plt.show()

    plt.figure()
    plt.plot(metrics["pruning_percentage"], metrics["model_size"], marker='o', color="green", label="Model Size")
    plt.title("Model Size vs. Pruning Percentage")
    plt.xlabel("Pruning Percentage (%)")
    plt.ylabel("Number of Parameters")
    plt.grid(True)
    plt.legend()
    plt.show()

In [4]:
def main():
    wandb.init(project='alexnet_depGraph', name='AlexNet_HighLevelPruning')
    wandb_logger = WandbLogger(log_model=False)

    transform = Compose([
        Resize((224, 224)),
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    test_dataset = Subset(CIFAR10(root='./data', train=False, download=True, transform=transform), range(100))
    test_dataloader = DataLoader(test_dataset, batch_size=32)

    checkpoint_path = "../checkpointsAlex/best-checkpoint.ckpt"
    model = AlexNetFineTuner.load_from_checkpoint(checkpoint_path)
    model = model.to(torch.float32).to("mps")

    # Baseline test (0% pruning)
    print("Evaluating baseline model (0% pruning)...")
    trainer = pl.Trainer(logger=wandb_logger, max_epochs=1)
    trainer.test(model, dataloaders=test_dataloader)

    # Pruning
    pruning_percentages = [0, 0.1, 0.2, 0.3, 0.4, 0.5]
    for pruning_percentage in pruning_percentages:
        print(f"Applying pruning at {pruning_percentage * 100}%...")
        model.high_level_prune_model(ch_sparsity=pruning_percentage, iterative_steps=1)
        trainer.test(model, dataloaders=test_dataloader)

    # Plot Metrics
    plot_metrics(model.metrics)
    wandb.finish()



if __name__ == "__main__":
    main()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mp-mangal[0m ([33mp-mangal-university-of-amsterdam[0m). Use [1m`wandb login --relogin`[0m to force relogin


Files already downloaded and verified


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/pytorch_lightning/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Evaluating baseline model (0% pruning)...


Testing: |          | 0/? [00:00<?, ?it/s]

Applying pruning at 0%...
Initial model state -> Device: cpu, Dtype: torch.float32
IMPORTANCE CRITERIA-------------------> <torch_pruning.pruner.importance.TaylorImportance object at 0x30eaa54f0>




Initial MACs: 711505866.0, Parameters: 57044810
Pruning step 1/1...
After step 1: MACs=711505866.0, Params=57044810
Pruning complete.
Final MACs: 711505866.0, Parameters: 57044810


Testing: |          | 0/? [00:00<?, ?it/s]

Applying pruning at 10.0%...
Initial model state -> Device: cpu, Dtype: torch.float32
IMPORTANCE CRITERIA-------------------> <torch_pruning.pruner.importance.TaylorImportance object at 0x30e7a76e0>
Initial MACs: 711505866.0, Parameters: 57044810
Pruning step 1/1...
After step 1: MACs=578047534.0, Params=46142053
Pruning complete.
Final MACs: 578047534.0, Parameters: 46142053


Testing: |          | 0/? [00:00<?, ?it/s]

Applying pruning at 20.0%...
Initial model state -> Device: cpu, Dtype: torch.float32
IMPORTANCE CRITERIA-------------------> <torch_pruning.pruner.importance.TaylorImportance object at 0x30e9feb40>
Initial MACs: 578047534.0, Parameters: 46142053
Pruning step 1/1...
After step 1: MACs=377242916.0, Params=29526996
Pruning complete.
Final MACs: 377242916.0, Parameters: 29526996


Testing: |          | 0/? [00:00<?, ?it/s]

Applying pruning at 30.0%...
Initial model state -> Device: cpu, Dtype: torch.float32
IMPORTANCE CRITERIA-------------------> <torch_pruning.pruner.importance.TaylorImportance object at 0x30e644b90>
Initial MACs: 377242916.0, Parameters: 29526996
Pruning step 1/1...
After step 1: MACs=192553590.0, Params=14407299
Pruning complete.
Final MACs: 192553590.0, Parameters: 14407299


Testing: |          | 0/? [00:00<?, ?it/s]

Applying pruning at 40.0%...
Initial model state -> Device: cpu, Dtype: torch.float32
IMPORTANCE CRITERIA-------------------> <torch_pruning.pruner.importance.TaylorImportance object at 0x30d34aea0>
Initial MACs: 192553590.0, Parameters: 14407299
Pruning step 1/1...
After step 1: MACs=75838955.0, Params=5151620
Pruning complete.
Final MACs: 75838955.0, Parameters: 5151620


Testing: |          | 0/? [00:00<?, ?it/s]

Applying pruning at 50.0%...
Initial model state -> Device: cpu, Dtype: torch.float32
IMPORTANCE CRITERIA-------------------> <torch_pruning.pruner.importance.TaylorImportance object at 0x30d34aea0>
Initial MACs: 75838955.0, Parameters: 5151620
Pruning step 1/1...
After step 1: MACs=23825366.0, Params=1291365
Pruning complete.
Final MACs: 23825366.0, Parameters: 1291365


Testing: |          | 0/? [00:00<?, ?it/s]

KeyError: 'test_accuracy[1,7]'