In [None]:
import matplotlib.pyplot as plt
import torch
import numpy as np
from collections import Counter

def visualize_memory_samples(dataset, num_samples=5, figsize=(16, 3)):
    """
    Visualize memory task samples with detailed annotations

    Args:
        dataset: Your visual memory dataset
        num_samples: Number of samples to visualize
        figsize: Figure size per sample
    """

    fig, axes = plt.subplots(num_samples, 6, figsize=(figsize[0], figsize[1] * num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)

    fig.suptitle('Visual Memory Task Samples', fontsize=16, y=0.98)

    for i in range(num_samples):
        # Get sample with metadata
        sequence, label, metadata = dataset.get_sample_with_metadata(i)
        seq_len = sequence.shape[0]

        target_digit = metadata['target_digit']
        probe_digit = metadata['probe_digit']
        is_match = metadata['is_match']

        print(f"\n📋 Sample {i+1}:")
        print(f"  🎯 Target digit: {target_digit}")
        print(f"  🔍 Probe digit: {probe_digit}")
        print(f"  ✅ Result: {'MATCH' if is_match else 'NO MATCH'}")
        print(f"  🏷️ Label: {label}")
        print(f"  📏 Sequence length: {seq_len}")

        # Display sequence
        for j in range(min(seq_len, 6)):  # Show up to 6 images
            ax = axes[i, j]

            # Handle different tensor formats
            img = sequence[j]
            if len(img.shape) == 3:
                img = img.squeeze()
            img = img.numpy()

            ax.imshow(img, cmap='gray', vmin=0, vmax=1)
            ax.axis('off')

            if j == 0:
                # Target image
                ax.set_title(f'TARGET\n(digit: {target_digit})',
                           fontweight='bold', color='blue', fontsize=10)
                ax.add_patch(plt.Rectangle((0, 0), img.shape[1]-1, img.shape[0]-1,
                                         fill=False, edgecolor='blue', linewidth=2))
            elif j == seq_len - 1:
                # Probe image
                match_text = 'MATCH' if is_match else 'NO MATCH'
                color = 'green' if is_match else 'red'
                ax.set_title(f'PROBE\n(digit: {probe_digit})\n{match_text}',
                           fontweight='bold', color=color, fontsize=9)
                ax.add_patch(plt.Rectangle((0, 0), img.shape[1]-1, img.shape[0]-1,
                                         fill=False, edgecolor=color, linewidth=2))
            else:
                # Distractor/noise image
                ax.set_title(f'NOISE {j}', color='gray', fontsize=9)

        # Hide unused subplots
        for j in range(seq_len, 6):
            axes[i, j].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
import torch
import random
import numpy as np
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from collections import Counter, defaultdict
import matplotlib.pyplot as plt

class OptimizedNoiseGenerator:
    """Optimized noise generator with batch processing capabilities"""

    def __init__(self):
        # Pre-compute some values for efficiency
        self.patch_sizes = [2, 3, 4]

    def gaussian_noise(self, shape, noise_level=0.5):
        """Generate Gaussian noise image - vectorized"""
        noise = torch.randn(shape) * noise_level
        return torch.clamp(noise, 0, 1)

    def salt_pepper_noise(self, shape, noise_density=0.3):
        """Generate salt and pepper noise - optimized"""
        noise = torch.rand(shape)
        result = torch.rand(shape) * 0.5 + 0.25  # Base gray level

        # Vectorized salt and pepper assignment
        salt_mask = noise < noise_density/2
        pepper_mask = noise > 1 - noise_density/2

        result[salt_mask] = 1.0
        result[pepper_mask] = 0.0

        return result

    def random_patches(self, shape, num_patches=10, patch_size=5):
        """Generate random square patches - optimized"""
        noise = torch.rand(shape) * 0.3 + 0.35
        h, w = shape[-2:]

        # Pre-compute valid positions
        max_x = max(0, w - patch_size)
        max_y = max(0, h - patch_size)

        if max_x > 0 and max_y > 0:
            for _ in range(num_patches):
                x = random.randint(0, max_x)
                y = random.randint(0, max_y)
                intensity = random.random()
                noise[..., y:y+patch_size, x:x+patch_size] = intensity

        return noise

    def scrambled_mnist_batch(self, images, scramble_intensity=0.8):
        """Create scrambled versions from a batch of images"""
        results = []
        for image in images:
            if random.random() > scramble_intensity:
                results.append(image)
                continue

            img_np = image.squeeze().numpy()
            h, w = img_np.shape

            block_size = random.choice(self.patch_sizes)
            for i in range(0, h - block_size, block_size):
                for j in range(0, w - block_size, block_size):
                    block = img_np[i:i+block_size, j:j+block_size].flatten()
                    np.random.shuffle(block)
                    img_np[i:i+block_size, j:j+block_size] = block.reshape(block_size, block_size)

            results.append(torch.tensor(img_np).unsqueeze(0).float())

        return results


class OptimizedVisualMemoryDataset(Dataset):
    """Heavily optimized version of Visual Memory Dataset"""

    def __init__(self, mnist_dataset, num_samples=1000, num_distractors=3,
                 noise_types=['gaussian', 'salt_pepper', 'patches', 'scrambled'],
                 match_probability=0.5, batch_size=100):

        self.mnist_dataset = mnist_dataset
        self.num_samples = num_samples
        self.num_distractors = num_distractors
        self.noise_types = noise_types
        self.match_probability = match_probability
        self.noise_generator = OptimizedNoiseGenerator()

        print(f"🚀 Generating {num_samples} optimized visual memory samples...")

        # Step 1: Pre-index MNIST data by digit for faster access
        print("📋 Pre-indexing MNIST data by digit...")
        self._preindex_mnist_data()

        # Step 2: Pre-sample all required indices
        print("🎯 Pre-sampling required indices...")
        self._presample_indices()

        # Step 3: Batch generate samples
        print("⚡ Batch generating samples...")
        self.samples = []
        self.labels = []
        self.metadata = []

        self._batch_generate_samples(batch_size)

        # Print statistics
        self._print_statistics()

    def _preindex_mnist_data(self):
        """Pre-index MNIST data by digit for O(1) access"""
        self.digit_indices = defaultdict(list)

        # Build index once
        for idx, (_, digit) in enumerate(self.mnist_dataset):
            self.digit_indices[digit].append(idx)

        # Convert to lists for faster random access
        for digit in self.digit_indices:
            self.digit_indices[digit] = list(self.digit_indices[digit])

        print(f"   Indexed {len(self.digit_indices)} digit classes")
        for digit, indices in self.digit_indices.items():
            print(f"   Digit {digit}: {len(indices)} samples")

    def _presample_indices(self):
        """Pre-sample all indices needed for the entire dataset"""
        # Determine matches vs non-matches
        num_matches = int(self.num_samples * self.match_probability)
        self.match_labels = [1] * num_matches + [0] * (self.num_samples - num_matches)
        random.shuffle(self.match_labels)

        # Pre-sample target indices
        all_indices = list(range(len(self.mnist_dataset)))
        self.target_indices = random.choices(all_indices, k=self.num_samples)

        # Pre-sample probe indices based on match/no-match
        self.probe_indices = []
        for i, is_match in enumerate(self.match_labels):
            target_idx = self.target_indices[i]
            _, target_digit = self.mnist_dataset[target_idx]

            if is_match:
                # Find same digit, different image
                same_digit_candidates = [idx for idx in self.digit_indices[target_digit]
                                       if idx != target_idx]
                if same_digit_candidates:
                    probe_idx = random.choice(same_digit_candidates)
                else:
                    probe_idx = target_idx  # Fallback to same image
            else:
                # Find different digit
                different_digits = [d for d in self.digit_indices.keys() if d != target_digit]
                if different_digits:
                    different_digit = random.choice(different_digits)
                    probe_idx = random.choice(self.digit_indices[different_digit])
                else:
                    # Fallback - find any different image
                    probe_idx = random.choice([idx for idx in all_indices if idx != target_idx])

            self.probe_indices.append(probe_idx)

    def _batch_generate_samples(self, batch_size):
        """Generate samples in batches for efficiency"""

        for batch_start in tqdm(range(0, self.num_samples, batch_size),
                               desc="Generating batches"):
            batch_end = min(batch_start + batch_size, self.num_samples)
            batch_samples, batch_labels, batch_metadata = self._create_batch(
                batch_start, batch_end)

            self.samples.extend(batch_samples)
            self.labels.extend(batch_labels)
            self.metadata.extend(batch_metadata)

    def _create_batch(self, start_idx, end_idx):
        """Create a batch of samples efficiently"""
        batch_samples = []
        batch_labels = []
        batch_metadata = []

        # Pre-load all needed MNIST images for this batch
        target_indices_batch = self.target_indices[start_idx:end_idx]
        probe_indices_batch = self.probe_indices[start_idx:end_idx]

        # Load target images
        target_images = []
        target_digits = []
        for idx in target_indices_batch:
            img, digit = self.mnist_dataset[idx]
            target_images.append(img)
            target_digits.append(digit)

        # Load probe images
        probe_images = []
        probe_digits = []
        for idx in probe_indices_batch:
            img, digit = self.mnist_dataset[idx]
            probe_images.append(img)
            probe_digits.append(digit)

        # Generate noise images in batch
        batch_size = end_idx - start_idx
        noise_images_batch = self._generate_noise_batch(
            target_images[0].shape, batch_size * self.num_distractors)

        # Assemble sequences
        for i in range(batch_size):
            # Get images for this sample
            target_img = target_images[i]
            probe_img = probe_images[i]

            # Get noise images for this sample
            noise_start = i * self.num_distractors
            noise_end = noise_start + self.num_distractors
            noise_imgs = noise_images_batch[noise_start:noise_end]

            # Create sequence: [target] + [noise images] + [probe]
            sequence = [target_img] + noise_imgs + [probe_img]
            sequence_tensor = torch.stack(sequence)

            # Create metadata
            global_idx = start_idx + i
            metadata = {
                'target_digit': target_digits[i],
                'probe_digit': probe_digits[i],
                'is_match': bool(self.match_labels[global_idx]),
                'target_idx': target_indices_batch[i],
                'probe_idx': probe_indices_batch[i],
                'sequence_length': len(sequence)
            }

            batch_samples.append(sequence_tensor)
            batch_labels.append(self.match_labels[global_idx])
            batch_metadata.append(metadata)

        return batch_samples, batch_labels, batch_metadata

    def _generate_noise_batch(self, shape, total_noise_images):
        """Generate a batch of noise images efficiently"""
        noise_images = []

        # Group by noise type for batch processing
        noise_counts = {noise_type: 0 for noise_type in self.noise_types}
        noise_assignments = []

        for _ in range(total_noise_images):
            noise_type = random.choice(self.noise_types)
            noise_assignments.append(noise_type)
            noise_counts[noise_type] += 1

        # Generate each type in batch
        noise_by_type = {}
        for noise_type, count in noise_counts.items():
            if count > 0:
                if noise_type == 'scrambled':
                    # For scrambled, we need actual MNIST images
                    random_indices = random.choices(range(len(self.mnist_dataset)), k=count)
                    random_images = [self.mnist_dataset[idx][0] for idx in random_indices]
                    noise_by_type[noise_type] = self.noise_generator.scrambled_mnist_batch(
                        random_images)
                else:
                    # Generate other noise types in batch
                    batch_shape = (count,) + shape
                    if noise_type == 'gaussian':
                        batch_noise = self.noise_generator.gaussian_noise(batch_shape, 0.6)
                    elif noise_type == 'salt_pepper':
                        batch_noise = self.noise_generator.salt_pepper_noise(batch_shape, 0.4)
                    elif noise_type == 'patches':
                        batch_noise = torch.stack([
                            self.noise_generator.random_patches(shape, 15, 4)
                            for _ in range(count)
                        ])

                    noise_by_type[noise_type] = [batch_noise[i] for i in range(count)]

        # Reconstruct in original order
        type_counters = {noise_type: 0 for noise_type in self.noise_types}
        for noise_type in noise_assignments:
            noise_images.append(noise_by_type[noise_type][type_counters[noise_type]])
            type_counters[noise_type] += 1

        return noise_images

    def _print_statistics(self):
        """Print dataset statistics"""
        match_count = sum(self.labels)
        print(f"\n✅ Dataset created:")
        print(f"  Total samples: {self.num_samples}")
        print(f"  Match trials: {match_count} ({match_count/self.num_samples*100:.1f}%)")
        print(f"  No-match trials: {self.num_samples-match_count} ({(self.num_samples-match_count)/self.num_samples*100:.1f}%)")
        print(f"  Sequence length: {self.num_distractors + 2} (1 target + {self.num_distractors} distractors + 1 probe)")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx], self.labels[idx]

    def get_sample_with_metadata(self, idx):
        return self.samples[idx], self.labels[idx], self.metadata[idx]

    def get_statistics(self):
        """Get detailed dataset statistics"""
        match_count = sum(self.labels)
        total = len(self.labels)

        target_digits = [meta['target_digit'] for meta in self.metadata]
        probe_digits = [meta['probe_digit'] for meta in self.metadata]

        target_counter = Counter(target_digits)
        probe_counter = Counter(probe_digits)

        return {
            'total_samples': total,
            'match_trials': match_count,
            'nomatch_trials': total - match_count,
            'match_percentage': match_count / total * 100,
            'target_digit_distribution': dict(target_counter),
            'probe_digit_distribution': dict(probe_counter),
            'sequence_length': self.num_distractors + 2
        }


def create_optimized_datasets_for_training(train_size=3000, test_size=1000,
                                         num_distractors=3, batch_size=100):
    """Create optimized datasets for training - much faster"""

    print("="*60)
    print("CREATING OPTIMIZED TRAINING DATASETS")
    print("="*60)

    # Import here to avoid issues if not available
    from torchvision import datasets, transforms

    # Define transform
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    # Load MNIST datasets
    print("📁 Loading MNIST datasets...")
    mnist_train = datasets.MNIST('./data', train=True, download=True, transform=transform)
    mnist_test = datasets.MNIST('./data', train=False, download=True, transform=transform)

    # Training dataset
    print("\n🚂 Creating optimized training dataset...")
    train_memory_dataset = OptimizedVisualMemoryDataset(
        mnist_train,
        num_samples=train_size,
        num_distractors=num_distractors,
        noise_types=['gaussian', 'salt_pepper', 'patches', 'scrambled'],
        match_probability=0.5,
        batch_size=batch_size
    )

    # Testing dataset
    test_memory_dataset = OptimizedVisualMemoryDataset(
        mnist_test,
        num_samples=test_size,
        num_distractors=num_distractors,
        noise_types=['gaussian', 'salt_pepper', 'patches', 'scrambled'],
        match_probability=0.5,
        batch_size=batch_size
    )

    # Create DataLoaders
    train_loader = DataLoader(train_memory_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_memory_dataset, batch_size=32, shuffle=False)


    return train_memory_dataset, test_memory_dataset, train_loader, test_loader


# Compatibility function with original interface
def create_datasets_for_training(train_size=3000, test_size=1000, num_distractors=3):
    """Drop-in replacement for original function with massive speedup"""
    return create_optimized_datasets_for_training(
        train_size=train_size,
        test_size=test_size,
        num_distractors=num_distractors,
        batch_size=min(100, train_size // 10)  # Adaptive batch size
    )


# Example usage


In [None]:
train_dataset, test_dataset, train_loader, test_loader = create_datasets_for_training(
    train_size=1000,
    test_size=300,
    num_distractors=2
)


CREATING OPTIMIZED TRAINING DATASETS
📁 Loading MNIST datasets...

🚂 Creating optimized training dataset...
🚀 Generating 1000 optimized visual memory samples...
📋 Pre-indexing MNIST data by digit...
   Indexed 10 digit classes
   Digit 5: 5421 samples
   Digit 0: 5923 samples
   Digit 4: 5842 samples
   Digit 1: 6742 samples
   Digit 9: 5949 samples
   Digit 2: 5958 samples
   Digit 3: 6131 samples
   Digit 6: 5918 samples
   Digit 7: 6265 samples
   Digit 8: 5851 samples
🎯 Pre-sampling required indices...
⚡ Batch generating samples...


Generating batches: 100%|██████████| 10/10 [00:00<00:00, 10.98it/s]



✅ Dataset created:
  Total samples: 1000
  Match trials: 500 (50.0%)
  No-match trials: 500 (50.0%)
  Sequence length: 4 (1 target + 2 distractors + 1 probe)

🧪 Creating optimized testing dataset...
🚀 Generating 300 optimized visual memory samples...
📋 Pre-indexing MNIST data by digit...
   Indexed 10 digit classes
   Digit 7: 1028 samples
   Digit 2: 1032 samples
   Digit 1: 1135 samples
   Digit 0: 980 samples
   Digit 4: 982 samples
   Digit 9: 1009 samples
   Digit 5: 892 samples
   Digit 6: 958 samples
   Digit 3: 1010 samples
   Digit 8: 974 samples
🎯 Pre-sampling required indices...
⚡ Batch generating samples...


Generating batches: 100%|██████████| 3/3 [00:00<00:00, 11.62it/s]


✅ Dataset created:
  Total samples: 300
  Match trials: 150 (50.0%)
  No-match trials: 150 (50.0%)
  Sequence length: 4 (1 target + 2 distractors + 1 probe)

🎉 Optimized datasets ready for training!
Training samples: 1000
Testing samples: 300
Sequence length: 4
Expected speedup: 10-50x faster than original

🚀 Dataset creation completed successfully!
Ready for training!





In [None]:
visualize_memory_samples(train_dataset, num_samples=3)