In [42]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels, embed_dim=96, patch_size=4):
        super().__init__()
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x):
        return self.proj(x)

class WindowAttention(nn.Module):
    def __init__(self, dim, num_heads, window_size):
        super().__init__()
        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5
        
        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.proj = nn.Linear(dim, dim)
        
        self.window_size = window_size
        self.relative_position_bias = nn.Parameter(torch.zeros(
            (2*window_size-1)*(2*window_size-1), num_heads))
        
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        return self.proj(x)

class SwinBlock(nn.Module):
    def __init__(self, dim, num_heads, window_size=7):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, num_heads, window_size)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class SwinTransformer(nn.Module):
    def __init__(self, img_size=48, in_channels=3, out_channels=1, embed_dim=96, depths=[2, 2], num_heads=[3, 6], window_size=7):
        super().__init__()
        self.patch_embed = PatchEmbedding(in_channels, embed_dim)
        self.layers = nn.ModuleList()
        dim = embed_dim
        for i in range(len(depths)):
            blocks = [SwinBlock(dim, num_heads[i], window_size) for _ in range(depths[i])]
            self.layers.append(nn.Sequential(*blocks))
        self.norm = nn.LayerNorm(dim)
        self.head = nn.Conv2d(dim, out_channels, kernel_size=1)
    
    def forward(self, x):
        print(x.shape)
        x = self.patch_embed(x)
        print(x.shape)
        B, C, H, W = x.shape
        x = rearrange(x, 'b c h w -> b (h w) c')
        print(x.shape)
        for layer in self.layers:
            x = layer(x)
            print(x.shape)
        x = self.norm(x)
        print(x.shape)
        x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)
        print(x.shape)
        return self.head(x)


In [40]:

x = torch.rand(1, 3, 224, 224)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
x = x.type(torch.FloatTensor).to(device)
model = Swin().to(device)
print(model(x).shape)

torch.Size([1, 5])


In [41]:
import torch
import torch.nn.functional as F
from torch import nn
from einops import rearrange
import numpy as np


class SwinEmbedding(nn.Module):
    def __init__(self, patch_size=4, emb_size=96):
        super().__init__()
        # Use a convolutional layer with stride=1 to preserve spatial dimensions
        self.linear_embedding = nn.Conv2d(3, emb_size, kernel_size=patch_size, stride=1, padding=patch_size//2)
        
    def forward(self, x):
        x = self.linear_embedding(x)
        x = rearrange(x, 'b c h w -> b (h w) c')
        return x
    

class ShiftedWindowMSA(nn.Module):
    def __init__(self, emb_size, num_heads, window_size=7, shifted=True):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.window_size = window_size
        self.shifted = shifted
        self.linear1 = nn.Linear(emb_size, 3*emb_size)
        self.linear2 = nn.Linear(emb_size, emb_size)
        self.pos_embeddings = nn.Parameter(torch.randn(window_size*2 - 1, window_size*2 - 1))
        self.indices = torch.tensor(np.array([[x, y] for x in range(window_size) for y in range(window_size)]))
        self.relative_indices = self.indices[None, :, :] - self.indices[:, None, :]
        self.relative_indices += self.window_size - 1

    def forward(self, x):
        h_dim = self.emb_size / self.num_heads
        height = width = int(np.sqrt(x.shape[1]))
        x = self.linear1(x)
        
        x = rearrange(x, 'b (h w) (c k) -> b h w c k', h=height, w=width, k=3, c=self.emb_size)
        
        if self.shifted:
            x = torch.roll(x, (-self.window_size//2, -self.window_size//2), dims=(1,2))
        
        x = rearrange(x, 'b (Wh w1) (Ww w2) (e H) k -> b H Wh Ww (w1 w2) e k', w1=self.window_size, w2=self.window_size, H=self.num_heads)            
        
        Q, K, V = x.chunk(3, dim=6)
        Q, K, V = Q.squeeze(-1), K.squeeze(-1), V.squeeze(-1)
        wei = (Q @ K.transpose(4,5)) / np.sqrt(h_dim)
        
        rel_pos_embedding = self.pos_embeddings[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]]
        wei += rel_pos_embedding
        
        if self.shifted:
            row_mask = torch.zeros((self.window_size**2, self.window_size**2)).cuda()
            row_mask[-self.window_size * (self.window_size//2):, 0:-self.window_size * (self.window_size//2)] = float('-inf')
            row_mask[0:-self.window_size * (self.window_size//2), -self.window_size * (self.window_size//2):] = float('-inf')
            column_mask = rearrange(row_mask, '(r w1) (c w2) -> (w1 r) (w2 c)', w1=self.window_size, w2=self.window_size)
            wei[:, :, -1, :] += row_mask
            wei[:, :, :, -1] += column_mask
        
        wei = F.softmax(wei, dim=-1) @ V
        
        x = rearrange(wei, 'b H Wh Ww (w1 w2) e -> b (Wh w1) (Ww w2) (H e)', w1=self.window_size, w2=self.window_size, H=self.num_heads)
        x = rearrange(x, 'b h w c -> b (h w) c')
        
        return self.linear2(x)
    
class MLP(nn.Module):
    def __init__(self, emb_size):
        super().__init__()
        self.ff = nn.Sequential(
                         nn.Linear(emb_size, 4*emb_size),
                         nn.GELU(),
                         nn.Linear(4*emb_size, emb_size),
                  )
    
    def forward(self, x):
        return self.ff(x)
    
class SwinEncoder(nn.Module):
    def __init__(self, emb_size, num_heads, window_size=7):
        super().__init__()
        self.WMSA = ShiftedWindowMSA(emb_size, num_heads, window_size, shifted=False)
        self.SWMSA = ShiftedWindowMSA(emb_size, num_heads, window_size, shifted=True)
        self.ln = nn.LayerNorm(emb_size)
        self.MLP = MLP(emb_size)
        
    def forward(self, x):
        # Window Attention
        x = x + self.WMSA(self.ln(x))
        x = x + self.MLP(self.ln(x))
        # shifted Window Attention
        x = x + self.SWMSA(self.ln(x))
        x = x + self.MLP(self.ln(x))
        
        return x
    
class SwinPDE(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, emb_size=96):
        super().__init__()
        self.Embedding = SwinEmbedding(patch_size=4, emb_size=emb_size)
        
        # Swin stages without PatchMerging
        self.stage1 = SwinEncoder(emb_size, num_heads=3)
        self.stage2 = SwinEncoder(emb_size, num_heads=6)
        self.stage3 = nn.ModuleList([SwinEncoder(emb_size, num_heads=12) for _ in range(3)])
        self.stage4 = SwinEncoder(emb_size, num_heads=24)
        
        # Final projection to output channels
        self.final_proj = nn.Conv2d(emb_size, out_channels, kernel_size=1, stride=1)

    def forward(self, x):
        x = self.Embedding(x)
        x = self.stage1(x)
        x = self.stage2(x)
        for stage in self.stage3:
            x = stage(x)
        x = self.stage4(x)
        
        # Reshape back to (batch, channels, height, width)
        height = width = int(np.sqrt(x.shape[1]))
        x = rearrange(x, 'b (h w) c -> b c h w', h=height, w=width)
        
        # Final projection to output channels
        x = self.final_proj(x)
        return x


# Example usage
if __name__ == '__main__':
    x = torch.rand(1, 3, 224, 224)  # Input with 3 channels
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    x = x.to(device)
    model = SwinPDE(in_channels=3, out_channels=1).to(device)  # Output with 1 channel
    print(model(x).shape)  # Should output (1, 1, 224, 224)

EinopsError:  Error while processing rearrange-reduction pattern "b (Wh w1) (Ww w2) (e H) k -> b H Wh Ww (w1 w2) e k".
 Input tensor shape: torch.Size([1, 225, 225, 96, 3]). Additional info: {'w1': 7, 'w2': 7, 'H': 3}.
 Shape mismatch, can't divide axis of length 225 in chunks of 7

In [44]:
import random
