In [2]:
import math
import logging
from functools import partial
from collections import OrderedDict
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.models.helpers import build_model_with_cfg, named_apply, adapt_input_conv
from timm.models.layers import trunc_normal_, lecun_normal_, to_2tuple
from timm.models.registry import register_model
import timm

# from helpers import complement_idx
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce

In [None]:
class Downsample(nn.Module):
    def __init__(self, in_channels, use_conv, out_channels=None):
        super().__init__()
        self.channels = in_channels
        out_channels = out_channels or in_channels
        if use_conv:
            # downsamples by 1/2
            self.downsample = nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1, )
        else:
            assert in_channels == out_channels
            self.downsample = nn.AvgPool2d(kernel_size=2, stride=2)

    def forward(self, x, time_embed=None):
        assert x.shape[1] == self.channels
        return self.downsample(x)


class Upsample(nn.Module):
    def __init__(self, in_channels, use_conv, out_channels=None):
        super().__init__()
        self.channels = in_channels
        self.use_conv = use_conv
        # uses upsample then conv to avoid checkerboard artifacts
        # self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
        if use_conv:
            self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)

    def forward(self, x, time_embed=None):
        assert x.shape[1] == self.channels
        x = F.interpolate(x, scale_factor=2, mode="nearest")
        if self.use_conv:
            x = self.conv(x)
        return x


In [24]:
class DAFF(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,
                 kernel_size=3, with_bn=True):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        # pointwise
        self.conv1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, stride=1, padding=0)
        # depthwise
        self.conv2 = nn.Conv2d(
            hidden_features, hidden_features, kernel_size=kernel_size, stride=1,
            padding=(kernel_size - 1) // 2, groups=hidden_features)
        
        # pointwise
        self.conv3 = nn.Conv2d(hidden_features, out_features, kernel_size=1, stride=1, padding=0)
        self.act = act_layer()
        
        self.bn1 = nn.BatchNorm2d(hidden_features)
        self.bn2 = nn.BatchNorm2d(hidden_features)
        self.bn3 = nn.BatchNorm2d(out_features)
        
        # The reduction ratio is always set to 4
        self.squeeze = nn.AdaptiveAvgPool2d((1, 1))
        self.compress = nn.Linear(in_features, in_features//4)
        self.excitation = nn.Linear(in_features//4, in_features)
                
    def forward(self, x):
        B, N, C = x.size()
        cls_token, tokens = torch.split(x, [1, N - 1], dim=1)
        print(cls_token.shape)
        print(tokens.shape)
        x = tokens.reshape(B, int(math.sqrt(N - 1)), int(math.sqrt(N - 1)), C).permute(0, 3, 1, 2)
        print(x.shape)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act(x)

        shortcut = x
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.act(x)
        x = shortcut + x

        x = self.conv3(x)
        x = self.bn3(x)

        weight = self.squeeze(x).flatten(1).reshape(B, 1, C)
        weight = self.excitation(self.act(self.compress(weight)))
        cls_token = cls_token * weight
        
        tokens = x.flatten(2).permute(0, 2, 1)
        out = torch.cat((cls_token, tokens), dim=1)
        
        return out
    

In [25]:
dim= 768
x_test = torch.rand(1, 197, 768)

In [26]:
daff = DAFF(in_features=dim)

In [27]:
y_test = daff(x_test)

torch.Size([1, 1, 768])
torch.Size([1, 196, 768])
torch.Size([1, 768, 14, 14])


In [11]:
y_test.shape

torch.Size([1, 197, 768])

In [46]:
import torch
import torch.nn as nn

# Define a custom convolutional layer with 3x3 kernel
def conv3x3(in_channels, out_channels):
    return nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)

class MyNetwork(nn.Module):
    def __init__(self):
        super(MyNetwork, self).__init__()
        
        # Define the layers for converting to patches
        self.patch_extraction = nn.Unfold(kernel_size=16, stride=16)  # Extract patches of size (16, 16)
        self.patch_embedding = nn.Linear(16 * 16 * 256, 256)  # Embed patches into the desired dimension

        # Define the layers for upsampling back to (256, 256)
        self.upsample = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # Double spatial dimensions
            nn.GELU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # Double spatial dimensions
            nn.GELU(),
            nn.ConvTranspose2d(64, 256, kernel_size=4, stride=2, padding=1),  # Double spatial dimensions
        )

    def forward(self, x):
        # Extract patches and reshape
        patches = self.patch_extraction(x)
        print(patches.shape)
        # Get the number of patches and the embedding dimension
        num_patches, embedding_dim = patches.size(1), 256
        batch_size = x.size(0)
        # Reshape patches to match the embedding dimension
        patches = patches.view(batch_size, num_patches, embedding_dim)
        patches = self.patch_embedding(patches)

        # Reshape and apply deconvolutions
        patches = patches.view(batch_size, 256, 16, 16)
        x = self.upsample(patches)
        return x

# Example usage with batch size 4:
input_tensor = torch.randn(4, 256, 768)
model = MyNetwork()
output = model(input_tensor)
print(output.shape)  # Should print torch.Size([4, 256, 256])


torch.Size([1024, 768])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (3072x256 and 65536x256)