# Trening generator + diskriminator

In [1]:
import os
from PIL import Image
import numpy as np
from skimage import color
import torch
from skimage.color import lab2rgb
from torch.utils.data import Dataset
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import time
import matplotlib.pyplot as plt
import torchvision.transforms.functional as TF
from generator import Generator
from discriminator import Discriminator
from BaseColor import *

In [2]:

def find_images_recursive(folder, extensions=('.JPEG', '.jpeg', '.jpg', '.JPG')):
    img_paths = []
    for root, dirs, files in os.walk(folder):
        for f in files:
            if f.endswith(extensions):
                img_paths.append(os.path.abspath(os.path.join(root, f)))
    return img_paths


In [3]:
class ColorizationDataset(Dataset, BaseColor):
    def __init__(self, root_dir, HW=(256, 256), extensions=('.JPEG', '.jpeg', '.jpg', '.JPG')):
        Dataset.__init__(self)
        BaseColor.__init__(self)
        self.HW = HW
        self.img_paths = find_images_recursive(root_dir, extensions)
        print(f"Pronađeno {len(self.img_paths)} slika u {root_dir} i podfolderima.")

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]

        try:
            # img_rgb_orig = load_img(img_path)

            # tens_l, tens_ab = preprocess_img(img_rgb_orig, HW=self.HW)

            # return tens_l, tens_ab
            img_rgb_orig = load_img(img_path)

            tens_orig_l, tens_rs_l = preprocess_img(img_rgb_orig, HW=self.HW)

            img_rgb_rs = resize_img(img_rgb_orig, HW=self.HW, resample=Image.BILINEAR)
            img_lab_rs = color.rgb2lab(img_rgb_rs).astype(np.float32)

            img_ab_rs = img_lab_rs[:, :, 1:3]
            tens_ab = torch.tensor(img_ab_rs).permute(2,0,1).float()

            tens_l_norm = self.normalize_l(tens_rs_l.squeeze(0))
            tens_ab_norm = self.normalize_ab(tens_ab)

            return tens_l_norm, tens_ab_norm

        except Exception as e:
            print(f"Preskačem fajl zbog greške: {img_path}\nGreška: {e}")
            return self.__getitem__((idx + 1) % len(self))


In [4]:
dataset = ColorizationDataset("images_pfe/8/train.X1/n01440764", HW=(256, 256))
input_l, target_ab = dataset[0]

Pronađeno 1300 slika u images_pfe/8/train.X1/n01440764 i podfolderima.


In [5]:
target_ab.shape, input_l.shape

(torch.Size([2, 256, 256]), torch.Size([1, 256, 256]))

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 4
num_epochs = 10
learning_rate = 1e-3
HW = (256, 256) 


In [7]:

generator = Generator().to(device)
discriminator = Discriminator().to(device)

gen_opt = torch.optim.Adam(generator.parameters(), learning_rate, betas=(0.5, 0.999))
disc_opt = torch.optim.Adam(discriminator.parameters(), learning_rate, betas=(0.5, 0.999))


In [8]:
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

# Funkcije

In [9]:
def numpy_to_tensor(np_img):
    if isinstance(np_img, np.ndarray):
        tensor = torch.from_numpy(np_img.astype(np.float32))
        if tensor.ndim == 3:
            tensor = tensor.permute(2, 0, 1)
        elif tensor.ndim == 4:
            tensor = tensor.permute(0, 3, 1, 2)
        return tensor
    else:
        raise TypeError("Ulaz nije NumPy niz")
    
def lab_to_rgb_image(input_l, out_ab):
    base_color = BaseColor()
    l_denorm = base_color.unnormalize_l(input_l)
    ab_denorm = base_color.unnormalize_ab(out_ab)

    lab = torch.cat([l_denorm, ab_denorm], dim=1)
    lab_np = lab[0].permute(1, 2, 0).cpu().numpy()

    rgb = lab2rgb(lab_np)
    return rgb

def safe_lab2rgb(lab):
    lab_clipped = lab.copy()
    lab_clipped[..., 0] = np.clip(lab_clipped[..., 0], 0, 100)     
    lab_clipped[..., 1] = np.clip(lab_clipped[..., 1], -110, 110)  
    lab_clipped[..., 2] = np.clip(lab_clipped[..., 2], -110, 110)  

    rgb = lab2rgb(lab_clipped)
    return rgb

def show_images(real, fake):
    real = real.permute(1, 2, 0).detach().cpu().numpy()
    fake = fake.permute(1, 2, 0).detach().cpu().numpy()

    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    axs[0].imshow(real)
    axs[0].set_title("Originalna slika")
    axs[0].axis('off')

    axs[1].imshow(fake)
    axs[1].set_title("Generisana slika")
    axs[1].axis('off')

    plt.tight_layout()
    plt.show()

# Checkpoint

In [10]:
start_epoch = 0
checkpoint_path = 'checkpoint.pth1'

if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    generator.load_state_dict(checkpoint['generator_state_dict'])
    discriminator.load_state_dict(checkpoint["discriminator_state_dict"])
    gen_opt.load_state_dict(checkpoint['optimizer_g_dict'])
    disc_opt.load_state_dict(checkpoint['optimizer_d_dict'])
    start_epoch = checkpoint['epoch'] + 1
    global_step = checkpoint.get('global_step', 0)
    print(f"Nastavljam od epohe {start_epoch}, global_step = {global_step}")
else:
    global_step = 0

# Generator loss

In [11]:
bce_loss = nn.BCEWithLogitsLoss()
LAMBDA = 100

def delta_e_loss(input_l, ab_pred, ab_true):
    base_color = BaseColor()
    l_denorm = base_color.unnormalize_l(input_l)
    ab_pred_denorm = base_color.unnormalize_ab(ab_pred)
    ab_true_denorm = base_color.unnormalize_ab(ab_true)

    lab_pred = torch.cat([l_denorm, ab_pred_denorm], dim=1)
    lab_true = torch.cat([l_denorm, ab_true_denorm], dim=1)

    delta_e = torch.norm(lab_pred - lab_true, dim=1)  
    return delta_e.mean()

def colorfulness_loss(ab_pred):
    return -ab_pred.abs().mean()

def gan_loss(disc_generated_output):
    gan_loss = bce_loss(torch.ones_like(disc_generated_output), disc_generated_output)
    return gan_loss

def generator_loss(input_l,ab_pred,ab_true,disc_generated_output):
    color_loss = delta_e_loss(input_l, ab_pred, ab_true)
    boost_color = colorfulness_loss(ab_pred)
    gan_loss = gan_loss(disc_generated_output)
    loss = color_loss + 0.1 * boost_color + gan_loss
    return loss


# Diskriminator loss

In [12]:
bce_loss = nn.BCEWithLogitsLoss()
LAMBDA = 100

def discriminator_loss(disc_real_output, disc_generated_output):
    real_loss = bce_loss(disc_real_output, torch.ones_like(disc_real_output))
    fake_loss = bce_loss(disc_generated_output, torch.zeros_like(disc_generated_output))
    total_disc_loss = real_loss + fake_loss
    return total_disc_loss

In [13]:
import warnings
import matplotlib.pyplot as plt

# warnings.filterwarnings("ignore", message="Conversion from CIE-LAB, via XYZ to sRGB color space resulted in .* negative Z values.*")

epoch_losses = [] 
global_step = 0

for epoch in range(start_epoch, num_epochs):
    G.train()
    epoch_loss = 0.0
    start_time = time.time()

    for batch_idx, (input_l, target_ab) in enumerate(dataloader):
        input_l = input_l.to(device)        
        target_ab = target_ab.to(device) 
        g_optimizer.zero_grad()

        output_ab = G(input_l)         

        color_loss = delta_e_loss(input_l, output_ab, target_ab)
        boost_color = colorfulness_loss(output_ab)
        loss = color_loss + 0.1 * boost_color
        loss.backward()
        g_optimizer.step()

        epoch_loss += loss.item()

        if global_step % 100 == 0:
            with torch.no_grad():
                fake_rgb_np = lab_to_rgb_image(input_l, output_ab)
                real_rgb_np = lab_to_rgb_image(input_l, target_ab)

                fake_rgb = numpy_to_tensor(fake_rgb_np).to(device)  
                real_rgb = numpy_to_tensor(real_rgb_np).to(device)

                fake_rgb_vis = fake_rgb.clone().clamp(0, 1)
                real_rgb_vis = real_rgb.clone().clamp(0, 1)

            N = min(4, fake_rgb_vis.size(0))
            if fake_rgb_vis.ndim == 3:  
                fake_rgb_vis = fake_rgb_vis.unsqueeze(0)  
            if real_rgb_vis.ndim == 3:
                real_rgb_vis = real_rgb_vis.unsqueeze(0)
    
            fake_rgb_vis_uint8 = (fake_rgb_vis * 255).clamp(0, 255).to(torch.uint8)
            real_rgb_vis_uint8 = (real_rgb_vis * 255).clamp(0, 255).to(torch.uint8)

            if global_step % 1000 == 0:
                show_images(real_rgb_vis_uint8[0], fake_rgb_vis_uint8[0])

            print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(dataloader)}], Loss: {loss.item():.4f}")

        global_step += 1

    avg_epoch_loss = epoch_loss / len(dataloader)
    epoch_losses.append(avg_epoch_loss) 
    elapsed = time.time() - start_time
    print(f"Epoch {epoch+1} završena. Prosečni loss: {avg_epoch_loss:.4f}. Vreme epohe: {elapsed:.1f} s")
    
    torch.save({
        'epoch': epoch,
        'model_state_dict': G.state_dict(),
        'optimizer_state_dict':g_optimizer.state_dict(),
        'global_step': global_step
    }, checkpoint_path)

print("Trening završen.")

plt.figure(figsize=(8, 5))
plt.plot(range(start_epoch + 1, start_epoch + 1 + len(epoch_losses)), epoch_losses, marker='o')
plt.title('Loss tokom epoha')
plt.xlabel('Epoka')
plt.ylabel('Prosečni loss')
plt.grid(True)
plt.show()



NameError: name 'G' is not defined

In [14]:
input_l.shape,target_ab.shape

(torch.Size([1, 256, 256]), torch.Size([2, 256, 256]))

In [15]:
input_l.max(),input_l.min(),target_ab.max(),target_ab.min()

(tensor(0.5000), tensor(-0.4994), tensor(0.5773), tensor(-0.2923))

In [16]:
input_l = input_l.unsqueeze(0)    # (1, 1, 256, 256)
target_ab = target_ab.unsqueeze(0)                # (1, 2, 256, 256)
input_l.shape, target_ab.shape

(torch.Size([1, 1, 256, 256]), torch.Size([1, 2, 256, 256]))

In [17]:

input_l = input_l.to(device)        
target_ab = target_ab.to(device)
fake_ab = generator(input_l)

# Potpun trening

In [18]:
fake_ab.shape

torch.Size([1, 2, 256, 256])

In [19]:
fake_ab.max(),fake_ab.min()

(tensor(0.1463, grad_fn=<MaxBackward1>),
 tensor(-0.1693, grad_fn=<MinBackward1>))

In [20]:
fake_lab = torch.cat([input_l, fake_ab], dim=1) 
real_lab = torch.cat([input_l, target_ab], dim=1)
input_rgb = input_l.repeat(1, 3, 1, 1)  

In [21]:
fake_lab.shape,real_lab.shape,input_rgb.shape

(torch.Size([1, 3, 256, 256]),
 torch.Size([1, 3, 256, 256]),
 torch.Size([1, 3, 256, 256]))

In [22]:

        
disc_real_out = discriminator(input_rgb, real_lab)
disc_fake_out = discriminator(input_rgb, fake_lab.detach())


In [23]:
disc_fake_out.shape, disc_real_out.min()

(torch.Size([1, 30, 30]), tensor(-0.9526, grad_fn=<MinBackward1>))

In [24]:
loss = discriminator_loss(disc_real_out, disc_fake_out)
loss

tensor(1.4242, grad_fn=<AddBackward0>)

In [None]:
import warnings
import matplotlib.pyplot as plt

# warnings.filterwarnings("ignore", message="Conversion from CIE-LAB, via XYZ to sRGB color space resulted in .* negative Z values.*")

epoch_losses = [] 
global_step = 0

for epoch in range(start_epoch, num_epochs):
    generator.train()
    discriminator.train()
    epoch_gen_loss = 0.0
    epoch_disc_loss = 0.0
    start_time = time.time()
    print("pocetak treninga")
    for batch_idx, (input_l, target_ab) in enumerate(dataloader):
        print("prva epoha")
        input_l = input_l.to(device)        
        target_ab = target_ab.to(device)


        fake_ab = generator(input_l)
        input_l = input_l.unsqueeze(0)    # (1, 1, 256, 256)
        target_ab = target_ab.unsqueeze(0) # 1 2 256 256
        fake_lab = torch.cat([input_l, fake_ab], dim=1) 
        real_lab = torch.cat([input_l, target_ab], dim=1)
        input_rgb = input_l.repeat(1, 3, 1, 1)  
        
        # trening diskriminatora
        
        disc_real_out = discriminator(input_rgb, real_lab)
        disc_fake_out = discriminator(input_rgb, fake_lab.detach())

        d_loss = discriminator_loss(disc_real_out, disc_fake_out)

        disc_opt.zero_grad()
        d_loss.backward()
        disc_opt.step()

    #   trening generatora
        disc_fake_out = discriminator(input_rgb, fake_lab)
        g_loss = generator_loss(input_l,target_ab,fake_ab,disc_fake_out)

        gen_opt.zero_grad()
        g_loss.backward()
        gen_opt.step()

        epoch_gen_loss += g_loss.item()
        epoch_disc_loss += d_loss.item()


        if global_step % 100 == 0:
            with torch.no_grad():
                fake_rgb_np = lab_to_rgb_image(input_l, fake_ab)
                real_rgb_np = lab_to_rgb_image(input_l, target_ab)

                fake_rgb = numpy_to_tensor(fake_rgb_np).to(device)  
                real_rgb = numpy_to_tensor(real_rgb_np).to(device)

                fake_rgb_vis = fake_rgb.clone().clamp(0, 1)
                real_rgb_vis = real_rgb.clone().clamp(0, 1)

            if fake_rgb_vis.ndim == 3:  
                fake_rgb_vis = fake_rgb_vis.unsqueeze(0)  
            if real_rgb_vis.ndim == 3:
                real_rgb_vis = real_rgb_vis.unsqueeze(0)

            fake_rgb_vis_uint8 = (fake_rgb_vis * 255).clamp(0, 255).to(torch.uint8)
            real_rgb_vis_uint8 = (real_rgb_vis * 255).clamp(0, 255).to(torch.uint8)

            if global_step % 1000 == 0:
                show_images(real_rgb_vis_uint8[0], fake_rgb_vis_uint8[0])

            print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(dataloader)}], "
                  f"Gen Loss: {g_loss.item():.4f}, Disc Loss: {d_loss.item():.4f}, "
                  f"GAN: {g_loss.item():.4f}")

        global_step += 1

    avg_gen_loss = epoch_gen_loss / len(dataloader)
    avg_disc_loss = epoch_disc_loss / len(dataloader)
    epoch_losses.append(avg_gen_loss) 

    elapsed = time.time() - start_time
    print(f"Epoch {epoch+1} završena. G-Loss: {avg_gen_loss:.4f}, D-Loss: {avg_disc_loss:.4f}. Vreme: {elapsed:.1f} s")

    # ============ 5. ČUVANJE MODEL STATE ============
    torch.save({
        'epoch': epoch,
        'generator_state_dict': generator.state_dict(),
        'discriminator_state_dict': discriminator.state_dict(),
        'optimizer_g': gen_opt.state_dict(),
        'optimizer_d': disc_opt.state_dict(),
        'global_step': global_step
    }, checkpoint_path)

print("Trening završen.")


plt.figure(figsize=(8, 5))
plt.plot(range(start_epoch + 1, start_epoch + 1 + len(epoch_losses)), epoch_losses, marker='o')
plt.title('Generator Loss tokom epoha')
plt.xlabel('Epoha')
plt.ylabel('Prosečni generator loss')
plt.grid(True)
plt.show()


pocetak treninga


