In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment
from einops import rearrange
from swin import Swin_transformer
from transformer import Decoder

In [2]:
%%javascript
(function(on) {
const e=$( "<a>Setup failed</a>" );
const ns="js_jupyter_suppress_warnings";
var cssrules=$("#"+ns);
if(!cssrules.length) cssrules = $("<style id='"+ns+"' type='text/css'>div.output_stderr { } </style>").appendTo("head");
e.click(function() {
    var s='Showing';  
    cssrules.empty()
    if(on) {
        s='Hiding';
        cssrules.append("div.output_stderr, div[data-mime-type*='.stderr'] { display:none; }");
    }
    e.text(s+' warnings (click to toggle)');
    on=!on;
}).click();
$(element).append(e);
})(true);

<IPython.core.display.Javascript object>

## Pixel Decoder
* FPN 모델 구조를 따왔음.
* FPN 모델 구조를 따왔지만 layer 마다 output을 뽑아서 학습을 하지는 않도록 했다.
* Swin transformer의 layer에서 feature map을 따올 때, 이미지 형태의 구조(b, c, h, w)로 따와야 한다.

In [3]:
class Pixel_decoder(nn.Module):
    def __init__(
        self, 
        in_channels: list = [96 * 8, 96 * 4, 96 * 2, 96], 
        channels: int = 256,
        n_groups: int = 16
    ):
        super(Pixel_decoder, self).__init__()
        self.num_stage = len(in_channels)
        self.from_encoder_projection_list = nn.ModuleList([])
        self.from_feature_projection_list = nn.ModuleList([])
        
        # 첫번째 layer는 encoder에서 들어오는 feature에 대한 projection이 필요가 없음.
        for i, in_channel in enumerate(in_channels):
            if i == 0:
                from_feature_projection = nn.Sequential(
                nn.Conv2d(in_channel, channels, kernel_size=3, stride=1, padding=1),
                nn.GroupNorm(n_groups, channels),
                nn.ReLU()
                )
                
                self.from_encoder_projection_list.append(None)
                self.from_feature_projection_list.append(from_feature_projection)
            else:
                from_feature_projection = nn.Sequential(
                nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
                nn.GroupNorm(n_groups, channels),
                nn.ReLU()
                )
                from_encoder_projection = nn.Sequential(
                    nn.Conv2d(in_channel, channels, kernel_size=1, stride=1),
                    nn.GroupNorm(n_groups, channels)
                )
                
                self.from_encoder_projection_list.append(from_encoder_projection)
                self.from_feature_projection_list.append(from_feature_projection)
        
        self.final_projection = nn.Conv2d(channels, channels, kernel_size=1, stride=1)

    def forward(self, features):
        '''
            features : dict keys : stage1, stage2, stage3, stage4
        '''
        feature = self.from_feature_projection_list[0](features['stage4'])
        
        for i, (encoder_projection, feature_projection) in enumerate(zip(self.from_encoder_projection_list[1:], self.from_feature_projection_list[1:])):
            feature = encoder_projection(features['stage' + str(3-i)]) + F.interpolate(feature, scale_factor=2, mode="nearest")
            feature = feature_projection(feature)
        
        return self.final_projection(feature)

## Transformer decoder
* DETR의 Transformer decoder와 동일하다.
* positional embedding을 DETR에서 사용한 방법이 아닌 Swin에서 사용한 방법을 쓰고 싶었지만 Window 기반이 아니기에 Table 크기가 너무 커져 비효율적인듯...?
* 모델에 저장되어 있는 쿼리는 배치가 1이기 때문에 이미지의 배치에 맞게 복사를 해야한다.
* 우선 auxiary 학습을 위해 Transformer의 모든 layer마다 출력을 한다.

In [4]:
class Transformer_decoder(Decoder):
    def __init__(
        self,
        n_query = 100,
        h = 8, 
        d_model = 256, 
        d_ff = 512, 
        dropout = 0.1, 
        N = 10,
        in_channels = 768,
        feature_size = (16, 16)
    ):
        super().__init__(h = h, d_model = d_model, d_ff = d_ff, dropout = dropout, N = N)
        self.n_query = n_query
        self.queries = nn.Parameter(torch.rand(1, n_query, d_model))
        self.linear = nn.Linear(in_channels, d_model)
        self.positional_encoding = nn.Parameter(torch.rand(1, in_channels, feature_size[0], feature_size[1]))

    def forward(self, feature):
        b, c, h, w = feature.shape
        positional_encoding = F.interpolate(self.positional_encoding, (h, w), mode="bilinear")
        keyvalue = rearrange(feature, 'b c h w -> b (h w) c') + rearrange(positional_encoding, 'b c h w -> b (h w) c')
        keyvalue = self.linear(keyvalue)
        queries = self.queries.expand(b, self.n_query, -1)
        return super().forward(queries, keyvalue)

## Segmentation Module
* Transformer decoder의 output을 가지고 segmentation vector와 classification vector를 출력한다.

In [5]:
class Segmentation_module(nn.Module):
    def __init__(
        self, 
        n_class,
        in_channels = 256,
        inter_channels = 256,
        out_channels = 256,
    ):
        super(Segmentation_module, self).__init__()
        self.mlp_segmentation_mask = nn.Sequential(
            nn.Linear(in_channels, inter_channels),
            nn.ReLU(),
            nn.Linear(inter_channels, out_channels),
        )
        self.classification = nn.Sequential(
            nn.Linear(in_channels, n_class + 1),
            nn.Softmax(dim=-1)
        )
        
    def forward(self, mask_embedded_vec):
        return self.mlp_segmentation_mask(mask_embedded_vec), self.classification(mask_embedded_vec)

## MaskFormer

In [6]:
class MaskFormer(nn.Module):
    def __init__(self, model_config):
        super(MaskFormer, self).__init__()
        self.backbone = Swin_transformer(
            patch_size = model_config["backbone_patch_size"], 
            window_size = model_config["backbone_window_size"], 
            merge_size = model_config["backbone_merge_size"], 
            model_dim = model_config["backbone_model_dim"], 
            num_layers_in_stage = model_config["backbone_num_layers_in_stage"]
        )
        
        in_channels = list(model_config["backbone_model_dim"] * 2 ** i for i in range(len(model_config["backbone_num_layers_in_stage"])))[::-1]
        self.pixel_decoder = Pixel_decoder(
            in_channels = in_channels,
            channels = model_config["pixel_decoder_channels"],
            n_groups = model_config["pixel_decoder_n_groups"]
        )
        
        in_channels = model_config["backbone_model_dim"] * 2 ** (len(model_config["backbone_num_layers_in_stage"]) - 1)
        self.transformer_decoder = Transformer_decoder(
            n_query = model_config["transformer_decoder_num_query"],
            h = model_config["transformer_decoder_num_head"], 
            d_model = model_config["transformer_decoder_dimension"],
            dropout = model_config["transformer_decoder_dropout"],
            N = model_config["transformer_decoder_num_layer"],
            in_channels = in_channels,
            feature_size = model_config["transformer_decoder_positional_size"]
        )
        
        self.segmentation_module = Segmentation_module(
            n_class = model_config["segmentation_module_num_class"], 
            in_channels = model_config["segmentation_module_in_channels"],
            out_channels = model_config["segmentation_module_out_channels"]
        )
    
    def forward(self, x):
        features = self.backbone(x)
        pixel_feature = self.pixel_decoder(features)
        b, C, H, W = pixel_feature.shape
        
        mask_embedded_vec = self.transformer_decoder(features['stage4'])[-1]
        
        segmentation_mask_vecs, classification_vecs = self.segmentation_module(mask_embedded_vec)
        
        segmentation_mask = torch.matmul(segmentation_mask_vecs, pixel_feature.view(b, C, -1)).view(b, -1, H, W)
        segmentation_mask = F.sigmoid(segmentation_mask)
        
        result = {}
        result["pred_masks"] = segmentation_mask
        result["pred_logits"] = classification_vecs
        return result

## matching
* DETR과 동일하게 bipartite maching을 한다.
* mask loss는 focal loss와 dice loss를 사용한다.

In [20]:
class HungarianMatcher(nn.Module):
    def __init__(self, w_class: float = 1, w_focal: float = 1, w_dice: float = 1):
        super().__init__()
        self.w_class = w_class
        self.w_focal = w_focal
        self.w_dice = w_dice
        
    @torch.no_grad()
    def dice_cost(self, predict, target):
        # predict : b * n_queries, h * w
        # target : b * n_obj, h * w
        numerator = 2 * (predict[:, None, :] * target[None, :, :]).sum(-1)
        denominator = predict.sum(-1)[:, None] + target.sum(-1)[None, :]
        cost_dice = 1 - (numerator + 1) / (denominator + 1)
        return cost_dice
    
    @torch.no_grad()
    def focal_cost(self, predict, target, gamma = 2., alpha = 0.25):
        # predict : b * n_queries, h * w
        # target : b * n_obj, h * w
        predict = predict[:, None, :].expand((predict.shape[0], target.shape[0], predict.shape[1]))
        target = target[None, :, :].expand((predict.shape[0], target.shape[0], target.shape[1]))
        ce = F.binary_cross_entropy_with_logits(predict, target, reduction='none')

        p_t = predict * target + (1 - predict) * (1 - target)
        focal_cost = ce * ((1 - p_t) ** gamma)

        alpha_t = alpha * target + (1 - alpha) * (1 - target)
        focal_cost = alpha_t * focal_cost
        return focal_cost.mean(-1)
        
    @torch.no_grad()
    def forward(self, out, targets):
        pred_logits = out["pred_logits"] # b, n, class + 1
        pred_masks = out["pred_masks"] # b, n, h, w
        target_logits = targets["labels"] # [ m_i for i in b]
        target_masks = targets["masks"] # [ m_i, h, w for i in b]
        bs, num_queries = pred_logits.shape[:2]
        device = pred_logits.device
        
        out_prob = pred_logits.flatten(0, 1).softmax(-1)  # [batch_size * num_queries, num_classes]
        out_mask = pred_masks.flatten(0, 1).flatten(1, 2)  # [batch_size * num_queries, h * w]

        tgt_ids = torch.cat([v for v in target_logits]) # [batch_size * num_obj]
        tgt_mask = torch.cat([v for v in target_masks]).flatten(1, 2) # [batch_size * num_obj, h * w]

        # cost :
        #     row : pred_querys
        #     col : target_obj
        cost_class = -out_prob[:, tgt_ids]                                                          # [batch_size * num_queries, batch_size * num_obj]
        cost_dice = self.dice_cost(out_mask, tgt_mask)                                              # [batch_size * num_queries, batch_size * num_obj] 
        cost_focal = self.focal_cost(out_mask, tgt_mask)                                            # [batch_size * num_queries, batch_size * num_obj]

        # Final cost matrix
        C = self.w_dice * cost_dice + self.w_class * cost_class + self.w_focal * cost_focal
        C = C.view(bs, num_queries, -1).cpu() # [batch_size, num_queries, batch_size * num_obj]

        sizes = [len(v) for v in target_masks]
        indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
        result = []
        for i, j in indices:
            i = torch.as_tensor(i, dtype=torch.int64, device=device)
            j = torch.as_tensor(j, dtype=torch.int64, device=device)
            result.append(i[j])
        return result

## Loss

In [32]:
class Maskformer_loss(nn.Module):
    def __init__(self, w_focal: float = 1., w_dice: float = 1., w_class: float = 1., w_noobj: float = 1.):
        super(Maskformer_loss, self).__init__()
        self.w_class = w_class
        self.w_focal = w_focal
        self.w_dice = w_dice
        self.w_noobj = w_noobj
        
    def class_loss(self, pred_logits, target_logits, match_indexs):
        device = pred_logits.device
        target_labels = torch.zeros(pred_logits.shape[:2], dtype=torch.int64, device=device)
        cost_no_obj = torch.ones(pred_logits.shape[2], device=device)
        cost_no_obj[0] *= self.w_noobj
        
        for i, match_index in enumerate(match_indexs):
            target_labels[i, match_index] = target_logits[i]
        
        class_loss = F.cross_entropy(pred_logits.flatten(0, 1), target_labels.flatten(0, 1), cost_no_obj)
        return class_loss
        
    def focal_loss(self, predict, target, gamma = 2.0, alpha = 0.25):
        # predict : b * n_queries, h * w
        # target : b * n_obj, h * w
        ce = F.binary_cross_entropy_with_logits(predict, target, reduction='none')

        p_t = predict * target + (1 - predict) * (1 - target)
        focal_cost = ce * ((1 - p_t) ** gamma)

        alpha_t = alpha * target + (1 - alpha) * (1 - target)
        focal_loss = alpha_t * focal_cost
        return focal_loss.mean()
        
    def dice_loss(self, predict, target):
        numerator = 2 * (predict * target).sum(-1)
        denominator = predict.sum(-1) + target.sum(-1)
        loss_dice = 1 - (numerator + 1) / (denominator + 1)
        return loss_dice.mean()
        
    def forward(self, out, targets, match_indexs):
        pred_logits = out["pred_logits"] # b, n, class + 1
        pred_boxes = out["pred_masks"] # b, n, h, w
        target_logits = targets["labels"] # [ m_i for i in b]
        target_boxes = targets["masks"] # [ m_i, h, w for i in b]
        
        tgt_mask = torch.cat([v for v in target_boxes]).flatten(1, 2) # [batch_size * num_obj, h * w]
        out_mask = pred_boxes.flatten(2)  # [batch_size, num_queries, h * w]
        out_mask = torch.cat([out_mask[i, match_index, :] for i, match_index in enumerate(match_indexs)]) # [batch_size * num_obj, h * w]
        
        class_loss = self.class_loss(pred_logits, target_logits, match_indexs) * self.w_class
        focal_loss = self.focal_loss(out_mask, tgt_mask) * self.w_focal
        dice_loss = self.dice_loss(out_mask, tgt_mask) * self.w_dice
        
        return class_loss + focal_loss + dice_loss

## Create model

In [33]:
img_size = (512, 512)

In [34]:
model_config = {
    "backbone_patch_size" : 4,
    "backbone_window_size" : 8,
    "backbone_merge_size" : 2,
    "backbone_model_dim" : 96,
    "backbone_num_layers_in_stage" : [2, 2, 6, 2],
    "pixel_decoder_n_groups" : 16,
    "pixel_decoder_channels" : 256,
    "transformer_decoder_positional_size" : (16, 16),
    "transformer_decoder_num_query" : 100,
    "transformer_decoder_dimension" : 256,
    "transformer_decoder_num_head" : 8,
    "transformer_decoder_dropout" : 0,
    "transformer_decoder_num_layer" : 6,
    "segmentation_module_in_channels" : 256,
    "segmentation_module_num_class" : 10,
    "segmentation_module_out_channels" : 256
}

In [35]:
# swin transformer 때문에 아직 256의 배수 크기 밖에 안된다. 이거 나중에 고치자
img = torch.randn((2, 3, 512, 512)).cuda()
target = {}
target['labels'] = [torch.zeros((15), dtype=torch.long).cuda(), torch.ones((4), dtype=torch.long).cuda()]
target['masks'] = [torch.zeros((15, 128, 128)).cuda(), torch.ones((4, 128, 128)).cuda()]

In [36]:
maskformer = MaskFormer(model_config).cuda()

In [37]:
result = maskformer(img)

  "See the documentation of nn.Upsample for details.".format(mode)


In [38]:
result.keys()

dict_keys(['pred_masks', 'pred_logits'])

In [39]:
print(result["pred_masks"].shape, result["pred_logits"].shape)

torch.Size([2, 100, 128, 128]) torch.Size([2, 100, 11])


In [40]:
matcher = HungarianMatcher().cuda()
Loss = Maskformer_loss(w_noobj=0.1).cuda()

In [41]:
match_indexs = matcher(result, target)
loss = Loss(result, target, match_indexs)
print(loss)

tensor(3.3066, device='cuda:0', grad_fn=<AddBackward0>)


In [None]:
loss.backward()