In [3]:
from dataset import dataset
from models import *
from utils import save_img

import numpy as np
import os
from glob import glob
import matplotlib.pyplot as plt

import torch
import torchvision

In [4]:
BATCH_SIZE = 1
WH = (256,256)
NUM_EPOCHS = 200
LEARNING_RATE = 1e-4
BETAS = (0.5, 0.999)

In [5]:
train_dataset = dataset(root_dir="/jupyterdata/horse2zebra/trainA/", w=WH[0], h = WH[1])
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = BATCH_SIZE, num_workers = 4, pin_memory=True, drop_last = True)

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

feature_extractor = Feature_Extraction().to(device)
generator = Generator().to(device)
discriminator = Discriminator(WH, [64,128,256,512]).to(device)

In [8]:
criterion_MSE = torch.nn.MSELoss()
criterion_L1 = torch.nn.L1Loss()

optimizer_G = torch.optim.Adam(generator.parameters(), lr = LEARNING_RATE, betas=BETAS)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr = LEARNING_RATE, betas=BETAS)

In [9]:
real_label = torch.autograd.Variable(torch.ones(BATCH_SIZE, *discriminator.output_size), requires_grad=False).to(device)
gene_label = torch.autograd.Variable(torch.zeros(BATCH_SIZE, *discriminator.output_size), requires_grad=False).to(device)

for epoch in range(NUM_EPOCHS):
    for batch, (real_lr_img, real_hr_img) in enumerate(train_loader):
        real_lr_img = real_lr_img.to(device)
        real_hr_img = real_hr_img.to(device)
        
        # training discrimiantor
        optimizer_D.zero_grad()
        gene_hr_img = generator(real_lr_img)
        real_hr_logit = discriminator(real_hr_img)
        gene_hr_logit = discriminator(gene_hr_img.detach())
         
        D_real_loss = criterion_MSE(real_hr_logit, real_label)
        D_gene_loss = criterion_MSE(gene_hr_logit, gene_label)
        
        D_loss = (D_real_loss + D_gene_loss) / 2
        D_loss.backward()
        optimizer_D.step()
        
        # training generator
        optimizer_G.zero_grad()
        adv_loss = criterion_MSE(discriminator(gene_hr_img), real_label)
        content_loss = criterion_L1(feature_extractor(gene_hr_img), feature_extractor(real_hr_img))
        loss_G = content_loss + 1e-3 * adv_loss
        loss_G.backward()
        optimizer_G.step()
        
        if batch % 50 == 0:
            print("D loss : {}, G loss : {}".format(D_loss, G_loss))
        
    if epoch % 5 == 0:
        save_img(gene_hr_img, "{}".format(epoch))        

  return F.mse_loss(input, target, reduction=self.reduction)


RuntimeError: The size of tensor a (8) must match the size of tensor b (42) at non-singleton dimension 3