In [None]:
import umap  # pip install umap-learn
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
from scipy.ndimage import rotate
import matplotlib.pyplot as plt

class MultiDigitRotatedMNISTDataset(Dataset):
    """
    For each digit in [0..9], pick 1 example from MNIST.
    For each digit d != 4, keep all angles in [0..360].
    For digit 4, skip angles in [140..200].
    """
    def __init__(self, num_angles_per_digit=100, gap_start=140, gap_end=200):
        super().__init__()
        
        # Load MNIST (training set)
        mnist = datasets.MNIST(
            root='./data',
            train=True,
            download=True,
            transform=transforms.ToTensor()
        )
        
        # We'll store one reference image per digit
        self.reference_images = {}
        for digit in range(10):
            idxs = (mnist.targets == digit).nonzero().squeeze()
            if len(idxs) == 0:
                raise ValueError(f"No examples of digit {digit} found in MNIST!")
            # Just pick the first occurrence
            img = mnist.data[idxs[0]].float() / 255.0  # shape (28, 28)
            self.reference_images[digit] = img
        
        # For each digit, build a list of angles to keep
        # We'll store pairs (digit, angle_in_radians) in self.samples
        self.samples = []
        
        # Convert gap angles to radians
        gap_start_rad = gap_start * np.pi / 180.0
        gap_end_rad   = gap_end   * np.pi / 180.0
        
        # We'll do a uniform sampling from [0..2π]
        # e.g., 0..2π broken into num_angles_per_digit points
        for digit in range(10):
            all_angles = np.linspace(0, 2*np.pi, num_angles_per_digit, endpoint=False)
            
            if digit == 4:
                # skip angles in [gap_start, gap_end]
                def in_gap(a):
                    deg = a * 180.0 / np.pi
                    return (deg >= gap_start) and (deg <= gap_end)
                
                # filter out angles in the gap
                kept_angles = [a for a in all_angles if not in_gap(a)]
            else:
                # keep all angles
                kept_angles = all_angles
            
            for a in kept_angles:
                self.samples.append( (digit, a) )
        
        # Pre-rotate all images
        self.rotated_images = []
        for (digit, angle_rad) in self.samples:
            # rotate that digit's reference image
            angle_deg = angle_rad * 180.0 / np.pi
            rotated = rotate(
                self.reference_images[digit].numpy(),
                angle_deg,
                reshape=False,
                order=1,
                mode='constant',
                cval=0.0
            )
            self.rotated_images.append( torch.tensor(rotated, dtype=torch.float32) )
        
        print(f"Created MultiDigit dataset with total {len(self.samples)} samples.")
        print(f"Digit=4 has gap in angles [{gap_start}°, {gap_end}°]. Others have full range.")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        digit, angle_rad = self.samples[idx]
        img = self.rotated_images[idx]  # shape (28, 28)
        
        angle_sin = np.sin(angle_rad)
        angle_cos = np.cos(angle_rad)
        
        # Our input: [digit, cos(θ), sin(θ)] or just [cos, sin] if you want
        # But we might want the digit as well. For now let's store it as an int in a separate field.
        # The model currently only expects 2D input (cos & sin). We'll adapt that next.
        
        # If you want to feed the digit as well, you'd do something like:
        #   return (torch.tensor([digit, angle_cos, angle_sin]),  img.unsqueeze(0))
        # But that requires changing the model to input 3D or do an embedding for the digit.
        
        # If you ONLY feed (cos, sin), the model won't know about digit differences.
        # We'll keep it consistent with your current code that expects (2,).
        return (torch.tensor([angle_cos, angle_sin], dtype=torch.float32),
                img.unsqueeze(0))


# Instead of RotatedMNISTDataset(digit=4,...), do:
dataset = MultiDigitRotatedMNISTDataset(
    num_angles_per_digit=100,
    gap_start=140,
    gap_end=200
)


def evaluate_digit4_gap(model, device, num_points=20, store=True):
    """
    Evaluate the model's performance specifically on digit=4
    in the previously skipped gap region [140°, 200°].
    """
    model.eval()
    gap_angles_deg = np.linspace(140, 200, num_points)
    gap_angles_rad = gap_angles_deg * np.pi / 180.0
    
    with torch.no_grad():
        # Build input: for each angle, input is [cos(θ), sin(θ)]
        # The model does not currently take the digit as input, 
        # so it does NOT know that we're generating digit=4 specifically. 
        # It's just going to produce a rotation it thinks is "closest" 
        # in its learned manifold.
        angles_t = torch.tensor([ [np.cos(a), np.sin(a)] for a in gap_angles_rad ], 
                                dtype=torch.float32, device=device)
        
        outputs = model(angles_t)  # shape: (num_points, 1, 28, 28)
    
    plt.figure(figsize=(18, 4))
    for i in range(num_points):
        plt.subplot(2, num_points//2, i+1)
        plt.imshow(outputs[i, 0].cpu().numpy(), cmap='gray')
        plt.axis('off')
        plt.title(f"{gap_angles_deg[i]:.1f}° (digit=4 OOD)")
    
    plt.suptitle("Digit 4: OOD angles [140°,200°]", color='red')
    if store:
        plt.savefig('./icml_figs_temp/interpolated_rotated_images_digit4_gap.png', 
                    bbox_inches='tight')
    else:
        plt.show()


def visualize_digit4_umap(model, device='cpu', store=True):
    """
    1) Sample 360 angles from 0..2π specifically for digit=4.
    2) Hook each layer, collect activations + final outputs.
    3) Color ID vs. OOD in [140°, 200°].
    """
    model.eval()
    model.to(device)
    
    # 360 angles
    num_samples = 360
    angles_rad = torch.linspace(0, 2*np.pi, num_samples, device=device)
    # For each angle, input is [cos(θ), sin(θ)]
    inputs = torch.stack([torch.cos(angles_rad), torch.sin(angles_rad)], dim=1)
    
    # Hook machinery
    collected_activations = {}
    def flatten_activation(x):
        return x.view(x.size(0), -1)
    def make_hook(layer_name):
        def hook_fn(module, inp, out):
            collected_activations[layer_name] = out.detach().cpu()
        return hook_fn

    handles = []
    handles.append(model.fc.register_forward_hook(make_hook("fc")))
    for i, layer in enumerate(model.decoder):
        lname = f"decoder_{i}_{layer.__class__.__name__}"
        handles.append(layer.register_forward_hook(make_hook(lname)))
    
    with torch.no_grad():
        outputs = model(inputs)  # shape (360, 1, 28, 28)
    
    # Remove hooks
    for h in handles:
        h.remove()
    
    # Also store final outputs
    collected_activations["final_output"] = outputs.detach().cpu()
    
    # ID vs. OOD mask
    angles_deg = angles_rad.cpu().numpy() * 180 / np.pi
    ood_mask = (angles_deg >= 140) & (angles_deg <= 200)
    id_mask  = ~ood_mask
    
    # Plot
    layer_names = sorted(collected_activations.keys())
    fig, axes = plt.subplots(nrows=len(layer_names), ncols=1, figsize=(8, 4*len(layer_names)))
    if len(layer_names) == 1:
        axes = [axes]  # ensure list

    for row_idx, layer_name in enumerate(layer_names):
        ax = axes[row_idx]
        act = flatten_activation(collected_activations[layer_name]).numpy()  # (360, ?)
        
        # UMAP
        import umap
        reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, random_state=42)
        embedding = reducer.fit_transform(act)  # (360, 2)
        
        ax.scatter(embedding[id_mask, 0], embedding[id_mask, 1], 
                   c='blue', label='ID', s=30, alpha=0.7)
        ax.scatter(embedding[ood_mask, 0], embedding[ood_mask, 1],
                   c='red',  label='OOD', s=30, alpha=0.7)
        
        ax.set_title(f"Digit 4 only: UMAP of '{layer_name}'", fontsize=10)
        ax.set_xlabel("UMAP-1")
        ax.set_ylabel("UMAP-2")
        ax.grid(True)
        ax.legend()

    plt.tight_layout()
    if store:
        plt.savefig("./icml_figs_temp/umap_digit4_layers.png", bbox_inches='tight')
    else:
        plt.show()