In [1]:
import sys
sys.path.append("..")

In [2]:
from asyncio import MultiLoopChildWatcher
from doctest import OutputChecker

from turtle import hideturtle
import warnings

from models import GeneralModel
from models.statistics.Metrics import Metrics
from utils.config_utils import *
from utils.model_utils import *
from utils.system_utils import *

warnings.filterwarnings("ignore")

The Zen of Python, by Tim Peters

Beautiful is better than ugly.
Explicit is better than implicit.
Simple is better than complex.
Complex is better than complicated.
Flat is better than nested.
Sparse is better than dense.
Readability counts.
Special cases aren't special enough to break the rules.
Although practicality beats purity.
Errors should never pass silently.
Unless explicitly silenced.
In the face of ambiguity, refuse the temptation to guess.
There should be one-- and preferably only one --obvious way to do it.
Although that way may not be obvious at first unless you're Dutch.
Now is better than never.
Although never is often better than *right* now.
If the implementation is hard to explain, it's a bad idea.
If the implementation is easy to explain, it may be a good idea.
Namespaces are one honking great idea -- let's do more of those!


In [3]:
torch.__version__

'1.11.0+cu113'

In [18]:
# define arguments manually
arguments = argparse.Namespace()
# device
arguments.device = "cuda"

# define arguments for model
#arguments.model = "ResNet18" # ResNet not supported for structured
arguments.model = "LeNet5"
arguments.hidden_dim = None
#arguments.input_dim = None # for ResNet
arguments.input_dim = (1,1,1) # for LeNet5
arguments.output_dim = 10
arguments.disable_masking = 1 # 0 for disable mask, 1 for mask (unstructured)
arguments.track_weights = 0
arguments.enable_rewinding = 0
arguments.growing_rate = 0.0000
arguments.outer_layer_pruning = 1
# arguments.prune_criterion = "SNIPit"  # unstructured

arguments.prune_criterion = "SNAPitDuring" # or SNAPit ... # structured
arguments.l0 = 0
arguments.l0_reg = 1.0
arguments.l1_reg = 0
arguments.lp_reg = 0
arguments.l2_reg = 5e-5
arguments.hoyer_reg = 0.001
arguments.N = 6000 # different for different dataset
arguments.beta_ema = 0.999


# define arguments for criterion
arguments.pruning_limit = 0.7
arguments.snip_steps = 6

# not pre-trained model
arguments.checkpoint_name = None
arguments.checkpoint_model = None

# dataset
arguments.data_set = "MNIST"
arguments.batch_size = 512
arguments.mean = (0.1307,)
arguments.std = (0.3081,)
arguments.tuning = 0
arguments.preload_all_data = 0
arguments.random_shuffle_labels = 0

# loss
arguments.loss = "CrossEntropy"

# optimizer
arguments.optimizer = "ADAM"
arguments.learning_rate = 2e-3

# training
arguments.save_freq = 1e6
arguments.eval = 0
arguments.train_scheme = "DefaultTrainer"
arguments.seed = 1234
arguments.epochs = 10

arguments.grad_noise = 0
arguments.grad_clip =10
arguments.eval_freq = 1000
arguments.max_training_minutes= 6120
arguments.plot_weights_freq = 50
arguments.prune_delay = 0
arguments.prune_freq = 1
arguments.rewind_to = 6

arguments.skip_first_plot = 0
arguments.disable_histograms = 0
arguments.disable_saliency = 0
arguments.disable_confusion = 0
arguments.disable_weightplot = 0
arguments.disable_netplot = 0
arguments.disable_activations = 0

arguments.pruning_rate = 0
# during training
arguments.pruning_freq = 1

In [19]:
metrics = Metrics()
out = metrics.log_line
print = out

ensure_current_directory()
global out 
out = metrics.log_line
out(f"starting at {get_date_stamp()}")

metrics._batch_size = arguments.batch_size
metrics._eval_freq = arguments.eval_freq

starting at 2022-04-18_01.10.59


In [20]:
device = configure_device(arguments)

In [21]:
# get model
model: GeneralModel = find_right_model(
        NETWORKS_DIR,arguments.model,
        device=device,
        hidden_dim = arguments.hidden_dim,
        input_dim = arguments.input_dim,
        output_dim = arguments.output_dim,
        is_maskable=arguments.disable_masking,
        is_tracking_weights=arguments.track_weights,
        is_rewindable=arguments.enable_rewinding,
        is_growable=arguments.growing_rate > 0,
        outer_layer_pruning=arguments.outer_layer_pruning,
        maintain_outer_mask_anyway=(
                                       not arguments.outer_layer_pruning) and (
                                           "Structured" in arguments.prune_criterion),
        l0=arguments.l0,
        l0_reg=arguments.l0_reg,
        N=arguments.N,
        beta_ema=arguments.beta_ema,
        l2_reg=arguments.l2_reg
    ).to(device)

In [22]:
for key,param in model.named_parameters():
    print(param.size())

torch.Size([6, 1, 5, 5])
torch.Size([6])
torch.Size([6])
torch.Size([6])
torch.Size([16, 6, 5, 5])
torch.Size([16])
torch.Size([16])
torch.Size([16])
torch.Size([120, 16, 5, 5])
torch.Size([120])
torch.Size([120])
torch.Size([120])
torch.Size([84, 1080])
torch.Size([84])
torch.Size([84])
torch.Size([84])
torch.Size([10, 84])
torch.Size([10])


In [23]:
# get criterion
criterion = find_right_model(
        CRITERION_DIR,arguments.prune_criterion,
        model=model,
        limit=arguments.pruning_limit,
        start=0.5,
        steps=arguments.snip_steps,
        device=arguments.device
    )   

In [24]:
def load_checkpoint(arguments, metrics, model):
    if (not (arguments.checkpoint_name is None)) and (not (arguments.checkpoint_model is None)):
        path = os.path.join(RESULTS_DIR, arguments.checkpoint_name, MODELS_DIR, arguments.checkpoint_model)
        state = DATA_MANAGER.load_python_obj(path)
        try:
            model.load_state_dict(state)
        except KeyError as e:
            print(list(state.keys()))
            raise e
        out(f"Loaded checkpoint {arguments.checkpoint_name} from {arguments.checkpoint_model}")

# load pre-trained weights if specified
load_checkpoint(arguments, metrics, model)  

In [25]:
# load data
train_loader, test_loader = find_right_model(
        DATASETS, arguments.data_set,
        arguments=arguments,
        mean=arguments.mean,
        std=arguments.std
    )

Using mean (0.1307,)


In [26]:
# get loss function
loss = find_right_model(
        LOSS_DIR, arguments.loss,
        device=device,
        l1_reg=arguments.l1_reg,
        lp_reg=arguments.lp_reg,
        l0_reg=arguments.l0_reg,
        hoyer_reg=arguments.hoyer_reg
    )

In [27]:
# get optimizer
optimizer = find_right_model(
        OPTIMS, arguments.optimizer,
        params=model.parameters(),
        lr=arguments.learning_rate,
        weight_decay=arguments.l2_reg if not arguments.l0 else 0
    )


In [28]:
if not arguments.eval:
    # build trainer
    run_name = f'_model={arguments.model}_dataset={arguments.data_set}_prune-criterion={arguments.prune_criterion}' + \
               f'_pruning-limit={arguments.pruning_limit}_train-scheme={arguments.train_scheme}_seed={arguments.seed}'
    trainer = find_right_model(
            TRAINERS_DIR, arguments.train_scheme,
            model=model,
            loss=loss,
            optimizer=optimizer,
            device=device,
            arguments=arguments,
            train_loader=train_loader,
            test_loader=test_loader,
            metrics=metrics,
            criterion=criterion,
            run_name = run_name
        )

Made datestamp: 2022-04-18_01.11.01_model=LeNet5_dataset=MNIST_prune-criterion=SNAPitDuring_pruning-limit=0.7_train-scheme=DefaultTrainer_seed=1234


In [29]:
trainer.train()

[1mStarted training[0m


[1mEPOCH 0 [0m 




Training... 0/118

Evaluating... 19/20

$  acc/train  |  loss/train  |  loss/test  |  acc/test  |  sparse/weight  |  sparse/node  |  sparse/hm  |  sparse/log_disk_size  |  time/gpu_time  |  time/flops_per_sample  |  time/flops_log_cum 
$  0.0976562  |  2.5814126   |  2.5531989  | 0.1848920  |    0.0000000    |   0.0000000   |  0.0000000  |       15.3924092       |    1.9441691    |     666174.0000000      |      8.5328576      
$ |  cuda/ram_footprint  |  time/batch_time  |  
$ |      2931200.0       |     0.0054764     |
Training... 117/118

plotting..
finished plotting


[1mEPOCH 1 [0m 




Training... 0/118

Evaluating... 19/20

$  acc/train  |  loss/train  |  loss/test  |  acc/test  |  sparse/weight  |  sparse/node  |  sparse/hm  |  sparse/log_disk_size  |  time/gpu_time  |  time/flops_per_sample  |  time/flops_log_cum 
$  0.9824219  |  0.0649967   |  0.0715947  | 0.9786420  |    0.0000000    |   0.0000000   |  0.0000000  |       1



[1mEPOCH 5 [0m 




Training... 0/118

Evaluating... 19/20

$  acc/train  |  loss/train  |  loss/test  |  acc/test  |  sparse/weight  |  sparse/node  |  sparse/hm  |  sparse/log_disk_size  |  time/gpu_time  |  time/flops_per_sample  |  time/flops_log_cum 
$  0.9746094  |  0.0982891   |  0.0525089  | 0.9874368  |    0.8842809    |   0.6769911   |  0.9330162  |       13.2360599       |   10.8410203    |     501787.0000000      |      11.2176924     
$ |  cuda/ram_footprint  |  time/batch_time  |  
$ |      1571840.0       |     0.0083905     |
Training... 117/118


PRUNING...

Saved results/2022-04-18_01.11.01_model=LeNet5_dataset=MNIST_prune-criterion=SNAPitDuring_pruning-limit=0.7_train-scheme=DefaultTrainer_seed=1234/output/scores
set to zero but not removed because of input-output compatibility: 0 (0.0 features)
trimming nodes in layer conv.0.weight from 6 to 6
pruning 0 percentage 0.0 length_nonzero 150
trimming nodes in layer conv.4.weight from 14 to 14
pruning 0 percentage 0.0



[1mEPOCH 9 [0m 




Training... 0/118

Evaluating... 19/20

$  acc/train  |  loss/train  |  loss/test  |  acc/test  |  sparse/weight  |  sparse/node  |  sparse/hm  |  sparse/log_disk_size  |  time/gpu_time  |  time/flops_per_sample  |  time/flops_log_cum 
$  0.9863281  |  0.0393877   |  0.0387464  | 0.9890740  |    0.9001166    |   0.6991150   |  0.9425009  |       13.0889307       |   11.2752195    |     499413.0000000      |      11.4562303     
$ |  cuda/ram_footprint  |  time/batch_time  |  
$ |      1525760.0       |     0.0083773     |
Training... 117/118


PRUNING...

finished all pruning events already


<Figure size 432x288 with 0 Axes>

In [17]:
for key,param in model.named_parameters():
    print(param.size())

torch.Size([3, 1, 5, 5])
torch.Size([3])
torch.Size([3])
torch.Size([3])
torch.Size([8, 3, 5, 5])
torch.Size([8])
torch.Size([8])
torch.Size([8])
torch.Size([6, 8, 5, 5])
torch.Size([6])
torch.Size([6])
torch.Size([6])
torch.Size([7, 54])
torch.Size([7])
torch.Size([7])
torch.Size([7])
torch.Size([10, 7])
torch.Size([10])


In [16]:
# for unstructured calculating sparsity
zero=0
non_zero=0
for weight in model.mask.values():
    zero += torch.sum(weight==0).item()
    non_zero += torch.sum(weight!=0).item()

print(zero/(zero+non_zero))

0.006026689625484288
