In [None]:
from collections import OrderedDict
import os
from datetime import datetime
from copy import deepcopy
import json

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt

from scripts.data import get_train_and_test_loaders
from scripts.models import *
from scripts.visualizations import *
from scripts.pruning import *
from scripts.activation_patterns import *
from scripts.stats import *
from scripts.activation_regions import *
from scripts.lp_decision_trees import *
from scripts.utils import *

In [None]:
!pip list | grep torch

In [None]:
random_seed = 1

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
print(f'device: {device}')

## Hyperparams
...That you actually change

In [None]:
hyperparams = {}

In [None]:
notebook_name = 'sparse-vs-dense'
test_name = 'example'

hyperparams['misc'] = {
    'notebook_name': notebook_name,
    'test_name': test_name,
    'random_seed': random_seed
}

### Saving

In [None]:
save_models = True
load_models = False # not implemented

save_data = True
compute_acc = True
compute_stats = False # to compute specialization and AR-data as well

In [None]:
if save_models or load_models:
    assert save_models is not load_models, 'you cannot load and save models at the same time'

In [None]:
hyperparams['experiments'] = {
    'compute_stats': compute_stats,
    'compute_acc': compute_acc,
    'save_results': save_data
}

### Data

In [None]:
normalize_data = False
normalize_by_features = True # doesnt matter if normalize_data == False

normalize_by_moving = True # doesnt matter if normalize_data == False
normalize_by_scaling = True # doesnt matter if normalize_data == False

batch_size = 60
test_batch_size = 1000

# digit, fashion
data_name = 'fashion'

In [None]:
training_data_dir = './mnist/data'

In [None]:
hyperparams['data'] = {
    'normalize_data': normalize_data,
    'normalize_by_features': normalize_by_features,
    'normalize_by_moving': normalize_by_moving,
    'normalize_by_scaling': normalize_by_scaling,
    'batch_size': batch_size,
    'test_batch_size': test_batch_size,
    'data_name': data_name,
    'data_dir': training_data_dir
}

### Models

In [None]:
max_hid_neurons = 400
min_hid_neurons = 16

n_sizes = 4
size_scheme = 'log'

smallest_net_densities = [1,0.5]

In [None]:
sizes = get_sizes(n_sizes, max_hid_neurons, min_hid_neurons, size_scheme)
print(sizes)

In [None]:
iterations = 50000

bias_std = 10**-6

n_models = {'lenet': 5}
n_models['deepfc'] = 0
n_models['pipefc'] = 0

In [None]:
models_total =  count_models_n_total(n_sizes, smallest_net_densities, n_models)
print(f'models in total {models_total}')

In [None]:
#learning hyperparams
lr = 1.2e-3

In [None]:
hidden_dims_dict = get_hidden_dims(max_hid_neurons)

In [None]:
evaluation_scheme = get_eval_scheme(iterations)
print(evaluation_scheme, len(evaluation_scheme))

In [None]:
hyperparams['models'] = {
    'sizes': sizes,
    'iterations': iterations,
    'bias_std': bias_std,
    'max_hid_neurons': max_hid_neurons,
    'smallest_net_densities': smallest_net_densities,
    'hidden_dims_dict': hidden_dims_dict,
    'n_models': n_models,
    'models_total': models_total,
    'learning_rate': lr,
    'evaluation_scheme': evaluation_scheme
}

### Pruning

In [None]:
prune_weights = True # False -> prune nodes
layer_wise_pruning = True # False -> global pruning (not implemented for pytorch pruning)

prune_all_layers = False # should also the weights on the output layer be pruned?

random_mask = True # does the random init network have a random mask as well?

xscale = 'linear' # depends how we define the pruning rates, either linear or logit

In [None]:
hyperparams['pruning'] = {
    'prune_weights': prune_weights,
    'layer_wise_pruning': layer_wise_pruning,
    'prune_all_layers': prune_all_layers,
    'random_mask': random_mask,
}

### Statistics

In [None]:
# Counting ARs

## 2D
classes = [0,1,9] # shirts, pants, shoes
classes = None # to randomize classes

# how much one class samples should be covered?
class_coverage = [0.8,0.95,0.99,1]

use_three_samples = True
average_over_images = 5

# Mining LPs with Decision Trees
lp_samples = 10000

# Dark neurons
dm_limit_perc = 0.01
dark_mask_limit = int(dm_limit_perc*test_batch_size)
print(f'dark mask flags neurons with n <= {dark_mask_limit} images (out of {test_batch_size})')

In [None]:
hyperparams['experiments']['classes'] = classes
hyperparams['experiments']['class_coverage'] = class_coverage
hyperparams['experiments']['use_three_samples'] = use_three_samples
hyperparams['experiments']['average_over_images'] = average_over_images
hyperparams['experiments']['lp_samples'] = lp_samples


In [None]:
hyperparams

### Init variables

In [None]:
model_architectures = []
for name in n_models:
    if n_models[name] > 0:
        model_architectures.append(name)
        
print(model_architectures)

In [None]:
data_from_training = OrderedDict()
ar_data_from_training = OrderedDict()
spec_data_from_training = OrderedDict()

data = OrderedDict()
data['acc_training'] = data_from_training
data['ar_2d_training'] = ar_data_from_training
data['spec_training'] = spec_data_from_training

## MNIST data

In [None]:
torch.manual_seed(random_seed)

In [None]:
train_loader, test_loader = get_train_and_test_loaders(training_data_dir,
                                                       data_name, batch_size,
                                                       test_batch_size, normalize_data,
                                                       normalize_by_features, normalize_by_moving,
                                                       normalize_by_scaling, kwargs)

train_samples = train_loader.dataset.data.shape[0]
print(f'train data {train_loader.dataset.data.shape} on device {train_loader.dataset.data.device}')
print(test_loader.dataset.data.shape)

In [None]:
# hyperparams
input_features = train_loader.dataset.data.shape[1] * train_loader.dataset.data.shape[2]
output_dim = len(train_loader.dataset.classes)

## Help functions for saving

In [None]:
models_dir = './models/'

In [None]:
ts = datetime.now().strftime('%y%m%d-%H%M%S')

In [None]:
folder_name = f'{data_name}-{ts}-{notebook_name}-{test_name}-{models_total}_models/'
saving_folder = f'results/{folder_name}'
if save_data and not os.path.isdir(saving_folder):
    os.mkdir(saving_folder)
    print('created folder for saving figures', saving_folder)


In [None]:
if not os.path.isdir(models_dir):
    os.mkdir(models_dir)

checkpoint_dir = models_dir + folder_name
if save_models and not os.path.isdir(checkpoint_dir):
    os.mkdir(checkpoint_dir)

if save_models:
    print(f'save models to {checkpoint_dir}')
else:
    print("don't save models")

#### Images that span the subspaces 

In [None]:
if compute_stats:
    n_example_sets, example_img_labels = get_n_spanning_image_groups(average_over_images, test_loader, classes)
    for i, (example_imgs, labels) in enumerate(zip(n_example_sets, example_img_labels)):
        visualize_example_images(example_imgs, labels,
                                 classes_str=test_loader.dataset.classes,
                                 horizontal=True
                                )
        if save_data:
            plt.savefig(saving_folder+f'spanning_images_subspace_{i}.png')
            
    if save_data:
        torch.save(example_img_labels, saving_folder+f'spanning_img_labels.pt')
else:
    print('no need for spanning images')

## Loading models

In [None]:
model_folders = get_folders(models_dir)
for i, folder in enumerate(model_folders):
    print(i, folder)

### Choose, which models to load

In [None]:
dir_ind = 0
if load_models and False: # adjust these variables depending which were used in the training of the models to be loaded
    pruning_new_version = True
    prune_weights = True
    layer_wise_pruning = False

if load_models:
    path_to_models = models_dir + model_folders[dir_ind]
    
    sdict_files = get_pt_files_in_dir(path_to_models)
    print(f'loading models from\n\t{model_folders[dir_ind]}')
    if save_figures or save_data:
        print(f'saving to\n\t{checkpoint_dir}')
else:
    print('dont load models')

In [None]:
sdicts = {}

if load_models:
    for file_name in sdict_files:
        name, ending = file_name.split('.')
        if ending != 'pt':
            continue
        sdicts[name] = torch.load(f'{path_to_models}/{file_name}')

## Save the metadata to a file

In [None]:
def get_metadata_strings(prune_weights, layer_wise_pruning, bias_std, random_mask, prune_all_layers,
                        normalize_data, normalize_by_features, normalize_by_moving, normalize_by_scaling,
                        use_three_samples, lp_samples, class_coverage):
    metadata = [
        'loaded models' if load_models else 'fresh models',
        'unstructured pruning' if prune_weights else 'structured pruning',
        'pruning by layer' if layer_wise_pruning else 'pruning globally',
        f'bias std {bias_std}',
        f'random_mask' if random_mask else 'pruned_mask',
        'prune all layers' if prune_all_layers else 'dont prune the output layer',
        '2D plane is spun by origin and two images' if not use_three_samples else '2D plane is spun by three images',
        f'use {lp_samples} to get layer patterns for specialization and dtree',
        f'specialization blanket has {class_coverage}% sample coverage',
        f'classes for the AR & specialization: {classes}'
    ]
    if normalize_data:
        metadata += [
            f'normalize data',
            '\tby features' if normalize_by_features else '\tas whole',
            '\tmove data' if normalize_by_moving else '\tdont move the data',
            '\tscale data' if normalize_by_scaling else '\tdont scale the data'
        ]
    else:
        metadata.append('dont normalize the data')

    
    return metadata

In [None]:
metadata = get_metadata_strings(prune_weights, layer_wise_pruning, bias_std, random_mask, prune_all_layers,
                        normalize_data, normalize_by_features, normalize_by_moving, normalize_by_scaling,
                        use_three_samples, lp_samples, class_coverage)

In [None]:
def write_metadata_to_file(fo):
    newlinestr = '\n'
    fo.write(ts + newlinestr)
    fo.write('dataset: ' + data_name + newlinestr)
    fo.write('notebook: ' + notebook_name + newlinestr)
    fo.write('test: ' + test_name + newlinestr)
    for line in metadata:
        fo.write(line + newlinestr)
        
if save_data:
    with open(saving_folder + 'metadata.txt', 'w') as fo:
        write_metadata_to_file(fo)
    with open(saving_folder + 'metadata.json', 'w') as fo:
        fo.write(json.dumps(hyperparams))
    
if save_models:
    with open(checkpoint_dir + 'metadata.txt', 'w') as fo:
        write_metadata_to_file(fo)
    with open(checkpoint_dir + 'metadata.json', 'w') as fo:
        fo.write(json.dumps(hyperparams))

## A simple dense network - Lenet

Using the same architecture as Zhou et al. (https://arxiv.org/pdf/1905.01067.pdf) and Frankle & Carbin (https://arxiv.org/pdf/1803.03635.pdf), ie. Lenet from LeCun et al. 1998.

| FC layers | 300, 100, 10 |
|:--|:--|
| Weights | 266k  |
| Iterations | 50k |
| Batch size | 60 |
| initialization | normal distribution std=0.1 |
| Optimizer | Adam 1.2e-3 |
| Pruning Rate (for iterative) | fc 20% |
| Loss Function | Cross Entropy Loss |

## A slightly deeper dense network

| FC layers | 200, 100, 100, 10 |
|:--|:--|
| Weights | 188k  |

## Pipefc (layers have shared width)

| FC layers | 100, 100, 100, 10 |
|:--|:--|
| Parameters | 100k  |

## Set up the networks sparse networks with the same #params

In [None]:
arc_dims = get_architecture_dims(hidden_dims_dict, sizes, output_dim)

In [None]:
arc_dims

In [None]:
dense_params = get_dense_params(arc_dims, input_features, smallest_net_densities)
dense_params

In [None]:
if not load_models:
    hyperparams_for_sparse = get_hyperparams_for_sparses(dense_params,
                                                         arc_dims,
                                                         n_models,
                                                         smallest_net_densities=smallest_net_densities,
                                                         count_output=prune_all_layers,
                                                         input_dim=input_features)
else:
    hyperparams_for_sparse = sdicts['hyperparams']
    

In [None]:
print_hyperparams(hyperparams_for_sparse)

In [None]:
param_keys = OrderedDict()
for name in hyperparams_for_sparse:
    param_keys[name] = [
        hyp[0][-1] for hyp in hyperparams_for_sparse[name]
    ]
    
param_keys

In [None]:
draw_hyperparam_setup(hyperparams_for_sparse)

## Init the models

In [None]:
if not load_models:
    models = init_models_sparse_vs_dense(hyperparams_for_sparse, n_models, input_features, output_dim, bias_std,
                                        random_mask, prune_all_layers, prune_weights)
else:
    models = sdicts['models_init']

In [None]:
# update n_models to be accurate
if load_models:
    n_models = {}
    for name in models.keys():
        for params in models[name].keys():
            n_models[name] = len(models[name][params][0])
            break
    print(f'n_models: {n_models}')

In [None]:
print_model_information(models)

## Save the original weights

In [None]:
if save_models:
    torch.save(hyperparams_for_sparse, checkpoint_dir + 'hyperparams.pt')
    torch.save(models, checkpoint_dir + 'models_init.pt')
    print('models saved')

## Train the networks

In [None]:
loss_fn = torch.nn.CrossEntropyLoss()

In [None]:
def print_progress(train_log, val_log):
    print(f'\ttraining \tloss {train_log[-1][0]:.3f} acc {train_log[-1][1]:.2f}%')
    print(f'\tvalidation \tloss {val_log[-1][0]:.3f} acc {val_log[-1][1]:.2f}%')

In [None]:
def compute_stats_in_middle_of_training(model, test_loader, lp_samples, coverage, example_image_sets, 
                                        average_over_images, use_three_samples=True, n_classes=10):
    '''return ar data, specialization data'''
    with torch.no_grad():
        _, _, max_patterns = record_lps_and_max_patterns(model,
                                                         test_loader,
                                                         n_samples=lp_samples,
                                                         n_classes=n_classes,
                                                         coverage=coverage)
        ars, specs = compute_2D_ARs_and_specialization_for_subspaces(model,
                                                                    example_image_sets=example_image_sets,
                                                                    average_over_images=average_over_images,
                                                                    use_three_samples=True,
                                                                     max_patterns=max_patterns
                                                                  )
    return ars, specs


In [None]:
def train(model, iterations, lr, evaluation_scheme, average_over_images=5, print_every_n_iteration=100, compute_stats=False, compute_acc=True, verbose=True):
    model.train()

    if type(evaluation_scheme) is int:
        evaluation_scheme = list(range(evaluation_scheme, iterations / evaluation_scheme, evaluation_scheme))
    optimizer = optim.Adam(model.parameters(), lr=lr)

    #before any training
    val_log, train_log, ar_data, spec_data = [], [], [], []
    i = 0
    
    if compute_acc:
        val_log.append(evaluate(model, test_loader, loss_fn))
        train_log.append(evaluate(model, train_loader, loss_fn))
        
    if verbose:
        plur = 's' if i > 1 else ''
        print(f' after {i} iteration{plur}')
        print_progress(train_log, val_log)

    if compute_stats:
        ars, specs = compute_stats_in_middle_of_training(model, test_loader, lp_samples, class_coverage, n_example_sets, average_over_images)
        ar_data.append(ars)
        spec_data.append(specs)
    
    while i < iterations:

        total_loss = 0
        correct_pred_n = 0
        for data, target in train_loader:
            if i >= iterations:
                break
                
            if verbose and i % print_every_n_iteration == 0:
                print(f'iteration {i}/{iterations}', end='\r')

            data = data.view(-1,28*28) # flatten
            outputs = model(data)
            loss = loss_fn(outputs, target)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
                        
            i += 1 # we have trained one more iteration
            
            if i in evaluation_scheme and (compute_stats or compute_acc):
                with torch.no_grad():            

                    if compute_acc:
                        train_log.append(evaluate(model, train_loader, loss_fn))
                        val_log.append(evaluate(model, test_loader, loss_fn))
                        
                    if verbose:
                        plur = 's' if i > 1 else ''
                        print(f' after {i} iteration{plur}')
                        print_progress(train_log, val_log)
                    
                    if compute_stats:
                        ars, specs = compute_stats_in_middle_of_training(model, test_loader, lp_samples, class_coverage, n_example_sets, average_over_images)
                        ar_data.append(ars)
                        spec_data.append(specs)

                    total_loss = 0
                    correct_pred_n = 0
                

    return np.dstack((train_log, val_log)), np.array(ar_data), np.array(spec_data)

In [None]:
%%time

verbose = True
k = 0
if not load_models:
    for name in models:
        print(name)
        models_p = models[name]
        td, regions, specializations = OrderedDict(), OrderedDict(), OrderedDict()

        for i, params in enumerate(list(models_p.keys())):
            models_h = models_p[params]
            td_h, regions_h, specs_h = [], [], []
            for j, models_n in enumerate(models_h):
                td_n, regions_n, specs_n = [], [], []
                for model in models_n:
                    k += 1
                    end = '\n' if verbose else '\r'
                    print(f'train model {k}/{models_total}', end=end)
                    t, regs, specs = train(model, iterations, lr, evaluation_scheme,
                                    average_over_images=average_over_images,
                                    compute_stats=compute_stats,
                                    compute_acc=compute_acc,
                                    verbose=verbose)
                    td_n.append(t)
                    regions_n.append(regs)
                    specs_n.append(specs)
                td_h.append(td_n)
                regions_h.append(regions_n)
                specs_h.append(specs_n)
                
            td[params] = np.array(td_h)
            regions[params] = np.array(regions_h)
            specializations[params] = np.array(specs_h)
            
        data_from_training[name] = td
        ar_data_from_training[name] = regions
        spec_data_from_training[name] = specializations
        
        models_trained = models
else:
    print('Models will be loaded, no training needed.')
    models = sdicts['models_trained']
    model_architectures = list(models.keys())
    
    
# save

if save_models:
    torch.save(models, checkpoint_dir + 'models_trained.pt')
    print('models saved')
else:
    print('dont save models')
    
if save_data and (compute_stats or compute_acc):
    torch.save(data, saving_folder + 'data.pt')
    torch.save(hyperparams_for_sparse, saving_folder + 'hyperparams.pt')
    print('saved')
elif not save_data:
    print('dont save data')
else:
    print('nothing to save')

## Save the trained dense networks

In [None]:
if save_models:
    torch.save(models, checkpoint_dir + 'models_trained.pt')
    print('models saved')
else:
    print('dont save models')

#### Save the data

In [None]:
if save_data and (compute_stats or compute_acc):
    torch.save(data, saving_folder + 'data.pt')
    torch.save(hyperparams_for_sparse, saving_folder + 'hyperparams.pt')
    print('saved')
elif not save_data:
    print('dont save data')
else:
    print('nothing to save')