# imports

In [None]:
from fastai.vision.all import *

In [None]:
import torchvision

In [None]:
# export
#from FLAI.detect_symbol.exp import databunch as databunch_detsym
from FLAI.detect_symbol.exp import resnet_ssd as resnet_ssd_detsym
from FLAI.detect_symbol.exp import anchors_loss_metrics as anchors_loss_metrics_detsym
from FLAI.detect_symbol.exp import optimizer as optimizer_detsym
#from FLAI.detect_symbol.exp import init_model as init_model_detsym
#from FLAI.detect_symbol.exp import tensorboard_callback
#from FLAI.detect_symbol.exp import scheduling_train

In [None]:
#最后会引用detect_symbol.databunch，ImageList找不到
# sys.path.append('../sick_tree_detection')
# from sick_tree_detection.exp import anchors_loss_metrics as anchors_loss_metrics_sicktree

# functions

## 应对无目标的情况

In [None]:
def bb_pad_intlbl(samples, pad_idx=0):
    "Function that collect `samples` of labelled bboxes and adds padding with `pad_idx`."
    samples = [(s[0], *clip_remove_empty(*s[1:])) for s in samples]
    max_len = max([len(s[2]) for s in samples])
    def _f(img,bbox,lbl):
        bbox = torch.cat([bbox,bbox.new_zeros(max_len-bbox.shape[0], 4)])
        #lbl  = torch.cat([lbl, lbl .new_zeros(max_len-lbl .shape[0], int)+pad_idx])
        #在无目标也就是lbl为[]的情况下，lbl  = torch.cat([lbl, lbl .new_zeros(max_len-lbl .shape[0])+pad_idx])
        #上面的代码即使指定了dtype=torch.int得到的仍然是浮点数。会导致后面的报错不是索引
        if lbl.shape[0] != 0:
            lbl  = torch.cat([lbl, lbl .new_zeros(max_len-lbl .shape[0])+pad_idx])
        else:
            lbl = lbl.new_zeros(max_len, dtype = torch.int) + pad_idx

        
        return img,bbox,lbl
    return [_f(*s) for s in samples]

BBoxBlock = TransformBlock(type_tfms=TensorBBox.create, item_tfms=PointScaler, dls_kwargs = {'before_batch': bb_pad_intlbl})

## 获取BBox和label  
两个是分开进行的。并且BBox的顺序改成了先x后y，使用v1版的fastai的数据集的时候需要转换顺序。

In [None]:
#export
pat_coord = re.compile(r'\d+')
pat_clas = re.compile(r'\w+')
pat_imgName = re.compile(r'(\w+/\d+\.png)$')
pat_imgName = re.compile(r'(\w+/\d+\.jpg)$')
def get_label_from_df(fn, df, pat_imgName, box_col, cat_col):    
    fn = str(fn)
    pat_cat = re.compile(r'\w+')
    
    fn = pat_imgName.findall(str(fn))[0]
    cats = df.loc[fn,cat_col]
    cats = pat_clas.findall(cats)
    
    return cats

def get_boxes_from_df(fn, df, pat_imgName, box_col, cat_col):
    fn = str(fn)
    pat_num = re.compile(r'\d+')
    pat_cat = re.compile(r'\w+')
    fn = pat_imgName.findall(str(fn))[0]
    #print('dbg1', fn)
    
    boxes = df.loc[fn,box_col]
    boxes = pat_num.findall(boxes)
    #boxes = list(map(np.long, boxes))
    boxes = list(map(np.int32, boxes))
    boxes = np.array(boxes).reshape(-1,4)
    
    #fastai2里面bbox的顺序改成了xy的顺序。现在用的这个数据集还是v1里面的yx的顺序。这里调整一下
    boxes = boxes[...,[1, 0, 3, 2]]
    boxes = boxes.tolist()
    
    cats = df.loc[fn,cat_col]
    cats = pat_clas.findall(cats)
    #print('dbg2', fn, boxes, cats)
    assert len(boxes)==len(cats), 'length of bounding boxes and categories not equeal.'
    
    #print('dbg_boxes:', boxes)    
    return boxes

## 生成DataBlock
作用相当于之前的DataBunch  
item_tfms=Resize(128) 作用类似v1里面的after_open，可以对图片进行一些处理，但是这个处理无法作用在y上,如果需要改变图片尺寸连带y一起改变，应该在aug_transforms里面指定size参数

In [None]:
def get_db():
    get_y1 = partial(get_boxes_from_df, df=df, pat_imgName=pat_imgName, box_col='box', cat_col='cls')
    get_y2 = partial(get_label_from_df, df=df, pat_imgName=pat_imgName, box_col='box', cat_col='cls')
    
    syms = DataBlock(blocks=(ImageBlock, BBoxBlock, BBoxLblBlock),
                     get_items=get_image_files,
                     splitter=RandomSplitter(),
                     get_y=[get_y1, get_y2],
                     #item_tfms=Resize(128),
                     #batch_tfms=aug_transforms(size=(128,128)),
                     n_inp=1)
    return syms

## callback

In [None]:
class ExtValidCal(TrainEvalCallback):
    def before_fit(self):
        import pdb;pdb.set_trace()
        self.recorder.add_metric_names(['ext_valid'])
    def after_epoch(self, **kwargs):
        import pdb;pdb.set_trace()
        print('on after epoch', kwargs)
    '''def __getattr__(self,k): 
        import pdb;pdb.set_trace()
        if k in ['ext_valids', 'fld_names']:
            return self.ext_valids
        else:
            return getattr(self.learn, k)
    '''

## 网络-病树检测

### ssd_block

In [None]:
# export
class ssd_block(nn.Module):
    '''
    和detect_symbol里面的ssd_block相比只是去掉了宽高相关的部分
    '''
    def __init__(self, k, nin, n_clas):
        '''
        ssd头模块，它根据某层的特征图给出bbox预测信息，该模块的输出包含4个部分：
        -- loc：bbox中心偏移，2个值
        -- conf：目标信心，1个值
        -- clas：目标类别，n_clas个值
        ----------------------------------------
        参数：
        -- k：每个grid的anchor数
        -- nin：输入特征图通道数
        -- n_clas：目标类别数
        '''
        super().__init__()
        self.k = k
        self.oconv_loc = nn.Conv2d(nin, 2*k, 3, padding=1) # bbox center
        self.oconv_conf = nn.Conv2d(nin, 1*k, 3, padding=1) # confidence
        self.oconv_clas = nn.Conv2d(nin, n_clas*k, 3, padding=1) # classification
        
    def forward(self, x):
        return (resnet_ssd_detsym.flatten_grid_anchor(self.oconv_loc(x), self.k),
                resnet_ssd_detsym.flatten_grid_anchor(self.oconv_conf(x), self.k),
                resnet_ssd_detsym.flatten_grid_anchor(self.oconv_clas(x), self.k)
               )

### ResNetIsh_1SSD

In [None]:
class ResNetIsh_1SSD(resnet_ssd_detsym.ResNetIsh_1SSD):    
    def forward(self, x):
        outs = self._forward_impl(x)
        
        locs,confs,clss = [],[],[]
        for out in outs:
            locs += [out[0]]
            confs += [out[1]]
            clss += [out[2]]
        
        return (torch.cat(locs,dim=1),
                torch.cat(confs,dim=1),
                torch.cat(clss,dim=1)
               )

### 模型

In [None]:
def get_resnet18_1ssd(layers4fpn = False, num_classes = 1):
    #layers4fpn是否保留后面的两层给fpn用
    if not layers4fpn:
        return ResNetIsh_1SSD(block=torchvision.models.resnet.BasicBlock,
                   layers=[2,2,2],
                   chs=[64,128,256],
                   strides=[1,2,2],
                   pred_layerIds=[2],
                   num_anchors=1,
                   neck_block=resnet_ssd_detsym.cnv1x1_bn_relu,
                   head_chin=256,
                   head_block=ssd_block,
                   num_classes=num_classes)
    else:
        assert False, '没有实现'

### anchor_loss_metrics

In [None]:
TEST_2020904_ = False

In [None]:
def find_neibs(idx, grids = (49, 49), dis = 1):
    '''
    找到某个anchor周围相邻的anchors的下标里列表。距离默认1。
    这个任务中只有第一层的grids参与，所以只需要第一次的grids的尺寸。
    anchor也是1对1的。
    参数：
        idx：目标anchor在grid anchors(get_grid_anchors返回的gvs)列表中的下标
        grids: 尺寸
        dis：邻居的距离
    返回值：
        邻居的下标列表
    '''
    if TEST_2020904_:
        dis = 5
        
    gh, gw = grids
    x = idx % gh
    y = idx // gw
    ret = []
    for nx in range(x - dis, x + dis + 1):
        for ny in range(y - dis, y + dis + 1):
            if nx >= 0 and ny >= 0 and nx < gw and ny < gh \
                    and not(nx == x and ny == y):
                nidx = ny * gw + nx
                ret += [nidx]
    return ret      
                

In [None]:
def get_y(pts):
    keep = pts.abs().sum(-1).nonzero()[:,0]
    return keep

In [None]:
#定义一个新的GridAnchor_Functions，主要是修改:
#get_scroe_hits->get_hits,b2t->b2c,t2b->c2b;
#LblPts指定是使用ImageBBox还是labled points
class GridAnchor_Funcs(anchors_loss_metrics_detsym.GridAnchor_Funcs):
    def __init__(self, fig_hw, grids, device, LblPts = True):
        anchors = [[(0, 0)]]
        gvs,ghs,gws,avs,ahs,aws = anchors_loss_metrics_detsym.get_grids_anchors( \
                    fig_hw, grids, anchors)
        self.grids = grids
        self.LblPts = LblPts
        super().__init__(gvs, avs, device)
        
    #下面的三个函数都用不上了。防止被调用到。
    def get_scores_hits(self, gt_bb_or_lpts): 
        assert False, 'deleted'
    def b2t(self, gt_bb_or_lpts,idx,eps=1):
        assert False, 'deleted'
    def t2b(self,t,idx,eps=1):
        assert False, 'deleted'
        
    def get_hits(self, gt_bb_or_lpts): 
        # ground truch bbox center x,y
        if not self.LblPts:
            gt_cx = gt_bb_or_lpts[:,[0,2]].mean(-1)
            gt_cy = gt_bb_or_lpts[:,[1,3]].mean(-1)
        else:
            gt_cx = gt_bb_or_lpts[:,[0]].mean(-1)
            gt_cy = gt_bb_or_lpts[:,[1]].mean(-1)

        # 判断目标bbox的中心落在哪个cell内
        hits = ((gt_cx[:,None] >= self.gvs[:,0][None]) &
                (gt_cx[:,None] <  self.gvs[:,2][None]) &
                (gt_cy[:,None] >= self.gvs[:,1][None]) &
                (gt_cy[:,None] <  self.gvs[:,3][None]))
        
        return hits
   
    def b2c(self, gt_bb_or_lpts,idx,eps=1):
        '''
        gt_bb_or_lpts->center        
        '''
        cx,cy = self.gvs[idx,0],self.gvs[idx,1]
        gh,gw = self.ghs[idx],self.gws[idx]
        #ph,pw = self.ahs[idx],self.aws[idx]

        if not self.LblPts:
            bx = (gt_bb_or_lpts[:,0] + gt_bb_or_lpts[:,2])/2 # x of center of box
            by = (gt_bb_or_lpts[:,1] + gt_bb_or_lpts[:,3])/2 # y of center of box
        else:
            bx = gt_bb_or_lpts[:,0]
            by = gt_bb_or_lpts[:,1]
        hatsig_tx = (bx - cx)/gh
        hatsig_ty = (by - cy)/gw
        
        sig_tx = (hatsig_tx+0.5*eps)/(1+eps)
        sig_ty = (hatsig_ty+0.5*eps)/(1+eps)

        tx = torch.log(sig_tx/(1-sig_tx))
        ty = torch.log(sig_ty/(1-sig_ty))
        
        return torch.stack([tx, ty]).t()
  
    def c2b(self,t,idx,eps=1):
        '''
        center->gt_bb_or_lpts.
            如果是ImageBBox那么这些bbox都是没有宽高的。也就是后右下角坐标和左上角坐标相同。
            或者是Points
        '''
        cx,cy = self.gvs[idx,0],self.gvs[idx,1]
        gh,gw = self.ghs[idx],self.gws[idx]

        sig_tx = torch.sigmoid(t[...,0])
        sig_ty = torch.sigmoid(t[...,1])
        
        hatsig_tx = (1+eps)*(sig_tx-0.5) + 0.5
        hatsig_ty = (1+eps)*(sig_ty-0.5) + 0.5

        bx = hatsig_tx*gw + cx # x of center of box
        by = hatsig_ty*gh + cy # y of center of box
        
        if not self.LblPts:
            res = torch.stack([bx, by, bx, by],dim=0)
        else:
            res = torch.stack([bx, by],dim=0)
        res = res.permute(list(range(len(res.shape)))[1:]+[0])
        return res
    

In [None]:
def clas_acc(pred_batch, *gt_batch, gaf):
    '''
    classification accuracy
    '''
    posCnt = tensor(0.)
    totCnt = tensor(0.)
    for pred_clas,gt_bb_or_lpts,gt_clas in zip(pred_batch[2], *gt_batch):
        if not gaf.LblPts:
            keep = anchors_loss_metrics_detsym.get_y(gt_bb_or_lpts)
        else:
            keep = get_y(gt_bb_or_lpts)
        if keep.numel()==0: continue
        
        gt_bb_or_lpts = gt_bb_or_lpts[keep]
        gt_clas = gt_clas[keep]
        
        gt_bb_or_lpts = (gt_bb_or_lpts + 1) / 2
        gt_clas = gt_clas - 1 # the databunch add a 'background' class to classes[0], but we don't want that,so gt_clas-1
        
        hits = gaf.get_hits(gt_bb_or_lpts)
        idx = idx_fromHits(hits)
        
        pred_clas = pred_clas[idx]
        pred_clas = pred_clas.max(1)[1]
        
        posCnt += (pred_clas==gt_clas).sum().item()
        totCnt += gt_clas.shape[0]

    return posCnt/totCnt

In [None]:
def clas_L(pred_batch, *gt_batch, lambda_clas=1, clas_weights=None, gaf):
    '''
    class loss
    若某anchor对某object负责，则应训练其classification靠近该object的类别。
    '''
    loss = 0
    cnt = 0
    for pred_clas,gt_bb_or_lpts,gt_clas in zip(pred_batch[2], *gt_batch):
        if gaf.LblPts:
            keep = get_y(gt_bb_or_lpts)
        else:
            keep = anchors_loss_metrics_detsym.get_y(gt_bb_or_lpts)
        if keep.numel()==0: continue
        
        gt_bb_or_lpts = gt_bb_or_lpts[keep]
        gt_clas = gt_clas[keep]
        
        gt_bb_or_lpts = (gt_bb_or_lpts + 1) / 2
        gt_clas = gt_clas - 1 # the databunch add a 'background' class to classes[0], but we don't want that,so gt_clas-1
        
        hits = gaf.get_hits(gt_bb_or_lpts)
        idx = idx_fromHits(hits)
        
        pred_clas = pred_clas[idx]
        
        loss += F.cross_entropy(pred_clas, gt_clas, weight=clas_weights, reduction='sum')
        cnt += gt_clas.shape[0]
        
    return lambda_clas*loss/cnt

In [None]:
def cent_L(pred_batch, *gt_batch, lambda_cent=1, clas_weights=None, gaf):
    '''
    bbox center loss
    若某 anchor 对某 object 负责，则应训练其预测之 中心 靠近该 object box 之 中心。
    '''
    loss = 0
    cnt = 0
    for pred_txy,gt_bb_or_lpts,gt_clas in zip(pred_batch[0], *gt_batch):
        if not gaf.LblPts:
            keep = anchors_loss_metrics_detsym.get_y(gt_bb_or_lpts)
        else:
            keep = get_y(gt_bb_or_lpts)
        if keep.numel()==0: continue
          
        gt_bb_or_lpts = gt_bb_or_lpts[keep]
        gt_clas = gt_clas[keep]
        
        gt_bb_or_lpts = (gt_bb_or_lpts + 1) / 2
        gt_clas = gt_clas - 1
        
        if clas_weights is not None: ws = clas_weights[gt_clas]
        else: ws = None
        
        hits = gaf.get_hits(gt_bb_or_lpts)
        idx = idx_fromHits(hits)
        
        gt_t = gaf.b2c(gt_bb_or_lpts,idx,eps=1)
        pred_txy = pred_txy[idx]
        
        if ws is not None:
            tmp = ((gt_t[...,:2]-pred_txy)*ws[...,None]).abs().sum()
        else:
            tmp = (gt_t[...,:2]-pred_txy).abs().sum()
        
        loss += tmp
        cnt += len(idx)
    
    return lambda_cent*loss/cnt

In [None]:
def pConf_L(pred_batch, *gt_batch, lambda_pconf=1, clas_weights=None, gaf):
    '''
    positive confidence loss
    若某 anchor 为某 object 负责，则训练其 conf_score 靠近 1。
    '''
    loss = 0
    cnt = 0
    for pred_conf,gt_bb_or_lpts,gt_clas in zip(pred_batch[1], *gt_batch):
        if not gaf.LblPts:
            keep = anchors_loss_metrics_detsym.get_y(gt_bb_or_lpts)
        else:
            keep = get_y(gt_bb_or_lpts)
            
        if keep.numel()==0: continue
          
        gt_bb_or_lpts = gt_bb_or_lpts[keep]
        gt_clas = gt_clas[keep]
        
        gt_bb_or_lpts = (gt_bb_or_lpts + 1) / 2
        gt_clas = gt_clas - 1
        
        if clas_weights is not None: ws = clas_weights[gt_clas]
        else: ws = None
        
        hits = gaf.get_hits(gt_bb_or_lpts)
        idx = idx_fromHits(hits)
        
        conf_pos = pred_conf[idx]
#         conf_pos = torch.sigmoid(conf_pos)
#         tmp = (1-conf_pos).abs().sum()
        if ws is not None: 
            tmp = F.binary_cross_entropy_with_logits(conf_pos,torch.ones_like(conf_pos),weight=ws[...,None],reduction='sum')
        else: 
            tmp = F.binary_cross_entropy_with_logits(conf_pos,torch.ones_like(conf_pos),reduction='sum')
    
    
        loss += tmp
        cnt += len(idx)
        
    return lambda_pconf*loss/cnt

In [None]:
def nConf_L(pred_batch, *gt_batch, gaf, conf_th=0.5, lambda_nconf=1):
    '''
    negative confidence loss
    若某 anchor 不对任何 object 负责，且它与任何 object 的 匹配得分 都差于 threshold，则训练其 conf_score 靠近 0。
    '''
    loss = 0
    cnt = 0
    for pred_conf,gt_bb_or_lpts,_ in zip(pred_batch[1], *gt_batch):
        if not gaf.LblPts:
            keep = anchors_loss_metrics_detsym.get_y(gt_bb_or_lpts)
        else:
            keep = get_y(gt_bb_or_lpts)
        if keep.numel()==0: continue
        
        gt_bb_or_lpts = gt_bb_or_lpts[keep]
        gt_bb_or_lpts = (gt_bb_or_lpts + 1) / 2
        
        hits = gaf.get_hits(gt_bb_or_lpts)
        idx = idx_fromHits(hits)
        
        #positive
        tmp = (hits * 1).max(dim=0)[0]
        
        #取得命中的anchor周围的anchor的下标立标
        discards = []
        for hidx in idx:
            neibs = find_neibs(hidx, gaf.grids[0], dis = 1)            
            for i in neibs:
                discards += [i]
        #把周围的邻居加进来，剩下的就是negative了
        tmp[discards] = 1
        
        neg_idx = torch.where(tmp==0)[0] # 如果没有，该anchor是negative anchor
        
        conf_neg = pred_conf[neg_idx]
#         conf_neg = torch.sigmoid(conf_neg)
#         loss += conf_neg.abs().sum()
        tmp = F.binary_cross_entropy_with_logits(conf_neg,torch.zeros_like(conf_neg),reduction='sum')
        loss += tmp
        cnt += len(neg_idx)
        
    return lambda_nconf*loss/cnt

In [None]:
def yolo_L(pred_batch, *gt_batch, conf_th=0.5,
           lambda_cent=1, lambda_pconf=1, lambda_nconf=1, lambda_clas=1, clas_weights=None, gaf):
    '''
    与detect_symbol里面的yolo_L相比的区别是：
        不计算宽高方面的损失
        neg_idx要去掉find_neibs返回的discard列表
        
    clas_weights: 
    为了解决数据集的imbalance问题，一种方法是在dataloader中使用WeightedRandomSampler，但是这种方法不适用于目标检测问题。
    因为，（1）目标检测的label不是一个简单的数值（2）目标检测问题的一张图片可能包括不同类别的多个目标。
    所以为了解决目标检测问题中的imbalance问题，我们的方法是在损失函数中使用权重。
    为各类别分配权重，各目标对应的损失乘以该目标所属类别的权重。
    默认为None，即不使用权重。
    若设置非None，则clas_weights应该是一个一维tensor，其长度等于数据集的类别数。
    若设置为全1，则相当于不使用权重。
    合理的设置应保证所有元素之和等于数据集的类别数，否则相当于对损失函数的整体做了缩放。
    '''
    clas_loss = 0
    cent_loss = 0
    pconf_loss = 0
    nconf_loss = 0
    pos_cnt = 0
    neg_cnt = 0
    
    for pred_txy,pred_conf,pred_clas,gt_bb_or_lpts,gt_clas in zip(*pred_batch, *gt_batch):
        if not gaf.LblPts:
            keep = anchors_loss_metrics_detsym.get_y(gt_bb_or_lpts)
        else:
            keep = get_y(gt_bb_or_lpts)
        if keep.numel()==0: 
            #这时候所有anchor都是negative的。所以空白的也要贡献自己的loss
            conf_neg = pred_conf#所有anchor的
            nconf_loss += F.binary_cross_entropy_with_logits(conf_neg,torch.zeros_like(conf_neg),reduction='sum')
            neg_cnt += len(pred_conf)
            continue
          
        gt_bb_or_lpts = gt_bb_or_lpts[keep]
        gt_clas = gt_clas[keep]
        
        gt_bb_or_lpts = (gt_bb_or_lpts + 1) / 2
        gt_clas = gt_clas - 1 # the databunch add a 'background' class to classes[0], but we don't want that,so gt_clas-1
        
        if clas_weights is not None: ws = clas_weights[gt_clas]
        else: ws = None
        
        hits = gaf.get_hits(gt_bb_or_lpts)
        idx = idx_fromHits(hits)
        
        # classification loss
        pred_clas = pred_clas[idx]
        clas_loss += F.cross_entropy(pred_clas, gt_clas, weight=clas_weights, reduction='sum')
        
        # bbox center loss
        gt_t = gaf.b2c(gt_bb_or_lpts,idx,eps=1)
        pred_txy = pred_txy[idx]
        if ws is not None:
            cent_loss += ((gt_t[...,:2]-pred_txy)*ws[...,None]).abs().sum()
        else:
            cent_loss += (gt_t[...,:2]-pred_txy).abs().sum()
        
        # positive confidence loss
        conf_pos = pred_conf[idx]
        if ws is not None: 
            pconf_loss += F.binary_cross_entropy_with_logits(conf_pos,torch.ones_like(conf_pos),weight=ws[...,None],reduction='sum')
        else: 
            pconf_loss += F.binary_cross_entropy_with_logits(conf_pos,torch.ones_like(conf_pos),reduction='sum')

        #positive
        tmp = (hits * 1).max(dim=0)[0]
        
        #取得命中的anchor周围的anchor的下标立标
        discards = []
        for hidx in idx:
            neibs = find_neibs(hidx, gaf.grids[0], dis = 1)            
            for i in neibs:
                discards += [i]
        #把周围的邻居加进来，剩下的就是negative了
        tmp[discards] = 1
         
        neg_idx = torch.where(tmp==0)[0] # 如果没有，该anchor是negative anchor
        
        conf_neg = pred_conf[neg_idx]
        nconf_loss += F.binary_cross_entropy_with_logits(conf_neg,torch.zeros_like(conf_neg),reduction='sum')
        
        pos_cnt += len(idx)
        neg_cnt += len(neg_idx)
        
    
    if pos_cnt > 0:#测试的极端情况碰到都是空白的。只有nconf_loss在前面计算了。
        clas_loss  = lambda_clas  * clas_loss  /pos_cnt
        cent_loss  = lambda_cent  * cent_loss  /pos_cnt
        pconf_loss = lambda_pconf * pconf_loss /pos_cnt
    nconf_loss = lambda_nconf * nconf_loss /neg_cnt
    
    return clas_loss + cent_loss + pconf_loss + nconf_loss

In [None]:
def bbox2c(b):
    '''
    将bbox的（左上x，左上y，右下x，右下y）表示变为（中心x，中心y）表示
    '''
    cx = b[...,[0,2]].mean(-1)[...,None]
    cy = b[...,[1,3]].mean(-1)[...,None]
    
    return torch.cat([cx,cy],dim=-1)

In [None]:
def idx_fromHits(hits):
    idx = (hits * 1).max(1)[1]
    return idx

In [None]:
def cent_d(pred_batch, *gt_batch, gaf):
    '''
    bbox center difference
    '''
    dif = tensor(0.)
    cnt = tensor(0.)
    for pred_txy,gt_bb_or_lpts,_ in zip(pred_batch[0], *gt_batch):
        if not gaf.LblPts:
            keep = anchors_loss_metrics_detsym.get_y(gt_bb_or_lpts)
        else:
            keep = get_y(gt_bb_or_lpts)
        if keep.numel()==0: continue
          
        #pred_t = torch.cat([pred_txy,pred_thw],dim=1)
        pred_t = pred_txy
        
        gt_bb_or_lpts = gt_bb_or_lpts[keep]
        gt_bb_or_lpts = (gt_bb_or_lpts + 1) / 2
        
        hits = gaf.get_hits(gt_bb_or_lpts)
        idx = idx_fromHits(hits)
        
        pred_t = pred_t[idx]
        if not gaf.LblPts:
            pred_c = bbox2c(gaf.c2b(pred_t,idx))[...,:2]
            gt_c = bbox2c(gt_bb_or_lpts)[...,:2]
        else:
            pred_c = gaf.c2b(pred_t,idx)
            gt_c = gt_bb_or_lpts
        
        tmp = (gt_c - pred_c).abs().sum()
        dif += tmp
        cnt += len(idx)
    
    return dif/cnt/2

In [None]:
def split_model(model):
#     idb.set_trace()
    group0 = ModuleList()
    group1 = ModuleList()
    
    pretrained_layers = Sequential(model.conv1, model.bn1, model.res_blocks[:4])
#     noPretrain_layers = Sequential(model.res_blocks[4], model.neck_blocks, model.head_block)
    noPretrain_layers = Sequential(model.neck_blocks, model.head_block)
    
    #把pretrained layers分作batchnorm部分（放在group1），和非batchnorm部分（放在group0）
    for m in pretrained_layers.modules():
        if isinstance(m,bn_types): group1.append(m)
        elif isinstance(m,bias_types): group0.append(m)
            
    #把非pretrain的层放到group1
    for m in noPretrain_layers.children():
        group1.append(m)
    
    return [group0, group1]

# test

In [None]:
src_path = '/home/dev/jupyter/detect_symbol/data/ds_20200429/'

In [None]:
path = src_path + 'images'

In [None]:
path = Path(path)

In [None]:
path.ls()

In [None]:
df = pd.read_csv(src_path + 'gends.csv',index_col=0)
df = df.set_index('image')
df.head()

In [None]:
syms = get_db()

在docker中如果没有设置-shm-size，不设置num_workers=0会使用_MultiProcessingDataLoaderIter，导致错误： 
Unable to write to file </torch_18692_1954506624>
https://discuss.pytorch.org/t/unable-to-write-to-file-torch-18692-1954506624/9990

在fastai v1中对应的错误是内存溢出。

In [None]:
dls = syms.dataloaders(path, bs = 16, num_workers = 0)

In [None]:
dls.show_batch(max_n = 1)

In [None]:
#syms.summary(path)

In [None]:
#dts = syms.datasets(path)

In [None]:
#dts[1]

In [None]:
#df.loc['images/02364.jpg']

In [None]:
device = torch.device('cuda')

In [None]:
#device = torch.device('cpu')

## 模型和训练-符号检测

In [None]:
model = resnet_ssd_detsym.get_resnet34_1ssd()

In [None]:
gvs,_,_,avs,_,_ = anchors_loss_metrics_detsym.get_ga666()
gaf = anchors_loss_metrics_detsym.GridAnchor_Funcs(gvs,avs,device)

In [None]:
loss_func = partial(anchors_loss_metrics_detsym.yolo_L, gaf=gaf, conf_th=1, clas_weights=None, lambda_nconf=10)

In [None]:
evc = ExtValidCal()

In [None]:
#learn = cnn_learner(dls, model, pretrained=False)
learn = Learner(dls, model, loss_func = loss_func, device = device)#, metrics = [evc])

In [None]:
evc.learn = learn

In [None]:
learn.fit(1)

### 用病树检测的模型试一下

In [None]:
model = get_resnet18_1ssd(num_classes = 17)

In [None]:
#model.load_state_dict(torch.load('../sick_tree_detection/models/pretrained_res18_1ssd_detsym17clas.pth'))

In [None]:
!ls ../sick_tree_detection/models

In [None]:
gaf = GridAnchor_Funcs(fig_hw = (776,776)
                         , grids = [(49,49)]
                         , device = device)
gvs, avs = gaf.gvs, gaf.avs

In [None]:
clas_cnts = [11191, 712, 1362, 224, 8710, 1212, 1139, 8686, 857, 2176, 6175, 1869, 14794, 1435, 13628, 9618, 1462]
weights = anchors_loss_metrics_detsym.get_clasWeights(clas_cnts,10)
weights = tensor(weights).float().to(device)

In [None]:
loss_func = partial(yolo_L, gaf=gaf, conf_th=1, clas_weights=None, lambda_nconf=10)

In [None]:
learn = Learner(dls, model, loss_func = loss_func, device = device)

In [None]:
learn.lr_find()

In [None]:
learn.fit(10)

In [None]:
def dbg():
    import pdb;pdb.set_trace()
    #dls.show_batch()
    #dls = syms.dataloaders(path)
    syms2 = DataBlock(blocks=(ImageBlock, BBoxBlock, BBoxLblBlock),
                 get_items=get_image_files,
                 splitter=RandomSplitter(),
                 get_y=[get_y1, get_y2],
                 item_tfms=Resize(128),
                 #batch_tfms=aug_transforms(),
                 n_inp=1)
    dls = syms2.dataloaders(path)
    dls.show_batch()
dbg()

In [None]:
coco_source = untar_data(URLs.COCO_TINY)

In [None]:
images, lbl_bbox = get_annotations(coco_source/'train.json')
img2bbox = dict(zip(images, lbl_bbox))

In [None]:
coco_source

In [None]:
def get_y1(o):
    #print('dbg0', o)
    #return [img2bbox[o.name][0], img2bbox[o.name][1]]
    return img2bbox[o.name][0]

def get_y2(o): 
    return img2bbox[o.name][1]

coco = DataBlock(blocks=(ImageBlock, BBoxBlock, BBoxLblBlock),
                 get_items=get_image_files,
                 splitter=RandomSplitter(),
                 get_y=[get_y1, get_y2], #[lambda o: img2bbox[o.name][0], lambda o: img2bbox[o.name][1]]
                 item_tfms=Resize(128),
                 batch_tfms=aug_transforms(),
                 n_inp=1)

In [None]:
cocodls = coco.dataloaders('/root/.fastai/data/coco_tiny')

In [None]:
cocodls.show_batch()


In [None]:
# pets = DataBlock(blocks = (ImageBlock, CategoryBlock),
#                  get_items=get_image_files, 
#                  splitter=RandomSplitter(seed=42),
#                  get_y=using_attr(RegexLabeller(r'(.+)_\d+.jpg$'), 'name'),
#                  item_tfms=Resize(460),
#                  batch_tfms=aug_transforms(size=224, min_scale=0.75))
# dls = pets.dataloaders(path/"images")