In [2]:
import torch.nn as nn

In [3]:
def UNetBlock(inChannels,outChannels):
    return nn.Sequential(
        nn.Conv2d(inChannels,outChannels,kernel_size = 3,padding=1, bias = False),
        nn.BatchNorm2d(num_features =outChannels),
        nn.ReLU(inplace=True),
        nn.Conv2d(outChannels,outChannels,kernel_size = 3,padding=1, bias = False),
        nn.BatchNorm2d(num_features =outChannels),
        nn.ReLU(inplace=True),
        nn.Conv2d(outChannels,outChannels,kernel_size = 3,padding=1, bias = False),
        nn.BatchNorm2d(num_features =outChannels),
        nn.ReLU(inplace=True), 
    )

In [5]:
class BioFaceUNet(nn.Module):
    def __init__(self):
        super(BioFaceUNet,self).__init__()
        self.inputChannels = 3
        self.features = np.array([32,64,128,256,512])
        self.nclasses = 4
        self.lightVectorSize = 15
        self.bSize = 2
        self.dims = self.lightVectorSize + self.bSize
        
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.encBlock1 = UNetBlock(3,self.features[0])
        self.encBlock2 = UNetBlock(self.features[0],self.features[1])
        self.encBlock3 = UNetBlock(self.features[1],self.features[2])
        self.encBlock4 = UNetBlock(self.features[2],self.features[3])
        self.encBlock5 = UNetBlock(self.features[3],self.features[4])
        
        self.upsample1 = nn.Upsample(size = (27,22))
        self.upsample2 = nn.Upsample(size = (54,44))
        self.upsample3 = nn.Upsample(size = (109,89))
        self.upsample4 = nn.Upsample(size = (218,178))
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
        
        self.decBlock1 = UNetBlock(self.features[4]+self.features[3],self.features[3])
        self.decBlock2 = UNetBlock(self.features[3]+self.features[2],self.features[2])
        self.decBlock3 = UNetBlock(self.features[2]+self.features[1],self.features[1])
        self.decBlock4 = UNetBlock(self.features[1]+self.features[0],self.features[0])

        
        self.convFCN1 = nn.Conv2d(in_channels=self.features[4], out_channels=self.features[4], kernel_size = 4)
        self.bnormFCN1 = nn.BatchNorm2d(num_features=self.features[4])
        self.reluFCN1 = nn.ReLU(inplace=True)
        self.convFCN2 = nn.Conv2d(in_channels=self.features[4], out_channels=self.features[4], kernel_size = 1)
        self.bnormFCN2 = nn.BatchNorm2d(num_features=self.features[4])
        self.reluFCN2 = nn.ReLU(inplace=True)
        
        self.pred = nn.Conv2d(in_channels = self.features[4],out_channels = 17, kernel_size = 5,bias=False)
        
    def forward(self,x):
        
        ########Encoding################
        #x = x.permute(0,3,1,2) # rearranging the input to make sure it has 3 channels coming to the UNet
        encBlock1 = self.encBlock1(x)
        encPool1 = self.pool(encBlock1)
        encBlock2 = self.encBlock2(encPool1)
        encPool2 = self.pool(encBlock2)
        encBlock3 = self.encBlock3(encPool2)
        encPool3 = self.pool(encBlock3)
        encBlock4 = self.encBlock4(encPool3)
        encPool4 = self.pool(encBlock4)
        encBlock5 = self.encBlock5(encPool4)
        y = encBlock5
        ################################
        
        ########Decoding################
        for c in range(1,5):
            x = self.upsample(y)
            #print("--------Decoding")
            #print("Upsample size: ", x.shape)
            x = torch.cat((x,encBlock4),1)
            x = self.decBlock1(x)
            
            x = self.upsample(x)
            x = torch.cat((x,encBlock3),1)
            x = self.decBlock2(x)
            
            x = self.upsample(x)
            x = torch.cat((x,encBlock2),1)
            x = self.decBlock3(x)
            
            x = self.upsample(x)
            x = torch.cat((x,encBlock1),1)
            x = self.decBlock4(x)
            
            if c==1:
                z = x
            else:
                z = torch.cat((z,x),1)
        
        x=z
        
        ######### FCN ##################
        fcn1 = self.convFCN1(y)
        fcn1 = self.bnormFCN1(fcn1)
        fcn1 = self.reluFCN1(fcn1)
        fcn2 = self.convFCN2(fcn1)
        fcn2 = self.bnormFCN2(fcn2)
        fcn2 = self.reluFCN2(fcn2)
        ################################
        
        predictions = self.pred(fcn2)
        
    
        
        lightingParameters = predictions[:,0:self.lightVectorSize,:,:]
        b = predictions[:,15:17,0,0]
        fmel = x[:,0,:,:]
        fblood = x[:,1,:,:]
        shading = x[:,2,:,:]
        specmask = x[:,3,:,:]
        
        #print("lightingParameters: ", lightingParameters.shape)
        #print("b: ", b.shape)
        #print("fmel: ", fmel.shape)
        #print("fblood: ", fblood.shape)
        #print("shading: ", shading.shape)
        #print("specmask: ", specmask.shape)

        return lightingParameters,b,fmel,fblood,shading,specmask