# Overfitting Single Sample Analysis

## 1. Introduction

In this AI research project, we're developing a general posterior estimator for complex scientific models. Our approach involves building the entire pipeline from the simplest possible case - overfitting on a single sample - and gradually increasing complexity. This methodical process allows us to ensure each component works as expected before adding more layers.

The main goals of this notebook are:

1. To demonstrate the process of overfitting on a single sample using two approaches:
   a) Flow Matching (from flow_matching_single_example.py)
   b) Conditional Case (from pipeline_small.py)

2. To verify that the implemented general posterior estimator in the conditional case functions properly.

3. To prepare for future comparisons with "ground truth" models using MCMC or nested sampling for performance evaluation.

By starting with overfitting on a single sample, we can debug and test each level of the pipeline, ensuring a solid foundation for more complex implementations.

## 2. Flow Matching Single Example

### 2.1 Code Explanation

The flow_matching_single_example.py script implements a flow matching algorithm to overfit on a single data point from the "moons" dataset. Let's break down the main components:

```python
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from sklearn.datasets import make_moons
from zuko.utils import odeint

# ... (rest of the imports)

class MLP(nn.Sequential):
    def __init__(self, in_features, out_features, hidden_features=[64, 64]):
        layers = []
        for a, b in zip((in_features, *hidden_features), (*hidden_features, out_features)):
            layers.extend([nn.Linear(a, b), nn.ELU()])
        super().__init__(*layers[:-1])

class CNF(nn.Module):
    def __init__(self, features, freqs=3, **kwargs):
        super().__init__()
        self.net = MLP(2 * freqs + features, features, **kwargs)
        self.register_buffer('freqs', torch.arange(1, freqs + 1) * torch.pi)

    def forward(self, t, x):
        t = self.freqs * t[..., None]
        t = torch.cat((t.cos(), t.sin()), dim=-1)
        t = t.expand(*x.shape[:-1], -1)
        return self.net(torch.cat((t, x), dim=-1))

    # ... (rest of the CNF class methods)

class FlowMatchingLoss(nn.Module):
    def __init__(self, v):
        super().__init__()
        self.v = v

    def forward(self, x):
        t = torch.rand_like(x[..., 0, None])
        z = torch.randn_like(x)
        y = (1 - t) * x + (1e-4 + (1 - 1e-4) * t) * z
        u = (1 - 1e-4) * z - x
        return (self.v(t.squeeze(-1), y) - u).square().mean()

# Main training loop
if __name__ == '__main__':
    flow = CNF(2, hidden_features=[64] * 3)
    loss = FlowMatchingLoss(flow)
    optimizer = torch.optim.Adam(flow.parameters(), lr=1e-3)

    noise_levels = [0.01, 0.05, 0.1]
    for i, noise in enumerate(noise_levels):
        data, _ = make_moons(1, noise=noise)
        data = torch.from_numpy(data).float()
        
        batch_size = 32
        data_batch = data.expand(batch_size, -1, -1)

        for epoch in tqdm(range(10000), ncols=88, desc=f'Training with noise={noise}'):
            x = data_batch
            loss(x).backward()
            optimizer.step()
            optimizer.zero_grad()

        # Sampling and visualization
        with torch.no_grad():
            z = torch.randn(10, 2)
            x = flow.decode(z)

        plt.figure(figsize=(4.8, 4.8), dpi=150)
        plt.scatter(*data.T, color='red', label='Ground Truth')
        plt.scatter(*x.T, color='blue', alpha=0.5, label='Samples')
        plt.legend()
        plt.xlim(-1.5, 2.5)
        plt.ylim(-1, 1.5)
        plt.title(f'Overfitting on Single Example with Noise={noise}')
        plt.savefig(f'experiments/plots_fm/moons_fm_single_example_noise_{noise}.pdf')

        # Log-likelihood calculation
        with torch.no_grad():
            log_p = flow.log_prob(data)
        print(f'Log probability for noise {noise}: {log_p.item()}')

This code defines the MLP, CNF, and FlowMatchingLoss classes, then implements the training loop, sampling, and visualization.
### 2.2 Results Analysis
Let's look at the running outputs:

Training with noise=0.05: 100%|█████████████████| 10000/10000 [00:04<00:00, 2187.57it/s]
Log probability for noise 0.05: 7.608288764953613

These outputs show that:

1. The training completed 10,000 epochs for the noise level of 0.05.
2. The log probability for this noise level is 7.608288764953613.

This high log probability indicates that the model has successfully overfit to the single example, as it assigns a very high likelihood to the training point.
### 2.3 Plot Interpretation
Let's analyze the plot "Overfitting on Single Example with Noise=0.05":

# ![Overfitting on Single Example with Noise=0.05](plots_fm/moons_fm_single_example_noise_0.05.png)


In this plot:

1. The red dot represents the ground truth (original data point).
2. The blue dots represent samples generated by the model.
3. We can observe that the generated samples cluster tightly around the ground truth, indicating successful overfitting.

The tight clustering shows that the model has learned to generate points very close to the single training example, rather than capturing the general shape of the "moons" distribution. This is exactly what we want in this overfitting exercise.

## 3. Conditional Case

### 3.1 Code Explanation
The pipeline_small.py script implements a conditional flow matching approach. Let's break down the key components:

```python
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
import corner
import numpy as np
from lampe.plots import nice_rc
from lampe.utils import GDStep
from lampe.inference import FMPE, FMPELoss
from generate_burst_data import simulate_burst

# ... (rest of the imports and utility functions)

# Initialize models, loss functions, and optimizers
estimator = FMPE(theta_dim=1 * max_ncomp, x_dim=1000, freqs=5).to(device)
loss = FMPELoss(estimator)
optimizer = optim.Adam(estimator.parameters(), lr=1e-4)
step = GDStep(optimizer, clip=1.0)

# Generate a single sample for overfitting
ncomp = 2
burstparams = generate_burst_params(ncomp)
ymodel, ycounts = simulate_burst(time, ncomp, burstparams, ybkg*10, return_model=True, noise_type='gaussian')
fixed_t0 = torch.from_numpy(burstparams[:ncomp]).float().to(device)
fixed_x = torch.from_numpy(ycounts).float().to(device)

# Transform the time series data using Fourier Transform
import torch.fft as fft
fixed_x_fft = torch.abs(fft.fft(fixed_x))

# Training loop for overfitting
estimator.train()

batch_size = 1024
fixed_t0_batch = fixed_t0.repeat(batch_size, 1)
fixed_x_fft_batch = fixed_x_fft.repeat(batch_size, 1)

num_epochs = 50000

for epoch in range(num_epochs):
    optimizer.zero_grad()
    loss_value = loss(fixed_t0_batch, fixed_x_fft_batch)
    loss_value.backward()
    optimizer.step()
    
    if epoch % 1000 == 0:
        print(f"Overfitting Epoch {epoch+1}, Loss: {loss_value.item()}")

# Evaluation after overfitting
estimator.eval()

with torch.no_grad():
    num_samples = 1000
    samples_t0_given_x_N = estimator.flow(fixed_x_fft_batch).sample((num_samples,))
    
    mean_sample = samples_t0_given_x_N.mean(dim=0)
    std_sample = samples_t0_given_x_N.std(dim=0)
    
    batch_mean_sample = samples_t0_given_x_N.mean(dim=[0, 1])
    batch_std_sample = samples_t0_given_x_N.std(dim=[0, 1])
    
    print("Fixed Input t0:", fixed_t0_batch[0])
    print("Mean of sampled t0s from posterior after overfitting:", mean_sample)
    print("Standard deviation of sampled t0s from posterior after overfitting:", std_sample)
    print("Batch mean of sampled t0s from posterior after overfitting:", batch_mean_sample)
    print("Batch standard deviation of sampled t0s from posterior after overfitting:", batch_std_sample)

    # Plotting the samples using corner
    figure = corner.corner(samples_t0_given_x_N.view(-1, samples_t0_given_x_N.size(-1)).cpu().numpy(), 
                           labels=[f"t0_{i}" for i in range(samples_t0_given_x_N.size(-1))],
                           truths=fixed_t0_batch[0].cpu().numpy(), title="Posterior Samples vs Fixed Input t0")
    figure.savefig('plots_small/posterior_samples_fixed_t0_corner_fft_overfit.png')
    plt.close(figure)

    # Plotting the overfitting loss curve
    plt.figure(figsize=(10, 6))
    plt.plot(overfit_loss_values, label='Overfitting Loss over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Overfitting Loss Curve')
    plt.legend()
    plt.savefig('plots_small/loss_curve_fft_overfit.png')
    plt.close()

    # Plotting the conditioning data (time series)
    plt.figure(figsize=(10, 6))
    plt.plot(time, ycounts, label='Conditioning Time Series Data')
    plt.xlabel('Time')
    plt.ylabel('Counts')
    plt.title('Conditioning Time Series Data')
    plt.legend()
    plt.savefig('plots_small/conditioning_data_time_series.png')
    plt.close()

This code sets up the FMPE model, generates a single burst data sample, transforms it using FFT, and then trains the model to overfit on this single sample. It then evaluates the model and generates various plots for analysis.

### 3.2 Results Analysis
Let's examine the running outputs:

Using device: cuda

Overfitting Epoch 1, Loss: 15840.744140625

Overfitting Epoch 1001, Loss: 0.9846465587615967

Overfitting Epoch 2001, Loss: 1.0064780712127686

...

Overfitting Epoch 48001, Loss: 0.2650120258331299

Overfitting Epoch 49001, Loss: 0.2915732264518738

Training completed in 95.00 seconds.

Fixed Input t0: tensor([0.5347, 0.3227], device='cuda:0')

Mean of sampled t0s from posterior after overfitting: tensor([[0.5523, 0.3232],
        [0.5653, 0.3281],
        [0.5659, 0.3096],
        ...,
        [0.5642, 0.3147],
        [0.5519, 0.3207],
        [0.5561, 0.3293]], device='cuda:0')

Standard deviation of sampled t0s from posterior after overfitting: tensor([[0.2339, 0.2265],
        [0.2301, 0.2214],
        [0.2270, 0.2220],
        ...,
        [0.2352, 0.2259],
        [0.2213, 0.2184],
        [0.2293, 0.2329]], device='cuda:0')

Batch mean of sampled t0s from posterior after overfitting: tensor([0.5579, 0.3230], device='cuda:0')

Batch standard deviation of sampled t0s from posterior after overfitting: tensor([0.2292, 0.2272], device='cuda:0')


These results show:

1. The training completed 50,000 epochs.
2. The loss decreased from 15840.744140625 to around 0.2915732264518738.
3. The fixed input t0 values are [0.5347, 0.3227].
4. The batch mean of sampled t0s [0.5579, 0.3230] is close to the input.
5. The batch standard deviation [0.2292, 0.2272] is relatively high.

These results suggest that while the model has learned to generate samples close to the input on average, there's still significant variability in the outputs.

### 3.3 Plot Interpretation
Let's analyze the generated plots:

1. Posterior Samples vs Fixed Input t0 (Corner Plot):

# ![Posterior Samples vs Fixed Input t0 (Corner Plot)](plots_small/posterior_samples_fixed_t0_corner_fft_overfit.png)

This corner plot shows:

* The distribution of sampled t0 values in 2D.
* Blue lines represent the true input t0 values.
* Contours represent the 68%, 95%, and 99.7% credible regions.
* Histograms on the diagonal show the marginal distributions for each t0.

2. Overfitting Loss Curve:


# ![Overfitting Loss Curve](plots_small/loss_curve_fft_overfit.png)

This plot shows how the loss decreased over the training epochs.

3. Conditioning Time Series Data:

# ![Conditioning Time Series Data](plots_small/conditioning_data_time_series.png)

This plot shows the single burst data sample used for conditioning.
### 3.4 Issues with Overfitting Procedure
Despite the long training process, there are still some issues with the overfitting:

1. Spread of Samples: The posterior samples are more spread out than expected for perfect overfitting. Ideally, they should cluster very tightly around the true values.
2. Credible Regions: Many samples lie outside the innermost credible region. For true overfitting, we'd expect almost all samples to be within the 68% credible region.
3. Uncertainty: The relatively high standard deviations (0.2292 and 0.2272) indicate that the model is still uncertain about the true values, despite seeing only one example repeatedly.

These issues suggest that the model might be struggling to fully capture the complexity of the single example, or that the training process could be further optimized.
## 4. Conclusion
This notebook has demonstrated the process of overfitting on a single sample using both flow matching and conditional approaches. These experiments represent the initial steps in building a robust pipeline for a general posterior estimator in complex scientific models. The results show that the model is still uncertain about the true values, despite seeing only one example repeatedly. This suggests that further optimization of the training process is necessary to improve the model's performance.

Key takeaways:

1. The flow matching example showed tighter clustering around the true value, indicating successful overfitting on a single sample.

2. The conditional case, while showing promising results, still exhibits some spread in its posterior samples. This suggests that further refinement may be needed in the model architecture, loss function, or training process.

3. The issues identified in the conditional case provide valuable insights for improving the general posterior estimator. Addressing these challenges will be crucial as we move towards more complex implementations.

Next steps in the research project:

1. Refine the conditional case to achieve tighter overfitting on a single sample. This may involve:
   - Experimenting with different model architectures
   - Adjusting the learning rate or using learning rate schedules
   - Increasing the number of training epochs
   - Exploring alternative loss functions

2. Gradually increase the complexity of the input data and model:
   - Test with multiple burst components
   - Introduce variability in the background noise
   - Experiment with different noise types (e.g., Poisson noise)

3. Implement and compare with "ground truth" models using MCMC or nested sampling:
   - Develop MCMC and nested sampling implementations for the burst model
   - Compare the posterior distributions obtained from these methods with our flow-based approach

4. Evaluate the performance of the general posterior estimator against these benchmark methods:
   - Compare computational efficiency
   - Assess accuracy of posterior estimates
   - Analyze the ability to capture multi-modal or complex posterior distributions

By continuing this methodical approach of testing and refinement at each level of complexity, we can build a robust and accurate general posterior estimator for complex scientific models.

## 5. Further Investigations

Based on our current results, here are some specific areas we could investigate to improve the conditional case:

1. Model Capacity: The current FMPE model might not have sufficient capacity to capture the complexity of the burst data. We could experiment with:
   - Increasing the number of layers or neurons in the neural network
   - Using more sophisticated architectures like residual networks or attention mechanisms

2. Training Dynamics: The training process might be sub-optimal. We could try:
   - Implementing gradient clipping to handle potential exploding gradients
   - Using a learning rate scheduler to adjust the learning rate during training
   - Exploring different optimizers (e.g., AdamW, RMSprop)

3. Data Representation: The current FFT representation might not be optimal. We could investigate:
   - Different data transformations (e.g., wavelet transforms)
   - Normalizing or scaling the input data differently

4. Loss Function: The current loss function might not be ideal for this specific problem. We could:
   - Experiment with different variants of the flow matching loss
   - Incorporate additional regularization terms

5. Sampling Process: The current sampling process might not be capturing the true posterior effectively. We could:
   - Increase the number of samples drawn
   - Implement Markov Chain Monte Carlo (MCMC) sampling on top of the flow-based model

6. Numerical Stability: There might be numerical stability issues affecting the results. We should:
   - Check for any NaN or infinity values during training
   - Implement robust handling of edge cases in the model

By systematically addressing these areas, we can work towards improving the overfitting performance on a single sample in the conditional case, setting a strong foundation for more complex scenarios in our general posterior estimator.