In [1]:
from blocks import UnetEncodeLayer,UnetUpscaleLayer,UnetForwardDecodeLayer, conv3x3
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms.v2 as v2
import torchvision.transforms.functional as functional
import math
from transformers import AutoModel, AutoImageProcessor, SegformerForSemanticSegmentation, SegformerConfig
from blocks import UnetEncodeLayer,UnetUpscaleLayer,UnetForwardDecodeLayer
from torchvision.models.segmentation import deeplabv3_resnet101,deeplabv3_mobilenet_v3_large

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
class VisionTransformer(nn.Module):
	# D = embedding dimension (patch is p*p*3 and will be projected to be D dimensional)
	# N = number of patches
	# p = patch size	
	def __init__(self, D, num_heads):
		super(VisionTransformer, self).__init__()
		# self.linear_projection = nn.Linear(p*p*3, D) don't need it in this architecture
		# self.positional_encoding = PositionalEncoding(D, N) neither this.
		self.layer_norm1_1 = nn.LayerNorm(D) #q
		self.layer_norm1_2 = nn.LayerNorm(D) #k,v
		self.layer_norm2 = nn.LayerNorm(D)
		self.MHA = nn.MultiheadAttention(embed_dim=D, num_heads=num_heads, batch_first=True)
		self.mlp = nn.Sequential(
			# using D*4 hidden size according to original vision transformer paper
			nn.Linear(D, D*4),
			nn.GELU(),
			nn.Linear(D*4, D)
		)
		# we should have one of this for each head
	def forward(self, s: torch.tensor):
		"""
		Parameters:
			+ s <concatenation of query, key and value>
		"""
		#x = self.linear_projection(x) # N, p*p*3 --> N, D
		#self.r1 = self.positional_encoding(x) # add positional encoding to x, embedded patches
		q = s[0]
		k = s[1]
		v = s[2]
		self.r1 = q
		q = self.layer_norm1_1(q)
		k = self.layer_norm1_2(k)
		v = self.layer_norm1_2(v)

		x = self.MHA(q,k,v)[0]

		self.r2 = x + self.r1
		x = self.layer_norm2(x)
		x = self.mlp(x)
		return torch.stack((x+self.r2, k, v), dim=0)

def vision_transformer(D,num_heads):
	return VisionTransformer(D,num_heads)
	
class VisionTransformerEncoder(nn.Module):
	def __init__(self, D, num_heads, layers):
		super(VisionTransformerEncoder, self).__init__()
		self.layers =[vision_transformer(D,num_heads) for _ in range(layers)]
		self.stack = nn.Sequential(*self.layers)
	def forward(self, q,k,v):
		s = torch.stack((q,k,v), dim=0)
		return self.stack(s)

In [21]:
class FUnet(nn.Module):
	  # classic Unet with some reshape and cropping to match our needs.
	def __init__(self, num_classes, D=196):
		super(FUnet, self).__init__()		
		self.requires_context = True
		self.wrapper = False
		self.returns_logits = True
		# -----------------PATCH ENCODER-----------------------
		self.D = D
		self.encode1 = nn.Sequential(
			UnetEncodeLayer(3, 64, padding=1),
			UnetEncodeLayer(64, 64, padding=1), ## keep dimensions unchanged
		)
		self.encode2 = nn.Sequential(
			nn.MaxPool2d(kernel_size=2, stride=2),
			UnetEncodeLayer(64, 128, padding=1),
			UnetEncodeLayer(128, 128, padding=1),
		)
		self.encode3 = nn.Sequential(
			nn.MaxPool2d(kernel_size=2, stride=2),
			UnetEncodeLayer(128, 256, padding=1),
			UnetEncodeLayer(256, 256, padding=1),
		)
		self.encode4 = nn.Sequential(
			nn.MaxPool2d(kernel_size=2, stride=2),
			UnetEncodeLayer(256, 512, padding=1),
			UnetEncodeLayer(512, 512, padding=1),
		)
		self.encode5 = nn.Sequential(
			nn.MaxPool2d(kernel_size=2, stride=2),
			UnetEncodeLayer(512, 1024, padding=1),
			UnetEncodeLayer(1024, 1024, padding=1),
		)
		# -----------------CONTEXT ENCODER--------------------
		self.encode1_c = nn.Sequential(
			UnetEncodeLayer(3, 64, padding=3, dilation=3),
			UnetEncodeLayer(64, 64, padding=3, dilation=3), ## keep dimensions unchanged
		)
		self.encode2_c = nn.Sequential(
			nn.MaxPool2d(kernel_size=2, stride=2),
			UnetEncodeLayer(64, 128, padding=3, dilation=3),
			UnetEncodeLayer(128, 128, padding=3, dilation=3),
		)
		self.encode3_c = nn.Sequential(
			nn.MaxPool2d(kernel_size=2, stride=2),
			UnetEncodeLayer(128, 256, padding=3, dilation=3),
			UnetEncodeLayer(256, 256, padding=3, dilation=3),
		)
		self.encode4_c = nn.Sequential(
			nn.MaxPool2d(kernel_size=2, stride=2),
			UnetEncodeLayer(256, 512, padding=3, dilation=3),
			UnetEncodeLayer(512, 512, padding=3, dilation=3),
		)
		self.encode5_c = nn.Sequential(
			nn.MaxPool2d(kernel_size=2, stride=2),
			UnetEncodeLayer(512, 1024, padding=3, dilation=3),
			UnetEncodeLayer(1024, 1024, padding=3, dilation=3),
		)		

		# ---------------DECODER-----------------

		self.upscale1 = nn.Sequential(
			nn.ConvTranspose2d(1024, 512,kernel_size=2, stride=2)
		)
		self.decode_forward1 = nn.Sequential(
			UnetForwardDecodeLayer(1536,512, padding=1)
		)
		self.upscale2 = nn.Sequential(
			nn.ConvTranspose2d(512, 256,kernel_size=2, stride=2)
		)
		self.decode_forward2 = nn.Sequential(
			UnetForwardDecodeLayer(768, 256, padding=1)
		)
		self.upscale3 = nn.Sequential(
			nn.ConvTranspose2d(256, 128,kernel_size=2, stride=2)
		)
		self.decode_forward3 = nn.Sequential(
			UnetForwardDecodeLayer(384,128,padding=1)
		)
		self.upscale4 = nn.Sequential(
			nn.ConvTranspose2d(128, 64,kernel_size=2, stride=2)
		)
		self.decode_forward4 = nn.Sequential(
			UnetForwardDecodeLayer(192,64, padding=1),
			nn.Conv2d(64, num_classes, kernel_size=1) # final conv 1x1
			# Model output is 6xHxW, so we have a prob. distribution
			# for each pixel (each pixel has a logit for each of the 6 classes.)
		)

		self.transformer = VisionTransformerEncoder(self.D, 4, 2)

	def encode_patch(self, x: torch.Tensor):
		self.x1 = self.encode1(x)
		self.x2 = self.encode2(self.x1)
		self.x3 = self.encode3(self.x2)
		self.x4 = self.encode4(self.x3)
		self.x5 = self.encode5(self.x4)
		return self.x5

	def encode_context(self, x:torch.Tensor):
		self.x1_c = self.encode1_c(x)
		self.x2_c = self.encode2_c(self.x1_c)
		self.x3_c = self.encode3_c(self.x2_c)
		self.x4_c = self.encode4_c(self.x3_c)
		self.x5_c = self.encode5_c(self.x4_c)
		return self.x5_c
	
	def embedding_fusion(self):
		N,L,h,w = self.x5.shape
		D = h*w		
		q = self.x5.reshape(N,L,D)
		k = v = self.x5_c.reshape(N,L,D)
		test = self.transformer(q,k,v)
		self.fused_features = test[0].reshape(N,L,h,w)
		pass

	def decode(self):
		y1 = self.upscale1(self.fused_features)

		c1 = torch.concat((self.x4, self.x4_c, y1), 1)
		y2 = self.decode_forward1(c1)
		
		y2 = self.upscale2(y2)
		c2 = torch.concat((self.x3, self.x3_c, y2), 1)
		y3 = self.decode_forward2(c2)

		y3 = self.upscale3(y3)
		c3 = torch.concat((self.x2, self.x2_c, functional.center_crop(y3, self.x2.shape[2])), 1)
		y4 = self.decode_forward3(c3)

		y4 = self.upscale4(y4)
		c4 = torch.concat((self.x1, self.x1_c, y4), 1)
		return self.decode_forward4(c4)
		
	def forward(self, x: torch.Tensor, context):
		patch_embedding = self.encode_patch(x)
		context_embedding = self.encode_context(context)
		self.embedding_fusion()
		segmap = self.decode()
		return segmap

In [22]:
net = FUnet(16)

In [None]:
test = torch.rand((2,3,224,224))
out = net(test, test)