# 自定义一个backbone，最后的tensor.size = [B,1024,24,24]

In [2]:
from timm.models.resnetv2 import ResNetV2

In [12]:
from timm.models.layers import StdConv2dSame
backbone = ResNetV2(
        layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=3,
        preact=False, stem_type='same', conv_layer=StdConv2dSame)

In [13]:
print(backbone)

ResNetV2(
  (stem): Sequential(
    (conv): StdConv2dSame(3, 64, kernel_size=(7, 7), stride=(2, 2), bias=False)
    (norm): GroupNormAct(
      32, 64, eps=1e-05, affine=True
      (act): ReLU(inplace=True)
    )
    (pool): MaxPool2dSame(kernel_size=(3, 3), stride=(2, 2), padding=(0, 0), dilation=(1, 1), ceil_mode=True)
  )
  (stages): Sequential(
    (0): ResNetStage(
      (blocks): Sequential(
        (0): Bottleneck(
          (downsample): DownsampleConv(
            (conv): StdConv2dSame(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (norm): GroupNormAct(
              32, 256, eps=1e-05, affine=True
              (act): Identity()
            )
          )
          (conv1): StdConv2dSame(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (norm1): GroupNormAct(
            32, 64, eps=1e-05, affine=True
            (act): ReLU(inplace=True)
          )
          (conv2): StdConv2dSame(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), 

In [48]:
import timm
resnetv2 = timm.create_model('resnetv2_50x1_bitm',pretrained = True)

In [49]:
resnetv2.reset_classifier(2)

In [77]:
VIT = timm.create_model('vit_base_resnet50_384',pretrained = True)
VIT.reset_classifier(2)

In [82]:
import torch.nn as nn
 
import torch.nn.functional as F
 
class Model(nn.Module):
    def __init__(self,stem,stages0,stages1,stages2):
        
        super(Model,self).__init__()
        self.stem=stem
        self.stages0=stages0
        self.stages1=stages1
        self.stages2=stages2
 
    def forward(self,x):
        stem_out = self.stem(x)
        stages0_out = self.stages0(stem_out)
        stages1_out = self.stages1(stages0_out)
        
        return self.stages2(stages1_out)

In [83]:
backbone = Model(resnetv2.stem,resnetv2.stages[0],resnetv2.stages[1],resnetv2.stages[2])

In [84]:
print(backbone)

Model(
  (stem): Sequential(
    (conv): StdConv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (pad): ConstantPad2d(padding=(1, 1, 1, 1), value=0.0)
    (pool): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (stages0): ResNetStage(
    (blocks): Sequential(
      (0): PreActBottleneck(
        (downsample): DownsampleConv(
          (conv): StdConv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (norm): Identity()
        )
        (norm1): GroupNormAct(
          32, 64, eps=1e-05, affine=True
          (act): ReLU(inplace=True)
        )
        (conv1): StdConv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): GroupNormAct(
          32, 64, eps=1e-05, affine=True
          (act): ReLU(inplace=True)
        )
        (conv2): StdConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (norm3): GroupNormAct(
          32, 64, eps=1e-05, affine=True


In [85]:
print(resnetv2.stages[1])
VIT.patch_embed.backbone=backbone

ResNetStage(
  (blocks): Sequential(
    (0): PreActBottleneck(
      (downsample): DownsampleConv(
        (conv): StdConv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (norm): Identity()
      )
      (norm1): GroupNormAct(
        32, 256, eps=1e-05, affine=True
        (act): ReLU(inplace=True)
      )
      (conv1): StdConv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (norm2): GroupNormAct(
        32, 128, eps=1e-05, affine=True
        (act): ReLU(inplace=True)
      )
      (conv2): StdConv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (norm3): GroupNormAct(
        32, 128, eps=1e-05, affine=True
        (act): ReLU(inplace=True)
      )
      (conv3): StdConv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (drop_path): Identity()
    )
    (1): PreActBottleneck(
      (norm1): GroupNormAct(
        32, 512, eps=1e-05, affine=True
        (act): ReLU(inplace=True)
      )
      (conv1)

In [88]:
import glob
import os
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
import cv2 as cv
from config import load_config
from torchvision import transforms

train_transforms = A.Compose(
    [
        A.Resize(height=384,width=384),
    ]
)

class DeeperDataset(Dataset):
    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform
        self.as_tensor = transforms.Compose([
            transforms.ToTensor(),
            ])

    def __len__(self):
        self.filelength = len(self.file_list)
        return self.filelength

    def __getitem__(self, idx):
        img_path = self.file_list[idx]
        img = Image.open(img_path)
        #print(img_path)
        img_transformed = self.transform(image = np.array(img))
        img_transformed = self.as_tensor(img_transformed['image'])
        label = img_path.split("/")[-1].split(".")[0].split("_")[-1]
        label = 1 if label == "fake" else 0

        return img_transformed, label

def load_data():

    train_dir = './'
    train_list = glob.glob(os.path.join(train_dir,'*.png'))
    train_data = DeeperDataset(train_list, transform=train_transforms)
    train_loader = DataLoader(dataset = train_data, batch_size=1, shuffle=True)

    return train_loader

train_loader = load_data()
for data,label in train_loader:
    stem = resnetv2.stem(data)
    stages0 = resnetv2.stages[0](stem)
    stages1 = resnetv2.stages[1](stages0)
    stages2 = resnetv2.stages[2](stages1)
    stages3 = resnetv2.stages[3](stages2)
    norm = resnetv2.norm(stages3)
    head = resnetv2.head(norm)
    print(stem.shape)
    print("0:",stages0.shape)
    print("1:",stages1.shape)
    print("2:",stages2.shape)
    print(stages3.shape)
    print(norm.shape)
    print(head.shape)
    
    output = VIT(data)
    VIT_patch = VIT.patch_embed.backbone(data)
    print(output)

torch.Size([1, 64, 96, 96])
0: torch.Size([1, 256, 96, 96])
1: torch.Size([1, 512, 48, 48])
2: torch.Size([1, 1024, 24, 24])
torch.Size([1, 2048, 12, 12])
torch.Size([1, 2048, 12, 12])
torch.Size([1, 2, 1, 1])
tensor([[0.3996, 0.4507]], grad_fn=<AddmmBackward>)
