In [None]:
import sys
from google.colab import drive
drive.mount("/content/drive/", force_remount=True)
sys.path.append('')

from pprint import pprint
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.pprint = True

In [None]:
import torch

from data.dataset import get_dataset
from data.dataloader import get_dataloader
from models.ResNet import ResNet18, ResNet34
from evaluation.validate import validate
from train.fine_tune import fine_tune
from train.distillation import distillation

# Global Config

In [None]:
epoch = 300
batch_size = 256

theta = 0.1
alpha = 0.5
beta = 1e-5

# CIFAR 10 Model

## Load Data

In [None]:
train_dataset_cifar10, val_dataset_cifar10 = get_dataset("cifar10")

In [None]:
train_dataloader_cifar10, val_dataloader_cifar10 = get_dataloader(train_dataset_cifar10, val_dataset_cifar10, batch_size)

## Distillation

In [None]:
def distillation_loop(teacher, student, train_dataloader, val_dataloader,
                      output_path, is_norm, norm_type, is_soft_kl, use_soft, penalty_output):
  criterion = torch.nn.CrossEntropyLoss()
  optimizer = torch.optim.SGD(student.parameters(), lr = 0.01, momentum = 0.9, weight_decay = 0.003)
  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epoch * 200)

  distillation(teacher, student,
             train_dataloader, val_dataloader,
             criterion, optimizer, scheduler,
             epoch, is_norm=is_norm, norm_type=norm_type, is_soft_kl=is_soft_kl, use_soft=use_soft, penalty_output=penalty_output)

  model_save = {'model_state_dict': student.state_dict()}
  torch.save(model_save, output_path)

### Base Model

In [None]:
ResNet34_b_10_config = torch.load("./ResNet34_b_10.pth", map_location=torch.device("cuda"))

In [None]:
ResNet34_b_10_t = ResNet34(10)
ResNet34_b_10_t.load_state_dict(ResNet34_b_10_config["model_state_dict"])
ResNet34_b_10_t.to("cuda")
validate(ResNet34_b_10_t, val_dataloader_cifar10)

In [None]:
from evaluation.sparsity import calculate_sparsity
calculate_sparsity(ResNet34_b_10_t)

In [None]:
ResNet18_b_10_config = torch.load("./ResNet18_b_10.pth", map_location=torch.device("cuda"))

In [None]:
ResNet18_b_10_t = ResNet18(10)
ResNet18_b_10_t.load_state_dict(ResNet18_b_10_config)
ResNet18_b_10_t.to("cuda")
validate(ResNet18_b_10_t, val_dataloader_cifar10)

In [None]:
ResNet18_d_l_s_5e3_10 = ResNet18(10)

In [None]:
distillation_loop(teacher=ResNet18_b_10_t,
                  student=ResNet18_d_l_s_5e3_10,
                  train_dataloader=train_dataloader_cifar10,
                  val_dataloader=val_dataloader_cifar10,
                  output_path="ResNet18_d_l_s_5e3_10.pth",
                  is_norm=True,
                  norm_type="l1",
                  penalty_output=0.005,
                  is_soft_kl=True,
                  use_soft=True)

In [None]:
ResNet18_d_f34_l_s_5e3_10 = ResNet18(10)

In [None]:
distillation_loop(teacher=ResNet34_b_10_t,
                  student=ResNet18_d_f34_l_s_5e3_10,
                  train_dataloader=train_dataloader_cifar10,
                  val_dataloader=val_dataloader_cifar10,
                  output_path="ResNet18_d_f34_l_s_5e3_10.pth",
                  is_norm=True,
                  norm_type="l1",
                  penalty_output=0.005,
                  is_soft_kl=True,
                  use_soft=True)

In [None]:


import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

from evaluation.validate import validate
from evaluation.sparsity import calculate_sparsity
from train.distillation import get_weights_norm


def prune_conv2d_l2_in_blocks(model: nn.Module, prune_ratio=0.3) -> None:
    with torch.no_grad():
        for name, module in model.named_children():
            if isinstance(module, nn.Sequential) and "downsample" not in name:
                for submodule in module.modules():
                    if isinstance(submodule, nn.Conv2d):
                        weight = submodule.weight.data
                        flat = weight.view(-1).abs()
                        k = int(flat.numel() * prune_ratio)
                        if k > 0:
                            threshold = torch.kthvalue(flat, k).values.item()
                            mask = weight.abs() < threshold
                            weight[mask] = 0.0


def fine_tune(model: nn.Module,
              train_loader: DataLoader,
              val_loader: DataLoader,
              criterion: nn.Module,
              optimizer: optim.Optimizer,
              scheduler: optim.lr_scheduler,
              epoch: int,
              device: str='cuda') -> None:
    model.to(device)

    out_bar = tqdm(range(epoch), desc="Epoch")
    for epoch in out_bar:
        model.train()
        pbar = tqdm(train_loader, desc=f"Train E{epoch}", leave=False)
        for x, y in pbar:
            x, y = x.to(device), y.to(device)
            output = model(x)

            loss = criterion(output, y)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()

            pbar.set_postfix({"loss": loss.item()})

        out_bar.set_postfix({"accuracy": validate(model, val_loader, device)})


In [None]:
def get_conv2d_masks(model: nn.Module) -> dict:
    """获取剪枝后 Conv2d 层的 mask（仅中间 blocks，不含 downsample）"""
    mask_dict = {}
    for name, module in model.named_children():
        if isinstance(module, nn.Sequential) and "downsample" not in name:
            for submodule in module.modules():
                if isinstance(submodule, nn.Conv2d):
                    weight = submodule.weight.data
                    mask = (weight != 0).float()
                    mask_dict[id(submodule.weight)] = mask  # 用id追踪weight变量本体
    return mask_dict


def apply_mask_gradient_hooks(model: nn.Module, mask_dict: dict) -> None:
    """注册 hook，屏蔽已剪掉位置的梯度更新"""
    for name, module in model.named_children():
        if isinstance(module, nn.Sequential) and "downsample" not in name:
            for submodule in module.modules():
                if isinstance(submodule, nn.Conv2d):
                    weight = submodule.weight
                    if id(weight) in mask_dict:
                        mask = mask_dict[id(weight)]
                        weight.register_hook(lambda grad, mask=mask: grad * mask)

def prune_and_retrain(model: nn.Module,
                      train_loader: DataLoader,
                      val_loader: DataLoader,
                      epoch: int,
                      device: str = 'cuda') -> None:
    # Step 1: 原地剪枝
    prune_conv2d_l2_in_blocks(model, prune_ratio=0.3)

    print(validate(model, val_dataloader_cifar10))

    # Step 2: 获取剪枝后的权重 mask
    mask_dict = get_conv2d_masks(model)

    # Step 3: 注册 mask-based hook，阻止剪掉的权重被更新
    apply_mask_gradient_hooks(model, mask_dict)

    # Step 4: 设定训练配置并 retrain（你已有的 fine_tune 函数）
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epoch*200)

    fine_tune(model, train_loader, val_loader, criterion, optimizer, scheduler, epoch, device)

In [None]:
ResNet18_b_l2_retrain_10 = ResNet18(10)
ResNet18_b_10_config = torch.load("./ResNet18_b_10.pth", map_location=torch.device("cuda"))
ResNet18_b_l2_retrain_10.load_state_dict(ResNet18_b_10_config)
ResNet18_b_l2_retrain_10.to("cuda")
validate(ResNet18_b_l2_retrain_10, val_dataloader_cifar10)

In [None]:
prune_and_retrain(ResNet18_b_l2_retrain_10, train_dataloader_cifar10, val_dataloader_cifar10, epoch=300)

In [None]:
model_save = {'model_state_dict': ResNet18_b_l2_retrain_10.state_dict()}
torch.save(model_save, "ResNet18_b_l2_retrain_10.pth")