<a href="https://colab.research.google.com/github/taravatp/Multi_Spectral_Image_Segmentation/blob/main/networks/transunet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports

In [2]:
# from google.colab import drive
# drive.mount('/content/drive')

In [7]:
# cd /content/drive/MyDrive/Vision_Impulse_Task

/content/drive/MyDrive/Vision_Impulse_Task


In [4]:
!pip install ml_collections

Collecting ml_collections
  Downloading ml_collections-0.1.1.tar.gz (77 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/77.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.9/77.9 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: ml_collections
  Building wheel for ml_collections (setup.py) ... [?25l[?25hdone
  Created wheel for ml_collections: filename=ml_collections-0.1.1-py3-none-any.whl size=94506 sha256=289170c0ed290b8d1340eb3e453906125bb34ae215ac88a43933b595df8377ad
  Stored in directory: /root/.cache/pip/wheels/7b/89/c9/a9b87790789e94aadcfc393c283e3ecd5ab916aed0a31be8fe
Successfully built ml_collections
Installing collected packages: ml_collections
Successfully installed ml_collections-0.1.1


In [8]:
import torch
import torch.nn as nn
from torch.nn.modules.utils import _pair
import math
import copy
import numpy as np
import vit_seg_configs as configs
import os

#Paths to weights of pretrained VIT

In [10]:
ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
ATTENTION_K = "MultiHeadDotProductAttention_1/key"
ATTENTION_V = "MultiHeadDotProductAttention_1/value"
ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
FC_0 = "MlpBlock_3/Dense_0"
FC_1 = "MlpBlock_3/Dense_1"
ATTENTION_NORM = "LayerNorm_0"
MLP_NORM = "LayerNorm_2"

# Activation functions

In [11]:
def swish(x):
    return x * torch.sigmoid(x)

ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}

def np2th(weights, conv=False):
    if conv:
        weights = weights.transpose([3, 2, 0, 1])
    return torch.from_numpy(weights)

# UNet Encoder

In [12]:
def convolution_module(in_channels,out_channels):
  conv = nn.Sequential(
      nn.Conv2d(in_channels,out_channels,kernel_size=3, padding=1, bias=False),
      nn.BatchNorm2d(out_channels),
      nn.ReLU(inplace=True),
      nn.Conv2d(out_channels,out_channels,kernel_size=3, padding=1, bias=False),
      nn.BatchNorm2d(out_channels),
      nn.ReLU(inplace=True)
  )
  return conv

In [13]:
class CNN_encoder(nn.Module):

  def __init__(self, in_channels=12, feat_channels=[32, 64, 128, 256]):

    super(CNN_encoder, self).__init__()

    # Encoder convolutions
    self.down_conv1 = convolution_module(in_channels,32)
    self.down_conv2 = convolution_module(32,64)
    self.down_conv3 = convolution_module(64,128)
    self.down_conv4 = convolution_module(128,256)
    self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
    self.out_channels = 256

  def forward(self,x):

    x1 = self.down_conv1(x)
    x_low1 = self.max_pool(x1)

    x2 = self.down_conv2(x_low1)
    x_low2 = self.max_pool(x2)

    x3 = self.down_conv3(x_low2)
    x_low3 = self.max_pool(x3)

    x4 = self.down_conv4(x_low3)
    # x_low4 = self.max_pool(x4)

    # x5 = self.down_conv5(x_low4)

    return x4,[x3,x2,x1]

# Transformer Modules

In [14]:
class Embeddings(nn.Module):

    def __init__(self, config, in_channels):
        super(Embeddings, self).__init__()

        self.config = config

        self.img_size = 128
        self.feature_size = 8
        self.patch_size = 1
        self.patch_size_real = 16

        img_size = _pair(self.img_size) #(64,64)
        feature_size = _pair(self.feature_size) #()
        patch_size = _pair(self.patch_size) #(1,1)
        patch_size_real = _pair(self.patch_size_real) #(8,8)

        self.n_patches = (feature_size[0] // patch_size[0]) * (feature_size[1] // patch_size[1]) # if patching applied on the main image
        self.n_patches_real = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1]) # if patching applied on the feature map

        self.patch_embeddings = nn.Conv2d(in_channels=in_channels, out_channels=config.hidden_size, kernel_size=patch_size, stride=patch_size)
        self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches_real, config.hidden_size))
        self.dropout = nn.Dropout(config.transformer["dropout_rate"])

    def forward(self, x):
        x = self.patch_embeddings(x)
        x = x.flatten(2)
        x = x.transpose(-1, -2)
        embeddings = x + self.position_embeddings[:,:self.n_patches,:]
        embeddings = self.dropout(embeddings)
        return embeddings

In [15]:
class Attention(nn.Module):
    def __init__(self, config):
        super(Attention, self).__init__()

        self.num_attention_heads = config.transformer["num_heads"]
        self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.out = nn.Linear(config.hidden_size, config.hidden_size)
        self.attn_dropout = nn.Dropout(config.transformer["attention_dropout_rate"])
        self.proj_dropout = nn.Dropout(config.transformer["attention_dropout_rate"])

        self.softmax = nn.Softmax(dim=-1)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states):

        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        key_layer = key_layer.transpose(-1, -2)
        attention_scores = torch.matmul(query_layer, key_layer) #because the input tensors have 4 dimensions, this function performs batch multiplication
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = self.softmax(attention_scores)
        attention_probs = self.attn_dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        attention_output = self.out(context_layer)
        attention_output = self.proj_dropout(attention_output)

        return attention_output

In [16]:
class Mlp(nn.Module):
    def __init__(self, config):
        super(Mlp, self).__init__()
        self.fc1 = nn.Linear(config.hidden_size, config.transformer["mlp_dim"])
        self.fc2 = nn.Linear(config.transformer["mlp_dim"], config.hidden_size)
        self.act_fn = ACT2FN["gelu"]
        self.dropout = nn.Dropout(config.transformer["dropout_rate"])
        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.normal_(self.fc1.bias, std=1e-6)
        nn.init.normal_(self.fc2.bias, std=1e-6)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act_fn(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

In [17]:
class Block(nn.Module):
    def __init__(self, config):
        super(Block, self).__init__()
        self.hidden_size = config.hidden_size
        self.attention_norm = nn.LayerNorm(config.hidden_size, eps=1e-6)
        self.ffn_norm = nn.LayerNorm(config.hidden_size, eps=1e-6)
        self.ffn = Mlp(config)
        self.attn = Attention(config)

    def forward(self, x):
        h = x
        x = self.attention_norm(x) #[8,256,768] => [batchsize, number_of_tokens, hidden_size]
        x = self.attn(x) #[8,256,768] => [batchsize, number_of_tokens, hidden_size]
        x = x + h

        h = x
        x = self.ffn_norm(x) #[8,256,768] => [batchsize, number_of_tokens, hidden_size]
        x = self.ffn(x) #[8,256,768] => [batchsize, number_of_tokens, hidden_size]
        x = x + h
        return x

    def load_from(self, weights, n_block):

      ROOT = f"Transformer/encoderblock_{n_block}"

      with torch.no_grad():
        query_weight = np2th(weights[os.path.join(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()
        key_weight = np2th(weights[os.path.join(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()
        value_weight = np2th(weights[os.path.join(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t()
        out_weight = np2th(weights[os.path.join(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t()

        query_bias = np2th(weights[os.path.join(ROOT, ATTENTION_Q, "bias")]).view(-1)
        key_bias = np2th(weights[os.path.join(ROOT, ATTENTION_K, "bias")]).view(-1)
        value_bias = np2th(weights[os.path.join(ROOT, ATTENTION_V, "bias")]).view(-1)
        out_bias = np2th(weights[os.path.join(ROOT, ATTENTION_OUT, "bias")]).view(-1)

        self.attn.query.weight.copy_(query_weight)
        self.attn.key.weight.copy_(key_weight)
        self.attn.value.weight.copy_(value_weight)
        self.attn.out.weight.copy_(out_weight)
        self.attn.query.bias.copy_(query_bias)
        self.attn.key.bias.copy_(key_bias)
        self.attn.value.bias.copy_(value_bias)
        self.attn.out.bias.copy_(out_bias)

        mlp_weight_0 = np2th(weights[os.path.join(ROOT, FC_0, "kernel")]).t()
        mlp_weight_1 = np2th(weights[os.path.join(ROOT, FC_1, "kernel")]).t()
        mlp_bias_0 = np2th(weights[os.path.join(ROOT, FC_0, "bias")]).t()
        mlp_bias_1 = np2th(weights[os.path.join(ROOT, FC_1, "bias")]).t()

        self.ffn.fc1.weight.copy_(mlp_weight_0)
        self.ffn.fc2.weight.copy_(mlp_weight_1)
        self.ffn.fc1.bias.copy_(mlp_bias_0)
        self.ffn.fc2.bias.copy_(mlp_bias_1)

        attention_norm_weight  = np2th(weights[os.path.join(ROOT, ATTENTION_NORM, "scale")])
        attention_norm_bias = np2th(weights[os.path.join(ROOT, ATTENTION_NORM, "bias")])
        ffn_norm_weight = np2th(weights[os.path.join(ROOT, MLP_NORM, "scale")])
        ffn_norm_bias = np2th(weights[os.path.join(ROOT, MLP_NORM, "bias")])

        self.attention_norm.weight.copy_(attention_norm_weight)
        self.attention_norm.bias.copy_(attention_norm_bias)
        self.ffn_norm.weight.copy_(ffn_norm_weight)
        self.ffn_norm.bias.copy_(ffn_norm_bias)


In [18]:
class Encoder(nn.Module):
    def __init__(self, config):
        super(Encoder, self).__init__()
        self.layer = nn.ModuleList()
        self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=1e-6)

        for _ in range(config.transformer["num_layers"]):
            layer = Block(config)
            self.layer.append(copy.deepcopy(layer))

    def forward(self, hidden_states):
        for layer_block in self.layer:
            hidden_states = layer_block(hidden_states)

        encoded = self.encoder_norm(hidden_states)
        return encoded

# CNN decoder

In [19]:
class CNN_decoder(nn.Module):
  def __init__(self,config):
    super(CNN_decoder,self).__init__()

    self.config = config
    self.first_conv = nn.Conv2d(config.hidden_size,256, kernel_size=1, stride = 1, padding=0)

    self.upsample1 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
    self.conv1 = convolution_module(256, 128)

    self.upsample2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
    self.conv2 = convolution_module(128, 64)

    self.upsample3 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
    self.conv3 = convolution_module(64, 32)

    # self.upsample4 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)
    # self.conv4 = convolution_module(32, 16)

    self.out = nn.Conv2d(32,3,kernel_size=1)

  def crop_tensor(self,target_tensor,input_tensor):
    target_tensor_size = target_tensor.size()[2]
    input_tensor_size = input_tensor.size()[2]
    delta = input_tensor_size - target_tensor_size
    delta = delta // 2
    cropped_tensor = input_tensor[:,:,delta:input_tensor_size-delta,delta:input_tensor_size-delta]
    return cropped_tensor

  def forward(self,hidden_states, features):

    # reshaping the hidden_states
    batch_size, n_patch, hidden_size = hidden_states.size()  # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
    H, W = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
    hidden_states = hidden_states.permute(0, 2, 1) # (batch_size, hidden_size, n_patch)
    hidden_states = hidden_states.contiguous().view(batch_size, hidden_size, H, W)

    #one convolution layer for changing the number of channels
    x = self.first_conv(hidden_states)
    [x1,x2,x3] = features

    x = self.upsample1(x)
    y = self.crop_tensor(x,x1)
    x = self.conv1(torch.cat([x,y],1))

    x = self.upsample2(x)
    y = self.crop_tensor(x,x2)
    x = self.conv2(torch.cat([x,y],1))

    x = self.upsample3(x)
    y = self.crop_tensor(x,x3)
    x = self.conv3(torch.cat([x,y],1))

    # x = self.upsample4(x)
    # y = self.crop_tensor(x,x4)
    # x = self.conv4(torch.cat([x,y],1))

    x = self.out(x)
    return x


# Building TransUNET

In [22]:
class Trans_Unet(nn.Module):
  def __init__(self,config):
    super(Trans_Unet, self).__init__()

    self.config = config
    self.conv_encoder = CNN_encoder()
    self.embedding = Embeddings(config, in_channels=self.conv_encoder.out_channels)
    self.transformer_encoder = Encoder(config)
    self.conv_decoder = CNN_decoder(config)
    # self.conv_layer = nn.Conv2d(in_channels=self.conv_encoder.out_channels, out_channels=3, kernel_size=1, stride=1, padding = 0)

  def forward(self,x):
    x, features = self.conv_encoder(x)
    # x = self.conv_layer(x)
    embeddings = self.embedding(x)
    hidden_states = self.transformer_encoder(embeddings)
    output = self.conv_decoder(hidden_states, features)
    return output

  def load_from(self,weights):

    with torch.no_grad():
      embedding_kernel = np2th(weights["embedding/kernel"], conv=True)[:,0:self.conv_encoder.out_channels,:,:]
      embedding_bias = np2th(weights["embedding/bias"])

      self.embedding.patch_embeddings.weight.copy_(embedding_kernel)
      self.embedding.patch_embeddings.bias.copy_(embedding_bias)

      self.transformer_encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))
      self.transformer_encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))

      posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])
      posemb_new = self.embedding.position_embeddings

      # Encoder whole
      for bname, block in self.transformer_encoder.named_children():
          for uname, unit in block.named_children():
              unit.load_from(weights, n_block=uname)

In [23]:
if __name__ == "__main__":

  CONFIGS = {
    'ViT-B_16': configs.get_b16_config(),
    'ViT-B_32': configs.get_b32_config(),
    'ViT-L_16': configs.get_l16_config(),
    'ViT-L_32': configs.get_l32_config(),
    'ViT-H_14': configs.get_h14_config(),
    'R50-ViT-B_16': configs.get_r50_b16_config(),
    'R50-ViT-L_16': configs.get_r50_l16_config(),
    'testing': configs.get_testing(),
  }

  config = CONFIGS['ViT-B_16']
  image = torch.rand((8,12,64,64))
  model = Trans_Unet(config)

  output = model(image)
  print('output:',output.shape)

  path_wights = '/content/drive/MyDrive/Vision_Impulse_Task/R50+ViT-B_16.npz'
  weights = np.load(path_wights)
  model.load_from(weights)

output: torch.Size([8, 3, 64, 64])
