## Feature Extractor

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvEmbedding(nn.Module):
    def __init__(self, in_channels=1, embed_dim=128):
        super(ConvEmbedding, self).__init__()
        self.conv = nn.Conv2d(in_channels, embed_dim, kernel_size=3, padding=1)
        self.flatten = nn.Flatten(2)  # Flatten the spatial dimensions (H, W)
    
    def forward(self, x):
        x = self.conv(x)  # Apply convolution
        x = self.flatten(x)  # Flatten to (batch_size, embed_dim, H*W)
        x = x.transpose(1, 2)  # Transpose to (batch_size, H*W, embed_dim)
        return x

class FeatureExtractor(nn.Module):
    def __init__(self, embed_dim=128, num_heads=4, seq_len=30*30, add_cls_token=True):
        super(FeatureExtractor, self).__init__()
        self.embedding = ConvEmbedding(embed_dim=embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if add_cls_token else None
        self.positional_encoding = nn.Parameter(torch.zeros(1, seq_len+1 if add_cls_token else seq_len, embed_dim))
        self.attention = nn.TransformerEncoderLayer(embed_dim, num_heads, dim_feedforward=embed_dim * 4)
        self.transformer_encoder = nn.TransformerEncoder(self.attention, num_layers=1)
    
    def forward(self, x):
        # x shape: (batch_size, in_channels, H, W)
        x = self.embedding(x)  # Convert to tokens
        print(x.shape)
        if self.cls_token is not None:
            batch_size = x.size(0)
            cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # (batch_size, 1, embed_dim)
            x = torch.cat((cls_tokens, x), dim=1)  # Prepend cls_token
        print(x.shape)
        x = x + self.positional_encoding[:, :x.size(1), :]  # Add positional encoding
        x = self.transformer_encoder(x)  # Apply self-attention
        print(x.shape)
        cls_feature = x[:, 0, :]  # Extract cls token feature
        token_features = x[:, 1:, :]  # Extract other token features
        
        return cls_feature, token_features

# Example usage:
input_tensor = torch.randn(10, 1, 30, 30)  # Batch size of 1, single channel, 30x30 input
model = FeatureExtractor()
cls_feature, token_features = model(input_tensor)

print(cls_feature.shape)  # Expected output: torch.Size([1, 128])
print(token_features.shape)  # Expected output: torch.Size([1, 900, 128])


torch.Size([10, 900, 128])
torch.Size([10, 901, 128])
torch.Size([10, 901, 128])
torch.Size([10, 128])
torch.Size([10, 900, 128])


## Causal Inference - Cross Attention

In [37]:
import torch
import torch.nn as nn

class CrossAttention(nn.Module):
    def __init__(self, feature_dim=128, num_heads=4):
        super(CrossAttention, self).__init__()
        self.multihead_attn = nn.MultiheadAttention(feature_dim, num_heads)
        self.new_cls_token = nn.Parameter(torch.zeros(1, 1, feature_dim))  # New Cls token
    
    def forward(self, example_input_cls, example_output_cls):
        # example_input_cls, example_output_cls: Shape (batch_size, embed_dim)
        
        # Expand new_cls_token to match batch size
        batch_size = example_input_cls.size(0)
        query = self.new_cls_token.expand(batch_size, -1, -1)  # (batch_size, 1, embed_dim)
        
        # Reshape to (sequence_length, batch_size, embed_dim) as expected by MultiheadAttention
        query = query.permute(1, 0, 2)  # (1, batch_size, embed_dim)
        key = torch.stack([example_input_cls, example_output_cls], dim=0)  # (2, batch_size, embed_dim)
        value = torch.stack([example_input_cls, example_output_cls], dim=0)  # (2, batch_size, embed_dim)
        
        # Apply cross attention
        attn_output, _ = self.multihead_attn(query, key, value)
        
        # Return the output as the new Cls token, shape: (batch_size, embed_dim)
        return attn_output.squeeze(0)

# Example usage:
example_input_cls = torch.randn(1, 128)  # Example input Cls token
example_output_cls = torch.randn(1, 128)  # Example output Cls token

cross_attention_v1 = CrossAttention()
output_cls_v1 = cross_attention_v1(example_input_cls, example_output_cls)

print(output_cls_v1.shape)  # Expected output: torch.Size([1, 128])


torch.Size([1, 128])


## Causal Inference - Self Attention

In [17]:
import torch
import torch.nn as nn

class SelfAttentionWithThreeTokens(nn.Module):
    def __init__(self, feature_dim=128, num_heads=4):
        super(SelfAttentionWithThreeTokens, self).__init__()
        self.multihead_attn = nn.MultiheadAttention(feature_dim, num_heads)
        self.new_cls_token = nn.Parameter(torch.zeros(1, 1, feature_dim))  # New Cls token
    
    def forward(self, example_input_cls, example_output_cls):
        # example_input_cls, example_output_cls: Shape (batch_size, embed_dim)
        
        # Expand new_cls_token to match batch size
        batch_size = example_input_cls.size(0)
        new_cls_token_expanded = self.new_cls_token.expand(batch_size, -1, -1)  # (batch_size, 1, embed_dim)
        
        # Combine all Cls tokens: shape (3, batch_size, embed_dim)
        combined_cls_tokens = torch.cat([new_cls_token_expanded, 
                                         example_input_cls.unsqueeze(1), 
                                         example_output_cls.unsqueeze(1)], dim=1)
        print("c")
        print(combined_cls_tokens.shape)
        # Apply self-attention
        combined_cls_tokens = combined_cls_tokens.permute(1, 0, 2)  # (seq_len, batch_size, embed_dim)
        attn_output, _ = self.multihead_attn(combined_cls_tokens, combined_cls_tokens, combined_cls_tokens)
        
        # Return the output as the new Cls token, shape: (batch_size, embed_dim)
        return attn_output[0]  # The first token corresponds to the new_cls_token

# Example usage:
example_input_cls = torch.randn(10, 128)  # Example input Cls token
example_output_cls = torch.randn(10, 128)  # Example output Cls token

self_attention_v2 = SelfAttentionWithThreeTokens()
output_cls_v2 = self_attention_v2(example_input_cls, example_output_cls)

print(output_cls_v2.shape)  # Expected output: torch.Size([1, 128])


c
torch.Size([10, 3, 128])
torch.Size([10, 128])


In [22]:
import torch
import torch.nn as nn

class CombineModule(nn.Module):
    def __init__(self, feature_dim):
        super(CombineModule, self).__init__()
        
        # Self-Attention Layer for combining causals
        self.self_attention = nn.MultiheadAttention(embed_dim=feature_dim, num_heads=4)
        
        # Fully connected layer to produce the final causal representation
        self.fc = nn.Linear(feature_dim, feature_dim)
        
    def forward(self, causals):
        # causals: shape (num_causals, batch_size, feature_dim)
        causals = causals.unsqueeze(1)  # Add a sequence dimension
        attn_output, _ = self.self_attention(causals, causals, causals)
        print("attn_output",attn_output.shape)
        # Mean pooling over the sequence dimension (num_causals)
        combined_causal = attn_output.squeeze(1)  # Shape: (batch_size, feature_dim)
        
        # Pass through a fully connected layer to get the final causal representation
        final_causal = self.fc(combined_causal)  # Shape: (batch_size, feature_dim)
        
        return final_causal

# Example usage:
feature_dim = 128
num_causals = 10

# Assume causals is the output from the Causal Inference module
causals = torch.randn(num_causals, feature_dim)

combine_module = CombineModule(feature_dim=feature_dim)
final_causal = combine_module(causals)

print(final_causal.shape)  # Expected output: torch.Size([batch_size, feature_dim])


attn_output torch.Size([10, 1, 128])
torch.Size([10, 128])


## Head

In [2]:
import torch
import torch.nn as nn

class Head(nn.Module):
    def __init__(self, embed_dim=128, output_dim=1, seq_len=30*30, num_classes=11):
        super(Head, self).__init__()
        self.num_classes = num_classes
        
        # FC layers to transform features
        self.fc1 = nn.Linear(embed_dim, embed_dim)  # Project final_causal to match cls_token
        self.fc2 = nn.Linear(embed_dim * 2, embed_dim)  # Combine Cls and final_causal
        self.fc3 = nn.Linear(embed_dim + (seq_len * embed_dim), seq_len)  # Combine with flattened token features
        
        # Convolution layer to map to class logits
        self.conv = nn.Conv2d(1, num_classes, kernel_size=1)  # 1x1 Convolution for class logits
        
        # Output reshape (no upsample since the size is already 30x30)
        self.output_reshape = nn.Sequential(
            nn.Unflatten(1, (1, int(seq_len ** 0.5), int(seq_len ** 0.5)))  # (batch_size, 1, 30, 30)
        )
        
        # LogSoftmax for multi-class classification
        self.log_softmax = nn.LogSoftmax(dim=1)
    
    def forward(self, cls_token, token_features, final_causal):
        # Project final_causal to the same dimension as cls_token
        final_causal_proj = self.fc1(final_causal)
        
        # Combine cls_token and final_causal_proj
        cls_combined = torch.cat((cls_token, final_causal_proj), dim=-1)
        cls_combined = self.fc2(cls_combined)
        
        # Flatten token features
        token_features_flat = token_features.view(token_features.size(0), -1)
        
        # Combine cls_combined with token_features_flat
        combined_features = torch.cat((cls_combined, token_features_flat), dim=-1)
        x = self.fc3(combined_features)  # (batch_size, seq_len)
        
        # Reshape to (batch_size, 1, 30, 30)
        x = self.output_reshape(x)
        
        # Apply convolution to get class logits
        logits = self.conv(x)  # (batch_size, num_classes, 30, 30)
        
        # Apply log_softmax to get class probabilities
        output = self.log_softmax(logits)  # (batch_size, num_classes, 30, 30)
        
        return output

# Example usage:
cls_token = torch.randn(1, 128)  # Example Cls token
token_features = torch.randn(1, 30*30, 128)  # Example token features
final_causal = torch.randn(1, 128)  # Example Final Causal

head = Head()
output = head(cls_token, token_features, final_causal)

# Get the class with the highest probability for each pixel
predicted_classes = torch.argmax(output, dim=1, keepdim=True)  # (batch_size, 1, 30, 30)

print(predicted_classes.shape)  # Expected output: torch.Size([1, 1, 30, 30])
# print(predicted_classes)  # Prints the predicted class for each pixel


torch.Size([1, 1, 30, 30])
tensor([[[[ 5, 10,  0, 10,  5,  5, 10,  5, 10,  5, 10,  5,  5, 10, 10, 10, 10,
            5,  5,  5, 10, 10, 10, 10, 10, 10, 10,  5,  5,  5],
          [10, 10, 10,  5, 10, 10,  5, 10,  5,  5,  5,  5,  5,  5,  5, 10,  5,
           10,  5,  5,  5, 10, 10,  5,  5,  5,  5, 10, 10,  5],
          [10,  5,  5, 10,  5, 10,  5,  5, 10, 10, 10, 10, 10, 10, 10,  5,  5,
            5, 10,  5,  5,  5, 10,  0,  5, 10, 10,  5, 10,  5],
          [10,  5,  0, 10,  5,  5, 10, 10,  5, 10,  5, 10,  5, 10, 10, 10,  5,
           10, 10,  5,  5, 10, 10, 10,  5, 10,  5,  5,  5,  5],
          [ 5, 10, 10,  5, 10, 10, 10,  5, 10,  5, 10,  5, 10,  5, 10,  5,  5,
           10, 10, 10, 10, 10, 10,  5,  5, 10,  5,  5, 10, 10],
          [10,  5,  5, 10, 10,  5,  5,  5,  5,  5, 10,  5,  5, 10,  5, 10,  5,
           10,  5,  5, 10,  5, 10,  0,  5,  5,  5, 10,  5,  5],
          [10,  5, 10, 10, 10,  5,  5,  5,  5, 10, 10,  5, 10, 10,  5,  5,  5,
            0, 10,  5,  5, 10,  5,  

In [42]:
import torch
import torch.nn as nn
import torch.nn.functional as F
#from bw_net import FeatureExtractor, CausalInference, CombineModule, Head

class BWNet(nn.Module):
    def __init__(self, feature_dim=128, num_examples=5):
        super(BWNet, self).__init__()
        self.feature_extractor = FeatureExtractor(embed_dim=feature_dim)
        self.causal_inference = SelfAttentionWithThreeTokens(feature_dim=feature_dim)
        self.combine_module = CombineModule(feature_dim=feature_dim)
        self.head = Head()

    def forward(self, input_tensor, example_input, example_output):
        # Pad inputs and example tensors to 30x30
        input_padded = F.pad(input_tensor, (0, 30 - input_tensor.size(3), 0, 30 - input_tensor.size(2)), mode='constant', value=0)
        example_input_padded = F.pad(example_input, (0, 30 - example_input.size(3), 0, 30 - example_input.size(2)), mode='constant', value=0)
        example_output_padded = F.pad(example_output, (0, 30 - example_output.size(3), 0, 30 - example_output.size(2)), mode='constant', value=0)
        
        # Feature extraction
        cls_feature, input_features = self.feature_extractor(input_padded)
        example_cls_feature, example_features = self.feature_extractor(example_input_padded)
        
        # Causal inference
        causals = self.causal_inference(example_features, example_output_padded)
        
        # Combine module
        final_causal = self.combine_module(causals)
        
        # Head
        output = self.head(cls_feature, input_features, final_causal)
        
        # Remove padding values for final output
        output = self.remove_padding(output)
        
        return output

    def remove_padding(self, output):
        # Assumes padding value is 0; modify if necessary
        mask = output != 0
        output_cleaned = output[mask].view(output.size(0), 1, -1)
        output_cleaned = F.interpolate(output_cleaned, size=(30, 30), mode='bilinear', align_corners=True)
        return output_cleaned

# Example usage:
input_tensor = torch.randn(1, 1, 20, 20)  # Example input tensor of size (1, 1, 20, 20)
example_input = torch.randn(4, 1, 20, 20)  # Example input tensor of size (1, 1, 20, 20)
example_output = torch.randn(4, 1, 20, 20)  # Example output tensor of size (1, 1, 20, 20)

model = BWNet(feature_dim=128)  # Set appropriate values for feature_dim and num_examples
output = model(input_tensor, example_input, example_output)

print(output.shape)  # Should output torch.Size([1, 1, 30, 30]) after processing




torch.Size([1, 900, 128])
torch.Size([1, 901, 128])
torch.Size([1, 901, 128])
torch.Size([4, 900, 128])
torch.Size([4, 901, 128])
torch.Size([4, 901, 128])


RuntimeError: Tensors must have same number of dimensions: got 3 and 4

In [None]:
# Example usage:
input_tensor = torch.randn(1, 1, 20, 20)  # Example input tensor of size (1, 1, 20, 20)
example_input = torch.randn(1, 1, 20, 20)  # Example input tensor of size (1, 1, 20, 20)
example_output = torch.randn(1, 1, 20, 20)  # Example output tensor of size (1, 1, 20, 20)

input_padded = F.pad(input_tensor, (0, 30 - input_tensor.size(3), 0, 30 - input_tensor.size(2)), mode='constant', value=0)
example_input_padded = F.pad(example_input, (0, 30 - example_input.size(3), 0, 30 - example_input.size(2)), mode='constant', value=0)
example_output_padded = F.pad(example_output, (0, 30 - example_output.size(3), 0, 30 - example_output.size(2)), mode='constant', value=0)