<a href="https://colab.research.google.com/github/yavuzkayacan/my_colab/blob/main/pytorch_att_conv_pixelshuffle_r1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [122]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import einsum
from torch import Tensor
import math
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
from torchvision import transforms
import torchvision.transforms as transforms
import torchvision.models as models
import pandas as pd
import os
import time
import matplotlib.pyplot as plt
from PIL import Image, ImageFile

In [123]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [124]:
class Patches(nn.Module):
    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size


    def forward(self, images):
        batch_size, channels, height, width = images.size()

        patches = images.unfold(2, self.patch_size[0], self.patch_size[0]).unfold(3, self.patch_size[1], self.patch_size[1])
        patches = patches.contiguous().view(batch_size, channels, -1, self.patch_size[0], self.patch_size[1])


        return patches

In [125]:
class PatchComb(nn.Module):

    def __init__(self, patch_size, num_patches, ch):
        super(PatchComb, self).__init__()
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.ch = ch

    def forward(self,patches):
        reshaped_patches = torch.reshape(patches, (-1, self.num_patches[0], self.num_patches[1], self.patch_size[0], self.patch_size[1], self.ch))
        reshaped_patches = reshaped_patches.permute(0, 5, 1, 3, 2, 4)
        reshaped_patches = torch.reshape(reshaped_patches, (-1, self.ch, self.num_patches[0]*self.patch_size[0], self.num_patches[1]*self.patch_size[1]))
        return reshaped_patches

In [126]:
class ConvAttention(nn.Module):
    def __init__(self, n_channels):
        super(ConvAttention, self).__init__()
        self.n_channels = n_channels
        self.query = nn.Conv3d(self.n_channels, self.n_channels, kernel_size=3, padding='same')
        self.key = nn.Conv3d(self.n_channels, self.n_channels, kernel_size=3, padding='same')
        self.value = nn.Conv3d(self.n_channels, self.n_channels, kernel_size=3, padding='same')
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        size = x.shape
        f, g, h = self.query(x), self.key(x), self.value(x)
        mat_mul = torch.matmul(f.permute(0, 1, 2, 3, 4), g)
        beta = F.softmax(mat_mul, dim=1)
        o = self.gamma * torch.matmul(h, beta) + x

        return mat_mul

In [151]:
X = torch.rand((1,16,256,64))

In [152]:
input_shape = (1, 256, 64)
patch_size = [4, 4]
num_patches = (input_shape[1] // patch_size[0]) * (input_shape[2] // patch_size[1])
projection_dim = 64
num_patch = np.array(((input_shape[1] // patch_size[0]) , (input_shape[2] // patch_size[1])))

In [153]:
ch_in = 1

In [154]:
X = nn.Conv2d(X.shape[1], ch_in ,kernel_size=3, padding='same')(X)

In [155]:
X.shape

torch.Size([1, 1, 256, 64])

In [156]:
patch_size

[4, 4]

In [157]:
X = Patches(patch_size)(X)

In [158]:
X.shape

torch.Size([1, 1, 1024, 4, 4])

In [161]:
X = ConvAttention(ch_in)(X)

In [162]:
X.shape

torch.Size([1, 1, 1024, 4, 4])

In [None]:
X = PatchComb(patch_size, num_patch, ch_in)(X)

In [136]:
X.shape

torch.Size([1, 1, 256, 64])

In [101]:
nn.PixelUnshuffle(2)(X).shape

torch.Size([1, 4, 128, 32])

In [102]:
class Conv_Transformers(nn.Module):
    def __init__(self, patch_size, num_patch, in_ch, ch, depth):
        super(Conv_Transformers, self).__init__()

        self.patch_size = patch_size
        self.num_patch = num_patch
        self.in_ch = in_ch
        self.ch = ch

        self.depth = depth


        self.conv = nn.Conv2d(in_channels=self.in_ch, out_channels=self.ch, kernel_size=1,padding='same')

        self.patching = Patches(self.patch_size)  # Assuming Patches is defined
        self.conv_att = ConvAttention(self.ch)  # Assuming ConvAttention is defined
        self.patch_comb = PatchComb(self.patch_size, self.num_patch, self.ch)  # Assuming PatchComb is defined

        self.ds = nn.PixelUnshuffle(2)

    def forward(self, X):

        X = self.conv(X)
        X = self.patching(X)
        X = self.conv_att(X)
        X = self.patch_comb(X)
        X = self.ds(X)


        return X


In [103]:
num_patch

array([64, 16])

In [104]:
X.shape

torch.Size([1, 1, 256, 64])

In [105]:
Conv_Transformers(patch_size=patch_size,ch=64,in_ch=1,num_patch=num_patch,depth=1)(X).shape

torch.Size([1, 256, 128, 32])

In [119]:
class Conv_Transformers(nn.Module):
    def __init__(self, patch_size, num_patch, in_ch, ch, depth):
        super(Conv_Transformers, self).__init__()

        self.patch_size = patch_size
        self.num_patch = num_patch
        self.in_ch = in_ch
        self.ch = ch

        self.depth = depth

        self.conv = nn.Conv2d(in_channels=self.in_ch, out_channels=self.ch, kernel_size=1,padding='same')
        self.layers_down = nn.ModuleList()
        self.layers_up = nn.ModuleList()
        self.conv_out = nn.Conv2d(in_channels=self.ch*(self.depth+1), out_channels=1, kernel_size=1,padding='same')

        self.down_ch = []

        for _ in range(depth):


          self.layers_down.append(nn.Sequential(
          Patches(self.patch_size),  # Assuming Patches is defined
          ConvAttention(self.ch),  # Assuming ConvAttention is defined
          PatchComb(self.patch_size, self.num_patch, self.ch),  # Assuming PatchComb is defined
          nn.PixelUnshuffle(2)
          ))

          self.down_ch.append(self.ch)
          self.ch = self.ch*4
          self.num_patch =  self.num_patch // 2

        self.down_ch = self.down_ch[::-1]

        for _ in range(depth):

          self.layers_up.append(nn.Sequential(
          Patches(self.patch_size),  # Assuming Patches is defined
          ConvAttention(self.ch),  # Assuming ConvAttention is defined
          PatchComb(self.patch_size, self.num_patch, self.ch),  # Assuming PatchComb is defined
          nn.PixelShuffle(2)
          ))
          print(self.ch//4 , self.down_ch[_])
          self.ch = self.ch//4 + self.down_ch[_]

          print(self.ch)
          self.num_patch =  self.num_patch * 2

    def forward(self,X):
        X_skip = []
        X = self.conv(X)

        for i in range(self.depth):
            X_skip.append(X)
            X = self.layers_down[i](X)

        X_skip = X_skip[::-1]

        for i in range(self.depth):

            X = self.layers_up[i](X)
            print(X.shape)
            X_con = torch.cat([X,X_skip[i]],1)
            X = X_con
            print(X.shape)

        out = self.conv_out(X)

        return out


In [120]:
X = torch.rand((1,1,256,64))

In [121]:
x_s = Conv_Transformers(patch_size=patch_size,ch=4,in_ch=1,num_patch=num_patch,depth=2)(X)

16 16
32
8 4
12
torch.Size([1, 16, 128, 32])
torch.Size([1, 32, 128, 32])
torch.Size([1, 8, 256, 64])
torch.Size([1, 12, 256, 64])


In [41]:
x_s.shape

torch.Size([2, 1, 256, 64])

In [31]:
Conv_Transformers(patch_size=patch_size,ch=16,in_ch=1,num_patch=num_patch,depth=3)

Conv_Transformers(
  (conv): Conv2d(1, 16, kernel_size=(1, 1), stride=(1, 1), padding=same)
  (layers_down): ModuleList(
    (0): Sequential(
      (0): Patches()
      (1): ConvAttention(
        (query): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=same)
        (key): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=same)
        (value): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=same)
      )
      (2): PatchComb()
      (3): PixelUnshuffle(downscale_factor=2)
    )
    (1): Sequential(
      (0): Patches()
      (1): ConvAttention(
        (query): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=same)
        (key): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=same)
        (value): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=same)
      )
      (2): PatchComb()
      (3): PixelUnshuffle(downscale_factor=2)
    )
    (2): Sequential(
      (0): Patches()
      (1): ConvAt

In [None]:
x_s.shape

torch.Size([1, 4, 256, 64])

In [None]:
x_s[0].shape

torch.Size([5, 256, 64])

In [None]:
x_s[1].shape

torch.Size([1, 64, 128, 32])

In [None]:
x_s[2].shape

torch.Size([1, 16, 256, 64])