In [30]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.utils.data import DataLoader
from diffusers import AutoencoderKL
from PIL import Image
from torchvision import transforms
import os
import matplotlib.pyplot as plt
import torch
import os
from tqdm import tqdm
from torch.utils.data import random_split
import matplotlib.pyplot as plt


In [31]:
# --- PARAMETRI DI CONFIGURAZIONE GLOBALI ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Questi valori sono per un training leggero (adatta al tuo hardware)
LATENT_CHANNELS = 4 # Output VAE: 4 canali
FEATURES = [128, 256, 512]
TIMESTEPS = 1000
BATCH_SIZE = 32
LEARNING_RATE = 1e-4 
EPOCHS = 5 # Solo per un test iniziale
SAVE_FOLDER = "../weights/"
MODEL_NAME = "ldm_unet"
DATA_FOLDER = "../data/val2017"
IMAGE_SIZE = 128
VALIDATION_SPLIT_RATIO = 0.2 


In [32]:
# --- 2. VAE ENCODING & DECODING UTILITY ---
# Caricamento del VAE e definizione delle utility (come discusso in precedenza)
try:
    vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(DEVICE).eval()
    VAE_SCALE_FACTOR = vae.config.scaling_factor
except Exception as e:
    print(f"Errore nel caricamento del VAE. Assicurati di avere 'diffusers' installato. {e}")
    VAE_SCALE_FACTOR = 0.18215 # Valore fallback

def encode_to_latent(pixels: torch.Tensor):
    """Converte pixel [0, 1] in latenti (4, H/8, W/8)."""
    pixels = (pixels * 2) - 1.0 # Scala [0, 1] a [-1, 1]
    with torch.no_grad():
        posterior = vae.encode(pixels).latent_dist
    latents = posterior.sample()
    return latents * VAE_SCALE_FACTOR

def decode_from_latent(latents: torch.Tensor):
    """Converte latenti in pixel [0, 1]."""
    latents = latents / VAE_SCALE_FACTOR
    with torch.no_grad():
        image = vae.decode(latents).sample
    return (image / 2 + 0.5).clamp(0, 1)

In [33]:
# --- 3. CLASSI ESSENZIALI (Definizione Minimali) ---
# Necessario inserire qui il codice corretto di ResBlock e UNet per far funzionare lo script!
# Ho usato il codice corretto per garantire l'esecuzione.

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        
        device = time.device
        half_dim = self.dim // 2
        
        # 'embeddings' will be (B, half_dim)
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :] 
        # broadcasting (B, half_dim)
        
        # Sine and Cosine
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) # (B, dim)
        
        # if dim is odd, pad with one zero vector
        if self.dim % 2 != 0:
            embeddings = F.pad(embeddings, (0, 1), mode='constant', value=0)
            
        return embeddings
    


# Residual Block
class ResBlock(nn.Module):
    """
    A simple Residual Block with two convolutional layers.
    """

    def __init__(self, channels, time_embed_dim = 128, num_groups=8, self_attention=False):
        super(ResBlock, self).__init__()

        # Time Embedding
        self.time_proj1 = nn.Linear(time_embed_dim, channels * 2)
        self.time_proj2 = nn.Linear(time_embed_dim, channels * 2)

        # Input Number of channels (128, 256, 512) x h (16, 32) x w (16, 32)
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=channels)
        self.act = nn.SiLU(inplace=True)

        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.norm2 = nn.GroupNorm(num_groups = num_groups, num_channels = channels)
        
        # Output channels = Input channels 

        # Self-attention layer (optional)
        if self_attention:
            self.attention = nn.MultiheadAttention(embed_dim=channels, num_heads=4)


    def forward(self, x, t_emb):
        identity = x
        # Adaptive Normalization with Time Embedding

        # Time embedding 1 projection 
        t_proj_1 = self.time_proj1(t_emb).chunk(2, dim=-1)
        gamma1, beta1 = t_proj_1[0].unsqueeze(-1).unsqueeze(-1), t_proj_1[1].unsqueeze(-1).unsqueeze(-1)




        out = self.conv1(x)
        out = self.norm1(out) * (1 + gamma1) + beta1
        out = self.act(out)

        # Time embedding 2 projection 
        t_proj_2 = self.time_proj2(t_emb).chunk(2, dim=-1)
        gamma2, beta2 = t_proj_2[0].unsqueeze(-1).unsqueeze(-1), t_proj_2[1].unsqueeze(-1).unsqueeze(-1)

        out = self.conv2(out)
        out = self.norm2(out) * (1 + gamma2) + beta2
        out = self.act(out)

     
        
        if hasattr(self, "attention"):
            b, c, h, w = out.size()
            out_reshaped = out.view(b, c, h * w).permute(2, 0, 1)  # (h*w, b, c)
            out_attended, _ = self.attention(out_reshaped, out_reshaped, out_reshaped)
            out = out_attended.permute(1, 2, 0).view(b, c, h, w)

        out += identity # Residual connection

        return out



# UNet Model
class UNet(nn.Module):
    """
    LDM UNet model skeleton.
    """

    def __init__(self, in_channels, out_channels, num_blocks = 2, time_emb_dim = 128, features=[128, 256, 512]):
        super(UNet, self).__init__()



        #The Variational autoencoder reduces the input image of size 3x128x128 to a latent representation of size 4 x 32 x 32.
        # Input
        # 4 x 32 x 32
        
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.time_proj = SinusoidalPositionEmbeddings(dim=time_emb_dim)

        # Time Embedding MLP
        self.time_mlp = nn.Sequential(
            nn.Linear(time_emb_dim, time_emb_dim * 4), 
            nn.SiLU(),
            nn.Linear(time_emb_dim * 4, time_emb_dim) 
        )
        self.time_emb_dim = time_emb_dim



        # Initial Convolution out = [H_in + 2*padding - dilation*(kernel_size-1) -1]/stride +1
        # Output
        # 128 x 32 x 32
        self.init_conv = nn.Conv2d(in_channels, features[0], kernel_size=3, padding=1)

        # Encoder
        self.enc_layers = nn.ModuleList()
        self.downsamples = nn.ModuleList()


        current_channels = features[0]

        for next_channels in features[1:]:
            level_blocks = nn.ModuleList()
            for _ in range(num_blocks):
                block = ResBlock(current_channels, time_embed_dim=time_emb_dim, num_groups = min(current_channels//32, 32))
                level_blocks.append(block)

            # Output size halved (DownSampling Layer)
            downsample = nn.Conv2d(current_channels, next_channels, kernel_size=4, stride=2, padding=1)
            self.downsamples.append(downsample)
            self.enc_layers.append(level_blocks)
            current_channels = next_channels

        # Bottleneck
        self.bottleneck = nn.ModuleList([
        ResBlock(features[-1], time_embed_dim=time_emb_dim, num_groups = min(features[-1]//32, 32), self_attention=True),
        ResBlock(features[-1], time_embed_dim=time_emb_dim, num_groups = min(features[-1]//32, 32), self_attention=True)])


        # Decoder
        self.dec_layers = nn.ModuleList()
        self.upsamples = nn.ModuleList()
        reversed_features = features[::-1]
        
        # Modified to Allow Skip Connections
        for i in range(len(reversed_features) - 1):
            level_blocks = nn.ModuleList()
            
            out_channels_level = reversed_features[i+1]

            in_channels_up = reversed_features[i]

            # UpSampling: Reduces channels while doubling spatial size.
            upsample = nn.ConvTranspose2d(in_channels_up, out_channels_level, kernel_size=4, stride=2, padding=1)
            self.upsamples.append(upsample)
            
            # After UpSampling + Skip Connection, the channel count will be N_current * 2
            block_in_channels_after_skip = out_channels_level * 2

            
            # To solve the channel mismatch after concatenation with skip connections,
            # we introduce an adapter convolutional layer.
            adapter_conv = nn.Conv2d(block_in_channels_after_skip, out_channels_level, kernel_size=1)
            level_blocks.append(adapter_conv) # Adapter is the first operation in this decoder level

            for _ in range(num_blocks):

                block = ResBlock(out_channels_level, time_embed_dim=time_emb_dim, num_groups = min(out_channels_level//32, 32))
                level_blocks.append(block)
            
            self.dec_layers.append(level_blocks)


        in_channels = features[0]
        self.out_conv = nn.Sequential(
        nn.GroupNorm(num_groups= min(in_channels//32, 32), num_channels=in_channels),
        nn.SiLU(),
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
        


    def forward(self, x, time):

        time_sin = self.time_proj(time)         
        t_emb = self.time_mlp(time_sin)
        x = self.init_conv(x)

        # skip connections
        skips = []

        # encoder
        for blocks, down in zip(self.enc_layers, self.downsamples):
            for blk in blocks:
                x = blk(x, t_emb)
            skips.append(x)
            x = down(x)
        
        # skip = [Encoder 1 --> 128 x 32 x 32 , Encoder 2 --> 256 x 16 x 16]
        # bottleneck
        for layer in self.bottleneck:
            x = layer(x, t_emb)

        # decoder for skip connections
        for up, blocks, skip in zip(self.upsamples, self.dec_layers, reversed(skips)):
            x = up(x)
            # # # if shapes mismatch due to odd sizes, center-crop skip
            if x.shape[-2:] != skip.shape[-2:]:
                    # simple interpolate to match
                    #_, _, h, w = x.shape
                x = F.interpolate(x, size=skip.shape[-2:], mode='nearest')
            
            #concat along channels
            x = torch.cat([x, skip], dim=1)
            for blk in blocks:
                if isinstance(blk, nn.Conv2d):
                    # Adapter conv
                    x = blk(x)
                else:
                    x = blk(x, t_emb)

        # final conv
        out = self.out_conv(x)


        return out

class LatentDataset(torch.utils.data.Dataset):
    # ... Inserire qui il codice della classe LatentDataset fornito
    def __init__(self, data_dir, image_size=128):
        self.data_dir = data_dir
        self.image_paths = [
            os.path.join(data_dir, f) 
            for f in os.listdir(data_dir) 
            if f.endswith(('.png', '.jpg', '.jpeg'))
        ]
        self.transform = transforms.Compose([
            transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            image = Image.open(img_path).convert('RGB')
        except:
            # Semplice fallback per dataset di test
            return self.__getitem__((idx + 1) % len(self))
            
        return self.transform(image)


In [35]:
class ForwardConfig:
    """
    Parameters configuartion for the forward process
    """
    def __init__(
        self,
        input_path="latents.pt",
        output_path="latents_noised.pt",
        t: float = 0.7,
        final: bool = True,
        eps: float = 1e-5,
        closed_formula: bool = True,
        seed: int = 42,
        beta_min: float = 0.1,
        beta_max: float = 20.0,
        N: int = 1000,
    ):
        self.input_path = input_path
        self.output_path = output_path
        self.t = t
        self.final = final
        self.eps = eps
        self.closed_formula = closed_formula
        self.seed = seed
        self.beta_min = beta_min
        self.beta_max = beta_max
        self.N = N


In [48]:
import torch
import numpy as np

class subVP_SDE:
    def __init__(self, beta_min=0.1, beta_max=20, N=1000):
        """Construct the sub-VP SDE

        Args:
        beta_min: value of beta(0)
        beta_max: value of beta(1)
        N: number of discretization steps

        Attributes:
        beta_0: minimum noise scale at t=0 for the linear schedule.
        beta_1: maximum noise scale at t=1 for the linear schedule.
        N: stored grid size, which is usually not used by closed-form routines below.
        """
        self.beta_0 = beta_min
        self.beta_1 = beta_max
        self.N = N

    def beta_linear(self, t: torch.Tensor) -> torch.Tensor:
        """Return β(t) = β_0 + t*(β_1 - β_0)
        Used for closed form forward application
        """
        return self.beta_0 + t * (self.beta_1 - self.beta_0)

    def beta_exponential(self, t: torch.Tensor) -> torch.Tensor:
        """β(t) grows exponentially from β_0 to β_1 across t in [0, 1]"""
        sequence = torch.log(torch.tensor(self.beta_1/self.beta_0, device = t.device, dtype = t.dtype))
        return self.beta_0 * torch.exp(sequence * t)

    # Instanteneous SDE coefficients
    def sde(self, x, t):
        """Returns instantaneous coefficients of the SDE evaluated at (x,t).
        This function do not integrate but it provides the per-time drift and diffusion values.

        Args:
        x: (B,C,H,W), t: (B,)
        
        Details:
        beta(t) = beta_0 + t * (beta_1 - beta_0)
        ∫_0^t beta(s) ds = beta_0 * t + 0.5 * (beta_1 - beta_0) * t^2
        discount := 1 - exp(-2 * ∫_0^t beta(s) ds) = 1 - exp(-2 * beta_0 * t - (beta_1 - beta_0) * t^2)
        g(t) = sqrt(beta(t) * discount)
        """
        beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)
        drift = -0.5 * beta_t[:, None, None, None] * x #because x is (B, C, H, W), where B is the batch size
        discount = 1.0 - torch.exp(-2 * self.beta_0 * t - (self.beta_1 - self.beta_0) * t ** 2) #development of the integral beta(s) ds
        diffusion = torch.sqrt(beta_t * discount)
        return drift, diffusion

    # Closed form formula
    def marginal_prob(self, x_0, t):
        """Closed form for X_t given X_0 = x_O
        
        Distribution:
        X_t|X_0 ~ N (mean_coeff*X_0, std^2 * I)

        Args:
        x0: (B,C,H,W), t: (B,)
        Returns mean (B,1,1,1)*x0 and std (B,)

        Details:
        ∫_0^t beta(s) ds = beta_0 * t + 0.5 * (beta_1 - beta_0) * t^2
        log_mean_coeff = -0.5 * ∫_0^t beta(s) ds
                        = -0.5 * beta_0 * t - 0.25 * (beta_1 - beta_0) * t^2
        mean = exp(log_mean_coeff) * x_0
        std  = sqrt(1 - exp(2 * log_mean_coeff))
        
        """
        
        log_mean_coeff = -0.5 * self.beta_0 * t + -0.25 * (t ** 2) * (self.beta_1 - self.beta_0) #log exp(-1/2 * integral ( beta(s) ds)
        mean = torch.exp(log_mean_coeff)[:, None, None, None] * x_0
        std = torch.sqrt(1.0 - torch.exp(2 * log_mean_coeff)) #double check
        return mean, std

    def perturb_closed(self, x_0, t, noise = None):
        """Sample X_t by perturbing X_0 with gaussian noise

        Operation:
        1. Compute closed-form mean and std of X_t | X_0.
        2. Draw epsilon ~ N(0, I) with the same shape as x_0 if not provided.
        3. Return x_t = mean + std * epsilon, along with epsilon and std.

         Notes:
          - Deterministic for fixed x_0, t, and noise.
          - Suitable for training score/ε-predictor networks with known std.
          
        Args:
        x_0: (B,C,H,W), t:(B,)"""
        
        mean, std = self.marginal_prob(x_0, t)
        if noise is None:
            noise = torch.randn_like(x_0)
        x_t = mean + std[:,None, None, None] * noise
        return x_t, noise, std

    def perturb_simulate_path(self, x_0: torch.Tensor, t_end: float = 0.5, steps: int = 500, seed: int = 42, eps: float = 1e-12):
        """Sample X_t by perturbing X_0 with gaussian noise at time t

        Operation:
        1. Compute simulate path for of X_t | X_t-1 and updating X_t values for steps time
        2. omputing the mean and std at time t
        3. Calculating the implied eps

         Notes:
          - Deterministic for fixed x_0, t, and noise.
          - Suitable for training score/ε-predictor networks with known std.
          
        Args:
        x_0: (B,C,H,W), t:(B,)"""
        t_scalar = float(t_end)
        
        device = x_0.device
        dtype = x_0.dtype
        B = x_0.shape[0]
    
        gen = torch.Generator(device = device)
        gen.manual_seed(seed)

        t_grid = torch.linspace(0.0, t_scalar, steps+1, device = device, dtype = dtype)
        x = x_0.clone()
        
        for k in range(steps):
            t_k = t_grid[k].expand(B)
            dt = (t_grid[k+1] - t_grid[k]).item() # we return a scalar value
            drift, diffusion = self.sde(x, t_k)
            noise = noise = torch.randn(x.shape, device=x.device, dtype=x.dtype, generator=gen) # we generate Gaussian Noise, with same device and dtype as x
            x = x + drift * dt + diffusion[:, None, None, None] * (dt ** 0.5) * noise
        
        t_tensor = torch.full((B,), t_scalar, device = device, dtype = dtype)
        mean_t, std_t = self.marginal_prob(x_0, t_tensor)
        eps_implied = (x - mean_t) / (std_t[:, None, None, None] +1e-12) #noise tensor
        return x, eps_implied, std_t
    
    def get_integral_beta(self, t: torch.Tensor) -> torch.Tensor:
        """Computes the integral of beta(s) from 0 to t."""
        # ∫_0^t beta(s) ds = beta_0 * t + 0.5 * (beta_1 - beta_0) * t^2
        return self.beta_0 * t + 0.5 * (self.beta_1 - self.beta_0) * t ** 2
    
            
    def get_g_squared(self, t: torch.Tensor) -> torch.Tensor:
        """
        Computes the coefficient g(t)^2.
        
        λ(t) is used in the SDE definition as the diffusion coefficient squared.
        λ(t) = g(t)^2 = β(t) * (1 - exp(-2 * ∫_0^t β(s) ds))
        """


        # 1. β(t)
        beta_t = self.beta_linear(t)
        
        # 2. ∫_0^t β(s) ds
        integral_beta = self.get_integral_beta(t)
        
        # 3. 'discount' factor (1 - exp(-2 * Integrale))
        discount_factor = 1.0 - torch.exp(-2 * integral_beta)
        
        # 4. g(t)^2 = β(t) * discount_factor
        g_squared = beta_t * discount_factor
        
        return g_squared
    

        
    def get_lambda_original(self, t: torch.Tensor) -> torch.Tensor:
        """
        Computes λ(t) = (1 - exp(-∫_0^t β(s) ds))^2
        """
        integral_beta = self.get_integral_beta(t)
        # 1 - exp(-Integral) è un fattore comune nel formalismo DDPM/SDE
        alpha_t_factor = torch.exp(-integral_beta) 
        return (1.0 - alpha_t_factor) ** 2

In [52]:
from typing import Optional, Tuple

# import torch
# from subVP_SDE import subVP_SDE
# from Configurations import ForwardConfig

class ForwardProcess:
    def __init__(self, beta_min: float = 0.1, beta_max: float = 20.0, N: int = 1000):
        self.beta_min = beta_min
        self.beta_max = beta_max
        self.N = N
        self.sde_model = subVP_SDE(beta_min=beta_min, beta_max=beta_max, N=N)


    @torch.no_grad()
    def get_noised_latents(self, z0: torch.Tensor, t: torch.Tensor, final: bool = False, eps: float = 1e-5, closed_formula : bool = True, steps: int = 500, seed: int = 42, sde_cfg: ForwardConfig = ForwardConfig()) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Return noised latents z_t, along with the exact epsilon used and std(t).
        
        Inputs:
        z0: encoded latents. Device and dtype define outputs.
        t: scalar in (0, 1). If None and final=False, defaults to 0.5 for a mid-horizon corruption level.
        final: if True, overrides t and uses t = 1 - eps to avoid t=1 exactly for numerical stability when computing σ(t).
        eps: small offset so that final-time evaluation uses 1 - eps instead of 1.0. Prevents sqrt(1 - exp(…)) from degenerating.
        sde_cfg: ForwardProcess instance carrying beta schedule and N. Used to build a subVP_SDE with matching parameters.

        Operations:
        1. Builds subVP_SDE(beta_min, beta_max, N) on the same device as z0.
        2. Broadcasts scalar t to a batch vector (B,) for the SDE call.
        3. Calls closed-form perturbation: z_t = μ(t|z0) + σ(t) * ε, where ε ~ N(0, I) if not supplied internally by subVP_SDE.
        4. Returns (z_t, ε, σ(t)), where σ(t) has shape (B,).

        Notes:
        - Deterministic given z0, t, and a fixed epsilon.
        - Useful for reproducible corruption by reusing returned epsilon.
        """
        if isinstance(t, torch.Tensor):
            t_tensor = t

            if not closed_formula:
                raise ValueError("Training requires closed_formula=True when t is a tensor.")
        
        
        else: 
            if final:
                t_val = 1.0 - float(eps)
            else:
                t_val = 0.5 if t is None else float(t)


        # Building the SDE on the same device of the latent vector
        # sde = subVP_SDE(beta_min=sde_cfg.beta_min, beta_max=sde_cfg.beta_max, N=sde_cfg.N)

        if closed_formula:
            #t_tensor = torch.full((z0.size(0),), t_tensor, device=z0.device, dtype=z0.dtype)
            z_t, epsilon, std = self.sde_model.perturb_closed(z0, t_tensor)
        else:
            # t_tensor = torch.tensor([t_val], device=z0.device, dtype=z0.dtype)
            z_t, epsilon, std = self.sde_model.perturb_simulate_path(z0, t_val, steps=steps, seed=seed)        
        
        
        return z_t, epsilon, std

    @torch.no_grad()
    def main():
        cfg = ForwardConfig()
        z0 = torch.load(cfg.input_path, map_location="cpu")
    
        sde_cfg = ForwardProcess(cfg.beta_min,cfg.beta_max, cfg.N)
    
        z_t, epsilon, std = get_noised_latents(z0, t = cfg.t, final = cfg.final, eps = cfg.eps, closed_formula = cfg.closed_formula, steps = cfg.N, seed = cfg.seed, sde_cfg = sde_cfg)
    
        torch.save(z_t, cfg.output_path)

In [38]:
def show_tensor_image(tensor_img, title=""):
    """
    Converte un tensore (C, H, W) nel range [0, 1] in un'immagine numpy e la mostra.
    """
    if tensor_img.device.type != 'cpu':
        tensor_img = tensor_img.cpu()
        
    # Se il tensore è un batch (B, C, H, W), prendiamo il primo elemento
    if tensor_img.ndim == 4:
        tensor_img = tensor_img[0]
        
    # Clampa per sicurezza e converte in NumPy (H, W, C)
    img_np = tensor_img.clamp(0, 1).permute(1, 2, 0).numpy()
    
    plt.figure(figsize=(4, 4))
    plt.imshow(img_np)
    plt.title(title)
    plt.axis('off')
    plt.show()

In [39]:
# # --- 4. FUNZIONE DI VISUALIZZAZIONE E SETUP ---
# def setup():
#     """Setup del modello e loop di training."""
    
#     print(f"Inizializzazione su {DEVICE}...")

#     # A. Data Loading
#     try:
#         dataset = LatentDataset(data_dir=DATA_FOLDER, image_size=IMAGE_SIZE)
#         dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
#         print(f"Dataset caricato: {len(dataset)} immagini.")
#     except Exception as e:
#         print(f"ERRORE: Impossibile trovare o caricare i dati in {DATA_FOLDER}. Verifica il path. {e}")
#         return

#     # Visualizzazione e Test Trasformazione
#     example_img = dataset[0]
#     print(f"Immagine originale (pixel) shape: {example_img.shape}")

#     show_tensor_image(example_img, title="1. Immagine Originale (Pixels)")

#     # B. Model Setup
#     model = UNet(in_channels=LATENT_CHANNELS, out_channels=LATENT_CHANNELS, features=FEATURES).to(DEVICE)
#     diffusion = GaussianDiffusion(timesteps=TIMESTEPS)
#     optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
#     criterion = nn.MSELoss() # La loss standard per DDPM è MSELoss tra rumore predetto e rumore reale

#     print("--- Inizio Test VAE e UNet ---")
    
#     # 1. Encoding
#     dummy_pixels = example_img.unsqueeze(0).to(DEVICE)
#     latents = encode_to_latent(dummy_pixels)
#     print(f"Latent (encoded) shape: {latents.shape}") # Dovrebbe essere (1, 4, 16, 16) per 128x128
    
#     # 2. Diffusion Forward Test (t=500)
#     test_t = torch.tensor([500]).to(DEVICE)
#     noisy_latents, true_noise = diffusion.forward_process(latents, test_t)
#     print(f"Latent rumoroso (t=500) shape: {noisy_latents.shape}")
#     show_tensor_image(noisy_latents, title=f"3. Immagine Rumorosa (t={test_t.item()})")

#     # 3. UNet Prediction Test
#     predicted_noise = model(noisy_latents, test_t)
#     print(f"Rumore predetto (output UNet) shape: {predicted_noise.shape}")
#     assert predicted_noise.shape == true_noise.shape, "Shape mismatch tra rumore predetto e reale!"

#     print("Test UNet OK. Inizio Training.")
    
#     print("-" * 30)

#     return model, diffusion, dataloader, optimizer, criterion



# model, diffusion, dataloader, optimizer, criterion = setup()
    
    


In [40]:
# from torch.utils.data import Subset #
# NUM_SAMPLES_FOR_SANITY_CHECK = 5000 

# def small_dataset():
    
#     # A. Data Loading
#     try:
#         # Inizializza il Dataset Completo (con la logica di caricamento file/classi esistente)
#         full_dataset = LatentDataset(data_dir=DATA_FOLDER, image_size=IMAGE_SIZE)
        
        
#         if len(full_dataset) > NUM_SAMPLES_FOR_SANITY_CHECK:
#             # 1. Genera N indici casuali (da 0 alla dimensione totale del dataset)
#             indices = torch.randperm(len(full_dataset))[:NUM_SAMPLES_FOR_SANITY_CHECK].tolist()
            
#             # 2. Crea il Subset utilizzando solo questi indici
#             subset_dataset = Subset(full_dataset, indices)
            
#             print(f"Dataset completo caricato ({len(full_dataset)} immagini).")
#             print(f"Utilizzo Subset di {len(subset_dataset)} immagini per il Sanity Check.")
#         else:
#             subset_dataset = full_dataset
#             print("Dataset troppo piccolo. Utilizzo tutte le immagini.")

#         # 3. Passa il Subset (o il full_dataset) al DataLoader
#         dataloader = DataLoader(subset_dataset, batch_size=BATCH_SIZE, shuffle=True)
        
#         return dataloader
        
#     except Exception as e:
#         print(f"ERRORE: Impossibile trovare o caricare i dati in {DATA_FOLDER}. Verifica il path. {e}")
#         return

# small_dataloader = small_dataset()



In [41]:
def calculate_importance_sampling_probabilities(sde_model, N_timesteps, device):
    """
    Calcola il tensore di probabilità p_IS per l'Importance Sampling.
    p(t) ∝ g(t)^2 / λ_orig(t)
    """
    T_max = 1.0
    epsilon = 1e-8 # Per stabilità numerica (evitare divisioni per zero)
    
    # 1. Crea il vettore di timestep continui da [eps, 1.0]
    timesteps = torch.linspace(epsilon, T_max, N_timesteps, device=device)
    
    # 2. Calcola i pesi necessari (g(t)^2 e λ_orig(t))
    g_squared = sde_model.get_g_squared(timesteps)
    lambda_original = sde_model.get_lambda_original(timesteps)
    
    # 3. Calcola il peso non normalizzato p(t) ∝ g(t)^2 / λ_orig(t)
    # Aggiungiamo epsilon al denominatore per sicurezza.
    sampling_weights = g_squared / (lambda_original + epsilon)
    
    # 4. Normalizza per ottenere la distribuzione di probabilità (somma = 1)
    probabilities = sampling_weights / torch.sum(sampling_weights)
    
    return probabilities

In [53]:
# --- NUOVE COSTANTI GLOBALI NECESSARIE ---
def setup():
    """Setup del modello e preparazione dei data loader di training e validazione."""
    
    print(f"Inizializzazione su {DEVICE}...")

    # A. Data Loading & Splitting
    try:
        # 1. Carica il Dataset Completo
        full_dataset = LatentDataset(data_dir=DATA_FOLDER, image_size=IMAGE_SIZE)
        
        # 2. Definisci le dimensioni dello split
        val_size = int(VALIDATION_SPLIT_RATIO * len(full_dataset))
        train_size = len(full_dataset) - val_size
        
        # 3. SPLIT DETERMINISTICO (Usiamo un seed per la riproducibilità!)
        torch.manual_seed(42)
        train_dataset, val_dataset = random_split(
            full_dataset, [train_size, val_size]
        )
        
        # 4. Crea i DataLoader separati
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

        print(f"Dataset caricato: Totale {len(full_dataset)} immagini.")
        print(f" -> Train Loader: {len(train_dataset)} immagini ({len(train_loader)} batches)")
        print(f" -> Validation Loader: {len(val_dataset)} immagini ({len(val_loader)} batches)")

    except Exception as e:
        print(f"ERRORE: Impossibile trovare o caricare i dati in {DATA_FOLDER}. Verifica il path. {e}")
        return

    # B. Model Setup
    model = UNet(in_channels=LATENT_CHANNELS, out_channels=LATENT_CHANNELS, features=FEATURES).to(DEVICE)
    forward_process = ForwardProcess(beta_min=0.1, beta_max=20.0, N=1000)
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.MSELoss() 

    importance_sampling_probabilities = calculate_importance_sampling_probabilities(
    forward_process.sde_model, 
    forward_process.N, 
    DEVICE
)

    # 5. Restituisce il modello, la diffusione e I DUE DATALOADER
    return model, train_loader, val_loader, optimizer, forward_process, criterion, importance_sampling_probabilities

model, train_loader, val_loader, optimizer, forward_process, criterion, importance_sampling_probabilities = setup()

Inizializzazione su cuda...
Dataset caricato: Totale 5000 immagini.
 -> Train Loader: 4000 immagini (125 batches)
 -> Validation Loader: 1000 immagini (32 batches)


In [54]:
EPOCHS = 5 
def train(model, train_loader, val_loader, optimizer, forward_process, criterion, vae_scale_factor = 0.18215, importance_sampling_probabilities =  None, device = "cuda"):
    """
    Esegue il loop di training DDPM, salvando i pesi e i metadati.
    Include la validazione inline (senza funzione esterna).
    """

    # Crea la cartella di salvataggio se non esiste
    os.makedirs(SAVE_FOLDER, exist_ok=True)
    
    # Inizializza la cronologia delle loss PRIMA del ciclo delle Epoche
    loss_history = {'train_loss': [], 'val_loss': []}

    for epoch in range(EPOCHS):
        
        # 1. --- FASE DI TRAINING (Modalità: model.train()) ---
        model.train() 
        total_train_loss = 0
        train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} (TRAIN)")
        
        for step, batch in enumerate(train_bar):
            optimizer.zero_grad()
            
            # Logica Forward/Backward
            x_start_pixels = batch.to(device)
            x_start_latents = encode_to_latent(x_start_pixels) * vae_scale_factor
            batch_size = x_start_latents.shape[0]

            if importance_sampling_probabilities is not None:
                indices = torch.multinomial(importance_sampling_probabilities, num_samples=batch_size, replacement=True)
                t_float_values = indices.float() / forward_process.N 
                t = t_float_values.to(device)
            else: 
                t = torch.rand(x_start_latents.shape[0], device=device)


            true_noise = torch.randn_like(x_start_latents)

            x_t, _, std_t = forward_process.get_noised_latents(x_start_latents, t = t, eps = true_noise)           
            
            #x_t, true_noise = diffusion.forward_process(x_start_latents, t, noise)
            
            predicted_noise = model(x_t, t)


            # Not weighted loss per sample
            per_sample_loss = criterion(predicted_noise, true_noise)

            # Weighting factor λ(t)
            weighting_factor = forward_process.sde_model.get_g_squared(t)[:, None, None, None]
            
            # Total weighted loss
            weighted_loss = per_sample_loss * weighting_factor
            
            loss = torch.mean(weighted_loss)

            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.item()
            avg_train_loss = total_train_loss / (step + 1)
            train_bar.set_postfix(loss=f'{avg_train_loss:.4f}') # Aggiorna la barra TQDM
            
        final_avg_train_loss = total_train_loss / len(train_loader) 
            
        # 2. --- FASE DI VALIDAZIONE (Dopo il ciclo di training) ---
        
        # Disattiva il calcolo del gradiente per la valutazione
        with torch.no_grad():
            model.eval() # Imposta il modello in modalità valutazione (disattiva dropout/batchnorm)
            total_val_loss = 0

            val_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [VAL]")

            for val_batch in val_bar:
                x_start_pixels = val_batch.to(device)
                x_start_latents = encode_to_latent(x_start_pixels) * vae_scale_factor


                batch_size = x_start_latents.shape[0]

                if importance_sampling_probabilities is not None:
                    indices = torch.multinomial(importance_sampling_probabilities, num_samples=batch_size, replacement=True)
                    t_float_values = indices.float() / forward_process.N 
                    t = t_float_values.to(device)
                else: 
                    t = torch.rand(x_start_latents.shape[0], device=device)


                true_noise = torch.randn_like(x_start_latents)


                x_t, _, std_t = forward_process.get_noised_latents(x_start_latents, t = t, eps = true_noise)           
                
                
                                
                predicted_noise = model(x_t, t)

                # Not weighted loss per sample
                per_sample_loss = criterion(predicted_noise, true_noise)

                # Weighting factor λ(t)
                weighting_factor = forward_process.sde_model.get_g_squared(t)[:, None, None, None]
                
                # Total weighted loss
                weighted_loss = per_sample_loss * weighting_factor
            
                loss = torch.mean(weighted_loss)

                
                total_val_loss += loss.item()

                avg_val_loss = total_val_loss / (val_bar.n + 1) 
                val_bar.set_postfix(loss=f'{avg_val_loss:.4f}')

            avg_val_loss = total_val_loss / len(val_loader)
            
        model.train() # Riporta il modello in modalità training per l'epoca successiva
        
        # --- LOGGING E SALVATAGGIO ---
        
        loss_history['train_loss'].append(final_avg_train_loss)
        loss_history['val_loss'].append(avg_val_loss)

        
        
        # Checkpoint: Salva lo stato completo del modello e la history
        # checkpoint_data = {
        #      'epoch': epoch + 1,
        #      'model_state_dict': model.state_dict(),
        #      'optimizer_state_dict': optimizer.state_dict(),
        #      'avg_val_loss': avg_val_loss,
        #      'loss_history': loss_history, 
        # }
        
        # checkpoint_path = os.path.join(
        #     SAVE_FOLDER, 
        #     f"{MODEL_NAME}_epoch_{epoch+1:03d}_val_{avg_val_loss:.4f}.pth"
        # )
        # torch.save(checkpoint_data, checkpoint_path)
        # print(f"Checkpoint salvato in: {checkpoint_path}")

    # --- LOGICA DI SALVATAGGIO FINALE (Fuori dal ciclo epoch) ---
    print("Training completato!")
    lr_str = str(LEARNING_RATE).replace('.', '') 
    hyper_suffix = f"T{TIMESTEPS}_LR{lr_str}_E{EPOCHS}"
    final_model_filename = f"{MODEL_NAME}_final_{hyper_suffix}.pth"
    final_model_path = os.path.join(SAVE_FOLDER, final_model_filename)

    torch.save({'state_dict': model.state_dict()}, final_model_path)
    
    print(f"Pesi e cronologia finali salvati in: {final_model_path}")
    return loss_history

loss_history = train(model, train_loader=train_loader, val_loader=val_loader, optimizer=optimizer, forward_process=forward_process, criterion=criterion, importance_sampling_probabilities=importance_sampling_probabilities, device=DEVICE)

Epoch 1/5 (TRAIN): 100%|██████████| 125/125 [00:48<00:00,  2.59it/s, loss=5.3009]
Epoch 1/5 [VAL]: 100%|██████████| 32/32 [00:11<00:00,  2.71it/s, loss=5.2453]
Epoch 2/5 (TRAIN): 100%|██████████| 125/125 [00:44<00:00,  2.81it/s, loss=5.1260]
Epoch 2/5 [VAL]: 100%|██████████| 32/32 [00:10<00:00,  3.09it/s, loss=5.0363]
Epoch 3/5 (TRAIN): 100%|██████████| 125/125 [00:45<00:00,  2.74it/s, loss=5.3119]
Epoch 3/5 [VAL]: 100%|██████████| 32/32 [00:10<00:00,  3.13it/s, loss=4.7602]
Epoch 4/5 (TRAIN): 100%|██████████| 125/125 [00:44<00:00,  2.83it/s, loss=4.9173]
Epoch 4/5 [VAL]: 100%|██████████| 32/32 [00:10<00:00,  3.07it/s, loss=5.3259]
Epoch 5/5 (TRAIN): 100%|██████████| 125/125 [00:45<00:00,  2.76it/s, loss=5.2827]
Epoch 5/5 [VAL]: 100%|██████████| 32/32 [00:10<00:00,  3.10it/s, loss=4.9270]

Training completato!
Pesi e cronologia finali salvati in: ../weights/ldm_unet_final_T1000_LR00001_E5.pth





In [None]:

def plot_loss_history(history: dict):
    """
    Visualizza l'andamento delle loss di training e validazione.

    Args:
        history (dict): Dizionario contenente le liste 'train_loss' e 'val_loss'.
    """
    
    # Verifica che le chiavi necessarie siano presenti
    if 'train_loss' not in history or 'val_loss' not in history:
        print("Errore: Il dizionario history deve contenere le chiavi 'train_loss' e 'val_loss'.")
        return

    train_losses = history['train_loss']
    val_losses = history['val_loss']
    
    # Genera l'asse X (numero di epoche)
    epochs = range(1, len(train_losses) + 1) 

    plt.figure(figsize=(10, 6))
    
    # Plotta la Training Loss
    plt.plot(epochs, train_losses, 'b-o', label='Training Loss')
    
    # Plotta la Validation Loss
    plt.plot(epochs, val_losses, 'r-s', label='Validation Loss')
    
    # Aggiunge i dettagli al grafico
    plt.title('Andamento della Loss (DDPM)')
    plt.xlabel('Epoca')
    plt.ylabel('Loss Media (MSE)')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)
    
    # Forza i tick dell'asse X a essere interi (utile se hai poche epoche)
    plt.xticks(epochs) 
    
    plt.show()

# --- Esempio di Utilizzo ---

# 1. Esegui il training e ottieni la history
# loss_history = train(model, diffusion, train_loader, val_loader, optimizer, criterion)
# print("Training completato. Generazione del grafico...")

# 2. Chiama la funzione di plot
# plot_loss_history(loss_history)