In [2]:
import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as opt
import torchvision.transforms as transforms
import numpy as np
import os
from PIL import Image
from IPython import display
import matplotlib.pyplot as plt
import glob
%matplotlib inline
device = "cuda"

In [3]:
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [4]:
class DownSamp(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DownSamp, self).__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=2,padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1,padding=1)
        self.conv3 = nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1,padding=1)
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu_(x)
        x = self.conv2(x)
        x = F.relu_(x)
        x = self.conv3(x)
        x = F.relu_(x)
        return x
class UpSamp(nn.Module):
    def __init__(self, in_ch, cat_ch, out_ch):
        super(UpSamp, self).__init__()
        self.deconv = nn.ConvTranspose2d(in_ch,in_ch//2,kernel_size=4, stride=2,padding=1)
        #forwardでcatする catするサイズはアップサンプル後
        self.conv1 = nn.Conv2d(in_ch//2+cat_ch, out_ch, kernel_size=3, stride=1,padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1,padding=1)
    def forward(self,x, cat):
        x = self.deconv(x)
        x = F.relu_(x)
        x = torch.cat((x, cat),axis=1)
        x = self.conv1(x)
        x = F.relu_(x)
        x = self.conv2(x)
        x = F.relu_(x)
        return x
d = DownSamp(32,64)
u = UpSamp(64,32,32)
d(torch.zeros((1,32,64,64)))
u(torch.zeros((1,64,32,32)),torch.zeros((1,32,64,64))).shape
class UNet(nn.Module):
    def __init__(self):
        super(UNet,self).__init__()
        self.in1 = nn.Conv2d(3, 64, kernel_size=7,stride=1,padding=3)
        self.in2 = nn.Conv2d(64, 64, kernel_size=3,stride=1,padding=1)
        self.down1 = DownSamp(64,128)
        self.down2 = DownSamp(128,256)
        self.down3 = DownSamp(256,512)
        #self.down4 = DownSamp(512,1024)
        #self.up4 = UpSamp(1024,512,512)
        self.up3 = UpSamp(512,256,256)
        self.up2 = UpSamp(256,128,128)
        self.up1 = UpSamp(128,64,64)
        self.out = nn.Conv2d(64,3, kernel_size=7, stride=1, padding=3)
    def forward(self,x):
        x = self.in1(x)
        x = F.relu_(x)
        x1 = self.in2(x)
        x1 = F.relu_(x1)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        #x5 = self.down4(x4)
        #x = self.up4(x5,x4)
        x = self.up3(x4, x3)
        x = self.up2(x, x2)
        x = self.up1(x, x1)
        x = self.out(x)
        x = F.tanh(x)
        return x


In [5]:
from torch.nn.utils.spectral_norm import spectral_norm

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()        
        self.main = nn.Sequential(
            spectral_norm(nn.Conv2d(3, 64, stride=2, kernel_size=3, padding=1)),
            nn.LeakyReLU(0.1),
            spectral_norm(nn.Conv2d(64, 128, stride=2, kernel_size=3, padding=1)),
            nn.LeakyReLU(0.1),
            spectral_norm(nn.Conv2d(128, 256, stride=2, kernel_size=3, padding=1)),
            nn.LeakyReLU(0.1),
            spectral_norm(nn.Conv2d(256, 512, stride=2, kernel_size=3, padding=1)),
            nn.LeakyReLU(0.1),
            nn.Conv2d(512,1,stride=2,kernel_size=1,padding=1)
            
        )
    def forward(self, x):
        return self.main(x).squeeze_(1)
d = Discriminator()
d(torch.zeros((1,3,224,224))).shape

torch.Size([1, 8, 8])

In [7]:

g = UNet().to(device)
d = Discriminator().to(device)
fake = g(torch.zeros((1,3,224,224)).to(device))
real = torch.zeros((1,3,224,224)).to(device)
fake.shape
real.shape
g_loss = -torch.sum(d(fake))

d_loss = -torch.mean(torch.min(0,-1+d(real)) + torch.min(0, -1 - d(fake)))

TypeError: min() received an invalid combination of arguments - got (int, Tensor), but expected one of:
 * (Tensor input)
 * (Tensor input, Tensor other, Tensor out)
 * (Tensor input, int dim, bool keepdim, tuple of Tensors out)
