In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torchdiffeq import odeint
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import json
from datetime import datetime

In [11]:
# ================== Experiment Configuration ==================
EXPERIMENT_NAME = "exp3"  # Change this for each experiment
RESULTS_DIR = f"results/{EXPERIMENT_NAME}"
os.makedirs(RESULTS_DIR, exist_ok=True)

In [12]:
# ================== Configuration Parameters ==================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_size = 32  # CIFAR-10 is 32x32
channels = 3     # CIFAR-10 is RGB
batch_size = 128
num_classes = 10
model_save_path = os.path.join(RESULTS_DIR, 'FMmodel.pth')

In [13]:
# Learning Rate
lr = 1e-4

# Number of Epochs
epochs = 100

# Model Size
BASE_CHANNELS = 128

# Time Embedding Dimension 
TIME_EMBED_DIM = 128

# ODE Solver
ODE_METHOD = 'dopri5'  

# Number of ODE Steps 
ODE_STEPS = 50 

# Loss Function: b/w u^target and u^theta
LOSS_TYPE = 'huber'  # Huber Loss

# Normalization 
NORM_TYPE = 'groupnorm'

# Activation Function
ACTIVATION = 'silu'

# Optimizer 
OPTIMIZER_TYPE = 'adamw'

# Save experiment config
config = {
    'experiment_name': EXPERIMENT_NAME,
    'lr': lr,
    'epochs': epochs,
    'batch_size': batch_size,
    'base_channels': BASE_CHANNELS,
    'time_embed_dim': TIME_EMBED_DIM,
    'ode_method': ODE_METHOD,
    'ode_steps': ODE_STEPS,
    'loss_type': LOSS_TYPE,
    'norm_type': NORM_TYPE,
    'activation': ACTIVATION,
    'optimizer_type': OPTIMIZER_TYPE,
    'image_size': image_size,
    'channels': channels,
    'device': str(device)
}

with open(os.path.join(RESULTS_DIR, 'config.json'), 'w') as f:
    json.dump(config, f, indent=4)

In [14]:
# ================== Data Loading ==================
def normalize_img(x):
    """Normalize image to [-1, 1]"""
    return 2 * x - 1

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(normalize_img)
])

# ================== Helper Functions ==================
def get_activation():
    """Return activation function based on config"""
    if ACTIVATION == 'silu':
        return F.silu
    elif ACTIVATION == 'swiglu':
        return SwiGLU()
    elif ACTIVATION == 'gelu':
        return F.gelu
    else:
        return F.silu

def get_norm_layer(num_channels):
    """Return normalization layer based on config"""
    if NORM_TYPE == 'groupnorm':
        return nn.GroupNorm(min(32, num_channels), num_channels)
    elif NORM_TYPE == 'batchnorm':
        return nn.BatchNorm2d(num_channels)
    elif NORM_TYPE == 'layernorm':
        # For LayerNorm with images, normalize over C, H, W
        return nn.GroupNorm(1, num_channels)
    else:
        return nn.GroupNorm(min(32, num_channels), num_channels)

In [15]:
# ================== Model Architecture ==================
class ConditionedDoubleConv(nn.Module):
    """Double convolution module with condition injection"""
    
    def __init__(self, in_channels, out_channels, cond_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.norm1 = get_norm_layer(out_channels)
        self.conv2 = nn.Conv2d(out_channels + cond_dim, out_channels, kernel_size=3, padding=1)
        self.norm2 = get_norm_layer(out_channels)
        self.activation = get_activation()
    
    def forward(self, x, cond):
        x = self.activation(self.norm1(self.conv1(x)))
        cond = cond.expand(-1, -1, x.size(2), x.size(3))
        x = torch.cat([x, cond], dim=1)
        return self.activation(self.norm2(self.conv2(x)))


class Down(nn.Module):
    """Downsampling module"""
    
    def __init__(self, in_channels, out_channels, cond_dim):
        super().__init__()
        self.maxpool = nn.MaxPool2d(2)
        self.conv = ConditionedDoubleConv(in_channels, out_channels, cond_dim)
    
    def forward(self, x, cond):
        x = self.maxpool(x)
        return self.conv(x, cond)


class Up(nn.Module):
    """Upsampling module"""
    
    def __init__(self, in_channels, out_channels, cond_dim):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv = ConditionedDoubleConv(in_channels, out_channels, cond_dim)
    
    def forward(self, x1, x2, cond):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x, cond)


class ConditionalUNet(nn.Module):
    """Enhanced UNet for CIFAR-10"""
    
    def __init__(self):
        super().__init__()
        # Condition dimensions
        self.label_dim = 32
        self.cond_dim = TIME_EMBED_DIM + self.label_dim
        
        # Time embedding
        self.time_embed = nn.Sequential(
            nn.Linear(1, TIME_EMBED_DIM * 2),
            nn.SiLU(),
            nn.Linear(TIME_EMBED_DIM * 2, TIME_EMBED_DIM)
        )
        
        # Label embedding
        self.label_embed = nn.Embedding(num_classes, self.label_dim)
        
        # Encoder path (deeper for CIFAR-10)
        self.inc = ConditionedDoubleConv(channels, BASE_CHANNELS, self.cond_dim)
        self.down1 = Down(BASE_CHANNELS, BASE_CHANNELS * 2, self.cond_dim)
        self.down2 = Down(BASE_CHANNELS * 2, BASE_CHANNELS * 4, self.cond_dim)
        self.down3 = Down(BASE_CHANNELS * 4, BASE_CHANNELS * 8, self.cond_dim)
        
        # Decoder path
        self.up1 = Up(BASE_CHANNELS * 8 + BASE_CHANNELS * 4, BASE_CHANNELS * 4, self.cond_dim)
        self.up2 = Up(BASE_CHANNELS * 4 + BASE_CHANNELS * 2, BASE_CHANNELS * 2, self.cond_dim)
        self.up3 = Up(BASE_CHANNELS * 2 + BASE_CHANNELS, BASE_CHANNELS, self.cond_dim)
        self.outc = nn.Conv2d(BASE_CHANNELS, channels, kernel_size=1)
    
    def forward(self, x, t, labels):
        # Condition encoding
        t_emb = self.time_embed(t.view(-1, 1))
        lbl_emb = self.label_embed(labels)
        cond = torch.cat([t_emb, lbl_emb], dim=1)
        cond = cond.unsqueeze(-1).unsqueeze(-1)
        
        # Encoder
        x1 = self.inc(x, cond)
        x2 = self.down1(x1, cond)
        x3 = self.down2(x2, cond)
        x4 = self.down3(x3, cond)
        
        # Decoder
        x = self.up1(x4, x3, cond)
        x = self.up2(x, x2, cond)
        x = self.up3(x, x1, cond)
        return self.outc(x)

In [16]:
# ================== Training and Generation ==================
model = ConditionalUNet().to(device)

# Initialize optimizer based on config
if OPTIMIZER_TYPE == 'adam':
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
elif OPTIMIZER_TYPE == 'adamw':
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
elif OPTIMIZER_TYPE == 'sgd':
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
else:
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)


def compute_loss(pred, target):
    """Compute loss based on config"""
    if LOSS_TYPE == 'mse':
        return F.mse_loss(pred, target)
    elif LOSS_TYPE == 'l1':
        return F.l1_loss(pred, target)
    elif LOSS_TYPE == 'huber':
        return F.huber_loss(pred, target)
    else:
        return F.mse_loss(pred, target)


@torch.no_grad()
def generate_with_label(label, num_samples=16):
    """Generate samples with specified label"""
    current_model_state = model.training
    model.eval()
    
    x0 = torch.randn(num_samples, channels, image_size, image_size, device=device)
    labels = torch.full((num_samples,), label, device=device, dtype=torch.long)
    
    def ode_func(t: torch.Tensor, x: torch.Tensor):
        t_expanded = t.expand(x.size(0))
        vt = model(x, t_expanded, labels)
        return vt
    
    # Use different time points based on ODE method
    if ODE_METHOD in ['euler', 'rk4', 'midpoint']:
        t_eval = torch.linspace(0.0, 1.0, ODE_STEPS + 1, device=device)
    else:
        t_eval = torch.tensor([0.0, 1.0], device=device)
    
    generated = odeint(
        ode_func,
        x0,
        t_eval,
        rtol=1e-5,
        atol=1e-5,
        method=ODE_METHOD
    )
    
    model.train(current_model_state)
    
    images = (generated[-1].clamp(-1, 1) + 1) / 2
    return images.cpu()


def visualize_train(epoch):
    """Generate visualization grid"""
    print(f"Generating visualization for epoch {epoch}...")
    plt.figure(figsize=(12, 12))
    plt.subplots_adjust(wspace=0.1, hspace=0.1)
    
    for label in range(num_classes):
        generated_images = generate_with_label(label=label, num_samples=10)
        
        for i in range(10):
            ax = plt.subplot(10, num_classes, (i * num_classes) + label + 1)
            # Convert CHW to HWC for display
            img = generated_images[i].permute(1, 2, 0).numpy()
            plt.imshow(img, vmin=0, vmax=1)
            ax.axis('off')
            if i == 0:
                ax.set_title(str(label), fontsize=14, pad=5)
    
    plt.suptitle(f"Generated Samples - Epoch {epoch}", fontsize=18, y=0.98)
    plt.savefig(os.path.join(RESULTS_DIR, f"epoch{epoch}.jpg"), dpi=150, bbox_inches='tight')
    plt.close()


def generate_final_samples(num_sample=5):
    """Generate final sample grids"""
    print(f"Generating {num_sample} final sample grids...")
    
    for k in range(num_sample):
        plt.figure(figsize=(12, 12))
        plt.subplots_adjust(wspace=0.1, hspace=0.1)
        
        print(f"Generating grid {k + 1}/{num_sample}...")
        for label in tqdm(range(num_classes), desc=f"Grid {k+1}"):
            generated_images = generate_with_label(label=label, num_samples=10)
            
            for i in range(10):
                ax = plt.subplot(10, num_classes, (i * num_classes) + label + 1)
                img = generated_images[i].permute(1, 2, 0).numpy()
                plt.imshow(img, vmin=0, vmax=1)
                ax.axis('off')
                if i == 0:
                    ax.set_title(str(label), fontsize=14, pad=5)
        
        plt.suptitle(f"Final Generated Samples (Grid {k + 1})", fontsize=18, y=0.98)
        save_path = os.path.join(RESULTS_DIR, f"generated_grid{k + 1}.jpg")
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved: {save_path}")
        plt.close()

In [17]:
def train(num_epochs=100):
    """Training loop with logging"""
    print(f"Starting training for experiment: {EXPERIMENT_NAME}")
    print(f"Results will be saved to: {RESULTS_DIR}")
    print(f"Configuration: {config}")
    
    global train_loader
    training_log = []
    
    for epoch in range(num_epochs):
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}")
        model.train()
        total_loss = 0
        num_batches = 0
        
        for images, labels in progress_bar:
            images = images.to(device)
            labels = labels.to(device)
            
            noise = torch.randn_like(images)
            t = torch.rand(images.size(0), device=device)
            xt = (1 - t.view(-1, 1, 1, 1)) * noise + t.view(-1, 1, 1, 1) * images
            
            vt_pred = model(xt, t, labels)
            loss = compute_loss(vt_pred, images - noise)
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
            
            progress_bar.set_postfix({"Loss": f"{loss.item():.4f}"})
        
        avg_loss = total_loss / num_batches
        training_log.append({'epoch': epoch + 1, 'loss': avg_loss})
        print(f"Epoch {epoch + 1} - Average Loss: {avg_loss:.4f}")
        
        # Save loss log
        with open(os.path.join(RESULTS_DIR, 'training_log.json'), 'w') as f:
            json.dump(training_log, f, indent=4)
        
        # Generate samples
        if (epoch + 1) % 10 == 0 or epoch == 0:
            visualize_train(epoch + 1)
    
    # Save model
    torch.save(model.state_dict(), model_save_path)
    print(f"Training complete. Model saved to: {model_save_path}")
    
    # Plot training curve
    plot_training_curve(training_log)


def plot_training_curve(training_log):
    """Plot and save training loss curve"""
    epochs_list = [entry['epoch'] for entry in training_log]
    losses = [entry['loss'] for entry in training_log]
    
    plt.figure(figsize=(10, 6))
    plt.plot(epochs_list, losses, linewidth=2)
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt.title(f'Training Loss - {EXPERIMENT_NAME}', fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.savefig(os.path.join(RESULTS_DIR, 'training_curve.png'), dpi=150, bbox_inches='tight')
    plt.close()
    print(f"Training curve saved to: {os.path.join(RESULTS_DIR, 'training_curve.png')}")

In [18]:
# Load CIFAR-10
train_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                            num_workers=4, pin_memory=True)

print(f"Dataset loaded: {len(train_dataset)} training images")
print(f"Image shape: {channels} x {image_size} x {image_size}")

# Train model
train(epochs)

# Generate final samples
generate_final_samples(num_sample=5)

print(f"\nExperiment '{EXPERIMENT_NAME}' completed!")
print(f"All results saved in: {RESULTS_DIR}")

Dataset loaded: 50000 training images
Image shape: 3 x 32 x 32
Starting training for experiment: exp3
Results will be saved to: results/exp3
Configuration: {'experiment_name': 'exp3', 'lr': 0.0001, 'epochs': 100, 'batch_size': 128, 'base_channels': 128, 'time_embed_dim': 128, 'ode_method': 'dopri5', 'ode_steps': 50, 'loss_type': 'huber', 'norm_type': 'groupnorm', 'activation': 'silu', 'optimizer_type': 'adamw', 'image_size': 32, 'channels': 3, 'device': 'cuda'}


Epoch 1/100: 100%|██████████| 391/391 [06:58<00:00,  1.07s/it, Loss=0.1306]


Epoch 1 - Average Loss: 0.1404
Generating visualization for epoch 1...


Epoch 2/100: 100%|██████████| 391/391 [07:01<00:00,  1.08s/it, Loss=0.0941]


Epoch 2 - Average Loss: 0.1080


Epoch 3/100: 100%|██████████| 391/391 [07:02<00:00,  1.08s/it, Loss=0.1070]


Epoch 3 - Average Loss: 0.1029


Epoch 4/100: 100%|██████████| 391/391 [07:02<00:00,  1.08s/it, Loss=0.1129]


Epoch 4 - Average Loss: 0.1006


Epoch 5/100: 100%|██████████| 391/391 [06:59<00:00,  1.07s/it, Loss=0.0920]


Epoch 5 - Average Loss: 0.0985


Epoch 6/100: 100%|██████████| 391/391 [07:02<00:00,  1.08s/it, Loss=0.0902]


Epoch 6 - Average Loss: 0.0976


Epoch 7/100: 100%|██████████| 391/391 [07:02<00:00,  1.08s/it, Loss=0.0954]


Epoch 7 - Average Loss: 0.0960


Epoch 8/100: 100%|██████████| 391/391 [06:55<00:00,  1.06s/it, Loss=0.0931]


Epoch 8 - Average Loss: 0.0959


Epoch 9/100: 100%|██████████| 391/391 [07:02<00:00,  1.08s/it, Loss=0.0948]


Epoch 9 - Average Loss: 0.0951


Epoch 10/100: 100%|██████████| 391/391 [06:58<00:00,  1.07s/it, Loss=0.0945]


Epoch 10 - Average Loss: 0.0946
Generating visualization for epoch 10...


Epoch 11/100: 100%|██████████| 391/391 [06:57<00:00,  1.07s/it, Loss=0.0909]


Epoch 11 - Average Loss: 0.0938


Epoch 12/100: 100%|██████████| 391/391 [07:02<00:00,  1.08s/it, Loss=0.0982]


Epoch 12 - Average Loss: 0.0934


Epoch 13/100: 100%|██████████| 391/391 [07:02<00:00,  1.08s/it, Loss=0.0893]


Epoch 13 - Average Loss: 0.0932


Epoch 14/100: 100%|██████████| 391/391 [06:58<00:00,  1.07s/it, Loss=0.1001]


Epoch 14 - Average Loss: 0.0927


Epoch 15/100: 100%|██████████| 391/391 [07:02<00:00,  1.08s/it, Loss=0.0924]


Epoch 15 - Average Loss: 0.0930


Epoch 16/100: 100%|██████████| 391/391 [07:02<00:00,  1.08s/it, Loss=0.0910]


Epoch 16 - Average Loss: 0.0925


Epoch 17/100: 100%|██████████| 391/391 [06:58<00:00,  1.07s/it, Loss=0.0953]


Epoch 17 - Average Loss: 0.0920


Epoch 18/100: 100%|██████████| 391/391 [06:54<00:00,  1.06s/it, Loss=0.0910]


Epoch 18 - Average Loss: 0.0915


Epoch 19/100: 100%|██████████| 391/391 [06:57<00:00,  1.07s/it, Loss=0.0957]


Epoch 19 - Average Loss: 0.0919


Epoch 20/100: 100%|██████████| 391/391 [07:00<00:00,  1.08s/it, Loss=0.0899]


Epoch 20 - Average Loss: 0.0916
Generating visualization for epoch 20...


Epoch 21/100: 100%|██████████| 391/391 [06:54<00:00,  1.06s/it, Loss=0.0995]


Epoch 21 - Average Loss: 0.0917


Epoch 22/100: 100%|██████████| 391/391 [06:58<00:00,  1.07s/it, Loss=0.0912]


Epoch 22 - Average Loss: 0.0912


Epoch 23/100: 100%|██████████| 391/391 [06:56<00:00,  1.06s/it, Loss=0.0880]


Epoch 23 - Average Loss: 0.0912


Epoch 24/100: 100%|██████████| 391/391 [07:00<00:00,  1.08s/it, Loss=0.0888]


Epoch 24 - Average Loss: 0.0906


Epoch 25/100: 100%|██████████| 391/391 [07:00<00:00,  1.07s/it, Loss=0.0867]


Epoch 25 - Average Loss: 0.0913


Epoch 26/100: 100%|██████████| 391/391 [06:59<00:00,  1.07s/it, Loss=0.0869]


Epoch 26 - Average Loss: 0.0907


Epoch 27/100: 100%|██████████| 391/391 [06:49<00:00,  1.05s/it, Loss=0.0881]


Epoch 27 - Average Loss: 0.0910


Epoch 28/100: 100%|██████████| 391/391 [07:00<00:00,  1.07s/it, Loss=0.0900]


Epoch 28 - Average Loss: 0.0909


Epoch 29/100: 100%|██████████| 391/391 [07:00<00:00,  1.07s/it, Loss=0.0896]


Epoch 29 - Average Loss: 0.0905


Epoch 30/100: 100%|██████████| 391/391 [06:59<00:00,  1.07s/it, Loss=0.0963]


Epoch 30 - Average Loss: 0.0900
Generating visualization for epoch 30...


Epoch 31/100: 100%|██████████| 391/391 [06:54<00:00,  1.06s/it, Loss=0.1029]


Epoch 31 - Average Loss: 0.0904


Epoch 32/100: 100%|██████████| 391/391 [06:53<00:00,  1.06s/it, Loss=0.0908]


Epoch 32 - Average Loss: 0.0899


Epoch 33/100: 100%|██████████| 391/391 [06:59<00:00,  1.07s/it, Loss=0.0817]


Epoch 33 - Average Loss: 0.0902


Epoch 34/100: 100%|██████████| 391/391 [06:59<00:00,  1.07s/it, Loss=0.0837]


Epoch 34 - Average Loss: 0.0898


Epoch 35/100: 100%|██████████| 391/391 [07:00<00:00,  1.07s/it, Loss=0.0908]


Epoch 35 - Average Loss: 0.0899


Epoch 36/100: 100%|██████████| 391/391 [06:55<00:00,  1.06s/it, Loss=0.0820]


Epoch 36 - Average Loss: 0.0896


Epoch 37/100: 100%|██████████| 391/391 [06:52<00:00,  1.06s/it, Loss=0.0909]


Epoch 37 - Average Loss: 0.0901


Epoch 38/100: 100%|██████████| 391/391 [06:59<00:00,  1.07s/it, Loss=0.0906]


Epoch 38 - Average Loss: 0.0894


Epoch 39/100: 100%|██████████| 391/391 [06:58<00:00,  1.07s/it, Loss=0.0839]


Epoch 39 - Average Loss: 0.0893


Epoch 40/100: 100%|██████████| 391/391 [06:59<00:00,  1.07s/it, Loss=0.0955]


Epoch 40 - Average Loss: 0.0893
Generating visualization for epoch 40...


Epoch 41/100: 100%|██████████| 391/391 [06:20<00:00,  1.03it/s, Loss=0.0787]


Epoch 41 - Average Loss: 0.0890


Epoch 42/100: 100%|██████████| 391/391 [05:15<00:00,  1.24it/s, Loss=0.0923]


Epoch 42 - Average Loss: 0.0895


Epoch 43/100: 100%|██████████| 391/391 [05:16<00:00,  1.23it/s, Loss=0.0843]


Epoch 43 - Average Loss: 0.0887


Epoch 44/100: 100%|██████████| 391/391 [05:17<00:00,  1.23it/s, Loss=0.0874]


Epoch 44 - Average Loss: 0.0895


Epoch 45/100: 100%|██████████| 391/391 [05:16<00:00,  1.24it/s, Loss=0.0932]


Epoch 45 - Average Loss: 0.0893


Epoch 46/100: 100%|██████████| 391/391 [05:16<00:00,  1.24it/s, Loss=0.0936]


Epoch 46 - Average Loss: 0.0888


Epoch 47/100: 100%|██████████| 391/391 [05:09<00:00,  1.26it/s, Loss=0.0794]


Epoch 47 - Average Loss: 0.0886


Epoch 48/100: 100%|██████████| 391/391 [05:16<00:00,  1.23it/s, Loss=0.0857]


Epoch 48 - Average Loss: 0.0892


Epoch 49/100: 100%|██████████| 391/391 [05:16<00:00,  1.24it/s, Loss=0.0926]


Epoch 49 - Average Loss: 0.0888


Epoch 50/100: 100%|██████████| 391/391 [05:16<00:00,  1.24it/s, Loss=0.0873]


Epoch 50 - Average Loss: 0.0891
Generating visualization for epoch 50...


Epoch 51/100: 100%|██████████| 391/391 [05:10<00:00,  1.26it/s, Loss=0.0937]


Epoch 51 - Average Loss: 0.0889


Epoch 52/100: 100%|██████████| 391/391 [05:14<00:00,  1.24it/s, Loss=0.0947]


Epoch 52 - Average Loss: 0.0888


Epoch 53/100: 100%|██████████| 391/391 [05:16<00:00,  1.24it/s, Loss=0.0990]


Epoch 53 - Average Loss: 0.0887


Epoch 54/100: 100%|██████████| 391/391 [05:16<00:00,  1.23it/s, Loss=0.0893]


Epoch 54 - Average Loss: 0.0888


Epoch 55/100: 100%|██████████| 391/391 [05:16<00:00,  1.24it/s, Loss=0.0911]


Epoch 55 - Average Loss: 0.0886


Epoch 56/100: 100%|██████████| 391/391 [05:16<00:00,  1.24it/s, Loss=0.0810]


Epoch 56 - Average Loss: 0.0887


Epoch 57/100: 100%|██████████| 391/391 [05:09<00:00,  1.26it/s, Loss=0.0888]


Epoch 57 - Average Loss: 0.0886


Epoch 58/100: 100%|██████████| 391/391 [05:16<00:00,  1.24it/s, Loss=0.0934]


Epoch 58 - Average Loss: 0.0882


Epoch 59/100: 100%|██████████| 391/391 [05:16<00:00,  1.24it/s, Loss=0.0860]


Epoch 59 - Average Loss: 0.0885


Epoch 60/100: 100%|██████████| 391/391 [05:16<00:00,  1.24it/s, Loss=0.0872]


Epoch 60 - Average Loss: 0.0880
Generating visualization for epoch 60...


Epoch 61/100: 100%|██████████| 391/391 [05:09<00:00,  1.26it/s, Loss=0.0914]


Epoch 61 - Average Loss: 0.0884


Epoch 62/100: 100%|██████████| 391/391 [05:14<00:00,  1.24it/s, Loss=0.0784]


Epoch 62 - Average Loss: 0.0879


Epoch 63/100: 100%|██████████| 391/391 [05:15<00:00,  1.24it/s, Loss=0.0894]


Epoch 63 - Average Loss: 0.0881


Epoch 64/100: 100%|██████████| 391/391 [05:15<00:00,  1.24it/s, Loss=0.0896]


Epoch 64 - Average Loss: 0.0881


Epoch 65/100: 100%|██████████| 391/391 [05:15<00:00,  1.24it/s, Loss=0.1093]


Epoch 65 - Average Loss: 0.0880


Epoch 66/100: 100%|██████████| 391/391 [05:16<00:00,  1.24it/s, Loss=0.0786]


Epoch 66 - Average Loss: 0.0878


Epoch 67/100: 100%|██████████| 391/391 [05:08<00:00,  1.27it/s, Loss=0.0922]


Epoch 67 - Average Loss: 0.0883


Epoch 68/100: 100%|██████████| 391/391 [05:15<00:00,  1.24it/s, Loss=0.1018]


Epoch 68 - Average Loss: 0.0878


Epoch 69/100: 100%|██████████| 391/391 [05:15<00:00,  1.24it/s, Loss=0.0983]


Epoch 69 - Average Loss: 0.0881


Epoch 70/100: 100%|██████████| 391/391 [05:15<00:00,  1.24it/s, Loss=0.0796]


Epoch 70 - Average Loss: 0.0880
Generating visualization for epoch 70...


Epoch 71/100: 100%|██████████| 391/391 [05:08<00:00,  1.27it/s, Loss=0.0847]


Epoch 71 - Average Loss: 0.0879


Epoch 72/100: 100%|██████████| 391/391 [05:14<00:00,  1.24it/s, Loss=0.0927]


Epoch 72 - Average Loss: 0.0877


Epoch 73/100: 100%|██████████| 391/391 [05:15<00:00,  1.24it/s, Loss=0.0866]


Epoch 73 - Average Loss: 0.0878


Epoch 74/100: 100%|██████████| 391/391 [05:15<00:00,  1.24it/s, Loss=0.0855]


Epoch 74 - Average Loss: 0.0880


Epoch 75/100: 100%|██████████| 391/391 [05:15<00:00,  1.24it/s, Loss=0.0916]


Epoch 75 - Average Loss: 0.0882


Epoch 76/100: 100%|██████████| 391/391 [05:15<00:00,  1.24it/s, Loss=0.0822]


Epoch 76 - Average Loss: 0.0878


Epoch 77/100: 100%|██████████| 391/391 [05:08<00:00,  1.27it/s, Loss=0.0826]


Epoch 77 - Average Loss: 0.0877


Epoch 78/100: 100%|██████████| 391/391 [05:14<00:00,  1.24it/s, Loss=0.0930]


Epoch 78 - Average Loss: 0.0877


Epoch 79/100: 100%|██████████| 391/391 [05:14<00:00,  1.24it/s, Loss=0.0949]


Epoch 79 - Average Loss: 0.0874


Epoch 80/100: 100%|██████████| 391/391 [05:15<00:00,  1.24it/s, Loss=0.0868]


Epoch 80 - Average Loss: 0.0876
Generating visualization for epoch 80...


Epoch 81/100: 100%|██████████| 391/391 [05:08<00:00,  1.27it/s, Loss=0.0848]


Epoch 81 - Average Loss: 0.0878


Epoch 82/100: 100%|██████████| 391/391 [05:13<00:00,  1.25it/s, Loss=0.0908]


Epoch 82 - Average Loss: 0.0875


Epoch 83/100: 100%|██████████| 391/391 [05:14<00:00,  1.24it/s, Loss=0.0899]


Epoch 83 - Average Loss: 0.0874


Epoch 84/100: 100%|██████████| 391/391 [05:14<00:00,  1.24it/s, Loss=0.0855]


Epoch 84 - Average Loss: 0.0874


Epoch 85/100: 100%|██████████| 391/391 [05:14<00:00,  1.24it/s, Loss=0.0852]


Epoch 85 - Average Loss: 0.0873


Epoch 86/100: 100%|██████████| 391/391 [05:14<00:00,  1.24it/s, Loss=0.0814]


Epoch 86 - Average Loss: 0.0875


Epoch 87/100: 100%|██████████| 391/391 [05:07<00:00,  1.27it/s, Loss=0.0983]


Epoch 87 - Average Loss: 0.0873


Epoch 88/100: 100%|██████████| 391/391 [05:14<00:00,  1.24it/s, Loss=0.0758]


Epoch 88 - Average Loss: 0.0874


Epoch 89/100: 100%|██████████| 391/391 [05:14<00:00,  1.24it/s, Loss=0.0800]


Epoch 89 - Average Loss: 0.0874


Epoch 90/100: 100%|██████████| 391/391 [05:14<00:00,  1.24it/s, Loss=0.0890]


Epoch 90 - Average Loss: 0.0873
Generating visualization for epoch 90...


Epoch 91/100: 100%|██████████| 391/391 [05:08<00:00,  1.27it/s, Loss=0.0963]


Epoch 91 - Average Loss: 0.0870


Epoch 92/100: 100%|██████████| 391/391 [04:33<00:00,  1.43it/s, Loss=0.0907]


Epoch 92 - Average Loss: 0.0869


Epoch 93/100: 100%|██████████| 391/391 [03:26<00:00,  1.89it/s, Loss=0.0822]


Epoch 93 - Average Loss: 0.0872


Epoch 94/100: 100%|██████████| 391/391 [03:25<00:00,  1.90it/s, Loss=0.0911]


Epoch 94 - Average Loss: 0.0871


Epoch 95/100: 100%|██████████| 391/391 [03:25<00:00,  1.90it/s, Loss=0.0889]


Epoch 95 - Average Loss: 0.0872


Epoch 96/100: 100%|██████████| 391/391 [03:25<00:00,  1.90it/s, Loss=0.0895]


Epoch 96 - Average Loss: 0.0872


Epoch 97/100: 100%|██████████| 391/391 [02:47<00:00,  2.33it/s, Loss=0.0870]


Epoch 97 - Average Loss: 0.0869


Epoch 98/100: 100%|██████████| 391/391 [01:43<00:00,  3.77it/s, Loss=0.0877]


Epoch 98 - Average Loss: 0.0869


Epoch 99/100: 100%|██████████| 391/391 [01:34<00:00,  4.14it/s, Loss=0.0849]


Epoch 99 - Average Loss: 0.0868


Epoch 100/100: 100%|██████████| 391/391 [01:34<00:00,  4.14it/s, Loss=0.0820]


Epoch 100 - Average Loss: 0.0869
Generating visualization for epoch 100...
Training complete. Model saved to: results/exp3/FMmodel.pth
Training curve saved to: results/exp3/training_curve.png
Generating 5 final sample grids...
Generating grid 1/5...


Grid 1: 100%|██████████| 10/10 [00:12<00:00,  1.25s/it]


Saved: results/exp3/generated_grid1.jpg
Generating grid 2/5...


Grid 2: 100%|██████████| 10/10 [00:11<00:00,  1.18s/it]


Saved: results/exp3/generated_grid2.jpg
Generating grid 3/5...


Grid 3: 100%|██████████| 10/10 [00:11<00:00,  1.17s/it]


Saved: results/exp3/generated_grid3.jpg
Generating grid 4/5...


Grid 4: 100%|██████████| 10/10 [00:12<00:00,  1.23s/it]


Saved: results/exp3/generated_grid4.jpg
Generating grid 5/5...


Grid 5: 100%|██████████| 10/10 [00:11<00:00,  1.15s/it]


Saved: results/exp3/generated_grid5.jpg

Experiment 'exp3' completed!
All results saved in: results/exp3
