## Feature Extractor

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvEmbedding(nn.Module):
    def __init__(self, in_channels=1, embed_dim=128):
        super(ConvEmbedding, self).__init__()
        self.conv = nn.Conv2d(in_channels, embed_dim, kernel_size=3, padding=1)
        self.flatten = nn.Flatten(2)  # Flatten the spatial dimensions (H, W)
    
    def forward(self, x):
        x = self.conv(x)  # Apply convolution
        x = self.flatten(x)  # Flatten to (batch_size, embed_dim, H*W)
        x = x.transpose(1, 2)  # Transpose to (batch_size, H*W, embed_dim)
        return x

class FeatureExtractor(nn.Module):
    def __init__(self, embed_dim=128, num_heads=4, seq_len=30*30, add_cls_token=True):
        super(FeatureExtractor, self).__init__()
        self.embedding = ConvEmbedding(embed_dim=embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if add_cls_token else None
        self.positional_encoding = nn.Parameter(torch.zeros(1, seq_len+1 if add_cls_token else seq_len, embed_dim))
        self.attention = nn.TransformerEncoderLayer(embed_dim, num_heads, dim_feedforward=embed_dim * 4)
        self.transformer_encoder = nn.TransformerEncoder(self.attention, num_layers=1)
    
    def forward(self, x):
        # x shape: (batch_size, in_channels, H, W)
        x = x.squeeze(1)  # Remove in_channels dimension
        x = self.embedding(x)  # Convert to tokens
        print(x.shape)
        if self.cls_token is not None:
            batch_size = x.size(0)
            cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # (batch_size, 1, embed_dim)
            x = torch.cat((cls_tokens, x), dim=1)  # Prepend cls_token
        print(x.shape)
        x = x + self.positional_encoding[:, :x.size(1), :]  # Add positional encoding
        x = self.transformer_encoder(x)  # Apply self-attention
        print(x.shape)
        cls_feature = x[:, 0, :]  # Extract cls token feature
        token_features = x[:, 1:, :]  # Extract other token features
        
        return cls_feature, token_features

# Example usage:
input_tensor = torch.randn(10, 1, 1, 30, 30)  # Batch size of 1, single channel, 30x30 input
model = FeatureExtractor()
cls_feature, token_features = model(input_tensor)

print(cls_feature.shape)  # Expected output: torch.Size([1, 128])
print(token_features.shape)  # Expected output: torch.Size([1, 900, 128])


torch.Size([10, 900, 128])
torch.Size([10, 901, 128])
torch.Size([10, 901, 128])
torch.Size([10, 128])
torch.Size([10, 900, 128])


## Causal Inference - Cross Attention

In [9]:
import torch
import torch.nn as nn

class CrossAttention(nn.Module):
    def __init__(self, feature_dim=128, num_heads=4):
        super(CrossAttention, self).__init__()
        self.multihead_attn = nn.MultiheadAttention(feature_dim, num_heads)
        self.new_cls_token = nn.Parameter(torch.zeros(1, 1, feature_dim))  # New Cls token
    
    def forward(self, example_input_cls, example_output_cls):
        # example_input_cls, example_output_cls: Shape (batch_size, embed_dim)
        
        # Expand new_cls_token to match batch size
        batch_size = example_input_cls.size(0)
        query = self.new_cls_token.expand(batch_size, -1, -1)  # (batch_size, 1, embed_dim)
        
        # Reshape to (sequence_length, batch_size, embed_dim) as expected by MultiheadAttention
        query = query.permute(1, 0, 2)  # (1, batch_size, embed_dim)
        key = torch.stack([example_input_cls, example_output_cls], dim=0)  # (2, batch_size, embed_dim)
        value = torch.stack([example_input_cls, example_output_cls], dim=0)  # (2, batch_size, embed_dim)
        
        # Apply cross attention
        attn_output, _ = self.multihead_attn(query, key, value)
        
        # Return the output as the new Cls token, shape: (batch_size, embed_dim)
        return attn_output.squeeze(0)

# Example usage:
example_input_cls = torch.randn(1, 128)  # Example input Cls token
example_output_cls = torch.randn(1, 128)  # Example output Cls token

cross_attention_v1 = CrossAttention()
output_cls_v1 = cross_attention_v1(example_input_cls, example_output_cls)

print(output_cls_v1.shape)  # Expected output: torch.Size([1, 128])


torch.Size([1, 128])


## Causal Inference - Self Attention

In [10]:
import torch
import torch.nn as nn

class SelfAttentionWithThreeTokens(nn.Module):
    def __init__(self, feature_dim=128, num_heads=4):
        super(SelfAttentionWithThreeTokens, self).__init__()
        self.multihead_attn = nn.MultiheadAttention(feature_dim, num_heads)
        self.new_cls_token = nn.Parameter(torch.zeros(1, 1, feature_dim))  # New Cls token
    
    def forward(self, example_input_cls, example_output_cls):
        # example_input_cls, example_output_cls: Shape (batch_size, embed_dim)
        
        # Expand new_cls_token to match batch size
        batch_size = example_input_cls.size(0)
        new_cls_token_expanded = self.new_cls_token.expand(batch_size, -1, -1)  # (batch_size, 1, embed_dim)
        
        # Combine all Cls tokens: shape (3, batch_size, embed_dim)
        combined_cls_tokens = torch.cat([new_cls_token_expanded, 
                                         example_input_cls.unsqueeze(1), 
                                         example_output_cls.unsqueeze(1)], dim=1)
        print("c")
        print(combined_cls_tokens.shape)
        # Apply self-attention
        combined_cls_tokens = combined_cls_tokens.permute(1, 0, 2)  # (seq_len, batch_size, embed_dim)
        attn_output, _ = self.multihead_attn(combined_cls_tokens, combined_cls_tokens, combined_cls_tokens)
        
        # Return the output as the new Cls token, shape: (batch_size, embed_dim)
        return attn_output[0]  # The first token corresponds to the new_cls_token

# Example usage:
example_input_cls = torch.randn(10, 128)  # Example input Cls token
example_output_cls = torch.randn(10, 128)  # Example output Cls token

self_attention_v2 = SelfAttentionWithThreeTokens()
output_cls_v2 = self_attention_v2(example_input_cls, example_output_cls)

print(output_cls_v2.shape)  # Expected output: torch.Size([1, 128])


c
torch.Size([10, 3, 128])
torch.Size([10, 128])


## Combine Module

In [11]:
import torch
import torch.nn as nn

class CombineModule(nn.Module):
    def __init__(self, feature_dim):
        super(CombineModule, self).__init__()
        
        # Self-Attention Layer for combining causals
        self.self_attention = nn.MultiheadAttention(embed_dim=feature_dim, num_heads=4)
        self.new_cls_token = nn.Parameter(torch.zeros(1, 1, feature_dim))  # New Cls token
        # Fully connected layer to produce the final causal representation
        self.fc = nn.Linear(feature_dim, feature_dim)
        
    def forward(self, causals):
        # causals: shape (num_causals, batch_size, feature_dim)
        causals = causals.unsqueeze(1)  # Add a sequence dimension
        print("causals",causals.shape)
        new_cls_token_expanded = self.new_cls_token.expand(1, -1, -1)
        print("new_cls_token_expanded",new_cls_token_expanded.shape)
        
        # cls_causals = torch.cat([new_cls_token_expanded, 
        #                                  causals], dim=0)

        attn_output, _ = self.self_attention(new_cls_token_expanded, causals, causals)
        print("attn_output",attn_output.shape)
        # Mean pooling over the sequence dimension (num_causals)
        combined_causal = attn_output.squeeze(1)  # Shape: (batch_size, feature_dim)
        
        # Pass through a fully connected layer to get the final causal representation
        final_causal = self.fc(combined_causal)  # Shape: (batch_size, feature_dim)
        
        return final_causal

# Example usage:
feature_dim = 128
num_causals = 10

# Assume causals is the output from the Causal Inference module
causals = torch.randn(num_causals, feature_dim)

combine_module = CombineModule(feature_dim=feature_dim)
final_causal = combine_module(causals)

print(final_causal.shape)  # Expected output: torch.Size([batch_size, feature_dim])


causals torch.Size([10, 1, 128])
new_cls_token_expanded torch.Size([1, 1, 128])
attn_output torch.Size([1, 1, 128])
torch.Size([1, 128])


## Head

In [26]:
import torch
import torch.nn as nn

class Head(nn.Module):
    def __init__(self, embed_dim=128, output_dim=1, seq_len=30*30, num_classes=11):
        super(Head, self).__init__()
        self.num_classes = num_classes
        
        # FC layers to transform features
        self.fc1 = nn.Linear(embed_dim, embed_dim)  # Project final_causal to match cls_token
        self.fc2 = nn.Linear(embed_dim, embed_dim)
        self.fc3 = nn.Linear(embed_dim*2, seq_len)  # Combine Cls and final_causal
        self.fc4 = nn.Linear(embed_dim + (seq_len * embed_dim), seq_len)  # Combine with flattened token features
        
        # Convolution layer to map to class logits
        self.conv = nn.Conv2d(1, num_classes, kernel_size=1)  # 1x1 Convolution for class logits
        
        # Output reshape (no upsample since the size is already 30x30)
        self.output_reshape = nn.Sequential(
            nn.Unflatten(1, (1, int(seq_len ** 0.5), int(seq_len ** 0.5)))  # (batch_size, 1, 30, 30)
        )
        
        # LogSoftmax for multi-class classification
        self.log_softmax = nn.LogSoftmax(dim=1)
    
    def forward(self, cls_token, token_features, final_causal):
        # Project final_causal to the same dimension as cls_token
        final_causal_proj = self.fc1(final_causal)
        print("final_causal_proj",final_causal_proj.shape)
        # Combine cls_token and final_causal_proj
        #cls_combined = torch.cat((cls_token, final_causal_proj), dim=-1)
        
        cls_combined = self.fc2(cls_token)
        print("cls_combined",cls_combined.shape)
        # Flatten token features
        #token_features_flat = token_features.view(token_features.size(0), -1)
        
        # Combine cls_combined with token_features_flat
        #combined_features = torch.cat((cls_combined, token_features_flat), dim=-1)
        combined_features = torch.cat((final_causal_proj, cls_combined), dim=-1)
        print("combined_features",combined_features.shape)
        x = self.fc3(combined_features)  # (batch_size, seq_len)
        print("x",x.shape)
        # Reshape to (batch_size, 1, 30, 30)
        x = self.output_reshape(x)
        print("x",x.shape)
        # Apply convolution to get class logits
        logits = self.conv(x)  # (batch_size, num_classes, 30, 30)
        print("logits",logits.shape)
        # Apply log_softmax to get class probabilities
        output = self.log_softmax(logits)  # (batch_size, num_classes, 30, 30)
        
        return output

# Example usage:
cls_token = torch.randn(1, 128)  # Example Cls token
token_features = torch.randn(1, 30*30, 128)  # Example token features
final_causal = torch.randn(1, 128)  # Example Final Causal

head = Head()
output = head(cls_token, token_features, final_causal)

# Get the class with the highest probability for each pixel
predicted_classes = torch.argmax(output, dim=1, keepdim=True)  # (batch_size, 1, 30, 30)

print(predicted_classes.shape)  # Expected output: torch.Size([1, 1, 30, 30])
# print(predicted_classes)  # Prints the predicted class for each pixel


final_causal_proj torch.Size([1, 128])
cls_combined torch.Size([1, 128])
combined_features torch.Size([1, 256])
x torch.Size([1, 900])
x torch.Size([1, 1, 30, 30])
logits torch.Size([1, 11, 30, 30])
torch.Size([1, 1, 30, 30])


In [30]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class BWNet(nn.Module):
    def __init__(self, feature_dim=128, num_examples=5):
        super(BWNet, self).__init__()
        self.feature_extractor = FeatureExtractor(embed_dim=feature_dim)
        self.causal_inference = SelfAttentionWithThreeTokens(feature_dim=feature_dim)
        self.combine_module = CombineModule(feature_dim=feature_dim)
        self.head = Head()

    def forward(self, input_tensor, example_input, example_output):
        # 입력 및 예제 텐서를 30x30으로 패딩
        
        # Feature extraction
        cls_feature, _ = self.feature_extractor(input_tensor)
        ex_input_cls_feature, _ = self.feature_extractor(example_input)
        ex_output_cls_feature, _ = self.feature_extractor(example_output)
        
        # Causal inference
        causals = self.causal_inference(ex_input_cls_feature, ex_output_cls_feature)
        
        # Combine module
        final_causal = self.combine_module(causals)
        
        # Head
        output = self.head(cls_feature, _, final_causal)
        
        # Padding 제거
        # output = self.remove_padding(output)
        
        return output

    # def remove_padding(self, output):
    #     # 패딩 값을 제거하고 원본 크기로 보간
    #     mask = output != 0
    #     output_cleaned = output[mask].view(output.size(0), 1, -1)
    #     output_cleaned = F.interpolate(output_cleaned, size=(30, 30), mode='bilinear', align_corners=True)
    #     return output_cleaned

# 예시 사용법:
input_tensor = torch.randn(1, 1, 30, 30)  # 입력 텐서 예시
example_input = torch.randn(10, 1, 30, 30)  # 입력 텐서 예시
example_output = torch.randn(10, 1, 30, 30)  # 출력 텐서 예시

model = BWNet(feature_dim=128)
output = model(input_tensor, example_input, example_output)

print(output.shape)  # 최종 출력 크기를 확인

torch.Size([1, 900, 128])
torch.Size([1, 901, 128])
torch.Size([1, 901, 128])
torch.Size([10, 900, 128])
torch.Size([10, 901, 128])
torch.Size([10, 901, 128])
torch.Size([10, 900, 128])
torch.Size([10, 901, 128])
torch.Size([10, 901, 128])
c
torch.Size([10, 3, 128])
causals torch.Size([10, 1, 128])
new_cls_token_expanded torch.Size([1, 1, 128])
attn_output torch.Size([1, 1, 128])
final_causal_proj torch.Size([1, 128])
cls_combined torch.Size([1, 128])
combined_features torch.Size([1, 256])
x torch.Size([1, 900])
x torch.Size([1, 1, 30, 30])
logits torch.Size([1, 11, 30, 30])
torch.Size([1, 11, 30, 30])


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FEBlock(nn.Module):
    def __init__(self, embed_size=1):
        super(FEBlock, self).__init__()
        # 1x1 ~ 30x30 Convolution layers 생성 (Padding X)
        # self.convs = nn.ModuleList([nn.Conv2d(1, embed_size, kernel_size=n, padding=0) for n in range(1, 31)])
        # self.fc = nn.ModuleList([nn.Linear((30-n)**(30-n)embed_size, 30*30*embed_size) for n in range(1, 31)])
        self.numbers = [1, 2, 3, 5, 7, 9, 11, 13, 15, 25, 30]
        self.stages = nn.ModuleList([])
        for n in self.numbers:
            self.stages.append(nn.Sequential(
                nn.Conv2d(1, n, kernel_size=n, padding=0),
                nn.Flatten(1),
                nn.Linear((30-n+1)**2*n, 30*30*embed_size)
            ))

    def forward(self, x):
        features = []
        for stage in self.stages:
            features.append(stage(x).unsqueeze(1))  # (batch, 1, 1, n*n*embed_size)
            print(features[-1].shape)
        return torch.stack(features, dim=1)  # (batch, 30, n*n*embed_size)

class SelfAttentionBlock(nn.Module):
    def __init__(self, embed_size=1):
        super(SelfAttentionBlock, self).__init__()
        self.cls_token = nn.Parameter(torch.randn(1, 1, 30*30*embed_size))
        self.self_attn = nn.MultiheadAttention(embed_dim=30*30*embed_size, num_heads=4)

    def forward(self, x):
        batch_size = x.size(0)
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        # print(cls_tokens.shape)
        # print(x.shape)
        x = x.squeeze(2)
        # print(x.shape)
        x = torch.cat((cls_tokens, x), dim=1)  # (batch, 31, 30*30*embed_size)
        x, _ = self.self_attn(x, x, x)
        return x[:, 0].unsqueeze(1)  # (batch, 1, 30*30*embed_size)

class HeadBlock(nn.Module):
    def __init__(self, embed_size=1, num_classes=11):
        super(HeadBlock, self).__init__()
        self.fc1 = nn.Linear(30*30*embed_size, 30*30*embed_size*embed_size)
        self.fc2 = nn.Linear(30*30*embed_size*embed_size, 30*30*embed_size)
        self.fc3 = nn.Linear(30*30*embed_size, 30*30)
        self.conv = nn.Conv2d(1, num_classes, kernel_size=1)
        self.log_softmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        x = x.view(-1, 1, 30, 30)  # Reshape to (batch, 1, 30, 30)
        x = self.conv(x)
        x = self.log_softmax(x)
        return x

class BWNet_MAML(nn.Module):
    def __init__(self, embed_size=1):
        super(BWNet_MAML, self).__init__()
        self.fe_block = FEBlock(embed_size=embed_size)
        self.self_attn_block = SelfAttentionBlock(embed_size=embed_size)
        self.head_block = HeadBlock(embed_size=embed_size)
        

    def forward(self, ex_input):
        features = self.fe_block(ex_input)  # (batch, 30, 30*30*embed_size)
        cls_feature = self.self_attn_block(features)  # (batch, 1, 30*30*embed_size)
        out = self.head_block(cls_feature)  # (batch, 1, 30, 30)
        #out = self.fc_final(out)  # (batch, 11, 30, 30)
        return out

input_tensor = torch.randn(1, 1, 30, 30)  # 입력 텐서 예시
example_input = torch.randn(10, 1, 30, 30)  # 입력 텐서 예시
# example_output = torch.randn(10, 1, 30, 30)  # 출력 텐서 예시

model = BWNet_MAML(embed_size=2)
output = model(example_input)

print(output.shape)  # 최종 출력 크기를 확인


torch.Size([10, 1, 1800])
torch.Size([10, 1, 1800])
torch.Size([10, 1, 1800])
torch.Size([10, 1, 1800])
torch.Size([10, 1, 1800])
torch.Size([10, 1, 1800])
torch.Size([10, 1, 1800])
torch.Size([10, 1, 1800])
torch.Size([10, 1, 1800])
torch.Size([10, 1, 1800])
torch.Size([10, 1, 1800])
torch.Size([10, 11, 30, 30])
