In [1]:
import torch
import torch.nn as nn

## Generator(UNet)

In [6]:
class UNetDown(nn.Module):
    def __init__(self, in_channels, out_channels, normalize=True, dropout=0.0):
        super(UNetDown, self).__init__()

        layers = [nn.Conv2d(in_channels, out_channels, 4, stride=2, padding=1, bias=False)]

        if normalize:
            layers.append(nn.InstanceNorm2d(out_channels)),
        layers.append(nn.LeakyReLU(0.2))

        if dropout:
            layers.append(nn.Dropout(dropout))
        self.down = nn.Sequential(*layers)


    def forward(self, x):
        y = self.down(x)

        return y

X = torch.randn(16, 3, 256, 256)
model = UNetDown(3, 64)
down_out = model(X)
print(down_out.shape)

torch.Size([16, 64, 128, 128])


In [7]:
class UNetUp(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=0.0):
        super(UNetUp, self).__init__()

        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU()
        ]

        if dropout:
            layers.append(nn.Dropout(dropout))
        self.up = nn.Sequential(*layers)


    def forward(self, x, skip):
        y = self.up(x)
        y = torch.cat((y, skip), 1)

        return y

X = torch.randn(16, 128, 64, 64)
model = UNetUp(128, 64)
out = model(X, down_out)
print(out.shape)

torch.Size([16, 128, 128, 128])


In [12]:
# Generator 가짜 이미지 생성 모델
class GeneratorUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(GeneratorUNet, self).__init__()

        self.down1 = UNetDown(in_channels, 64, normalize=False)
        self.down2 = UNetDown(64, 128)
        self.down3 = UNetDown(128, 256)
        self.down4 = UNetDown(256, 512, dropout=0.5)
        self.down5 = UNetDown(512, 512, dropout=0.5)
        self.down6 = UNetDown(512, 512, dropout=0.5)
        self.down7 = UNetDown(512, 512, dropout=0.5)
        self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5)

        self.up1 = UNetUp(512, 512, dropout=0.5)
        self.up2 = UNetUp(1024, 512, dropout=0.5)
        self.up3 = UNetUp(1024, 512, dropout=0.5)
        self.up4 = UNetUp(1024, 512, dropout=0.5)
        self.up5 = UNetUp(1024, 256)
        self.up6 = UNetUp(512, 128)
        self.up7 = UNetUp(256, 64)
        self.up8 = nn.Sequential(
            nn.ConvTranspose2d(128, 3, 4, stride=2, padding=1),
            nn.Tanh()
        )


    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)

        u1 = self.up1(d8,d7)
        u2 = self.up2(u1,d6)
        u3 = self.up3(u2,d5)
        u4 = self.up4(u3,d4)
        u5 = self.up5(u4,d3)
        u6 = self.up6(u5,d2)
        u7 = self.up7(u6,d1)
        u8 = self.up8(u7)

        return u8

X = torch.randn(16, 3, 256, 256)
model = GeneratorUNet()
out = model(X)
print(out.shape)

torch.Size([16, 3, 256, 256])


## Discriminator(UNet)

In [18]:
class Dis_Block(nn.Module):
    def __init__(self, in_channels, out_channels, normalize=True):
        super(Dis_Block, self).__init__()

        layers = [nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1)]

        if normalize:
            layers.append(nn.InstanceNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2))

        self.block = nn.Sequential(*layers)


    def forward(self, x):
        x = self.block(x)

        return x

X = torch.randn(16, 64, 64, 128)
model = Dis_Block(64, 128)
out = model(X)
print(out.shape)

torch.Size([16, 128, 32, 64])


In [19]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(Discriminator, self).__init__()

        self.stage_1 = Dis_Block(in_channels * 2, 64, normalize=False)
        self.stage_2 = Dis_Block(64, 128)
        self.stage_3 = Dis_Block(128, 256)
        self.stage_4 = Dis_Block(256, 512)

        self.patch = nn.Conv2d(512, 1, 3, padding=1)


    def forward(self, a, b):
        x = torch.cat((a, b), 1)
        x = self.stage_1(x)
        x = self.stage_2(x)
        x = self.stage_3(x)
        x = self.stage_4(x)
        x = self.patch(x)
        y = torch.sigmoid(x)

        return y

X = torch.randn(16, 3, 256, 256)
model = Discriminator()
out = model(X, X)
print(out.shape)

torch.Size([16, 1, 16, 16])
