In [None]:
from utils import *
from flowmatching_reflow import *

# Flow matching with reflow

## generate data

In [None]:

# 1. Generate Data (on CPU initially)
M = 3 # Data dimensionality
num_samples = 32000
print(f"Generating {num_samples} samples...")
data_cpu, original_radii = generate_3d_sphere_data(num_samples)
data_cpu = data_cpu.float()

# 2. Configure Training
config = Config(
    M=M,
    nhidden=1024,
    nlayers=2,
    batch_size=128,
    learning_rate=1e-3,
    epochs=2000, # Epochs for initial training
    epochs_reflow=50, # Epochs for reflow training
    time_embed_dim=64,
    ode_steps=50,        # Steps for SAMPLING (initial and reflow)
    reflow_ode_steps=10, # Steps for generating TRAJECTORIES for reflow training
    animation_steps=100, # Steps/frames for animation
    animation_samples=500, # Number of points in animation
    epsilon=1e-5
)
# Add config for reflow if needed (e.g., reflow_epochs, reflow_lr)
# config.reflow_epochs = config.epochs // 2 # Example

# 3. Move Original Dataset to Target Device
data_on_device = None
if config.device.type == 'cuda':
    print(f"Attempting to move dataset ({data_cpu.nelement() * data_cpu.element_size() / 1024**2:.2f} MB) to {config.device}...")
    try:
        data_on_device = data_cpu.to(config.device)
        print(f"Successfully moved dataset to {data_on_device.device}")
    except RuntimeError as e:
        print(f"\n----- ERROR moving full dataset to GPU: {e}. Exiting. -----")
        exit()
else:
    print("Running on CPU, keeping data on CPU.")
    data_on_device = data_cpu


## Initial training of the model

In [None]:
# --- Initial Training ---
train_initial_model_flag = False
config_dir = get_config_directory(config)
os.makedirs(config_dir, exist_ok=True)
initial_model_filename = "model_fm_initial.pth"
reflow_model_filename = "model_fm_reflow.pth"
initial_model = None
epoch_losses_initial = []

if train_initial_model_flag:
    print("\n--- Starting Initial Flow Matching Training ---")
    initial_model, epoch_losses_initial = train_flow_matching(data_on_device, config)

    if initial_model:
        save_model(initial_model, config, filename=initial_model_filename)
        if epoch_losses_initial:
            plt.figure(figsize=(10, 5))
            plt.plot(epoch_losses_initial, label='Initial Training Loss')
            plt.xlabel('Epoch')
            plt.ylabel('Loss (MSE)')
            plt.yscale('log')
            plt.grid(True, which='both', linestyle='--', linewidth=0.5)
            plt.legend()
            plt.title(f'Initial Flow Matching Training Loss\n{get_config_description(config)}')
            loss_fig = plt.gcf()
            save_plot(loss_fig, config, "training_loss_fm_initial.png")
    else:
        print("Initial training failed, model not saved.")
        initial_model = None # Ensure it's None if training failed

# --- Load Initial Model if not trained in this run ---
if initial_model is None:
    initial_model_path = os.path.join(config_dir, initial_model_filename)
    if os.path.exists(initial_model_path):
        print(f"\n--- Loading Initial Model from {initial_model_path} ---")
        initial_model = load_model(VelocityMLP, config, filename=initial_model_filename)
    else:
        print(f"\nInitial model file {initial_model_path} not found and training was skipped/failed. Cannot proceed with reflow.")
        exit() # Exit if no initial model is available



### sampling and plotting of the initial model

In [None]:
if initial_model:
    print("\n--- Generating Samples using Initial ODE Solver ---")
    num_generated_samples_initial = 4000 # Use the same number as reflow for comparison
    generated_data_initial = sample_flow(initial_model, config, num_samples=num_generated_samples_initial)

    if generated_data_initial is not None and generated_data_initial.shape[0] > 0:
            print(f"Generated {generated_data_initial.shape[0]} samples using Initial model.")

            # --- Analysis and Plotting (Initial) ---
            print("--- Analyzing and Plotting Initial Model Results ---")
            generated_radii_initial = torch.norm(generated_data_initial, dim=1).numpy()
            original_radii_np = original_radii.cpu().numpy() # Already computed or recompute if needed

            # Plot Radii Histogram (Initial)
            plt.figure(figsize=(10, 6))
            hist_range = (0.0, 1.5)
            plt.hist(original_radii_np, bins=50, range=hist_range, density=True, alpha=0.6, label='Original Data Radii')
            plt.hist(generated_radii_initial, bins=50, range=hist_range, density=True, alpha=0.6, label='Generated Data Radii (Initial)')
            plt.xlabel('Radius')
            plt.ylabel('Density')
            plt.legend()
            plt.title(f'Histogram of Data Radii (Initial Flow Matching)\n{get_config_description(config)}')
            plt.grid(True, linestyle='--', linewidth=0.5)
            histogram_fig_initial = plt.gcf()
            save_plot(histogram_fig_initial, config, "radii_histogram_fm_initial.png")

            # Plot Scatter (Initial, if M >= 2)
            if M >= 2:
                plt.figure(figsize=(8, 8))
                num_points_to_plot = min(1000, data_on_device.shape[0], generated_data_initial.shape[0])
                orig_data_cpu = data_on_device[:num_points_to_plot].cpu()
                gen_data_initial_cpu = generated_data_initial[:num_points_to_plot].cpu()
                plt.scatter(orig_data_cpu[:, 0], orig_data_cpu[:, 1], alpha=0.5, s=10, label='Original Data Sample')
                plt.scatter(gen_data_initial_cpu[:, 0], gen_data_initial_cpu[:, 1], alpha=0.5, s=10, label='Generated Data Sample (Initial)')
                plt.xlabel('Feature 1')
                plt.ylabel('Feature 2')
                plt.legend()
                plt.title(f'Data Scatter Plot (Initial Flow Matching)\n{get_config_description(config)}')
                plt.axis('equal')
                plt.grid(True, linestyle='--', linewidth=0.5)
                scatter_fig_initial = plt.gcf()
                save_plot(scatter_fig_initial, config, "scatter_plot_2d_fm_initial.png")

            print(f"\nInitial model results saved in directory: {config_dir}")
    else:
            print("No data generated from initial model or sampling failed.")
else:
    # This case should technically not be reached due to the exit() above if loading fails
    print("\nNo initial model available for sampling.")
    exit()



## Reflow

### generate trajectories using the initial model

In [None]:

# --- Generate Trajectories for Reflow ---
print("\n--- Generating Trajectories for Reflow ---")
# Use a reasonable number of trajectories for reflow training data
num_reflow_samples = num_samples # Or adjust based on memory/time
trajectories = generate_trajectories(initial_model, config, num_samples=num_reflow_samples)
# trajectories shape: [ode_steps, num_reflow_samples, M]

# Extract x0 (start) and x1 (end) points for reflow training
# Ensure they remain on the device
x0_reflow = trajectories[0]   # Shape: [num_reflow_samples, M]
x1_reflow = trajectories[-1]  # Shape: [num_reflow_samples, M]
print(f"Generated {x0_reflow.shape[0]} (x0, x1) pairs for reflow training on device {x0_reflow.device}.")
del trajectories # Free up memory if trajectories tensor is large


### Train reflow model

In [None]:

# --- Reflow Training ---
train_reflow_model_flag = False
reflow_model = None
epoch_losses_reflow = []

if train_reflow_model_flag:
    print("\n--- Starting Reflow Matching Training ---")
    # Train the reflow model. Can optionally pass initial_model to fine-tune.
    # Here, we train a new model from scratch using the reflow data.
    # To fine-tune, pass: initial_model=initial_model
    reflow_model, epoch_losses_reflow = train_reflow_matching(x0_reflow, x1_reflow, config, initial_model=None)

    if reflow_model:
        save_model(reflow_model, config, filename=reflow_model_filename)
        if epoch_losses_reflow:
            plt.figure(figsize=(10, 5))
            plt.plot(epoch_losses_reflow, label='Reflow Training Loss')
            plt.xlabel('Epoch')
            plt.ylabel('Loss (MSE)')
            plt.yscale('log')
            plt.grid(True, which='both', linestyle='--', linewidth=0.5)
            plt.legend()
            plt.title(f'Reflow Matching Training Loss\n{get_config_description(config)}')
            loss_fig = plt.gcf()
            save_plot(loss_fig, config, "training_loss_fm_reflow.png")
    else:
        print("Reflow training failed, model not saved.")
        reflow_model = None

# --- Sampling from Reflow Model ---
if reflow_model is None:
    reflow_model_path = os.path.join(config_dir, reflow_model_filename)
    if os.path.exists(reflow_model_path):
            print(f"\n--- Loading Reflow Model from {reflow_model_path} ---")
            reflow_model = load_model(VelocityMLP, config, filename=reflow_model_filename)
    else:
            print(f"\nReflow model file {reflow_model_path} not found and training was skipped/failed. Cannot perform reflow sampling.")
            # Optionally load and sample from the initial model as a fallback
            # reflow_model = initial_model # Sample from initial if reflow failed
            # print("Sampling from initial model instead.")


### make plots of the reflow model

In [None]:
if reflow_model:
    print("\n--- Generating Samples using Reflowed ODE Solver ---")
    num_generated_samples = 4000
    generated_data_reflow = sample_flow(reflow_model, config, num_samples=num_generated_samples, isreflow=True)

    if generated_data_reflow is not None and generated_data_reflow.shape[0] > 0:
            print(f"Generated {generated_data_reflow.shape[0]} samples using Reflow model.")

            # --- Analysis and Plotting (Reflow) ---
            print("--- Analyzing and Plotting Reflow Results ---")
            generated_radii_reflow = torch.norm(generated_data_reflow, dim=1).numpy()
            original_radii_np = original_radii.cpu().numpy() # Already computed or recompute

            # Plot Radii Histogram (Reflow)
            plt.figure(figsize=(10, 6))
            hist_range = (0.0, 1.5)
            plt.hist(original_radii_np, bins=50, range=hist_range, density=True, alpha=0.6, label='Original Data Radii')
            plt.hist(generated_radii_reflow, bins=50, range=hist_range, density=True, alpha=0.6, label='Generated Data Radii (Reflow)')
            plt.xlabel('Radius')
            plt.ylabel('Density')
            plt.legend()
            plt.title(f'Histogram of Data Radii (Reflow Matching)\n{get_config_description(config)}')
            plt.grid(True, linestyle='--', linewidth=0.5)
            histogram_fig = plt.gcf()
            save_plot(histogram_fig, config, "radii_histogram_fm_reflow.png")

            # Plot Scatter (Reflow, if M >= 2)
            if M >= 2:
                plt.figure(figsize=(8, 8))
                num_points_to_plot = min(1000, data_on_device.shape[0], generated_data_reflow.shape[0])
                orig_data_cpu = data_on_device[:num_points_to_plot].cpu()
                gen_data_reflow_cpu = generated_data_reflow[:num_points_to_plot].cpu()
                plt.scatter(orig_data_cpu[:, 0], orig_data_cpu[:, 1], alpha=0.5, s=10, label='Original Data Sample')
                plt.scatter(gen_data_reflow_cpu[:, 0], gen_data_reflow_cpu[:, 1], alpha=0.5, s=10, label='Generated Data Sample (Reflow)')
                plt.xlabel('Feature 1')
                plt.ylabel('Feature 2')
                plt.legend()
                plt.title(f'Data Scatter Plot (Reflow Matching)\n{get_config_description(config)}')
                plt.axis('equal')
                plt.grid(True, linestyle='--', linewidth=0.5)
                scatter_fig = plt.gcf()
                save_plot(scatter_fig, config, "scatter_plot_2d_fm_reflow.png")

            # --- Generate Animation (Reflow Model) ---
            create_flow_animation(reflow_model, config, config.animation_samples, filename="flow_animation_reflow.gif")

            print(f"\nReflow results saved in directory: {config_dir}")
    else:
            print("No data generated from reflow model or sampling failed.")
else:
    print("\nNo reflow model available for sampling.")

