In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
import json
from typing import Dict, List, Tuple
import numpy as np
from pose_decoder import PoseDecoder
from depth_decoder import DepthDecoder
from typing import Union

class NumpyFloatEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, (np.float32, np.float64, np.int64)):
            return float(obj)
        return super().default(obj)

# Hardware assumptions (example values)
HARDWARE_CONFIG = {
    'compute_throughput': 10e12,  # 10 TFLOPS
    'compute_efficiency': 5e-12,  # 5 pJ per FLOP
    'memory_bandwidth': 900e9,    # 900 GB/s
    'memory_energy': 20e-12,      # 20 pJ per byte
    'interconnect_bandwidth': 400e9,  # 400 Gbps
    'interconnect_latency': 100e-9,   # 100ns base latency
    'interconnect_energy': 1e-12,     # 1 pJ per bit
}

def count_flops(module: nn.Module, in_shape: Tuple[int, ...], out_shape: Tuple[int, ...]) -> int:
    """Enhanced FLOP counter for various operations"""
    try:
        if isinstance(module, nn.Conv2d):
            # Handle case where input might be reshaped
            if len(in_shape) == 3:
                batch_size = 1
                in_channels, in_h, in_w = in_shape
            else:
                batch_size, in_channels, in_h, in_w = in_shape
                
            if len(out_shape) == 3:
                out_channels, out_h, out_w = out_shape
            else:
                _, out_channels, out_h, out_w = out_shape
                
            kernel_h, kernel_w = module.kernel_size
            flops = (2 * kernel_h * kernel_w * (in_channels // module.groups) - 1) * out_h * out_w * out_channels
            
        elif isinstance(module, nn.Linear):
            flops = (2 * module.in_features - 1) * module.out_features
            
        elif isinstance(module, nn.BatchNorm2d):
            if len(in_shape) == 3:
                channels, height, width = in_shape
            else:
                _, channels, height, width = in_shape
            flops = 2 * channels * height * width
            
        elif isinstance(module, (nn.ReLU, nn.ReLU6)):
            flops = np.prod(in_shape)
            
        elif isinstance(module, nn.MaxPool2d):
            if len(out_shape) == 3:
                channels, height, width = out_shape
            else:
                _, channels, height, width = out_shape
            kernel_size = np.prod(module.kernel_size) if isinstance(module.kernel_size, tuple) else module.kernel_size**2
            flops = (kernel_size - 1) * channels * height * width
            
        elif isinstance(module, nn.AvgPool2d):
            if len(out_shape) == 3:
                channels, height, width = out_shape
            else:
                _, channels, height, width = out_shape
            kernel_size = np.prod(module.kernel_size) if isinstance(module.kernel_size, tuple) else module.kernel_size**2
            flops = kernel_size * channels * height * width
            
        elif isinstance(module, (PoseDecoder, DepthDecoder)):
            # For decoders, sum up the FLOPs of their submodules
            flops = sum(count_flops(m, in_shape, out_shape) for m in module.modules() 
                        if isinstance(m, (nn.Conv2d, nn.Linear, nn.BatchNorm2d)))
        else:
            flops = 0
        
    except Exception as e:
        print(f"Warning: Error calculating FLOPs for {type(module)}: {str(e)}")
        flops = 0

    return int(flops)  # Convert to int to avoid numpy types
def calculate_tensor_bytes(shape: Tuple[int, ...], dtype=torch.float32) -> int:
    """Calculate memory size in bytes for a tensor"""
    element_size = {
        torch.float32: 4,
        torch.float16: 2,
        torch.int8: 1,
    }.get(dtype, 4)
    return int(np.prod(shape) * element_size)  # Convert to int

def estimate_compute_metrics(flops: int) -> Dict[str, float]:
    """Estimate runtime and energy for computation"""
    runtime = float(flops / HARDWARE_CONFIG['compute_throughput'])
    energy = float(flops * HARDWARE_CONFIG['compute_efficiency'])
    return {
        'runtime': runtime,
        'energy': energy
    }

def estimate_communication_metrics(bytes_transferred: int) -> Dict[str, float]:
    """Estimate runtime and energy for data transfer"""
    bits_transferred = bytes_transferred * 8
    transfer_time = float(bits_transferred / HARDWARE_CONFIG['interconnect_bandwidth'])
    total_latency = float(HARDWARE_CONFIG['interconnect_latency'] + transfer_time)
    energy = float(bits_transferred * HARDWARE_CONFIG['interconnect_energy'])
    return {
        'runtime': total_latency,
        'energy': energy
    }

class EnhancedDAGExtractor:
    def __init__(self):
        self.nodes = []
        self.edges = []
        self.node_count = 0
        self.tensor_shapes = {}
    
    def get_node_id(self) -> int:
        self.node_count += 1
        return self.node_count - 1
    
    def add_node(self, name: str, op_type: str, weight_shape: Tuple[int, ...], 
                flops: int, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...]) -> int:
        node_id = self.get_node_id()
        
        weight_bytes = calculate_tensor_bytes(weight_shape) if weight_shape else 0
        compute_metrics = estimate_compute_metrics(flops)
        
        self.nodes.append({
            "id": node_id,
            "name": name,
            "op_type": op_type,
            "weight_shape": list(weight_shape) if weight_shape else [],
            "weight_bytes": weight_bytes,
            "flops": flops,
            "input_shape": list(input_shape),
            "output_shape": list(output_shape),
            "estimated_runtime": compute_metrics['runtime'],
            "estimated_energy": compute_metrics['energy']
        })
        return node_id
    
    def add_edge(self, source_id: int, dest_id: int, tensor_shape: Tuple[int, ...]):
        tensor_bytes = calculate_tensor_bytes(tensor_shape)
        comm_metrics = estimate_communication_metrics(tensor_bytes)
        
        self.edges.append({
            "source": source_id,
            "destination": dest_id,
            "tensor_shape": list(tensor_shape),
            "tensor_bytes": tensor_bytes,
            "estimated_latency": comm_metrics['runtime'],
            "estimated_energy": comm_metrics['energy']
        })

    def hook_fn(self, module, input_tensor, output_tensor):
        node_id = self.get_node_id()
        op_type = module.__class__.__name__
        
        input_shape = tuple(input_tensor[0].shape)
        output_shape = tuple(output_tensor.shape)
        weight_shape = tuple(module.weight.shape) if hasattr(module, 'weight') else None
        
        flops = count_flops(module, input_shape, output_shape)
        
        self.add_node(
            name=f"{op_type}_{node_id}",
            op_type=op_type,
            weight_shape=weight_shape,
            flops=flops,
            input_shape=input_shape,
            output_shape=output_shape
        )
        
        self.tensor_shapes[output_tensor] = (node_id, output_shape)

        for inp in input_tensor:
            if inp in self.tensor_shapes:
                source_id, tensor_shape = self.tensor_shapes[inp]
                self.add_edge(source_id, node_id, tensor_shape)

    def extract_dag(self, model: nn.Module, input_size: Union[Tuple[int, ...], List[torch.Tensor]]):
        hooks = []
        for name, module in model.named_modules():
            if isinstance(module, (nn.Conv2d, nn.Linear, nn.BatchNorm2d, nn.ReLU, 
                                nn.MaxPool2d, nn.AvgPool2d, nn.ReLU6,
                                PoseDecoder, DepthDecoder)):
                hooks.append(module.register_forward_hook(self.hook_fn))
        
        # Handle both single tensor and list of tensor inputs
        if isinstance(input_size, (tuple, list)) and isinstance(input_size[0], torch.Tensor):
            dummy_input = input_size  # Use provided tensors directly
        else:
            dummy_input = torch.randn(input_size)  # Create new tensor
            
        model(dummy_input)
        
        for hook in hooks:
            hook.remove()
        
        return {
            "nodes": self.nodes,
            "edges": self.edges,
            "hardware_config": HARDWARE_CONFIG
        }

def analyze_model(model_name: str, model: nn.Module, input_size: Tuple[int, ...]):
    extractor = EnhancedDAGExtractor()
    dag = extractor.extract_dag(model, input_size)
    
    with open(f'{model_name}_dag_enhanced.json', 'w') as f:
        json.dump(dag, f, indent=2, cls=NumpyFloatEncoder)

# Analyze ResNet18
resnet18 = models.resnet18(pretrained=False)
analyze_model('resnet18', resnet18, (1, 3, 224, 224))

# Analyze PoseDecoder
num_ch_enc = np.array([64, 64, 128, 256, 512])  # Example encoder channels
num_input_features = 512  # Add this parameter
pose_decoder = PoseDecoder(
    num_ch_enc=num_ch_enc,
    num_input_features=num_input_features
)
# Create dummy input features list
pose_input_features = [
    torch.randn(1, 64, 56, 56),    # First encoder feature
    torch.randn(1, 64, 28, 28),    # Second encoder feature
    torch.randn(1, 128, 14, 14),   # Third encoder feature
    torch.randn(1, 256, 7, 7),     # Fourth encoder feature
    torch.randn(1, 512, 7, 7)      # Fifth encoder feature
]
analyze_model('pose_decoder', pose_decoder, pose_input_features)

# Analyze DepthDecoder
depth_decoder = DepthDecoder(
    num_ch_enc=num_ch_enc,
    scales=range(4),
    num_output_channels=1,
    use_skips=True
)
# Create dummy input features list
depth_input_features = [
    torch.randn(1, 64, 56, 56),    # First encoder feature
    torch.randn(1, 64, 28, 28),    # Second encoder feature
    torch.randn(1, 128, 14, 14),   # Third encoder feature
    torch.randn(1, 256, 7, 7),     # Fourth encoder feature
    torch.randn(1, 512, 7, 7)      # Fifth encoder feature
]
analyze_model('depth_decoder', depth_decoder, depth_input_features)

  from .autonotebook import tqdm as notebook_tqdm


TypeError: randn(): argument 'size' (position 1) must be tuple of ints, but found element of type ResNet at pos 0