In [1]:
import torch
from torch import nn

class ResNetBlock(nn.Module): # <1>
    def __init__(self, dim):
        super().__init__()
        self.conv_block = self.build_conv_block(dim)
        
    def build_conv_block(self, dim):
        conv_block = nn.Sequential(
            nn.ReflectionPad2d(1),
            
            nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
            nn.InstanceNorm2d(dim),
            nn.ReLU(True),
            
            nn.ReflectionPad2d(1),
            
            nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
            nn.InstanceNorm2d(dim)
        )
        
        return conv_block
    
    def forward(self, x):
        out = x + self.conv_block(x) # <2>
        
        return out
    

class ResNetGenerator(nn.Module):
    def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9): # <3>
        assert n_blocks >= 0
        super().__init__()
        
        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ngf = ngf
        
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True),
            nn.InstanceNorm2d(ngf),
            nn.ReLU(True)
        ]
        
        n_downsampling = 2
        
        for i in range(n_downsampling):
            mult = 2 ** i
            model.extend([
                nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, 
                          stride=2, padding=1, bias=True),
                nn.InstanceNorm2d(ngf * mult * 2),
                nn.ReLU(True)
            ])
            
        mult = 2 * n_downsampling
        
        for i in range(n_blocks):
            model.append(ResNetBlock(ngf * mult))
        
        for i in range(n_downsampling):
            mult = 2 ** (n_downsampling - i)
            model.extend([
                nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), 
                                   kernel_size=3, stride=2,
                                   padding=1, output_padding=1,
                                   bias=True),
                nn.InstanceNorm2d(int(ngf * mult / 2)),
                nn.ReLU(True)
            ])
            
        model.append(nn.ReflectionPad2d(3))
        model.append(nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0))
        model.append(nn.Tanh())
        
        self.model = nn.Sequential(*model)
        
    def forward(self, input): # <3>
        
        return self.model(input)