In [36]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.nn as nn

def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True)
    )   


class UNet(nn.Module):

    def __init__(self):
        super().__init__()
                
        self.dconv_down1 = double_conv(3, 32)
        self.dconv_down2 = double_conv(32, 64)
        self.dconv_down3 = double_conv(64, 128)
        self.dconv_down4 = double_conv(128, 256)        

        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)        
        
        self.dconv_up3 = double_conv(128 + 256, 128)
        self.dconv_up2 = double_conv(64 + 128, 64)
        self.dconv_up1 = double_conv(64 + 32, 32)
        
        self.conv_last = nn.Conv2d(32, 1, 1)
        
        
    def forward(self, x):
        print(x.shape)
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)
        print(x.shape)

        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)
        print(x.shape)
        
        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)   
        print(x.shape)
        
        x = self.dconv_down4(x)
        print(x.shape)
        
        x = self.upsample(x)        
        x = torch.cat([x, conv3], dim=1)
        print(x.shape)
        
        x = self.dconv_up3(x)
        x = self.upsample(x)        
        x = torch.cat([x, conv2], dim=1)       
        print(x.shape)

        x = self.dconv_up2(x)
        x = self.upsample(x)        
        x = torch.cat([x, conv1], dim=1)   
        print(x.shape)
        
        x = self.dconv_up1(x)
        print(x.shape)
        
        out = self.conv_last(x)
        print('out', out.shape)
        
        return out
x = torch.randn(1, 3, 256, 256)
model = UNet()
model(x)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(params/ 1e6)

torch.Size([1, 3, 256, 256])
torch.Size([1, 32, 128, 128])
torch.Size([1, 64, 64, 64])
torch.Size([1, 128, 32, 32])
torch.Size([1, 256, 32, 32])
torch.Size([1, 384, 64, 64])
torch.Size([1, 192, 128, 128])
torch.Size([1, 96, 256, 256])
torch.Size([1, 32, 256, 256])
out torch.Size([1, 1, 256, 256])
1.946881


In [38]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.nn as nn

def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True)
    )   



class GrayScaleConverterModel(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, dim=32, channels_mult=[2, 2, 2]):
        super().__init__()
        self.down_blocks = []
        self.down_blocks.append(double_conv(in_channels, dim))
        last_dim = dim
        down_dims = []
        down_dims.append(last_dim)
        for ch_mult in channels_mult:
            new_dim = last_dim*ch_mult
            self.down_blocks.append(double_conv(last_dim, new_dim))
            last_dim = new_dim
            down_dims.append(last_dim)
            
        self.up_blocks = []
        print(last_dim)
        for ch_mult in channels_mult:
            new_dim = last_dim//ch_mult
            self.up_blocks.append(double_conv(last_dim + new_dim, new_dim))
            last_dim = new_dim        
        self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.maxpool = nn.MaxPool2d(2)
        self.conv_last = nn.Conv2d(dim, 1, 1)

    def forward(self, x):
        print(x.shape)
        residuals = []
        for idx, b in enumerate(self.down_blocks):
            x = b(x)
            if idx < len(self.down_blocks)-1:
                residuals.append(x)
                x = self.maxpool(x)

        residuals = residuals[::-1]
        for idx, b in enumerate(self.up_blocks):
            x = self.upsample(x)
            x = torch.cat([residuals[idx], x], dim=1)
            x = b(x)
        return self.conv_last(x)

x = torch.randn(1, 3, 256, 256)
model = GrayScaleConverterModel()
x = model(x)
print(x.shape)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(params/ 1e6)

# torch.Size([1, 3, 256, 256])
# torch.Size([1, 32, 128, 128])
# torch.Size([1, 64, 64, 64])
# torch.Size([1, 128, 32, 32])
# torch.Size([1, 256, 32, 32])
# torch.Size([1, 384, 64, 64])
# torch.Size([1, 192, 128, 128])
# torch.Size([1, 96, 256, 256])
# torch.Size([1, 32, 256, 256])
# torch.Size([1, 32, 256, 256])

256
torch.Size([1, 3, 256, 256])
torch.Size([1, 32, 128, 128])
torch.Size([1, 64, 64, 64])
torch.Size([1, 128, 32, 32])
torch.Size([1, 256, 32, 32])
torch.Size([1, 128, 64, 64])
torch.Size([1, 64, 128, 128])
torch.Size([1, 32, 256, 256])
torch.Size([1, 1, 256, 256])
3.3e-05
