In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.layer = nn.Sequential(
            nn.Linear(1, dim),
            nn.SiLU(),
            nn.Linear(dim, dim)
        )
        
    def forward(self, t):
        # Convert time to float32
        t = t.float().unsqueeze(-1)
        return self.layer(t)

class EELSDataset(Dataset):
    def __init__(self, data, noise_level=0.1, device='cpu'):
        # Ensure data is float32
        self.data = torch.FloatTensor(data).to(device)
        self.noise_level = noise_level
        self.device = device
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        spectrum = self.data[idx]
        noisy_spectrum = spectrum + torch.randn_like(spectrum) * self.noise_level * torch.std(spectrum)
        return noisy_spectrum, spectrum

class DiffusionModel:
    def __init__(self, model, n_steps=1000, beta_start=1e-4, beta_end=2e-2, device='cpu'):
        self.model = model
        self.n_steps = n_steps
        self.device = device
        
        # Move all tensors to the specified device and ensure float32
        self.beta = torch.linspace(beta_start, beta_end, n_steps, dtype=torch.float32).to(device)
        self.alpha = (1 - self.beta).to(device)
        self.alpha_bar = torch.cumprod(self.alpha, dim=0).to(device)
        
    def diffusion_step(self, x, t):
        noise = torch.randn_like(x, device=self.device)
        t = t.long()  # Ensure t is long for indexing
        alpha_t = self.alpha_bar[t].float().view(-1, 1)  # Convert to float32
        return torch.sqrt(alpha_t) * x + torch.sqrt(1 - alpha_t) * noise, noise
    
    def denoise(self, noisy_data, n_steps=None):
        if n_steps is None:
            n_steps = self.n_steps
            
        with torch.no_grad():
            x = noisy_data.to(self.device)
            for t in range(n_steps - 1, -1, -1):
                t_tensor = torch.ones(x.shape[0], device=self.device, dtype=torch.float32) * t
                predicted_noise = self.model(x, t_tensor)
                alpha_t = self.alpha[t].float()
                alpha_bar_t = self.alpha_bar[t].float()
                
                if t > 0:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)
                    
                x = (1 / torch.sqrt(alpha_t)) * (
                    x - (1 - alpha_t) / torch.sqrt(1 - alpha_bar_t) * predicted_noise
                ) + torch.sqrt(1 - alpha_t) * noise
                
        return x.cpu()

class SpectralUNet(nn.Module):
    def __init__(self, in_channels=1, hidden_dims=[16, 32, 64]):
        super().__init__()
        self.time_embed = TimeEmbedding(hidden_dims[0])
        self.hidden_dims = hidden_dims
        
        self.time_mlp = nn.Sequential(
            nn.Linear(hidden_dims[0], hidden_dims[0]),
            nn.SiLU()
        )
        
        # Encoder
        self.encoders = nn.ModuleList()
        current_channels = in_channels
        for dim in hidden_dims:
            self.encoders.append(nn.Sequential(
                nn.Conv1d(current_channels, dim, 3, padding='same'),
                nn.GroupNorm(8, dim),
                nn.SiLU(),
                nn.Conv1d(dim, dim, 3, padding='same'),
                nn.GroupNorm(8, dim),
                nn.SiLU(),
                nn.AvgPool1d(2, padding=0)
            ))
            current_channels = dim
        
        # Middle
        self.middle = nn.Sequential(
            nn.Conv1d(hidden_dims[-1], hidden_dims[-1], 3, padding='same'),
            nn.GroupNorm(8, hidden_dims[-1]),
            nn.SiLU(),
            nn.Conv1d(hidden_dims[-1], hidden_dims[-1], 3, padding='same'),
            nn.GroupNorm(8, hidden_dims[-1]),
            nn.SiLU()
        )
        
        # Decoder
        self.decoders = nn.ModuleList()
        hidden_dims_reversed = hidden_dims[::-1]
        
        for i in range(len(hidden_dims) - 1):
            in_channels = hidden_dims_reversed[i] + hidden_dims_reversed[i + 1]
            out_channels = hidden_dims_reversed[i + 1]
            
            self.decoders.append(nn.Sequential(
                nn.Conv1d(in_channels, out_channels, 3, padding='same'),
                nn.GroupNorm(8, out_channels),
                nn.SiLU(),
                nn.Conv1d(out_channels, out_channels, 3, padding='same'),
                nn.GroupNorm(8, out_channels),
                nn.SiLU()
            ))
        
        # Output
        self.output = nn.Conv1d(hidden_dims[0], 1, 1)
        
    def forward(self, x, t):
        # Add channel dimension if needed
        if len(x.shape) == 2:
            x = x.unsqueeze(1)
        
        # Store original size
        original_size = x.shape[-1]
        
        # Time embedding
        t_emb = self.time_embed(t)
        t_emb = self.time_mlp(t_emb)
        
        # Encoder path with skip connections
        skips = []
        for encoder in self.encoders:
            x = encoder(x)
            skips.append(x)
        
        # Middle
        x = self.middle(x)
        
        # Decoder path
        skips = skips[:-1][::-1]  # Remove last skip and reverse
        
        for skip, decoder in zip(skips, self.decoders):
            # Upsample to match skip connection size
            x = F.interpolate(x, size=skip.shape[-1], mode='linear', align_corners=False)
            x = torch.cat([x, skip], dim=1)
            x = decoder(x)
        
        # Final upsampling to original size
        x = F.interpolate(x, size=original_size, mode='linear', align_corners=False)
        x = self.output(x)
        
        return x.squeeze(1)

def train_diffusion_model(SI_data, device='cuda'):
    # Prepare data
    data = SI_data.reshape(-1, SI_data.shape[-1])
    
    try:
        # Create dataset and dataloader
        dataset = EELSDataset(data, device=device)
        dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
        
        # Initialize model
        model = SpectralUNet(
            in_channels=1,
            hidden_dims=[16, 32, 64]
        ).to(device)
        
        # Initialize diffusion
        diffusion = DiffusionModel(model, n_steps=50, device=device)
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
        
        # Training loop
        n_epochs = 1
        for epoch in range(n_epochs):
            total_loss = 0
            for noisy, clean in dataloader:
                noisy, clean = noisy.float(), clean.float()
                t = torch.randint(0, diffusion.n_steps, (clean.shape[0],), device=device)
                
                x_t, noise = diffusion.diffusion_step(clean, t)
                t = t.float()
                
                predicted_noise = model(x_t, t)
                loss = F.mse_loss(predicted_noise, noise)
                
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                
                total_loss += loss.item()
                
            avg_loss = total_loss / len(dataloader)
            print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.6f}")
        
        return diffusion
        
    except Exception as e:
        print(f"Error occurred: {str(e)}")
        raise e

def denoise_EELS_data(SI_data, diffusion_model=None, device='cuda'):
    print(f"Starting denoising process...")
    
    if diffusion_model is None:
        print("Training new diffusion model...")
        diffusion_model = train_diffusion_model(SI_data, device)
    
    # Reshape data for processing
    original_shape = SI_data.shape
    data = SI_data.reshape(-1, original_shape[-1])
    
    # Process in batches
    batch_size = 32
    denoised_data = []
    
    for i in range(0, len(data), batch_size):
        batch = torch.FloatTensor(data[i:i+batch_size]).to(device)
        denoised_batch = diffusion_model.denoise(batch)
        denoised_data.append(denoised_batch.cpu().numpy())
    
    denoised_data = np.concatenate(denoised_data, axis=0)
    denoised_data = denoised_data.reshape(original_shape)
    
    print("Denoising complete.")
    return denoised_data
     
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Ensure input data is float32
SI_raw.data = SI_raw.data.astype(np.float32)

try:
    # Print input shape
    print(f"Input data shape: {SI_raw.data.shape}")
    
    # Denoise the data
    denoised_data = denoise_EELS_data(SI_raw.data, device=device)
    
    # Convert back to HyperSpy signal
    SI_denoised = hs.signals.Signal1D(denoised_data)
    
    # Plot comparison
    x, y = np.random.randint(0, SI_raw.data.shape[0]), np.random.randint(0, SI_raw.data.shape[1])
    plt.figure(figsize=(12, 6))
    plt.plot(SI_raw.data[x,y], label='Original', alpha=0.7)
    plt.plot(denoised_data[x,y], label='Denoised', alpha=0.7)
    plt.legend()
    plt.title(f'Comparison at pixel ({x}, {y})')
    plt.show()
    
except RuntimeError as e:
    print(f"Error occurred: {str(e)}")