In [1]:
# %%

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torchdiffeq import odeint
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import json
from datetime import datetime

# %%
# ================== Experiment Configuration ==================
EXPERIMENT_NAME = "exp6_cfg"  # Change this for each experiment
RESULTS_DIR = f"results/{EXPERIMENT_NAME}"
os.makedirs(RESULTS_DIR, exist_ok=True)

# %%
# ================== Configuration Parameters ==================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_size = 32  # CIFAR-10 is 32x32
channels = 3  # CIFAR-10 is RGB
batch_size = 128
num_classes = 10
model_save_path = os.path.join(RESULTS_DIR, 'FMmodel.pth')

# %%
# Learning Rate
lr = 5e-4
# Number of Epochs
epochs = 100
# Model Size
BASE_CHANNELS = 64
# Time Embedding Dimension
TIME_EMBED_DIM = 64
# ODE Solver
ODE_METHOD = 'dopri5'  # adaptive Runge-Kutta
# Number of ODE Steps
ODE_STEPS = 50
# Loss Function: b/w u^target and u^theta
LOSS_TYPE = 'mse'  # Mean Squared Error
# Normalization
NORM_TYPE = 'groupnorm'
# Activation Function
ACTIVATION = 'silu'  # SiLU/Swish
# Optimizer
OPTIMIZER_TYPE = 'adamw'  # AdamW

# ================== Classifier-Free Guidance Parameters ==================
CFG_DROPOUT_PROB = 0.1  # Probability of dropping conditioning during training
CFG_GUIDANCE_SCALE = 2.0  # Guidance scale for generation (higher = more conditioning influence)

# Save experiment config
config = {
    'experiment_name': EXPERIMENT_NAME,
    'lr': lr,
    'epochs': epochs,
    'batch_size': batch_size,
    'base_channels': BASE_CHANNELS,
    'time_embed_dim': TIME_EMBED_DIM,
    'ode_method': ODE_METHOD,
    'ode_steps': ODE_STEPS,
    'loss_type': LOSS_TYPE,
    'norm_type': NORM_TYPE,
    'activation': ACTIVATION,
    'optimizer_type': OPTIMIZER_TYPE,
    'image_size': image_size,
    'channels': channels,
    'device': str(device),
    'cfg_dropout_prob': CFG_DROPOUT_PROB,
    'cfg_guidance_scale': CFG_GUIDANCE_SCALE
}

with open(os.path.join(RESULTS_DIR, 'config.json'), 'w') as f:
    json.dump(config, f, indent=4)

# %%
# ================== Data Loading ==================
def normalize_img(x):
    """Normalize image to [-1, 1]"""
    return 2 * x - 1

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(normalize_img)
])

# ================== Helper Functions ==================
def get_activation():
    """Return activation function based on config"""
    if ACTIVATION == 'silu':
        return F.silu
    elif ACTIVATION == 'gelu':
        return F.gelu
    else:
        return F.silu

def get_norm_layer(num_channels):
    """Return normalization layer based on config"""
    if NORM_TYPE == 'groupnorm':
        return nn.GroupNorm(min(32, num_channels), num_channels)
    elif NORM_TYPE == 'batchnorm':
        return nn.BatchNorm2d(num_channels)
    elif NORM_TYPE == 'layernorm':
        return nn.GroupNorm(1, num_channels)
    else:
        return nn.GroupNorm(min(32, num_channels), num_channels)

# %%
# ================== Model Architecture ==================
class ConditionedDoubleConv(nn.Module):
    """Double convolution module with condition injection"""
    def __init__(self, in_channels, out_channels, cond_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.norm1 = get_norm_layer(out_channels)
        self.conv2 = nn.Conv2d(out_channels + cond_dim, out_channels, kernel_size=3, padding=1)
        self.norm2 = get_norm_layer(out_channels)
        self.activation = get_activation()

    def forward(self, x, cond):
        x = self.activation(self.norm1(self.conv1(x)))
        cond = cond.expand(-1, -1, x.size(2), x.size(3))
        x = torch.cat([x, cond], dim=1)
        return self.activation(self.norm2(self.conv2(x)))

class Down(nn.Module):
    """Downsampling module"""
    def __init__(self, in_channels, out_channels, cond_dim):
        super().__init__()
        self.maxpool = nn.MaxPool2d(2)
        self.conv = ConditionedDoubleConv(in_channels, out_channels, cond_dim)

    def forward(self, x, cond):
        x = self.maxpool(x)
        return self.conv(x, cond)

class Up(nn.Module):
    """Upsampling module"""
    def __init__(self, in_channels, out_channels, cond_dim):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv = ConditionedDoubleConv(in_channels, out_channels, cond_dim)

    def forward(self, x1, x2, cond):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x, cond)

class ConditionalUNet(nn.Module):
    """Enhanced UNet for CIFAR-10 with Classifier-Free Guidance"""
    def __init__(self):
        super().__init__()
        # Condition dimensions
        self.label_dim = 32
        self.cond_dim = TIME_EMBED_DIM + self.label_dim
        
        # Time embedding
        self.time_embed = nn.Sequential(
            nn.Linear(1, TIME_EMBED_DIM * 2),
            nn.SiLU(),
            nn.Linear(TIME_EMBED_DIM * 2, TIME_EMBED_DIM)
        )
        
        # Label embedding (num_classes + 1 for unconditional class)
        self.label_embed = nn.Embedding(num_classes + 1, self.label_dim)
        self.null_label_id = num_classes  # Use last index for unconditional
        
        # Encoder path
        self.inc = ConditionedDoubleConv(channels, BASE_CHANNELS, self.cond_dim)
        self.down1 = Down(BASE_CHANNELS, BASE_CHANNELS * 2, self.cond_dim)
        self.down2 = Down(BASE_CHANNELS * 2, BASE_CHANNELS * 4, self.cond_dim)
        self.down3 = Down(BASE_CHANNELS * 4, BASE_CHANNELS * 8, self.cond_dim)
        
        # Decoder path
        self.up1 = Up(BASE_CHANNELS * 8 + BASE_CHANNELS * 4, BASE_CHANNELS * 4, self.cond_dim)
        self.up2 = Up(BASE_CHANNELS * 4 + BASE_CHANNELS * 2, BASE_CHANNELS * 2, self.cond_dim)
        self.up3 = Up(BASE_CHANNELS * 2 + BASE_CHANNELS, BASE_CHANNELS, self.cond_dim)
        self.outc = nn.Conv2d(BASE_CHANNELS, channels, kernel_size=1)

    def forward(self, x, t, labels):
        # Condition encoding
        t_emb = self.time_embed(t.view(-1, 1))
        lbl_emb = self.label_embed(labels)
        cond = torch.cat([t_emb, lbl_emb], dim=1)
        cond = cond.unsqueeze(-1).unsqueeze(-1)
        
        # Encoder
        x1 = self.inc(x, cond)
        x2 = self.down1(x1, cond)
        x3 = self.down2(x2, cond)
        x4 = self.down3(x3, cond)
        
        # Decoder
        x = self.up1(x4, x3, cond)
        x = self.up2(x, x2, cond)
        x = self.up3(x, x1, cond)
        return self.outc(x)

# %%
# ================== Training and Generation ==================
model = ConditionalUNet().to(device)

# Initialize optimizer based on config
if OPTIMIZER_TYPE == 'adam':
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
elif OPTIMIZER_TYPE == 'adamw':
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
elif OPTIMIZER_TYPE == 'sgd':
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
else:
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

def compute_loss(pred, target):
    """Compute loss based on config"""
    if LOSS_TYPE == 'mse':
        return F.mse_loss(pred, target)
    elif LOSS_TYPE == 'l1':
        return F.l1_loss(pred, target)
    elif LOSS_TYPE == 'huber':
        return F.huber_loss(pred, target)
    else:
        return F.mse_loss(pred, target)

@torch.no_grad()
def generate_with_label(label, num_samples=16, guidance_scale=CFG_GUIDANCE_SCALE):
    """Generate samples with specified label using Classifier-Free Guidance"""
    current_model_state = model.training
    model.eval()
    x0 = torch.randn(num_samples, channels, image_size, image_size, device=device)
    labels = torch.full((num_samples,), label, device=device, dtype=torch.long)
    
    def ode_func(t: torch.Tensor, x: torch.Tensor):
        t_expanded = t.expand(x.size(0))
        
        if guidance_scale == 1.0:
            # No guidance, just use conditional prediction
            vt = model(x, t_expanded, labels)
        else:
            # Classifier-Free Guidance: interpolate between conditional and unconditional
            # Conditional prediction
            vt_cond = model(x, t_expanded, labels)
            
            # Unconditional prediction (using null label)
            null_labels = torch.full_like(labels, model.null_label_id)
            vt_uncond = model(x, t_expanded, null_labels)
            
            # Apply guidance: vt = vt_uncond + guidance_scale * (vt_cond - vt_uncond)
            vt = vt_uncond + guidance_scale * (vt_cond - vt_uncond)
        
        return vt
    
    # Use different time points based on ODE method
    if ODE_METHOD in ['euler', 'rk4', 'midpoint']:
        t_eval = torch.linspace(0.0, 1.0, ODE_STEPS + 1, device=device)
    else:
        t_eval = torch.tensor([0.0, 1.0], device=device)
    
    generated = odeint(
        ode_func,
        x0,
        t_eval,
        rtol=1e-5,
        atol=1e-5,
        method=ODE_METHOD
    )
    
    model.train(current_model_state)
    images = (generated[-1].clamp(-1, 1) + 1) / 2
    return images.cpu()

def visualize_train(epoch):
    """Generate visualization grid"""
    print(f"Generating visualization for epoch {epoch}...")
    plt.figure(figsize=(12, 12))
    plt.subplots_adjust(wspace=0.1, hspace=0.1)
    
    for label in range(num_classes):
        generated_images = generate_with_label(label=label, num_samples=10)
        for i in range(10):
            ax = plt.subplot(10, num_classes, (i * num_classes) + label + 1)
            img = generated_images[i].permute(1, 2, 0).numpy()
            plt.imshow(img, vmin=0, vmax=1)
            ax.axis('off')
            if i == 0:
                ax.set_title(str(label), fontsize=14, pad=5)
    
    plt.suptitle(f"Generated Samples - Epoch {epoch} (CFG scale={CFG_GUIDANCE_SCALE})", 
                 fontsize=18, y=0.98)
    plt.savefig(os.path.join(RESULTS_DIR, f"epoch{epoch}.jpg"), dpi=150, bbox_inches='tight')
    plt.close()

def generate_final_samples(num_sample=5):
    """Generate final sample grids with different guidance scales"""
    guidance_scales = [1.0, 1.5, 2.0, 3.0, 5.0]
    
    for scale in guidance_scales:
        print(f"\nGenerating samples with guidance scale = {scale}")
        for k in range(num_sample):
            plt.figure(figsize=(12, 12))
            plt.subplots_adjust(wspace=0.1, hspace=0.1)
            print(f"Generating grid {k + 1}/{num_sample} (scale={scale})...")
            
            for label in tqdm(range(num_classes), desc=f"Grid {k+1}"):
                generated_images = generate_with_label(label=label, num_samples=10, 
                                                      guidance_scale=scale)
                for i in range(10):
                    ax = plt.subplot(10, num_classes, (i * num_classes) + label + 1)
                    img = generated_images[i].permute(1, 2, 0).numpy()
                    plt.imshow(img, vmin=0, vmax=1)
                    ax.axis('off')
                    if i == 0:
                        ax.set_title(str(label), fontsize=14, pad=5)
            
            plt.suptitle(f"Generated Samples (Grid {k + 1}, CFG scale={scale})", 
                        fontsize=18, y=0.98)
            save_path = os.path.join(RESULTS_DIR, f"generated_grid{k + 1}_cfg{scale}.jpg")
            plt.savefig(save_path, dpi=150, bbox_inches='tight')
            print(f"Saved: {save_path}")
            plt.close()

# %%
def train(num_epochs=100):
    """Training loop with Classifier-Free Guidance"""
    print(f"Starting training for experiment: {EXPERIMENT_NAME}")
    print(f"Results will be saved to: {RESULTS_DIR}")
    print(f"Configuration: {config}")
    print(f"CFG Dropout Probability: {CFG_DROPOUT_PROB}")
    
    global train_loader
    training_log = []
    
    for epoch in range(num_epochs):
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}")
        model.train()
        total_loss = 0
        num_batches = 0
        
        for images, labels in progress_bar:
            images = images.to(device)
            labels = labels.to(device)
            
            # Classifier-Free Guidance: randomly drop conditioning
            # Create a mask for which samples should be unconditional
            drop_mask = torch.rand(labels.size(0), device=device) < CFG_DROPOUT_PROB
            
            # Replace labels with null label where mask is True
            labels_with_dropout = labels.clone()
            labels_with_dropout[drop_mask] = model.null_label_id
            
            noise = torch.randn_like(images)
            t = torch.rand(images.size(0), device=device)
            xt = (1 - t.view(-1, 1, 1, 1)) * noise + t.view(-1, 1, 1, 1) * images
            
            vt_pred = model(xt, t, labels_with_dropout)
            loss = compute_loss(vt_pred, images - noise)
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
            progress_bar.set_postfix({"Loss": f"{loss.item():.4f}"})
        
        avg_loss = total_loss / num_batches
        training_log.append({'epoch': epoch + 1, 'loss': avg_loss})
        print(f"Epoch {epoch + 1} - Average Loss: {avg_loss:.4f}")
        
        # Save loss log
        with open(os.path.join(RESULTS_DIR, 'training_log.json'), 'w') as f:
            json.dump(training_log, f, indent=4)
        
        # Generate samples
        if (epoch + 1) % 10 == 0 or epoch == 0:
            visualize_train(epoch + 1)
        
        # Save model
        torch.save(model.state_dict(), model_save_path)
    
    print(f"Training complete. Model saved to: {model_save_path}")
    
    # Plot training curve
    plot_training_curve(training_log)

def plot_training_curve(training_log):
    """Plot and save training loss curve"""
    epochs_list = [entry['epoch'] for entry in training_log]
    losses = [entry['loss'] for entry in training_log]
    
    plt.figure(figsize=(10, 6))
    plt.plot(epochs_list, losses, linewidth=2)
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt.title(f'Training Loss - {EXPERIMENT_NAME}', fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.savefig(os.path.join(RESULTS_DIR, 'training_curve.png'), dpi=150, bbox_inches='tight')
    plt.close()
    print(f"Training curve saved to: {os.path.join(RESULTS_DIR, 'training_curve.png')}")

# %%
# Load CIFAR-10
train_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                          num_workers=4, pin_memory=True)

print(f"Dataset loaded: {len(train_dataset)} training images")
print(f"Image shape: {channels} x {image_size} x {image_size}")

# Train model
train(epochs)

# Generate final samples with different guidance scales
generate_final_samples(num_sample=5)

print(f"\nExperiment '{EXPERIMENT_NAME}' completed!")
print(f"All results saved in: {RESULTS_DIR}")
print(f"\nGenerated samples with guidance scales: [1.0, 1.5, 2.0, 3.0, 5.0]")

Dataset loaded: 50000 training images
Image shape: 3 x 32 x 32
Starting training for experiment: exp6_cfg
Results will be saved to: results/exp6_cfg
Configuration: {'experiment_name': 'exp6_cfg', 'lr': 0.0005, 'epochs': 100, 'batch_size': 128, 'base_channels': 64, 'time_embed_dim': 64, 'ode_method': 'dopri5', 'ode_steps': 50, 'loss_type': 'mse', 'norm_type': 'groupnorm', 'activation': 'silu', 'optimizer_type': 'adamw', 'image_size': 32, 'channels': 3, 'device': 'cuda', 'cfg_dropout_prob': 0.1, 'cfg_guidance_scale': 2.0}
CFG Dropout Probability: 0.1


Epoch 1/100: 100%|██████████| 391/391 [01:29<00:00,  4.39it/s, Loss=0.2188]


Epoch 1 - Average Loss: 0.2865
Generating visualization for epoch 1...


Epoch 2/100: 100%|██████████| 391/391 [01:28<00:00,  4.41it/s, Loss=0.2020]


Epoch 2 - Average Loss: 0.2276


Epoch 3/100: 100%|██████████| 391/391 [01:28<00:00,  4.41it/s, Loss=0.2221]


Epoch 3 - Average Loss: 0.2162


Epoch 4/100: 100%|██████████| 391/391 [01:28<00:00,  4.40it/s, Loss=0.2105]


Epoch 4 - Average Loss: 0.2091


Epoch 5/100: 100%|██████████| 391/391 [01:29<00:00,  4.39it/s, Loss=0.1861]


Epoch 5 - Average Loss: 0.2072


Epoch 6/100: 100%|██████████| 391/391 [01:28<00:00,  4.39it/s, Loss=0.2312]


Epoch 6 - Average Loss: 0.2034


Epoch 7/100: 100%|██████████| 391/391 [01:29<00:00,  4.39it/s, Loss=0.1910]


Epoch 7 - Average Loss: 0.2013


Epoch 8/100: 100%|██████████| 391/391 [01:29<00:00,  4.38it/s, Loss=0.2090]


Epoch 8 - Average Loss: 0.2000


Epoch 9/100: 100%|██████████| 391/391 [01:24<00:00,  4.63it/s, Loss=0.2069]


Epoch 9 - Average Loss: 0.1988


Epoch 10/100: 100%|██████████| 391/391 [01:29<00:00,  4.38it/s, Loss=0.2027]


Epoch 10 - Average Loss: 0.1978
Generating visualization for epoch 10...


Epoch 11/100: 100%|██████████| 391/391 [01:29<00:00,  4.39it/s, Loss=0.2038]


Epoch 11 - Average Loss: 0.1974


Epoch 12/100: 100%|██████████| 391/391 [01:29<00:00,  4.39it/s, Loss=0.1755]


Epoch 12 - Average Loss: 0.1966


Epoch 13/100: 100%|██████████| 391/391 [01:29<00:00,  4.38it/s, Loss=0.1945]


Epoch 13 - Average Loss: 0.1951


Epoch 14/100: 100%|██████████| 391/391 [01:29<00:00,  4.37it/s, Loss=0.1824]


Epoch 14 - Average Loss: 0.1941


Epoch 15/100: 100%|██████████| 391/391 [01:29<00:00,  4.38it/s, Loss=0.1951]


Epoch 15 - Average Loss: 0.1934


Epoch 16/100: 100%|██████████| 391/391 [01:29<00:00,  4.38it/s, Loss=0.1839]


Epoch 16 - Average Loss: 0.1927


Epoch 17/100: 100%|██████████| 391/391 [01:29<00:00,  4.37it/s, Loss=0.1953]


Epoch 17 - Average Loss: 0.1921


Epoch 18/100: 100%|██████████| 391/391 [01:29<00:00,  4.38it/s, Loss=0.2053]


Epoch 18 - Average Loss: 0.1924


Epoch 19/100: 100%|██████████| 391/391 [01:29<00:00,  4.37it/s, Loss=0.2007]


Epoch 19 - Average Loss: 0.1917


Epoch 20/100: 100%|██████████| 391/391 [01:29<00:00,  4.38it/s, Loss=0.2181]


Epoch 20 - Average Loss: 0.1924
Generating visualization for epoch 20...


Epoch 21/100: 100%|██████████| 391/391 [01:29<00:00,  4.37it/s, Loss=0.1832]


Epoch 21 - Average Loss: 0.1910


Epoch 22/100: 100%|██████████| 391/391 [01:29<00:00,  4.38it/s, Loss=0.1771]


Epoch 22 - Average Loss: 0.1914


Epoch 23/100: 100%|██████████| 391/391 [01:29<00:00,  4.38it/s, Loss=0.1857]


Epoch 23 - Average Loss: 0.1897


Epoch 24/100: 100%|██████████| 391/391 [01:25<00:00,  4.59it/s, Loss=0.1993]


Epoch 24 - Average Loss: 0.1901


Epoch 25/100: 100%|██████████| 391/391 [01:29<00:00,  4.38it/s, Loss=0.2074]


Epoch 25 - Average Loss: 0.1894


Epoch 26/100: 100%|██████████| 391/391 [01:29<00:00,  4.36it/s, Loss=0.1777]


Epoch 26 - Average Loss: 0.1891


Epoch 27/100: 100%|██████████| 391/391 [01:29<00:00,  4.37it/s, Loss=0.1875]


Epoch 27 - Average Loss: 0.1897


Epoch 28/100: 100%|██████████| 391/391 [01:29<00:00,  4.37it/s, Loss=0.1819]


Epoch 28 - Average Loss: 0.1886


Epoch 29/100: 100%|██████████| 391/391 [01:29<00:00,  4.37it/s, Loss=0.2051]


Epoch 29 - Average Loss: 0.1887


Epoch 30/100: 100%|██████████| 391/391 [01:29<00:00,  4.38it/s, Loss=0.1833]


Epoch 30 - Average Loss: 0.1888
Generating visualization for epoch 30...


Epoch 31/100: 100%|██████████| 391/391 [01:29<00:00,  4.38it/s, Loss=0.1690]


Epoch 31 - Average Loss: 0.1894


Epoch 32/100: 100%|██████████| 391/391 [01:29<00:00,  4.37it/s, Loss=0.2018]


Epoch 32 - Average Loss: 0.1886


Epoch 33/100: 100%|██████████| 391/391 [01:29<00:00,  4.37it/s, Loss=0.1815]


Epoch 33 - Average Loss: 0.1872


Epoch 34/100: 100%|██████████| 391/391 [01:29<00:00,  4.37it/s, Loss=0.1725]


Epoch 34 - Average Loss: 0.1875


Epoch 35/100: 100%|██████████| 391/391 [01:29<00:00,  4.37it/s, Loss=0.2092]


Epoch 35 - Average Loss: 0.1867


Epoch 36/100: 100%|██████████| 391/391 [01:29<00:00,  4.36it/s, Loss=0.2034]


Epoch 36 - Average Loss: 0.1880


Epoch 37/100: 100%|██████████| 391/391 [01:29<00:00,  4.36it/s, Loss=0.2038]


Epoch 37 - Average Loss: 0.1867


Epoch 38/100: 100%|██████████| 391/391 [01:28<00:00,  4.44it/s, Loss=0.1708]


Epoch 38 - Average Loss: 0.1873


Epoch 39/100: 100%|██████████| 391/391 [01:26<00:00,  4.52it/s, Loss=0.1715]


Epoch 39 - Average Loss: 0.1872


Epoch 40/100: 100%|██████████| 391/391 [01:29<00:00,  4.37it/s, Loss=0.1714]


Epoch 40 - Average Loss: 0.1860
Generating visualization for epoch 40...


Epoch 41/100: 100%|██████████| 391/391 [01:29<00:00,  4.37it/s, Loss=0.2009]


Epoch 41 - Average Loss: 0.1865


Epoch 42/100: 100%|██████████| 391/391 [01:29<00:00,  4.36it/s, Loss=0.1823]


Epoch 42 - Average Loss: 0.1862


Epoch 43/100: 100%|██████████| 391/391 [01:29<00:00,  4.37it/s, Loss=0.1848]


Epoch 43 - Average Loss: 0.1862


Epoch 44/100: 100%|██████████| 391/391 [01:29<00:00,  4.37it/s, Loss=0.1752]


Epoch 44 - Average Loss: 0.1861


Epoch 45/100: 100%|██████████| 391/391 [01:29<00:00,  4.36it/s, Loss=0.1718]


Epoch 45 - Average Loss: 0.1854


Epoch 46/100: 100%|██████████| 391/391 [01:29<00:00,  4.37it/s, Loss=0.1756]


Epoch 46 - Average Loss: 0.1855


Epoch 47/100: 100%|██████████| 391/391 [01:29<00:00,  4.37it/s, Loss=0.1787]


Epoch 47 - Average Loss: 0.1851


Epoch 48/100: 100%|██████████| 391/391 [01:29<00:00,  4.36it/s, Loss=0.1932]


Epoch 48 - Average Loss: 0.1852


Epoch 49/100: 100%|██████████| 391/391 [01:29<00:00,  4.36it/s, Loss=0.1939]


Epoch 49 - Average Loss: 0.1848


Epoch 50/100: 100%|██████████| 391/391 [01:29<00:00,  4.36it/s, Loss=0.1929]


Epoch 50 - Average Loss: 0.1859
Generating visualization for epoch 50...


Epoch 51/100: 100%|██████████| 391/391 [01:29<00:00,  4.37it/s, Loss=0.1870]


Epoch 51 - Average Loss: 0.1850


Epoch 52/100: 100%|██████████| 391/391 [01:29<00:00,  4.36it/s, Loss=0.1664]


Epoch 52 - Average Loss: 0.1849


Epoch 53/100: 100%|██████████| 391/391 [01:25<00:00,  4.60it/s, Loss=0.1855]


Epoch 53 - Average Loss: 0.1845


Epoch 54/100: 100%|██████████| 391/391 [01:29<00:00,  4.36it/s, Loss=0.1728]


Epoch 54 - Average Loss: 0.1842


Epoch 55/100: 100%|██████████| 391/391 [01:29<00:00,  4.37it/s, Loss=0.2107]


Epoch 55 - Average Loss: 0.1838


Epoch 56/100: 100%|██████████| 391/391 [01:29<00:00,  4.36it/s, Loss=0.1862]


Epoch 56 - Average Loss: 0.1843


Epoch 57/100: 100%|██████████| 391/391 [01:29<00:00,  4.35it/s, Loss=0.1879]


Epoch 57 - Average Loss: 0.1843


Epoch 58/100: 100%|██████████| 391/391 [01:29<00:00,  4.37it/s, Loss=0.1678]


Epoch 58 - Average Loss: 0.1847


Epoch 59/100: 100%|██████████| 391/391 [01:29<00:00,  4.37it/s, Loss=0.2006]


Epoch 59 - Average Loss: 0.1844


Epoch 60/100: 100%|██████████| 391/391 [01:29<00:00,  4.37it/s, Loss=0.2114]


Epoch 60 - Average Loss: 0.1847
Generating visualization for epoch 60...


Epoch 61/100: 100%|██████████| 391/391 [01:29<00:00,  4.38it/s, Loss=0.1858]


Epoch 61 - Average Loss: 0.1842


Epoch 62/100: 100%|██████████| 391/391 [01:29<00:00,  4.37it/s, Loss=0.1668]


Epoch 62 - Average Loss: 0.1827


Epoch 63/100: 100%|██████████| 391/391 [01:29<00:00,  4.37it/s, Loss=0.1782]


Epoch 63 - Average Loss: 0.1839


Epoch 64/100: 100%|██████████| 391/391 [01:29<00:00,  4.36it/s, Loss=0.1658]


Epoch 64 - Average Loss: 0.1832


Epoch 65/100: 100%|██████████| 391/391 [01:29<00:00,  4.38it/s, Loss=0.1712]


Epoch 65 - Average Loss: 0.1842


Epoch 66/100: 100%|██████████| 391/391 [01:29<00:00,  4.37it/s, Loss=0.1770]


Epoch 66 - Average Loss: 0.1837


Epoch 67/100: 100%|██████████| 391/391 [01:29<00:00,  4.36it/s, Loss=0.1800]


Epoch 67 - Average Loss: 0.1837


Epoch 68/100: 100%|██████████| 391/391 [01:25<00:00,  4.57it/s, Loss=0.1847]


Epoch 68 - Average Loss: 0.1827


Epoch 69/100: 100%|██████████| 391/391 [01:29<00:00,  4.37it/s, Loss=0.1773]


Epoch 69 - Average Loss: 0.1837


Epoch 70/100: 100%|██████████| 391/391 [01:29<00:00,  4.36it/s, Loss=0.2047]


Epoch 70 - Average Loss: 0.1839
Generating visualization for epoch 70...


Epoch 71/100: 100%|██████████| 391/391 [01:29<00:00,  4.36it/s, Loss=0.1879]


Epoch 71 - Average Loss: 0.1829


Epoch 72/100: 100%|██████████| 391/391 [01:29<00:00,  4.36it/s, Loss=0.1899]


Epoch 72 - Average Loss: 0.1832


Epoch 73/100: 100%|██████████| 391/391 [01:29<00:00,  4.35it/s, Loss=0.1906]


Epoch 73 - Average Loss: 0.1835


Epoch 74/100: 100%|██████████| 391/391 [01:29<00:00,  4.36it/s, Loss=0.1853]


Epoch 74 - Average Loss: 0.1823


Epoch 75/100: 100%|██████████| 391/391 [01:29<00:00,  4.37it/s, Loss=0.1815]


Epoch 75 - Average Loss: 0.1831


Epoch 76/100: 100%|██████████| 391/391 [01:29<00:00,  4.36it/s, Loss=0.2034]


Epoch 76 - Average Loss: 0.1829


Epoch 77/100: 100%|██████████| 391/391 [01:29<00:00,  4.37it/s, Loss=0.1533]


Epoch 77 - Average Loss: 0.1827


Epoch 78/100: 100%|██████████| 391/391 [01:29<00:00,  4.36it/s, Loss=0.1805]


Epoch 78 - Average Loss: 0.1826


Epoch 79/100: 100%|██████████| 391/391 [01:29<00:00,  4.35it/s, Loss=0.1817]


Epoch 79 - Average Loss: 0.1826


Epoch 80/100: 100%|██████████| 391/391 [01:29<00:00,  4.36it/s, Loss=0.1888]


Epoch 80 - Average Loss: 0.1828
Generating visualization for epoch 80...


Epoch 81/100: 100%|██████████| 391/391 [01:29<00:00,  4.36it/s, Loss=0.1799]


Epoch 81 - Average Loss: 0.1829


Epoch 82/100: 100%|██████████| 391/391 [01:26<00:00,  4.54it/s, Loss=0.1777]


Epoch 82 - Average Loss: 0.1827


Epoch 83/100: 100%|██████████| 391/391 [01:28<00:00,  4.41it/s, Loss=0.1883]


Epoch 83 - Average Loss: 0.1823


Epoch 84/100: 100%|██████████| 391/391 [01:29<00:00,  4.36it/s, Loss=0.1681]


Epoch 84 - Average Loss: 0.1823


Epoch 85/100: 100%|██████████| 391/391 [01:30<00:00,  4.34it/s, Loss=0.1770]


Epoch 85 - Average Loss: 0.1821


Epoch 86/100: 100%|██████████| 391/391 [01:29<00:00,  4.35it/s, Loss=0.1703]


Epoch 86 - Average Loss: 0.1824


Epoch 87/100: 100%|██████████| 391/391 [01:29<00:00,  4.36it/s, Loss=0.1882]


Epoch 87 - Average Loss: 0.1825


Epoch 88/100: 100%|██████████| 391/391 [01:29<00:00,  4.35it/s, Loss=0.1702]


Epoch 88 - Average Loss: 0.1818


Epoch 89/100: 100%|██████████| 391/391 [01:29<00:00,  4.36it/s, Loss=0.1775]


Epoch 89 - Average Loss: 0.1820


Epoch 90/100: 100%|██████████| 391/391 [01:29<00:00,  4.37it/s, Loss=0.1750]


Epoch 90 - Average Loss: 0.1825
Generating visualization for epoch 90...


Epoch 91/100: 100%|██████████| 391/391 [01:29<00:00,  4.37it/s, Loss=0.2198]


Epoch 91 - Average Loss: 0.1823


Epoch 92/100: 100%|██████████| 391/391 [01:29<00:00,  4.35it/s, Loss=0.1705]


Epoch 92 - Average Loss: 0.1816


Epoch 93/100: 100%|██████████| 391/391 [01:29<00:00,  4.36it/s, Loss=0.1762]


Epoch 93 - Average Loss: 0.1818


Epoch 94/100: 100%|██████████| 391/391 [01:29<00:00,  4.36it/s, Loss=0.1611]


Epoch 94 - Average Loss: 0.1818


Epoch 95/100: 100%|██████████| 391/391 [01:29<00:00,  4.35it/s, Loss=0.1711]


Epoch 95 - Average Loss: 0.1815


Epoch 96/100: 100%|██████████| 391/391 [01:29<00:00,  4.35it/s, Loss=0.1696]


Epoch 96 - Average Loss: 0.1812


Epoch 97/100: 100%|██████████| 391/391 [01:25<00:00,  4.58it/s, Loss=0.1800]


Epoch 97 - Average Loss: 0.1822


Epoch 98/100: 100%|██████████| 391/391 [01:29<00:00,  4.37it/s, Loss=0.1965]


Epoch 98 - Average Loss: 0.1814


Epoch 99/100: 100%|██████████| 391/391 [01:29<00:00,  4.36it/s, Loss=0.1853]


Epoch 99 - Average Loss: 0.1819


Epoch 100/100: 100%|██████████| 391/391 [01:29<00:00,  4.36it/s, Loss=0.1606]


Epoch 100 - Average Loss: 0.1821
Generating visualization for epoch 100...
Training complete. Model saved to: results/exp6_cfg/FMmodel.pth
Training curve saved to: results/exp6_cfg/training_curve.png

Generating samples with guidance scale = 1.0
Generating grid 1/5 (scale=1.0)...


Grid 1: 100%|██████████| 10/10 [00:20<00:00,  2.09s/it]


Saved: results/exp6_cfg/generated_grid1_cfg1.0.jpg
Generating grid 2/5 (scale=1.0)...


Grid 2: 100%|██████████| 10/10 [00:21<00:00,  2.13s/it]


Saved: results/exp6_cfg/generated_grid2_cfg1.0.jpg
Generating grid 3/5 (scale=1.0)...


Grid 3: 100%|██████████| 10/10 [00:20<00:00,  2.03s/it]


Saved: results/exp6_cfg/generated_grid3_cfg1.0.jpg
Generating grid 4/5 (scale=1.0)...


Grid 4: 100%|██████████| 10/10 [00:20<00:00,  2.10s/it]


Saved: results/exp6_cfg/generated_grid4_cfg1.0.jpg
Generating grid 5/5 (scale=1.0)...


Grid 5: 100%|██████████| 10/10 [00:21<00:00,  2.13s/it]


Saved: results/exp6_cfg/generated_grid5_cfg1.0.jpg

Generating samples with guidance scale = 1.5
Generating grid 1/5 (scale=1.5)...


Grid 1: 100%|██████████| 10/10 [00:33<00:00,  3.32s/it]


Saved: results/exp6_cfg/generated_grid1_cfg1.5.jpg
Generating grid 2/5 (scale=1.5)...


Grid 2: 100%|██████████| 10/10 [00:32<00:00,  3.24s/it]


Saved: results/exp6_cfg/generated_grid2_cfg1.5.jpg
Generating grid 3/5 (scale=1.5)...


Grid 3: 100%|██████████| 10/10 [00:30<00:00,  3.09s/it]


Saved: results/exp6_cfg/generated_grid3_cfg1.5.jpg
Generating grid 4/5 (scale=1.5)...


Grid 4: 100%|██████████| 10/10 [00:32<00:00,  3.28s/it]


Saved: results/exp6_cfg/generated_grid4_cfg1.5.jpg
Generating grid 5/5 (scale=1.5)...


Grid 5: 100%|██████████| 10/10 [00:31<00:00,  3.11s/it]


Saved: results/exp6_cfg/generated_grid5_cfg1.5.jpg

Generating samples with guidance scale = 2.0
Generating grid 1/5 (scale=2.0)...


Grid 1: 100%|██████████| 10/10 [00:32<00:00,  3.22s/it]


Saved: results/exp6_cfg/generated_grid1_cfg2.0.jpg
Generating grid 2/5 (scale=2.0)...


Grid 2: 100%|██████████| 10/10 [00:32<00:00,  3.30s/it]


Saved: results/exp6_cfg/generated_grid2_cfg2.0.jpg
Generating grid 3/5 (scale=2.0)...


Grid 3: 100%|██████████| 10/10 [00:34<00:00,  3.48s/it]


Saved: results/exp6_cfg/generated_grid3_cfg2.0.jpg
Generating grid 4/5 (scale=2.0)...


Grid 4: 100%|██████████| 10/10 [00:32<00:00,  3.29s/it]


Saved: results/exp6_cfg/generated_grid4_cfg2.0.jpg
Generating grid 5/5 (scale=2.0)...


Grid 5: 100%|██████████| 10/10 [00:33<00:00,  3.37s/it]


Saved: results/exp6_cfg/generated_grid5_cfg2.0.jpg

Generating samples with guidance scale = 3.0
Generating grid 1/5 (scale=3.0)...


Grid 1: 100%|██████████| 10/10 [00:35<00:00,  3.56s/it]


Saved: results/exp6_cfg/generated_grid1_cfg3.0.jpg
Generating grid 2/5 (scale=3.0)...


Grid 2: 100%|██████████| 10/10 [00:34<00:00,  3.43s/it]


Saved: results/exp6_cfg/generated_grid2_cfg3.0.jpg
Generating grid 3/5 (scale=3.0)...


Grid 3: 100%|██████████| 10/10 [00:33<00:00,  3.38s/it]


Saved: results/exp6_cfg/generated_grid3_cfg3.0.jpg
Generating grid 4/5 (scale=3.0)...


Grid 4: 100%|██████████| 10/10 [00:33<00:00,  3.32s/it]


Saved: results/exp6_cfg/generated_grid4_cfg3.0.jpg
Generating grid 5/5 (scale=3.0)...


Grid 5: 100%|██████████| 10/10 [00:31<00:00,  3.19s/it]


Saved: results/exp6_cfg/generated_grid5_cfg3.0.jpg

Generating samples with guidance scale = 5.0
Generating grid 1/5 (scale=5.0)...


Grid 1: 100%|██████████| 10/10 [00:36<00:00,  3.61s/it]


Saved: results/exp6_cfg/generated_grid1_cfg5.0.jpg
Generating grid 2/5 (scale=5.0)...


Grid 2: 100%|██████████| 10/10 [00:36<00:00,  3.68s/it]


Saved: results/exp6_cfg/generated_grid2_cfg5.0.jpg
Generating grid 3/5 (scale=5.0)...


Grid 3: 100%|██████████| 10/10 [00:33<00:00,  3.36s/it]


Saved: results/exp6_cfg/generated_grid3_cfg5.0.jpg
Generating grid 4/5 (scale=5.0)...


Grid 4: 100%|██████████| 10/10 [00:36<00:00,  3.66s/it]


Saved: results/exp6_cfg/generated_grid4_cfg5.0.jpg
Generating grid 5/5 (scale=5.0)...


Grid 5: 100%|██████████| 10/10 [00:33<00:00,  3.32s/it]


Saved: results/exp6_cfg/generated_grid5_cfg5.0.jpg

Experiment 'exp6_cfg' completed!
All results saved in: results/exp6_cfg

Generated samples with guidance scales: [1.0, 1.5, 2.0, 3.0, 5.0]
