In [None]:
import torch
import torch.nn as nn
from torchvision import models
import torch.utils.data as data
from constants import *


class CycleGAN(nn.Module):
    def __init__(self,config_data):
        super().__init__()
        self.criterionGAN = nn.MSELoss().to(self.device)
        self.criterionCycle = torch.nn.L1Loss()
        self.register_buffer('real_label', torch.tensor(1.0))
        self.register_buffer('fake_label', torch.tensor(0.0))
        self.model_G_A = None
        self.model_G_B = None
        self.model_D_A = None
        self.model_D_B = None
        self.optimizerG = torch.optim.Adam(itertools.chain(self.model_G_A.parameters(), self.model_G_B.parameters()), lr=self.lr, betas=(self.beta, 0.999))
        self.optimizerD = torch.optim.Adam(itertools.chain(self.model_D_A.parameters(), self.model_D_B.parameters()), lr=self.lr, betas=(self.beta, 0.999))
        
        
    def forward(self, input_image):
        '''G_A is generating A from B and G_B is generating B from A'''
        self.real_A = input_image['A'].to(self.device)
        self.real_B = input_image['B'].to(self.device)
        
        #G_A(B)
        self.fake_A = self.model_G_A(self.real_B)
        
        #G_B(A)
        self.fake_B = self.model_G_B(self.read_A)
        
        #G_A(G_B(A))
        self.recreate_A = self.model_G_A(self.fake_B)
        
        #G_B(G_A(B))
        self.recreate_B = self.model_G_B(self.fake_A)
    def basic_D_backward(self, model_D, real, fake):
        pred_real = model_D(real)
        loss_D_real = self.criterionGAN(pred_real, self.real_label.expand_as(pred_real))
        pred_fake = model_D(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, self.fake_label.expand_as(pred_fake))
        # Combined loss and calculate gradients
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        return loss_D
    def D_A_backward(self, real, fake):
        '''D_A is discriminate A'''
        self.loss_D_A = self.basic_D_backward(self.model_D_A, self.real_A, self.fake_A)
    def D_B_backward(self, real, fake):
        self.loss_D_B = self.basic_D_backward(self.model_D_B, self.real_B, self.fake_B)
    def backward_G(self):
        check_G_A = self.model_D_A(self.fakeA)
        self.loss_G_A = self.criterionGAN(check_G_A,self.real_label.expand_as(check_G_A))
        check_G_B = self.model_D_B(self.fakeB)
        self.loss_G_B = self.criterionGAN(check_G_B,self.real_label.expand_as(check_G_B))
        self.loss_cycle_A = self.criterionCycle(self.recreate_A, self.real_A) * self.lambda_A
        self.loss_cycle_B = self.criterionCycle(self.recreate_B, self.real_B) * self.lambda_B
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B
        self.loss_G.backward()
    def update(self):
        self.forward()
        self.set_model_grad([self.model_D_A, sel.model_D_B], False)
        self.optimizerG.zero_grad()
        self.backward_G()
        self.optimizerG.step()
        self.set_model_grad([self.model_D_A, sel.model_D_B], True)
        self.optimizerD.zero_grad()
        self.D_A_backward()
        self.D_B_backward()
        self.optimizerD.step()
        
        
        
        
    def set_model_grad(self, nets, requires):
        for net in nets:
            if net is not None:
                for para in net.parameters():
                    para.require_grad = requires
        
        
        
        
def get_model(config_data):
    pass