In [9]:
from generator import Generator
from discriminator import Discriminator
from BaseColor import *

# --- Inicijalizacija ---
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import cv2
import matplotlib.pyplot as plt
from torch.optim import Adam
from torch.nn import L1Loss, BCELoss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [13]:
gen = Generator().to(device)
disc = Discriminator().to(device)

gen_opt = torch.optim.Adam(gen.parameters(), lr=2e-4, betas=(0.5, 0.999))
disc_opt = torch.optim.Adam(disc.parameters(), lr=2e-4, betas=(0.5, 0.999))

In [14]:

bce_loss = nn.BCELoss()
LAMBDA = 100

def generator_loss(disc_generated_output, gen_output, target):
    # GAN loss
    gan_loss = bce_loss(disc_generated_output, torch.ones_like(disc_generated_output))
    
    # L1 loss
    l1_loss = F.l1_loss(gen_output, target)

    total_gen_loss = gan_loss + LAMBDA * l1_loss
    return total_gen_loss, gan_loss, l1_loss

def discriminator_loss(disc_real_output, disc_generated_output):
    real_loss = bce_loss(disc_real_output, torch.ones_like(disc_real_output))
    fake_loss = bce_loss(disc_generated_output, torch.zeros_like(disc_generated_output))
    total_disc_loss = real_loss + fake_loss
    return total_disc_loss


In [15]:
def train_step(input_l, target_ab, generator, discriminator,
               gen_optimizer, disc_optimizer):
    generator.train()
    discriminator.train()

    input_l = input_l.to(device)          # (1, 1, 256, 256)
    target_ab = target_ab.to(device)      # (1, 2, 256, 256)

    real_lab = torch.cat([input_l, target_ab], dim=1)  # (1, 3, 256, 256)

    # ----- Generator -----
    fake_ab = generator(input_l)                          # (1, 2, 256, 256)
    fake_lab = torch.cat([input_l, fake_ab], dim=1)       # (1, 3, 256, 256)

    # ----- Discriminator -----
    disc_real_output = discriminator(input_img=input_l.repeat(1,3,1,1), target_img=real_lab)
    disc_fake_output = discriminator(input_img=input_l.repeat(1,3,1,1), target_img=fake_lab.detach())

    # Losses
    disc_loss = discriminator_loss(disc_real_output, disc_fake_output)

    disc_optimizer.zero_grad()
    disc_loss.backward()
    disc_optimizer.step()

    # Generator loss (after disc update)
    disc_fake_output = discriminator(input_img=input_l.repeat(1,3,1,1), target_img=fake_lab)
    gen_loss, gan_loss, l1_loss = generator_loss(disc_fake_output, fake_ab, target_ab)

    gen_optimizer.zero_grad()
    gen_loss.backward()
    gen_optimizer.step()

    return gen_loss.item(), disc_loss.item()


In [None]:
num_epochs = 50

In [None]:
for epoch in range(num_epochs):
	for input_l, target_ab in dataloader:  # sredi dataloader
		losses = train_step(gen, disc, gen_opt, disc_opt, input_l, target_ab, device)
		print(f"Ep {epoch}: Gen={losses['gen_total']:.3f}, D={losses['disc']:.3f}")
