In [1]:
from init import *
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from easydl import clear_output
import torch
from torch import nn
import torch.nn.functional as F
from modules.operates.ops import conv3x3,conv1x1,residualBlock
from modules.blocks.yolo_blocks import C3, Conv

RUNDIR: runs/log/80971
[easydl] tensorflow not available!


In [2]:
class Fusion2Backbone(nn.Module):
    def __init__(self,c1,ns=[],ss=[],c_size=True):
        super(Fusion2Backbone,self).__init__()
        
        self.seq = nn.Sequential()
        self.seq1 = nn.Sequential()
        self.seq2 = nn.Sequential()
        c=c1
        for i,(n,s) in enumerate(zip(ns,ss)):
            subseq = nn.Sequential()
            for _ in range(n):
                subseq.add_module('c3_'+str(i)+'_'+str(_),C3(c1,c,1))
            
            if not c_size: c*=s
            self.seq.add_module('subseq_'+str(i),subseq)
            self.seq1.add_module('c1_'+str(i),Conv(c1,c,k=3,s=s,p=1))
            c1=c
            
            self.seq2.add_module('c1_'+str(i),Conv(c1*2,c1,k=3,s=1,p=1))
        self.cov_act1 = Cov_Act(c1,c,3,1,1)
        self.cov_act2 = Cov_Act(c1,c,3,1,1)
        
    def forward(self,x1,x2=None):
        if x2  is  None:
            x2 = x1
        for m,m1,m2 in zip(self.seq,self.seq1,self.seq2):
            x1 = m(x1)
            y1 = m1(x1)
            
            x2 = m(x2)
            y2 = m1(x2)
            x1 = torch.cat([y1*0.8,y2*0.2],1)
            x2 = torch.cat([y2*0.8,y1*0.2],1)
            
            x1 = m2(x1)
            x2 = m2(x2)
            x1 = self.cov_act1(x1)
            x2 = self.cov_act2(x1)
        return x1,x2
    
class Cov_Act(nn.Module):
    def __init__(self,c1,c2,k,s,p):
        super(Cov_Act,self).__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, p, groups=1, bias=False)
        # self.bn = nn.BatchNorm2d(c2,eps=0.001,momentum=0.03)
        self.act = nn.SiLU()
    def forward(self,x):
        x = self.conv(x)
        x = self.act(x)
        return x

In [1]:
import  torchvision
resize = torchvision.transforms.Resize((320,320), interpolation=2)
epochs=500 
device = 'cuda:6'

fb = Fusion2Backbone(3,ns=[6,1,1,6,1,1,6,1,1],ss=[1,1,1,1,1,1,1,1,1])
n_claess,train_loader, val_loder = get_dataloader(cfg0)
optimizer, scheduler = get_optimizer_scheduler(fb)
fb = fb.train().to(device).half()

for epoch  in  range(epochs): 
    losses =[]
    # optimizer.zero_grad()
    for images in train_loader:
        optimizer.zero_grad()
        
        t1 = images['t1'].to(device).half()
        t2 = images['t2'].to(device).half()
        t1_b = images['t1_b'].to(device).half()
        t2_b = images['t2_b'].to(device).half()
       
        r1,r2 = fb(t1,t2)
        t1_b,t2_b  = resize(t1_b),resize(t2_b)
        loss = torch.sum(torch.abs(r1-t1_b))/r1.shape.numel()*10 + torch.sum(torch.abs(r2-t2_b))/r2.shape.numel()*10
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        losses.append(loss.detach().cpu().numpy())
    print('epoch'+str(epoch)+' :', np.sum(losses)/len(losses))


In [63]:
cv2.imwrite('test.png',img)

True

In [11]:
def show(t1_b):
    img = t1_b[0]
    img = img.detach().cpu().numpy()
    img = img.transpose(1,2,0)
    img = (img-img.min())/img.max()
    img*=255
    img = img.astype(np.uint8)
    return img

In [18]:
img = img.astype(np.uint8)
cv2.imwrite('test.png',img)

True

In [18]:
# from torchstat import stat
# stat(fb, (3, 224, 224))

In [6]:
# from torchsummary import summary
# summary(fb.to(device),input_size=(3,640,640),batch_size=-1)

In [23]:
from thop import profile
import torchprof
from fvcore.nn import FlopCountAnalysis, parameter_count_table,flop_count_str,flop_count_table,ActivationCountAnalysis

def params_count(model):
    return np.sum([p.numel() for p in model.parameters()]).item()

def analysis_model(model,x=None):
    # fa = FlopCountAnalysis(model, imgs)
    # print(flop_count_str(FlopCountAnalysis(model, (imgs,imgs))))
    tabel = flop_count_table(FlopCountAnalysis(model, x))
    # acts = ActivationCountAnalysis(model, imgs)
    # acts.by_module()
    return tabel

def analysis_model1(model,x=None):
    flops, params = profile(model, (imgs,))
    with torchprof.Profile(model, use_cuda=False) as prof:
        model(imgs)
    trace, event_lists_dict  = prof.raw()
    return trace, event_lists_dict

print(analysis_model(model,x=(imgs,imgs)))

In [24]:
imgs = torch.ones((1,3,640,640))

print(analysis_model(model,x=(imgs,imgs)))
print('params_count: ',params_count(fb))

| module                     | #parameters or shape   | #flops     |
|:---------------------------|:-----------------------|:-----------|
| model                      | 2.268K                 | 2.487G     |
|  seq                       |  0.576K                |  0.747G    |
|   seq.subseq_0             |   0.216K               |   0.28G    |
|    seq.subseq_0.c3_0_0     |    36                  |    46.694M |
|    seq.subseq_0.c3_0_1     |    36                  |    46.694M |
|    seq.subseq_0.c3_0_2     |    36                  |    46.694M |
|    seq.subseq_0.c3_0_3     |    36                  |    46.694M |
|    seq.subseq_0.c3_0_4     |    36                  |    46.694M |
|    seq.subseq_0.c3_0_5     |    36                  |    46.694M |
|   seq.subseq_1.c3_1_0      |   36                   |   46.694M  |
|    seq.subseq_1.c3_1_0.cv1 |    5                   |    6.554M  |
|    seq.subseq_1.c3_1_0.cv2 |    5                   |    6.554M  |
|    seq.subseq_1.c3_1_0.cv3 |    