In [None]:
# pytorchライブラリのimport
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

# utils
import copy
import math

In [None]:
# 初期設定
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# データセットの読み込み
transform = transforms.Compose([
    transforms.Pad(4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32),
    transforms.ToTensor()])

# CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='../datasets',
                                             train=True,
                                             transform=transform,
                                             download=True)

test_dataset = torchvision.datasets.CIFAR10(root='../datasets',
                                            train=False,
                                            transform=transforms.ToTensor())

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=100,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=100,
                                          shuffle=False)

In [None]:
from src.models.resnet.resnet import ResNet18
from src.pruning.slth.edgepopup import modify_module_for_slth
import copy

model = ResNet18().to(device)

In [None]:
# パラメータ数の取得
def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

print(get_n_params(model))

In [None]:
# 1エポックだけ学習 (コラボGPUで数十秒/エポック)
# hyper parameters
"""
learning_rate = 0.001
num_epochs = 1

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# For updating learning rate
def update_lr(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

# Train the model
total_step = len(train_loader)
curr_lr = learning_rate
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print ("Epoch [{}/{}], Step [{}/{}] Loss: {:.4f}"
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

    # Decay learning rate
    if (epoch+1) % 20 == 0:
        curr_lr /= 3
        update_lr(optimizer, curr_lr)
"""

In [None]:
cd ../

In [None]:
from src.models.resnet.resnet import ResNet18
from src.pruning.slth.edgepopup import modify_module_for_slth
import copy

resnet18 = ResNet18().to(device)
resnet18_slth = modify_module_for_slth(resnet18, remain_rate=0.3).to(device)
resnet18_slth_init = copy.deepcopy(resnet18_slth).to(device)


In [None]:
# （参考）重みが変化していることの確認
#resnet18.conv.state_dict()['weight'] == resnet18_slth.conv.state_dict()['weight']
#resnet18.conv.state_dict()['weight'] == resnet18_slth_init.conv.state_dict()['weight']
resnet18_slth.conv.state_dict()['weight'] == resnet18_slth_init.conv.state_dict()['weight']


In [None]:
# （参考）slth_modelのパラメータ数とmodelのパラメータ数の比較
print("#params in the standard model: ", get_n_params(resnet18))
print("#params in the SLTH model: ", get_n_params(resnet18_slth))

In [None]:
# 以降は全く同じ．
# コラボGPUで数十秒/エポック

# hyper parameters
learning_rate = 0.001
num_epochs = 1

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet18_slth.parameters(), lr=learning_rate)

# For updating learning rate
def update_lr(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

# Train the model
total_step = len(train_loader)
curr_lr = learning_rate
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = resnet18_slth(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print ("Epoch [{}/{}], Step [{}/{}] Loss: {:.4f}"
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

    # Decay learning rate
    if (epoch+1) % 20 == 0:
        curr_lr /= 3
        update_lr(optimizer, curr_lr)

In [None]:
list(resnet18_slth.named_parameters())

In [None]:
for name, param in resnet18_slth.named_parameters():
    if 'weight' in name:  # 'weight'を含む名前のパラメータのみチェック
        # 初期状態のモデルから同じ名前のパラメータを取得
        init_param = resnet18_slth_init.state_dict()[name]
        # 現在のパラメータと初期パラメータを比較
        assert torch.equal(param.data, init_param), f"Weight mismatch found in {name} after epoch {epoch+1}"



In [None]:
# （参考）重みが変化していることの確認
resnet18_slth.conv.state_dict()['weight'] == resnet18_slth_init.conv.state_dict()['weight']
#resnet18_slth.conv.state_dict()['weight'] == resnet18.conv.state_dict()['weight']


In [None]:
resnet18_slth.conv.state_dict()['scores'] == resnet18_slth_init.conv.state_dict()['scores']
