In [2]:
import json
import torch
import numpy as np
import torch.nn as nn
from resnet_encoder import ResnetEncoder
from depth_decoder import DepthDecoder
from pose_decoder import PoseDecoder

ModuleNotFoundError: No module named 'torchvision_updated'

Resource Estimates

In [3]:
ENERGY_PER_FLOP = 1e-12
TIME_PER_FLOP = 1e-9
COMM_BANDWIDTH = 1e10
ENERGY_PER_BYTE = 1e-9

Data Extraction Functions

In [27]:
def compute_flops(module, input, output):
    if isinstance(module, nn.Conv2d):
        kernel_ops = np.prod(module.kernel_size)
        cin = module.in_channels 
        cout = module.out_channels
        hout, wout = output.shape[2], output.shape[3]
        flops = 2 * kernel_ops * cin * cout * hout * wout
        return flops
    elif isinstance(module, nn.Linear):
        flops = 2 * module.in_features * module.out_features
        return flops
    elif isinstance(module, (nn.ReLU, nn.ELU, nn.GELU)):
        if isinstance(module, nn.ReLU):
            flops = output.numel()
        elif isinstance(module, nn.ELU):
            flops = 4*output.numel()
        else: 
            flops = 12*output.numel()
        return flops
    elif isinstance(module, nn.BatchNorm2d):
        flops = 4*output.numel()
        return flops 
    elif isinstance(module, (nn.MaxPool2d, nn.AvgPool2d)):
        if isinstance(module, nn.MaxPool2d):
            kernel_ops = np.prod(module.kernel_size) - 1
        else:
            kernel_ops = np.prod(module.kernel_size)
        flops = kernel_ops*output.numel()
        return flops
    return 0

def hook_fn(module, input, output):
    global node_id
    flops = compute_flops(module, input, output)
    param_shapes = [list(p.shape) for p in module.parameters() if hasattr(module, 'parameters')]
    
    node = {
        "name": module.__class__.__name__,
        "id": node_id,
        "opcode": type(module).__name__,
        "param_shapes": param_shapes,  
        "energy": flops * ENERGY_PER_FLOP,
        "runtime": flops * TIME_PER_FLOP,
        "flops": flops,
        "size": sum(p.numel() * p.element_size() for p in module.parameters() if hasattr(module, 'parameters')),
    }
    nodes.append(node)

    if isinstance(output, torch.Tensor):
        tensor_to_node[output] = node_id
    elif isinstance(output, (tuple, list)):
        for out in output:
            if isinstance(out, torch.Tensor):
                tensor_to_node[out] = node_id

    if isinstance(input, (tuple, list)):
        for inp in input:
            if isinstance(inp, torch.Tensor) and inp in tensor_to_node:
                source_id = tensor_to_node[inp]
                if source_id != node_id:
                    data_volume = inp.numel() * inp.element_size()
                    edge = {
                        "source": source_id,
                        "destination": node_id,
                        "shape": list(inp.shape),
                        "latency": data_volume / COMM_BANDWIDTH,
                        "energy": data_volume * ENERGY_PER_BYTE,
                        "size": data_volume
                    }
                    edges.append(edge)
    node_id += 1

def convert_to_serializable(obj):
    if isinstance(obj, dict):
        return {key: convert_to_serializable(value) for key, value in obj.items()}
    elif isinstance(obj, list):
        return [convert_to_serializable(item) for item in obj]
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, np.generic):
        return obj.item()
    elif isinstance(obj, torch.Tensor):
        return obj.tolist()
    return obj

In [6]:
# With Learnable Parameters
m = nn.BatchNorm2d(100)
rel = nn.ReLU(inplace=False)

# Without Learnable Parameters
# m = nn.BatchNorm2d(100, affine=False)
input = torch.randn(20, 100, 35, 45)
output = m(input)
output = rel(output)
print(isinstance(output, torch.Tensor))

True


ONNX Reconstruction

In [28]:
import onnx
import onnx.helper as helper
import numpy as np

def build_onnx_from_json(json_nodes, json_edges):
    graph_nodes = []
    graph_inputs = []
    graph_outputs = []
    initializers = []
    node_map = {} 

    for node in json_nodes:
        node_id = node['id']
        op_type = node['opcode']
        node_name = f"node_{node_id}"
        node_map[node_id] = node_name
        input_names = []
        for edge in json_edges:
            if edge['destination'] == node_id:
                input_names.append(f"node_{edge['source']}_output")
        output_name = f"{node_name}_output"

        onnx_node = helper.make_node(
            op_type=op_type,
            inputs=input_names,
            outputs=[output_name],
            name=node_name
        )
        graph_nodes.append(onnx_node)

        if 'param_shapes' in node and node['param_shapes']:
            for idx, shape in enumerate(node['param_shapes']):
                param_name = f"{node_name}_param_{idx}"
                initializer = helper.make_tensor(
                    name=param_name,
                    data_type=onnx.TensorProto.FLOAT,
                    dims=shape,
                    vals=np.random.rand(*shape).astype(np.float32).flatten()
                )
                initializers.append(initializer)

    for edge in json_edges:
        if edge['source'] not in node_map:  
            input_name = f"node_{edge['source']}_output"
            graph_inputs.append(helper.make_tensor_value_info(
                input_name,
                onnx.TensorProto.FLOAT,
                edge['shape']
            ))
        if edge['destination'] not in node_map:  
            output_name = f"node_{edge['destination']}_output"
            graph_outputs.append(helper.make_tensor_value_info(
                output_name,
                onnx.TensorProto.FLOAT,
                edge['shape']
            ))

    graph = helper.make_graph(
        nodes=graph_nodes,
        name="ReconstructedGraph",
        inputs=graph_inputs,
        outputs=graph_outputs,
        initializer=initializers
    )

    model = helper.make_model(graph, producer_name="json_to_onnx")
    return model

Resnet Encoder

In [33]:
nodes = []
edges = []
node_id = 0
tensor_to_node = {}

num_layers = 18  
pretrained = False  
num_input_images = 2

encoder = ResnetEncoder(num_layers=num_layers, pretrained=pretrained, num_input_images=num_input_images)

In [34]:
for module in encoder.encoder.modules():
    if len(list(module.children())) == 0:
        module.register_forward_hook(hook_fn)

input_image = torch.randn(1, 6, 224, 224)
encoder(input_image)

data = {"nodes": [convert_to_serializable(node) for node in nodes], "edges": [convert_to_serializable(edge) for edge in edges]}
with open('resnet_encoder_graph.json', 'w') as f:
    json.dump(data, f, indent=4)

onnx_model = build_onnx_from_json(nodes, edges)
onnx.save(onnx_model, "reconstructed_model.onnx")

DepthDecoder

In [18]:
nodes = []
edges = []
tensor_to_node = {}
node_id = 0

num_ch_enc = np.array([48, 80, 128])  # Adjusted to match LiteMono encoder

    # Instantiate the DepthDecoder with adjusted scales
depth_decoder= DepthDecoder(num_ch_enc=num_ch_enc, scales=range(len(num_ch_enc)))
# depth_decoder = DepthDecoder(num_ch_enc=num_ch_enc, scales=scales, num_output_channels=1, use_skips=True)

In [21]:
for module in depth_decoder.modules():
    if len(list(module.children())) == 0:  
        module.register_forward_hook(hook_fn)

batch_size = 1
height = 224
width = 224


factors = [4, 8, 16]

input_features = [
        torch.randn(batch_size, num_ch_enc[i], height // factors[i], width // factors[i])
        for i in range(len(num_ch_enc))
    ]


depth_decoder(input_features)

data = {"nodes": [convert_to_serializable(node) for node in nodes], "edges": [convert_to_serializable(edge) for edge in edges]}
with open('depth_decoder_graph.json', 'w') as f:
    json.dump(data, f, indent=4)

onnx_model = build_onnx_from_json(nodes, edges)
onnx.save(onnx_model, "reconstructed_depth_decoder_model.onnx")

In [22]:
pose_decoder = PoseDecoder(num_ch_enc=encoder.num_ch_enc, num_input_features=1, num_frames_to_predict_for=2)
input_features = [encoder(input_image)]

for module in pose_decoder.modules():
    if len(list(module.children())) == 0:  
        module.register_forward_hook(hook_fn)

data = {"nodes": [convert_to_serializable(node) for node in nodes], "edges": [convert_to_serializable(edge) for edge in edges]}
with open('pose_decoder_graph.json', 'w') as f:
    json.dump(data, f, indent=4)

onnx_model = build_onnx_from_json(nodes, edges)
onnx.save(onnx_model, "reconstructed_pose_model.onnx")