In [2]:
import os
os.chdir("/workspaces/BAIT")
from copy import copy
from src.experiment import Experiment, get_instance_id, load_base_model_weights
from src.phase_actions import run_phase
from utils.params import ModelParams, DataParams, TrainingPhase, PruningPhase, IterativePruningPhase
import re
from datetime import datetime
from utils.monitor import MONITOR
import numpy as np
import pandas as pd
import torch
import seaborn as sns
from analysis.results_utils import *

algo = "mag"
sparsity = 0.2
base_model_id = "47ee372c-ce4a-11ec-8633-0242ac120002"

base_experiment = Experiment()
base_experiment.repeats = 1 # single repeat but there will be 3 container instances resulting in 3 repeats

base_experiment.model_params = ModelParams(model="lenet_300_100", init_strategy="synflow")
base_experiment.data_params = DataParams(dataset="mnist")
expr = copy(base_experiment)
expr.pruning_phases = [
    IterativePruningPhase(
        prune_params = PruningPhase(
            strategy=algo,
            sparsity=sparsity,
            prune_epochs=2
        ),
        train_params = TrainingPhase(
            train_epochs=2
        ),
        iterations = 2,
        rewind = True
    )
]
expr.name = f"oneshot_{algo}_{sparsity}"
time = re.sub("[-_\. :]", "", str(datetime.now()))
expr.full_name = f"{expr.name}_{get_instance_id()}_{time}"
MONITOR.start(expr.full_name)
s = State(expr.model_params, expr.data_params, expr)
load_base_model_weights(base_model_id, s.model)
s.base_model_id = base_model_id
s.bake_initial_state()

run_phase(s, expr.pruning_phases[0])

Loading mnist dataset.
Creating lottery-lenet_300_100 model.
iterative pruning with mag until 0.2 sparsity in 2 steps with max 2 epochs of training
starting with 100.0% weights remaining
callback bake_rewind_state at iteration 0
regular training for 2 epochs


100%|██████████| 2/2 [00:05<00:00,  2.96s/it]
100%|██████████| 2/2 [00:00<00:00, 556.64it/s]


rewinding state at iteration 1876
pruned to 60.0%
regular training for 2 epochs


100%|██████████| 2/2 [00:05<00:00,  2.67s/it]
100%|██████████| 2/2 [00:00<00:00, 555.83it/s]


pruned to 20.0%
regular training for 2 epochs


100%|██████████| 2/2 [00:05<00:00,  2.80s/it]


In [8]:
torch.rand(4).round_() == 0

tensor([False, False,  True,  True])