In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torchdiffeq import odeint
import os
from tqdm import tqdm
import matplotlib.pyplot as plt

# ================== Configuration Parameters ==================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_size = 28
channels = 1
batch_size = 128
lr = 1e-4
epochs = 50
num_classes = 10
model_save_path = 'FMmodel.pth'

# ================== Data Loading ==================

# Normalize image to [-1,1]
def normalize_img(x):
    return 2 * x - 1

transform = transforms.Compose([
    transforms.ToTensor(),              # Convert image to tensor
    transforms.Lambda(normalize_img)    # Apply normalization
])

# ================== Model Architecture ==================
class ConditionedDoubleConv(nn.Module):
    """Double convolution module with condition injection"""

    def __init__(self, in_channels, out_channels, cond_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)     #TODO: why even or odd.
        self.norm1 = nn.GroupNorm(8, out_channels)                                      #TODO: what about different norms. Group and Batch and Layer
        self.conv2 = nn.Conv2d(out_channels + cond_dim, out_channels, kernel_size=3, padding=1)
        self.norm2 = nn.GroupNorm(8, out_channels)

    def forward(self, x, cond):
        x = F.silu(self.norm1(self.conv1(x)))           #TODO: other activations #TODO: conv -> act -> norm?
        cond = cond.expand(-1, -1, x.size(2), x.size(3))  # Dynamic condition broadcasting
        x = torch.cat([x, cond], dim=1)
        return F.silu(self.norm2(self.conv2(x)))


class Down(nn.Module):
    """Downsampling module"""

    def __init__(self, in_channels, out_channels, cond_dim):
        super().__init__()
        self.maxpool = nn.MaxPool2d(2)                  #TODO: other downsampling methods (avg.)
        self.conv = ConditionedDoubleConv(in_channels, out_channels, cond_dim)

    def forward(self, x, cond):
        x = self.maxpool(x)
        return self.conv(x, cond)


class Up(nn.Module):
    """Upsampling module"""

    def __init__(self, in_channels, out_channels, cond_dim):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)      #TODO: better alternative (transposed conv)
        self.conv = ConditionedDoubleConv(in_channels, out_channels, cond_dim)

    def forward(self, x1, x2, cond):
        x1 = self.up(x1)
        # Size alignment
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x, cond)


#TODO: Diffusion/Vision Transformers instead of UNet
class ConditionalUNet(nn.Module):
    """Dimension-safe conditional UNet"""

    def __init__(self):
        super().__init__()
        # Unified condition encoding dimensions
        self.t_dim = 16
        self.label_dim = 16
        self.cond_dim = self.t_dim + self.label_dim  # 32

        # Time embedding
        self.time_embed = nn.Sequential(
            nn.Linear(1, 32),
            nn.SiLU(),
            nn.Linear(32, self.t_dim)
        )
        # Label embedding
        self.label_embed = nn.Embedding(num_classes, self.label_dim)

        # Encoder path
        self.inc = ConditionedDoubleConv(1, 64, self.cond_dim)
        self.down1 = Down(64, 128, self.cond_dim)
        self.down2 = Down(128, 256, self.cond_dim)

        # Decoder path
        self.up1 = Up(256 + 128, 128, self.cond_dim)  # Input channel correction
        self.up2 = Up(128 + 64, 64, self.cond_dim)
        self.outc = nn.Conv2d(64, 1, kernel_size=1)

    def forward(self, x, t, labels):
        # Condition encoding (unified dimensions)
        t_emb = self.time_embed(t.view(-1, 1))  # [B, 16]
        lbl_emb = self.label_embed(labels)  # [B, 16]
        cond = torch.cat([t_emb, lbl_emb], dim=1)  # [B, 32]
        cond = cond.unsqueeze(-1).unsqueeze(-1)  # [B, 32, 1, 1]

        # Encoder
        x1 = self.inc(x, cond)
        x2 = self.down1(x1, cond)
        x3 = self.down2(x2, cond)

        # Decoder
        x = self.up1(x3, x2, cond)
        x = self.up2(x, x1, cond)
        return self.outc(x)


# ================== Training and Generation ==================
# Initialize model and optimizer here for global access (especially for generate_with_label)
model = ConditionalUNet().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=lr)


@torch.no_grad()
def generate_with_label(label, num_samples=16):
    """
    Generate samples with specified label.
    Args:
        label (int): Digit label to generate (0-9).
        num_samples (int): Number of samples to generate.
    Returns:
        torch.Tensor: Generated image tensor, shape (num_samples, image_size, image_size), values in [0, 1].
    """
    # Save current model training state and set to eval mode
    current_model_state = model.training
    model.eval()

    # Create initial noise and label tensors
    x0 = torch.randn(num_samples, 1, image_size, image_size, device=device)
    labels = torch.full((num_samples,), label, device=device, dtype=torch.long)

    # ODE: t is time, x is current state
    def ode_func(t: torch.Tensor, x: torch.Tensor):
        t_expanded = t.expand(x.size(0))  # [1] -> [num_samples]
        vt = model(x, t_expanded, labels)  # Predict velocity field
        return vt

    # Time points: 0 -> 1
    t_eval = torch.tensor([0.0, 1.0], device=device)

    # Solve ODE (adaptive step size)            #TODO: change ode solvers
    generated = odeint(
        ode_func,
        x0,
        t_eval,
        rtol=1e-5,
        atol=1e-5,
        method='dopri5'
    )

    # Restore training state
    model.train(current_model_state)

    # Post-processing
    images = (generated[-1].clamp(-1, 1) + 1) / 2  # [-1,1] -> [0,1]
    return images.cpu().squeeze(1)


def visualize_train(epoch):
    """
    Generate visualization using current model:
    Create a 10x10 grid with columns 1-10 for each digit (0-9),
    generating 10 images per digit with labels in the first row.
    """
    print("Generating training visualization samples...")
    plt.figure(figsize=(10, 10))
    plt.subplots_adjust(wspace=0.05, hspace=0.05)  # Reduce subplot spacing

    # Generate 10 images for each digit 0-9
    for label in range(num_classes):
        # Generate 10 samples for current digit
        generated_images = generate_with_label(
            label=label,
            num_samples=10
        ).numpy()  # Shape (10, 28, 28)

        # Plot in current column, ensuring each column represents one digit
        for i in range(10):
            # Subplot position calculation: (row_index * total_columns) + column_index + 1
            # We want column 0 for digit 0, column 1 for digit 1, etc.
            # So: (row i * num_classes (10)) + column label + 1
            ax = plt.subplot(10, num_classes, (i * num_classes) + label + 1)
            plt.imshow(generated_images[i], cmap='gray', vmin=0, vmax=1)
            ax.axis('off')
            # Add digit label in first row of each column (when i == 0)
            if i == 0:
                ax.set_title(str(label), fontsize=16, pad=5)  # set_title is more appropriate
    plt.suptitle("Generated Samples During Training", fontsize=20, y=0.97)
    plt.savefig(f"epoch{epoch}.jpg")
    plt.close()


def hundred_image(model_path=model_save_path, num_sample=5):
    print(f"Loading model from {model_path} and generating {num_sample} images...")
    # Create new model instance and load weights
    global model
    original_model = model  # Save original model reference

    loaded_model = ConditionalUNet().to(device)
    if os.path.exists(model_path):
        loaded_model.load_state_dict(torch.load(model_path, map_location=device))
        loaded_model.eval()  # Evaluation mode
    else:
        print(f"Error: Model file not found: {model_path}. Cannot generate images.")
        return

    model = loaded_model  # Temporarily replace global model with loaded model

    for k in range(num_sample):
        plt.figure(figsize=(10, 10))
        plt.subplots_adjust(wspace=0.05, hspace=0.05)

        print(f"Generating image {k + 1}...")
        for label in tqdm(range(num_classes), desc="Generating images for each digit"):
            # Generate 10 samples for current digit
            generated_images = generate_with_label(
                label=label,
                num_samples=10
            ).numpy()  # Shape (10, 28, 28)

            # Plot in current column
            for i in range(10):
                ax = plt.subplot(10, num_classes, (i * num_classes) + label + 1)
                plt.imshow(generated_images[i], cmap='gray', vmin=0, vmax=1)
                ax.axis('off')
                # Add label in first row
                if i == 0:
                    ax.set_title(str(label), fontsize=16, pad=5)

        plt.suptitle(f"Final Generated Samples (Generation {k + 1})", fontsize=20, y=0.97)
        plt.savefig(f"generated_image{k + 1}.jpg")
        print(f"Generated image saved to: generated_image{k + 1}.jpg")
        plt.close()

    # Restore global model reference
    model = original_model


def train(num_epochs=100):
    """Training loop"""
    print("Starting training...")
    # Ensure train_loader is available here since it's initialized in the if __name__ == "__main__": block
    global train_loader

    for epoch in range(num_epochs):
        # Use tqdm to show training progress bar
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}")
        model.train()  # Training mode
        total_loss = 0

        for images, labels in progress_bar:  # Add progress bar
            images = images.to(device)
            labels = labels.to(device)

            # Dynamic noise generation
            noise = torch.randn_like(images)
            t = torch.rand(images.size(0), device=device)
            # Flow Matching target velocity is the velocity field from noise x0 to real data x1
            # x_t = (1-t) * x0 + t * x1
            xt = (1 - t.view(-1, 1, 1, 1)) * noise + t.view(-1, 1, 1, 1) * images

            # Forward pass, model predicts velocity field v_t
            vt_pred = model(xt, t, labels)
            # True velocity field v_t = x1 - x0
            loss = F.mse_loss(vt_pred, images - noise)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Gradient clipping to prevent explosion   #TODO: avoid clipping
            optimizer.step()

            total_loss += loss.item()

            # Update progress bar with loss display
            progress_bar.set_postfix({"Loss": f"{total_loss:.4f}"})

        # Generate samples every 10 epochs
        if (epoch + 1) % 10 == 0:
            visualize_train(epoch + 1)
        # Generate sample for first epoch
        if epoch == 0:
            visualize_train(epoch + 1)

    # Save model after training
    torch.save(model.state_dict(), model_save_path)
    print(f"Training complete. Model saved to: {model_save_path}")


if __name__ == "__main__":
    train_dataset = torchvision.datasets.MNIST(
        root='./data', train=True, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                              num_workers=2)
    # Comment out when using existing model
    train(epochs)
    hundred_image(model_save_path)

Starting training...


Epoch 1: 100%|██████████| 469/469 [00:24<00:00, 19.13it/s, Loss=925.8238]


Generating training visualization samples...


Epoch 2: 100%|██████████| 469/469 [00:24<00:00, 19.07it/s, Loss=689.1275]
Epoch 3: 100%|██████████| 469/469 [00:24<00:00, 19.11it/s, Loss=590.8281]
Epoch 4: 100%|██████████| 469/469 [00:24<00:00, 18.97it/s, Loss=521.7590]
Epoch 5: 100%|██████████| 469/469 [00:24<00:00, 18.88it/s, Loss=460.5462]
Epoch 6: 100%|██████████| 469/469 [00:24<00:00, 19.13it/s, Loss=405.8646]
Epoch 7: 100%|██████████| 469/469 [00:24<00:00, 19.04it/s, Loss=361.4963]
Epoch 8: 100%|██████████| 469/469 [00:24<00:00, 19.09it/s, Loss=325.7300]
Epoch 9: 100%|██████████| 469/469 [00:24<00:00, 19.05it/s, Loss=300.4889]
Epoch 10: 100%|██████████| 469/469 [00:24<00:00, 19.16it/s, Loss=281.4976]


Generating training visualization samples...


Epoch 11: 100%|██████████| 469/469 [00:24<00:00, 19.13it/s, Loss=266.2280]
Epoch 12: 100%|██████████| 469/469 [00:24<00:00, 19.09it/s, Loss=255.5439]
Epoch 13: 100%|██████████| 469/469 [00:24<00:00, 19.17it/s, Loss=246.8708]
Epoch 14: 100%|██████████| 469/469 [00:24<00:00, 19.08it/s, Loss=240.5798]
Epoch 15: 100%|██████████| 469/469 [00:24<00:00, 19.20it/s, Loss=234.5871]
Epoch 16: 100%|██████████| 469/469 [00:24<00:00, 19.09it/s, Loss=230.1819]
Epoch 17: 100%|██████████| 469/469 [00:24<00:00, 19.16it/s, Loss=224.9269]
Epoch 18: 100%|██████████| 469/469 [00:24<00:00, 19.19it/s, Loss=221.1591]
Epoch 19: 100%|██████████| 469/469 [00:24<00:00, 19.11it/s, Loss=217.9509]
Epoch 20: 100%|██████████| 469/469 [00:24<00:00, 19.08it/s, Loss=214.3570]


Generating training visualization samples...


Epoch 21: 100%|██████████| 469/469 [00:24<00:00, 19.19it/s, Loss=211.5842]
Epoch 22: 100%|██████████| 469/469 [00:24<00:00, 19.11it/s, Loss=208.6239]
Epoch 23: 100%|██████████| 469/469 [00:24<00:00, 19.12it/s, Loss=206.8436]
Epoch 24: 100%|██████████| 469/469 [00:24<00:00, 19.22it/s, Loss=203.2784]
Epoch 25: 100%|██████████| 469/469 [00:24<00:00, 19.04it/s, Loss=201.6637]
Epoch 26: 100%|██████████| 469/469 [00:48<00:00,  9.71it/s, Loss=199.1758]
Epoch 27: 100%|██████████| 469/469 [00:36<00:00, 12.79it/s, Loss=197.5957]
Epoch 28: 100%|██████████| 469/469 [00:50<00:00,  9.20it/s, Loss=194.9864]
Epoch 29: 100%|██████████| 469/469 [00:47<00:00,  9.97it/s, Loss=193.4212]
Epoch 30: 100%|██████████| 469/469 [00:24<00:00, 18.86it/s, Loss=192.2274]


Generating training visualization samples...


Epoch 31: 100%|██████████| 469/469 [00:24<00:00, 19.08it/s, Loss=189.6132]
Epoch 32: 100%|██████████| 469/469 [00:24<00:00, 19.01it/s, Loss=187.9532]
Epoch 33: 100%|██████████| 469/469 [00:24<00:00, 19.09it/s, Loss=186.5910]
Epoch 34: 100%|██████████| 469/469 [00:24<00:00, 19.09it/s, Loss=185.0160]
Epoch 35: 100%|██████████| 469/469 [00:24<00:00, 19.11it/s, Loss=183.4117]
Epoch 36: 100%|██████████| 469/469 [00:24<00:00, 19.05it/s, Loss=182.1165]
Epoch 37: 100%|██████████| 469/469 [00:24<00:00, 19.04it/s, Loss=180.8234]
Epoch 38: 100%|██████████| 469/469 [00:24<00:00, 19.12it/s, Loss=179.3385]
Epoch 39: 100%|██████████| 469/469 [00:24<00:00, 19.16it/s, Loss=177.9864]
Epoch 40: 100%|██████████| 469/469 [00:24<00:00, 19.05it/s, Loss=176.0617]


Generating training visualization samples...


Epoch 41: 100%|██████████| 469/469 [00:24<00:00, 19.07it/s, Loss=175.8308]
Epoch 42: 100%|██████████| 469/469 [00:24<00:00, 19.07it/s, Loss=174.2884]
Epoch 43: 100%|██████████| 469/469 [00:24<00:00, 19.21it/s, Loss=173.4384]
Epoch 44: 100%|██████████| 469/469 [00:24<00:00, 19.09it/s, Loss=171.8537]
Epoch 45: 100%|██████████| 469/469 [00:24<00:00, 19.21it/s, Loss=171.2573]
Epoch 46: 100%|██████████| 469/469 [00:24<00:00, 19.15it/s, Loss=170.0746]
Epoch 47: 100%|██████████| 469/469 [00:24<00:00, 19.17it/s, Loss=168.8070]
Epoch 48: 100%|██████████| 469/469 [00:24<00:00, 19.20it/s, Loss=167.9851]
Epoch 49: 100%|██████████| 469/469 [00:24<00:00, 19.22it/s, Loss=167.0389]
Epoch 50: 100%|██████████| 469/469 [00:24<00:00, 19.22it/s, Loss=166.2858]


Generating training visualization samples...
Training complete. Model saved to: FMmodel.pth
Loading model from FMmodel.pth and generating 5 images...
Generating image 1...


Generating images for each digit: 100%|██████████| 10/10 [00:02<00:00,  4.97it/s]


Generated image saved to: generated_image1.jpg
Generating image 2...


Generating images for each digit: 100%|██████████| 10/10 [00:01<00:00,  5.60it/s]


Generated image saved to: generated_image2.jpg
Generating image 3...


Generating images for each digit: 100%|██████████| 10/10 [00:02<00:00,  4.84it/s]


Generated image saved to: generated_image3.jpg
Generating image 4...


Generating images for each digit: 100%|██████████| 10/10 [00:01<00:00,  5.57it/s]


Generated image saved to: generated_image4.jpg
Generating image 5...


Generating images for each digit: 100%|██████████| 10/10 [00:01<00:00,  5.61it/s]


Generated image saved to: generated_image5.jpg
