In [1]:
import argparse
import datetime
import json
import random
import time
from pathlib import Path

import numpy as np
import torch
from torch.utils.data import DataLoader
import datasets
import util.misc as utils
import datasets.samplers as samplers
from datasets import build_dataset, get_coco_api_from_dataset
from engine import evaluate, train_one_epoch
from models import build_model



In [5]:
import torch
import torch.nn.functional as F
from torch import nn
import math
import numpy as np
from util import box_ops
from util.misc import (NestedTensor, nested_tensor_from_tensor_list,
                       accuracy, get_world_size, interpolate,
                       is_dist_avail_and_initialized, inverse_sigmoid)

from models.backbone import build_backbone
from models.matcher import build_matcher
from models.segmentation import (DETRsegm, PostProcessPanoptic, PostProcessSegm,
                           dice_loss, sigmoid_focal_loss)
from models.deformable_transformer import build_deforamble_transformer
import copy
from torchvision.ops.boxes import batched_nms 


###출력헤드 만들기.


In [6]:
class MLP(nn.Module):
    def __init__(self,input_dim,hidden_dim,output_dim,num_layers):
        #중간은 hidden_dim으로 때우나보다.
        super().__init__()
        self.num_layers = num_layers 
        #self.input_dim = input_dim 다른 메소드에서 사용하면 이렇게 선언
        h = [hidden_dim]*(num_layers-1) #마지막 레이어는 출력이라.
        self.layers = nn.ModuleList(nn.Linear(n,k) for n,k in zip ([input_dim]+h,h+[output_dim]))
    def forward(self,x):
        for layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers-1 else layer(x)
            #마지막 레이어는 relu안하기.
        return x
        

In [13]:
class IpsHead(nn.Module):
    def __init__(self,hidden_dim):
        super().__init__()
        self.flatten = nn.Flatten(0,1)#batch_size * num_objects x hidden_dim
        self.linear2 = nn.Linear(hidden_dim,1)
        self.activate = nn.Sigmoid() # 확률로 출력해야하니.
        # nn.init.constant_(self.linear2.weight,0)
        # nn.init.constant_(self.linear2.bias,0)
    def fressze_obj_head(self):
        self.obj_head.eval()
    def forward(self,x):
        out = self.flatten(x)
        out = self.linear2(out)
        out = self.activate(out)
        out = out.unflatten(0,x.shape[:2])
        return out

In [None]:
def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

class Deformable(nn.Module):
    def __init__(self,backbone,transformer,num_classes,num_queries,num_feature_levels,aux_loss=True,with_box_refine=True,two_stage=True):
        super().__init__()
        self.num_feature_levels = num_feature_levels
        self.backbone = backbone
        self.num_queries = num_queries
        self.transformer = transformer
        hidden_dim = transformer.d_model #들어가는 차원수, 어디에 쓰냐.
        self.class_embed = nn.Linear(hidden_dim,num_classes)
        self.bbox_embed  = MLP(hidden_dim,hidden_dim,4,3)
        self.object_head = IpsHead(hidden_dim)
        if not two_stage:
            self.query_embed = nn.Embedding(num_queries,hidden_dim*2)
        if num_feature_levels > 1:
            num_backbone_outs = len(backbone.strides) #backbone의 출력리스트 개수
            input_proj_list = []
            for _ in range(num_backbone_outs):
                in_channels = backbone.num_channels[_]
                #각 in_channels받아오기.
                input_proj_list.append(
                    nn.Sequential(nn.Conv2d(in_channels,hidden_dim,kernel_size=1),nn.GroupNorm(32,hidden_dim),)
                )
            for _ in range(num_feature_levels - num_backbone_outs):
                input_proj_list.append(nn.Sequential(nn.Conv2d(in_channels,hidden_dim,kernel_size=3,stride=2,padding=1),nn.GroupNorm(32,hidden_dim),))
                in_channels = hidden_dim
            self.input_proj = nn.ModuleList(input_proj_list)
        else:
            self.input_proj = nn.ModuleList([nn.Sequential(nn.Conv2d(backbone.num_channels[0],hidden_dim,kernel_size=1),nn.GroupNorm(32,hidden_dim))]) #self.num_channels은 중간출력값이 없어 2048일것
        #여기까지 backbone setting. 선언 및 초기화.
        
        self.backbone = backbone
        self.aux_loss = aux_loss
        self.with_box_refine = with_box_refine
        self.two_stage= two_stage
        prior_prob =0.01
        bias_value = -math.log((1-prior_prob)/prior_prob)
        self.class_embed.bias.data = torch.ones(num_classes)*bias_value
        nn.init.constant_(self.bbox_embed.layers[-1].weight.data,0)
        nn.init.constant_(self.bbox_embed.layers[-1].bias.data,0)
        #레이어 편향초기화.
        for proj in self.input_proj:
            nn.init.xavier_uniform(proj[0].weight,gain=1)
            nn.init.constant_(proj[0].bias,0) #proj[0] = Conv.
        #region_proposal generation   
        num_pred = (transformer.decoder.num_layers+1) if two_stage else transformer.decoder.num_layers
        #num_pred는 디코더의 레이어 수.
        if with_box_refine:
            
            self.class_embed = _get_clones(self.class_embed,num_pred)
            #디코더의 레이어마다 예측을 해준다.
            self.bbox_embed = _get_clones(self.bbox_embed,num_pred)
            self.object_head = _get_clones(self.object_head,num_pred)
            nn.init.constant_(self.bbox_embed[0].layer[-1].bias.data[2:],-2.0) #초기바운딩 박스예측에서 작은 값을 설정.
            self.transformer.decoder.bbox_embed = self.bbox_embed
            
        else:
            nn.init.constant_(self.bbox_embed.layer[-1].bias.data[2:],-2.0)
            self.class_embed = nn.ModuleList([self.class_embed for _ in range (num_pred)])
            self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range (num_pred)])
            self.transformer.decoder.bbox_embed = None
        if two_stage:
            #애초에 안넣으면,self.transformer.decoder.class_embed는 None
            
            self.transformer.decoder.class_embed = self.class_embed
            for box_embed in self.bbox_embed:
                nn.init.constant_(box_embed.layers[-1].bias.data[2:0],0.0)
    def forward(self,samples: NestedTensor):
        """The forward expects a NestedTensor, which consists of:
               - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
               - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels"""
               
        if not isinstance(samples,NestedTensor):
            samples = nested_tensor_from_tensor_list(samples)
        features, pos = self.backbone(samples)
        
        srcs = []
        masks = []
        for l,feat in enumerate(features):
            src,mask = feat.decompose()
            srcs.append(self.input_proj[l](src))
            masks.append(mask)
            assert mask is not None
        if self.num_feature_levels > len(srcs):
            _len_srcs = len(srcs)
            for l in range(_len_srcs,self.num_feature_levels):
                if l == _len_srcs:
                    src = self.input_proj[l](features[-1].tensors)
                else:
                    src = self.input_proj[l](srcs[-1])
                m = samples.mask
                mask = F.interpolate(m[None].float(),size= src.shape[-2:]).to(torch.bool)[0]
                pos_l = self.backbone[l](NestedTensor(src,mask)).to(src.dtype)
                srcs.append(src)
                masks.append(mask)
                pos.append(pos_l)
        query_embeds = None
        if not self.two_stage:
            query_embeds = self.query_embed.weight
        hs,init_reference,inter_references,enc_outputs_class,enc_outputs_coord_unact = self.transformer(srcs,mask,pos,query_embeds)
        #what is hs, reference
        outputs_classes = []
        outputs_coords = []
        
        for lvl in range(hs.shape[0]):
            if lvl == 0:
                reference = init_reference
            else:
                reference = inter_references[lvl-1]
            reference = inverse_sigmoid(reference) #why?
            outputs_classes = self.class_embed[lvl](hs[lvl])
            tmp = self.bbox_embed[lvl](hs[lvl])
            if reference.shape[-1] ==4:
                tmp += reference
            else:
                assert reference.shape[-1] == 2
                tmp[...,:2] += tmp.sigmoid()
            outputs_coords = tmp.sigmoid()
            outputs_coords.append(outputs_coords)
            outputs_classes.append(outputs_classes)
        outputs_class = torch.stack(outputs_classes)
        outputs_coord = torch.stack(outputs_coords)
        
        out = {'pred_logits':outputs_class[-1],'pred_boxes':outputs_coord[-1]}
        if self.aux_loss:
            out['aux_outputs'] = self._set_aux_loss(outputs_class,outputs_coord)
        if self.two_stage:
            enc_outputs_coord = enc_outputs_coord_unact.sigmoid()
            #0~1로 정규화됨. 
            out['enc_outputs'] = {'pred_logits':enc_outputs_class,'pred:boxes':enc_outputs_coord}
        return out
    
    

In [None]:
class SetCriterion(nn.Module):
    """ This class computes the loss for DETR.
    The process happens in two steps:
        1) we compute hungarian assignment between ground truth boxes and the outputs of the model
        2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
    """
    def __init__(self,num_classes,matcher,weight_dict,losses,focal_alpha=0.25):
        super().__init__()
        self.num_classes = num_classes
        self.matcher = matcher
        self.weight_dict = weight_dict
        self.losses = losses
        self.focal_alpha = focal_alpha
        
    def loss_labels(self,outputs,targets,indices,num_boxes,log=True):
        assert 'pred_logits' in outputs
        src_logits = outputs['pred_logits']
        idx = self._get_src_permutation_idx(indices)
        #idx = batch_idx, src_idx
        target_classes_o = torch.cat([t["labels"][J] for t,(_,J) in zip(targets,indices)])
        #(t, (src_idx, tgt_idx)) 이거 확인해야돼!
        target_classes = torch.full(src_logits.shape[:2],self.num_classes,
                                    dtype=torch.int64,
                                    device=src_logits.device)
        target_classes[idx] = target_classes_o
        target_classes_onehot = torch.zeros([src_logits.shape[0],src_logits.shape[1],src_logits.shape[2]+1])
        target_classes_onehot.scatter_(2,target_classes.unsqueeze(-1),1)
        #target_classes 텐서의 마지막 차원에 새로운 차원을 추가.
        target_classes_onehot = target_classes_onehot[:,:,:-1]
        loss_ce = sigmoid_focal_loss(src_logits,target_classes_onehot,num_boxes,alpha=self.focal_alpha,gamma=2)*src_logits.shape[1]
        losses = {'loss_ce':loss_ce}
        
        if log:
            losses['class_error'] = 100 - accuracy(src_logits[idx],target_classes_o)[0]
        return losses
    def loss_cardinality(self, outputs, targets, indices, num_boxes):#+second_indices
        """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
        This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
        """
        pred_logits = outputs['pred_logits']
        device = pred_logits.device
        tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device)
        # Count the number of predictions that are NOT "no-object" (which is the last class)
        card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
        card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
        losses = {'cardinality_error': card_err}
        return losses

    def loss_boxes(self,outputs,targets,indices,num_boxes):
        assert 'pred_boxes' in outputs
        idx = self._get_src_permutation_idx(indices)
        #idx가 특정 인덱스인지 단체로있는건지 모르겠음.
        src_boxes = outputs['pred_boxes'][idx]
        target_boxes = torch.cat([t['boxes'][i] for t,(_,i) in zip(targets,indices)],dim=0)
        loss_bbox = F.l1_loss(src_boxes,target_boxes,reduction='none')
        losses = {}
        losses['loss_bbox'] = loss_bbox.sum()/num_boxes
        
        loss_giou = 1- torch.diag(box_ops.generalized_box_iou(box_ops.box_cxcywh_to_xyxy(src_boxes),box_ops.box_cxcywh_to_xyxy(target_boxes)))
        losses['loss_giou'] = loss_giou.sum()/num_boxes
        return losses
    
    def loss_masks(self, outputs, targets, indices, num_boxes):#tmp_indices,epoch
        """Compute the losses related to the masks: the focal loss and the dice loss.
           targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
        """
        assert "pred_masks" in outputs
        
        src_idx = self._get_src_permutation_idx(indices)
        tgt_idx = self._get_tgt_permutation_idx(indices)

        src_masks = outputs["pred_masks"]

        # TODO use valid to mask invalid areas due to padding in loss
        target_masks, valid = nested_tensor_from_tensor_list([t["masks"] for t in targets]).decompose()
        target_masks = target_masks.to(src_masks)

        src_masks = src_masks[src_idx]
        # upsample predictions to the target size
        src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:],
                                mode="bilinear", align_corners=False)
        src_masks = src_masks[:, 0].flatten(1)

        target_masks = target_masks[tgt_idx].flatten(1)

        losses = {
            "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
            "loss_dice": dice_loss(src_masks, target_masks, num_boxes),
        }
        return losses
    def _get_src_permutation_idx(self, indices):
        # permute predictions following indices
        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
        src_idx = torch.cat([src for (src, _) in indices])
        return batch_idx, src_idx

    def _get_tgt_permutation_idx(self, indices):
        # permute targets following indices
        batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
        tgt_idx = torch.cat([tgt for (_, tgt) in indices])
        return batch_idx, tgt_idx
    def get_loss(self,loss,outputs,targets,indices,num_boxes,**kwargs):
        loss_map = {
            'labels' : self.loss_labels,
            'cardinality' : self.loss_cardinality,
            'boxes' : self.loss_boxes,
            'masks' : self.loss_masks
        }
        assert loss in loss_map, f'do you really want to compute {loss} loss?'
        return loss_map[loss](outputs,targets,indices,num_boxes,**kwargs)
    def forward(self,outputs,targets):
        outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs' and k != 'enc_outputs'}
        indices = self.matcher(outputs_without_aux,targets)
        num_boxes = sum(len(t["labels"]) for t in targets)
        num_boxes = torch.as_tensor([num_boxes],dtype=torch.float,device=next(iter(outputs.values())).device)
        num_boxes = torch.clamp(num_boxes/get_world_size(),min=1).item()
        
        losses = {}
        for loss in self.losses:
            kwargs = {}
            losses.update(self.get_loss(loss,outputs,targets,indices,num_boxes,**kwargs))  #lossmap 갱신
        if 'aux_outputs' in outputs:
            for i,aux_outputs in enumerate(outputs['aux_outputs']):
                indices = self.matcher(aux_outputs,targets)
                if loss =='masks':
                    continue
                kwargs = {}
                if loss =='labels':
                    kwargs['log'] = False
                l_dict = self.get_loss(loss,aux_outputs,targets,indices,num_boxes,**kwargs)
                l_dict = {k + f'_{i}': v for k,v in l_dict.items()}
                losses.update(l_dict)
        if 'enc_output' in outputs:
            enc_outputs = outputs['enc_outputs']
            bin_targets = copy.deepcopy(targets)
            for bt in bin_targets:
                bt['labels'] = torch.zeros_like(bt['labels']) #no_object로 설정.
            indices = self.matcher(enc_outputs,bin_targets)
            for loss in self.losses:
                if loss == 'masks':
                
                    continue
                kwargs = {}
                if loss == 'labels':
                    kwargs['log'] = False
                l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs)
                #이때 target은 0으로 차있는 텐서/
                l_dict = {k + f'_enc': v for k, v in l_dict.items()}
                losses.update(l_dict)
                
            return losses
            

In [16]:
class_embed = nn.Linear(256,4)
input_tensor = torch.randn(256, 256)
output_tensor = class_embed(input_tensor)
print(output_tensor)
class_object_head = IpsHead(256)
input_tensor2 = torch.randn(3,20,256)
output_tensor2 = class_object_head(input_tensor2)
print(output_tensor2)

tensor([[-1.0155,  0.7537,  0.1728,  1.0055],
        [ 0.6616, -0.3418,  1.4101,  1.0494],
        [-0.2863, -0.4926, -0.6691,  0.4535],
        ...,
        [-0.6731, -0.6795, -0.7950, -0.0796],
        [ 0.1980,  0.5483, -0.5475,  0.4446],
        [-0.7517, -0.4201, -0.2257, -0.9833]], grad_fn=<AddmmBackward0>)
tensor([[[0.6128],
         [0.2261],
         [0.4610],
         [0.5335],
         [0.5371],
         [0.4147],
         [0.5092],
         [0.5341],
         [0.3849],
         [0.5527],
         [0.4893],
         [0.6785],
         [0.2209],
         [0.4640],
         [0.6745],
         [0.2606],
         [0.4794],
         [0.6315],
         [0.2708],
         [0.5134]],

        [[0.6860],
         [0.7088],
         [0.6086],
         [0.4202],
         [0.3929],
         [0.5026],
         [0.6284],
         [0.4770],
         [0.4755],
         [0.6593],
         [0.4052],
         [0.3331],
         [0.3536],
         [0.7010],
         [0.7606],
         [0.4901]

'이거머임'