In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import math

In [6]:


class Encoder2D(nn.Module):
    def __init__(self, repr_dim, input_size=65):
        super().__init__()
        self.repr_dim = repr_dim
        self.output_side = int(math.sqrt(repr_dim))  # Calculate the side of the 2D embedding

        # Determine the number of convolutional blocks required
        self.num_conv_blocks = int(math.log2(input_size / self.output_side))
        if 2 ** self.num_conv_blocks * self.output_side != 2 ** int(math.log2(input_size)):
            raise ValueError("Cannot evenly reduce input_size to output_side using stride-2 convolutions.")

        layers = []
        in_channels = 2  # Input has 2 channels (agent and wall)
        out_channels = 32  # Start with 32 output channels
        for i in range(self.num_conv_blocks):
            layers.append(
                nn.Conv2d(
                    in_channels,
                    out_channels,
                    kernel_size=3,
                    stride=2,
                    padding=0,
                ) if i == 0 else
                nn.Conv2d(
                    in_channels,
                    out_channels,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                )
            )  # Halve the spatial dimensions
            layers.append(nn.ReLU())
            in_channels = out_channels
            out_channels = min(out_channels * 2, 256)  # Cap channels at 256

        # Final convolution to reduce to single-channel output
        layers.append(nn.Conv2d(in_channels, 1, kernel_size=1))  # Single-channel embedding

        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        # Input: (B, 2, 65, 65)
        x = self.conv(x)  # Dynamically reduce to (B, 1, output_side, output_side)
        return x  # Output shape: (B, 1, output_side, output_side)

# Instantiate the Encoder2D with input size 65x65 and repr_dim 256
encoder = Encoder2D(repr_dim=256, input_size=65)
print(encoder)

# Test with a dummy input
input_tensor = torch.randn(1, 2, 65, 65)  # Batch size of 1, 2 channels, 65x65 input
output = encoder(input_tensor)

output.shape

Encoder2D(
  (conv): Sequential(
    (0): Conv2d(2, 32, kernel_size=(3, 3), stride=(2, 2))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1))
  )
)


torch.Size([1, 1, 16, 16])

In [None]:
import timm
import torch
import torch.nn as nn
import math

class FlexibleEncoder2D(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.repr_dim = config.embed_dim

        # Ensure output size is consistent
        self.output_side = int(math.sqrt(self.repr_dim))  # Should be 16 for config.embed_dim=256
        if self.output_side != 16:
            raise ValueError("Output side must be 16 for config.embed_dim=256.")

        # Dynamically select backbone based on config.encoder_backbone
        self.backbone = timm.create_model(
            config.encoder_backbone,  # Example: 'resnet18.a1_in1k'
            pretrained=False,  # No pretraining allowed
            num_classes=0,  # No classifier head
            in_chans=2,  # Input has 2 channels
            features_only=True,  # Extract spatial features
        )

        # Inspect available feature maps
        self.feature_channels = [info['num_chs'] for info in self.backbone.feature_info]
        self.feature_shapes = [info['reduction'] for info in self.backbone.feature_info]  # Spatial size reductions

        # Select the layer closest to 16x16
        self.closest_layer_index = self._find_closest_layer()

        # Final adjustment to 16x16
        self.adjust_to_target = nn.Sequential(
            nn.Conv2d(self.feature_channels[self.closest_layer_index], 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1)
        )

    def _find_closest_layer(self):
        # Find the layer whose spatial size is closest to output_side
        input_size = 65  # Assumes input spatial dimensions (H, W) = 65x65
        reductions = [input_size // red for red in self.feature_shapes]
        closest_index = min(range(len(reductions)), key=lambda i: abs(reductions[i] - self.output_side))
        return closest_index

    def forward(self, x):
        # Pass input through the backbone and select the appropriate layer
        features = self.backbone(x)
        x = features[self.closest_layer_index]  # Closest layer to 16x16

        # Adjust to target shape
        x = self.adjust_to_target(x)
        return x


# Define the configuration class
class Config:
    embed_dim = 256  # Output embedding size
    encoder_backbone = 'resnet18.a1_in1k'  # Use ResNet-18 as the backbone

# Instantiate and test the model
config = Config()
model = FlexibleEncoder2D(config)

# Generate a random input tensor
input_tensor = torch.randn(4, 2, 65, 65)  # Example input (B, 2, 65, 65)

# Run the model and check the output size
output = model(input_tensor)
output.shape


torch.Size([4, 1, 17, 17])

In [29]:
import torch.nn as nn

def print_param_count(model: nn.Module):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params}")
    print(f"Trainable parameters: {trainable_params}")

In [None]:
import timm
import torch
import torch.nn as nn
import math

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.convnext = timm.create_model(
            'resnet18.a1_in1k', 
            pretrained=False, 
            num_classes=0, 
            in_chans=2, 
            features_only=True
        )
        feature_info = self.convnext.feature_info
        self.convnext2 = timm.create_model(
            'resnet18.a1_in1k', 
            pretrained=False, 
            num_classes=0, 
            in_chans=2, 
            features_only=True, 
            out_indices=(1,)  # Only keep up to layer1
        )
        print_param_count(self.convnext)
        print_param_count(self.convnext2)
    
    def forward(self, x):
        # Reshape input to merge batch and trajectory dimensions
        original_shape = x.shape
        x = x.view(-1, *original_shape[-3:])  # Reshape to [batch*trajectory, channels, height, width]
        features = self.convnext(x)[1]
        
        # Reshape features back to original trajectory structure
        features1 = features.view(original_shape[0], original_shape[1], *features.shape[-3:])

        features = self.convnext2(x)[0]
        
        # Reshape features back to original trajectory structure
        features2 = features.view(original_shape[0], original_shape[1], *features.shape[-3:])
        return features1, features2


# Define the configuration class
class Config:
    embed_dim = 256  # Output embedding size
    encoder_backbone = 'resnet18.a1_in1k'  # Use ResNet-18 as the backbone

# Instantiate and test the model
config = Config()
model = Encoder()

# Generate a random input tensor
input_tensor = torch.ones(4, 4, 2, 65, 65)  # Example input (B, 2, 65, 65)

# Run the model and check the output size
features1, features2 = model(input_tensor)
print(features1.shape, features2.shape)

print(features1[0][0][0][0])
print(features2[0][0][0][0])


Total parameters: 11173376
Trainable parameters: 11173376
Total parameters: 0
Trainable parameters: 0
torch.Size([4, 4, 64, 33, 33]) torch.Size([4, 4, 2, 65, 65])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0.], grad_fn=<SelectBackward0>)
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])


In [8]:
import torch
import timm

# Function to calculate the number of parameters in a PyTorch model (in millions)
def calculate_model_parameters(model):
    """
    Calculate the number of parameters in a PyTorch model in millions.

    Args:
        model (torch.nn.Module): PyTorch model to calculate the parameters for.

    Returns:
        float: Number of parameters in millions.
    """
    total_params = sum(p.numel() for p in model.parameters())
    return total_params / 1e6  # Convert to millions

# Full ResNet-18 model
full_model = timm.create_model('resnet18.a1_in1k', pretrained=False, in_chans=2)

# Truncated model with features_only=True
truncated_model = timm.create_model(
    'resnet18.a1_in1k', 
    pretrained=False, 
    in_chans=2, 
    features_only=True, 
    out_indices=(1,)  # Only keep up to layer1
)

# Calculate and print number of parameters
full_model_params = calculate_model_parameters(full_model)
truncated_model_params = calculate_model_parameters(truncated_model)

print(f"Full model parameters: {full_model_params:.2f} million")
print(f"Truncated model parameters: {truncated_model_params:.2f} million")


Full model parameters: 11.69 million
Truncated model parameters: 0.15 million


In [37]:
class Predictor(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.predictor = nn.Sequential(
            nn.Conv2d(input_dim, input_dim-2, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(input_dim-2, output_dim, kernel_size=3, padding=1)
        )
    
    def forward(self, encoded_o_t, action):
        # Reshape inputs
        batch_size, trajectory_length = encoded_o_t.shape[:2]
        
        # Reshape action to match encoded_o_t dimensions
        action = action.view(batch_size, trajectory_length-1, 2, 1, 1)
        action = action.repeat(1, 1, 1, encoded_o_t.size(3), encoded_o_t.size(4))
        
        # Prepare inputs for prediction
        predictions = []
        for t in range(trajectory_length - 1):
            # Concatenate current encoded state with action
            x = torch.cat([encoded_o_t[:, t], action[:, t]], dim=1)
            pred = self.predictor(x)
            predictions.append(pred)
        
        return torch.stack(predictions, dim=1)
    
# Generate a random action tensor
action_tensor = torch.randn(4, 3, 2)  # Example action (B, T-1, 2)

# Pass the encoded output and action through the predictor
predictor = Predictor(input_dim=66, output_dim=64)  # Assuming input_dim=3 (encoded_o_t channels + action channels) and output_dim=1
predicted_output = predictor(output, action_tensor)

# Check the output size
predicted_output.shape

torch.Size([4, 3, 64, 17, 17])

In [40]:
convnext = timm.create_model('resnet18.a1_in1k', pretrained=False, num_classes=0, in_chans=2, features_only=True)

In [41]:
convnext

FeatureListNet(
  (conv1): Conv2d(2, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=

In [2]:
import timm
import torch
from tqdm import tqdm

# Initialize an empty list to store model details
model_details = []

# Get all available models in timm
model_names = timm.list_models()

# Loop through each model
for model_name in tqdm(model_names):
    try:
        # Create the model without a classifier head
        model = timm.create_model(model_name, pretrained=False, num_classes=0)
        # Get the number of parameters
        num_params = sum(p.numel() for p in model.parameters())
        # Add to the list
        model_details.append((model_name, num_params))
    except Exception as e:
        # Some models may fail to initialize; skip them
        print(f"Error with model {model_name}: {e}")

# Sort models by size
sorted_models = sorted(model_details, key=lambda x: x[1])

# Display the models and their sizes
for name, size in sorted_models:
    print(f"{name}: {size} parameters")


  0%|          | 2/1170 [00:12<2:02:20,  6.28s/it]


KeyboardInterrupt: 

In [12]:
import timm
import torch
import torch.nn as nn

def print_param_count(model: nn.Module):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params}")
    print(f"Trainable parameters: {trainable_params}")

def set_identities_after(module, target_name_parts):
    """
    In-place modifies `module` so that all layers after the path specified by 
    `target_name_parts` are replaced with nn.Identity().
    """
    if not target_name_parts:
        # No more parts, means we've reached the target module.
        return

    part = target_name_parts[0]

    # Check if part is a digit (index in sequential) or a named submodule
    if part.isdigit():
        idx = int(part)
        # Handle sequential-like containers
        if isinstance(module, (nn.Sequential, nn.ModuleList)):
            for i, (n, m) in enumerate(module.named_children()):
                if i == idx:
                    # Recurse into the correct child
                    set_identities_after(m, target_name_parts[1:])
                elif i > idx:
                    # Replace all subsequent modules with Identity
                    setattr(module, n, nn.Identity())
        else:
            raise ValueError(f"Expected a sequential-like module at '{part}', got {type(module)}")
    else:
        # Handle named submodules
        if not hasattr(module, part):
            raise ValueError(f"Module '{part}' not found in {module}")

        found = False
        for n, m in module.named_children():
            if n == part:
                # Recurse into the matching submodule
                set_identities_after(m, target_name_parts[1:])
                found = True
            elif found:
                # Replace all subsequent layers after the target module
                setattr(module, n, nn.Identity())

def build_minimal_model(base_model: nn.Module, selected_feature_layer: str):
    """
    Build a minimal model from base_model by replacing all layers after 
    `selected_feature_layer` with nn.Identity().
    """
    if not selected_feature_layer:
        return base_model  # No modification needed

    # Split the layer path by '.' to navigate the hierarchy
    path_parts = selected_feature_layer.split('.')
    set_identities_after(base_model, path_parts)
    return base_model


# Example usage:
feature_index = 1  # Index of the feature map to extract
model_name = 'efficientnet_b0.ra_in1k'  # Can also try 'resnet18.a1_in1k'

print(model_name)
full_model = timm.create_model(
    model_name,
    pretrained=False,
    num_classes=0,
    in_chans=2,
    features_only=True
)

feature_info = full_model.feature_info
selected_feature_layer = feature_info[feature_index]['module']

# Create the base model (full)
base_model = timm.create_model(
    model_name,
    pretrained=False,
    num_classes=0,
    in_chans=2
)

# Build the minimal model
minimal_model = build_minimal_model(base_model, selected_feature_layer)
del base_model

# Compare parameter counts
print("Full model parameters:")
print_param_count(full_model)

print("Minimal model parameters:")
print_param_count(minimal_model)

# Input tensor
input_tensor = torch.ones(4, 2, 65, 65)  # Example input (B, 2, 65, 65)

# Pass through the full model
full_model_outputs = full_model(input_tensor)
full_model_output = full_model_outputs[feature_index]

# Pass through the minimal model
minimal_model_output = minimal_model(input_tensor)

# Print output sizes
print("Full Model Output Size:", full_model_output.shape)
print("Minimal Model Output Size:", minimal_model_output.shape)

print("Feature info:")
for i in range(len(feature_info)):
    print(feature_info[i])

print(full_model)

efficientnet_b0.ra_in1k


Full model parameters:
Total parameters: 3595100
Trainable parameters: 3595100
Minimal model parameters:
Total parameters: 18802
Trainable parameters: 18802
Full Model Output Size: torch.Size([4, 24, 17, 17])
Minimal Model Output Size: torch.Size([4, 24, 17, 17])
Feature info:
{'stage': 1, 'reduction': 2, 'module': 'blocks.0', 'num_chs': 16, 'index': 0}
{'stage': 2, 'reduction': 4, 'module': 'blocks.1', 'num_chs': 24, 'index': 1}
{'stage': 3, 'reduction': 8, 'module': 'blocks.2', 'num_chs': 40, 'index': 2}
{'stage': 5, 'reduction': 16, 'module': 'blocks.4', 'num_chs': 112, 'index': 3}
{'stage': 7, 'reduction': 32, 'module': 'blocks.6', 'num_chs': 320, 'index': 4}
EfficientNetFeatures(
  (conv_stem): Conv2d(2, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): BatchNormAct2d(
    32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
    (drop): Identity()
    (act): SiLU(inplace=True)
  )
  (blocks): Sequential(
    (0): Sequential(
      (0): Depth

In [61]:
print(minimal_model)

Sequential(
  (conv_stem): Conv2d(2, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): BatchNormAct2d(
    32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
    (drop): Identity()
    (act): SiLU(inplace=True)
  )
  (blocks): Sequential(
    (0): Sequential(
      (0): DepthwiseSeparableConv(
        (conv_dw): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (bn1): BatchNormAct2d(
          32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (aa): Identity()
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
          (act1): SiLU(inplace=True)
          (conv_expand): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
          (gate): Sigmoid()
        )
        (conv_pw): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2):

In [47]:
full_model

FeatureListNet(
  (conv1): Conv2d(2, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=

In [16]:
base_model = timm.create_model(
    model_name,
    pretrained=False,
    num_classes=0,
    in_chans=2,
    features_only=True
)

for x in base_model.named_modules():
    print(x)

('', EfficientNetFeatures(
  (conv_stem): Conv2d(2, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): BatchNormAct2d(
    32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
    (drop): Identity()
    (act): SiLU(inplace=True)
  )
  (blocks): Sequential(
    (0): Sequential(
      (0): DepthwiseSeparableConv(
        (conv_dw): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (bn1): BatchNormAct2d(
          32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (aa): Identity()
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
          (act1): SiLU(inplace=True)
          (conv_expand): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
          (gate): Sigmoid()
        )
        (conv_pw): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)

In [18]:
from models import ActionRegularizationJEPA2Dv0

model = ActionRegularizationJEPA2Dv0()
model.load_state_dict(torch.load('../weights/best_expert_model_epoch_4_train_iter_76_normal_loss_21.21573_wall_loss_19.79489_expert_loss_85.14044.pt'))

ModuleNotFoundError: No module named 'models'