#### `DATA NOISE DISTRIBUTION & FEATURE MAP ANALYSIS`

### `GAUSSUNET`


```
- 모델을 pypi 패키지화 하기로 한다.
- 우선은 모델을 encoder - skip connection - decoder 이렇게 나눠야함
- modules.py 파일 안에 모든 layer이나 필요한 module들을 넣어 주어야 함 
  - 예를 들면 dense_res_block, res_block, ssa_module, pyramid_attention, scse등
```



In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DenoiserModel(nn.Module):
  def __init__(self, 
               head,
               encoder,
               decoder,
               tail):
    super(DenoiserModel, self).__init__()
    self.head = head
    self.encoder = encoder
    self.decoder = decoder
    self.tail = tail

  def forward(self, x):
    input = x.clone()
    head = self.head(x)
    features = self.encoder(head)
    ## decoder의 경우에는 encoder의 각 layer의 출력 feature map을 사용해서 최종 decoder output을 생성
    output = self.decoder(*features)
    out = self.tail(output, input)

    return out
  
  @torch.no_grad()
  def predict(self, x):
    """
    model.predict(8F RAW IMAGE)를 하면 알아서 denoised 32F SYN IMAGE를 출력할 수 있게 한다.
    """
    if len(x.shape) == 3:
      x = x.unsqueeze(0)
    
    self.eval()
    out = self.forward(x)

    return out


In [None]:
def build_head(head_params):
  return HeadBlock(**head_params)

def build_encoder(encoder_params):
  return Encoder(**encoder_params)

def build_decoder(decoder_params):
  return Decoder(**decoder_params)

def build_tail(tail_params):
  return TailBlock(**tail_params)

#### `MODULES`

In [None]:
"""
base_modules.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F

############# HEAD & TAIL BLOCKS #####################
##### (1) HEAD BLOCK
class HeadBlock(nn.Module):
  def __init__(self, in_channels = 1, out_channels = 8):
    super(HeadBlock, self).__init__()
    self.head = nn.Sequential(
        nn.BatchNorm2d(in_channels), nn.Tanh(),
        nn.Conv2d(in_channels, out_channels, kernel_size = 1, padding = 0, stride = 1)
    )
  
  def forward(self, x):
    return self.head(x)

##### (2) TAIL BLOCK
class TailBlock(nn.Module):
  def __init__(self, in_channels, out_channels, tail_rate = 0.1):
    super(TailBlock, self).__init__()
    self.tail = nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding = 1, stride = 1)
    self.tail_rate = tail_rate

  def forward(self, x, input):
    x = self.tail(x)
    out = input + (x * self.tail_rate)

    return torch.clamp(out, -2, 2)

############# EXTRACTION BLOCKS ######################
### (1) NOISE EXTRACTION BLOCK
### TODO: MUST BE FIXED AFTER ANALYZING THE NOISE LEVELS OF THE FEATURE MAPS
class NoiseExtractionBlock(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(NoiseExtractionBlock, self).__init()
    self.block = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding = 0, stride = 1)
    )

  def forward(self, x):
    return self.block(x)

### (2) FEATURE EXTRACTION BLOCK ####
class FeatureExtractionBlock(nn.Module):
  def __init__(self, name,in_channels, out_channels):
    super(FeatureExtractionBlock, self).__init__()
    self.name = name.lower()
    if name.lower() == 'conv1':
      ## convolution 안에서는 downsample이나 upsampling을 안한다는 가정이다.
      self.block = nn.Conv2d(in_channels,out_channels, kernel_size = 1, padding = 0, stride = 1)

    elif name.lower() == 'conv3':
      self.block = nn.Conv2d(in_channels, out_channels,kernel_size = 3, padding = 1, stride = 1)

    elif name.lower() == 'sobel':
      self.xfilter = nn.Parameter(
          torch.Tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]), requires_grad = False)
      self.yfilter = nn.Parameter(
          torch.Tensor([[-1, 2, -1], [0, 0, 0], [1, 2, 1]]), requires_grad = False)
      
    elif name.lower() == 'laplacian':
      ## 원래 laplacian filter에서의 중심 픽셀은 4인데 (왜냐면 픽셀의 합이 0이어야 해서) 하지만 2로 하는게 더 나았음
      ## 일반적으로는 gaussian filtering을 한 이후에 하지만
      self.filter = nn.Parameter(
          torch.Tensor([[0, 1, 0], [1, -3, 1], [0, 1, 0]]), requires_grad = False)
      
  def forward(self, x):
    if self.name == 'sobel':
      X = F.conv2d(x, self.xfilter)
      Y = F.conv2d(x, self.yfilter)
      return x + X + Y

    elif self.name == 'laplacian':
      edge = F.conv2d(x, self.filter)
      return x + edge

    else:
      return self.block(x)


###################### CONVOLUTION BLOCK #################################
########### ASEBLOCK FOR EXPERIMENT 1 ######################
class ConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels, 
                  feature_ext = 'conv1', noise_ext = None, attention = None):
    super(ConvBlock, self).__init__()
    """ Basic Convolution Block for the Decoder Module
    - 기본적인 conv1, conv2, conv3로 구성된건 기존의 GaussNet에서의 구조와 동일함
    - Feature Extraction Block은 residual 하게 무조건 feature preservance를 위해서 필요
    - Noise Extraction Block은 noise를 extract하기 위해서 사용할 예정인데, 그러기 위해서는 noise level map에 대한 분석이 필요하다.
    """
    self.conv1 = nn.Sequential(
        nn.InstanceNorm2d(in_channels, affine = True), nn.Tanh(),
        nn.Conv2d(in_channels, out_channels, kernel_size = 1, bias = False),
    )
    self.conv2 = nn.Sequential(
        nn.InstanceNorm2d(out_channels, affine = True), nn.Tanh(),
        nn.Conv2d(out_channels, out_channels, kernel_size = 3, bias = False, padding = 1, stride =1)
    )
    self.attention = Attention(attention, out_channels)
    self.conv3 = nn.Sequential(
        nn.InstanceNorm2d(out_channels, affine = True), nn.Tanh(),
        nn.Conv2d(out_channels, out_channels, kernel_size = 1)
    )

    self.feat_ext = FeatureExtractionBlock(name = feature_ext, in_channels = in_channels, out_channels = out_channels) 
    self.noise_ext = NoiseExtractionBlock( in_channels = in_channels, out_channels = out_channels) if noise_ext is not False else None
    self.final_act = nn.Tanh()
  
  def forward(self, x):
    inp = x.clone()
    x = self.conv3(self.attention(self.conv2(self.conv1(x))))

    identity = self.feat_ext(inp)
    if self.noise_ext:
      noise = self.noise_ext(inp)
      x -= noise
    return self.final_act(x + identity)
    
    

#### SPATIAL & CHANNEL SQUEEZE MODULE 
class SCSEModule(nn.Module):
  def __init__(self, in_channels, reduction = 16):
    super(SCSEModule, self).__init__()
    self.channel = nn.Sequential(
        nn.AdaptiveAvgPool2d(1),
        nn.Conv2d(in_channels, in_channels // reduction, kernel_size = 1),
        nn.Tanh(),
        nn.Conv2d(in_channels // reduction, in_channels, kernel_size = 1),
        nn.Sigmoid()
    )
    self.spatial = nn.Sequential(
        nn.Conv2d(in_channels, 1, kernel_size = 1),
        nn.Sigmoid()
    )
  
  def forward(self, x):
    ch = self.channel(x)
    sp = self.spatial(x)
    out = ch * x + sp * x
    return out

class Attention(nn.Module):
  ## Attention Module used in the Simple blocks
  def __init__(self, name, in_channels):
    super(Attention, self).__init__()
    if name is None:
      self.attention = nn.Identity()
    elif name.lower() == 'scse':
      self.attention = SCSEModule(in_channels)
  def forward(self, x):
    return self.attention(x)


In [None]:
feat = nn.AdaptiveAvgPool2d(1)(torch.rand((2, 32,64,64)))
print(feat.shape)

torch.Size([2, 32, 1, 1])


In [None]:
"""
skip_modules.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
### FOR EXPERIMENT 1 ####################
#### (1) FiLM BASED SKIP CONNECTION
class ConvFiLM(nn.Module):
  def __init__(self, in_channels, film_channels):
    super(ConvFiLM, self).__init__()
    self.conv = nn.Conv2d(in_channels, film_channels, kernel_size = 1)
  
  def forward(self, skip):
    features = self.conv(skip)
    features = F.normalize(features, p = 2, dim = -1)
    gamma, beta = torch.chunk(features, chunks = 2, dim = 1)

    return gamma, beta

class LinearFiLM(nn.Module):
  def __init__(self, in_channels, film_channels):
    super(LinearFiLM, self).__init__()
    self.pool = nn.AdaptiveAvgPool2d(1)
    self.fc = nn.Linear(in_channels, film_channels)
  
    self.film_channels = film_channels

  def forward(self, skip):
    features = self.pool(skip)
    B, C, _, _ = features.shape
    features = features.view(B, C)
    features = self.fc(features)

    features = F.normalize(features, p = 2, dim = -1)

    gamma, beta = torch.chunk(features, chunks = 2, dim = 1)
    gamma = gamma.view(B, self.film_channels//2, 1, 1)
    beta = beta.view(B, self.film_channels//2, 1, 1)

    return gamma, beta

class FiLMSkipConnection(nn.Module):
  def __init__(self, skip_channels, in_channels, film = 'conv'):
    super(FiLMSkipConnection, self).__init__()
    if film == 'conv':
      self.film = ConvFiLM(skip_channels, in_channels * 2)
    else:
      self.film = LinearFiLM(skip_channels, in_channels*2)

  def forward(self, x, skip):
    x = F.interpolate(x, scale_factor = 2, mode = 'nearest')
    gamma, beta = self.film(skip)
    x = gamma * x + beta
    return torch.cat([x, skip], dim = 1)



#### (2) SIMPLE SKIP CONNECTION (Just Concat)
class SimpleSkipConnection(nn.Module):
  def forward(self, x, skip):
    x = F.interpolate(x, scale_factor = 2, mode = 'nearest')
    return torch.cat([x, skip], dim = 1)


#### (3) SUB SPACE ATTENTION SKIP CONNECTION 
class SimpleConv(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(SimpleConv, self).__init__()
    self.conv1 = nn.Sequential(
        nn.InstanceNorm2d(in_channels), nn.ReLU(),
        nn.Conv2d(in_channels, in_channels, kernel_size = 3, padding = 1, stride = 1)
    )
    self.conv2 = nn.Sequential(
        nn.InstanceNorm2d(in_channels), nn.ReLU(),
        nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding = 1, stride = 1)
    )
    self.identity = nn.Conv2d(in_channels, out_channels, kernel_size = 1)
  
  def forward(self, x):
    out = self.conv2(self.conv1(x))
    x = self.identity(x)

    return out + x

class SubNet(nn.Module):
  ## SubNet in the SSA Module (Gets the concat of x1[bridge] and x2[upsampled] as input)
  def __init__(self, ch_in, layer_n):
    super(SubNet, self).__init__()
    self.block = nn.ModuleList([SimpleConv(ch_in, ch_in) for _ in range(layer_n)])
  
  def forward(self, x):
    for idx, b in enumerate(self.block):
      x = b(x)

    return x

class SSASkipConnection(nn.Module):
  def __init__(self, in_channels, skip_channels, conv_layer_n, subspace_dim = 16):
    super(SSASkipConnection, self).__init__()
    """
    - Encoder Decoder의 skip connection에서 사용이 되는 SSA의 subspace_dim은 32일 경우 모델이 너무 복잡해져 학습이 잘 되지 않는다.
    - 또한, subspace_dim이 늘어나면 over smoothing이 너무 강하게 되는 경향을 보인다.
    - SSA에 입력되는 <bridge>는 ConvBlock을 여러개 연결한 Skip Connection을 거치게 된다.
    """
    self.sub_dim = subspace_dim
    self.subnet = SubNet(skip_channels, conv_layer_n)
    self.conv_block = SimpleConv(in_channels + skip_channels, subspace_dim)

  def forward(self, x, skip):
    # x : UpScaled Feature Map
    # bridge = self.skip_conv(bridge)
    x = F.interpolate(x, scale_factor = 2, mode = 'nearest') ## 아직 upscaling을 하지 않았으니 이미지의 크기를 키워준다.
    
    B, C, H, W = skip.shape
    skip = self.subnet(skip)
    #print(skip.shape, x.shape)
    ## SSA Module Starts ##
    # (1) Concat
    out = torch.cat([skip, x], dim = 1) ## 처음에 concatenate를 한번 하고
    # (2) Sub Space Block (=Conv-Block)
    sub = self.conv_block(out)
    # (3) Basis Vectors
    V_t = sub.reshape(B, self.sub_dim, H*W)
    # (4) Projection
    V_t = V_t / (1e-6 + torch.abs(V_t).sum(axis = 2, keepdims = True))  # 여기서는 torch.abs()이기 때문에 양수 + 매우 작은 양의 실수라 0이 될 위험은 없다

    V = torch.transpose(V_t, 1, 2)
    ## transpose를 시키고 원래 matrix와 matmul을 계산하면 대각선 값은 모두 동일함
    mat = torch.matmul(V_t, V) ## (B, 16, 16) : 16은 subspace dim
    # mat의 determinant를 계산한 결과가 양수가 되도록 abs 함수를 취해주면 된다.
    det = torch.clamp(torch.abs(torch.linalg.det(mat)), min = 1e-6)
    ## zero-division error을 막기 위해 어떤 수를 더해주는 것이 제일 나을지 모르겠음 (원래는 det에 어떤 작은 실수를 더해서 div가 가능하도록 했었다)

    mat_inv = torch.div(mat.permute(2, 1, 0), det).permute(2, 0, 1)

    proj_mat = torch.matmul(mat_inv, V_t) ## 그냥 matmul을 하는 수가 1이되도록 하면 loss == NaN이 되는 것은 당연히 막을 수 있다. (V_t)
    skip_ = skip.reshape(B, C, W*H)
    proj_feat = torch.matmul(proj_mat, torch.transpose(skip_, 1, 2))
    skip = torch.matmul(V, proj_feat)
    skip = torch.transpose(skip, 1, 2).reshape(B, C, H, W)
    
    out = torch.cat([x, skip], 1) ## 변화가 된, 즉 subspace projection에 의해서 새로운 latent space로 mapping이 된 skip feature map을 upscaling된 input에 더해준다.
    #print(out.shape, "OUT")
    return out 




#### `ENCODER`

In [None]:
# pretrained Encoder을 사용하고, 그 주어진 feature map의 크기에 맞게 decoder이 upsampling을 한다.
import torch
import torch.nn as nn
import torch.nn.functional as F

###### ENCODER BLOCK ##################
#######################################
class EncoderBlock(nn.Module):
  def __init__(self, 
               in_channels,
               out_channels,
               feature_ext = 'conv1',
               noise_ext = False):
    super(EncoderBlock, self).__init__()
    self.conv1 = nn.Sequential(
        nn.InstanceNorm2d(in_channels), nn.Tanh(),
        nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding = 1, stride = 1)
    )
    self.attention1 = Attention('scse', out_channels)
    self.conv2 = nn.Sequential(
        nn.InstanceNorm2d(out_channels), nn.Tanh(),
        nn.Conv2d(out_channels, out_channels, kernel_size = 1, padding = 0, stride = 1)
    )
    self.attention2 = Attention('scse', out_channels)
    self.conv3 = nn.Sequential(
        nn.InstanceNorm2d(out_channels), nn.Tanh(),
        nn.Conv2d(out_channels, out_channels, kernel_size = 3, padding = 1, stride = 1)
    )
    self.attention3 = Attention('scse', out_channels)

    self.feat_ext = FeatureExtractionBlock(feature_ext, in_channels, out_channels)
    self.noise_ext = NoiseExtractionBlock(in_channels, out_channels) if noise_ext else None
    self.final_act = nn.Tanh()

  def forward(self, x):
    inp = x.clone()
    x = self.attention1(self.conv1(x))
    x = self.attention2(self.conv2(x))
    x = self.attention3(self.conv3(x))
    feat = self.feat_ext(inp)
    if self.noise_ext:
      noise = self.noise_ext(inp)
      x -= noise
    x += feat
    x = F.interpolate(x, scale_factor = 0.5, mode = 'nearest')

    return x

###############################################
class Encoder(nn.Module):
  def __init__(self, encoder_channels = [8, 16, 32, 64], layer_n = 4):
    super(Encoder, self).__init__()
    blocks = [
        EncoderBlock(ch_in, ch_out) for (ch_in , ch_out) in zip(encoder_channels, encoder_channels[1:])
    ]
  
    self.blocks = nn.ModuleList(blocks)
  def forward(self, x):
    features = []
    for idx, block in enumerate(self.blocks):
      features.append(x)
      x = block(x)
    features.append(x)
    features.append(F.interpolate(x, scale_factor = 0.5, mode = 'nearest'))
    return features


In [None]:
encoder = Encoder()
## HEAD는 (1 -> 8의 channel크기를 갖도록 한다.)
x = torch.rand((2, 8, 512, 512))
out = encoder(x)
for o in out:
  print(o.shape)

torch.Size([2, 8, 512, 512])
torch.Size([2, 16, 256, 256])
torch.Size([2, 32, 128, 128])
torch.Size([2, 64, 64, 64])
torch.Size([2, 64, 32, 32])


#### `DECODER BLOCK`

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

##### DECODER BLOCK ###############
###################################
def build_skipconnection(skip_name, ch_in, skip_ch, conv_n):
  skip_name = skip_name.lower()
  if skip_name == 'concat':
    return SimpleSkipConnection()
  elif skip_name == 'film':
    return FiLMSkipConnection(skip_ch, ch_in)
  elif skip_name == 'ssa':
    return SSASkipConnection(ch_in, skip_ch, conv_layer_n =  conv_n)

class DecoderBlock(nn.Module):
  def __init__(
      self, 
      in_channels, ## skip connection을 거치고 나면 decoder layer의 input channel의 크기
      skip_channels,
      out_channels, ## 최종 output channel의 크기
      feat_ext_mode, ## feature extraction mode를 어떻게 할지 (default는 그냥 1x1 convolution)
      noise_ext, ## boolean (Noise Extraction을 할지 말지)
      attention_mode, ## ["None", "scse"]
      skip_mode, ## ["concat", "ssa", "film"]
      skip_conv_n ## dda connection에서 사용
    ):
    super(DecoderBlock, self).__init__()
    self.in_channels = in_channels
    self.skip_connection = build_skipconnection(skip_mode, in_channels, skip_channels, skip_conv_n)
    self.body = ConvBlock(in_channels + skip_channels, out_channels, feature_ext = feat_ext_mode, noise_ext = noise_ext)


  def forward(self, x, skip):
    ## skip connection 안에서 upsamping을 해 줌
    x = self.skip_connection(x, skip)
    x = self.body(x)

    return x

class Decoder(nn.Module):
  def __init__(
    self,
    encoder_channels=[8, 16, 32, 64, 128], 
    decoder_channels=[128, 64, 32, 16, 8, 1],
    feat_ext_mode='conv1',
    noise_ext=False,
    attention_mode='scse',
    skip_mode= 'ssa', # 'film', # 'ssa'
    n_blocks=5,
    ):
    super(Decoder, self).__init__()
    self.skip_conv_nums = [int(i) for i in range(n_blocks, 0, -1)]
    self.encoder_channels = encoder_channels[:n_blocks]
    self.decoder_channels = decoder_channels

    self.center = nn.Sequential(
        nn.InstanceNorm2d(encoder_channels[-2]), nn.Tanh(),
        nn.Conv2d(encoder_channels[-2], encoder_channels[-1], kernel_size = 3, padding = 1, stride = 1),
        nn.InstanceNorm2d(encoder_channels[-1]), nn.Tanh(),
        nn.Conv2d(encoder_channels[-1], decoder_channels[0], kernel_size = 3, padding = 1, stride = 1)
    )

    blocks = [
        DecoderBlock(ch_in, ch_skip, ch_out, feat_ext_mode = feat_ext_mode, noise_ext = noise_ext, attention_mode = attention_mode, skip_mode = skip_mode, skip_conv_n = skip_n) for \
        (ch_in, ch_skip, ch_out, skip_n) in zip(self.decoder_channels, self.encoder_channels[::-1][1:], self.decoder_channels[1:], self.skip_conv_nums)
    ]
    self.blocks = nn.ModuleList(blocks)
  
  def forward(self, x, features):
    x = self.center(x)
    features = features[::-1]
    for idx, block in enumerate(self.blocks):
      # print(x.shape, features[idx].shape)
      x = block(x, features[idx])
    
    return x




  






In [None]:
x = torch.rand((2, 8, 512,512))
features = [
    torch.rand((2, 8, 512, 512)),
    torch.rand((2, 16, 256, 256)),
    torch.rand((2, 32, 128, 128)),
    torch.rand((2, 64, 64, 64)),
]
encoder = Encoder()
features = encoder(x)
for f in features:
  print(f.shape)
decoder = Decoder()
print(decoder(features[-1], features[:-1]).shape)

torch.Size([2, 8, 512, 512])
torch.Size([2, 16, 256, 256])
torch.Size([2, 32, 128, 128])
torch.Size([2, 64, 64, 64])
torch.Size([2, 64, 32, 32])
torch.Size([2, 8, 512, 512])


In [None]:
a = [1,2,3]
a[:-1]

[1, 2]

#### `SETUP`

In [None]:
from setuptools import setup

setup(
    name = 'gaussunet',
    packages = ['gaussunet'],
    version = '0.0.1.dev'
)

# 이제 project root directory에서 pip install -e라고 입력하면 실행 가능한 gaussunet 폴더가 생성이 된다.

#### `SMP UNET & UNET++ TEST`

In [None]:
!pip install segmentation_models_pytorch

In [None]:
import segmentation_models_pytorch as smp

unet = smp.Unet(
    encoder_name = ''
)