#Projekt_3

## Moduły

In [None]:
import os
import json
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import torch
import torch.nn as nn
import sys
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision.utils import save_image
from tqdm import tqdm
import torchvision.models as models
from google.colab import drive
import matplotlib.pyplot as plt
import numpy as np

## Upload datasetu z dysku

In [None]:
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
ROOT_PROJECT_FOLDER = '/content/drive/MyDrive/projekt_3'

##Dataset

In [None]:
class PhongDataset(Dataset):
    # Zakresy teoretyczne wektorów relatywnych:
    MAX_VECTOR_RANGE = 30.0

    SHININESS_MIN = 3.0
    SHININESS_MAX = 20.0

    def __init__(self, root_dir):
        self.root_dir = root_dir

        self.files = sorted([f for f in os.listdir(root_dir) if f.endswith('.json')])

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def __len__(self):
        return len(self.files)

    @staticmethod
    def normalize_data(rel_light, rel_view, diffuse, shininess):
        norm_vector = []

        # w. Relatywne )
        # [-30, 30] -> [-1, 1]
        norm_vector.extend([x / PhongDataset.MAX_VECTOR_RANGE for x in rel_light])
        norm_vector.extend([x / PhongDataset.MAX_VECTOR_RANGE for x in rel_view])

        # [0.0, 1.0] do [-1, 1]
        norm_vector.extend([(x * 2.0) - 1.0 for x in diffuse])

        # [3.0, 20.0] do [-1, 1]
        s_01 = (shininess - PhongDataset.SHININESS_MIN) / (PhongDataset.SHININESS_MAX - PhongDataset.SHININESS_MIN)
        s_norm = (s_01 * 2.0) - 1.0
        norm_vector.append(s_norm)

        return norm_vector

    def __getitem__(self, idx):
        json_name = self.files[idx]
        json_path = os.path.join(self.root_dir, json_name)

        with open(json_path, 'r') as f:
            data = json.load(f)

        input_list = PhongDataset.normalize_data(
            rel_light=data["relative_light_vector"],
            rel_view=data["relative_view_vector"],
            diffuse=data["material_diffuse"],
            shininess=data["material_shininess"]
        )

        input_tensor = torch.tensor(input_list, dtype=torch.float32)

        # obraz
        img_name = data["file_name"]
        img_path = os.path.join(self.root_dir, img_name)

        image = Image.open(img_path).convert("RGB")
        image_tensor = self.transform(image)

        return input_tensor, image_tensor

## Model

In [None]:
import torch
import torch.nn as nn


PARAM_DIM = 10          # Wymiar parametrów IN
IMG_CHANNELS = 3        # RGB
IMG_SIZE = 128          # 128x128

class Generator(nn.Module):

    def __init__(self):
        super(Generator, self).__init__()

        self.input_dim = PARAM_DIM

        self.initial_layer = nn.Sequential(
            nn.Linear(self.input_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(True),

            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(True),

            nn.Linear(512, 512 * 4 * 4),
            nn.BatchNorm1d(512 * 4 * 4),
            nn.ReLU(True)
        )

        self.model = nn.Sequential(

            # 4x4 -> 8x8
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            # 8x8 -> 16x16
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            # 16x16 -> 32x32
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            # 32x32 -> 64x64
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(True),

            # 64x64 -> 128x128 (Output)
            nn.ConvTranspose2d(32, IMG_CHANNELS, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh() # Wyjście [-1, 1]
        )

    def forward(self, noise, labels):

        x = self.initial_layer(labels)

        x = x.view(-1, 512, 4, 4)
        img = self.model(x)
        return img

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.label_embedding = nn.Sequential(
            nn.Linear(PARAM_DIM, 128),
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.image_processing = nn.Sequential(
            # Input: 3 x 128 x 128

            nn.Conv2d(IMG_CHANNELS, 16, kernel_size=4, stride=2, padding=1, bias=False), # Start od 16
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=1, bias=False), # Max 32
            nn.InstanceNorm2d(32, affine=True),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1, bias=False), # Max 64
            nn.InstanceNorm2d(64, affine=True),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False), # Max 128
            nn.InstanceNorm2d(128, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
        )

        # Obrazek: 128 kanałów * 8 * 8 pikseli = 8192
        self.flatten_size = 128 * 8 * 8

        self.classifier = nn.Sequential(
            nn.Linear(self.flatten_size + 128, 256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Dropout(0.5),

            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        features = self.image_processing(img)
        features_flat = features.view(features.size(0), -1)

        label_emb = self.label_embedding(labels)
        concat_input = torch.cat((features_flat, label_emb), dim=1)

        validity = self.classifier(concat_input)
        return validity

def weights_init(m):
    classname = m.__class__.__name__
    if 'Conv' in classname:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif 'BatchNorm' in classname:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

Generator output: torch.Size([5, 3, 128, 128])
Discriminator output: torch.Size([5, 1])


##Parametry treningu

In [None]:
LEARNING_RATE_GEN = 0.0001
LEARNING_RATE_DISC = 0.000005
BATCH_SIZE = 32
NUM_EPOCHS = 289
L1_LAMBDA = 20.0

LOAD_MODEL = True
START_EPOCH =0

DISC_WARMUP_EPOCHS = 0


dataset_dir = ROOT_PROJECT_FOLDER + "/dataset"
checkpoint_dir = ROOT_PROJECT_FOLDER + "/model_2_checkpoints_v1"
evaluation_dir = ROOT_PROJECT_FOLDER + "/model_2_evaluation_samples_v1" #v3 - dynamic

## Trening

In [None]:
def get_loaders(root_dir, batch_size):
    # Podział: Train (2200), Val (200), Test (600)
    # Deterministyczny - manual_seed(42)

    dataset = PhongDataset(root_dir)

    total_len = len(dataset)
    test_len = 600
    val_len = 200
    train_len = total_len - test_len - val_len # (2200)

    generator = torch.Generator().manual_seed(42)

    train_ds, val_ds, test_ds = random_split(
        dataset,
        [train_len, val_len, test_len],
        generator=generator
    )

    print(f"Podział danych: Train={len(train_ds)}, Val={len(val_ds)}, Test={len(test_ds)}")

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True) #pod cpu
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


# Definicje globalne (potrzebne do normalizacji VGG)
VGG_WEIGHT = 0.6
WARMUP_EPOCHS = 60
VGG_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(DEVICE)
VGG_STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(DEVICE)

def setup_vgg_model(device):

    vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features.to(device)


    feature_layers = [4, 9, 16]
    blocks = []

    current_layer = 0
    for i in feature_layers:
        block = nn.Sequential(*vgg[current_layer:i]).eval()
        for p in block.parameters():
            p.requires_grad = False
        blocks.append(block)
        current_layer = i

    return nn.ModuleList(blocks).to(device)

def calculate_vgg_loss(vgg_model, fake_img, real_img):
    fake = nn.functional.interpolate(fake_img, size=(224, 224), mode='bilinear', align_corners=False)
    real = nn.functional.interpolate(real_img, size=(224, 224), mode='bilinear', align_corners=False)

    fake = (fake - VGG_MEAN) / VGG_STD
    real = (real - VGG_MEAN) / VGG_STD


    loss = 0.0
    for block in vgg_model:
        fake = block(fake)
        real = block(real)

        loss += torch.nn.functional.l1_loss(fake, real)

    return loss

os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(evaluation_dir, exist_ok=True)



def train_cooldown_fn(gen, loader, opt_gen, l1_loss, epoch):
    loop = tqdm(loader, leave=True)

    for idx, (inputs, real_images) in enumerate(loop):
        inputs = inputs.to(DEVICE)
        real_images = real_images.to(DEVICE)


        fake_images = gen(None, inputs)

        loss_g_l1, raw_l1, _ = calculate_masked_loss(fake_images, real_images, lambda_val=1.0) # Lambda nie ma znaczenia bo jest jeden loss

        g_loss = loss_g_l1

        gen.zero_grad()
        g_loss.backward()
        opt_gen.step()

        loop.set_postfix(Mode="COOLDOWN", L1=f"{raw_l1.item():.5f}")

def calculate_masked_loss(fake, real, lambda_val):

    l1_diff = torch.abs(fake - real)


    mask = (real > -0.98).float()
    mask_pct = mask.mean().item() # Ile % obrazka to kula

    weights = 1.0 + (mask * 10)


    loss = (l1_diff * weights).mean() * lambda_val

    return loss, l1_diff.mean(), mask_pct

def save_checkpoint(model, optimizer, filename="checkpoint.pth"):
    print(f"=> Zapisywanie checkpointu do {filename}")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)

def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print(f"=> Wczytywanie checkpointu {checkpoint_file}")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

def check_accuracy(val_loader, gen, disc, device, epoch, l1_loss_fn, bce_loss_fn, folder=evaluation_dir):
    gen.eval()
    disc.eval()


    total_raw_l1 = 0.0
    total_masked_l1 = 0.0
    total_g_loss = 0.0
    total_d_loss = 0.0

    total_fake_score = 0.0
    total_real_score = 0.0

    correct_real = 0
    correct_fake = 0
    total_samples = 0
    num_batches = 0

    print(f"\n=== RAPORT VALIDATION (EPOCH {epoch}) ===")

    with torch.no_grad():
        for i, (inputs, targets) in enumerate(val_loader):
            inputs = inputs.to(device)
            targets = targets.to(device)
            batch_size = inputs.shape[0]

            fake_images = gen(None, inputs)


            # Real
            d_real = disc(targets, inputs)
            loss_d_real = bce_loss_fn(d_real, torch.ones_like(d_real)) # Bez label smoothing w eval

            # Fake
            d_fake = disc(fake_images, inputs)
            loss_d_fake = bce_loss_fn(d_fake, torch.zeros_like(d_fake))

            # Loss D
            d_loss = (loss_d_real * 1.5 + loss_d_fake * 0.5) /2
            total_d_loss += d_loss.item()


            total_real_score += d_real.mean().item()
            total_fake_score += d_fake.mean().item()


            correct_real += (d_real > 0.5).sum().item()
            correct_fake += (d_fake < 0.5).sum().item()
            total_samples += batch_size


            loss_g_gan = bce_loss_fn(d_fake, torch.ones_like(d_fake))

            masked_l1, raw_l1, _ = calculate_masked_loss(fake_images, targets, L1_LAMBDA)

            g_loss = loss_g_gan + masked_l1

            total_g_loss += g_loss.item()
            total_masked_l1 += masked_l1.item()
            total_raw_l1 += raw_l1.item()

            num_batches += 1

            if i == 0:
                if False: #wylaczenie

                    for j in range(min(16, len(targets))):

                        # Pojedyncze pary

                        real_img = targets[j:j+1]

                        fake_img = fake_images[j:j+1]



                        # Pomiary lokalne

                        l1_val = torch.abs(fake_img - real_img).mean().item()

                        mask = (real_img > -0.98).float()

                        mask_p = mask.mean().item()



                        weights = 1.0 + (mask * 19.0)

                        l1_weighted = (torch.abs(fake_img - real_img) * weights).mean().item() * L1_LAMBDA


                        sc_real = d_real[j].item()

                        sc_fake = d_fake[j].item()



                        status = ""

                        if sc_fake > 0.5: status += "OSZUKANY! "



                        print(f"{j:<3} | {sc_real:.4f}   | {sc_fake:.4f}   | {l1_val:.4f}   | {l1_weighted:.4f}   | {mask_p*100:.1f}%  | {status}")

                debug_mask = (targets > -0.98).float()
                img_grid = torch.cat((targets[:16], fake_images[:16], debug_mask[:16]), dim=0)
                save_image(img_grid * 0.5 + 0.5, f"{folder}/epoch_{epoch}.png")

    avg_raw_l1 = total_raw_l1 / num_batches
    avg_masked_l1 = total_masked_l1 / num_batches
    avg_g_loss = total_g_loss / num_batches
    avg_d_loss = total_d_loss / num_batches

    avg_real_score = total_real_score / num_batches
    avg_fake_score = total_fake_score / num_batches

    acc_real = correct_real / total_samples
    acc_fake = correct_fake / total_samples

    print("-" * 75)
    print(f"LOSSES  => G: {avg_g_loss:.4f} | D: {avg_d_loss:.4f}")
    print(f"L1      => Raw: {avg_raw_l1:.4f} | Weighted: {avg_masked_l1:.4f}")
    print(f"SCORES  => Real(D): {avg_real_score:.4f} | Fake(D): {avg_fake_score:.4f}")
    print(f"ACC     => Real: {acc_real:.2%} | Fake: {acc_fake:.2%}")
    print("=" * 75)

    gen.train()
    disc.train()


ACC_TARGET_MIN = 0.65
ACC_TARGET_MAX = 0.90
MAX_D_REPEATS = 3
MAX_G_REPEATS = 2

CONSTANT_NOISE_STD = 0.01

def train_dynamic_fn(disc, gen, loader, opt_disc, opt_gen, l1_loss, bce, epoch, vgg_criterion):
    loop = tqdm(loader, leave=True)

    d_total_samples = 0
    d_real_correct = 0
    d_fake_correct = 0


    total_d_boosts = 0
    total_g_boosts = 0
    total_d_steps = 0
    total_g_steps = 0

    for idx, (inputs, real_images) in enumerate(loop):
        inputs = inputs.to(DEVICE)
        real_images = real_images.to(DEVICE)
        batch_size = inputs.shape[0]


        d_loops = 0

        while d_loops < MAX_D_REPEATS:
            with torch.no_grad():
                fake_images = gen(None, inputs)

            noise = torch.randn_like(real_images) * CONSTANT_NOISE_STD

            real_noisy = real_images + noise
            fake_noisy = fake_images.detach() + noise # Detach + Szum

            d_real = disc(real_noisy, inputs)
            loss_d_real = bce(d_real, torch.ones_like(d_real) * 0.9)

            d_fake = disc(fake_noisy, inputs)
            loss_d_fake = bce(d_fake, torch.zeros_like(d_fake))

            loss_d = (loss_d_real * 1.5 + loss_d_fake * 0.5) /2 #211

            acc_real = (d_real > 0.5).float().mean().item()
            acc_fake = (d_fake < 0.5).float().mean().item()
            current_acc = (acc_real + acc_fake) / 2



            is_unbalanced = abs(acc_real - acc_fake) > 0.4
            is_blind_to_real = acc_real < 0.45
            is_weak_overall = current_acc < ACC_TARGET_MAX

            # Decyzja o treningu D
            should_train_d = (is_weak_overall or is_blind_to_real or is_unbalanced)

            if current_acc > 0.98: should_train_d = False

            if should_train_d:
                disc.zero_grad()
                loss_d.backward()
                opt_disc.step()
                d_loops += 1
                total_d_steps += 1
            else:
                break
            if current_acc > ACC_TARGET_MIN and not is_blind_to_real:
                break

        # Statystyki
        d_real_correct += (d_real > 0.5).sum().item()
        d_fake_correct += (d_fake < 0.5).sum().item()
        d_total_samples += batch_size

        if d_loops > 1: total_d_boosts += 1


        # Jeśli D jest za madry, G trenuje 2 razy
        g_repeats = MAX_G_REPEATS if current_acc >= ACC_TARGET_MAX else 1
        if g_repeats > 1: total_g_boosts += 1

        for i in range(g_repeats):
            fake_for_g = gen(None, inputs)
            noise_g = torch.randn_like(fake_for_g) * CONSTANT_NOISE_STD
            fake_g_noisy = fake_for_g + noise_g

            d_fake_g = disc(fake_g_noisy, inputs)

            loss_g_gan = bce(d_fake_g, torch.ones_like(d_fake_g))
            loss_g_l1, raw_l1, _ = calculate_masked_loss(fake_for_g, real_images, L1_LAMBDA)

            g_loss = loss_g_gan + loss_g_l1

            gen.zero_grad()
            g_loss.backward()
            opt_gen.step()
            total_g_steps += 1


        status_msg = "OK"
        if d_loops > 1: status_msg = f"D+{d_loops-1}"
        if g_repeats > 1: status_msg = f"G+{g_repeats-1}"

        loop.set_postfix(
            St=status_msg,
            D=f"{loss_d.item():.3f}",
            G=f"{g_loss.item():.2f}",
            L1=f"{raw_l1.item():.4f}",
            L_W=f"{loss_g_l1.item():.2f}",
            AccR=f"{acc_real:.2f}",
            AccF=f"{acc_fake:.2f}"
        )

    final_acc_real = d_real_correct / d_total_samples if d_total_samples > 0 else 0
    final_acc_fake = d_fake_correct / d_total_samples if d_total_samples > 0 else 0

    print(f"\n[Dynamic Stats] D_Boosts: {total_d_boosts} | G_Boosts: {total_g_boosts} || Total Steps: D={total_d_steps}, G={total_g_steps}")
    print(f"[Epoch Stats]   Acc Real: {final_acc_real:.2%} | Acc Fake: {final_acc_fake:.2%}")

def train_fn(disc, gen, loader, opt_disc, opt_gen, l1_loss, bce, epoch, vgg_criterion):
    loop = tqdm(loader, leave=True)

    d_real_correct = 0
    d_fake_correct = 0
    d_total = 0

    for idx, (inputs, real_images) in enumerate(loop):
        inputs = inputs.to(DEVICE)
        real_images = real_images.to(DEVICE)
        batch_size = inputs.shape[0]

       #WARMUP
        if epoch < WARMUP_EPOCHS:
            fake_images = gen(None, inputs)

            loss_g_l1, raw_l1, mask_pct = calculate_masked_loss(fake_images, real_images, L1_LAMBDA)

            # Tylko L1
            g_loss = loss_g_l1

            gen.zero_grad()
            g_loss.backward()
            opt_gen.step()

            loop.set_postfix(Mode="WARMUP_L1", L1=f"{raw_l1:.4f}", TotalG=f"{g_loss.item():.2f}")

        else:
            freeze_gen = epoch < (START_EPOCH + DISC_WARMUP_EPOCHS)
            mode_name = "REHAB_DISC" if freeze_gen else "GAN_TRAIN"

            #  DYSKRYMINATOR
            with torch.set_grad_enabled(not freeze_gen):
                fake_images = gen(None, inputs)

            # Real
            d_real = disc(real_images, inputs)
            loss_d_real = bce(d_real, torch.ones_like(d_real) * 0.9) # Label smoothing
            d_real_correct += (d_real > 0.5).sum().item()

            # Fake
            d_fake = disc(fake_images.detach(), inputs)
            loss_d_fake = bce(d_fake, torch.zeros_like(d_fake))
            d_fake_correct += (d_fake < 0.5).sum().item()

            loss_d = (loss_d_real + loss_d_fake) / 2

            disc.zero_grad()
            loss_d.backward()
            opt_disc.step()

            #  GENERATOR
            if not freeze_gen:
                # Trenujemy G tylko w trybie GAN_TRAIN

                d_fake_for_gen = disc(fake_images, inputs)
                loss_g_gan = bce(d_fake_for_gen, torch.ones_like(d_fake_for_gen))

                loss_g_l1, raw_l1, mask_pct = calculate_masked_loss(fake_images, real_images, L1_LAMBDA)

                g_loss = loss_g_gan + loss_g_l1

                gen.zero_grad()
                g_loss.backward()
                opt_gen.step()


                log_g_loss = g_loss.item()
                log_l1 = raw_l1.item()
            else:
                with torch.no_grad():
                     _, log_l1, _ = calculate_masked_loss(fake_images, real_images, L1_LAMBDA)
                log_g_loss = 0.0

            # Logowanie
            d_total += batch_size
            acc_real = d_real_correct / d_total if d_total > 0 else 0
            acc_fake = d_fake_correct / d_total if d_total > 0 else 0

            loop.set_postfix(
                Mode=mode_name,
                D=f"{loss_d.item():.4f}",
                G=f"{log_g_loss:.2f}",
                L1=f"{log_l1:.4f}",
                AccR=f"{acc_real:.2f}",
                AccF=f"{acc_fake:.2f}"
            )

In [None]:
print(f"Urządzenie: {DEVICE}")
print(f"Lambda L1: {L1_LAMBDA}")

gen = Generator().to(DEVICE)
disc = Discriminator().to(DEVICE)

gen.apply(weights_init)
disc.apply(weights_init)

opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE_GEN, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE_DISC, betas=(0.5, 0.999)) # disc uczy się 4 razy wolniej

vgg_criterion = setup_vgg_model(DEVICE)
print("Zainicjalizowano Generator VGG Loss (VGG16).")

bce = nn.BCELoss()
l1_loss = nn.L1Loss()

train_loader, val_loader, test_loader = get_loaders(dataset_dir, BATCH_SIZE)
if LOAD_MODEL:
    current_lr_gen = LEARNING_RATE_GEN
    current_lr_disc = LEARNING_RATE_DISC
    load_checkpoint(f"{checkpoint_dir}/gen_epoch_{START_EPOCH}.pth", gen, opt_gen, current_lr_gen)
    load_checkpoint(f"{checkpoint_dir}/disc_epoch_{START_EPOCH}.pth", disc, opt_disc, current_lr_disc)
    print(f"--> Wznowiono trening od epoki {START_EPOCH}. Nowe LR: G={current_lr_gen}, D={current_lr_disc}")


Urządzenie: cpu
Lambda L1: 20.0
Zainicjalizowano Generator VGG Loss (VGG16).
Podział danych: Train=2200, Val=200, Test=600


In [None]:
train_info = []

os.makedirs(checkpoint_dir, exist_ok=True)
train_params = f"\n\n### PARAMETRY TRENINGU ### LEARNING_RATE_GEN={LEARNING_RATE_GEN}\nLEARNING_RATE_DISC={LEARNING_RATE_DISC}\nBATCH_SIZE={BATCH_SIZE}\nNUM_EPOCHS={NUM_EPOCHS}\nL1_LAMBDA={L1_LAMBDA}"
with open(f"{checkpoint_dir}/train_params.txt", "w") as f:
  f.write(train_params)

for epoch in range(START_EPOCH + 1, START_EPOCH + 1 + NUM_EPOCHS):
    print(f"Epoch [{epoch}/{START_EPOCH + NUM_EPOCHS}]")

    train_dynamic_fn(disc, gen, train_loader, opt_disc, opt_gen, l1_loss, bce, epoch, vgg_criterion)
    info = check_accuracy(val_loader, gen, disc, DEVICE, epoch, l1_loss,bce)

    train_info.append(info)

    if epoch == 50 and epoch == 100 and epoch == 150 and epoch == 180 and epoch >= 197:
      save_checkpoint(gen, opt_gen, filename=f"{checkpoint_dir}/gen_epoch_{epoch}.pth")
      save_checkpoint(disc, opt_disc, filename=f"{checkpoint_dir}/disc_epoch_{epoch}.pth")

## Wyniki

In [None]:

if LOAD_MODEL:
    current_lr_gen = LEARNING_RATE_GEN
    load_checkpoint(f"{checkpoint_dir}/gen_epoch_{NUM_EPOCHS}.pth", gen, opt_gen, current_lr_gen)

=> Wczytywanie checkpointu /content/drive/MyDrive/projekt_3/model_2_checkpoints_v1/gen_epoch_289.pth


In [None]:
test_ds_folder = ROOT_PROJECT_FOLDER + f"/test_results/model_2_{NUM_EPOCHS}"
result_folder = ROOT_PROJECT_FOLDER + "/model_2_metryki"

target_folder = test_ds_folder + "/target"
pred_folder = test_ds_folder + "/pred"

### Metryki

In [None]:
!pip install lpips
!pip install flip-evaluator

from skimage.metrics import structural_similarity as ssim
import lpips
from flip_evaluator import evaluate
from scipy.spatial.distance import directed_hausdorff
import cv2
from torchvision.utils import make_grid

Collecting lpips
  Downloading lpips-0.1.4-py3-none-any.whl.metadata (10 kB)
Downloading lpips-0.1.4-py3-none-any.whl (53 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.8/53.8 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: lpips
Successfully installed lpips-0.1.4
Collecting flip-evaluator
  Downloading flip_evaluator-1.7-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.9 kB)
Downloading flip_evaluator-1.7-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (415 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m415.2/415.2 kB[0m [31m19.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: flip-evaluator
Successfully installed flip-evaluator-1.7


In [None]:
def calc_hausdorff(orginal=np.ndarray, proccessed=np.ndarray) -> float:
  threshold_h, _ = cv2.threshold(cv2.cvtColor(orginal, cv2.COLOR_RGB2GRAY), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
  orginal_edges = cv2.Canny(orginal, threshold1=threshold_h * 0.3, threshold2=threshold_h)

  threshold_h, _ = cv2.threshold(cv2.cvtColor(proccessed, cv2.COLOR_RGB2GRAY), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
  proccessed_edges = cv2.Canny(proccessed, threshold1=threshold_h * 0.3, threshold2=threshold_h)

  orginal_points_edges = np.column_stack(np.where(orginal_edges > 0))
  proccessed_points_edges = np.column_stack(np.where(proccessed_edges > 0))

  return max(directed_hausdorff(orginal_points_edges, proccessed_points_edges)[0], directed_hausdorff(proccessed_points_edges, orginal_points_edges)[0])

In [None]:
def read_image(path: str):
  return cv2.imread(path, cv2.IMREAD_COLOR_RGB)

In [None]:
target_content = os.listdir(target_folder)
pred_content = os.listdir(pred_folder)

In [None]:
target_content = [item for item in target_content if item.endswith(".png")]
pred_content = [item for item in pred_content if item.endswith(".png")]

In [None]:
print(len(target_content))
print(len(pred_content))

600
600


In [None]:
def as_tensor(image):
  image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
  image  = image * 2 - 1

  return image.type(torch.float32)

In [None]:
def calc_metrics(gen, device, lpips_net='squeeze', result_folder=None):
  print(f"Urządzenie: {DEVICE}")

  gen.eval()

  loss_lpips = lpips.LPIPS(net=lpips_net).to(device)

  # metryki
  ssim_metric      = []
  lpips_metric     = []
  flip_metric      = []
  hausdorff_metric = []

  mertric_str = ""

  i = 0
  dynamicRange = "LDR"
  with torch.no_grad():
      for target_name in target_content:
        fullpath_target = target_folder + "/" + target_name
        fullpath_pred = pred_folder + "/" + target_name

        target_img = read_image(fullpath_target)
        pred_img = read_image(fullpath_pred)

        target_img_f = target_img.astype(np.float32) / 255.0
        pred_img_f = pred_img.astype(np.float32) / 255.0

        print(f"\ttarget_img_f: min={target_img_f.min().item()}, max={target_img_f.max().item()}")
        print(f"\tpred_img_f: min={pred_img_f.min().item()}, max={pred_img_f.max().item()}")

        ## ------- SSIM ----------
        ssim_metric.append(ssim(target_img, pred_img, full=False, multichannel=True, channel_axis=-1, data_range=1))

        ## ------- LIPIS ----------

        target_tensor = as_tensor(target_img_f)
        pred_tensor = as_tensor(pred_img_f)

        print(f"\ttarget_tensor: min={target_tensor.min().item()}, max={target_tensor.max().item()}")
        print(f"\tpred_tensor: min={pred_tensor.min().item()}, max={pred_tensor.max().item()}")

        lpips_metric.append(loss_lpips(target_tensor, pred_tensor).item())

        ## ------- FLIP ----------
        flip_info = evaluate(target_img_f, pred_img_f, dynamicRange)
        flip_metric.append(flip_info[1])

        ## ------- HAUSDORFF  ----------
        hausdorff = calc_hausdorff(target_img, pred_img)

        if np.isfinite(hausdorff) and hausdorff > 0:
          hausdorff_metric.append(hausdorff)

        ## ------- RAPORT ----------

        line = f"Obraz i={i+1}, ssim={ssim_metric[-1]}, lpips={lpips_metric[-1]}, flip={flip_metric[-1]}, hausdorff={hausdorff}\n"
        mertric_str += line

        print(line)
        i += 1

  avg_ssim = np.mean(ssim_metric)
  avg_lpips =  np.mean(lpips_metric)
  avg_flip = np.mean(flip_metric)
  avg_hausdorff = np.mean(hausdorff_metric)

  if result_folder:
    os.makedirs(result_folder, exist_ok=True)
    mertric_str += f"\n\n### METRYKI DLA ZBIORU TESTOWEGO ###\n\n\tssim={avg_ssim}, lpips={avg_lpips}, flip={avg_flip}, hausdorff={avg_hausdorff}"
    with open(f"{result_folder}/metryki.txt", "w") as f:
      f.write(mertric_str)

  return avg_ssim, avg_lpips, avg_flip, avg_hausdorff

In [None]:
avg_ssim, avg_lpips, avg_flip, avg_hausdorff = calc_metrics(gen, DEVICE, result_folder=result_folder)
print(f"Metryki dla zbioru testowego: ssim={avg_ssim}, lpips={avg_lpips}, flip={avg_flip}, hausdorff={avg_hausdorff}")

Urządzenie: cpu
Setting up [LPIPS] perceptual loss: trunk [squeeze], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/squeeze.pth
	target_img_f: min=0.0, max=1.0
	pred_img_f: min=0.0, max=0.21960784494876862
	target_tensor: min=-1.0, max=1.0
	pred_tensor: min=-1.0, max=-0.5607843399047852
Obraz i=1, ssim=0.3040780961780197, lpips=0.23678556084632874, flip=0.069606713950634, hausdorff=48.08326112068523





	target_img_f: min=0.0, max=1.0
	pred_img_f: min=0.0, max=0.9882352948188782
	target_tensor: min=-1.0, max=1.0
	pred_tensor: min=-1.0, max=0.9764705896377563
Obraz i=2, ssim=0.6895739749614718, lpips=0.17902058362960815, flip=0.07837072014808655, hausdorff=50.99019513592785

	target_img_f: min=0.0, max=1.0
	pred_img_f: min=0.0, max=0.5411764979362488
	target_tensor: min=-1.0, max=1.0
	pred_tensor: min=-1.0, max=0.08235299587249756
Obraz i=3, ssim=0.6497898402087622, lpips=0.17668937146663666, flip=0.024603012949228287, hausdorff=58.180752831155424

	target_img_f: min=0.0, max=1.0
	pred_img_f: min=0.0, max=0.6078431606292725
	target_tensor: min=-1.0, max=1.0
	pred_tensor: min=-1.0, max=0.21568632125854492
Obraz i=4, ssim=0.6657296011869136, lpips=0.2374034970998764, flip=0.15834884345531464, hausdorff=86.2670273047588

	target_img_f: min=0.0, max=1.0
	pred_img_f: min=0.0, max=0.8509804010391235
	target_tensor: min=-1.0, max=1.0
	pred_tensor: min=-1.0, max=0.7019608020782471
Obraz i=5, s