In [4]:
import torch
print(torch.__version__)

2.4.1


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ScaleSpecificTransformerEncoder(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, src):
        # Self-attention
        src2 = self.self_attn(src, src, src)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        
        # Feedforward network
        src2 = self.linear2(self.dropout(F.relu(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        
        print(f"ScaleSpecificTransformerEncoder output shape: {src.shape}")
        return src

class CrossScaleAttention(nn.Module):
    def __init__(self, d_model, nhead, dropout=0.1):
        super().__init__()
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)

    def forward(self, query, key, value):
        output = self.multihead_attn(query, key, value)[0]
        print(f"CrossScaleAttention output shape: {output.shape}")
        return output

class HierarchicalPooling(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(in_features, out_features)

    def forward(self, x):
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        print(f"HierarchicalPooling output shape: {x.shape}")
        return x

class HierarchicalTransformer(nn.Module):
    def __init__(self, num_scales, d_model, nhead, num_classes):
        super().__init__()
        self.num_scales = num_scales
        self.d_model = d_model

        # Scale-specific transformers
        self.transformers = nn.ModuleList([
            ScaleSpecificTransformerEncoder(d_model, nhead)
            for _ in range(num_scales)
        ])

        # Cross-scale attention
        self.cross_attentions = nn.ModuleList([
            CrossScaleAttention(d_model, nhead)
            for _ in range(num_scales - 1)
        ])

        # Hierarchical pooling
        self.hierarchical_pooling = nn.ModuleList([
            HierarchicalPooling(d_model, d_model)
            for _ in range(num_scales - 1)
        ])

        # Final classification layer
        self.classifier = nn.Linear(d_model, num_classes)

    def forward(self, features):
        print(f"Input features shapes: {[f.shape for f in features]}")

        # Apply scale-specific transformers
        transformed_features = [
            transformer(feature.flatten(2).permute(2, 0, 1))
            for transformer, feature in zip(self.transformers, features)
        ]
        print(f"Transformed features shapes: {[tf.shape for tf in transformed_features]}")

        # Apply cross-scale attention and hierarchical pooling
        for i in range(self.num_scales - 1, 0, -1):
            cross_attn = self.cross_attentions[i-1](
                transformed_features[i-1],
                transformed_features[i],
                transformed_features[i]
            )
            print(f"Cross attention output shape at scale {i}: {cross_attn.shape}")

            pooled = self.hierarchical_pooling[i-1](cross_attn.permute(1, 2, 0).unsqueeze(-1))
            print(f"Pooled shape at scale {i}: {pooled.shape}")

            transformed_features[i-1] = transformed_features[i-1] + pooled.permute(2, 0, 1)
            print(f"Updated transformed feature shape at scale {i-1}: {transformed_features[i-1].shape}")

        # Final classification
        output = self.classifier(transformed_features[0].mean(dim=0))
        print(f"Final output shape: {output.shape}")

        return output

# Example usage and debug
if __name__ == "__main__":
    # Example parameters
    num_scales = 3
    d_model = 256
    nhead = 8
    num_classes = 10
    batch_size = 4

    # Create model
    model = HierarchicalTransformer(num_scales, d_model, nhead, num_classes)

    # Create dummy input
    features = [
        torch.randn(batch_size, d_model, 64, 64),  # Scale 1
        torch.randn(batch_size, d_model, 32, 32),  # Scale 2
        torch.randn(batch_size, d_model, 16, 16)   # Scale 3
    ]

    # Forward pass
    output = model(features)
    print(f"Final classification output shape: {output.shape}")

: 

In [6]:
# Example usage
if __name__ == "__main__":
    # Assuming 3 scales, with feature dimensions of 256, 512, and 1024
    num_scales = 3
    d_model = 256  # We'll project all features to this dimension
    nhead = 8
    num_classes = 10  # For CIFAR-10

    # Create dummy input features
    features = [
        torch.randn(32, 256, 56, 56),
        torch.randn(32, 512, 28, 28),
        torch.randn(32, 1024, 14, 14)
    ]

    # Initialize the model
    model = HierarchicalTransformer(num_scales, d_model, nhead, num_classes)

    # Forward pass
    output = model(features)
    print(f"Output shape: {output.shape}")

print("Hierarchical Transformer implementation complete.")