In [1]:
import pickle

In [2]:
import numpy as np

In [3]:
import copy

In [4]:
import torch

In [5]:
from collections import OrderedDict

In [6]:
def picklify(path, obj):
    with open(path, 'wb') as f:
        pickle.dump(obj, f)


def unpicklify(path):
    with open(path, 'rb') as f:
        return pickle.load(f)

In [4]:
mask = unpicklify('./dumps/lt/fc5/mnist/lt_mask_26.4.pkl')

In [9]:
for m in mask:
    print(m.shape)

(256, 784)
(256, 256)
(256, 256)
(256, 256)
(10, 256)


In [58]:
init_model = torch.load('./saves/fc5/mnist/initial_state_dict_lt.pth.tar')

In [59]:
init_model

fc5(
  (classifier): Sequential(
    (0): Linear(in_features=784, out_features=256, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): ReLU(inplace=True)
    (4): Linear(in_features=256, out_features=256, bias=True)
    (5): ReLU(inplace=True)
    (6): Linear(in_features=256, out_features=256, bias=True)
    (7): ReLU(inplace=True)
    (8): Linear(in_features=256, out_features=10, bias=True)
  )
)

In [60]:
init_model.classifier[0].weight

Parameter containing:
tensor([[ 0.0082, -0.0284,  0.0101,  ...,  0.0559, -0.0141, -0.0623],
        [ 0.0074,  0.0714,  0.0041,  ..., -0.0214, -0.0116, -0.0516],
        [ 0.0163,  0.0478,  0.0005,  ..., -0.0298,  0.0102,  0.0261],
        ...,
        [-0.0312, -0.0247,  0.0187,  ..., -0.0133, -0.0078, -0.0139],
        [ 0.0156,  0.0035, -0.0214,  ..., -0.0427,  0.0868, -0.0070],
        [ 0.0508,  0.0040, -0.1378,  ..., -0.0457, -0.0197,  0.0103]],
       device='cuda:0', requires_grad=True)

In [61]:
init_sd = init_model.state_dict()

In [62]:
torch.save(init_sd, 'mlp5_mnist_init_state_dict.pth.tar')

In [None]:
new_sd = OrderedDict()
for k,v in sd.items():
    i, name = k.split('.')[1:]
    i = int(i)
    new_name = f'{i+1}.{name}'
    new_sd[new_name] = v

In [None]:
torch.save(new_sd, 'mlp5_mnist_init_state_dict_for_synflow_repo.pth.tar')

In [25]:
def combine_weight_mask(dst_path, weight_src, mask_src, transpose=False):
    m_w = torch.load(weight_src)
    m_m = torch.load(mask_src)
    new_sd = OrderedDict()
    for (k_w, v_w), (k_m, v_m) in zip(m_w.named_parameters(), m_m.named_parameters()):
        assert k_w == k_m
        mask = (v_m != 0.0).type(v_w.dtype)
        masked_w = v_w * mask
        if transpose and 'weight' in k_w:
            masked_w = masked_w.transpose(0,1)
        new_sd[k_w] = masked_w
        print(masked_w.shape)
    torch.save(new_sd, dst_path)

In [26]:
combine_weight_mask("mlp5_mnist_wt_untrained.pth.tar", './saves/fc5/mnist/initial_state_dict_lt.pth.tar', './saves/fc5/mnist/9_model_lt.pth.tar')

torch.Size([256, 784])
torch.Size([256])
torch.Size([256, 256])
torch.Size([256])
torch.Size([256, 256])
torch.Size([256])
torch.Size([256, 256])
torch.Size([256])
torch.Size([10, 256])
torch.Size([10])


In [67]:
sd['classifier.0.weight'].dtype

torch.float32

In [11]:
def combine_weight_mask_synflow(dst_path, weight_src, mask_src, transpose=False):
    sd_w = torch.load(weight_src)
    sd_m = torch.load(mask_src)
    new_sd = OrderedDict()
    for k in sd_w.keys():
        assert k in sd_m.keys()
        if k.endswith('mask'):
            continue
        if k.endswith('weight'):
            masked_w = sd_w[k] * sd_m[k+'_mask']
            if transpose:
                masked_w = masked_w.transpose(0,1)
            new_sd[k] = masked_w
        else:
            new_sd[k] = sd_w[k]
        print(new_sd[k].shape)
    torch.save(new_sd, dst_path)

In [42]:
combine_weight_mask_synflow(
    'mlp5_mnist_mag_pai_untrained.pth.tar', '/datadrive_c/xiaohan/network-dl/Synaptic-Flow/mnist/mlp5_mag/singleshot/0/init_model.pt',
    '/datadrive_c/xiaohan/network-dl/Synaptic-Flow/mnist/mlp5_mag/singleshot/0/post-train_model.pt'
)

torch.Size([256, 784])
torch.Size([256])
torch.Size([256, 256])
torch.Size([256])
torch.Size([256, 256])
torch.Size([256])
torch.Size([256, 256])
torch.Size([256])
torch.Size([10, 256])
torch.Size([10])


In [121]:
combine_weight_mask_synflow(
    'mlp5_mnist_synflow_pai_untrained.pth.tar', '/datadrive_c/xiaohan/network-dl/Synaptic-Flow/mnist/mlp5_synflow/singleshot/0/init_model.pt',
    '/datadrive_c/xiaohan/network-dl/Synaptic-Flow/mnist/mlp5_synflow/singleshot/0/post-train_model.pt'
)

torch.Size([256, 784])
torch.Size([256])
torch.Size([256, 256])
torch.Size([256])
torch.Size([256, 256])
torch.Size([256])
torch.Size([256, 256])
torch.Size([256])
torch.Size([10, 256])
torch.Size([10])


In [122]:
combine_weight_mask_synflow(
    'mlp5_mnist_grasp_pai_untrained.pth.tar', '/datadrive_c/xiaohan/network-dl/Synaptic-Flow/mnist/mlp5_grasp/singleshot/0/init_model.pt',
    '/datadrive_c/xiaohan/network-dl/Synaptic-Flow/mnist/mlp5_grasp/singleshot/0/post-train_model.pt'
)

torch.Size([256, 784])
torch.Size([256])
torch.Size([256, 256])
torch.Size([256])
torch.Size([256, 256])
torch.Size([256])
torch.Size([256, 256])
torch.Size([256])
torch.Size([10, 256])
torch.Size([10])


In [123]:
combine_weight_mask_synflow(
    'mlp5_mnist_snip_pai_untrained.pth.tar', '/datadrive_c/xiaohan/network-dl/Synaptic-Flow/mnist/mlp5_snip/singleshot/0/init_model.pt',
    '/datadrive_c/xiaohan/network-dl/Synaptic-Flow/mnist/mlp5_snip/singleshot/0/post-train_model.pt'
)

torch.Size([256, 784])
torch.Size([256])
torch.Size([256, 256])
torch.Size([256])
torch.Size([256, 256])
torch.Size([256])
torch.Size([256, 256])
torch.Size([256])
torch.Size([10, 256])
torch.Size([10])


## Random Pruning

In [None]:
init_sd = 

In [27]:
wt_untrained_sd = torch.load('./mlp5_mnist_wt_untrained.pth.tar')

In [39]:
for k, v in wt_untrained_sd.items():
    if 'weight' in k:
        v = v.detach().cpu().numpy()
        numel = v.size
        numz = np.sum(v == 0.0)
        sr = numz / numel * 100
        print(k, sr)

classifier.0.weight 86.5787428252551
classifier.2.weight 86.57989501953125
classifier.4.weight 86.57989501953125
classifier.6.weight 86.57989501953125
classifier.8.weight 86.6015625


In [33]:
wt_untrained_pckl = []
for k, v in wt_untrained_sd.items():
    if 'weight' in k:
        print(k)
        wt_untrained_pckl.append(v.transpose(0,1).detach().cpu().numpy())
picklify('random_pruning_exp_pckl_files/mlp5_mnist_wt_untrained.pckl', wt_untrained_pckl)

classifier.0.weight
classifier.2.weight
classifier.4.weight
classifier.6.weight
classifier.8.weight


In [49]:
mag_pai_untrained_sd = torch.load('./mlp5_mnist_mag_pai_untrained.pth.tar')

In [45]:
for k, v in mag_pai_untrained_sd.items():
    if 'weight' in k:
        v = v.detach().cpu().numpy()
        numel = v.size
        numz = np.sum(v == 0.0)
        sr = numz / numel * 100
        print(k, sr)

1.weight 86.57774633290816
3.weight 86.57684326171875
5.weight 86.57684326171875
7.weight 86.57684326171875
9.weight 86.5625


In [46]:
mag_pai_untrained_pckl = []
for k, v in mag_pai_untrained_sd.items():
    if 'weight' in k:
        print(k)
        mag_pai_untrained_pckl.append(v.transpose(0,1).detach().cpu().numpy())
picklify('random_pruning_exp_pckl_files/mlp5_mnist_mag_pai_untrained.pckl', mag_pai_untrained_pckl)

1.weight
3.weight
5.weight
7.weight
9.weight


In [109]:
new_mag_pai_untrained_sd = OrderedDict()
for k, v in mag_pai_untrained_sd.items():
    i, name = k.split('.')
    new_name = f'classifier.{int(i)-1}.{name}'
    print(new_name)
    new_mag_pai_untrained_sd[new_name] = v.to(0)
torch.save(new_mag_pai_untrained_sd, './random_pruning_exp_state_dict_files/mlp5_mnist_mag_pai_untrained.pth.tar')
# picklify('random_pruning_exp_pckl_files/mlp5_mnist_mag_pai_untrained.pckl', mag_pai_untrained_pckl)

classifier.0.weight
classifier.0.bias
classifier.2.weight
classifier.2.bias
classifier.4.weight
classifier.4.bias
classifier.6.weight
classifier.6.bias
classifier.8.weight
classifier.8.bias


In [126]:
synflow_pai_untrained_sd = torch.load('./mlp5_mnist_synflow_pai_untrained_from_synflow.pth.tar')

for k, v in synflow_pai_untrained_sd.items():
    if 'weight' in k:
        v = v.detach().cpu().numpy()
        numel = v.size
        numz = np.sum(v == 0.0)
        sr = numz / numel * 100
        print(k, sr)

synflow_pai_untrained_pckl = []
for k, v in synflow_pai_untrained_sd.items():
    if 'weight' in k:
        print(k)
        synflow_pai_untrained_pckl.append(v.transpose(0,1).detach().cpu().numpy())
picklify('random_pruning_exp_pckl_files/mlp5_mnist_synflow_pai_untrained.pckl', synflow_pai_untrained_pckl)

new_synflow_pai_untrained_sd = OrderedDict()
for k, v in synflow_pai_untrained_sd.items():
    i, name = k.split('.')
    new_name = f'classifier.{int(i)-1}.{name}'
    print(new_name)
    new_synflow_pai_untrained_sd[new_name] = v.to(0)
torch.save(new_synflow_pai_untrained_sd, './random_pruning_exp_state_dict_files/mlp5_mnist_synflow_pai_untrained.pth.tar')

1.weight 86.57774633290816
3.weight 86.57684326171875
5.weight 86.57684326171875
7.weight 86.57684326171875
9.weight 86.5625
1.weight
3.weight
5.weight
7.weight
9.weight
classifier.0.weight
classifier.0.bias
classifier.2.weight
classifier.2.bias
classifier.4.weight
classifier.4.bias
classifier.6.weight
classifier.6.bias
classifier.8.weight
classifier.8.bias


### Random Pruning Part

In [7]:
init_sd = torch.load('./mlp5_mnist_init_state_dict.pth.tar')

In [17]:
sparse_sd = torch.load('./mlp5_mnist_wt_untrained.pth.tar')

In [26]:
sparse_sd['classifier.0.weight']

tensor([[ 0.0000, -0.0000,  0.0000,  ...,  0.0559, -0.0000, -0.0000],
        [ 0.0000,  0.0000,  0.0000,  ..., -0.0000, -0.0000, -0.0000],
        [ 0.0000,  0.0000,  0.0000,  ..., -0.0000,  0.0000,  0.0000],
        ...,
        [-0.0000, -0.0000,  0.0000,  ..., -0.0000, -0.0000, -0.0000],
        [ 0.0000,  0.0000, -0.0000,  ..., -0.0000,  0.0868, -0.0000],
        [ 0.0000,  0.0000, -0.1378,  ..., -0.0000, -0.0000,  0.0000]],
       device='cuda:0', requires_grad=True)

In [27]:
init_sd['classifier.0.weight']

tensor([[ 0.0082, -0.0284,  0.0101,  ...,  0.0559, -0.0141, -0.0623],
        [ 0.0074,  0.0714,  0.0041,  ..., -0.0214, -0.0116, -0.0516],
        [ 0.0163,  0.0478,  0.0005,  ..., -0.0298,  0.0102,  0.0261],
        ...,
        [-0.0312, -0.0247,  0.0187,  ..., -0.0133, -0.0078, -0.0139],
        [ 0.0156,  0.0035, -0.0214,  ..., -0.0427,  0.0868, -0.0070],
        [ 0.0508,  0.0040, -0.1378,  ..., -0.0457, -0.0197,  0.0103]],
       device='cuda:0')

In [37]:
alpha0_sd = torch.load('./random_pruning_alpha_exp_state_dict_files/mlp5_mnist_rp_alpha0.0_seed1_untrained.pth.tar')

In [38]:
alpha0_sd['classifier.0.weight']

tensor([[ 0.0000, -0.0000,  0.0000,  ...,  0.0559, -0.0000, -0.0000],
        [ 0.0000,  0.0000,  0.0000,  ..., -0.0000, -0.0000, -0.0000],
        [ 0.0000,  0.0000,  0.0000,  ..., -0.0000,  0.0000,  0.0000],
        ...,
        [-0.0000, -0.0000,  0.0000,  ..., -0.0000, -0.0000, -0.0000],
        [ 0.0000,  0.0000, -0.0000,  ..., -0.0000,  0.0868, -0.0000],
        [ 0.0000,  0.0000, -0.1378,  ..., -0.0000, -0.0000,  0.0000]],
       device='cuda:0')

In [8]:
init_sd.keys()

odict_keys(['classifier.0.weight', 'classifier.0.bias', 'classifier.2.weight', 'classifier.2.bias', 'classifier.4.weight', 'classifier.4.bias', 'classifier.6.weight', 'classifier.6.bias', 'classifier.8.weight', 'classifier.8.bias'])

In [10]:
def random_prune_mlp_from_init_sd(init_sd, pruning_ratio, seed=42):
    np.random.seed(seed=seed)
    new_sd = OrderedDict()
    new_pckl = []
    for k, v in init_sd.items():
        if 'weight' in k:
            numel = v.numel()
            num_pruned = int(numel * pruning_ratio)
            mask = np.ones(numel).astype(np.float32)
            perm = np.random.permutation(numel)
            zero_idx = perm[:num_pruned]
            mask[zero_idx] = 0.0
            mask_t = v.new_tensor(mask).reshape_as(v)
            masked_v = v * mask_t
            new_pckl.append(masked_v.transpose(0,1).detach().cpu().numpy())
            new_sd[k] = masked_v
        else:
            new_sd[k] = v
    return new_sd, new_pckl

In [40]:
def random_prune_mlp_alpha_from_init_sd(init_sd, sparse_sd, pruning_ratio, alpha=1.0, seed=42):
    np.random.seed(seed=seed)
    new_sd = OrderedDict()
    new_pckl = []
    for k, v in init_sd.items():
        if 'weight' in k:
            assert k in sparse_sd

            v_init_np = v.detach().cpu().numpy()
            v_sparse_np = sparse_sd[k].detach().cpu().numpy()
            assert v_init_np.size == v_sparse_np.size
            
            numel = v.numel()
            zero_idx_v = (v_sparse_np == 0.0)
            
            score = np.random.rand(*v_init_np.shape)
            score[zero_idx_v] *= alpha
            
            percentile_value = np.quantile(score, pruning_ratio)
            
            mask = np.ones_like(v_init_np).astype(np.float32)
            mask = np.where(score <= percentile_value, 0.0, mask)
            
            mask_t = v.new_tensor(mask)
            masked_v = v * mask_t
            new_pckl.append(masked_v.transpose(0,1).detach().cpu().numpy())
            new_sd[k] = masked_v
        else:
            new_sd[k] = v
    return new_sd, new_pckl

In [11]:
a = np.array([[10, 7, 4], [3, 2, 1]])
np.percentile(a, 20)

2.0

In [100]:
rp_sd, rp_pckl = random_prune_mlp_from_init_sd(init_sd, 0.8657774633290816)

In [93]:
for k, v in rp_sd.items():
    if 'weight' in k:
        v = v.detach().cpu().numpy()
        numel = v.size
        numz = np.sum(v == 0.0)
        sr = numz / numel * 100
        print(k, sr)

classifier.0.weight 86.57774633290816
classifier.2.weight 86.57684326171875
classifier.4.weight 86.57684326171875
classifier.6.weight 86.57684326171875
classifier.8.weight 86.5625


In [102]:
for _ in rp_pckl:
    print(_.shape)

(784, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 10)


In [103]:
for seed in range(1, 41):
    rp_sd, rp_pckl = random_prune_mlp_from_init_sd(init_sd, 0.8657774633290816, seed=seed)
    torch.save(rp_sd, f'./random_pruning_exp_state_dict_files/mlp5_mnist_rp_seed{seed}_untrained.pth.tar')
    picklify(f'./random_pruning_exp_pckl_files/mlp5_mnist_rp_seed{seed}_untrained.pckl', rp_pckl)

In [41]:
for seed in range(1, 6):
#     for alpha in [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]:
    for alpha in [0.1, 0.3, 0.5, 0.7, 0.9]:
        rp_sd, rp_pckl = random_prune_mlp_alpha_from_init_sd(init_sd, sparse_sd, 0.8657774633290816, alpha=alpha, seed=seed)
        torch.save(rp_sd, f'./random_pruning_alpha_exp_state_dict_files/mlp5_mnist_rp_alpha{alpha}_seed{seed}_untrained.pth.tar')
        picklify(f'./random_pruning_alpha_exp_pckl_files/mlp5_mnist_rp_alpha{alpha}_seed{seed}_untrained.pckl', rp_pckl)

## Collect Random Pruning Experimental Results

In [43]:
import os

In [44]:
import glob

In [57]:
exp_dirs = glob.glob('random_pruning_exp_outputs/*') + glob.glob('random_pruning_alpha_exp_outputs/*')

In [58]:
exp_dirs

['random_pruning_exp_outputs/mlp5_mnist_rp_seed4_untrained',
 'random_pruning_exp_outputs/mlp5_mnist_rp_seed31_untrained',
 'random_pruning_exp_outputs/mlp5_mnist_rp_seed19_untrained',
 'random_pruning_exp_outputs/mlp5_mnist_rp_seed3_untrained',
 'random_pruning_exp_outputs/mlp5_mnist_rp_seed12_untrained',
 'random_pruning_exp_outputs/mlp5_mnist_rp_seed18_untrained',
 'random_pruning_exp_outputs/mlp5_mnist_rp_seed15_untrained',
 'random_pruning_exp_outputs/mlp5_mnist_rp_seed10_untrained',
 'random_pruning_exp_outputs/mlp5_mnist_rp_seed34_untrained',
 'random_pruning_exp_outputs/mlp5_mnist_rp_seed33_untrained',
 'random_pruning_exp_outputs/mlp5_mnist_rp_seed40_untrained',
 'random_pruning_exp_outputs/mlp5_mnist_grasp_pai_untrained',
 'random_pruning_exp_outputs/mlp5_mnist_rp_seed38_untrained',
 'random_pruning_exp_outputs/mlp5_mnist_rp_seed24_untrained',
 'random_pruning_exp_outputs/mlp5_mnist_rp_seed21_untrained',
 'random_pruning_exp_outputs/mlp5_mnist_rp_seed27_untrained',
 'random_p

In [59]:
len(exp_dirs)

100

In [52]:
random_pruning_acc_results = dict()

In [60]:
for exp_dir in exp_dirs:
    exp_name = os.path.basename(exp_dir)
#     try:
    bestacc = unpicklify(os.path.join(exp_dir, 'dumps', 'lt_bestaccuracy.dat'))[0]
#     except:
#         continue
    random_pruning_acc_results[exp_name] = bestacc

In [61]:
picklify('random_pruning_exp_accuracy_results.pckl', random_pruning_acc_results)

In [None]:
print(random_pruning_acc_results.keys())