In [1]:
# !wget -O faces_dataset.tar.gz "https://www.dropbox.com/scl/fi/7dv71y3nxrcdrpmwntr8e/faces_aligned_small_mirrored_co_aligned_cropped_cleaned.tar.gz?rlkey=h03r92h1mdr9yet2tkqosqq1k&dl=1"

In [None]:
# import tarfile

# with tarfile.open("faces_dataset.tar.gz", "r:gz") as tar:
#     tar.extractall("faces_dataset")

In [None]:
# import os
# file_names = os.listdir("/kaggle/working/faces_dataset/faces_aligned_small_mirrored_co_aligned_cropped_cleaned/M")
# len(file_names)

17673

In [None]:
import sys
!git clone https://github.com/seyedsaberi/DDPM-for-GDA.git
sys.path.append('/kaggle/working/DDPM-for-GDA')

In [None]:
import torch
from gda_functions import (
    Config, set_seed_all, HalfSpaceDenoiser, NoisedMixtureStream, 
    pretrain_w3, train_diffusion_simple
)

# Configure training
cfg = Config()

use_gpu = torch.cuda.is_available()
device = torch.device("cuda" if use_gpu else "cpu")

print(f"Using device: {device}")
if use_gpu:
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# Adjust for CPU if needed
if not use_gpu:
    cfg.steps = min(cfg.steps, 500)
    cfg.batch_size = min(cfg.batch_size, 2048)

set_seed_all(cfg.seed)

# Step 1: Pre-train w3 and b3 to predict i from x_t
print("\n" + "="*60)
print("STEP 1: Pre-training w3 and b3 to predict mixture index i")
print("="*60)
stream = NoisedMixtureStream(cfg, device=device)
pretrained_w3, pretrained_b3 = pretrain_w3(cfg, stream, pretrain_steps=500, lr=1e-2)
print(f"\nPre-trained w3: {pretrained_w3.cpu()}")
print(f"||w3||: {pretrained_w3.norm(p=2).item():.4f}")
print(f"Pre-trained b3: {pretrained_b3.item():.4f}")

# Step 2: Train diffusion model with pre-trained w3
print("\n" + "="*60)
print("STEP 2: Training diffusion model with pre-trained w3")
print("="*60)

# Create model
model = HalfSpaceDenoiser(
    d=cfg.d,
    use_extended=cfg.use_extended,
    gate_slope_k=cfg.gate_slope_k
)

# Train using simple training loop
model = train_diffusion_simple(
    cfg=cfg,
    stream=stream,
    model=model,
    pretrained_w3=pretrained_w3,
    pretrained_b3=pretrained_b3,
    steps=cfg.steps,
    lr=cfg.lr,
    log_every=50
)

# Extract and show learned parameters
print("\n" + "="*60)
print("LEARNED PARAMETERS")
print("="*60)
state = {k: v.detach().cpu() for k, v in model.state_dict().items()}

print("w0:", state.get("w0"))
print("w1:", state.get("w1"))
if cfg.use_extended:
    print("w2:", state.get("w2"))
    print("w3:", state.get("w3"))
    print(f"||w3||: {state.get('w3').norm(p=2).item():.4f}")
    print("b3:", state.get("b3"))
    if state.get("b3") is not None:
        print(f"b3 value: {state.get('b3').item():.4f}")

print("\n" + "="*60)
print("DATA FAMILY PARAMETERS")
print("="*60)
print("mu0:", stream.mu0.detach().cpu())
print("u:", stream.u.detach().cpu())
print(f"||mu0||: {stream.mu0.norm(p=2).item():.4f}")
print(f"||u||: {stream.u.norm(p=2).item():.4f}")


ModuleNotFoundError: No module named 'torch'

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from gda_functions import NoisedMixtureStream

# Create a stream with the same config and seed to get the same mu0 and u
set_seed_all(cfg.seed)
device = torch.device("cuda", 0) if torch.cuda.is_available() else torch.device("cpu")
stream = NoisedMixtureStream(cfg, device=device)

mu0_np = stream.mu0.detach().cpu().numpy()
u_np = stream.u.detach().cpu().numpy()

print(f"mu0 = {mu0_np}")
print(f"u = {u_np}")
print(f"||mu0|| = {np.linalg.norm(mu0_np):.4f}")
print(f"||u|| = {np.linalg.norm(u_np):.4f}")

# If d=2, plot the vectors
if len(mu0_np) == 2:
    fig, ax = plt.subplots(1, 1, figsize=(8, 8))
    
    # Plot origin
    ax.scatter([0], [0], c='black', s=100, marker='o', label='Origin', zorder=5)
    
    # Plot mu0 vector
    ax.quiver(0, 0, mu0_np[0], mu0_np[1], 
              angles='xy', scale_units='xy', scale=1, 
              color='blue', width=0.01, label=r'$\mu_0$', zorder=4)
    
    # Plot u vector
    ax.quiver(0, 0, u_np[0], u_np[1], 
              angles='xy', scale_units='xy', scale=1, 
              color='red', width=0.01, label=r'$u$', zorder=4)
    
    # Plot several mixture centers: mu_i = mu0 + i*u for i=0,...,N
    N = cfg.N
    for i in range(N + 1):
        mu_i = mu0_np + i * u_np
        ax.scatter(mu_i[0], mu_i[1], c='green', s=50, alpha=0.6, zorder=3)
        if i == 0 or i == N:
            ax.text(mu_i[0], mu_i[1], f'  i={i}', fontsize=9, va='bottom')
    
    ax.set_xlabel('Dimension 1', fontsize=12)
    ax.set_ylabel('Dimension 2', fontsize=12)
    ax.set_title(r'Data Family: $\mu_i = \mu_0 + i \cdot u$ for $i=0,\ldots,N$', fontsize=14)
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3)
    ax.axis('equal')
    plt.tight_layout()
    plt.show()
else:
    print(f"\nVisualization skipped (d={len(mu0_np)} != 2)")



In [None]:
import matplotlib.pyplot as plt
import numpy as np
from gda_functions import sample_from_diffusion_with_i, sample_true_distribution

# Choose which mixture index to generate
target_i = 8  # You can change this to any value from 0 to cfg.N

print(f"Generating samples for mixture index i={target_i}")
print("="*60)

# Generate samples from the trained diffusion model
num_samples = 1000

# Model is already on the correct device
model.eval()

print(f"Sampling {num_samples} samples from diffusion model...")
generated_samples = sample_from_diffusion_with_i(
    model, 
    cfg, 
    target_i, 
    pretrained_w3, 
    num_samples, 
    device
)

print(f"Sampling {num_samples} samples from true distribution...")
true_samples = sample_true_distribution(stream, target_i, num_samples)

# Convert to numpy for plotting
gen_np = generated_samples.detach().cpu().numpy()
true_np = true_samples.detach().cpu().numpy()

print(f"\nGenerated samples mean: {gen_np.mean(axis=0)}")
print(f"True samples mean: {true_np.mean(axis=0)}")
print(f"\nGenerated samples std: {gen_np.std(axis=0)}")
print(f"True samples std: {true_np.std(axis=0)}")

# Compute mu_i for reference
mu_i = (stream.mu0 + target_i * stream.u).detach().cpu().numpy()
print(f"\nTrue mu_{target_i}: {mu_i}")

# If d=2, plot the distributions
if cfg.d == 2:
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # Plot 1: Generated samples
    axes[0].scatter(gen_np[:, 0], gen_np[:, 1], alpha=0.3, s=10, c='blue')
    axes[0].scatter([mu_i[0], -mu_i[0]], [mu_i[1], -mu_i[1]], 
                   c='red', s=100, marker='x', linewidths=3, 
                   label=f'True centers (±μ_{target_i})')
    axes[0].set_xlabel('Dimension 1', fontsize=12)
    axes[0].set_ylabel('Dimension 2', fontsize=12)
    axes[0].set_title(f'Generated Samples (i={target_i})', fontsize=14)
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    axes[0].axis('equal')
    
    # Plot 2: True samples
    axes[1].scatter(true_np[:, 0], true_np[:, 1], alpha=0.3, s=10, c='green')
    axes[1].scatter([mu_i[0], -mu_i[0]], [mu_i[1], -mu_i[1]], 
                   c='red', s=100, marker='x', linewidths=3,
                   label=f'True centers (±μ_{target_i})')
    axes[1].set_xlabel('Dimension 1', fontsize=12)
    axes[1].set_ylabel('Dimension 2', fontsize=12)
    axes[1].set_title(f'True Samples (i={target_i})', fontsize=14)
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    axes[1].axis('equal')
    
    # Plot 3: Overlay
    axes[2].scatter(gen_np[:, 0], gen_np[:, 1], alpha=0.3, s=10, c='blue', label='Generated')
    axes[2].scatter(true_np[:, 0], true_np[:, 1], alpha=0.3, s=10, c='green', label='True')
    axes[2].scatter([mu_i[0], -mu_i[0]], [mu_i[1], -mu_i[1]], 
                   c='red', s=100, marker='x', linewidths=3,
                   label=f'True centers (±μ_{target_i})')
    axes[2].set_xlabel('Dimension 1', fontsize=12)
    axes[2].set_ylabel('Dimension 2', fontsize=12)
    axes[2].set_title(f'Overlay (i={target_i})', fontsize=14)
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)
    axes[2].axis('equal')
    
    plt.tight_layout()
    plt.show()
    
    # Plot histograms for each dimension
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    for dim in range(2):
        axes[dim].hist(gen_np[:, dim], bins=50, alpha=0.5, label='Generated', color='blue', density=True)
        axes[dim].hist(true_np[:, dim], bins=50, alpha=0.5, label='True', color='green', density=True)
        axes[dim].axvline(mu_i[dim], color='red', linestyle='--', linewidth=2, label=f'μ_{target_i}[{dim}]')
        axes[dim].axvline(-mu_i[dim], color='orange', linestyle='--', linewidth=2, label=f'-μ_{target_i}[{dim}]')
        axes[dim].set_xlabel(f'Dimension {dim}', fontsize=12)
        axes[dim].set_ylabel('Density', fontsize=12)
        axes[dim].set_title(f'Distribution along Dimension {dim}', fontsize=14)
        axes[dim].legend()
        axes[dim].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
else:
    print(f"\nVisualization skipped (d={cfg.d} != 2)")
    print("For high-dimensional data, consider using dimensionality reduction (PCA, t-SNE, etc.)")

