In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torch.nn.init as init
import torch.nn.functional as F
import visdom
import copy
import torch.nn.utils.prune as prune
from tqdm.notebook import tqdm
import numpy as np
import timeit

# custom librarys
import custom.model as cm # 저장된 model

In [2]:
#torch.manual_seed(55)
#torch.cuda.manual_seed_all(55)
#torch.backends.cudnn.enabled = False

In [3]:
GPU_NUM = 1
device = torch.device(f'cuda:{GPU_NUM}' if torch.cuda.is_available() else 'cpu')
torch.cuda.set_device(device)

print ('Available devices ', torch.cuda.device_count())
print ('Current cuda device ', torch.cuda.current_device())
print(torch.cuda.get_device_name(device))

print("cpu와 cuda 중 다음 기기로 학습함:", device, '\n')

Available devices  2
Current cuda device  1
GeForce RTX 2080 Ti
cpu와 cuda 중 다음 기기로 학습함: cuda:1 



In [4]:
# visdom setting
vis = visdom.Visdom()
vis.close(env="main")

# make plot
vis_plt = vis.line(X=torch.Tensor(1).zero_(), Y=torch.Tensor(1).zero_(), 
                    opts=dict(title = 'LeNet300_Accuracy_Tracker',
                              legend=['100'],
                             showlegend=True,
                              xtickmin = 0,
                              xtickmax = 20000,
                              ytickmin = 0.95,
                              ytickmax = 0.99
                             )
                   )

def visdom_plot(loss_plot, loss_value, num, name):
    vis.line(X = num,
            Y = loss_value,
            win = loss_plot,
            name = name,
            update = 'append'
            )

Setting up a new session...


In [5]:
# parameter
lr = 0.0012
#epochs = 50
#epochs = 20
epochs = 30
batch_size = 60
weight_decay = 1.2e-3
iteration = 0
remaining_weight = 1
prune_per = 0.2
# number of iteration
noi = 11

switch = 0
best_accu = []
# 마지막 layer의 Pruning rate는 기존의 1/2
# prune_per_ll = prune_per/2

In [6]:
transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

mnist_train = dsets.MNIST(root='../MNIST_data/',
                         train=True,
                         transform=transforms,
                         download=True)
mnist_test = dsets.MNIST(root='../MNIST_data/',
                        train=False,
                        transform=transforms,
                        download=True)

train_loader = torch.utils.data.DataLoader(dataset=mnist_train,
                                         batch_size=batch_size,
                                         shuffle=True,
                                         drop_last=True)
test_loader = torch.utils.data.DataLoader(dataset=mnist_test,
                                         shuffle=False,
                                         drop_last=True)

In [7]:
def train(model, dataloader, optimizer, criterion, cp_mask, switch):
    model.train()
    running_loss = 0.0
    EPS = 1e-6
    for data, label in dataloader:
        data, label = data.to(device), label.to(device)
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, label)
        loss.backward()

        if switch == 1:
            # 0-weight 학습 방지 code
            i = 0
            for name, p in model.named_parameters():
                if 'weight' in name:
                    p.grad.data *= cp_mask[i]
                    i += 1

        optimizer.step()

        running_loss = loss / len(dataloader)
    return running_loss

def test(model, dataloader, criterion):
    model.eval()
    correct = 0.0
    total = 0.0
    with torch.no_grad():
        for data, label in test_loader:
            data, label = data.to(device), label.to(device)
            outputs = model(data)
            
            predicted = torch.argmax(outputs.data, 1)
            total += label.size(0)
            correct += (predicted == label).sum().item()
            accuracy = (correct/total)

    return accuracy

# prune function
# pruning mask 생성 -> mask 복사 -> init값 복사 -> prune 진행
def weight_init(model1, model2, rate):
    # prune mask 생성
    for name, module in model1.named_modules():
        if isinstance(module, nn.Linear):
            # bottle neck 방지
            if name == 'fc3':
                prune.l1_unstructured(module, name = 'weight', amount = (rate/2))
            else:
                prune.l1_unstructured(module, name = 'weight', amount = rate)
                        
    # mask 복사
    cp_mask = []
    for name, mask in model1.named_buffers():
        cp_mask.append(mask)
    # init 값 복사
    
    for name, p in model1.named_parameters():
        if 'weight_orig' in name:
            for name2, p2 in model2.named_parameters():
                if name[0:len(name) - 5] in name2:
                    p.data = copy.deepcopy(p2.data)
                    break
        if 'bias_orig' in name:
            for name2, p2 in model2.named_parameters():
                if name[0:len(name) - 5] in name2:
                    p.data = copy.deepcopy(p2.data)
                    break
                    
    # prune 진행
    for name, module in model1.named_modules():
        if isinstance(module, nn.Linear):
            prune.remove(module, name = 'weight')
            
    return cp_mask

In [8]:
model = cm.LeNet300().to(device)
model_init = copy.deepcopy(model)

criterion = nn.CrossEntropyLoss().to(device)
#optimizer = optim.Adam(model.parameters(), lr = lr, weight_decay = 1.2e-3)

In [9]:
#EPS = 1e-6
# number of weight
a = ((model.fc1.weight != 0).sum(dim=1)).sum(dim=0) + ((model.fc2.weight != 0).sum(dim=1)).sum(dim=0) + ((model.fc3.weight != 0).sum(dim=1)).sum(dim=0)
#b = ((model.fc1.weight == 0).sum(dim=1)).sum(dim=0) + ((model.fc2.weight == 0).sum(dim=1)).sum(dim=0) + ((model.fc3.weight == 0).sum(dim=1)).sum(dim=0)
b = ((model.fc1.weight == 0).sum(dim=1)).sum(dim=0) + ((model.fc2.weight == 0).sum(dim=1)).sum(dim=0) + ((model.fc3.weight == 0).sum(dim=1)).sum(dim=0)

now = (a + b)

In [10]:
def calc_now(model):
    fc1_1 = ((model.fc1.weight != 0).sum(dim=1)).sum(dim=0)
    fc1_0 = ((model.fc1.weight == 0).sum(dim=1)).sum(dim=0)
    fc1 = fc1_1 + fc1_0
    fc1_p = fc1_0 / fc1_1
    fc2_1 = ((model.fc2.weight != 0).sum(dim=1)).sum(dim=0)
    fc2_0 = ((model.fc2.weight == 0).sum(dim=1)).sum(dim=0)
    fc2 = fc2_1 + fc2_0
    fc3_1 = ((model.fc3.weight != 0).sum(dim=1)).sum(dim=0)
    fc3_0 = ((model.fc3.weight == 0).sum(dim=1)).sum(dim=0)
    fc3 = fc3_1 + fc3_0
    #print(fc1, fc2, fc3, fc1+fc2+fc3, fc1_1 + fc2_1 + fc3_1 ,fc1_0 + fc2_0 + fc3_0)
    print('%d' % (fc1+fc2+fc3),
         '(%d |' % (fc1_1+fc2_1+fc3_1),
         '%d)' % (fc1_0+fc2_0+fc3_0)
         )
    print('%d' % fc1,
         '(%d |' % fc1_1,
         '%d)' % fc1_0
         )
    print('%d' % fc2,
         '(%d |' % fc2_1,
         '%d)' % fc2_0
         )
    print('%d' % fc3,
         '(%d |' % fc3_1,
         '%d)' % fc3_0
         )

a

b

weight_init(model, model_init, 1 - weight_remaining)

weight_init(model, model_init, 1 - weight_remaining)

model.state_dict().keys()

print(model.fc3.weight_orig)

for name, module in model.named_modules():
    if isinstance(module, nn.Linear):
        prune.l1_unstructured(module, name = 'weight', amount = 0.9)

    # init 값 복사
for name, p in model.named_parameters():
     if 'weight_orig' in name:
        for name2, p2 in model_init.named_parameters():
            if name[0:len(name) - 5] in name2:
                p.data = copy.deepcopy(p2.data)
                break
    if 'bias_orig' in name:
        for name2, p2 in modelinit.named_parameters():
            if name[0:len(name) - 5] in name2:
                p.data = copy.deepcopy(p2.data)
                break

for name, module in model.named_modules():
    if isinstance(module, nn.Linear):
        prune.remove(module, name = 'weight')

print(model.fc3.weight[0][0])

print(model_init.fc3.weight[0][0])

In [11]:
for i in range(noi):
    best_accu.append(0)
    best_accu[i] = [0, 0, 0]
    cp_mask = []
    if i != 0:
        remaining_weight = remaining_weight * (1-prune_per)
        cp_mask = weight_init(model, model_init, 1 - remaining_weight)
        switch = 1
    optimizer = optim.Adam(model.parameters(), lr = lr, weight_decay = weight_decay)
    print("Learning start(remaining weight : %d%%)" % round(remaining_weight * 100, 1))
    start_time = timeit.default_timer()
    pw = ((model.fc1.weight == 0).sum(dim=1)).sum(dim=0) + ((model.fc2.weight == 0).sum(dim=1)).sum(dim=0) + ((model.fc3.weight == 0).sum(dim=1)).sum(dim=0)
    print('pruned weight (All | Pruned) %d |' % now,'%d' % pw)
    #print(model.fc3.weight[0][0])
    #print(model_init.fc3.weight[0][0])
    calc_now(model)
    
    
    
    for epoch in tqdm(range(epochs)):
        if epoch == 0:
            accuracy = test(model, test_loader, criterion)
            visdom_plot(vis_plt, torch.Tensor([accuracy]), torch.Tensor([0]), str(round(remaining_weight*100, 1)))
            print('[epoch : %d]' % (epoch),
             '(loss: x.xxxxx)',
             '(accu: %.4f)' % (accuracy)
             )
        running_loss = train(model, train_loader, optimizer, criterion, cp_mask, switch)
        accuracy = test(model, test_loader, criterion)
        visdom_plot(vis_plt, torch.Tensor([accuracy]), torch.Tensor([(epoch+1) * 1000]), str(round(remaining_weight*100, 1)))
        
        # best accuracy list (weight_remain, epoch, accuracy)
        if best_accu[i][2] <= accuracy:
            best_accu[i] = [remaining_weight, epoch, accuracy]
        
        print('[epoch : %d]' % (epoch+1),
             '(loss: %.5f)' % (running_loss),
             '(accu: %.4f)' % (accuracy)
             )
    stop_time = timeit.default_timer()
    #print(model.fc3.weight[0][0])
    #print(model_init.fc3.weight[0][0])
    print("Finish!",
          "(Best accu: %.4f)" % best_accu[i][2],
          "(Time taken(sec) : %.2f)" % (stop_time - start_time),
          "\n")
    calc_now(model)
    print("\n\n\n\n\n\n")
    

Learning start(remaining weight : 100%)
pruned weight (All | Pruned) 266200 | 0
266200 (266200 | 0)
235200 (235200 | 0)
30000 (30000 | 0)
1000 (1000 | 0)




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

[epoch : 0] (loss: x.xxxxx) (accu: 0.0813)
[epoch : 1] (loss: 0.00019) (accu: 0.9607)
[epoch : 2] (loss: 0.00016) (accu: 0.9701)
[epoch : 3] (loss: 0.00002) (accu: 0.9743)
[epoch : 4] (loss: 0.00011) (accu: 0.9731)
[epoch : 5] (loss: 0.00011) (accu: 0.9712)
[epoch : 6] (loss: 0.00011) (accu: 0.9708)
[epoch : 7] (loss: 0.00013) (accu: 0.9738)
[epoch : 8] (loss: 0.00009) (accu: 0.9701)
[epoch : 9] (loss: 0.00015) (accu: 0.9751)
[epoch : 10] (loss: 0.00006) (accu: 0.9775)
[epoch : 11] (loss: 0.00002) (accu: 0.9726)
[epoch : 12] (loss: 0.00017) (accu: 0.9675)
[epoch : 13] (loss: 0.00006) (accu: 0.9735)
[epoch : 14] (loss: 0.00004) (accu: 0.9763)
[epoch : 15] (loss: 0.00005) (accu: 0.9729)
[epoch : 16] (loss: 0.00003) (accu: 0.9776)
[epoch : 17] (loss: 0.00010) (accu: 0.9753)
[epoch : 18] (loss: 0.00007) (accu: 0.9744)
[epoch : 19] (loss: 0.00011) (accu: 0.9751)
[epoch : 20] (loss: 0.00014) (accu: 0.9682)
[epoch : 21] (loss: 0.00008) (accu: 0.9763)
[epoch : 22] (loss: 0.00003) (accu: 0.9780

HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

[epoch : 0] (loss: x.xxxxx) (accu: 0.0961)
[epoch : 1] (loss: 0.00014) (accu: 0.9642)
[epoch : 2] (loss: 0.00017) (accu: 0.9649)
[epoch : 3] (loss: 0.00006) (accu: 0.9683)
[epoch : 4] (loss: 0.00003) (accu: 0.9706)
[epoch : 5] (loss: 0.00012) (accu: 0.9722)
[epoch : 6] (loss: 0.00008) (accu: 0.9776)
[epoch : 7] (loss: 0.00007) (accu: 0.9723)
[epoch : 8] (loss: 0.00003) (accu: 0.9765)
[epoch : 9] (loss: 0.00005) (accu: 0.9761)
[epoch : 10] (loss: 0.00003) (accu: 0.9742)
[epoch : 11] (loss: 0.00010) (accu: 0.9723)
[epoch : 12] (loss: 0.00005) (accu: 0.9750)
[epoch : 13] (loss: 0.00014) (accu: 0.9732)
[epoch : 14] (loss: 0.00008) (accu: 0.9753)
[epoch : 15] (loss: 0.00003) (accu: 0.9749)
[epoch : 16] (loss: 0.00007) (accu: 0.9777)
[epoch : 17] (loss: 0.00007) (accu: 0.9754)
[epoch : 18] (loss: 0.00004) (accu: 0.9738)
[epoch : 19] (loss: 0.00002) (accu: 0.9764)
[epoch : 20] (loss: 0.00007) (accu: 0.9728)
[epoch : 21] (loss: 0.00002) (accu: 0.9767)
[epoch : 22] (loss: 0.00009) (accu: 0.9771

HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

[epoch : 0] (loss: x.xxxxx) (accu: 0.0959)
[epoch : 1] (loss: 0.00023) (accu: 0.9676)
[epoch : 2] (loss: 0.00021) (accu: 0.9733)
[epoch : 3] (loss: 0.00003) (accu: 0.9717)
[epoch : 4] (loss: 0.00003) (accu: 0.9784)
[epoch : 5] (loss: 0.00008) (accu: 0.9764)
[epoch : 6] (loss: 0.00005) (accu: 0.9777)
[epoch : 7] (loss: 0.00004) (accu: 0.9704)
[epoch : 8] (loss: 0.00003) (accu: 0.9755)
[epoch : 9] (loss: 0.00006) (accu: 0.9768)
[epoch : 10] (loss: 0.00007) (accu: 0.9735)
[epoch : 11] (loss: 0.00008) (accu: 0.9698)
[epoch : 12] (loss: 0.00002) (accu: 0.9765)
[epoch : 13] (loss: 0.00001) (accu: 0.9766)
[epoch : 14] (loss: 0.00004) (accu: 0.9759)
[epoch : 15] (loss: 0.00004) (accu: 0.9749)
[epoch : 16] (loss: 0.00009) (accu: 0.9724)
[epoch : 17] (loss: 0.00006) (accu: 0.9735)
[epoch : 18] (loss: 0.00005) (accu: 0.9723)
[epoch : 19] (loss: 0.00008) (accu: 0.9753)
[epoch : 20] (loss: 0.00012) (accu: 0.9753)
[epoch : 21] (loss: 0.00003) (accu: 0.9715)
[epoch : 22] (loss: 0.00004) (accu: 0.9756

HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

[epoch : 0] (loss: x.xxxxx) (accu: 0.0972)
[epoch : 1] (loss: 0.00004) (accu: 0.9617)
[epoch : 2] (loss: 0.00008) (accu: 0.9646)
[epoch : 3] (loss: 0.00010) (accu: 0.9725)
[epoch : 4] (loss: 0.00004) (accu: 0.9718)
[epoch : 5] (loss: 0.00008) (accu: 0.9757)
[epoch : 6] (loss: 0.00007) (accu: 0.9779)
[epoch : 7] (loss: 0.00009) (accu: 0.9756)
[epoch : 8] (loss: 0.00011) (accu: 0.9733)
[epoch : 9] (loss: 0.00017) (accu: 0.9721)
[epoch : 10] (loss: 0.00002) (accu: 0.9696)
[epoch : 11] (loss: 0.00002) (accu: 0.9701)
[epoch : 12] (loss: 0.00004) (accu: 0.9746)
[epoch : 13] (loss: 0.00017) (accu: 0.9717)
[epoch : 14] (loss: 0.00008) (accu: 0.9685)
[epoch : 15] (loss: 0.00008) (accu: 0.9761)
[epoch : 16] (loss: 0.00007) (accu: 0.9758)
[epoch : 17] (loss: 0.00002) (accu: 0.9769)
[epoch : 18] (loss: 0.00002) (accu: 0.9746)
[epoch : 19] (loss: 0.00003) (accu: 0.9793)
[epoch : 20] (loss: 0.00009) (accu: 0.9748)
[epoch : 21] (loss: 0.00003) (accu: 0.9750)
[epoch : 22] (loss: 0.00008) (accu: 0.9770

HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

[epoch : 0] (loss: x.xxxxx) (accu: 0.0974)
[epoch : 1] (loss: 0.00009) (accu: 0.9586)
[epoch : 2] (loss: 0.00025) (accu: 0.9702)
[epoch : 3] (loss: 0.00002) (accu: 0.9714)
[epoch : 4] (loss: 0.00012) (accu: 0.9728)
[epoch : 5] (loss: 0.00005) (accu: 0.9736)
[epoch : 6] (loss: 0.00012) (accu: 0.9737)
[epoch : 7] (loss: 0.00021) (accu: 0.9769)
[epoch : 8] (loss: 0.00011) (accu: 0.9708)
[epoch : 9] (loss: 0.00009) (accu: 0.9743)
[epoch : 10] (loss: 0.00004) (accu: 0.9748)
[epoch : 11] (loss: 0.00002) (accu: 0.9774)
[epoch : 12] (loss: 0.00002) (accu: 0.9734)
[epoch : 13] (loss: 0.00009) (accu: 0.9776)
[epoch : 14] (loss: 0.00006) (accu: 0.9749)
[epoch : 15] (loss: 0.00003) (accu: 0.9786)
[epoch : 16] (loss: 0.00004) (accu: 0.9739)
[epoch : 17] (loss: 0.00005) (accu: 0.9771)
[epoch : 18] (loss: 0.00002) (accu: 0.9773)
[epoch : 19] (loss: 0.00001) (accu: 0.9772)
[epoch : 20] (loss: 0.00005) (accu: 0.9765)
[epoch : 21] (loss: 0.00005) (accu: 0.9783)
[epoch : 22] (loss: 0.00013) (accu: 0.9748

HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

[epoch : 0] (loss: x.xxxxx) (accu: 0.0974)
[epoch : 1] (loss: 0.00013) (accu: 0.9636)
[epoch : 2] (loss: 0.00014) (accu: 0.9703)
[epoch : 3] (loss: 0.00015) (accu: 0.9712)
[epoch : 4] (loss: 0.00004) (accu: 0.9692)
[epoch : 5] (loss: 0.00005) (accu: 0.9775)
[epoch : 6] (loss: 0.00017) (accu: 0.9738)
[epoch : 7] (loss: 0.00014) (accu: 0.9733)
[epoch : 8] (loss: 0.00009) (accu: 0.9701)
[epoch : 9] (loss: 0.00004) (accu: 0.9764)
[epoch : 10] (loss: 0.00005) (accu: 0.9774)
[epoch : 11] (loss: 0.00007) (accu: 0.9744)
[epoch : 12] (loss: 0.00005) (accu: 0.9809)
[epoch : 13] (loss: 0.00003) (accu: 0.9773)
[epoch : 14] (loss: 0.00006) (accu: 0.9763)
[epoch : 15] (loss: 0.00008) (accu: 0.9761)
[epoch : 16] (loss: 0.00005) (accu: 0.9758)
[epoch : 17] (loss: 0.00005) (accu: 0.9769)
[epoch : 18] (loss: 0.00002) (accu: 0.9686)
[epoch : 19] (loss: 0.00002) (accu: 0.9762)
[epoch : 20] (loss: 0.00002) (accu: 0.9777)
[epoch : 21] (loss: 0.00001) (accu: 0.9754)
[epoch : 22] (loss: 0.00002) (accu: 0.9739

HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

[epoch : 0] (loss: x.xxxxx) (accu: 0.0974)
[epoch : 1] (loss: 0.00010) (accu: 0.9608)
[epoch : 2] (loss: 0.00019) (accu: 0.9674)
[epoch : 3] (loss: 0.00005) (accu: 0.9722)
[epoch : 4] (loss: 0.00013) (accu: 0.9708)
[epoch : 5] (loss: 0.00013) (accu: 0.9667)
[epoch : 6] (loss: 0.00002) (accu: 0.9704)
[epoch : 7] (loss: 0.00004) (accu: 0.9684)
[epoch : 8] (loss: 0.00003) (accu: 0.9756)
[epoch : 9] (loss: 0.00009) (accu: 0.9753)
[epoch : 10] (loss: 0.00001) (accu: 0.9776)
[epoch : 11] (loss: 0.00014) (accu: 0.9710)
[epoch : 12] (loss: 0.00001) (accu: 0.9749)
[epoch : 13] (loss: 0.00005) (accu: 0.9739)
[epoch : 14] (loss: 0.00006) (accu: 0.9783)
[epoch : 15] (loss: 0.00002) (accu: 0.9762)
[epoch : 16] (loss: 0.00006) (accu: 0.9789)
[epoch : 17] (loss: 0.00008) (accu: 0.9772)
[epoch : 18] (loss: 0.00011) (accu: 0.9770)
[epoch : 19] (loss: 0.00004) (accu: 0.9776)
[epoch : 20] (loss: 0.00004) (accu: 0.9731)
[epoch : 21] (loss: 0.00003) (accu: 0.9767)
[epoch : 22] (loss: 0.00002) (accu: 0.9801

HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

[epoch : 0] (loss: x.xxxxx) (accu: 0.0974)
[epoch : 1] (loss: 0.00021) (accu: 0.9565)
[epoch : 2] (loss: 0.00009) (accu: 0.9684)
[epoch : 3] (loss: 0.00011) (accu: 0.9717)
[epoch : 4] (loss: 0.00016) (accu: 0.9725)
[epoch : 5] (loss: 0.00010) (accu: 0.9729)
[epoch : 6] (loss: 0.00007) (accu: 0.9769)
[epoch : 7] (loss: 0.00001) (accu: 0.9774)
[epoch : 8] (loss: 0.00006) (accu: 0.9741)
[epoch : 9] (loss: 0.00016) (accu: 0.9742)
[epoch : 10] (loss: 0.00015) (accu: 0.9728)
[epoch : 11] (loss: 0.00002) (accu: 0.9768)
[epoch : 12] (loss: 0.00015) (accu: 0.9776)
[epoch : 13] (loss: 0.00007) (accu: 0.9795)
[epoch : 14] (loss: 0.00003) (accu: 0.9764)
[epoch : 15] (loss: 0.00001) (accu: 0.9749)
[epoch : 16] (loss: 0.00001) (accu: 0.9757)
[epoch : 17] (loss: 0.00011) (accu: 0.9755)
[epoch : 18] (loss: 0.00007) (accu: 0.9780)
[epoch : 19] (loss: 0.00004) (accu: 0.9759)
[epoch : 20] (loss: 0.00005) (accu: 0.9795)
[epoch : 21] (loss: 0.00002) (accu: 0.9761)
[epoch : 22] (loss: 0.00006) (accu: 0.9766

HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

[epoch : 0] (loss: x.xxxxx) (accu: 0.0974)
[epoch : 1] (loss: 0.00011) (accu: 0.9560)
[epoch : 2] (loss: 0.00022) (accu: 0.9672)
[epoch : 3] (loss: 0.00005) (accu: 0.9686)
[epoch : 4] (loss: 0.00005) (accu: 0.9743)
[epoch : 5] (loss: 0.00006) (accu: 0.9717)
[epoch : 6] (loss: 0.00010) (accu: 0.9755)
[epoch : 7] (loss: 0.00013) (accu: 0.9765)
[epoch : 8] (loss: 0.00016) (accu: 0.9759)
[epoch : 9] (loss: 0.00016) (accu: 0.9761)
[epoch : 10] (loss: 0.00003) (accu: 0.9748)
[epoch : 11] (loss: 0.00011) (accu: 0.9772)
[epoch : 12] (loss: 0.00002) (accu: 0.9797)
[epoch : 13] (loss: 0.00004) (accu: 0.9739)
[epoch : 14] (loss: 0.00003) (accu: 0.9779)
[epoch : 15] (loss: 0.00004) (accu: 0.9727)
[epoch : 16] (loss: 0.00013) (accu: 0.9784)
[epoch : 17] (loss: 0.00001) (accu: 0.9754)
[epoch : 18] (loss: 0.00005) (accu: 0.9746)
[epoch : 19] (loss: 0.00012) (accu: 0.9752)
[epoch : 20] (loss: 0.00007) (accu: 0.9764)
[epoch : 21] (loss: 0.00009) (accu: 0.9762)
[epoch : 22] (loss: 0.00008) (accu: 0.9758

HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

[epoch : 0] (loss: x.xxxxx) (accu: 0.0974)
[epoch : 1] (loss: 0.00020) (accu: 0.9591)
[epoch : 2] (loss: 0.00018) (accu: 0.9659)
[epoch : 3] (loss: 0.00001) (accu: 0.9702)
[epoch : 4] (loss: 0.00010) (accu: 0.9723)
[epoch : 5] (loss: 0.00027) (accu: 0.9742)
[epoch : 6] (loss: 0.00005) (accu: 0.9733)
[epoch : 7] (loss: 0.00004) (accu: 0.9729)
[epoch : 8] (loss: 0.00008) (accu: 0.9736)
[epoch : 9] (loss: 0.00003) (accu: 0.9750)
[epoch : 10] (loss: 0.00002) (accu: 0.9749)
[epoch : 11] (loss: 0.00005) (accu: 0.9759)
[epoch : 12] (loss: 0.00002) (accu: 0.9774)
[epoch : 13] (loss: 0.00007) (accu: 0.9755)
[epoch : 14] (loss: 0.00009) (accu: 0.9740)
[epoch : 15] (loss: 0.00012) (accu: 0.9737)
[epoch : 16] (loss: 0.00008) (accu: 0.9749)
[epoch : 17] (loss: 0.00005) (accu: 0.9751)
[epoch : 18] (loss: 0.00004) (accu: 0.9778)
[epoch : 19] (loss: 0.00004) (accu: 0.9763)
[epoch : 20] (loss: 0.00002) (accu: 0.9778)
[epoch : 21] (loss: 0.00005) (accu: 0.9780)
[epoch : 22] (loss: 0.00001) (accu: 0.9767

HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

[epoch : 0] (loss: x.xxxxx) (accu: 0.0974)
[epoch : 1] (loss: 0.00011) (accu: 0.9534)
[epoch : 2] (loss: 0.00008) (accu: 0.9669)
[epoch : 3] (loss: 0.00015) (accu: 0.9642)
[epoch : 4] (loss: 0.00003) (accu: 0.9747)
[epoch : 5] (loss: 0.00006) (accu: 0.9736)
[epoch : 6] (loss: 0.00012) (accu: 0.9744)
[epoch : 7] (loss: 0.00004) (accu: 0.9749)
[epoch : 8] (loss: 0.00002) (accu: 0.9748)
[epoch : 9] (loss: 0.00003) (accu: 0.9742)
[epoch : 10] (loss: 0.00005) (accu: 0.9760)
[epoch : 11] (loss: 0.00003) (accu: 0.9759)
[epoch : 12] (loss: 0.00003) (accu: 0.9764)
[epoch : 13] (loss: 0.00003) (accu: 0.9787)
[epoch : 14] (loss: 0.00006) (accu: 0.9736)
[epoch : 15] (loss: 0.00005) (accu: 0.9760)
[epoch : 16] (loss: 0.00002) (accu: 0.9743)
[epoch : 17] (loss: 0.00005) (accu: 0.9759)
[epoch : 18] (loss: 0.00003) (accu: 0.9761)
[epoch : 19] (loss: 0.00005) (accu: 0.9780)
[epoch : 20] (loss: 0.00010) (accu: 0.9745)
[epoch : 21] (loss: 0.00002) (accu: 0.9748)
[epoch : 22] (loss: 0.00011) (accu: 0.9742

In [12]:
print(model.fc3.bias)

Parameter containing:
tensor([-0.0220, -0.0270,  0.0028, -0.0519,  0.0015, -0.0087, -0.0189, -0.0410,
         0.1603,  0.0056], device='cuda:1', requires_grad=True)


In [13]:
print("Maximum accuracy per weight remaining")
for i in range(len(best_accu)):
    print("Remaining weight %.1f %% " % (best_accu[i][0] * 100),
         "Epoch %d" % best_accu[i][1],
         "Accu %.4f %%" % best_accu[i][2])

Maximum accuracy per weight remaining
Remaining weight 100.0 %  Epoch 21 Accu 0.9780 %
Remaining weight 80.0 %  Epoch 25 Accu 0.9804 %
Remaining weight 64.0 %  Epoch 3 Accu 0.9784 %
Remaining weight 51.2 %  Epoch 23 Accu 0.9803 %
Remaining weight 41.0 %  Epoch 26 Accu 0.9794 %
Remaining weight 32.8 %  Epoch 11 Accu 0.9809 %
Remaining weight 26.2 %  Epoch 21 Accu 0.9801 %
Remaining weight 21.0 %  Epoch 19 Accu 0.9795 %
Remaining weight 16.8 %  Epoch 11 Accu 0.9797 %
Remaining weight 13.4 %  Epoch 24 Accu 0.9794 %
Remaining weight 10.7 %  Epoch 12 Accu 0.9787 %


In [14]:
print(model.fc3.weight)

Parameter containing:
tensor([[ 4.1834e-12,  0.0000e+00,  5.4876e-01, -5.2419e-39,  3.2088e-01,
          0.0000e+00, -1.9255e-18, -2.0705e-01,  0.0000e+00,  4.2849e-12,
         -8.2884e-02, -2.3371e-01,  6.1075e-04,  0.0000e+00,  0.0000e+00,
         -2.7548e-01, -0.0000e+00,  0.0000e+00, -0.0000e+00, -8.3592e-02,
          0.0000e+00,  4.4344e-02,  2.1923e-01, -2.4551e-41,  0.0000e+00,
          3.5267e-27, -0.0000e+00, -1.1557e-02, -2.4512e-01, -0.0000e+00,
         -0.0000e+00, -3.8251e-33,  0.0000e+00,  2.3943e-01,  7.8822e-03,
          2.8310e-01,  2.8572e-01,  0.0000e+00,  2.5408e-02,  0.0000e+00,
         -3.6546e-01, -0.0000e+00,  2.5015e-01, -2.7091e-02,  4.8094e-02,
         -0.0000e+00, -1.5852e-01, -0.0000e+00,  3.6886e-07,  2.4070e-01,
         -4.2816e-13,  2.0453e-01,  0.0000e+00, -8.7301e-43, -9.6522e-02,
         -0.0000e+00,  5.6883e-08,  0.0000e+00, -1.2275e-01, -0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00, -2.2855e-01,  7.4609e-02,
         -0.0000

for name, p in model.named_parameters():
    EPS = 1e-6
    if 'weight' in name:
        tensor = p.data.cpu().numpy()
        grad_tensor = p.grad.data.cpu().numpy()
        grad_tensor = np.where(tensor < EPS, 0, grad_tensor)
        p.grad.data = torch.from_numpy(grad_tensor).to(device)
        print(p.grad.data)

데이터 숫자 60000
배치 길이 60
배치 개수 1000
epoch = 50

이터레이션 횟수 50000