In [1]:
import numpy as np
from scipy.integrate import odeint

def generate_trajectory_simple(time_span=200, num_points=20000):
    """
    Generate trajectory data for the harmonic oscillator equation.

    Parameters:
    time_span (float): Total time for simulation.
    num_points (int): Number of data points.

    Returns:
    np.ndarray: A 2D array 
    """
    # problem
    omega = np.array([[2., 1.], [1., 2.]])
    X0 = np.array([2., 0.])
    Xprime0 = np.array([0., np.sqrt(8.)])
    
    # Generate time array
    t = np.linspace(0, time_span, num_points)

    # Compute trajectory using the analytical solution
    x = (np.cos(3*t) + np.sin(3*t)) - (np.cos(t) + np.sin(t))
    y = (np.cos(3*t) + np.sin(3*t)) + (np.cos(t) + np.sin(t))
    
    return np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1)], axis=-1)



def generate_trajectory_complex(time_span=200, num_points=20000):
    """
    Generates trajectory data for the nonlinear oscillator equation.

    Parameters:
    time_span (tuple): Time range for the solution (start, end).
    num_points (int): Number of points for time evaluation.

    Returns:
    np.ndarray: A 2D array 
    """
    def sine_system(x, t):
        omega = np.array([[2., 1.], [1., 2.]])
        dydt = np.hstack([x[2:], - np.sin(omega**2 @ x[:2])])
        return dydt
        
    # Initial conditions
    X0 = np.array([2., 0.])
    Xprime0 = np.array([0., np.sqrt(8.)])
    y0 = np.hstack([X0, Xprime0])
    
    # Generate time evaluation points
    t_eval = np.linspace(0, time_span, num_points)

    # Solve the system numerically using odeint
    sol = odeint(sine_system, y0, t_eval)

    return sol#[:, 0].reshape(-1, 2)


U_s = generate_trajectory_simple()
U_m = generate_trajectory_complex()
print(U_s.shape, U_m.shape)

(20000, 2) (20000, 4)


In [28]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchsummary import summary
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
import os, time, csv, glob, cv2



# Generate a synthetic dataset
def generate_data(trajectory, test_size, random_state=42):
    X, Y = trajectory[:-1], trajectory[1:]
    return train_test_split(X, Y, test_size=test_size, random_state=random_state)



# Define the neural network
class FourierNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(FourierNN, self).__init__()
        self.inner1 = nn.Linear(input_size, hidden_size, bias=False)
        self.outer1 = nn.Linear(hidden_size, output_size, bias=False)
        self.inner2 = nn.Linear(input_size, hidden_size, bias=False)
        self.outer2 = nn.Linear(hidden_size, output_size, bias=False)

    def forward(self, x):
        return self.outer1(torch.cos(self.inner1(x))) + self.outer2(torch.cos(self.inner2(x)))
 

def forecast(setup, model, validation_trajectory, plot_title="Trajectory Comparison"):
    """
    Recursively applies the model to generate a trajectory and compares it to a validation trajectory.

    Parameters:
    - model: The neural network model (e.g., FourierNN).
    - validation_trajectory: A tensor or numpy array with the ground truth trajectory for comparison.
    - plot_title: Title for the plot (default: "Trajectory Comparison").

    Returns:
    - generated_trajectory: The trajectory generated by the model.
    """

    # Prepare storage for the generated trajectory
    steps = validation_trajectory.shape[0]
    print(steps)
    initial_condition = validation_trajectory[0]
    generated_trajectory = np.zeros_like(validation_trajectory)
    generated_trajectory[0] = initial_condition
    model.eval()

    # Generate the trajectory recursively
    current_state = initial_condition.reshape(-1, 1)
    for i in range(steps-1):
        next_state = model(torch.tensor(current_state, dtype=torch.float32))
        generated_trajectory[i+1]= next_state.detach().numpy()
        current_state = next_state


    # # Plot the trajectories
    plt.figure(figsize=(10, 6))
    plt.plot(generated_trajectory.flatten(), label=f'Predicted Trajectory')
    plt.plot(validation_trajectory.flatten(), linestyle='--', label=f'Validation Trajectory')
    
    plt.ylabel(r'$\mathbf{u}_n$')
    plt.xlabel(r'$n$')
    plt.title(plot_title)
    plt.legend()
    plt.show()

    return generated_trajectory


def train_model(model, criterion, optimizer, X_train, Y_train, num_epochs, interval=100):
    model.train()
    epoch_progress = tqdm(range(num_epochs), desc="Epochs", position=0)
    for epoch in epoch_progress:
        optimizer.zero_grad()
        loss = criterion(model(X_train), Y_train)
        loss.backward()
        optimizer.step()
        # Log progress 
        if epoch % interval == 0 or epoch == num_epochs - 1:
            epoch_progress.set_postfix(avg_loss=loss.item())



# Evaluate the network
def evaluate_model(model, X_val, Y_val):
    with torch.no_grad():
        outputs = model(X_val)
    return nn.MSELoss()(outputs, Y_val)


# Function to plot parameters
def plot_parameters(setup, model, epoch, train_loss, test_loss, display=False):
    weight_file = f'../data/{setup}/training_history/model_{epoch}.pth'
    model.load_state_dict(torch.load(weight_file))  # Load the weights
    plot_folder = f'../data/{setup}/training_history/weight_plots'
    
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 12))
    
    # Plot inner1 layer weights
    inner1_weights = model.inner1.weight.detach().numpy().flatten()
    axes[0, 0].plot(inner1_weights)
    axes[0, 0].set_title('Inner1 Weights')
    axes[0, 0].set_ylabel('Weight Value')

    # Plot inner2 layer weights
    inner2_weights = model.inner2.weight.detach().numpy().flatten()
    axes[0, 1].plot(inner2_weights)
    axes[0, 1].set_title('Inner2 Weights')
    axes[0, 1].set_ylabel('Weight Value')

    # Plot outer1 layer weights
    outer1_weights = model.outer1.weight.detach().numpy().flatten()
    axes[1, 0].plot(outer1_weights)
    axes[1, 0].set_title('Outer1 Weights')
    axes[1, 0].set_ylabel('Weight Value')

    # Plot outer2 layer weights
    outer2_weights = model.outer2.weight.detach().numpy().flatten()
    axes[1, 1].plot(outer2_weights)
    axes[1, 1].set_title('Outer2 Weights')
    axes[1, 1].set_ylabel('Weight Value')
    
    fig.suptitle(f'Epoch = {epoch}, Train Loss={train_loss:.4f}, Test Loss={test_loss:.4f}', fontsize=13)
    plt.savefig(f'{plot_folder}/model_{epoch}.png', dpi=300)
    if display:
        plt.show()
    plt.close()



# Check for "grokking" with a plot
def check_grokking(setup, model, criterion, optimizer, X_train, Y_train, num_epochs, interval,\
                   X_val, Y_val, batch_size):
    train_losses = []
    val_losses = []
    log_epochs = []
    epoch_progress = tqdm(range(num_epochs), desc="Epochs", position=0)
    model.train()
    beta = 7e-5
    save_folder = f'../data/{setup}/training_history'

    # Create data loader for SGD
    train_dataset = torch.utils.data.TensorDataset(X_train, Y_train)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    # Ensure save folder exists
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)
    log_file = f'{save_folder}/training_log.csv'

    # Initialize CSV file for logging
    with open(log_file, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(['Epoch', 'Train Loss', 'Test Loss', 'Time Taken'])

    for epoch in epoch_progress:
        start_time = time.time()
        epoch_train_loss = 0.0
        for X_batch, Y_batch in train_loader:
            optimizer.zero_grad()
            model_weight = torch.sum(model.inner1.weight**2) + torch.sum(model.outer1.weight**2) \
                         + torch.sum(model.inner2.weight**2) + torch.sum(model.outer2.weight**2)
            batch_loss = criterion(model(X_batch), Y_batch) + beta * model_weight
            batch_loss.backward()
            optimizer.step()
            epoch_train_loss += batch_loss.item()

        avg_train_loss = epoch_train_loss / len(train_loader)

        # Log progress
        if epoch % interval == 0 or epoch == num_epochs - 1:
            epoch_progress.set_postfix(avg_loss=avg_train_loss)
            train_losses.append(avg_train_loss)
            val_loss = evaluate_model(model, X_val, Y_val) + beta * model_weight
            val_losses.append(val_loss.item())
            log_epochs.append(epoch)
            # Save model weights
            torch.save(model.state_dict(), f"{save_folder}/model_{epoch}.pth")
            # Log to CSV
            time_taken = time.time() - start_time
            with open(log_file, mode='a', newline='') as file:
                writer = csv.writer(file)
                writer.writerow([epoch, avg_train_loss, val_loss.item(), time_taken])
    
        
    # Plot training and validation accuracy
    k = 5
    plt.figure(figsize=(8, 6))
    plt.semilogy(log_epochs[k:], train_losses[k:], label='Training Loss', linestyle='solid', linewidth=2)
    plt.semilogy(log_epochs[k:], val_losses[k:], label='Validation Loss', linestyle='dashed', linewidth=2)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training vs Validation Loss')
    plt.legend()
    plt.savefig('{}/{}-loss_curve.png'.format(save_folder, setup), bbox_inches='tight')
    plt.show()


def generate_weight_plots(setup, model, num_epochs, interval):
    plot_folder = f'../data/{setup}/training_history/weight_plots'
    log_file = f'../data/{setup}/training_history/training_log.csv'
    if not os.path.exists(plot_folder):
        os.makedirs(plot_folder)
    
    data = pd.read_csv(log_file)
    
    for epoch in range(num_epochs):
        if epoch % interval == 0 or epoch == num_epochs - 1:
            data_ = data[data['Epoch']==epoch]
            train_loss = float(data_['Train Loss'].iloc[0])
            test_loss = float(data_['Test Loss'].iloc[0])
            plot_parameters(setup, model, epoch, train_loss, test_loss) 




def create_animation(setup, fps=10):
    plot_folder = f'../data/{setup}/training_history/weight_plots'
    output_video_path = f'../data/{setup}/training_history/model_evolution.mp4'
    image_files = sorted(glob.glob(f"{plot_folder}/*.png"))
    frame = cv2.imread(image_files[0])
    height, width, layers = frame.shape

    video = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))

    for image_file in image_files:
        frame = cv2.imread(image_file)
        video.write(frame)

    video.release()



# Main script

# Prepare data
setup = 'harmonic'

if setup == 'harmonic':
    U = U_s
elif setup == 'complex':
    U = U_m

X_train, X_val, Y_train, Y_val = generate_data(U[:2000], test_size=0.5)
X_train, X_val = torch.tensor(X_train, dtype=torch.float32), torch.tensor(X_val, dtype=torch.float32)
Y_train, Y_val = torch.tensor(Y_train, dtype=torch.float32), torch.tensor(Y_val, dtype=torch.float32)

train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(X_train, Y_train), batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(X_val, Y_val), batch_size=32, shuffle=False)

# Initialize model, criterion, and optimizer
input_size = X_train.shape[1]
hidden_size = 500
num_epochs = 10000
interval = 100
output_size = 2
batch_size = 100
model = FourierNN(input_size, hidden_size, output_size)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)


# train_model(model, criterion, optimizer, X_train, Y_train, num_epochs)
# Check for grokking with a plot
# check_grokking(setup, model, criterion, optimizer, X_train, Y_train, num_epochs, interval, X_val, Y_val, batch_size)
generate_weight_plots(setup, model, num_epochs, interval)
# create_animation(setup)
# forecast(setup, model, U[:100])