In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment
from einops import rearrange
from swin import Swin_transformer
from transformer import Decoder
from detectron2.layers import ShapeSpec
from position_encoding import PositionEmbeddingSine

# from IPython.core.display import display, HTML
# display(HTML("<style>.container { width:100% !important; }</style>"))

## Backbone
* Swin Transformer tiny version을 사용하였다.
* 코드는 https://github.com/tinnunculus/SwinTransformer/blob/main/swin.ipynb 여기 있음.

In [2]:
backbone_config = {}
backbone_config['backbone_patch_size'] = 4
backbone_config['backbone_window_size'] = 8
backbone_config['backbone_merge_size'] = 2
backbone_config['backbone_model_dim'] = 96
backbone_config['backbone_num_layers_in_stage'] = [2, 2, 6, 2]

In [3]:
backbone = Swin_transformer(
    patch_size = backbone_config["backbone_patch_size"], 
    window_size = backbone_config["backbone_window_size"], 
    merge_size = backbone_config["backbone_merge_size"], 
    model_dim = backbone_config["backbone_model_dim"], 
    num_layers_in_stage = backbone_config["backbone_num_layers_in_stage"]
).cuda()

In [4]:
imgs = torch.randn((2, 3, 512, 512)).cuda()
features = backbone(imgs)

In [5]:
print(features['res2'].shape)
print(features['res3'].shape)
print(features['res4'].shape)
print(features['res5'].shape)

torch.Size([2, 96, 128, 128])
torch.Size([2, 192, 64, 64])
torch.Size([2, 384, 32, 32])
torch.Size([2, 768, 16, 16])


## Pixel Decoder
* Maskformer의 Pixel Decoder와 다르게 Deformable Transformer의 Encoder를 사용하였다.
* Deformable Transformer는 https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/pixel_decoder/msdeformattn.py 에서 가져왔다.

In [6]:
'''
Deformable Transformer Encoder는 
https://github.com/facebookresearch/Mask2Former/mask2former/modeling/pixel_decoder/msdeformattn.py 에서 가져왔다.
'''

'''
    MSDeformAttnPixelDecoder :
    
        Args:
            input_shape:    
                'res2':(channels =  96, height=None, width=None, stride =  4)
                'res3':(channels = 192, height=None, width=None, stride =  8)
                'res4':(channels = 384, height=None, width=None, stride = 16)
                'res5':(channels = 768, height=None, width=None, stride = 32)
            transformer_dropout: 0.0
            transformer_nheads: 8
            transformer_dim_feedforward: 1024
            transformer_enc_layers: 6
            conv_dims: 256
            mask_dim: 256
            norm (str or callable): 'GN'
            transformer_in_features: ['res3', 'res4', 'res5']
            common_stride: 4
            
        Input:
            features : 
                'res2': torch.Size([1,  96, 328, 200])
                'res3': torch.Size([1, 192, 164, 100])
                'res4': torch.Size([1, 384,  82,  50])
                'res5': torch.Size([1, 768,  41,  25])
            
        Output:
            mask_features : torch tensor [1, 256, res2_h, res2_w],
            transformer_encoder_features : torch tensor [1, 256, res5_h, res5_w],
            multi_scale_features : [
                torch tensor [1, 256, res5_h, res5_w],
                torch tensor [1, 256, res4_h, res4_w],
                torch tensor [1, 256, res3_h, res3_w]
            ]
'''

from msdeformattn import MSDeformAttnPixelDecoder

In [7]:
## Pixel Decoder configuration
pixel_decoder_config = {}
pixel_decoder_config['input_shape'] = {}
pixel_decoder_config['input_shape']['res2'] = ShapeSpec(channels=96, height=None, width=None, stride=4)
pixel_decoder_config['input_shape']['res3'] = ShapeSpec(channels=192, height=None, width=None, stride=8)
pixel_decoder_config['input_shape']['res4'] = ShapeSpec(channels=384, height=None, width=None, stride=16)
pixel_decoder_config['input_shape']['res5'] = ShapeSpec(channels=768, height=None, width=None, stride=32)

pixel_decoder_config['transformer_dropout'] = 0.0
pixel_decoder_config['transformer_nheads'] = 8
pixel_decoder_config['transformer_dim_feedforward'] = 1024
pixel_decoder_config['transformer_enc_layers'] = 6
pixel_decoder_config['conv_dims'] = 256
pixel_decoder_config['mask_dim'] = 256
pixel_decoder_config['norm'] = 'GN'
pixel_decoder_config['transformer_in_features'] = ['res3', 'res4', 'res5']
pixel_decoder_config['common_stride'] = 4

In [8]:
pixel_decoder = MSDeformAttnPixelDecoder(
    input_shape = pixel_decoder_config['input_shape'], 
    transformer_dropout = pixel_decoder_config['transformer_dropout'],
    transformer_nheads = pixel_decoder_config['transformer_nheads'],
    transformer_dim_feedforward = pixel_decoder_config['transformer_dim_feedforward'],
    transformer_enc_layers = pixel_decoder_config['transformer_enc_layers'],
    conv_dim = pixel_decoder_config['conv_dims'],
    mask_dim = pixel_decoder_config['mask_dim'],
    norm = pixel_decoder_config['norm'],
    transformer_in_features = pixel_decoder_config['transformer_in_features'],
    common_stride = pixel_decoder_config['common_stride'],
).cuda()

In [9]:
mask_features, transformer_encoder_features, multi_scale_features = pixel_decoder.forward_features(features)

In [10]:
print(mask_features.shape)
print(transformer_encoder_features.shape)
for feature in multi_scale_features:
    print(feature.shape)

torch.Size([2, 256, 128, 128])
torch.Size([2, 256, 16, 16])
torch.Size([2, 256, 16, 16])
torch.Size([2, 256, 32, 32])
torch.Size([2, 256, 64, 64])


## Transformer decoder
* 기존의 cross attention을 mask cross attention으로 대체.
* self attention 과 cross attention의 위치 변경
* Learnable한 query vectors
* pixel decoder에서 Transformer decoder로 들어가는 feature map은 linear mapping을 한번 거친다.
* nn.MultiheadAttention에서는 N, B, C 순서의 query, key, value type이니 주의해야한다.

In [11]:
class Masked_attention(nn.Module):
    def __init__(self, model_dim, num_heads):
        super().__init__()
        self.mh_attention = nn.MultiheadAttention(embed_dim = model_dim, num_heads = num_heads)
        self.norm = nn.LayerNorm(model_dim)
        
    def forward(self, query, value, key_pos, attn_mask):
        key = value + key_pos
        
        out = self.mh_attention(
            query = query,
            key = key,
            value = value,
            attn_mask = attn_mask
        )[0]
        
        return self.norm(out + query)
    
class Self_attention(nn.Module):
    def __init__(self, model_dim, num_heads):
        super().__init__()
        self.mh_attention = nn.MultiheadAttention(embed_dim = model_dim, num_heads = num_heads)
        self.norm = nn.LayerNorm(model_dim)
        
    def forward(self, query):        
        out = self.mh_attention(
            query = query,
            key = query,
            value = query
        )[0]
        
        return self.norm(out + query)
    
class FFN(nn.Module):
    def __init__(self, model_dim, inter_dim):
        super(FFN, self).__init__()
        
        self.ffn = nn.Sequential(
            nn.Linear(model_dim, inter_dim),
            nn.ReLU(),
            nn.Linear(inter_dim, model_dim)
        )
        
        self.norm = nn.LayerNorm(model_dim)

    def forward(self, x):
        x = self.ffn(x) + x
        return self.norm(x)
    
class MLP(nn.Module):
    def __init__(self, model_dim = 256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(model_dim, model_dim),
            nn.ReLU(),
            nn.Linear(model_dim, model_dim),
            nn.ReLU(),
            nn.Linear(model_dim, model_dim),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.mlp(x)
    
class Transformer_decoder_block(nn.Module):
    def __init__(self, model_dim = 256, num_heads = 8):
        super().__init__()
        
        self.masked_attention = Masked_attention(model_dim, num_heads)
        self.self_attention = Self_attention(model_dim, num_heads)
        self.ffn = FFN(model_dim, 2*model_dim)
        
    def forward(self, query, value, key_pos, attn_mask):
        query = self.masked_attention(query, value, key_pos, attn_mask)
        out = self.self_attention(query)
        out = self.ffn(out)
        
        return out
    
class Transformer_decoder(nn.Module):
    def __init__(self, n_class = 10, L = 3, num_query = 100, num_features = 3, model_dim = 256, num_heads = 8):
        super().__init__()
        
        self.num_features = num_features
        self.num_heads = num_heads
        self.transformer_block = nn.ModuleList([Transformer_decoder_block(model_dim=model_dim, num_heads=num_heads) for _ in range(L * 3)])
        self.query = nn.Parameter(torch.rand(num_query, 1, model_dim))
        
        self.from_features_linear = nn.ModuleList([nn.Conv2d(model_dim, model_dim, kernel_size=1) for _ in range(num_features)])
        self.from_features_bias = nn.ModuleList([nn.Embedding(1, model_dim) for _ in range(num_features)])
        self.pos_emb = PositionEmbeddingSine(model_dim // 2, normalize=True)
        
        self.decoder_norm = nn.LayerNorm(model_dim)
        self.classfication_module = nn.Linear(model_dim, n_class)
        self.segmentation_module = MLP(model_dim)
        
    def forward_prediction_heads(self, mask_embed, pix_emb, decoder_layer_size=None):
        mask_embed = self.decoder_norm(mask_embed)
        mask_embed = mask_embed.transpose(0, 1) # b, 100, 256
        outputs_class = self.classfication_module(mask_embed)
        mask_embed = self.segmentation_module(mask_embed)
        outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, pix_emb)
        
        if decoder_layer_size is not None:
            attn_mask = F.interpolate(outputs_mask, size=decoder_layer_size, mode="bilinear", align_corners=False)
            attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool() # head 수 만큼 복사한다. bool 형으로 넣어야 한다. True인 곳이 무시할 픽셀
            attn_mask = attn_mask.detach()
        else:
            attn_mask = None
        return outputs_class, outputs_mask, attn_mask

        
    def forward(self, features, pix_emb):
        query = self.query.expand(self.query.shape[0], features[0].shape[0], self.query.shape[2]) # batch 만큼 복사
        
        predictions_class = []
        predictions_mask = []
        
        for i in range(self.num_features):
            b, c, h, w = features[i].shape
                                
            kv = self.from_features_linear[i](features[i])  + self.from_features_bias[i].weight[:, :, None, None]
            kv = rearrange(kv, 'b c h w-> (h w) b c')
            
            key_pos = self.pos_emb(b, h, w, features[i].device, None)
            key_pos = rearrange(key_pos, 'b c h w -> (h w) b c')
            
            for j in range(3):
                outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads(query, mask_features, decoder_layer_size=(h, w))
                # axial training을 위해 중간 결과를 저장한다.
                predictions_class.append(outputs_class)
                predictions_mask.append(outputs_mask)
                
                attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False # 중간 추출된 mask가 아무것도 가리키지 않을 경우 global context attention으로 처리한다.
                query = self.transformer_block[i * 3 + j](query, kv, key_pos, attn_mask)
                
        outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads(query, mask_features, decoder_layer_size=None)
        predictions_class.append(outputs_class)
        predictions_mask.append(outputs_mask)
                
        out = {
            'pred_logits': predictions_class[-1],
            'pred_masks': predictions_mask[-1],
            'aux_outputs': {
                'pred_logits' : predictions_class,
                'pred_masks': predictions_mask,
            }
        }
        return out

In [12]:
transformer_decoder_config = {}
transformer_decoder_config['n_class'] = 10
transformer_decoder_config['L'] = 3
transformer_decoder_config['num_query'] = 100
transformer_decoder_config['num_features'] = 3
transformer_decoder_config['model_dim'] = 256
transformer_decoder_config['num_heads'] = 8

In [13]:
transformer_decoder = Transformer_decoder(
    n_class = transformer_decoder_config['n_class'] + 1, 
    L = transformer_decoder_config['L'], 
    num_query = transformer_decoder_config['num_query'], 
    num_features = transformer_decoder_config['num_features'], 
    model_dim = transformer_decoder_config['model_dim'], 
    num_heads = transformer_decoder_config['num_heads']
).cuda()

In [14]:
out = transformer_decoder(multi_scale_features, mask_features)

In [15]:
out.keys()

dict_keys(['pred_logits', 'pred_masks', 'aux_outputs'])

In [16]:
print(out['pred_logits'].shape)
print(out['pred_masks'].shape)
print(len(out['aux_outputs']['pred_logits']), len(out['aux_outputs']['pred_masks']))

torch.Size([2, 100, 11])
torch.Size([2, 100, 128, 128])
10 10


## Matching
* DETR과 동일하게 bipartite maching을 한다.
* MaskFormer와는 다르게 focal loss 대신에 cross entropy를 사용한다.
* MaskFormer와는 다르게 모든 픽셀에 대해서 distance를 계산하지 않고 임의의 추출된 픽셀에 대해서만 계산한다.
* 포인트를 임의로 추출하는 것은 모든 이미지에 대해서 동일한 위치의 픽셀을 추출한다.
* 112 * 112 개의 픽셀을 추출하는데 feature map의 크기가 112, 112 보다 작을 수도 있다. 그렇기에 중복 추출을 허용한다. F.grid_sample 함수를 이용
* matching 하는데 있어서는 마지막 단 layer만을 이용한다.

In [17]:
from time import time

class HungarianMatcher(nn.Module):
    def __init__(self, n_sample = 112 * 112, w_class: float = 1, w_ce: float = 1, w_dice: float = 1):
        super().__init__()
        self.n_sample = n_sample
        self.w_class = w_class
        self.w_ce = w_ce
        self.w_dice = w_dice
        
    @torch.no_grad()
    def dice_cost(self, predict, target):
        # predict : b * n_queries, n_sample_points
        # target : b * n_obj, n_sample_points
        numerator = 2 * (predict[:, None, :] * target[None, :, :]).sum(-1)
        denominator = predict.sum(-1)[:, None] + target.sum(-1)[None, :]
        cost_dice = 1 - (numerator + 1) / (denominator + 1)
        return cost_dice
    
    @torch.no_grad()
    def ce_cost(self, predict, target):
        # predict : b * n_queries, n_sample_points
        # target : b * n_obj, n_sample_points
        predict = predict[:, None, :].expand((predict.shape[0], target.shape[0], predict.shape[1]))
        target = target[None, :, :].expand((predict.shape[0], target.shape[0], target.shape[1]))
        ce = F.binary_cross_entropy_with_logits(predict, target, reduction='none')
        
        return ce.mean(-1)
        
    @torch.no_grad()
    def forward(self, out, targets):
        pred_logits = out["pred_logits"] # b, n, class + 1
        pred_masks = out["pred_masks"] # b, n, h, w
        target_logits = targets["labels"] # [ m_i for i in b]
        target_masks = targets["masks"] # [ m_i, h, w for i in b]
        bs, num_queries = pred_logits.shape[:2]
        device = pred_logits.device
        
        out_prob = pred_logits.flatten(0, 1).softmax(-1)  # [batch_size * num_queries, num_classes]
        tgt_ids = torch.cat([v for v in target_logits]) # [batch_size * num_obj]
        
        
        out_mask = pred_masks.flatten(0, 1).unsqueeze(1)  # [batch_size * num_queries, 1, h, w]
        tgt_mask = torch.cat([v for v in target_masks]).unsqueeze(1) # [batch_size * num_obj, 1, h, w]
        grid = torch.rand((1, 1, self.n_sample, 2), device=out_mask.device) * 2 - 1
        out_grid = grid.expand(out_mask.shape[0], *grid.shape[1:])
        out_mask = F.grid_sample(out_mask, out_grid, mode='nearest', align_corners=False).squeeze()  # [batch_size * num_queries, n_sample_points]
        tgt_grid = grid.expand(tgt_mask.shape[0], *grid.shape[1:])
        tgt_mask = F.grid_sample(tgt_mask, tgt_grid, mode='nearest' , align_corners=False).squeeze()  # [batch_size * num_obj, n_sample_points]

        # cost :
        #     row : pred_querys
        #     col : target_obj
        cost_class = -out_prob[:, tgt_ids]                   # [batch_size * num_queries, batch_size * num_obj]
        cost_dice = self.dice_cost(out_mask, tgt_mask)       # [batch_size * num_queries, batch_size * num_obj] 
        cost_ce = self.ce_cost(out_mask, tgt_mask)           # [batch_size * num_queries, batch_size * num_obj]
        
        # Final cost matrix
        C = self.w_dice * cost_dice + self.w_class * cost_class + self.w_ce * cost_ce
        C = C.view(bs, num_queries, -1).cpu() # [batch_size, num_queries, batch_size * num_obj]
        
        sizes = [len(v) for v in target_masks]
        
        indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
        result = []
        for i, j in indices:
            i = torch.as_tensor(i, dtype=torch.int64, device=device)
            j = torch.as_tensor(j, dtype=torch.int64, device=device)
            result.append(i[j])
        return result

In [18]:
matcher_config = {}
matcher_config['n_sample'] = 112 * 112
matcher_config['w_class'] = 1.0
matcher_config['w_ce'] = 20.0
matcher_config['w_dice'] = 1.0

In [19]:
matcher = HungarianMatcher(
    n_sample = matcher_config['n_sample'],
    w_class = matcher_config['w_class'],
    w_ce = matcher_config['w_ce'],
    w_dice = matcher_config['w_dice']
).cuda()

In [20]:
## 예시 정답 데이터
target = {}
target['labels'] = [torch.zeros((15), dtype=torch.long).cuda(), torch.ones((4), dtype=torch.long).cuda()]
target['masks'] = [torch.zeros((15, 128, 128)).cuda(), torch.ones((4, 128, 128)).cuda()]

In [21]:
match_indexs = matcher(out, target)

In [22]:
print(match_indexs)

[tensor([ 0, 61, 26, 11, 89, 67, 20, 32, 99, 25, 68, 28, 74, 45, 22],
       device='cuda:0'), tensor([56, 80, 53,  5], device='cuda:0')]


## Loss
* 전체적으로 MaskFormer 의 Loss function 과 동일하다.
* Mask loss에 focal loss 대신 cross entropy를 사용한다.
* 모든 매칭에 대해서 classfication loss를 적용한다.
* object가 있는 매칭에 대해서만 mask loss를 적용한다.
* macher와 마찬가지로 point sampling을 하는데, Uniform 하게 샘플링 했던 matcher와는 정답 픽셀(foreground)에서 더 많이 뽑도록 한다. 이를 위해서 detectron2의 get_uncertain_point_coords_with_randomness 함수를 이용한다.

In [23]:
from detectron2.projects.point_rend.point_features import get_uncertain_point_coords_with_randomness

class Maskformer_loss(nn.Module):
    def __init__(self, n_sample = 112 * 112, w_ce = 1., w_dice = 1., w_class = 1., w_noobj = 1., oversample_ratio = 3.0, importance_sample_ratio = 0.75):
        super(Maskformer_loss, self).__init__()
        self.n_sample = n_sample
        self.w_class = w_class
        self.w_ce = w_ce
        self.w_dice = w_dice
        self.w_noobj = w_noobj
        self.oversample_ratio = oversample_ratio
        self.importance_sample_ratio = importance_sample_ratio
        
    def class_loss(self, pred_logits, target_logits, match_indexs):
        device = pred_logits.device
        target_labels = torch.zeros(pred_logits.shape[:2], dtype=torch.int64, device=device)
        cost_no_obj = torch.ones(pred_logits.shape[2], device=device)
        cost_no_obj[0] *= self.w_noobj
        
        for i, match_index in enumerate(match_indexs):
            target_labels[i, match_index] = target_logits[i]
        
        class_loss = F.cross_entropy(pred_logits.flatten(0, 1), target_labels.flatten(0, 1), cost_no_obj)
        return class_loss
        
    def ce_loss(self, predict, target, gamma = 2.0, alpha = 0.25):
        # predict : b * n_queries, h * w
        # target : b * n_obj, h * w
        ce = F.binary_cross_entropy_with_logits(predict, target, reduction='none')

        return ce.mean()
        
    def dice_loss(self, predict, target):
        numerator = 2 * (predict * target).sum(-1)
        denominator = predict.sum(-1) + target.sum(-1)
        loss_dice = 1 - (numerator + 1) / (denominator + 1)
        return loss_dice.mean()
    
    def calculate_uncertainty(self, logits):
        assert logits.shape[1] == 1
        gt_class_logits = logits.clone()
        return -(torch.abs(gt_class_logits))

    def forward(self, out, targets, match_indexs):
        pred_logits = out["pred_logits"] # b, n, class + 1
        pred_masks = out["pred_masks"] # b, n, h, w
        target_logits = targets["labels"] # [ m_i for i in b]
        target_boxes = targets["masks"] # [ m_i, h, w for i in b]
        
        tgt_mask = torch.cat([v for v in target_boxes]).unsqueeze(1) # [batch_size * num_obj, 1, h, w]
        out_mask = pred_masks  # [batch_size, num_queries, h, w]
        out_mask = torch.cat([out_mask[i, match_index, :] for i, match_index in enumerate(match_indexs)]).unsqueeze(1)  # [batch_size * num_obj, 1, h, w]
        
        with torch.no_grad():
            point_coords = get_uncertain_point_coords_with_randomness(
                out_mask,
                lambda logits: self.calculate_uncertainty(logits),
                self.n_sample,
                self.oversample_ratio,
                self.importance_sample_ratio,
            ).unsqueeze(1)
            
            tgt_mask = F.grid_sample(tgt_mask, point_coords, mode='nearest', align_corners=False).squeeze(1) # [batch_size * num_queries, n_sample_points]
        out_mask = F.grid_sample(out_mask, point_coords, mode='nearest', align_corners=False).squeeze(1) # [batch_size * num_queries, n_sample_points]

        class_loss = self.class_loss(pred_logits, target_logits, match_indexs) * self.w_class
        ce_loss = self.ce_loss(out_mask, tgt_mask) * self.w_ce
        dice_loss = self.dice_loss(out_mask, tgt_mask) * self.w_dice
        
        return class_loss + ce_loss + dice_loss

In [24]:
loss_config = {}
loss_config['n_sample'] = 112 * 112
loss_config['w_class'] = 1.0
loss_config['w_ce'] = 20.0
loss_config['w_dice'] = 1.0
loss_config['w_noobj'] = 0.1
loss_config['oversample_ratio'] = 3.0
loss_config['importance_sample_ratio'] = 0.75

In [25]:
Loss = Maskformer_loss(
    n_sample = loss_config['n_sample'] , 
    w_ce = loss_config['w_class'] , 
    w_dice = loss_config['w_ce'] , 
    w_class = loss_config['w_dice'], 
    w_noobj = loss_config['w_noobj'], 
    oversample_ratio = loss_config['oversample_ratio'], 
    importance_sample_ratio = loss_config['importance_sample_ratio']
).cuda()

In [26]:
loss = Loss(out, target, match_indexs)
print('loss: %.5f' % loss)

loss: 19.00613


In [27]:
loss.backward()