In [2]:
import matplotlib

matplotlib.use("Agg")
import argparse
import colorsys
import json
import os
import sys
from pathlib import Path

import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from scipy.stats import multivariate_normal
from pathlib import Path
sys.path.append(str(Path().resolve().parent.parent))
from experiments.procece_data import procece_data
from src.utils import metrics

In [3]:
DATA_DIR = Path().resolve().parents[1] / "data"

In [110]:
sim_id = '20250119191850'

In [111]:
# Load data and parameters
data = xr.open_dataset(DATA_DIR / sim_id / "data.nc")
params = xr.open_dataset(DATA_DIR / sim_id / "params.nc")


In [112]:
params

In [113]:
# Set up plot parameters
K = len(params.k)
S = len(data.s)
x_lim = (-30, 30) 
y_lim = (-30, 30)

# Generate colors for clusters
def generate_colors(n):
    colors = []
    for i in range(n):
        hue = i / n
        saturation = 0.7
        value = 0.9
        rgb = colorsys.hsv_to_rgb(hue, saturation, value)
        colors.append(rgb)
    return colors

cluster_colors = generate_colors(K)


In [114]:
K

4

In [115]:
# Create animation
# Calculate plot limits based on m values
m_max = params.m.max().max()
m_min = params.m.min().min()
abs_max = max(abs(m_max), abs(m_min))
x_lim = y_lim = (-abs_max*1.2, abs_max*1.2)  # Add 20% margin




def update(i):
    ax.clear()
    artists = []
    
    # Get data for current iteration
    X = data.X.isel(iter=i)
    Z = data.Z.isel(iter=i)
    
    # Plot data points colored by cluster assignment
    fake_z = np.argmax(Z.values, axis=1)
    for z in range(K):
        mask = fake_z == z
        if np.any(mask):  # Only plot if there are points for this cluster
            X_with_z = X[mask]
            scatter = ax.scatter(X_with_z[:, 0], X_with_z[:, 1], 
                               c=[cluster_colors[z]], s=2, alpha=0.2)
            artists.append(scatter)

    # Plot distribution for each cluster
    for k in range(K):
        # Convert mean from DataArray to numpy array and select correct context
        for s in range(S):
            mean = params.m.isel(iter=i, k=k, s=s).values
            
            x, y = np.meshgrid(np.linspace(*x_lim, 100), np.linspace(*y_lim, 100))
            xy = np.column_stack([x.flat, y.flat])

            matrix = (params.beta.isel(iter=i, k=k, s=s) * 
                    params.W.isel(iter=i, k=k, s=s))
            covar = np.linalg.inv(matrix)
            # Skip if covariance contains inf or nan
            if not np.all(np.isfinite(covar)):
                continue
            s_alpha_norm = params.s_alpha.isel(iter=i, k=k).values / np.sum(params.s_alpha.isel(iter=i, k=k).values)
            alpha = np.exp(-np.linalg.norm(s_alpha_norm[s]) * 0.5) # Reduced decay rate for alpha
            alpha = max(0.3, min(1.0, alpha)) # Increased minimum alpha to 0.3
            z = multivariate_normal.pdf(xy, mean=mean, cov=covar).reshape(x.shape)

            rv = multivariate_normal(mean, covar)
            level = rv.pdf(mean) * np.exp(-0.5)
            
            contour = ax.contour(x, y, z, levels=[level], colors=[cluster_colors[k]], alpha=alpha)
            artists.append(contour)
            contourf = ax.contourf(x, y, z, levels=[level, 1], 
                                colors=[cluster_colors[k]], alpha=alpha*0.2)
            artists.append(contourf)


    ax.set_xlim(x_lim)
    ax.set_ylim(y_lim)
    plt.tight_layout()
    
    return artists

ani = animation.FuncAnimation(fig, update, frames=len(data.iter), 
                            interval=100, blit=True)
ani.save(DATA_DIR / sim_id / 'animation.gif', writer='pillow')
plt.close()