In [None]:
import sys
from google.colab import drive
drive.mount("/content/drive/", force_remount=True)
sys.path.append('')

In [None]:
import torch
import torch.nn as nn
import numpy as np
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader
from models.ResNet import ResNet18, ResNet34
from evaluation.validate import validate
from evaluation.evaluation import evaluate_weights_prune, evaluate_activation_prune
from utils.prune import weight_prune

In [None]:
from data.dataset import get_dataset
from data.dataloader import get_dataloader

In [None]:
_, val_dataset_cifar10 = get_dataset("cifar10")

In [None]:
_, val_dataloader_cifar10 = get_dataloader(_, val_dataset_cifar10, 256)

# ResNet18 to ResNet18

## Weights Prune

In [None]:
thresholds = np.linspace(1e-4, 1e-2, num=20).tolist()

### ResNet18 Base + L1 Prune

In [None]:
ResNet18_b_10 = ResNet18(10)
ResNet18_b_10_config = torch.load("./ResNet18_b_10.pth", map_location=torch.device("cuda"))
ResNet18_b_10.load_state_dict(ResNet18_b_10_config)

In [None]:
ResNet18_b_10.to("cuda")
validate(ResNet18_b_10, val_dataloader_cifar10)

In [None]:
acc_d_l_10, spar_d_l_10 = evaluate_weights_prune(ResNet18_b_10, thresholds, val_dataloader_cifar10)

In [None]:
for acc, sparsity in zip(acc_d_l_10, spar_d_l_10):
  print(f"accuracy is {acc}, sparsity is {sparsity}")

### ResNet18 Base + L2 retrain

In [None]:
ResNet18_b_l2_retrain_10 = ResNet18(10)
ResNet18_b_l2_retrain_10_config = torch.load("./ResNet18_b_l2_retrain_10.pth", map_location=torch.device("cuda"))
ResNet18_b_l2_retrain_10.load_state_dict(ResNet18_b_l2_retrain_10_config["model_state_dict"])

In [None]:
ResNet18_b_l2_retrain_10.to("cuda")
validate(ResNet18_b_l2_retrain_10, val_dataloader_cifar10)

In [None]:
acc_d_l2_10, spar_d_l2_10 = evaluate_weights_prune(ResNet18_b_l2_retrain_10, thresholds, val_dataloader_cifar10)

In [None]:
for acc, sparsity in zip(acc_d_l2_10, spar_d_l2_10):
  print(f"accuracy is {acc}, sparsity is {sparsity}")

### ResNet18 Dis from 18

In [None]:
ResNet18_d_f18_l_s_10 = ResNet18(10)
ResNet18_d_f18_l_s_10_config = torch.load("./ResNet18_d_l_s_5e3_10.pth", map_location=torch.device("cuda"))
ResNet18_d_f18_l_s_10.load_state_dict(ResNet18_d_f18_l_s_10_config["model_state_dict"])

In [None]:
ResNet18_d_f18_l_s_10.to("cuda")
validate(ResNet18_d_f18_l_s_10, val_dataloader_cifar10)

In [None]:
acc_d_s_10, spar_d_s_10 = evaluate_weights_prune(ResNet18_d_f18_l_s_10, thresholds, val_dataloader_cifar10)

In [None]:
for acc, sparsity in zip(acc_d_s_10, spar_d_s_10):
  print(f"accuracy is {acc}, sparsity is {sparsity}")

### ResNet18 Dis from 34

In [None]:
ResNet18_d_f34_l_s_10 = ResNet18(10)
ResNet18_d_f34_l_s_10_config = torch.load("./ResNet18_d_f34_l_s_5e3_10.pth", map_location=torch.device("cuda"))
ResNet18_d_f34_l_s_10.load_state_dict(ResNet18_d_f34_l_s_10_config["model_state_dict"])

In [None]:
ResNet18_d_f34_l_s_10.to("cuda")
validate(ResNet18_d_f34_l_s_10, val_dataloader_cifar10)

In [None]:
acc_d_34_s_10, spar_d_34_s_10 = evaluate_weights_prune(ResNet18_d_f34_l_s_10, thresholds, val_dataloader_cifar10)

In [None]:
for acc, sparsity in zip(acc_d_34_s_10, spar_d_34_s_10):
  print(f"accuracy is {acc}, sparsity is {sparsity}")

### Plot

In [None]:
model_names = ['Base Model L1 Prune', 'L2 Retrain Model Prune', 'Distillation Model from ResNet18', 'Distillation Model from ResNet34']
accuracy_lists = [acc_d_l_10, acc_d_l2_10, acc_d_s_10, acc_d_34_s_10]
sparsity_lists = [spar_d_l_10, spar_d_l2_10, spar_d_s_10, spar_d_34_s_10]

In [None]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np

def plot_prune_metrics_dual_axis_all_models(thresholds, accuracy_lists, sparsity_lists, model_names):
    fig, ax1 = plt.subplots(figsize=(12, 6))

    num_models = len(model_names)
    colors = cm.get_cmap('tab10', num_models)

    for idx, (acc_list, name) in enumerate(zip(accuracy_lists, model_names)):
        ax1.plot(thresholds, acc_list, label=f'{name} Acc', color=colors(idx), linestyle='-')
    ax1.set_xlabel("Prune Threshold")
    ax1.set_ylabel("Accuracy")
    ax1.set_xscale("log")
    ax1.tick_params(axis='y')
    ax1.grid(False)

    ax2 = ax1.twinx()
    for idx, (sparsity_list, name) in enumerate(zip(sparsity_lists, model_names)):
        ax2.plot(thresholds, sparsity_list, label=f'{name} FLOPs', color=colors(idx), linestyle='--')
    ax2.set_ylabel("FLOPs")
    ax2.tick_params(axis='y')
    ax2.set_yscale("log")

    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax1.legend(lines1 + lines2, labels1 + labels2, loc='lower left')

    plt.title("Threshold vs Accuracy & FLOPs")
    plt.tight_layout()
    plt.show()

In [None]:
plot_prune_metrics_dual_axis_all_models(thresholds, accuracy_lists, sparsity_lists, model_names)

## Activation Prune

In [None]:
act_thresholds = np.linspace(1e-4, 1e-2, num=50).tolist()

### ResNet18 Base

In [None]:
ResNet18_b_10 = ResNet18(10)
ResNet18_b_10_config = torch.load("./ResNet18_b_10.pth", map_location=torch.device("cuda"))
ResNet18_b_10.load_state_dict(ResNet18_b_10_config)
ResNet18_b_10.to("cuda")
validate(ResNet18_b_10, val_dataloader_cifar10)

In [None]:
act_acc_b_10, act_flops_b_10 = evaluate_activation_prune(ResNet18_b_10, act_thresholds, val_dataloader_cifar10)

In [None]:
for acc, flops in zip(act_acc_b_10, act_flops_b_10):
  print(f"accuracy is {acc}, flops is {flops}")

### ResNet18 Base + L2 retrain

In [None]:
ResNet18_b_l2_retrain_10 = ResNet18(10)
ResNet18_b_l2_retrain_10_config = torch.load("./ResNet18_b_l2_retrain_10.pth", map_location=torch.device("cuda"))
ResNet18_b_l2_retrain_10.load_state_dict(ResNet18_b_l2_retrain_10_config["model_state_dict"])
ResNet18_b_l2_retrain_10.to("cuda")
validate(ResNet18_b_l2_retrain_10, val_dataloader_cifar10)

In [None]:
act_acc_d_l2_10, act_spar_d_l2_10 = evaluate_activation_prune(ResNet18_b_l2_retrain_10, act_thresholds, val_dataloader_cifar10)

In [None]:
for acc, flops in zip(act_acc_d_l2_10, act_spar_d_l2_10):
  print(f"accuracy is {acc}, flops is {flops}")

### ResNet18 Dis from 18

In [None]:
ResNet18_d_f18_l_s_10 = ResNet18(10)
ResNet18_d_f18_l_s_10_config = torch.load("./ResNet18_d_l_s_5e3_10.pth", map_location=torch.device("cuda"))
ResNet18_d_f18_l_s_10.load_state_dict(ResNet18_d_f18_l_s_10_config["model_state_dict"])
ResNet18_d_f18_l_s_10.to("cuda")
validate(ResNet18_d_f18_l_s_10, val_dataloader_cifar10)

In [None]:
act_acc_d_s_10, act_spar_d_s_10 = evaluate_activation_prune(ResNet18_d_f18_l_s_10, act_thresholds, val_dataloader_cifar10)

In [None]:
for acc, flops in zip(act_acc_d_s_10, act_spar_d_s_10):
  print(f"accuracy is {acc}, flops is {flops}")

### ResNet18 Dis from 34

In [None]:
ResNet18_d_f34_l_s_10 = ResNet18(10)
ResNet18_d_f34_l_s_10_config = torch.load("./ResNet18_d_f34_l_s_5e3_10.pth", map_location=torch.device("cuda"))
ResNet18_d_f34_l_s_10.load_state_dict(ResNet18_d_f34_l_s_10_config["model_state_dict"])
ResNet18_d_f34_l_s_10.to("cuda")
validate(ResNet18_d_f34_l_s_10, val_dataloader_cifar10)

In [None]:
act_acc_d_34_s_10, act_spar_d_34_s_10 = evaluate_activation_prune(ResNet18_d_f34_l_s_10, act_thresholds, val_dataloader_cifar10)

In [None]:
for acc, flops in zip(act_acc_d_34_s_10, act_spar_d_34_s_10):
  print(f"accuracy is {acc}, flops is {flops}")

### Plot

In [None]:
model_names = ['Base Model L1 Prune', 'L2 Retrain Model Prune', 'Distillation Model from ResNet18', 'Distillation Model from ResNet34']
accuracy_lists = [act_acc_b_10, act_acc_d_l2_10, act_acc_d_s_10, act_acc_d_34_s_10]
sparsity_lists = [act_flops_b_10, act_spar_d_l2_10, act_spar_d_s_10, act_spar_d_34_s_10]

In [None]:
plot_prune_metrics_dual_axis_all_models(thresholds, accuracy_lists, sparsity_lists, model_names)

## Joint Prune

In [None]:
weight_thresholds = thresholds
activation_thresholds = act_thresholds

In [None]:
def evaluate_joint_prune(model: nn.Module,
                              weight_thresholds: list,
                              activation_thresholds: list,
                              val_dataloader: DataLoader,
                              device: str="cuda") -> tuple:
    model.to(device)
    model.eval()

    accuracy_lists = []
    flops_lists = []

    for weight_threshold in tqdm(weight_thresholds, leave=False):
        weight_prune(model, weight_threshold)
        accuracy_list, flops_list = evaluate_activation_prune(model, activation_thresholds, val_dataloader)
        accuracy_lists.append(accuracy_list)
        flops_lists.append(flops_list)

    return accuracy_lists, flops_lists

### ResNet18 Base L1

In [None]:
ResNet18_b_10 = ResNet18(10)
ResNet18_b_10_config = torch.load("./ResNet18_b_10.pth", map_location=torch.device("cuda"))
ResNet18_b_10.load_state_dict(ResNet18_b_10_config)
ResNet18_b_10.to("cuda")
validate(ResNet18_b_10, val_dataloader_cifar10)

In [None]:
joint_b_acc, joint_b_flops = evaluate_joint_prune(ResNet18_b_10, weight_thresholds, activation_thresholds, val_dataloader_cifar10)

### ResNet18 Base + L2 retrain

In [None]:
ResNet18_b_l2_retrain_10 = ResNet18(10)
ResNet18_b_l2_retrain_10_config = torch.load("./ResNet18_b_l2_retrain_10.pth", map_location=torch.device("cuda"))
ResNet18_b_l2_retrain_10.load_state_dict(ResNet18_b_l2_retrain_10_config["model_state_dict"])
ResNet18_b_l2_retrain_10.to("cuda")
validate(ResNet18_b_l2_retrain_10, val_dataloader_cifar10)

In [None]:
joint_b_l2_acc, joint_b_l2_flops = evaluate_joint_prune(ResNet18_b_l2_retrain_10, weight_thresholds, activation_thresholds, val_dataloader_cifar10)

### ResNet18 Dis from 18

In [None]:
ResNet18_d_f18_l_s_10 = ResNet18(10)
ResNet18_d_f18_l_s_10_config = torch.load("./ResNet18_d_l_s_5e3_10.pth", map_location=torch.device("cuda"))
ResNet18_d_f18_l_s_10.load_state_dict(ResNet18_d_f18_l_s_10_config["model_state_dict"])
ResNet18_d_f18_l_s_10.to("cuda")
validate(ResNet18_d_f18_l_s_10, val_dataloader_cifar10)

In [None]:
joint_d_18_acc, joint_d_18_flops = evaluate_joint_prune(ResNet18_d_f18_l_s_10, weight_thresholds, activation_thresholds, val_dataloader_cifar10)

### ResNet18 Dis from 34

In [None]:
ResNet18_d_f34_l_s_10 = ResNet18(10)
ResNet18_d_f34_l_s_10_config = torch.load("./ResNet18_d_f34_l_s_5e3_10.pth", map_location=torch.device("cuda"))
ResNet18_d_f34_l_s_10.load_state_dict(ResNet18_d_f34_l_s_10_config["model_state_dict"])
ResNet18_d_f34_l_s_10.to("cuda")
validate(ResNet18_d_f34_l_s_10, val_dataloader_cifar10)

In [None]:
joint_d_34_acc, joint_d_34_flops = evaluate_joint_prune(ResNet18_d_f34_l_s_10, [0.00076, 0.0011174], [0.013795918367346938, 0.01806122448979592], val_dataloader_cifar10)

In [None]:
for acc_list, flops_list in zip(joint_d_34_acc, joint_d_34_flops):
  for i in range(len(acc_list)):
    acc = acc_list[i]
    flops = flops_list[i]
    print(f"accuracy is {acc}, flops is {flops}")

# penalty comparison

In [None]:
ResNet18_d_l_1e3_s_10 = ResNet18(10)
ResNet18_d_l_1e3_s_10_config = torch.load("./ResNet18_d_l_s_1e3_10.pth", map_location=torch.device("cuda"))
ResNet18_d_l_1e3_s_10.load_state_dict(ResNet18_d_l_1e3_s_10_config["model_state_dict"])
act_acc_d_l_1e3_s_10, act_flops_d_l_1e3_s_10 = evaluate_activation_prune(ResNet18_d_l_1e3_s_10, act_thresholds, val_dataloader_cifar10)

In [None]:
ResNet18_d_l_2e3_s_10 = ResNet18(10)
ResNet18_d_l_2e3_s_10_config = torch.load("./ResNet18_d_l_s_2e3_10.pth", map_location=torch.device("cuda"))
ResNet18_d_l_2e3_s_10.load_state_dict(ResNet18_d_l_2e3_s_10_config["model_state_dict"])
act_acc_d_l_2e3_s_10, act_flops_d_l_2e3_s_10 = evaluate_activation_prune(ResNet18_d_l_2e3_s_10, act_thresholds, val_dataloader_cifar10)

In [None]:
ResNet18_d_l_3e3_s_10 = ResNet18(10)
ResNet18_d_l_3e3_s_10_config = torch.load("./ResNet18_d_l_s_3e3_10.pth", map_location=torch.device("cuda"))
ResNet18_d_l_3e3_s_10.load_state_dict(ResNet18_d_l_3e3_s_10_config["model_state_dict"])
act_acc_d_l_3e3_s_10, act_flops_d_l_3e3_s_10 = evaluate_activation_prune(ResNet18_d_l_3e3_s_10, act_thresholds, val_dataloader_cifar10)

In [None]:
ResNet18_d_l_4e3_s_10 = ResNet18(10)
ResNet18_d_l_4e3_s_10_config = torch.load("./ResNet18_d_l_s_4e3_10.pth", map_location=torch.device("cuda"))
ResNet18_d_l_4e3_s_10.load_state_dict(ResNet18_d_l_4e3_s_10_config["model_state_dict"])
act_acc_d_l_4e3_s_10, act_flops_d_l_4e3_s_10 = evaluate_activation_prune(ResNet18_d_l_4e3_s_10, act_thresholds, val_dataloader_cifar10)

In [None]:
ResNet18_d_l_5e3_s_10 = ResNet18(10)
ResNet18_d_l_5e3_s_10_config = torch.load("./ResNet18_d_l_s_5e3_10.pth", map_location=torch.device("cuda"))
ResNet18_d_l_5e3_s_10.load_state_dict(ResNet18_d_l_5e3_s_10_config["model_state_dict"])
act_acc_d_l_5e3_s_10, act_flops_d_l_5e3_s_10 = evaluate_activation_prune(ResNet18_d_l_5e3_s_10, act_thresholds, val_dataloader_cifar10)

In [None]:
model_names = ['Soft KL', 'KL']
accuracy_lists = [act_acc_d_l_s_10, act_acc_d_l_ns_10]
flops_lists = [act_flops_d_l_s_10, act_flops_d_l_ns_10]

In [None]:
model_names = ['1.5e-3', '2e-3', '3e-3', '4e-3', '5e-3']
accuracy_lists = [act_acc_d_l_1e3_s_10, act_acc_d_l_2e3_s_10, act_acc_d_l_3e3_s_10, act_acc_d_l_4e3_s_10, act_acc_d_l_5e3_s_10]
flops_lists = [act_flops_d_l_1e3_s_10, act_flops_d_l_2e3_s_10, act_flops_d_l_3e3_s_10, act_flops_d_l_4e3_s_10, act_flops_d_l_5e3_s_10]

In [None]:
import matplotlib.pyplot as plt

def plot_activation_prune_metrics(thresholds, accuracy_lists, flops_lists, model_names):
    plt.clf()
    plt.figure(figsize=(8, 5))
    for acc_list, name in zip(accuracy_lists, model_names):
        plt.plot(thresholds, acc_list, label=name)
    plt.xlabel("Prune Threshold")
    plt.ylabel("Accuracy")
    plt.title("Accuracy vs Threshold (6 Models)")
    plt.xscale("log")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()

    plt.clf()
    plt.figure(figsize=(8, 5))
    for flops_lists, name in zip(flops_lists, model_names):
        plt.plot(thresholds, flops_lists, label=name)
    plt.xlabel("Prune Threshold")
    plt.ylabel("FLOPs")
    plt.title("FLOPs vs Threshold (6 Models)")
    plt.xscale("log")
    plt.yscale("log")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()

def plot_prune_metrics_dual_axis_all_models(thresholds, accuracy_lists, sparsity_lists, model_names):
    fig, ax1 = plt.subplots(figsize=(12, 6))

    num_models = len(model_names)
    colors = cm.get_cmap('tab10', num_models)

    for idx, (acc_list, name) in enumerate(zip(accuracy_lists, model_names)):
        ax1.plot(thresholds, acc_list, label=f'{name} Acc', color=colors(idx), linestyle='-')
    ax1.set_xlabel("Prune Threshold")
    ax1.set_ylabel("Accuracy")
    ax1.set_xscale("log")
    ax1.tick_params(axis='y')
    ax1.grid(False)

    ax2 = ax1.twinx()
    for idx, (sparsity_list, name) in enumerate(zip(sparsity_lists, model_names)):
        ax2.plot(thresholds, sparsity_list, label=f'{name} FLOPs', color=colors(idx), linestyle='--')
    ax2.set_ylabel("FLOPs")
    ax2.set_yscale("log")
    ax2.tick_params(axis='y')

    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper right')

    plt.title("Threshold vs Accuracy & FLOPs (All Models)")
    plt.tight_layout()
    plt.show()

In [None]:
plot_prune_metrics_dual_axis_all_models(act_thresholds, accuracy_lists, flops_lists, model_names)

In [None]:
weight_thresholds = [0.0005, 0.0006, 0.0007, 0.0008, 0.0009, 0.001, 0.0015, 0.002, 0.0025, 0.003, 0.0035, 0.004, 0.0045]
activation_thresholds = [0.01, 0.015, 0.02, 0.025, 0.03, 0.035, 0.04, 0.045, 0.05, 0.055,
              0.06, 0.065, 0.07, 0.075, 0.08, 0.085, 0.09, 0.095, 0.1, 0.15,
              0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7]

In [None]:
def evaluate_weight_activation_prune(model, weight_thresholds, activation_thresholds, val_dataloader):

  for weight_threshold in weight_thresholds:
    weight_prune(model, weight_threshold)
    acc_list, flops_list = evaluate_activation_prune(model, activation_thresholds, val_dataloader)

In [None]:
import torch
import matplotlib.pyplot as plt

def plot_sigmoid_variants(T_values, offset=0.0, x_range=(-5, 5), num_points=500):
    x = torch.linspace(x_range[0], x_range[1], steps=num_points)
    plt.figure(figsize=(10, 6))

    for T in T_values:
        y = 1 / (1 + torch.exp(-T * (x - offset)))
        plt.plot(x.numpy(), y.numpy(), label=f'T = {T}')

    plt.xlabel("x")
    plt.ylabel("f(x)")
    plt.title(f"Soft Sigmoid Variants with Different T (offset = {offset})")
    plt.legend(loc="upper right")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

In [None]:
plot_sigmoid_variants([1, 5, 25, 100, 500])

# Output Graph

In [None]:
_, test_dataloader_cifar10 = get_dataloader(_, val_dataset_cifar10, 1)

In [None]:
ResNet18_b_10 = ResNet18(10)
ResNet18_b_10_config = torch.load("./ResNet18_b_10.pth", map_location=torch.device("cuda"))
ResNet18_b_10.load_state_dict(ResNet18_b_10_config)
ResNet18_b_10.to("cuda")
validate(ResNet18_b_10, val_dataloader_cifar10)

In [None]:
ResNet18_b_l2_retrain_10 = ResNet18(10)
ResNet18_b_l2_retrain_10_config = torch.load("./ResNet18_b_l2_retrain_10.pth", map_location=torch.device("cuda"))
ResNet18_b_l2_retrain_10.load_state_dict(ResNet18_b_l2_retrain_10_config["model_state_dict"])
ResNet18_b_l2_retrain_10.to("cuda")
validate(ResNet18_b_l2_retrain_10, val_dataloader_cifar10)

In [None]:
ResNet18_d_f18_l_s_10 = ResNet18(10)
ResNet18_d_f18_l_s_10_config = torch.load("./ResNet18_d_l_s_5e3_10.pth", map_location=torch.device("cuda"))
ResNet18_d_f18_l_s_10.load_state_dict(ResNet18_d_f18_l_s_10_config["model_state_dict"])
ResNet18_d_f18_l_s_10.to("cuda")
validate(ResNet18_d_f18_l_s_10, val_dataloader_cifar10)

In [None]:
import torch.nn.functional as F

intermediate = []

for data in test_dataloader_cifar10:
  x, y = data
  x, y = x.to("cuda"), y.to("cuda")
  features = []

  x = ResNet18_b_10.conv1(x)
  x = ResNet18_b_10.bn1(x)
  x = F.relu(x)

  for name in ["layer1", "layer2", "layer3", "layer4"]:
      x = getattr(ResNet18_b_10, name)(x)
      intermediate.append(x)

  x = F.avg_pool2d(x, 4)
  x = x.view(x.size(0), -1)
  x = ResNet18_b_10.linear(x)
  break


In [None]:
import torch
import matplotlib.pyplot as plt

def plot_feature_maps(tensor: torch.Tensor, normalize=True, n_cols=8):
    """
    tensor: shape [1, C, H, W]
    n_cols: number of columns to arrange in the plot grid
    """
    assert tensor.ndim == 4 and tensor.shape[0] == 1, "Input must be of shape [1, C, H, W]"
    C = tensor.shape[1]
    n_rows = (C + n_cols - 1) // n_cols

    feature_maps = tensor.squeeze(0)  # shape: [C, H, W]
    if normalize:
        feature_maps = (feature_maps - feature_maps.min(dim=-1)[0].min(dim=-1)[0].unsqueeze(-1).unsqueeze(-1)) / \
                       (feature_maps.max(dim=-1)[0].max(dim=-1)[0].unsqueeze(-1).unsqueeze(-1) + 1e-5)

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 2, n_rows * 2))
    axes = axes.flatten()

    for i in range(C):
        axes[i].imshow(feature_maps[i].detach().cpu().numpy(), cmap='gray')
        axes[i].axis('off')

    for j in range(C, len(axes)):
        axes[j].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
intermediate[0].shape

In [None]:
plot_feature_maps(intermediate[1])

In [None]:
plot_feature_maps(intermediate[0])