In [1]:
import torch
import sys
import os
import json
import time
import numpy as np
import argparse

from torch.utils.data import DataLoader
from torch.utils.data import WeightedRandomSampler
from umap.umap_ import find_ab_params

sys.path.append("..")
from singleVis.custom_weighted_random_sampler import CustomWeightedRandomSampler
from singleVis.SingleVisualizationModel import VisModel
from singleVis.losses import UmapLoss, ReconstructionLoss, SingleVisLoss
from singleVis.edge_dataset import DataHandler
from singleVis.data import NormalDataProvider
from singleVis.spatial_edge_constructor import SingleEpochSpatialEdgeConstructor

In [2]:
VIS_METHOD = "DVI" # DeepVisualInsight
CONTENT_PATH = "/home/xianglin/projects/DVI_data/resnet18_fmnist"
I = 5
GPU_ID = "2"

In [3]:
sys.path.append(CONTENT_PATH)
with open(os.path.join(CONTENT_PATH, "config.json"), "r") as f:
    config = json.load(f)
config = config[VIS_METHOD]

In [4]:
SETTING = config["SETTING"]
CLASSES = config["CLASSES"]
DATASET = config["DATASET"]
PREPROCESS = config["VISUALIZATION"]["PREPROCESS"]
# GPU_ID = config["GPU"]

# Training parameter (subject model)
TRAINING_PARAMETER = config["TRAINING"]
NET = TRAINING_PARAMETER["NET"]
LEN = TRAINING_PARAMETER["train_num"]
EPOCH_START = config["EPOCH_START"]
EPOCH_END = config["EPOCH_END"]
EPOCH_PERIOD = config["EPOCH_PERIOD"]

# Training parameter (visualization model)
VISUALIZATION_PARAMETER = config["VISUALIZATION"]
LAMBDA1 = VISUALIZATION_PARAMETER["LAMBDA1"]
B_N_EPOCHS = VISUALIZATION_PARAMETER["BOUNDARY"]["B_N_EPOCHS"]
L_BOUND = VISUALIZATION_PARAMETER["BOUNDARY"]["L_BOUND"]
ENCODER_DIMS = VISUALIZATION_PARAMETER["ENCODER_DIMS"]
DECODER_DIMS = VISUALIZATION_PARAMETER["DECODER_DIMS"]
S_N_EPOCHS = VISUALIZATION_PARAMETER["S_N_EPOCHS"]
N_NEIGHBORS = VISUALIZATION_PARAMETER["N_NEIGHBORS"]
PATIENT = VISUALIZATION_PARAMETER["PATIENT"]
MAX_EPOCH = VISUALIZATION_PARAMETER["MAX_EPOCH"]

VIS_MODEL_NAME = VISUALIZATION_PARAMETER["VIS_MODEL_NAME"]
EVALUATION_NAME = VISUALIZATION_PARAMETER["EVALUATION_NAME"]

# Define hyperparameters
DEVICE = torch.device("cuda:{}".format(GPU_ID) if torch.cuda.is_available() else "cpu")

In [5]:
import Model.model as subject_model
net = eval("subject_model.{}()".format(NET))

In [6]:
# Define visualization models
model1 = VisModel([512,256,2], [2,256,512])
model2 = VisModel([512,256,256,2], [2,256,256,512])
model3 = VisModel([512,256,256,256,2], [2,256,256,256,512])
model4 = VisModel([512,256,256,256,256,2], [2,256,256,256,256,512])
model5 = VisModel([512,256,256,256,256,256,2], [2,256,256,256,256,256,512])


In [7]:
# Define data_provider
data_provider = NormalDataProvider(CONTENT_PATH, net, EPOCH_START, EPOCH_END, EPOCH_PERIOD, device=DEVICE, classes=CLASSES,verbose=1)
if PREPROCESS:
    data_provider._meta_data()
    if B_N_EPOCHS >0:
        data_provider._estimate_boundary(LEN//10, l_bound=L_BOUND)

# Define Losses
negative_sample_rate = 5
min_dist = .1
_a, _b = find_ab_params(1.0, min_dist)
umap_loss_fn = UmapLoss(negative_sample_rate, DEVICE, _a, _b, repulsion_strength=1.0)
recon_loss_fn = ReconstructionLoss(beta=1.0)

# Define DVI Loss
criterion = SingleVisLoss(umap_loss_fn, recon_loss_fn, lambd=LAMBDA1)

# Define training parameters
# Define Edge dataset
spatial_cons = SingleEpochSpatialEdgeConstructor(data_provider, I, S_N_EPOCHS, B_N_EPOCHS, N_NEIGHBORS)
edge_to, edge_from, probs, feature_vectors, attention = spatial_cons.construct()

probs = probs / (probs.max()+1e-3)
eliminate_zeros = probs>1e-2#1e-3
edge_to = edge_to[eliminate_zeros]
edge_from = edge_from[eliminate_zeros]
probs = probs[eliminate_zeros]

dataset = DataHandler(edge_to, edge_from, feature_vectors, attention)

n_samples = int(np.sum(S_N_EPOCHS * probs) // 1)
# chose sampler based on the number of dataset
if len(edge_to) > 2^24:
    sampler = CustomWeightedRandomSampler(probs, n_samples, replacement=True)
else:
    sampler = WeightedRandomSampler(probs, n_samples, replacement=True)
edge_loader = DataLoader(dataset, batch_size=1000, sampler=sampler)

Finish initialization...
Wed Dec 21 13:52:06 2022 Building RP forest with 17 trees
Wed Dec 21 13:52:07 2022 NN descent for 16 iterations
	 1  /  16
	 2  /  16
	 3  /  16
	 4  /  16
	Stopping threshold met -- exiting after 4 iterations


In [8]:
len(edge_loader.dataset)

376664

In [9]:
import torch.nn as nn
def get_layer_metric_array(network, metric, mode): 
    metric_array = []

    for layer in network.modules():
        if mode=='channel' and hasattr(layer,'dont_ch_prune'):
            continue
        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
            metric_array.append(metric(layer))
    
    return metric_array

In [10]:
def grad_norm(layer):
    if layer.weight.grad is not None:
        return layer.weight.grad
    else:
        return torch.zeros_like(layer.weight)

In [11]:
def get_grad_norm_arr(network, data, loss_fn):
    network.to(device=DEVICE)
    network.train()
    network.zero_grad()

    # for data in self.edge_loader:\
    edge_to, edge_from, a_to, a_from = data

    edge_to = edge_to.to(device=DEVICE, dtype=torch.float32)
    edge_from = edge_from.to(device=DEVICE, dtype=torch.float32)
    a_to = a_to.to(device=DEVICE, dtype=torch.float32)
    a_from = a_from.to(device=DEVICE, dtype=torch.float32)

    outputs = network(edge_to, edge_from)
    umap_l, recon_l, loss = loss_fn(edge_to, edge_from, a_to, a_from, outputs)
    # ===================backward====================
    
    loss.backward()
    grad_norm_arr = get_layer_metric_array(network, grad_norm, mode='param')

    return grad_norm_arr

In [12]:
def get_l2_norm_array(network):
    return get_layer_metric_array(network, lambda l: l.weight, mode="param")

In [13]:
# snip
import types
import torch.nn.functional as F
def snip_forward_conv2d(self, x):
    return F.conv2d(x, self.weight * self.weight_mask, self.bias,
                    self.stride, self.padding, self.dilation, self.groups)
def snip_forward_linear(self, x):
    return F.linear(x, self.weight * self.weight_mask, self.bias)

def compute_snip_per_weight(network, data, loss_fn):
    for layer in network.modules():
        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
            layer.weight_mask = nn.Parameter(torch.ones_like(layer.weight))
            layer.weight.requires_grad = False

        # Override the forward methods:
        if isinstance(layer, nn.Conv2d):
            layer.forward = types.MethodType(snip_forward_conv2d, layer)

        if isinstance(layer, nn.Linear):
            layer.forward = types.MethodType(snip_forward_linear, layer)

    # Compute gradients (but don't apply them)
    network.to(device=DEVICE)
    network.train()
    network.zero_grad()

    # for data in self.edge_loader:\
    edge_to, edge_from, a_to, a_from = data

    edge_to = edge_to.to(device=DEVICE, dtype=torch.float32)
    edge_from = edge_from.to(device=DEVICE, dtype=torch.float32)
    a_to = a_to.to(device=DEVICE, dtype=torch.float32)
    a_from = a_from.to(device=DEVICE, dtype=torch.float32)

    outputs = network(edge_to, edge_from)
    umap_l, recon_l, loss = loss_fn(edge_to, edge_from, a_to, a_from, outputs)
    # ===================backward====================
    
    loss.backward()

    # select the gradients that we want to use for search/prune
    def snip(layer):
        if layer.weight_mask.grad is not None:
            return torch.abs(layer.weight_mask.grad)
        else:
            return torch.zeros_like(layer.weight)
    
    grads_abs = get_layer_metric_array(network, snip, mode="param")

    return grads_abs

In [14]:
import torch.autograd as autograd

def compute_grasp_per_weight(network, data, loss_fn):

    # get all applicable weights
    weights = []
    for layer in network.modules():
        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
            weights.append(layer.weight)
            layer.weight.requires_grad_(True) # TODO isn't this already true?

    # NOTE original code had some input/target splitting into 2
    # I am guessing this was because of GPU mem limit
    network.to(device=DEVICE)
    network.train()
    network.zero_grad()

    #forward/grad pass #1
    grad_w = None
    #TODO get new data, otherwise num_iters is useless!
    # for data in self.edge_loader:\
    edge_to, edge_from, a_to, a_from = data

    edge_to = edge_to.to(device=DEVICE, dtype=torch.float32)
    edge_from = edge_from.to(device=DEVICE, dtype=torch.float32)
    a_to = a_to.to(device=DEVICE, dtype=torch.float32)
    a_from = a_from.to(device=DEVICE, dtype=torch.float32)

    outputs = network(edge_to, edge_from)
    umap_l, recon_l, loss = loss_fn(edge_to, edge_from, a_to, a_from, outputs)
    
    grad_w_p = autograd.grad(loss, weights, allow_unused=True)
    if grad_w is None:
        grad_w = list(grad_w_p)
    else:
        for idx in range(len(grad_w)):
            grad_w[idx] += grad_w_p[idx]


    # forward/grad pass #2
    edge_to, edge_from, a_to, a_from = data

    edge_to = edge_to.to(device=DEVICE, dtype=torch.float32)
    edge_from = edge_from.to(device=DEVICE, dtype=torch.float32)
    a_to = a_to.to(device=DEVICE, dtype=torch.float32)
    a_from = a_from.to(device=DEVICE, dtype=torch.float32)

    outputs = network(edge_to, edge_from)
    umap_l, recon_l, loss = loss_fn(edge_to, edge_from, a_to, a_from, outputs)
    
    grad_f = autograd.grad(loss, weights, create_graph=True, allow_unused=True)
    
    # accumulate gradients computed in previous step and call backwards
    z, count = 0,0
    for layer in network.modules():
        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
            if grad_w[count] is not None:
                z += (grad_w[count].data * grad_f[count]).sum()
            count += 1
    z.backward()

    # compute final sensitivity metric and put in grads
    def grasp(layer):
        if layer.weight.grad is not None:
            return -layer.weight.data * layer.weight.grad   # -theta_q Hg
            #NOTE in the grasp code they take the *bottom* (1-p)% of values
            #but we take the *top* (1-p)%, therefore we remove the -ve sign
            #EDIT accuracy seems to be negatively correlated with this metric, so we add -ve sign here!
        else:
            return torch.zeros_like(layer.weight)
    
    grads = get_layer_metric_array(network, grasp, mode="param")

    return grads



In [15]:
def get_sum(arr):
    s = 0.0
    for i in arr:
        s += i.sum().item()
        
    return s

In [16]:
num_b = 1
for d in edge_loader:
    num_b -=1
    if num_b ==0:
        data = d
        break

In [17]:
# grasp1 = get_sum(compute_grasp_per_weight(model1, data, criterion))
# print("layer=1\t", grasp1)

In [18]:
snip_grad1 = get_sum(compute_snip_per_weight(model1, data, criterion))
snip_grad2 = get_sum(compute_snip_per_weight(model2, data, criterion))
snip_grad3 = get_sum(compute_snip_per_weight(model3, data, criterion))
snip_grad4 = get_sum(compute_snip_per_weight(model4, data, criterion))
snip_grad5 = get_sum(compute_snip_per_weight(model5, data, criterion))
print("layer=1\t", snip_grad1)
print("layer=2\t", snip_grad2)
print("layer=3\t", snip_grad3)
print("layer=4\t", snip_grad4)
print("layer=5\t", snip_grad5)

layer=1	 420.97770261764526
layer=2	 326.5517023205757
layer=3	 225.5781311839819
layer=4	 181.85866290330887
layer=5	 170.13697430491447


In [19]:
model6 = VisModel([512,256,128,64,2], [2,64,128,256,512]) 
snip_grad6 = get_sum(compute_snip_per_weight(model6, data, criterion))
print("new arch\t", snip_grad6)

new arch	 182.0700019299984
