In [2]:
import os
from PIL import Image
import torchvision.transforms as transforms
import torch
import numpy as np

from src.model.fusion import AttentionFusionModule
from src.datamodule.datamodule import ImageForgeryDatamMdule
from timm.models.swin_transformer import SwinTransformer
from src.model.cnn_gru import HybridCNNGRU
from timm.models.layers import SelectAdaptivePool2d, ClassifierHead
import timm
from torch import nn
import torch_dct as dct
from src.model.cnn_gru import HybridCNNGRU

In [3]:
image = torch.rand(3, 3, 224, 224)

In [4]:
state_dict = timm.create_model('swin_tiny_patch4_window7_224', pretrained=True).state_dict()
model = SwinTransformer(        
            drop_rate=0.9,
            proj_drop_rate=0.9,
            attn_drop_rate=0.9,
            drop_path_rate=0.9,
        )
model.load_state_dict(state_dict)
model.head = SelectAdaptivePool2d(output_size=768, pool_type='avg')
model(image).shape

torch.Size([3, 7, 768, 768])

In [5]:
model

SwinTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
    (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
  )
  (layers): Sequential(
    (0): SwinTransformerStage(
      (downsample): Identity()
      (blocks): Sequential(
        (0): SwinTransformerBlock(
          (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=96, out_features=288, bias=True)
            (attn_drop): Dropout(p=0.9, inplace=False)
            (proj): Linear(in_features=96, out_features=96, bias=True)
            (proj_drop): Dropout(p=0.9, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path1): Identity()
          (norm2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=96, out_features=384, bias=True)
            (act): GELU(approximate='none')
            (drop1): 

In [6]:
SelectAdaptivePool2d

timm.layers.adaptive_avgmax_pool.SelectAdaptivePool2d

In [7]:
att = AttentionFusionModule(224, 512, 128)

In [8]:
x1 = torch.rand(1, 14, 14, 224)
x2 = torch.rand(1, 14, 14, 512)

In [9]:
(att(x1, x2)).shape

torch.Size([1, 14, 14, 128])

In [10]:
ClassifierHead(in_features=14, num_classes=2)(att(x1, x2))

tensor([[-0.1777,  0.0030]], grad_fn=<AddmmBackward0>)

In [11]:
tem = torch.rand(1, 2, 10)
torch.nn.Softmax(dim=1)(tem)

tensor([[[0.5063, 0.4341, 0.5391, 0.3359, 0.4139, 0.5473, 0.4205, 0.4878,
          0.5356, 0.4875],
         [0.4937, 0.5659, 0.4609, 0.6641, 0.5861, 0.4527, 0.5795, 0.5122,
          0.4644, 0.5125]]])

In [19]:
class CNN(nn.Module):
    def __init__(self, input_channels):
        super(CNN, self).__init__()
        self.module = nn.Sequential(
            nn.Conv2d(input_channels, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.GELU(),
            nn.AvgPool2d(kernel_size=2,),

            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.GELU(),
            nn.AvgPool2d(kernel_size=2,),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.GELU(),
            nn.AvgPool2d(kernel_size=2,),

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.GELU(),
            nn.AvgPool2d(kernel_size=2,),

            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.GELU(),
            nn.AvgPool2d(kernel_size=2,),
        )

    def forward(self, x):
        return self.module(x)
CNN(3)(image).shape

torch.Size([3, 512, 7, 7])

torch.Size([3, 512, 109, 109])