In [1]:
import sys
sys.path.append("..")

In [2]:
from asyncio import MultiLoopChildWatcher
from doctest import OutputChecker

from turtle import hideturtle
import warnings

from models import GeneralModel
from models.statistics.Metrics import Metrics
from utils.config_utils import *
from utils.model_utils import *
from utils.system_utils import *

warnings.filterwarnings("ignore")

The Zen of Python, by Tim Peters

Beautiful is better than ugly.
Explicit is better than implicit.
Simple is better than complex.
Complex is better than complicated.
Flat is better than nested.
Sparse is better than dense.
Readability counts.
Special cases aren't special enough to break the rules.
Although practicality beats purity.
Errors should never pass silently.
Unless explicitly silenced.
In the face of ambiguity, refuse the temptation to guess.
There should be one-- and preferably only one --obvious way to do it.
Although that way may not be obvious at first unless you're Dutch.
Now is better than never.
Although never is often better than *right* now.
If the implementation is hard to explain, it's a bad idea.
If the implementation is easy to explain, it may be a good idea.
Namespaces are one honking great idea -- let's do more of those!


In [3]:
# define arguments manually
arguments = argparse.Namespace()
# device
arguments.device = "cuda"

# define arguments for model
arguments.model = "ResNet18"
arguments.hidden_dim = None
arguments.input_dim = None
arguments.output_dim = 10
arguments.disable_masking = 1 # 0 for disable mask, 1 for mask (unstructured)
arguments.track_weights = 0
arguments.enable_rewinding = 0
arguments.growing_rate = 0.0000
arguments.outer_layer_pruning = 1
arguments.prune_criterion = "SNIPit"
arguments.l0 = 0
arguments.l0_reg = 1.0
arguments.l1_reg = 0
arguments.lp_reg = 0
arguments.l2_reg = 5e-5
arguments.hoyer_reg = 0.001
arguments.N = 6000 # different for different dataset
arguments.beta_ema = 0.999


# define arguments for criterion
arguments.pruning_limit = 0.5
arguments.snip_steps = 6

# not pre-trained model
arguments.checkpoint_name = None
arguments.checkpoint_model = None

# dataset
arguments.data_set = "CIFAR10"
arguments.batch_size = 512
arguments.mean = (0.4914, 0.4822, 0.4465)
arguments.std = (0.2471, 0.2435, 0.2616)
arguments.tuning = 0
arguments.preload_all_data = 0
arguments.random_shuffle_labels = 0

# loss
arguments.loss = "CrossEntropy"

# optimizer
arguments.optimizer = "ADAM"
arguments.learning_rate = 2e-3

# training
arguments.save_freq = 1e6
arguments.eval = 0
arguments.train_scheme = "DefaultTrainer"
arguments.seed = 1234
arguments.epochs = 10

arguments.grad_noise = 0
arguments.grad_clip =10
arguments.eval_freq = 1000
arguments.max_training_minutes= 6120
arguments.plot_weights_freq = 10
arguments.prune_delay = 0
arguments.prune_freq = 1
arguments.rewind_to = 6

arguments.skip_first_plot = 0
arguments.disable_histograms = 0
arguments.disable_saliency = 0
arguments.disable_confusion = 0
arguments.disable_weightplot = 0
arguments.disable_netplot = 0
arguments.disable_activations = 0

In [4]:
metrics = Metrics()
out = metrics.log_line
print = out

ensure_current_directory()
global out 
out = metrics.log_line
out(f"starting at {get_date_stamp()}")

metrics._batch_size = arguments.batch_size
metrics._eval_freq = arguments.eval_freq

starting at 2022-04-17_19.12.53


In [5]:
device = configure_device(arguments)

In [6]:
# get model
model: GeneralModel = find_right_model(
        NETWORKS_DIR,arguments.model,
        device=device,
        hidden_dim = arguments.hidden_dim,
        input_dim = arguments.input_dim,
        output_dim = arguments.output_dim,
        is_maskable=arguments.disable_masking,
        is_tracking_weights=arguments.track_weights,
        is_rewindable=arguments.enable_rewinding,
        is_growable=arguments.growing_rate > 0,
        outer_layer_pruning=arguments.outer_layer_pruning,
        maintain_outer_mask_anyway=(
                                       not arguments.outer_layer_pruning) and (
                                           "Structured" in arguments.prune_criterion),
        l0=arguments.l0,
        l0_reg=arguments.l0_reg,
        N=arguments.N,
        beta_ema=arguments.beta_ema,
        l2_reg=arguments.l2_reg
    ).to(device)

output_dim:10


In [7]:
# get criterion
criterion = find_right_model(
        CRITERION_DIR,arguments.prune_criterion,
        model=model,
        limit=arguments.pruning_limit,
        start=0.5,
        steps=arguments.snip_steps,
        device=arguments.device
    )   

In [8]:
def load_checkpoint(arguments, metrics, model):
    if (not (arguments.checkpoint_name is None)) and (not (arguments.checkpoint_model is None)):
        path = os.path.join(RESULTS_DIR, arguments.checkpoint_name, MODELS_DIR, arguments.checkpoint_model)
        state = DATA_MANAGER.load_python_obj(path)
        try:
            model.load_state_dict(state)
        except KeyError as e:
            print(list(state.keys()))
            raise e
        out(f"Loaded checkpoint {arguments.checkpoint_name} from {arguments.checkpoint_model}")

# load pre-trained weights if specified
load_checkpoint(arguments, metrics, model)  

In [9]:
# load data
train_loader, test_loader = find_right_model(
        DATASETS, arguments.data_set,
        arguments=arguments,
        mean=arguments.mean,
        std=arguments.std
    )

Using mean (0.4914, 0.4822, 0.4465)
Files already downloaded and verified
Files already downloaded and verified


In [10]:
# get loss function
loss = find_right_model(
        LOSS_DIR, arguments.loss,
        device=device,
        l1_reg=arguments.l1_reg,
        lp_reg=arguments.lp_reg,
        l0_reg=arguments.l0_reg,
        hoyer_reg=arguments.hoyer_reg
    )

In [11]:
# get optimizer
optimizer = find_right_model(
        OPTIMS, arguments.optimizer,
        params=model.parameters(),
        lr=arguments.learning_rate,
        weight_decay=arguments.l2_reg if not arguments.l0 else 0
    )


In [12]:
if not arguments.eval:
    # build trainer
    run_name = f'_model={arguments.model}_dataset={arguments.data_set}_prune-criterion={arguments.prune_criterion}' + \
               f'_pruning-limit={arguments.pruning_limit}_train-scheme={arguments.train_scheme}_seed={arguments.seed}'
    trainer = find_right_model(
            TRAINERS_DIR, arguments.train_scheme,
            model=model,
            loss=loss,
            optimizer=optimizer,
            device=device,
            arguments=arguments,
            train_loader=train_loader,
            test_loader=test_loader,
            metrics=metrics,
            criterion=criterion,
            run_name = run_name
        )

Made datestamp: 2022-04-17_19.12.58_model=ResNet18_dataset=CIFAR10_prune-criterion=SNIPit_pruning-limit=0.5_train-scheme=DefaultTrainer_seed=1234


In [13]:
trainer.train()

[1mStarted training[0m
Saved results/2022-04-17_19.12.58_model=ResNet18_dataset=CIFAR10_prune-criterion=SNIPit_pruning-limit=0.5_train-scheme=DefaultTrainer_seed=1234/output/scores
1728.0
pruning 8 percentage 0.004629629629629629 length_nonzero 1728.0
36864.0
pruning 7491 percentage 0.20320638020833334 length_nonzero 36864.0
36864.0
pruning 7249 percentage 0.19664171006944445 length_nonzero 36864.0
36864.0
pruning 6850 percentage 0.1858181423611111 length_nonzero 36864.0
36864.0
pruning 7047 percentage 0.191162109375 length_nonzero 36864.0
73728.0
pruning 10710 percentage 0.145263671875 length_nonzero 73728.0
147456.0
pruning 44523 percentage 0.30194091796875 length_nonzero 147456.0
8192.0
pruning 291 percentage 0.0355224609375 length_nonzero 8192.0
147456.0
pruning 62982 percentage 0.4271240234375 length_nonzero 147456.0
147456.0
pruning 51859 percentage 0.3516913519965278 length_nonzero 147456.0
294912.0
pruning 81770 percentage 0.2772691514756944 length_nonzero 294912.0
589824.0
p

Saved results/2022-04-17_19.12.58_model=ResNet18_dataset=CIFAR10_prune-criterion=SNIPit_pruning-limit=0.5_train-scheme=DefaultTrainer_seed=1234/output/scores
1720.0
pruning 8 percentage 0.004629629629629629 length_nonzero 1728.0
29373.0
pruning 7491 percentage 0.20320638020833334 length_nonzero 36864.0
29615.0
pruning 7249 percentage 0.19664171006944445 length_nonzero 36864.0
30014.0
pruning 6850 percentage 0.1858181423611111 length_nonzero 36864.0
29817.0
pruning 7047 percentage 0.191162109375 length_nonzero 36864.0
63018.0
pruning 10710 percentage 0.145263671875 length_nonzero 73728.0
102933.0
pruning 44523 percentage 0.30194091796875 length_nonzero 147456.0
7901.0
pruning 291 percentage 0.0355224609375 length_nonzero 8192.0
84474.0
pruning 62982 percentage 0.4271240234375 length_nonzero 147456.0
95597.0
pruning 51859 percentage 0.3516913519965278 length_nonzero 147456.0
213142.0
pruning 81770 percentage 0.2772691514756944 length_nonzero 294912.0
363052.0
pruning 226772 percentage 0.

Evaluating... 19/20

$  acc/train  |  loss/train  |  loss/test  |  acc/test  |  sparse/weight  |  sparse/node  |  sparse/hm  |  sparse/log_disk_size  |  time/gpu_time  |  time/flops_per_sample  |  time/flops_log_cum 
$  0.7480469  |  0.6875496   |  0.9487832  | 0.6893382  |    0.0000000    |   0.0000000   |  0.0000000  |       19.0602786       |   198.3574891   |       567743488.0       |      14.1545036     
$ |  cuda/ram_footprint  |  time/batch_time  |  
$ |     231605248.0      |     0.0259371     |
Training... 97/98



[1mEPOCH 6 [0m 




Training... 0/98

Evaluating... 19/20

$  acc/train  |  loss/train  |  loss/test  |  acc/test  |  sparse/weight  |  sparse/node  |  sparse/hm  |  sparse/log_disk_size  |  time/gpu_time  |  time/flops_per_sample  |  time/flops_log_cum 
$  0.8164062  |  0.5420224   |  0.7510330  | 0.7513500  |    0.0000000    |   0.0000000   |  0.0000000  |       19.0602786       |   199.7883278   |       567743488.0       |      14.2335374     
$ |  cuda/ram_foo

<Figure size 432x288 with 0 Axes>

In [48]:
zero=0
non_zero=0
for weight in model.mask.values():
    zero += torch.sum(weight==0).item()
    non_zero += torch.sum(weight!=0).item()

print(zero/(ze))

tensor(8, device='cuda:0')
tensor(7491, device='cuda:0')
tensor(7249, device='cuda:0')
tensor(6850, device='cuda:0')
tensor(7047, device='cuda:0')
tensor(10710, device='cuda:0')
tensor(44523, device='cuda:0')
tensor(291, device='cuda:0')
tensor(62982, device='cuda:0')
tensor(51859, device='cuda:0')
tensor(81770, device='cuda:0')
tensor(226772, device='cuda:0')
tensor(3072, device='cuda:0')
tensor(304163, device='cuda:0')
tensor(257919, device='cuda:0')
tensor(487975, device='cuda:0')
tensor(1190977, device='cuda:0')
tensor(18489, device='cuda:0')
tensor(1575536, device='cuda:0')
tensor(1248050, device='cuda:0')
tensor(35, device='cuda:0')


In [25]:
parameter==0

tensor([[[[False, False, False],
          [False, False, False],
          [False, False, False]],

         [[False, False, False],
          [False, False, False],
          [False, False, False]],

         [[False, False, False],
          [False, False, False],
          [False, False, False]]],


        [[[False, False, False],
          [False, False, False],
          [False, False, False]],

         [[False, False, False],
          [False, False, False],
          [False, False, False]],

         [[False, False, False],
          [False, False, False],
          [False, False, False]]],


        [[[False, False, False],
          [False, False, False],
          [False, False, False]],

         [[False, False, False],
          [False, False, False],
          [False, False, False]],

         [[False, False, False],
          [False, False, False],
          [False, False, False]]],


        ...,


        [[[False, False, False],
          [False, False, False],
    