In [14]:
import torch
import torch.nn as nn
import torchvision.models as models
import json
from typing import Dict, List, Tuple
import numpy as np
from resnet_encoder import ResnetEncoder
from pose_decoder import PoseDecoder
from depth_decoder import Upsampling, DepthDecoder, ExtractInitial, ExtractSecond, ExtractThird
# from depth_encoder import LayerNorm, MatrixMultiply, Softmax, WeightedSum, LiteMono, Permute4d, GammaMultiply
import depth_encoder
# from residual_add import ResidualAdd
from torchvision.models.residual_add import ResidualAdd
from typing import Union
import onnx
import onnx.helper as helper
from timm.models.layers import DropPath

class NumpyFloatEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, (np.float32, np.float64, np.int64)):
            return float(obj)
        return super().default(obj)

class ProcessBatch(nn.Module):
    def __init__(self):
        super(ProcessBatch, self).__init__()
        self.DepthEncoding = depth_encoder.LiteMono()
        self.DepthDecoding = DepthDecoder(self.DepthEncoding.num_ch_enc, scales = range(3))
        self._annotate_submodules(self.DepthEncoding, 'DepthEncoding')
        self._annotate_submodules(self.DepthDecoding, 'DepthDecoding')
        # self.PoseEncoding = ResnetEncoder(num_layers=18, pretrained=False, num_input_images=2)
        # self.PoseDecoding = PoseDecoder(num_ch_enc=self.PoseEncoding.num_ch_enc, num_input_features=1, num_frames_to_predict_for=2)
    def _annotate_submodules(self, module, model_name: str):
        for name, sub_module in module.named_modules():
            sub_module._name = model_name

    def forward(self, x):

        features = []
        x = (x - 0.45) / 0.225

        x_down = []
        for i in range(3):
            x_down.append(self.DepthEncoding.input_downsample[i](x)) # generates 4 different levels of avg pooling

        tmp_x = []
        x = self.DepthEncoding.downsample_layers[0](x) # 3 sequential conv layers
        x = self.DepthEncoding.cat1(x, x_down[0]) # concatenate 3 x 112 x 112 and 48 x 112 x 112 into 51 x 112 x 112
        x = self.DepthEncoding.stem2(x) # contatenate (3 x 112 x 112 and 48 x 112 x 112 into 51 x 112 x 112) conv outputs 48 channels
        # x = self.stem2(torch.cat((x, x_down[0]), dim=1)) # contatenate (3 x 112 x 112 and 48 x 112 x 112 into 51 x 112 x 112) conv outputs 48 channels
        tmp_x.append(x) # append 51 x 112 x 112 to the front

        for s in range(len(self.DepthEncoding.stages[0])-1):
            x = self.DepthEncoding.stages[0][s](x) # iterating through the Dilated COnvs and LGFI blocks
        x = self.DepthEncoding.stages[0][-1](x) # Stage 1 output --> 
        x2 = x
        tmp_x.append(x)
        features.append(x) # stage 1 output x2

        # Unroll for loop for DConvs and LGFIs

        tmp_x.append(x_down[1])
        x = self.DepthEncoding.cat2(*tmp_x)
        # x = torch.cat(tmp_x, dim=1)
        x = self.DepthEncoding.downsample_layers[1](x)

        tmp_x = [x]
        for s in range(len(self.DepthEncoding.stages[1]) - 1):
            x = self.DepthEncoding.stages[1][s](x)
        x = self.DepthEncoding.stages[1][-1](x)
        tmp_x.append(x)

        features.append(x) # stage 2 output x1
        x1 = x


        tmp_x.append(x_down[2])
        x = self.DepthEncoding.cat2(*tmp_x)
        # x = torch.cat(tmp_x, dim=1)
        x = self.DepthEncoding.downsample_layers[2](x)

        tmp_x = [x]
        for s in range(len(self.DepthEncoding.stages[2]) - 1):
            x = self.DepthEncoding.stages[2][s](x)
        x = self.DepthEncoding.stages[2][-1](x)
        tmp_x.append(x)

        features.append(x) # stage 3 output -> x


        # Depth Decoder Code
        self.outputs = {}
        # input = features

        # x = self.DepthDecoding.initialExtractor(input)
        # x1 = self.DepthDecoding.secondExtractor(input)
        # x2 = self.DepthDecoding.thirdExtractor(input)

        x = self.DepthDecoding.convs[("upconv", 2, 0)](x)
        x = self.DepthDecoding.upsampler(x)
        # x = self.listgen(x)
        # x = [upsample(x)]

        if self.DepthDecoding.use_skips:
                x = self.DepthDecoding.cat_append(x, x1)
            # else:
            # y = self.initialExtractor(input_features, i - 1)
            # # x += [input_features[i - 1]] # appending input_features to the upsampele list
            # x = self.cat_append(x, y)
        # x = torch.cat(x, 1)
        x = self.DepthDecoding.convs[("upconv", 2, 1)](x)

        if 2 in self.DepthDecoding.scales:
            f = self.DepthDecoding.convs[("dispconv", 2)](x)
            f = self.DepthDecoding.upsampler2(f)
            # f = upsample(self.convs[("dispconv", i)](x), mode='bilinear')
            self.outputs[("disp", 2)] = self.DepthDecoding.sigmoid(f)

        #------- next loop
        x = self.DepthDecoding.convs[("upconv", 1, 0)](x)
        x = self.DepthDecoding.upsampler(x)
        # x = self.listgen(x)
        # x = [upsample(x)]

        if self.DepthDecoding.use_skips:
                x = self.DepthDecoding.cat_append(x, x2)
            # else:
            # y = self.initialExtractor(input_features, i - 1)
            # # x += [input_features[i - 1]] # appending input_features to the upsampele list
            # x = self.cat_append(x, y)
        # x = torch.cat(x, 1)
        x = self.DepthDecoding.convs[("upconv", 1, 1)](x)

        if 1 in self.DepthDecoding.scales:
            f = self.DepthDecoding.convs[("dispconv", 1)](x)
            f = self.DepthDecoding.upsampler2(f)
            # f = upsample(self.convs[("dispconv", i)](x), mode='bilinear')
            self.outputs[("disp", 1)] = self.DepthDecoding.sigmoid(f)

        #------- next loop

        x = self.DepthDecoding.convs[("upconv", 0, 0)](x)
        x = self.DepthDecoding.upsampler(x)
        # x = self.listgen(x)
        # x = [upsample(x)]

           # else:
            # y = self.initialExtractor(input_features, i - 1)
            # # x += [input_features[i - 1]] # appending input_features to the upsampele list
            # x = self.cat_append(x, y)
        # x = torch.cat(x, 1)
        x = self.DepthDecoding.convs[("upconv", 0, 1)](x)

        if 0 in self.DepthDecoding.scales:
            f = self.DepthDecoding.convs[("dispconv", 0)](x)
            f = self.DepthDecoding.upsampler2(f)
            # f = upsample(self.convs[("dispconv", i)](x), mode='bilinear')
            self.outputs[("disp", 0)] = self.DepthDecoding.sigmoid(f)


        # x = self.DepthEncoding(x)
        print(f"{len(features)}, {features[0].shape}, {features[1].shape}, {features[2].shape}")
        print("Outputs")
        print(f"{len(self.outputs)}, {self.outputs[('disp', 0)].shape}, {self.outputs[('disp', 1)].shape}, {self.outputs[('disp', 2)].shape}")
        # x = self.DepthDecoding(x)
        return self.outputs

class ProcessPose(nn.Module):
    def __init__(self):
        super(ProcessPose, self).__init__()
        self.PoseEncoding = ResnetEncoder(num_layers=18, pretrained=False, num_input_images=2)
        self.PoseDecoding = PoseDecoder(num_ch_enc=self.PoseEncoding.num_ch_enc, num_input_features=1, num_frames_to_predict_for=2)
        self._annotate_submodules(self.PoseEncoding, 'PoseEncoding')
        self._annotate_submodules(self.PoseDecoding, 'PoseDecoding')
    def _annotate_submodules(self, module, model_name: str):
        for name, sub_module in module.named_modules():
            sub_module._name = model_name

    def forward(self, x):
        x = self.PoseEncoding(x)
        x = [x]
        x = self.PoseDecoding(x)
        return x

# Hardware assumptions (example values)
HARDWARE_CONFIG = {
    'compute_throughput': 10e12,  # 10 TFLOPS
    'compute_efficiency': 5e-12,  # 5 pJ per FLOP
    'memory_bandwidth': 900e9,    # 900 GB/s
    'memory_energy': 20e-12,      # 20 pJ per byte
    'interconnect_bandwidth': 400e9,  # 400 Gbps
    'interconnect_latency': 100e-9,   # 100ns base latency
    'interconnect_energy': 1e-12,     # 1 pJ per bit
}

def count_flops(module: nn.Module, in_shape: Tuple[int, ...], out_shape: Tuple[int, ...]) -> int:
    """Enhanced FLOP counter for various operations"""
    try:
        if isinstance(module, nn.Conv2d):
            # Handle case where input might be reshaped
            if len(in_shape) == 3:
                batch_size = 1
                in_channels, in_h, in_w = in_shape
            else:
                batch_size, in_channels, in_h, in_w = in_shape
                
            if len(out_shape) == 3:
                out_channels, out_h, out_w = out_shape
            else:
                _, out_channels, out_h, out_w = out_shape
                
            kernel_h, kernel_w = module.kernel_size
            flops = (2 * kernel_h * kernel_w * (in_channels // module.groups) - 1) * out_h * out_w * out_channels
            
        elif isinstance(module, nn.Linear):
            flops = (2 * module.in_features - 1) * module.out_features
            
        elif isinstance(module, nn.BatchNorm2d):
            if len(in_shape) == 3:
                channels, height, width = in_shape
            else:
                _, channels, height, width = in_shape
            flops = 2 * channels * height * width
            
        elif isinstance(module, (nn.ReLU, nn.ReLU6)):
            flops = np.prod(in_shape)
        
        elif isinstance(module, nn.GELU):
            flops = 12 * np.prod(in_shape)

        elif isinstance(module, nn.ELU):   
            flops = 4 * np.prod(in_shape)

        elif isinstance(module, nn.MaxPool2d):
            if len(out_shape) == 3:
                channels, height, width = out_shape
            else:
                _, channels, height, width = out_shape
            kernel_size = np.prod(module.kernel_size) if isinstance(module.kernel_size, tuple) else module.kernel_size**2
            flops = (kernel_size - 1) * channels * height * width
            
        elif isinstance(module, nn.AvgPool2d):
            if len(out_shape) == 3:
                channels, height, width = out_shape
            else:
                _, channels, height, width = out_shape
            kernel_size = np.prod(module.kernel_size) if isinstance(module.kernel_size, tuple) else module.kernel_size**2
            flops = kernel_size * channels * height * width
            
        # elif isinstance(module, (PoseDecoder, DepthDecoder)):
        #     # For decoders, sum up the FLOPs of their submodules
        #     flops = sum(count_flops(m, in_shape, out_shape) for m in module.modules() 
        #                 if isinstance(m, (nn.Conv2d, nn.Linear, nn.BatchNorm2d)))
        elif isinstance(module, ResidualAdd):
            channels, height, width = in_shape
            flops = height * width * channels
        elif isinstance(module, depth_encoder.LayerNorm):
            N_elements = np.prod(in_shape)
            flops = 8 * N_elements
        elif isinstance(module, depth_encoder.MatrixMultiply):
            q_shape = in_shape
            B, heads, d_h, N = q_shape
            flops = 2 * B * heads * N * d_h * d_h
        elif isinstance(module, depth_encoder.Softmax):
            B, heads, dh, _ = in_shape
            flops = 2 * B * heads * dh * dh
        elif isinstance(module, depth_encoder.WeightedSum):
            attn_shape = in_shape
            B, heads, dh, _ = attn_shape
            _, _, _, N = out_shape
            flops = 2 * N * d_h * B * heads * d_h
        elif isinstance(module, depth_encoder.GammaMultiply):
            N_elements = np.prod(in_shape)
            flops = N_elements
        elif isinstance(module, depth_encoder.PosEncode):
            B, H_spatial, W_spatial = in_shape  # Assuming in_shape = [B, H_spatial, W_spatial]
            hidden_dim = module.hidden_dim
            dim = module.dim

            # 1. Bitwise NOT
            flops = B * H_spatial * W_spatial  # ~mask

            # 2. Cumulative Sums
            flops += 2 * B * H_spatial * W_spatial  # y_embed and x_embed cumsum

            # 3. Normalization and Scaling
            flops += 2 * B * H_spatial * W_spatial  # y_embed and x_embed normalization and scaling

            # 4. Dimension Transformation
            flops += 4 * hidden_dim  # dim_t operations

            # 5. Positional Embeddings Division
            flops += 2 * B * H_spatial * W_spatial * hidden_dim  # pos_x and pos_y division

            # 6. Sin and Cos Transformations
            flops += 2 * B * H_spatial * W_spatial * hidden_dim  # pos_x sin/cos and pos_y sin/cos

            # 7. Token Projection (1x1 Conv)
            # FLOPs = 2 * out_channels * H_out * W_out * in_channels * 1 * 1
            # in_channels = hidden_dim * 2
            # out_channels = dim
            flops += 2 * dim * H_spatial * W_spatial * (hidden_dim * 2) * 1 * 1  # 4 * dim * hidden_dim * H * W
        else:
            flops = 0
        
    except Exception as e:
        print(f"Warning: Error calculating FLOPs for {type(module)}: {str(e)}")
        flops = 0

    return int(flops)  # Convert to int to avoid numpy types
def calculate_tensor_bytes(shape: Tuple[int, ...], dtype=torch.float32) -> int:
    """Calculate memory size in bytes for a tensor"""
    element_size = {
        torch.float32: 4,
        torch.float16: 2,
        torch.int8: 1,
    }.get(dtype, 4)
    return int(np.prod(shape) * element_size)  # Convert to int

def estimate_compute_metrics(flops: int) -> Dict[str, float]:
    """Estimate runtime and energy for computation"""
    runtime = float(flops / HARDWARE_CONFIG['compute_throughput'])
    energy = float(flops * HARDWARE_CONFIG['compute_efficiency'])
    return {
        'runtime': runtime,
        'energy': energy
    }

def estimate_communication_metrics(bytes_transferred: int) -> Dict[str, float]:
    """Estimate runtime and energy for data transfer"""
    bits_transferred = bytes_transferred * 8
    transfer_time = float(bits_transferred / HARDWARE_CONFIG['interconnect_bandwidth'])
    total_latency = float(HARDWARE_CONFIG['interconnect_latency'] + transfer_time)
    energy = float(bits_transferred * HARDWARE_CONFIG['interconnect_energy'])
    return {
        'runtime': total_latency,
        'energy': energy
    }

def build_onnx_from_json(json_nodes, json_edges):
    graph_nodes = []
    graph_inputs = []
    graph_outputs = []
    initializers = []
    node_map = {} 

    for node in json_nodes:
        node_id = node['id']
        op_type = node['op_type']
        node_name = f"node_{node_id}"
        node_map[node_id] = node_name
        input_names = []
        for edge in json_edges:
            if edge['destination'] == node_id:
                input_names.append(f"node_{edge['source']}_output")
        output_name = f"{node_name}_output"

        onnx_node = helper.make_node(
            op_type=op_type,
            inputs=input_names,
            outputs=[output_name],
            name=node_name
        )
        graph_nodes.append(onnx_node)

        if 'param_shapes' in node and node['weight_shapes']:
            for idx, shape in enumerate(node['weight_shapes']):
                param_name = f"{node_name}_param_{idx}"
                initializer = helper.make_tensor(
                    name=param_name,
                    data_type=onnx.TensorProto.FLOAT,
                    dims=shape,
                    vals=np.random.rand(*shape).astype(np.float32).flatten()
                )
                initializers.append(initializer)

    for edge in json_edges:
        if edge['source'] not in node_map:  
            input_name = f"node_{edge['source']}_output"
            graph_inputs.append(helper.make_tensor_value_info(
                input_name,
                onnx.TensorProto.FLOAT,
                edge['tensor_shape']
            ))
        if edge['destination'] not in node_map:  
            output_name = f"node_{edge['destination']}_output"
            graph_outputs.append(helper.make_tensor_value_info(
                output_name,
                onnx.TensorProto.FLOAT,
                edge['tensor_shape']
            ))

    graph = helper.make_graph(
        nodes=graph_nodes,
        name="ReconstructedGraph",
        inputs=graph_inputs,
        outputs=graph_outputs,
        initializer=initializers
    )

    model = helper.make_model(graph, producer_name="json_to_onnx")
    return model

class EnhancedDAGExtractor:
    def __init__(self, model_name: str = 'model'):
        self.nodes = []
        self.edges = []
        self.node_count = 0
        self.tensor_shapes = {}
        self.model_name = model_name
    
    def get_node_id(self) -> int:
        # self.node_count += 1
        return self.node_count
    
    def add_node(self, name: str, op_type: str, weight_shape: Tuple[int, ...], 
                flops: int, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...]) -> int:
        node_id = self.get_node_id()
        self.node_count+=1
        weight_bytes = calculate_tensor_bytes(weight_shape) if weight_shape else 0
        compute_metrics = estimate_compute_metrics(flops)
        
        self.nodes.append({
            "id": node_id,
            "name": name,
            "op_type": op_type,
            "weight_shape": list(weight_shape) if weight_shape else [],
            "weight_bytes": weight_bytes,
            "flops": flops,
            "input_shape": list(input_shape),
            "output_shape": list(output_shape),
            "estimated_runtime": compute_metrics['runtime'],
            "estimated_energy": compute_metrics['energy']
        })
        return node_id
    
    def add_edge(self, source_id: int, dest_id: int, tensor_shape: Tuple[int, ...]):
        tensor_bytes = calculate_tensor_bytes(tensor_shape)
        comm_metrics = estimate_communication_metrics(tensor_bytes)
        
        self.edges.append({
            "source": source_id,
            "destination": dest_id,
            "tensor_shape": list(tensor_shape),
            "tensor_bytes": tensor_bytes,
            "estimated_latency": comm_metrics['runtime'],
            "estimated_energy": comm_metrics['energy']
        })

    def _extract_first_tensor_shape(self, data):
        if isinstance(data, torch.Tensor):
            return tuple(data.shape)
        
        if isinstance(data, (tuple, list)) and len(data) > 0:
            return self._extract_first_tensor_shape(data[0])
        
        return None

    def _add_edges_for_nested_input(self, inp, dest_id):
        if isinstance(inp, torch.Tensor):
            if inp in self.tensor_shapes:
                source_id, tensor_shape = self.tensor_shapes[inp]
                self.add_edge(source_id, dest_id, tensor_shape)
            else:
                pass  # Ignore tensors not seen before
        elif isinstance(inp, (tuple, list)):
            for i in inp:
                self._add_edges_for_nested_input(i, dest_id)

    def hook_fn(self, module, input_tensor, output_tensor):
        node_id = self.get_node_id()
        op_type = module.__class__.__name__
        
        # input_shape = tuple(input_tensor[0].shape)
        # output_shape = tuple(output_tensor.shape)
        # weight_shape = tuple(module.weight.shape) if hasattr(module, 'weight') else None
        
        input_shape = self._extract_first_tensor_shape(input_tensor)
        output_shape = self._extract_first_tensor_shape(output_tensor)
        weight_shape = tuple(module.weight.shape) if hasattr(module, 'weight') else None

        flops = count_flops(module, input_shape, output_shape)

        model_ident = getattr(module, '_name', self.model_name)
        
        self.add_node(
            name=f"{model_ident}_{op_type}_{node_id}",
            op_type=op_type,
            weight_shape=weight_shape,
            flops=flops,
            input_shape=input_shape,
            output_shape=output_shape
        )
        
        self.tensor_shapes[output_tensor] = (node_id, output_shape)

        for inp in input_tensor:
            # if inp in self.tensor_shapes:
            #     source_id, tensor_shape = self.tensor_shapes[inp]
            #     self.add_edge(source_id, node_id, tensor_shape)
            self._add_edges_for_nested_input(inp, node_id)

    def is_shape_tuple(self, x):
        """
        Returns True if x is a tuple/list of ints, e.g. (1, 3, 224, 224).
        Returns False otherwise.
        """
        if not isinstance(x, (tuple, list)):
            return False
        return all(isinstance(el, int) for el in x)


    def extract_dag(self, model: nn.Module, input_size: Union[Tuple[int, ...], List[torch.Tensor]]):
        hooks = []
        for name, module in model.named_modules():
            if isinstance(module, (nn.Conv2d, nn.Linear, nn.BatchNorm2d, nn.ReLU, 
                                nn.MaxPool2d, nn.AvgPool2d, nn.ReLU6, ResidualAdd, nn.GELU, nn.AvgPool2d, 
                                depth_encoder.LayerNorm, depth_encoder.MatrixMultiply, depth_encoder.Softmax, depth_encoder.WeightedSum,
                                depth_encoder.GammaMultiply, DropPath, depth_encoder.PosEncode,
                                depth_encoder.Reshape3d, depth_encoder.Reshape4d, depth_encoder.Reshape5d,
                                depth_encoder.Permute3d, depth_encoder.Permute4d, depth_encoder.Permute5d, 
                                depth_encoder.Transpose2d, depth_encoder.Normalize2d, depth_encoder.Cat,
                                depth_encoder.Extract2dq, depth_encoder.Extract2dv, depth_encoder.Extract2dk,
                                nn.ReflectionPad2d, nn.ZeroPad2d, nn.ELU, nn.Sigmoid, Upsampling, ExtractInitial,
                                ExtractSecond, ExtractThird)):
                # if(isinstance(module, ResidualAdd)):
                #     print("This is in extract DAG for ResidualAdd")
                hooks.append(module.register_forward_hook(self.hook_fn))
        
        if self.is_shape_tuple(input_size):
            print(f"Creating dummy input tensor of size: {input_size}")
            dummy_input = torch.randn(input_size)
        else:
            print("Using provided input tensor(s)")
            dummy_input = input_size

        # # Handle both single tensor and list of tensor inputs
        # if isinstance(input_size, (tuple, list)) and isinstance(input_size[0], torch.Tensor):
        #     dummy_input = input_size  # Use provided tensors directly
        # else:
        #     dummy_input = torch.randn(input_size)  # Create new tensor
            
        model(dummy_input)
        # print("Ran model with dummy input to extract DAG")

        
        for hook in hooks:
            hook.remove()

        # print("Removed hooks after extracting DAG")

        onnx_model = build_onnx_from_json(self.nodes, self.edges)

        # print("built onnx model")

        onnx.save(onnx_model, "reconstructed_model_depthencoder.onnx")
        
        # print(f"Num Nodes: {len(self.nodes)}, Num Edges: {len(self.edges)}")
        return {
            "nodes": self.nodes,
            "edges": self.edges,
            "hardware_config": HARDWARE_CONFIG
        }

def analyze_model(model_name: str, model: nn.Module, input_size: Tuple[int, ...]):
    extractor = EnhancedDAGExtractor(model_name=model_name)
    dag = extractor.extract_dag(model, input_size)

    print(f"Extracted DAG for {model_name}")
    print(f"Num Nodes: {len(dag['nodes'])}, Num Edges: {len(dag['edges'])}")
    
    with open(f'{model_name}_dag_enhanced.json', 'w') as f:
        json.dump(dag, f, indent=2, cls=NumpyFloatEncoder)


# input_features = torch.randn(1, 6, 224, 224)

# # # Analyze ResNet18
# resnet18 = ResnetEncoder(num_layers=18, pretrained=False, num_input_images=2)

# encoder_features = resnet18(input_features)

# # # # resnet18 = models.resnet18(pretrained=False)
# # analyze_model('resnet18', resnet18, (1, 6, 224, 224))

# # # Analyze PoseDecoder
# # # num_ch_enc = np.array([64, 64, 128, 256, 512])  # Example encoder channels
# num_input_features = 1  # Add this parameter
# pose_decoder = PoseDecoder(
#     num_ch_enc=resnet18.num_ch_enc,
#     num_input_features=num_input_features, 
#     num_frames_to_predict_for=2
# )
# # # Create dummy input features list
# pose_input_features = [
#     encoder_features
# ]

# # print("Pose input tensor shape: ", len(encoder_features))

# analyze_model('pose_decoder', pose_decoder, pose_input_features)

# Analyze DepthDecoder
# depth_decoder = DepthDecoder(
#     num_ch_enc=num_ch_enc,
#     scales=range(4),
#     num_output_channels=1,
#     use_skips=True
# )
# # Create dummy input features list
# depth_input_features = [
#     torch.randn(1, 64, 56, 56),    # First encoder feature
#     torch.randn(1, 64, 28, 28),    # Second encoder feature
#     torch.randn(1, 128, 14, 14),   # Third encoder feature
#     torch.randn(1, 256, 7, 7),     # Fourth encoder feature
#     torch.randn(1, 512, 7, 7)      # Fifth encoder feature
# ]
# analyze_model('depth_decoder', depth_decoder, depth_input_features)

input_features = torch.randn(1, 3, 224, 224)

process_batch = ProcessBatch()

# analyze_model('process_batch', process_batch, (1, 3, 224, 224))

process_pose = ProcessPose()

analyze_model('process_pose', process_pose, (1, 6, 224, 224))

# depthencoder = depth_encoder.LiteMono() 

# # analyze_model('depth_encoder', depthencoder, (1, 3, 224, 224))

# output_encoder = depthencoder(input_features)

# depth_decoder = DepthDecoder(depthencoder.num_ch_enc, scales = range(3))

# analyze_model('depth_decoder', depth_decoder, output_encoder)

# print(depth_decoder(output_encoder))



Creating dummy input tensor of size: (1, 6, 224, 224)
Ran ResidualAdd
Ran ResidualAdd
Ran ResidualAdd
Ran ResidualAdd
Ran ResidualAdd
Ran ResidualAdd
Ran ResidualAdd
Ran ResidualAdd
Extracted DAG for process_pose
Num Nodes: 74, Num Edges: 81


[W111 14:04:51.055985295 NNPACK.cpp:61] Could not initialize NNPACK! Reason: Unsupported hardware.
