In [20]:
%matplotlib inline
import numpy as np
import scipy.io as sio
import os
import cv2
import skimage.io
import matplotlib.pyplot as plt
from skimage import morphology
import pandas as pd
import skimage.transform
import copy

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F    
import torch.optim as optim
import gc

In [21]:

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.Conv1 = nn.Conv2d(3,96,(3,3),stride=1,padding=1)
        self.Conv2 = nn.Conv2d(96,256,(3,3),stride=1,padding=1)
        self.Conv3 = nn.Conv2d(256,384,(3,3),stride=1,padding=1)
        self.Conv4 = nn.Conv2d(384,256,(3,3),stride=1,padding=1)
        self.Conv5 = nn.Conv2d(256,1024,(3,3),stride=1,padding=1)
        self.Conv6 = nn.Conv2d(1024,1024,1)
        
        self.Pool = nn.MaxPool2d(2)
        
        self.ConPre1 = nn.Conv2d(1024,2,1)
        self.ConPre2 = nn.Conv2d(256,2,1)
        self.ConPre3 = nn.Conv2d(384,2,1)
        
        self.Convtrans1 = nn.ConvTranspose2d(2,2,(2,2),stride=2)
        self.Convtrans2 = nn.ConvTranspose2d(2,2,(2,2),stride=2)
        self.Convtrans3 = nn.ConvTranspose2d(2,2,(8,8),stride=8)
        '''卷积核和原图一样大小'''
#         self.ResConv1 = nn.Conv2d(4096,136,(16,8))
        
        self.prelu = nn.PReLU()
        self.dropout = nn.Dropout2d(0.25)
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()

        self.leftcon1 = nn.Conv2d(96,256,(3,3),stride=1,padding=1)
        self.leftcon2 = nn.Conv2d(256*2,384,(3,3),stride=1,padding=1)
        self.leftcon3 = nn.Conv2d(384*2,256,(3,3),stride=1,padding=1)
        self.leftcon4 = nn.Conv2d(256*2,1024,(3,3),stride=1,padding=1)
        self.leftcon5 = nn.Conv2d(1024*2,1024,(3,3),stride=1,padding=1)
        self.leftcon6 = nn.Conv2d(1024*2,1024,(3,3),stride=1,padding=1)
        self.leftcon7 = nn.Conv2d(1024,1024,(3,3),stride=1,padding=1)
        self.leftcon8 = nn.Conv2d(1024,672,(4,2))
        
        self.bnc1 = nn.BatchNorm2d(96)
        self.bnc2 = nn.BatchNorm2d(256)
        self.bnc3 = nn.BatchNorm2d(384)
        self.bnc4 = nn.BatchNorm2d(256)
        self.bnc5 = nn.BatchNorm2d(1024)
        self.bnc6 = nn.BatchNorm2d(1024)
        
        self.bnct1 = nn.BatchNorm2d(2)
        self.bnct2 = nn.BatchNorm2d(2)
        self.bnct3 = nn.BatchNorm2d(2)
        
        self.bnl1 = nn.BatchNorm2d(256)
        self.bnl2 = nn.BatchNorm2d(384)
        self.bnl3 = nn.BatchNorm2d(256)
        self.bnl4 = nn.BatchNorm2d(1024)
        self.bnl5 = nn.BatchNorm2d(1024)
        self.bnl6 = nn.BatchNorm2d(1024)
        self.bnl7 = nn.BatchNorm2d(1024)
    def forward(self,x):
        
        x = self.Conv1(x)
        x = self.bnc1(x)
        x = self.dropout(x)        
        x = self.prelu(x)
        x = self.Pool(x)
        x1 = x
        
        '''        512 256       '''
        x = self.Conv2(x)
        x = self.bnc2(x)
        x = self.dropout(x)        
        x = self.prelu(x)
        x = self.Pool(x)
        x2 = x
        

        '''        256 128       '''
        x = self.Conv3(x)
        x = self.bnc3(x)
        x = self.dropout(x)        
        x = self.prelu(x)
        x = self.Pool(x)
        x3 = x
        '''        128 64        '''
        x = self.Conv4(x)
        x = self.bnc4(x)
        x = self.dropout(x)        
        x = self.prelu(x)
        x = self.Pool(x)
        x4 = x
        '''  64 32     '''
        x = self.Conv5(x)
        x = self.bnc5(x)
        x = self.dropout(x)        
        x = self.prelu(x)
        x = self.Pool(x)
        x5 = x
        ''' 32* 16  '''
        x = self.Conv6(x)
        x = self.bnc6(x)
        x = self.dropout(x)        
        x = self.prelu(x)   
        x6 = x


        
        lx = self.leftcon1(x1)
        lx = self.bnl1(lx)
        lx = self.dropout(lx)        
        lx = self.prelu(lx)
        lx = self.Pool(lx)
        lx = torch.cat([lx,x2],dim=1)
        
        lx = self.leftcon2(lx)
        lx = self.bnl2(lx)
        lx = self.dropout(lx)        
        lx = self.prelu(lx)
        lx = self.Pool(lx)
        lx = torch.cat([lx,x3],dim=1)
        
        lx = self.leftcon3(lx)
        lx = self.bnl3(lx)
        lx = self.dropout(lx)        
        lx = self.prelu(lx)
        lx = self.Pool(lx)
        lx = torch.cat([lx,x4],dim=1)
        
        lx = self.leftcon4(lx)
        lx = self.bnl4(lx)
        lx = self.dropout(lx)        
        lx = self.prelu(lx)
        lx = self.Pool(lx)
        lx = torch.cat([lx,x5],dim=1)
        
        lx = self.leftcon5(lx)
        lx = self.bnl5(lx)
        lx = self.dropout(lx)        
        lx = self.prelu(lx)
        lx = torch.cat([lx,x6],dim=1)
        
        lx = self.leftcon6(lx)
        lx = self.bnl6(lx)
        lx = self.dropout(lx)        
        lx = self.prelu(lx)
        lx = self.Pool(lx)
        
        lx = self.leftcon7(lx)
        lx = self.bnl7(lx)
        lx = self.dropout(lx)        
        lx = self.prelu(lx)
        lx = self.Pool(lx)
        
        lx = self.leftcon8(lx)
        lx = self.sigmoid(lx)

        x = self.ConPre1(x)
        x = self.bnct1(x)
        x = self.dropout(x)        
        x = self.prelu(x)

        xx4 = self.ConPre2(x4)
        xx4 = self.bnct2(xx4)
        xx4 = self.dropout(xx4)          
        xx4 = self.prelu(xx4)
     
        xx3 = self.ConPre3(x3)
        xx3 = self.bnct3(xx3)
        xx3 = self.dropout(xx3)        
        xx3 = self.prelu(xx3)

        
        x = self.Convtrans1(x)
        x = self.bnct1(x)
        x += xx4
        
        x = self.Convtrans2(x)
        x = self.bnct2(x)
        x += xx3
        
        x = self.Convtrans3(x)
        x = self.bnct3(x)
        x = self.sigmoid(x)
        return x,lx
    