# Simple CNN

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np

#########################################
# 1. Define a simple CNN with 3 conv layers
#########################################
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        # conv1: from 4x4 input → output remains 4x4
        self.conv1 = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False)
        # conv2: from 4x4 → 2x2 (stride=2)
        self.conv2 = nn.Conv2d(1, 1, kernel_size=3, stride=2, padding=1, bias=False)
        # conv3: from 2x2 → 1x1 (stride=2)
        self.conv3 = nn.Conv2d(1, 1, kernel_size=3, stride=2, padding=1, bias=False)
        self.fcn = nn.Linear(1, 10, bias=False)  # not used for interaction

    def forward(self, x):
        
        a1 = (self.conv1(x))    # shape: (1,1,4,4)
        a2 = (self.conv2(a1))     # shape: (1,1,2,2)
        a3 = (self.conv3(a2))     # shape: (1,1,1,1)
        return a3

model = SimpleCNN()
criterion = nn.CrossEntropyLoss(reduction="none")
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Gradient Extraction & Storing

In [2]:
#########################################
# 2. Register hooks to capture gradients
#########################################
activation_gradients = {}
gradient_flows = {}
def get_activation_grad(name, connet2name=None):
    def hook(module, grad_input, grad_output):
        if name is not None:
            # Lấy grad_out: shape [N, C_out, H_out, W_out]
            grad_out = grad_output[0].detach()
            N, C_out, H_out, W_out = grad_out.shape
            
            # Lấy weight của module: shape [C_out, C_in, kH, kW]
            weight = module.weight  
            C_out_w, C_in, kH, kW = weight.shape
            assert C_out == C_out_w, "Mismatch in output channels."

            # Chuyển weight thành dạng ma trận: [C_out, C_in*kH*kW]
            weight_reshaped = weight.view(C_out, -1)
            
            # Reshape grad_out thành [N, C_out, Len_out] với Len_out = H_out * W_out
            Len_out = H_out * W_out
            grad_out_reshaped = grad_out.view(N, C_out, Len_out)

            # Tính grad_input_cols: [N, C_in*kH*kW, Len_out]
            grad_input_cols = torch.matmul(weight_reshaped.t(), grad_out_reshaped)
            
            # Giả sử batch size N=1
            grad_input_cols = grad_input_cols[0]  # [C_in*kH*kW, Len_out]

            # Lấy kích thước input từ grad_input[0]: [N, C_in, H_in, W_in]
            H_in, W_in = grad_input[0].shape[2:]
            Len_in = H_in * W_in

            # Xây dựng ánh xạ từ các patch đến các vị trí trên input:
            # Tạo tensor chứa các chỉ số của các ô input, shape: [1, 1, H_in, W_in]
            input_indices = torch.arange(Len_in, device=grad_out.device).view(1, 1, H_in, W_in).float()+1
            # Sử dụng F.unfold để lấy ma trận ánh xạ, shape: [C_in*kH*kW, Len_out]
            idx_map = (F.unfold(input_indices, kernel_size=module.kernel_size, 
                                dilation=module.dilation, padding=module.padding, stride=module.stride)[0])

            # Khởi tạo gradient_flows với kích thước (Len_in, Len_out)
            gradient_flow = torch.zeros(Len_in+1, Len_out, device=grad_out.device)
            # Sử dụng scatter_add_ để cộng các giá trị từ grad_input_cols vào gradient_flows
            # Cho mỗi phần tử tại vị trí (p, j) trong grad_input_cols, ta cộng vào gradient_flows tại (idx_map[p,j], j)
            gradient_flow.scatter_add_(0, idx_map.long(), grad_input_cols)
            gradient_flow = (gradient_flow[1:,:])
            # print((grad_input[0]).cpu().numpy().reshape(-1))
            # print(gradient_flow.cpu().numpy().sum(axis=-1,keepdims=False))
            assert np.abs(gradient_flow.cpu().numpy().sum(axis=-1,keepdims=False)-(grad_input[0]).cpu().numpy().reshape(-1)).sum() < 1e-7, "Mismatch in gradient values."
            gradient_flow = torch.abs(gradient_flow)
            gradient_flow[gradient_flow>1e-5] = 1.
            gradient_flow[gradient_flow<=1e-5] = .99

            # Lưu kết quả vào activation_gradients
            gradient_flows[(name, connet2name)] = gradient_flow.cpu().numpy()  # kích thước: (Len_in, Len_out)
            
            activation_gradients[name] = (gradient_flows[(name, connet2name)]).sum(axis=-1,keepdims=False).reshape(H_in, W_in)
            activation_gradients[connet2name] = (gradient_flows[(name, connet2name)]).sum(axis=0,keepdims=False).reshape(H_out, W_out)
    return hook

model.conv1.register_backward_hook(get_activation_grad(None, "conv1"))
model.conv2.register_backward_hook(get_activation_grad("conv1", "conv2"))
model.conv3.register_backward_hook(get_activation_grad("conv2", "conv3"))

#########################################
# 3. Run forward/backward on a random input
#########################################
input_tensor = torch.randn(1, 1, 4, 4)
target = torch.randn(1, 1, 4, 4).long()
optimizer.zero_grad()
a3 = model(input_tensor)
loss = (a3.reshape(1,-1) - target[0,0,2,2].reshape(1,-1)).mean()  # use conv3 output for loss
loss.backward()

Save gradient flow information, includes: 
- activation_gradients: dict of aggregated weights of nodes at each layer. 
- gradient_flows: dict of values of gradient flows between adjacent layers.

In [3]:
import os,pickle

os.mkdir("debug_data")

flow_info = {"activation_gradients": activation_gradients, 
             "gradient_flows": gradient_flows}
with open('debug_data/flow_info.pkl', 'wb') as f:
    pickle.dump(flow_info, f)
np.save('debug_data/input_representation.npy', input_tensor[0,0,:,:].cpu().numpy())
np.save('debug_data/target_representation.npy', target[0,0,:,:].cpu().numpy())

FileExistsError: [WinError 183] Cannot create a file when that file already exists: 'debug_data'

# Rendering the Visualization

In [4]:
import subprocess
debug_dir = os.path.join(os.getcwd(), "debug_data") + "\\"
debug_dir

'C:\\Users\\Admin\\Documents\\GitHub\\vnittest\\examples\\breakpoints & data hooking\\debug_data\\'

In [5]:
subprocess.run(["vnittest", debug_dir])

CompletedProcess(args=['vnittest', 'C:\\Users\\Admin\\Documents\\GitHub\\vnittest\\examples\\breakpoints & data hooking\\debug_data\\'], returncode=0)