In [1]:
# same as cycleGAN
import os
import numpy as np
import math
import itertools
import datetime
import time

import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable


import torch.nn as nn
import torch.nn.functional as F
import torch

import glob

In [2]:
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

# build datast and dataloader
class ImageDataset(Dataset):
    def __init__(self, root, transforms_=None, mode="train"):
        self.transform = transforms.Compose(transforms_)
#         self.unaligned = unaligned
        if mode=='train':
            self.files = glob.glob("edges2shoes/train/*.jpg")
        else:
            self.files = glob.glob("edges2shoes/val/*.jpg")
    def __getitem__(self, index):
        img = Image.open(self.files[index % len(self.files)])

        # resize here
#         print(image_A.size)
        edges = img.crop((0,0,256, 256))
        shoes = img.crop((256,0,512, 256))
        #256*256
        edges = self.transform(edges)
        shoes = self.transform(shoes)
        return {"E": edges, "S": shoes}

    def __len__(self):
        return len(self.files)
    



In [3]:
import torch.nn as nn
import torch.nn.functional as F
import torch

def weights_init_normal(m):

    classname = m.__class__.__name__

    if classname.find("Conv") != -1:

        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)

    elif classname.find("BatchNorm2d") != -1:

        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)

        torch.nn.init.constant_(m.bias.data, 0.0)

class UNetDown(nn.Module):

    def __init__(self, in_size, out_size, normalize=True, dropout=0.0):

        super(UNetDown, self).__init__()
        layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
        if normalize:
            layers.append(nn.InstanceNorm2d(out_size))

        layers.append(nn.LeakyReLU(0.2))

        if dropout:

            layers.append(nn.Dropout(dropout))

        self.model = nn.Sequential(*layers)
    def forward(self, x):
        return self.model(x)

class UNetUp(nn.Module):

    def __init__(self, in_size, out_size, dropout=0.0):

        super(UNetUp, self).__init__()

        layers = [

            nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),#反卷积

            nn.InstanceNorm2d(out_size),

            nn.ReLU(inplace=True),

        ]

        if dropout:

            layers.append(nn.Dropout(dropout))

        self.model = nn.Sequential(*layers)
    def forward(self, x, skip_input):
        x = self.model(x)
        x = torch.cat((x, skip_input), 1)
        return x

class GeneratorUNet(nn.Module):
    #dropout太大可能会出问题
    def __init__(self, in_channels=3, out_channels=3):

        super(GeneratorUNet, self).__init__()
        self.down1 = UNetDown(in_channels, 64, normalize=False)

        self.down2 = UNetDown(64, 128)

        self.down3 = UNetDown(128, 256)

        self.down4 = UNetDown(256, 512, dropout=0.5)

#         self.down5 = UNetDown(512, 512, dropout=0.5)

#         self.down6 = UNetDown(512, 512, dropout=0.5)

        self.down7 = UNetDown(512, 512, dropout=0.5)

        self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5)

        self.up1 = UNetUp(512, 512, dropout=0.5)

#         self.up2 = UNetUp(1024, 512, dropout=0.5)

#         self.up3 = UNetUp(1024, 512, dropout=0.5)

        self.up4 = UNetUp(1024, 512, dropout=0.5)

        self.up5 = UNetUp(1024, 256)

        self.up6 = UNetUp(512, 128)

        self.up7 = UNetUp(256, 64)
        self.final = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(128, out_channels, 4, padding=1),
            nn.Tanh(),

        )
    def forward(self, x):
        #256*256
        # U-Net generator with skip connections from encoder to decoder
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
#         d5 = self.down5(d4)
#         d6 = self.down6(d5)
        d7 = self.down7(d4)
        d8 = self.down8(d7)
        u1 = self.up1(d8, d7)
#         u2 = self.up2(u1, d6)
#         u3 = self.up3(u2, d5)
        u4 = self.up4(u1, d4)
        u5 = self.up5(u4, d3)
        u6 = self.up6(u5, d2)
        u7 = self.up7(u6, d1)
        return self.final(u7)

class Discriminator(nn.Module):

    def __init__(self, in_channels=3):

        super(Discriminator, self).__init__()
        def discriminator_block(in_filters, out_filters, normalization=True):

            """Returns downsampling layers of each discriminator block"""

            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]

            if normalization:

                layers.append(nn.InstanceNorm2d(out_filters))

            layers.append(nn.LeakyReLU(0.2, inplace=True))

            return layers
        self.model = nn.Sequential(
            *discriminator_block(in_channels * 2, 64, normalization=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),

            nn.Conv2d(512, 1, 4, padding=1, bias=False),
            #modified here
            nn.Sigmoid()
        )
    def forward(self, img_A, img_B):

        # Concatenate image and condition image by channels to produce input
        img_input = torch.cat((img_A, img_B), 1)#cat 2 input
        return self.model(img_input)

In [5]:
# import argparse
import os
import numpy as np
import math
import itertools
import time
import datetime
import sys

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable


import torch.nn as nn
import torch.nn.functional as F
import torch


os.environ["CUDA_VISIBLE_DEVICES"] = "7,4"
device_ids = [0,1]

epoch=80
n_epochs=80
dataset_name="Generate_1"
batch_size=1024#改size啊...
lr=0.0002
b1=0.5
b2=0.999
decay_epoch=50
# n_cpu=2
sample_interval = 20
checkpoint_interval=20
img_height=64
img_width=64
channels=3

os.makedirs("images_wg/%s" %  dataset_name, exist_ok=True)
os.makedirs("saved_models/%s" %  dataset_name, exist_ok=True)

cuda = True if torch.cuda.is_available() else False

# Loss functions
# criterion_G = torch.nn.MSELoss()
criterion_D = torch.nn.BCELoss()


criterion_pixelwise = torch.nn.L1Loss()

# Loss weight of L1 pixel-wise loss between translated image and real image
lambda_pixel = 100

# Calculate output of image discriminator (PatchGAN)
patch = (1,  img_height // 2 ** 4,  img_width // 2 ** 4)

# Initialize generator and discriminator
generator = GeneratorUNet()
discriminator = Discriminator()

if cuda:
#     generator = generator.cuda()
#     discriminator = discriminator.cuda()
    
    
    generator = torch.nn.DataParallel(generator, device_ids=device_ids)
    generator = generator.cuda(device=device_ids[0])
    
    discriminator = torch.nn.DataParallel(discriminator, device_ids=device_ids)
    discriminator = discriminator.cuda(device=device_ids[0])
    
#     criterion_G.cuda()
    criterion_D.cuda()
    criterion_pixelwise.cuda()
    print('cuda over')
if  epoch != 0:
    # Load pretrained models
    generator.load_state_dict(torch.load("saved_models/%s/generator_%d.pth" % ( dataset_name,  epoch)))
    discriminator.load_state_dict(torch.load("saved_models/%s/discriminator_%d.pth" % ( dataset_name,  epoch)))
else:
    # Initialize weights
    generator.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr= lr, betas=( b1,  b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr= lr, betas=( b1,  b2))

# Configure dataloaders
transforms_ = [
    transforms.Resize((img_height, img_width), Image.BICUBIC),
    transforms.ToTensor(),
#     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]

dataloader = DataLoader(
    ImageDataset("%s" %  dataset_name, transforms_=transforms_),
    batch_size= batch_size,
    shuffle=True,
#     num_workers= n_cpu,
)

val_dataloader = DataLoader(
    ImageDataset("%s" %  dataset_name, transforms_=transforms_, mode="val"),
    batch_size=10,
    shuffle=False,
#     num_workers=1,
)

# Tensor type
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor


def sample_images(epoch):
    """Saves a generated sample from the validation set"""
    imgs = next(iter(val_dataloader))
    real_A = Variable(imgs["E"].type(Tensor))
    real_B = Variable(imgs["S"].type(Tensor))
    fake_B = generator(real_A)
    img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -2)
    save_image(img_sample, "images_wg/%s/%s.png" % ( dataset_name, epoch), nrow=5, normalize=True)




cuda over


In [6]:
# ----------
#  Training
# ----------

prev_time = time.time()

for epoch in range( epoch,  n_epochs):
    for i, batch in enumerate(dataloader):

        # Model inputs
        real_A = Variable(batch["E"].type(Tensor))
        real_B = Variable(batch["S"].type(Tensor))

        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((real_A.size(0), *patch))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((real_A.size(0), *patch))), requires_grad=False)

        # ------------------
        #  Train Generators
        # ------------------

        optimizer_G.zero_grad()

        # GAN loss E->S
        fake_B = generator(real_A)
#         pred_fake = discriminator(fake_B, real_A)
        
        G_loss = torch.mean(-discriminator(fake_B, real_A))
        
#         loss_GAN = criterion_D(pred_fake, valid)#
        # Pixel-wise loss
        loss_pixel = criterion_pixelwise(fake_B, real_B)

        # Total loss
#         loss_G = loss_GAN + lambda_pixel * loss_pixel
        loss_G = G_loss + lambda_pixel * loss_pixel

        loss_G.backward()

        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Real loss
        pred_real = discriminator(real_B, real_A)
#         loss_real = criterion_D(pred_real, valid)

        # Fake loss
        pred_fake = discriminator(fake_B.detach(), real_A)
#         loss_fake = criterion_D(pred_fake, fake)

        # Total loss
#         loss_D = 0.5 * (loss_real + loss_fake)
        loss_D = torch.mean(pred_fake - pred_real)

        loss_D.backward()
        optimizer_D.step()

        # --------------
        #  Log Progress
        # --------------

        # Determine approximate time left
        batches_done = epoch * len(dataloader) + i
        batches_left =  n_epochs * len(dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()

        # Print log
        sys.stdout.write(
#             "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, pixel: %f, adv: %f] ETA: %s"
            "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, pixel: %f, ] ETA: %s"
            % (
                epoch,
                 n_epochs,
                i,
                len(dataloader),
                loss_D.item(),
                loss_G.item(),
                loss_pixel.item(),
#                 loss_GAN.item(),
                time_left,
            )
        )

        # If at sample interval save image
#         if batches_done %  sample_interval == 0:
#             sample_images(batches_done)

    if  checkpoint_interval != -1 and epoch %  checkpoint_interval == 0:
        # Save model checkpoints
        torch.save(generator.state_dict(), "saved_models/%s/generator_%d.pth" % ( dataset_name, epoch))
        torch.save(discriminator.state_dict(), "saved_models/%s/discriminator_%d.pth" % ( dataset_name, epoch))
        sample_images(epoch)

In [7]:
for i in range(101,111):
    sample_images(i)

RuntimeError: CUDA error: out of memory (malloc at /opt/conda/conda-bld/pytorch_1579022027550/work/c10/cuda/CUDACachingAllocator.cpp:260)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x47 (0x7f968c826627 in /home1/lisl/anaconda3/envs/Pyt/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x1ea4a (0x7f968ca6aa4a in /home1/lisl/anaconda3/envs/Pyt/lib/python3.8/site-packages/torch/lib/libc10_cuda.so)
frame #2: <unknown function> + 0x1ff2e (0x7f968ca6bf2e in /home1/lisl/anaconda3/envs/Pyt/lib/python3.8/site-packages/torch/lib/libc10_cuda.so)
frame #3: THCStorage_resize + 0xa3 (0x7f96976476b3 in /home1/lisl/anaconda3/envs/Pyt/lib/python3.8/site-packages/torch/lib/libtorch.so)
frame #4: at::native::empty_strided_cuda(c10::ArrayRef<long>, c10::ArrayRef<long>, c10::TensorOptions const&) + 0x626 (0x7f9698ff4f56 in /home1/lisl/anaconda3/envs/Pyt/lib/python3.8/site-packages/torch/lib/libtorch.so)
frame #5: <unknown function> + 0x417bcea (0x7f9697558cea in /home1/lisl/anaconda3/envs/Pyt/lib/python3.8/site-packages/torch/lib/libtorch.so)
frame #6: <unknown function> + 0x1b0ec41 (0x7f9694eebc41 in /home1/lisl/anaconda3/envs/Pyt/lib/python3.8/site-packages/torch/lib/libtorch.so)
frame #7: <unknown function> + 0x366cf70 (0x7f9696a49f70 in /home1/lisl/anaconda3/envs/Pyt/lib/python3.8/site-packages/torch/lib/libtorch.so)
frame #8: <unknown function> + 0x1b0ec41 (0x7f9694eebc41 in /home1/lisl/anaconda3/envs/Pyt/lib/python3.8/site-packages/torch/lib/libtorch.so)
frame #9: <unknown function> + 0x187765e (0x7f9694c5465e in /home1/lisl/anaconda3/envs/Pyt/lib/python3.8/site-packages/torch/lib/libtorch.so)
frame #10: at::native::to(at::Tensor const&, c10::TensorOptions const&, bool, bool, c10::optional<c10::MemoryFormat>) + 0x245 (0x7f9694c556b5 in /home1/lisl/anaconda3/envs/Pyt/lib/python3.8/site-packages/torch/lib/libtorch.so)
frame #11: <unknown function> + 0x1bbcb5a (0x7f9694f99b5a in /home1/lisl/anaconda3/envs/Pyt/lib/python3.8/site-packages/torch/lib/libtorch.so)
frame #12: <unknown function> + 0x38a2826 (0x7f9696c7f826 in /home1/lisl/anaconda3/envs/Pyt/lib/python3.8/site-packages/torch/lib/libtorch.so)
frame #13: <unknown function> + 0x1c075a2 (0x7f9694fe45a2 in /home1/lisl/anaconda3/envs/Pyt/lib/python3.8/site-packages/torch/lib/libtorch.so)
frame #14: torch::cuda::scatter(at::Tensor const&, c10::ArrayRef<long>, c10::optional<std::vector<long, std::allocator<long> > > const&, long, c10::optional<std::vector<c10::optional<c10::cuda::CUDAStream>, std::allocator<c10::optional<c10::cuda::CUDAStream> > > > const&) + 0x710 (0x7f9697952f20 in /home1/lisl/anaconda3/envs/Pyt/lib/python3.8/site-packages/torch/lib/libtorch.so)
frame #15: <unknown function> + 0x9e6662 (0x7f96c3a31662 in /home1/lisl/anaconda3/envs/Pyt/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #16: <unknown function> + 0x28b8a7 (0x7f96c32d68a7 in /home1/lisl/anaconda3/envs/Pyt/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #17: PyCFunction_Call + 0x56 (0x562493c7a006 in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #18: _PyObject_MakeTpCall + 0x21f (0x562493c3b03f in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #19: _PyEval_EvalFrameDefault + 0x5307 (0x562493cdfd87 in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #20: _PyEval_EvalCodeWithName + 0x1dc (0x562493c85cec in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #21: _PyFunction_Vectorcall + 0x1c5 (0x562493c86da5 in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #22: _PyEval_EvalFrameDefault + 0x4d78 (0x562493cdf7f8 in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #23: _PyFunction_Vectorcall + 0xfb (0x562493c86cdb in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #24: PyVectorcall_Call + 0x6f (0x562493c3a86f in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #25: THPFunction_apply(_object*, _object*) + 0xb2f (0x7f96c36bfd1f in /home1/lisl/anaconda3/envs/Pyt/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #26: PyCFunction_Call + 0xdb (0x562493c7a08b in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #27: _PyObject_MakeTpCall + 0x21f (0x562493c3b03f in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #28: _PyEval_EvalFrameDefault + 0x5307 (0x562493cdfd87 in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #29: _PyEval_EvalCodeWithName + 0x955 (0x562493c86465 in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #30: _PyFunction_Vectorcall + 0x21e (0x562493c86dfe in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #31: <unknown function> + 0x186a56 (0x562493c79a56 in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #32: PyIter_Next + 0xe (0x562493c3cc3e in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #33: PySequence_Tuple + 0xfb (0x562493c8513b in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #34: _PyEval_EvalFrameDefault + 0x5c1c (0x562493ce069c in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #35: _PyEval_EvalCodeWithName + 0x955 (0x562493c86465 in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #36: _PyFunction_Vectorcall + 0x21e (0x562493c86dfe in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #37: _PyEval_EvalFrameDefault + 0x6e4 (0x562493cdb164 in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #38: _PyEval_EvalCodeWithName + 0x955 (0x562493c86465 in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #39: _PyFunction_Vectorcall + 0x21e (0x562493c86dfe in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #40: _PyEval_EvalFrameDefault + 0x6e4 (0x562493cdb164 in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #41: _PyEval_EvalCodeWithName + 0x1dc (0x562493c85cec in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #42: _PyFunction_Vectorcall + 0x268 (0x562493c86e48 in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #43: _PyEval_EvalFrameDefault + 0x147b (0x562493cdbefb in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #44: _PyFunction_Vectorcall + 0xfb (0x562493c86cdb in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #45: <unknown function> + 0x16ff85 (0x562493c62f85 in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #46: _PyEval_EvalFrameDefault + 0x4d78 (0x562493cdf7f8 in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #47: _PyEval_EvalCodeWithName + 0x1dc (0x562493c85cec in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #48: _PyFunction_Vectorcall + 0x21e (0x562493c86dfe in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #49: <unknown function> + 0x17004b (0x562493c6304b in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #50: PyVectorcall_Call + 0x6f (0x562493c3a86f in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #51: _PyEval_EvalFrameDefault + 0x1f14 (0x562493cdc994 in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #52: _PyEval_EvalCodeWithName + 0x1dc (0x562493c85cec in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #53: _PyObject_FastCallDict + 0x1b8 (0x562493c87358 in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #54: _PyObject_Call_Prepend + 0x63 (0x562493c87623 in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #55: <unknown function> + 0x19472a (0x562493c8772a in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #56: _PyObject_MakeTpCall + 0x21f (0x562493c3b03f in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #57: _PyEval_EvalFrameDefault + 0x4cbd (0x562493cdf73d in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #58: _PyFunction_Vectorcall + 0xfb (0x562493c86cdb in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #59: _PyEval_EvalFrameDefault + 0x6e4 (0x562493cdb164 in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #60: _PyEval_EvalCodeWithName + 0x1dc (0x562493c85cec in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #61: PyEval_EvalCodeEx + 0x44 (0x562493c86ba4 in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #62: PyEval_EvalCode + 0x1c (0x562493c86bcc in /home1/lisl/anaconda3/envs/Pyt/bin/python)
frame #63: <unknown function> + 0x206128 (0x562493cf9128 in /home1/lisl/anaconda3/envs/Pyt/bin/python)
