In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import numpy as np

import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt

import Utils
import os

# 1. 데이터 로드
 - 데이터 셋 로드 및 데이터 로더 만들기
 

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(size=[512,512]),
    transforms.Normalize(mean=(0.5,), std=(0.5, ))
])

root_path = os.path.join(os.getcwd(), 'data')
dataset = Utils.vibrationData(root_path=root_path, transform=transform)

data_loader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)

In [3]:
sample = next(iter(data_loader))

In [4]:
signal, wavlet_img, corr_img, cls = sample

In [5]:
class Encoder(nn.Module):
    """
    DCGAN ENCODER NETWORK
    """

    def __init__(self, isize, nz, nc, ndf, ngpu, n_extra_layers=0, add_final_conv=True):
        """_summary_

        Args:
            isize (int): input image size
            nz (int): size of the latent z vector
            nc (int): input image channels
            ndf (int): _description_
            ngpu (int): number of GPUs to use
            n_extra_layers (int, optional): number of layers on gen and disc. Defaults to 0.
            add_final_conv (bool, optional): _description_. Defaults to True.
        """
        super(Encoder, self).__init__()
        self.ngpu = ngpu
        assert isize % 16 == 0, "isize has to be a multiple of 16"

        main = nn.Sequential()
        # input is nc x isize x isize
        main.add_module('initial-conv-{0}-{1}'.format(nc, ndf),
                        nn.Conv2d(nc, ndf, 3, 2, 1, bias=False))
        main.add_module('initial-relu-{0}'.format(ndf),
                        nn.LeakyReLU(0.2, inplace=True))
        csize, cndf = isize / 2, ndf

        # Extra layers
        for t in range(n_extra_layers):
            main.add_module('extra-layers-{0}-{1}-conv'.format(t, cndf),
                            nn.Conv2d(cndf, cndf, 3, 1, 1, bias=False))
            main.add_module('extra-layers-{0}-{1}-batchnorm'.format(t, cndf),
                            nn.BatchNorm2d(cndf))
            main.add_module('extra-layers-{0}-{1}-relu'.format(t, cndf),
                            nn.LeakyReLU(0.2, inplace=True))

        while csize > 4:
            in_feat = cndf
            out_feat = cndf * 2
            main.add_module('pyramid-{0}-{1}-conv'.format(in_feat, out_feat),
                            nn.Conv2d(in_feat, out_feat, 3, 2, 1, bias=False))
            main.add_module('pyramid-{0}-batchnorm'.format(out_feat),
                            nn.BatchNorm2d(out_feat))
            main.add_module('pyramid-{0}-relu'.format(out_feat),
                            nn.LeakyReLU(0.2, inplace=True))
            cndf = cndf * 2
            csize = csize / 2

        # state size. K x 4 x 4
        if add_final_conv:
            main.add_module('final-{0}-{1}-conv'.format(cndf, 1),
                            nn.Conv2d(cndf, nz, 3, 1, 0, bias=False))

        self.main = main

    def forward(self, input):
        if self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)

        return output
    
class Decoder(nn.Module):
    """
    DCGAN DECODER NETWORK
    """
    def __init__(self, isize, nz, nc, ngf, ngpu, n_extra_layers=0):
        """_summary_

        Args:
            isize (int): input image size
            nz (int): size of the latent z vector
            nc (int): input image channels
            ngf (_type_): _description_
            ngpu (int): number of GPUs to use
            n_extra_layers (int, optional): number of layers on gen and disc. Defaults to 0.
        """
        super(Decoder, self).__init__()
        self.ngpu = ngpu
        assert isize % 16 == 0, "isize has to be a multiple of 16"
        
        cngf, tisize = ngf // 2, 4
        
        while tisize != isize:
            cngf = cngf * 2
            tisize = tisize * 2

        main = nn.Sequential()
        # input is Z, going into a convolution
        main.add_module('initial-{0}-{1}-convt'.format(nz, cngf),
                        nn.ConvTranspose2d(nz, cngf, 3, 1, 0, bias=False))
        main.add_module('initial-{0}-batchnorm'.format(cngf),
                        nn.BatchNorm2d(cngf))
        main.add_module('initial-{0}-relu'.format(cngf),
                        nn.ReLU(True))

        csize, _ = 4, cngf
        while csize < isize // 2:
            main.add_module('pyramid-{0}-{1}-convt'.format(cngf, cngf // 2),
                            nn.ConvTranspose2d(cngf, cngf // 2, 4, 2, 1, bias=False))
            main.add_module('pyramid-{0}-batchnorm'.format(cngf // 2),
                            nn.BatchNorm2d(cngf // 2))
            main.add_module('pyramid-{0}-relu'.format(cngf // 2),
                            nn.ReLU(True))
            cngf = cngf // 2
            csize = csize * 2

        # Extra layers
        for t in range(n_extra_layers):
            main.add_module('extra-layers-{0}-{1}-conv'.format(t, cngf),
                            nn.Conv2d(cngf, cngf, 3, 1, 1, bias=False))
            main.add_module('extra-layers-{0}-{1}-batchnorm'.format(t, cngf),
                            nn.BatchNorm2d(cngf))
            main.add_module('extra-layers-{0}-{1}-relu'.format(t, cngf),
                            nn.ReLU(True))

        main.add_module('final-{0}-{1}-convt'.format(cngf, nc),
                        nn.ConvTranspose2d(cngf, nc, 4, 2, 1, bias=False))
        main.add_module('final-{0}-tanh'.format(nc),
                        nn.Tanh())
        self.main = main

    def forward(self, input):
        if self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)
        return output
    

In [37]:
encoder = Encoder(isize=512, nz=512, nc=4, ndf=64, ngpu=0)
z = encoder(wavlet_img)

print(z.size())

torch.Size([2, 512, 2, 2])


In [51]:
class Generator(nn.Module):
    def __init__(self, n_cls, isize, nz, nc, ngf, ngpu, n_extra_layers=0):
        super(Generator, self).__init__()
        self.n_cls = n_cls
        self.pre_z = nn.Sequential(
            nn.Linear(2048, 3*nz),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.pre_c = nn.Sequential(
            nn.Linear(n_cls, 1*nz),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.layers = Decoder(isize=isize, nz=nz, nc=nc, ngf=ngf, ngpu=ngpu, n_extra_layers=n_extra_layers)
        
    def forward(self, z, c):
        
        batch_size = z.size(0)
        
        z = z.view([batch_size, -1])
        z = self.pre_z(z)
        z = z.view([batch_size, -1, 2, 2])
        c = self.pre_c(c)
        c = c.view([batch_size, -1, 2, 2])
        
        zc = torch.cat([z,c], dim=1)

        return self.layers(zc)
    

In [48]:
generator = Generator(n_cls =2, isize=512, nz=512, nc=4, ngf=64, ngpu=0)

In [50]:
rand_z = 2 * torch.rand(wavlet_img.size(0), 512, 2, 2) - 1
x = generator(rand_z, cls)



torch.Size([2, 512, 2, 2])
torch.Size([2, 4, 512, 512])


In [26]:
x.size()

torch.Size([2, 4, 512, 512])

In [64]:

class Discriminator(nn.Module):
    def __init__(self,n_cls=2, isize=512, nz=100, nc=4, ngf=64, ngpu=0):
        super(Discriminator, self).__init__()
        
        self.pre_x = Encoder(isize=512, nz=100, nc=4, ndf=64, ngpu=0)
        # 2x2x100
        
        self.pre_z = nn.Sequential(
            nn.Linear(2048, 3*nz),
            nn.LeakyReLU(0.2, inplace=True)
        )
        # flatten(2x2x75)
        
        
        self.pre_c = nn.Sequential(
            nn.Linear(n_cls, 1*nz),
            nn.LeakyReLU(0.2, inplace=True)
        )
        # flatten(2x2x25)
        
        self.layers = nn.Sequential(
            nn.Linear(2448, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            
        
            nn.Linear(1024, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),
            

            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        

        
    def forward(self, x, z, c):

        batch_size = x.size(0)
        
        
        
        
        
        x = self.pre_x(x)
        z = z.view(batch_size, -1)

        z = self.pre_z(z)
        
        
        
        z = z.view([batch_size, -1, 2, 2])
        c = self.pre_c(c)
        c = c.view([batch_size, -1, 2, 2])
        
        zc = torch.cat([z,c], dim=1)
        xzc = torch.cat([x,zc], dim=1)
        
        # 2x2x200
        xzc = xzc.view(batch_size, -1)
        
        result = self.layers(xzc)
        
    
        return result

        

In [65]:
dis = Discriminator()

In [66]:
# print(f'x : {x.size()}, z : {z.size()}, cls : {cls.size()}')

# dis(x, z, cls)

x : torch.Size([2, 4, 512, 512]), z : torch.Size([2, 512, 2, 2]), cls : torch.Size([2, 2])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x800 and 2448x1024)

In [67]:
def D_loss(DG, DE, eps=1e-6):
    loss = torch.log(DE + eps) + torch.log(1 - DG + eps)
    return -torch.mean(loss)


def EG_loss(DG, DE, eps=1e-6):
    loss = torch.log(DG + eps) + torch.log(1 - DE + eps)
    return -torch.mean(loss)

def initialize_weights(model):
    classname = model.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(model.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(model.weight.data, 1.0, 0.02)
        nn.init.constant_(model.bias.data,0)
    elif classname.find('Linear') != -1:
        nn.init.constant_(model.bias.data, 0)

In [68]:
n_epochs = 500
l_rate = 2e-5

E = Encoder(isize=512, nz=512, nc=4, ndf=64, ngpu=0)
G = Generator(n_cls =2, isize=512, nz=512, nc=4, ngf=64, ngpu=0)
D = Discriminator(n_cls=2, isize=512, nz=512, nc=4, ngf=64, ngpu=0)

E.apply(initialize_weights)
G.apply(initialize_weights)
D.apply(initialize_weights)

optimizer_EG = torch.optim.Adam(list(E.parameters()) + list(G.parameters()), 
                                lr=l_rate, betas=(0.5, 0.999), weight_decay=1e-5)
optimizer_D = torch.optim.Adam(D.parameters(), lr=l_rate, betas=(0.5, 0.999), weight_decay=1e-5)


data_loader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)

In [69]:
from tqdm import tqdm

device = 'cpu'
for epoch in range(n_epochs):
    D_loss_acc = 0.
    EG_loss_acc = 0.
    D.train()
    E.train()
    G.train()
        
#     scheduler_D.step()
#     scheduler_EG.step()
    
    for i, (signal, images, corr_img, c) in tqdm(enumerate(data_loader)):
        
        images = images.to(device)
        
        rand_z = 2 * torch.rand(wavlet_img.size(0), 512, 2, 2) - 1
        rand_z = rand_z.to(device)
        
        #compute G(z, c) and E(X)
        
        Gz = G(rand_z, c)
        EX = E(images)
        
        print(f'image : {images.size()}, Gz : {Gz.size()}')
        print(f'rand_z : {rand_z.size()}, Ex : {EX.size()}')
        DE = D(images, EX, c)
        DG = D(Gz, rand_z, c)
        
        
        loss_EG = EG_loss(DG, DE)
        #Encoder & Generator training
        optimizer_EG.zero_grad()
        loss_EG.backward()
        optimizer_EG.step()
        #compute D(G(z, c), z, c) and D(X, E(X), c)
        #compute losses
        loss_D = D_loss(DG, DE)
        
        D_loss_acc += loss_D.item()
        EG_loss_acc += loss_EG.item()
        
        #Discriminator training
        optimizer_D.zero_grad()
        loss_D.backward(retain_graph=True)
        optimizer_D.step()
        
        
        if (epoch + 1) % 10 == 0:
            print('Epoch [{}/{}], Avg_Loss_D: {:.4f}, Avg_Loss_EG: {:.4f}'
                    .format(epoch + 1, n_epochs, D_loss_acc / i, EG_loss_acc / i))
            n_show = 10

0it [00:00, ?it/s]

image : torch.Size([2, 4, 512, 512]), Gz : torch.Size([2, 4, 512, 512])
rand_z : torch.Size([2, 512, 2, 2]), Ex : torch.Size([2, 512, 2, 2])


0it [00:14, ?it/s]


RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.