In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
cd drive/MyDrive/Neuromatch_project/

In [None]:
!pip install torch torchvision matplotlib

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import random

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## data generate function

In [None]:
import matplotlib.pyplot as plt
import torch
import numpy as np
from collections import Counter

def visualize_memory_samples(dataset, num_samples=5, figsize=(16, 3)):
    """
    Visualize memory task samples with detailed annotations

    Args:
        dataset: Your visual memory dataset
        num_samples: Number of samples to visualize
        figsize: Figure size per sample
    """

    fig, axes = plt.subplots(num_samples, 6, figsize=(figsize[0], figsize[1] * num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)

    fig.suptitle('Visual Memory Task Samples', fontsize=16, y=0.98)

    for i in range(num_samples):
        # Get sample with metadata
        sequence, label, metadata = dataset.get_sample_with_metadata(i)
        seq_len = sequence.shape[0]

        target_digit = metadata['target_digit']
        probe_digit = metadata['probe_digit']
        is_match = metadata['is_match']

        print(f"\n📋 Sample {i+1}:")
        print(f"  🎯 Target digit: {target_digit}")
        print(f"  🔍 Probe digit: {probe_digit}")
        print(f"  ✅ Result: {'MATCH' if is_match else 'NO MATCH'}")
        print(f"  🏷️ Label: {label}")
        print(f"  📏 Sequence length: {seq_len}")

        # Display sequence
        for j in range(min(seq_len, 6)):  # Show up to 6 images
            ax = axes[i, j]

            # Handle different tensor formats
            img = sequence[j]
            if len(img.shape) == 3:
                img = img.squeeze()
            img = img.numpy()

            ax.imshow(img, cmap='gray', vmin=0, vmax=1)
            ax.axis('off')

            if j == 0:
                # Target image
                ax.set_title(f'TARGET\n(digit: {target_digit})',
                           fontweight='bold', color='blue', fontsize=10)
                ax.add_patch(plt.Rectangle((0, 0), img.shape[1]-1, img.shape[0]-1,
                                         fill=False, edgecolor='blue', linewidth=2))
            elif j == seq_len - 1:
                # Probe image
                match_text = 'MATCH' if is_match else 'NO MATCH'
                color = 'green' if is_match else 'red'
                ax.set_title(f'PROBE\n(digit: {probe_digit})\n{match_text}',
                           fontweight='bold', color=color, fontsize=9)
                ax.add_patch(plt.Rectangle((0, 0), img.shape[1]-1, img.shape[0]-1,
                                         fill=False, edgecolor=color, linewidth=2))
            else:
                # Distractor/noise image
                ax.set_title(f'NOISE {j}', color='gray', fontsize=9)

        # Hide unused subplots
        for j in range(seq_len, 6):
            axes[i, j].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
import torch
import random
import numpy as np
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from collections import Counter, defaultdict
import matplotlib.pyplot as plt

class OptimizedNoiseGenerator:
    """Optimized noise generator with batch processing capabilities"""

    def __init__(self):
        # Pre-compute some values for efficiency
        self.patch_sizes = [2, 3, 4]

    def gaussian_noise(self, shape, noise_level=0.5):
        """Generate Gaussian noise image - vectorized"""
        noise = torch.randn(shape) * noise_level
        return torch.clamp(noise, 0, 1)

    def salt_pepper_noise(self, shape, noise_density=0.3):
        """Generate salt and pepper noise - optimized"""
        noise = torch.rand(shape)
        result = torch.rand(shape) * 0.5 + 0.25  # Base gray level

        # Vectorized salt and pepper assignment
        salt_mask = noise < noise_density/2
        pepper_mask = noise > 1 - noise_density/2

        result[salt_mask] = 1.0
        result[pepper_mask] = 0.0

        return result

    def random_patches(self, shape, num_patches=10, patch_size=5):
        """Generate random square patches - optimized"""
        noise = torch.rand(shape) * 0.3 + 0.35
        h, w = shape[-2:]

        # Pre-compute valid positions
        max_x = max(0, w - patch_size)
        max_y = max(0, h - patch_size)

        if max_x > 0 and max_y > 0:
            for _ in range(num_patches):
                x = random.randint(0, max_x)
                y = random.randint(0, max_y)
                intensity = random.random()
                noise[..., y:y+patch_size, x:x+patch_size] = intensity

        return noise

    def scrambled_mnist_batch(self, images, scramble_intensity=0.8):
        """Create scrambled versions from a batch of images"""
        results = []
        for image in images:
            if random.random() > scramble_intensity:
                results.append(image)
                continue

            img_np = image.squeeze().numpy()
            h, w = img_np.shape

            block_size = random.choice(self.patch_sizes)
            for i in range(0, h - block_size, block_size):
                for j in range(0, w - block_size, block_size):
                    block = img_np[i:i+block_size, j:j+block_size].flatten()
                    np.random.shuffle(block)
                    img_np[i:i+block_size, j:j+block_size] = block.reshape(block_size, block_size)

            results.append(torch.tensor(img_np).unsqueeze(0).float())

        return results


class OptimizedVisualMemoryDataset(Dataset):
    """Heavily optimized version of Visual Memory Dataset"""

    def __init__(self, mnist_dataset, num_samples=1000, num_distractors=3,
                 noise_types=['gaussian', 'salt_pepper', 'patches', 'scrambled'],
                 match_probability=0.5, batch_size=100):

        self.mnist_dataset = mnist_dataset
        self.num_samples = num_samples
        self.num_distractors = num_distractors
        self.noise_types = noise_types
        self.match_probability = match_probability
        self.noise_generator = OptimizedNoiseGenerator()

        print(f"🚀 Generating {num_samples} optimized visual memory samples...")

        # Step 1: Pre-index MNIST data by digit for faster access
        print("📋 Pre-indexing MNIST data by digit...")
        self._preindex_mnist_data()

        # Step 2: Pre-sample all required indices
        print("🎯 Pre-sampling required indices...")
        self._presample_indices()

        # Step 3: Batch generate samples
        print("⚡ Batch generating samples...")
        self.samples = []
        self.labels = []
        self.metadata = []

        self._batch_generate_samples(batch_size)

        # Print statistics
        self._print_statistics()

    def _preindex_mnist_data(self):
        """Pre-index MNIST data by digit for O(1) access"""
        self.digit_indices = defaultdict(list)

        # Build index once
        for idx, (_, digit) in enumerate(self.mnist_dataset):
            self.digit_indices[digit].append(idx)

        # Convert to lists for faster random access
        for digit in self.digit_indices:
            self.digit_indices[digit] = list(self.digit_indices[digit])

        print(f"   Indexed {len(self.digit_indices)} digit classes")
        for digit, indices in self.digit_indices.items():
            print(f"   Digit {digit}: {len(indices)} samples")

    def _presample_indices(self):
        """Pre-sample all indices needed for the entire dataset"""
        # Determine matches vs non-matches
        num_matches = int(self.num_samples * self.match_probability)
        self.match_labels = [1] * num_matches + [0] * (self.num_samples - num_matches)
        random.shuffle(self.match_labels)

        # Pre-sample target indices
        all_indices = list(range(len(self.mnist_dataset)))
        self.target_indices = random.choices(all_indices, k=self.num_samples)

        # Pre-sample probe indices based on match/no-match
        self.probe_indices = []
        for i, is_match in enumerate(self.match_labels):
            target_idx = self.target_indices[i]
            _, target_digit = self.mnist_dataset[target_idx]

            if is_match:
                # Find same digit, different image
                same_digit_candidates = [idx for idx in self.digit_indices[target_digit]
                                       if idx != target_idx]
                if same_digit_candidates:
                    probe_idx = random.choice(same_digit_candidates)
                else:
                    probe_idx = target_idx  # Fallback to same image
            else:
                # Find different digit
                different_digits = [d for d in self.digit_indices.keys() if d != target_digit]
                if different_digits:
                    different_digit = random.choice(different_digits)
                    probe_idx = random.choice(self.digit_indices[different_digit])
                else:
                    # Fallback - find any different image
                    probe_idx = random.choice([idx for idx in all_indices if idx != target_idx])

            self.probe_indices.append(probe_idx)

    def _batch_generate_samples(self, batch_size):
        """Generate samples in batches for efficiency"""

        for batch_start in tqdm(range(0, self.num_samples, batch_size),
                               desc="Generating batches"):
            batch_end = min(batch_start + batch_size, self.num_samples)
            batch_samples, batch_labels, batch_metadata = self._create_batch(
                batch_start, batch_end)

            self.samples.extend(batch_samples)
            self.labels.extend(batch_labels)
            self.metadata.extend(batch_metadata)

    def _create_batch(self, start_idx, end_idx):
        """Create a batch of samples efficiently"""
        batch_samples = []
        batch_labels = []
        batch_metadata = []

        # Pre-load all needed MNIST images for this batch
        target_indices_batch = self.target_indices[start_idx:end_idx]
        probe_indices_batch = self.probe_indices[start_idx:end_idx]

        # Load target images
        target_images = []
        target_digits = []
        for idx in target_indices_batch:
            img, digit = self.mnist_dataset[idx]
            target_images.append(img)
            target_digits.append(digit)

        # Load probe images
        probe_images = []
        probe_digits = []
        for idx in probe_indices_batch:
            img, digit = self.mnist_dataset[idx]
            probe_images.append(img)
            probe_digits.append(digit)

        # Generate noise images in batch
        batch_size = end_idx - start_idx
        noise_images_batch = self._generate_noise_batch(
            target_images[0].shape, batch_size * self.num_distractors)

        # Assemble sequences
        for i in range(batch_size):
            # Get images for this sample
            target_img = target_images[i]
            probe_img = probe_images[i]

            # Get noise images for this sample
            noise_start = i * self.num_distractors
            noise_end = noise_start + self.num_distractors
            noise_imgs = noise_images_batch[noise_start:noise_end]

            # Create sequence: [target] + [noise images] + [probe]
            sequence = [target_img] + noise_imgs + [probe_img]
            sequence_tensor = torch.stack(sequence)

            # Create metadata
            global_idx = start_idx + i
            metadata = {
                'target_digit': target_digits[i],
                'probe_digit': probe_digits[i],
                'is_match': bool(self.match_labels[global_idx]),
                'target_idx': target_indices_batch[i],
                'probe_idx': probe_indices_batch[i],
                'sequence_length': len(sequence)
            }

            batch_samples.append(sequence_tensor)
            batch_labels.append(self.match_labels[global_idx])
            batch_metadata.append(metadata)

        return batch_samples, batch_labels, batch_metadata

    def _generate_noise_batch(self, shape, total_noise_images):
        """Generate a batch of noise images efficiently"""
        noise_images = []

        # Group by noise type for batch processing
        noise_counts = {noise_type: 0 for noise_type in self.noise_types}
        noise_assignments = []

        for _ in range(total_noise_images):
            noise_type = random.choice(self.noise_types)
            noise_assignments.append(noise_type)
            noise_counts[noise_type] += 1

        # Generate each type in batch
        noise_by_type = {}
        for noise_type, count in noise_counts.items():
            if count > 0:
                if noise_type == 'scrambled':
                    # For scrambled, we need actual MNIST images
                    random_indices = random.choices(range(len(self.mnist_dataset)), k=count)
                    random_images = [self.mnist_dataset[idx][0] for idx in random_indices]
                    noise_by_type[noise_type] = self.noise_generator.scrambled_mnist_batch(
                        random_images)
                else:
                    # Generate other noise types in batch
                    batch_shape = (count,) + shape
                    if noise_type == 'gaussian':
                        batch_noise = self.noise_generator.gaussian_noise(batch_shape, 0.6)
                    elif noise_type == 'salt_pepper':
                        batch_noise = self.noise_generator.salt_pepper_noise(batch_shape, 0.4)
                    elif noise_type == 'patches':
                        batch_noise = torch.stack([
                            self.noise_generator.random_patches(shape, 15, 4)
                            for _ in range(count)
                        ])

                    noise_by_type[noise_type] = [batch_noise[i] for i in range(count)]

        # Reconstruct in original order
        type_counters = {noise_type: 0 for noise_type in self.noise_types}
        for noise_type in noise_assignments:
            noise_images.append(noise_by_type[noise_type][type_counters[noise_type]])
            type_counters[noise_type] += 1

        return noise_images

    def _print_statistics(self):
        """Print dataset statistics"""
        match_count = sum(self.labels)
        print(f"\n✅ Dataset created:")
        print(f"  Total samples: {self.num_samples}")
        print(f"  Match trials: {match_count} ({match_count/self.num_samples*100:.1f}%)")
        print(f"  No-match trials: {self.num_samples-match_count} ({(self.num_samples-match_count)/self.num_samples*100:.1f}%)")
        print(f"  Sequence length: {self.num_distractors + 2} (1 target + {self.num_distractors} distractors + 1 probe)")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx], self.labels[idx]

    def get_sample_with_metadata(self, idx):
        return self.samples[idx], self.labels[idx], self.metadata[idx]

    def get_statistics(self):
        """Get detailed dataset statistics"""
        match_count = sum(self.labels)
        total = len(self.labels)

        target_digits = [meta['target_digit'] for meta in self.metadata]
        probe_digits = [meta['probe_digit'] for meta in self.metadata]

        target_counter = Counter(target_digits)
        probe_counter = Counter(probe_digits)

        return {
            'total_samples': total,
            'match_trials': match_count,
            'nomatch_trials': total - match_count,
            'match_percentage': match_count / total * 100,
            'target_digit_distribution': dict(target_counter),
            'probe_digit_distribution': dict(probe_counter),
            'sequence_length': self.num_distractors + 2
        }


def create_optimized_datasets_for_training(train_size=3000, test_size=1000,
                                         num_distractors=3, batch_size=100):
    """Create optimized datasets for training - much faster"""

    print("="*60)
    print("CREATING OPTIMIZED TRAINING DATASETS")
    print("="*60)

    # Import here to avoid issues if not available
    from torchvision import datasets, transforms

    # Define transform
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    # Load MNIST datasets
    print("📁 Loading MNIST datasets...")
    mnist_train = datasets.MNIST('./data', train=True, download=True, transform=transform)
    mnist_test = datasets.MNIST('./data', train=False, download=True, transform=transform)

    # Training dataset
    print("\n🚂 Creating optimized training dataset...")
    train_memory_dataset = OptimizedVisualMemoryDataset(
        mnist_train,
        num_samples=train_size,
        num_distractors=num_distractors,
        noise_types=['gaussian', 'salt_pepper', 'patches', 'scrambled'],
        match_probability=0.5,
        batch_size=batch_size
    )

    # Testing dataset
    test_memory_dataset = OptimizedVisualMemoryDataset(
        mnist_test,
        num_samples=test_size,
        num_distractors=num_distractors,
        noise_types=['gaussian', 'salt_pepper', 'patches', 'scrambled'],
        match_probability=0.5,
        batch_size=batch_size
    )

    # Create DataLoaders
    train_loader = DataLoader(train_memory_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_memory_dataset, batch_size=32, shuffle=False)


    return train_memory_dataset, test_memory_dataset, train_loader, test_loader


# Compatibility function with original interface
def create_datasets_for_training(train_size=3000, test_size=1000, num_distractors=3):
    """Drop-in replacement for original function with massive speedup"""
    return create_optimized_datasets_for_training(
        train_size=train_size,
        test_size=test_size,
        num_distractors=num_distractors,
        batch_size=min(100, train_size // 10)  # Adaptive batch size
    )



## data generate

In [None]:
train_dataset, test_dataset, train_loader, test_loader = create_datasets_for_training(
    train_size=5000,
    test_size=1000,
    num_distractors=2
)


## save dataset function

In [None]:
# Save and Load Existing Visual Memory Datasets
import torch
import pickle
import os
from datetime import datetime

# =============================================================================
# OPTION 1: SIMPLE SAVE/LOAD (RECOMMENDED)
# =============================================================================

def save_datasets_simple(train_dataset, test_dataset, save_dir='./my_datasets/'):
    """Simple way to save your existing datasets"""

    # Create directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)

    # Create descriptive filenames
    timestamp = datetime.now().strftime("%Y%m%d_%H%M")
    train_size = len(train_dataset)
    test_size = len(test_dataset)
    num_distractors = train_dataset.num_distractors

    train_filename = f'train_{train_size}samples_{num_distractors}dist_{timestamp}.pkl'
    test_filename = f'test_{test_size}samples_{num_distractors}dist_{timestamp}.pkl'

    train_path = os.path.join(save_dir, train_filename)
    test_path = os.path.join(save_dir, test_filename)

    print(f"Saving datasets to {save_dir}")
    print(f"Training dataset: {train_filename}")
    print(f"Test dataset: {test_filename}")

    # Save training dataset
    print("Saving training dataset...")
    with open(train_path, 'wb') as f:
        pickle.dump({
            'samples': train_dataset.samples,
            'labels': train_dataset.labels,
            'metadata': train_dataset.metadata,
            'num_samples': train_dataset.num_samples,
            'num_distractors': train_dataset.num_distractors,
            'noise_types': train_dataset.noise_types,
            'match_probability': train_dataset.match_probability,
            'creation_time': datetime.now().isoformat(),
            'dataset_type': 'training'
        }, f, protocol=pickle.HIGHEST_PROTOCOL)

    # Save test dataset
    print("Saving test dataset...")
    with open(test_path, 'wb') as f:
        pickle.dump({
            'samples': test_dataset.samples,
            'labels': test_dataset.labels,
            'metadata': test_dataset.metadata,
            'num_samples': test_dataset.num_samples,
            'num_distractors': test_dataset.num_distractors,
            'noise_types': test_dataset.noise_types,
            'match_probability': test_dataset.match_probability,
            'creation_time': datetime.now().isoformat(),
            'dataset_type': 'testing'
        }, f, protocol=pickle.HIGHEST_PROTOCOL)

    # Print file sizes
    train_size_mb = os.path.getsize(train_path) / (1024 * 1024)
    test_size_mb = os.path.getsize(test_path) / (1024 * 1024)

    print(f"✅ Datasets saved successfully!")
    print(f"Training dataset: {train_size_mb:.1f} MB")
    print(f"Test dataset: {test_size_mb:.1f} MB")
    print(f"Total size: {train_size_mb + test_size_mb:.1f} MB")

    return train_path, test_path

def load_datasets_simple(train_path, test_path):
    """Simple way to load your saved datasets"""

    print(f"Loading datasets...")
    print(f"Training: {train_path}")
    print(f"Test: {test_path}")

    # Check if files exist
    if not os.path.exists(train_path):
        raise FileNotFoundError(f"Training dataset not found: {train_path}")
    if not os.path.exists(test_path):
        raise FileNotFoundError(f"Test dataset not found: {test_path}")

    # Load training dataset
    print("Loading training dataset...")
    with open(train_path, 'rb') as f:
        train_data = pickle.load(f)

    # Load test dataset
    print("Loading test dataset...")
    with open(test_path, 'rb') as f:
        test_data = pickle.load(f)

    # Create dataset objects
    train_dataset = SimpleVisualMemoryDataset(train_data)
    test_dataset = SimpleVisualMemoryDataset(test_data)

    # Print info
    print(f"✅ Datasets loaded successfully!")
    print(f"Training: {len(train_dataset)} samples")
    print(f"Test: {len(test_dataset)} samples")
    print(f"Distractors: {train_dataset.num_distractors}")
    print(f"Creation time: {train_data.get('creation_time', 'Unknown')}")

    return train_dataset, test_dataset

class SimpleVisualMemoryDataset:
    """Simple dataset class for loaded data"""
    def __init__(self, data_dict):
        self.samples = data_dict['samples']
        self.labels = data_dict['labels']
        self.metadata = data_dict['metadata']
        self.num_samples = data_dict['num_samples']
        self.num_distractors = data_dict['num_distractors']
        self.noise_types = data_dict['noise_types']
        self.match_probability = data_dict['match_probability']

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx], self.labels[idx]

    def get_sample_with_metadata(self, idx):
        return self.samples[idx], self.labels[idx], self.metadata[idx]

# =============================================================================
# OPTION 2: USING DATASET'S BUILT-IN SAVE METHOD
# =============================================================================

def save_datasets_builtin(train_dataset, test_dataset, save_dir='./my_datasets/'):
    """Use the dataset's built-in save method (if available)"""

    os.makedirs(save_dir, exist_ok=True)

    # Generate filenames
    timestamp = datetime.now().strftime("%Y%m%d_%H%M")
    train_size = len(train_dataset)
    test_size = len(test_dataset)
    num_distractors = train_dataset.num_distractors

    train_path = os.path.join(save_dir, f'train_{train_size}_{num_distractors}dist_{timestamp}.pkl')
    test_path = os.path.join(save_dir, f'test_{test_size}_{num_distractors}dist_{timestamp}.pkl')

    # Use built-in save method if your dataset has it
    if hasattr(train_dataset, 'save_dataset'):
        print("Using built-in save method...")
        train_dataset.save_dataset(train_path)
        test_dataset.save_dataset(test_path)
    else:
        print("No built-in save method, using simple save...")
        return save_datasets_simple(train_dataset, test_dataset, save_dir)

    return train_path, test_path

# =============================================================================
# CONVENIENCE FUNCTIONS
# =============================================================================

def quick_save(train_dataset, test_dataset, name="my_experiment"):
    """Quick save with a custom name"""

    save_dir = f'./saved_datasets/{name}/'
    os.makedirs(save_dir, exist_ok=True)

    timestamp = datetime.now().strftime("%Y%m%d_%H%M")
    train_path = os.path.join(save_dir, f'train_{timestamp}.pkl')
    test_path = os.path.join(save_dir, f'test_{timestamp}.pkl')

    # Save datasets
    with open(train_path, 'wb') as f:
        pickle.dump(train_dataset, f)

    with open(test_path, 'wb') as f:
        pickle.dump(test_dataset, f)

    print(f"✅ Quick save completed!")
    print(f"Saved to: {save_dir}")
    print(f"Training: train_{timestamp}.pkl")
    print(f"Test: test_{timestamp}.pkl")

    return train_path, test_path

def quick_load(train_path, test_path):
    """Quick load datasets"""

    with open(train_path, 'rb') as f:
        train_dataset = pickle.load(f)

    with open(test_path, 'rb') as f:
        test_dataset = pickle.load(f)

    print(f"✅ Quick load completed!")
    print(f"Training: {len(train_dataset)} samples")
    print(f"Test: {len(test_dataset)} samples")

    return train_dataset, test_dataset

def list_saved_datasets(save_dir='./my_datasets/'):
    """List all saved datasets in directory"""

    if not os.path.exists(save_dir):
        print(f"Directory {save_dir} doesn't exist")
        return []

    pkl_files = [f for f in os.listdir(save_dir) if f.endswith('.pkl')]

    if not pkl_files:
        print(f"No .pkl files found in {save_dir}")
        return []

    print(f"Saved datasets in {save_dir}:")
    print("=" * 50)

    datasets = []
    for file in sorted(pkl_files):
        file_path = os.path.join(save_dir, file)
        file_size_mb = os.path.getsize(file_path) / (1024 * 1024)
        mod_time = os.path.getmtime(file_path)
        mod_date = datetime.fromtimestamp(mod_time).strftime('%Y-%m-%d %H:%M:%S')

        print(f"📁 {file}")
        print(f"   Size: {file_size_mb:.1f} MB")
        print(f"   Modified: {mod_date}")

        datasets.append({
            'filename': file,
            'path': file_path,
            'size_mb': file_size_mb,
            'modified': mod_date
        })
        print()

    return datasets

# =============================================================================
# CREATE DATALOADERS FROM LOADED DATASETS
# =============================================================================

def create_dataloaders_from_datasets(train_dataset, test_dataset, batch_size=32):
    """Create DataLoaders from your loaded datasets"""

    from torch.utils.data import DataLoader

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True  # For consistent batch sizes
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        drop_last=False
    )

    print(f"✅ DataLoaders created!")
    print(f"Training: {len(train_loader)} batches of size {batch_size}")
    print(f"Test: {len(test_loader)} batches of size {batch_size}")

    return train_loader, test_loader

# =============================================================================
# COMPLETE WORKFLOW EXAMPLES
# =============================================================================

def example_save_workflow():
    """Example: Save your existing datasets"""

    print("="*60)
    print("EXAMPLE: SAVE WORKFLOW")
    print("="*60)

    # Assuming you already have your datasets:
    # train_dataset, test_dataset, train_loader, test_loader = create_datasets_for_training(...)

    print("# Step 1: Save your datasets")
    print("train_path, test_path = save_datasets_simple(train_dataset, test_dataset, './my_datasets/')")
    print()
    print("# OR quick save:")
    print("train_path, test_path = quick_save(train_dataset, test_dataset, 'experiment_1')")

def example_load_workflow():
    """Example: Load and use saved datasets"""

    print("\n" + "="*60)
    print("EXAMPLE: LOAD WORKFLOW")
    print("="*60)

    print("# Step 1: List available datasets")
    print("list_saved_datasets('./my_datasets/')")
    print()
    print("# Step 2: Load specific datasets")
    print("train_dataset, test_dataset = load_datasets_simple(")
    print("    './my_datasets/train_1000samples_3dist_20241223_1430.pkl',")
    print("    './my_datasets/test_300samples_3dist_20241223_1430.pkl'")
    print(")")
    print()
    print("# Step 3: Create DataLoaders")
    print("train_loader, test_loader = create_dataloaders_from_datasets(")
    print("    train_dataset, test_dataset, batch_size=32")
    print(")")
    print()
    print("# Step 4: Start training!")
    print("# Now you can use train_loader and test_loader for training")

# =============================================================================
# READY-TO-USE EXAMPLE
# =============================================================================

if __name__ == "__main__":
    print("DATASET SAVE/LOAD UTILITY")
    print("=" * 40)

    # Show examples
    example_save_workflow()
    example_load_workflow()

    print("\n" + "="*60)
    print("READY-TO-USE CODE FOR YOUR SITUATION")
    print("="*60)

    print("# You currently have:")
    print("# train_dataset, test_dataset, train_loader, test_loader = create_datasets_for_training(...)")
    print()
    print("# To save them:")
    print("train_path, test_path = save_datasets_simple(train_dataset, test_dataset)")
    print()
    print("# To load them later:")
    print("train_dataset, test_dataset = load_datasets_simple(train_path, test_path)")
    print("train_loader, test_loader = create_dataloaders_from_datasets(train_dataset, test_dataset)")
    print()
    print("# To see what you have saved:")
    print("list_saved_datasets()")

## save dataset

In [None]:
train_path, test_path = save_datasets_simple(train_dataset, test_dataset, './my_datasets/')

## loaded dataset (not useful if we generate the data)

In [None]:
# Save and Load Existing Visual Memory Datasets
import torch
import pickle
import os
from datetime import datetime


def save_datasets_simple(train_dataset, test_dataset, save_dir='./my_datasets/'):
    """Simple way to save your existing datasets"""

    # Create directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)

    # Create descriptive filenames
    timestamp = datetime.now().strftime("%Y%m%d_%H%M")
    train_size = len(train_dataset)
    test_size = len(test_dataset)
    num_distractors = train_dataset.num_distractors

    train_filename = f'train_{train_size}samples_{num_distractors}dist_{timestamp}.pkl'
    test_filename = f'test_{test_size}samples_{num_distractors}dist_{timestamp}.pkl'

    train_path = os.path.join(save_dir, train_filename)
    test_path = os.path.join(save_dir, test_filename)

    print(f"Saving datasets to {save_dir}")
    print(f"Training dataset: {train_filename}")
    print(f"Test dataset: {test_filename}")

    # Save training dataset
    print("Saving training dataset...")
    with open(train_path, 'wb') as f:
        pickle.dump({
            'samples': train_dataset.samples,
            'labels': train_dataset.labels,
            'metadata': train_dataset.metadata,
            'num_samples': train_dataset.num_samples,
            'num_distractors': train_dataset.num_distractors,
            'noise_types': train_dataset.noise_types,
            'match_probability': train_dataset.match_probability,
            'creation_time': datetime.now().isoformat(),
            'dataset_type': 'training'
        }, f, protocol=pickle.HIGHEST_PROTOCOL)

    # Save test dataset
    print("Saving test dataset...")
    with open(test_path, 'wb') as f:
        pickle.dump({
            'samples': test_dataset.samples,
            'labels': test_dataset.labels,
            'metadata': test_dataset.metadata,
            'num_samples': test_dataset.num_samples,
            'num_distractors': test_dataset.num_distractors,
            'noise_types': test_dataset.noise_types,
            'match_probability': test_dataset.match_probability,
            'creation_time': datetime.now().isoformat(),
            'dataset_type': 'testing'
        }, f, protocol=pickle.HIGHEST_PROTOCOL)

    # Print file sizes
    train_size_mb = os.path.getsize(train_path) / (1024 * 1024)
    test_size_mb = os.path.getsize(test_path) / (1024 * 1024)

    print(f"✅ Datasets saved successfully!")
    print(f"Training dataset: {train_size_mb:.1f} MB")
    print(f"Test dataset: {test_size_mb:.1f} MB")
    print(f"Total size: {train_size_mb + test_size_mb:.1f} MB")

    return train_path, test_path

def load_datasets_simple(train_path, test_path):
    """Simple way to load your saved datasets"""

    print(f"Loading datasets...")
    print(f"Training: {train_path}")
    print(f"Test: {test_path}")

    # Check if files exist
    if not os.path.exists(train_path):
        raise FileNotFoundError(f"Training dataset not found: {train_path}")
    if not os.path.exists(test_path):
        raise FileNotFoundError(f"Test dataset not found: {test_path}")

    # Load training dataset
    print("Loading training dataset...")
    with open(train_path, 'rb') as f:
        train_data = pickle.load(f)

    # Load test dataset
    print("Loading test dataset...")
    with open(test_path, 'rb') as f:
        test_data = pickle.load(f)

    # Create dataset objects
    train_dataset = SimpleVisualMemoryDataset(train_data)
    test_dataset = SimpleVisualMemoryDataset(test_data)

    # Print info
    print(f"✅ Datasets loaded successfully!")
    print(f"Training: {len(train_dataset)} samples")
    print(f"Test: {len(test_dataset)} samples")
    print(f"Distractors: {train_dataset.num_distractors}")
    print(f"Creation time: {train_data.get('creation_time', 'Unknown')}")

    return train_dataset, test_dataset

class SimpleVisualMemoryDataset:
    """Simple dataset class for loaded data"""
    def __init__(self, data_dict):
        self.samples = data_dict['samples']
        self.labels = data_dict['labels']
        self.metadata = data_dict['metadata']
        self.num_samples = data_dict['num_samples']
        self.num_distractors = data_dict['num_distractors']
        self.noise_types = data_dict['noise_types']
        self.match_probability = data_dict['match_probability']

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx], self.labels[idx]

    def get_sample_with_metadata(self, idx):
        return self.samples[idx], self.labels[idx], self.metadata[idx]

# =============================================================================
# OPTION 2: USING DATASET'S BUILT-IN SAVE METHOD
# =============================================================================

def save_datasets_builtin(train_dataset, test_dataset, save_dir='./my_datasets/'):
    """Use the dataset's built-in save method (if available)"""

    os.makedirs(save_dir, exist_ok=True)

    # Generate filenames
    timestamp = datetime.now().strftime("%Y%m%d_%H%M")
    train_size = len(train_dataset)
    test_size = len(test_dataset)
    num_distractors = train_dataset.num_distractors

    train_path = os.path.join(save_dir, f'train_{train_size}_{num_distractors}dist_{timestamp}.pkl')
    test_path = os.path.join(save_dir, f'test_{test_size}_{num_distractors}dist_{timestamp}.pkl')

    # Use built-in save method if your dataset has it
    if hasattr(train_dataset, 'save_dataset'):
        print("Using built-in save method...")
        train_dataset.save_dataset(train_path)
        test_dataset.save_dataset(test_path)
    else:
        print("No built-in save method, using simple save...")
        return save_datasets_simple(train_dataset, test_dataset, save_dir)

    return train_path, test_path

# =============================================================================
# CONVENIENCE FUNCTIONS
# =============================================================================

def quick_save(train_dataset, test_dataset, name="my_experiment"):
    """Quick save with a custom name"""

    save_dir = f'./saved_datasets/{name}/'
    os.makedirs(save_dir, exist_ok=True)

    timestamp = datetime.now().strftime("%Y%m%d_%H%M")
    train_path = os.path.join(save_dir, f'train_{timestamp}.pkl')
    test_path = os.path.join(save_dir, f'test_{timestamp}.pkl')

    # Save datasets
    with open(train_path, 'wb') as f:
        pickle.dump(train_dataset, f)

    with open(test_path, 'wb') as f:
        pickle.dump(test_dataset, f)

    print(f"✅ Quick save completed!")
    print(f"Saved to: {save_dir}")
    print(f"Training: train_{timestamp}.pkl")
    print(f"Test: test_{timestamp}.pkl")

    return train_path, test_path

def quick_load(train_path, test_path):
    """Quick load datasets"""

    with open(train_path, 'rb') as f:
        train_dataset = pickle.load(f)

    with open(test_path, 'rb') as f:
        test_dataset = pickle.load(f)

    print(f"✅ Quick load completed!")
    print(f"Training: {len(train_dataset)} samples")
    print(f"Test: {len(test_dataset)} samples")

    return train_dataset, test_dataset

def list_saved_datasets(save_dir='./my_datasets/'):
    """List all saved datasets in directory"""

    if not os.path.exists(save_dir):
        print(f"Directory {save_dir} doesn't exist")
        return []

    pkl_files = [f for f in os.listdir(save_dir) if f.endswith('.pkl')]

    if not pkl_files:
        print(f"No .pkl files found in {save_dir}")
        return []

    print(f"Saved datasets in {save_dir}:")
    print("=" * 50)

    datasets = []
    for file in sorted(pkl_files):
        file_path = os.path.join(save_dir, file)
        file_size_mb = os.path.getsize(file_path) / (1024 * 1024)
        mod_time = os.path.getmtime(file_path)
        mod_date = datetime.fromtimestamp(mod_time).strftime('%Y-%m-%d %H:%M:%S')

        print(f"📁 {file}")
        print(f"   Size: {file_size_mb:.1f} MB")
        print(f"   Modified: {mod_date}")

        datasets.append({
            'filename': file,
            'path': file_path,
            'size_mb': file_size_mb,
            'modified': mod_date
        })
        print()

    return datasets

# =============================================================================
# CREATE DATALOADERS FROM LOADED DATASETS
# =============================================================================

def create_dataloaders_from_datasets(train_dataset, test_dataset, batch_size=32):
    """Create DataLoaders from your loaded datasets"""

    from torch.utils.data import DataLoader

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True  # For consistent batch sizes
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        drop_last=False
    )

    print(f"✅ DataLoaders created!")
    print(f"Training: {len(train_loader)} batches of size {batch_size}")
    print(f"Test: {len(test_loader)} batches of size {batch_size}")

    return train_loader, test_loader

# =============================================================================
# COMPLETE WORKFLOW EXAMPLES
# =============================================================================

def example_save_workflow():
    """Example: Save your existing datasets"""

    print("="*60)
    print("EXAMPLE: SAVE WORKFLOW")
    print("="*60)

    # Assuming you already have your datasets:
    # train_dataset, test_dataset, train_loader, test_loader = create_datasets_for_training(...)

    print("# Step 1: Save your datasets")
    print("train_path, test_path = save_datasets_simple(train_dataset, test_dataset, './my_datasets/')")
    print()
    print("# OR quick save:")
    print("train_path, test_path = quick_save(train_dataset, test_dataset, 'experiment_1')")

def example_load_workflow():
    """Example: Load and use saved datasets"""

    print("\n" + "="*60)
    print("EXAMPLE: LOAD WORKFLOW")
    print("="*60)

    print("# Step 1: List available datasets")
    print("list_saved_datasets('./my_datasets/')")
    print()
    print("# Step 2: Load specific datasets")
    print("train_dataset, test_dataset = load_datasets_simple(")
    print("    './my_datasets/train_1000samples_3dist_20241223_1430.pkl',")
    print("    './my_datasets/test_300samples_3dist_20241223_1430.pkl'")
    print(")")
    print()
    print("# Step 3: Create DataLoaders")
    print("train_loader, test_loader = create_dataloaders_from_datasets(")
    print("    train_dataset, test_dataset, batch_size=32")
    print(")")
    print()
    print("# Step 4: Start training!")
    print("# Now you can use train_loader and test_loader for training")

# =============================================================================
# READY-TO-USE EXAMPLE
# =============================================================================

if __name__ == "__main__":
    print("DATASET SAVE/LOAD UTILITY")
    print("=" * 40)

    # Show examples
    example_save_workflow()
    example_load_workflow()

    print("\n" + "="*60)
    print("READY-TO-USE CODE FOR YOUR SITUATION")
    print("="*60)

    print("# You currently have:")
    print("# train_dataset, test_dataset, train_loader, test_loader = create_datasets_for_training(...)")
    print()
    print("# To save them:")
    print("train_path, test_path = save_datasets_simple(train_dataset, test_dataset)")
    print()
    print("# To load them later:")
    print("train_dataset, test_dataset = load_datasets_simple(train_path, test_path)")
    print("train_loader, test_loader = create_dataloaders_from_datasets(train_dataset, test_dataset)")
    print()
    print("# To see what you have saved:")
    print("list_saved_datasets()")

## dataset download

In [None]:
import os
folder_name = 'my_datasets'
file_list = os.listdir(folder_name)
print(file_list)

In [None]:
train_path = './my_datasets/train_5000samples_3dist_20250724_2358.pkl'
test_path = './my_datasets/test_1000samples_3dist_20250724_2358.pkl'

train_dataset, test_dataset = load_datasets_simple(train_path, test_path)
train_loader, test_loader = create_dataloaders_from_datasets(train_dataset, test_dataset)

## visual function

In [None]:
def plot_training_history(train_losses, train_accs, test_losses, test_accs,
                         train_match_accs, train_nomatch_accs,
                         test_match_accs, test_nomatch_accs):
    """Plot comprehensive training history"""

    epochs = range(1, len(train_losses) + 1)

    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('Visual Memory Task Training History', fontsize=16)

    # Loss plot
    axes[0, 0].plot(epochs, train_losses, 'b-', label='Train Loss', marker='o')
    axes[0, 0].plot(epochs, test_losses, 'r-', label='Test Loss', marker='s')
    axes[0, 0].set_title('Loss Over Time')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)

    # Overall accuracy plot
    axes[0, 1].plot(epochs, train_accs, 'b-', label='Train Accuracy', marker='o')
    axes[0, 1].plot(epochs, test_accs, 'r-', label='Test Accuracy', marker='s')
    axes[0, 1].set_title('Overall Accuracy')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy (%)')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)

    # Match trial accuracy
    axes[1, 0].plot(epochs, train_match_accs, 'g-', label='Train Match', marker='o')
    axes[1, 0].plot(epochs, test_match_accs, 'darkgreen', label='Test Match', marker='s')
    axes[1, 0].set_title('Match Trial Accuracy')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Accuracy (%)')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)

    # No-match trial accuracy
    axes[1, 1].plot(epochs, train_nomatch_accs, 'orange', label='Train No-Match', marker='o')
    axes[1, 1].plot(epochs, test_nomatch_accs, 'red', label='Test No-Match', marker='s')
    axes[1, 1].set_title('No-Match Trial Accuracy')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Accuracy (%)')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

def plot_confusion_matrix(predictions, targets, epoch):
    """Plot confusion matrix"""
    cm = confusion_matrix(targets, predictions)

    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=['No Match', 'Match'],
                yticklabels=['No Match', 'Match'])
    plt.title(f'Confusion Matrix - Epoch {epoch}')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.show()

    # Print classification report
    print(f"\nClassification Report - Epoch {epoch}:")
    print(classification_report(targets, predictions,
                              target_names=['No Match', 'Match']))

def visualize_model_predictions(model, test_dataset, device, num_samples=6):
    """Visualize model predictions on test samples"""
    model.eval()

    fig, axes = plt.subplots(num_samples, 6, figsize=(15, num_samples * 2.5))
    if num_samples == 1:
        axes = axes.reshape(1, -1)

    fig.suptitle('Model Predictions on Test Samples', fontsize=16)

    indices = np.random.choice(len(test_dataset), num_samples, replace=False)

    with torch.no_grad():
        for i, idx in enumerate(indices):
            sequence, true_label, metadata = test_dataset.get_sample_with_metadata(idx)

            # Get model prediction
            sequence_batch = sequence.unsqueeze(0).to(device)  # Add batch dimension
            output = model(sequence_batch)
            probabilities = F.softmax(output, dim=1)
            predicted_label = output.argmax(1).item()
            confidence = probabilities[0][predicted_label].item()

            # Get attention weights
            try:
                attention_weights = model.get_attention_weights(sequence_batch)[0].cpu().numpy()
            except:
                attention_weights = np.ones(sequence.shape[0]) / sequence.shape[0]  # Uniform if error

            seq_len = sequence.shape[0]
            target_digit = metadata['target_digit']
            probe_digit = metadata['probe_digit']

            print(f"\nSample {i+1}:")
            print(f"  Target: {target_digit}, Probe: {probe_digit}")
            print(f"  True: {'MATCH' if true_label == 1 else 'NO MATCH'}")
            print(f"  Predicted: {'MATCH' if predicted_label == 1 else 'NO MATCH'} (conf: {confidence:.3f})")
            print(f"  Attention weights: {attention_weights}")

            # Display sequence
            for j in range(min(seq_len, 6)):
                ax = axes[i, j]
                img = sequence[j].squeeze().numpy()
                ax.imshow(img, cmap='gray')
                ax.axis('off')

                # Add attention-based border
                alpha = attention_weights[j] if j < len(attention_weights) else 0.1
                border_width = int(alpha * 5) + 1

                if j == 0:
                    ax.set_title(f'TARGET\n(att: {alpha:.2f})', color='blue', fontweight='bold')
                    ax.add_patch(plt.Rectangle((0, 0), 27, 27, fill=False,
                                             edgecolor='blue', linewidth=border_width))
                elif j == seq_len - 1:
                    color = 'green' if predicted_label == true_label else 'red'
                    pred_text = 'MATCH' if predicted_label == 1 else 'NO MATCH'
                    ax.set_title(f'PROBE\n{pred_text}\n(att: {alpha:.2f})',
                               color=color, fontweight='bold')
                    ax.add_patch(plt.Rectangle((0, 0), 27, 27, fill=False,
                                             edgecolor=color, linewidth=border_width))
                else:
                    ax.set_title(f'NOISE {j}\n(att: {alpha:.2f})', color='gray')

            # Hide unused subplots
            for j in range(seq_len, 6):
                axes[i, j].axis('off')

    plt.tight_layout()
    plt.show()


## mean function

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import time
import os
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(128 * 7 * 7, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = F.relu(self.conv3(x))
        x = self.dropout1(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)
        return x

    def extract_features(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = F.relu(self.conv3(x))
        x = self.dropout1(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))

        return x

class CNNFeatureExtractor(nn.Module):
    def __init__(self, pretrained_cnn_path, feature_dim=512):
        super(CNNFeatureExtractor, self).__init__()

        self.cnn = SimpleCNN()
        self.cnn.load_state_dict(torch.load(pretrained_cnn_path, map_location=device))

        for param in self.cnn.parameters():
            param.requires_grad = False

        self.cnn.eval()
        self.feature_dim = feature_dim

    def forward(self, x):
        with torch.no_grad():
            features = self.cnn.extract_features(x)
        return features


class VisualMemoryModel(nn.Module):
    def __init__(self, pretrained_cnn_path, rnn_hidden_dim=256,
                 projection_dim=128, rnn_type='LSTM', num_layers=2, dropout=0.3):
        super(VisualMemoryModel, self).__init__()

        self.cnn_features = CNNFeatureExtractor(pretrained_cnn_path)
        cnn_feature_dim = self.cnn_features.feature_dim

        self.feature_projection = nn.Sequential(
            nn.Linear(cnn_feature_dim, projection_dim * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(projection_dim * 2, projection_dim),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5)
        )

        self.rnn_type = rnn_type
        if rnn_type == 'LSTM':
            self.rnn = nn.LSTM(
                projection_dim, rnn_hidden_dim,
                num_layers=num_layers,
                batch_first=True,
                dropout=dropout if num_layers > 1 else 0,
                bidirectional=False
            )
        elif rnn_type == 'GRU':
            self.rnn = nn.GRU(
                projection_dim, rnn_hidden_dim,
                num_layers=num_layers,
                batch_first=True,
                dropout=dropout if num_layers > 1 else 0,
                bidirectional=False
            )

        self.memory_classifier = nn.Sequential(
            nn.Linear(rnn_hidden_dim, rnn_hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(rnn_hidden_dim // 2, 64),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5),
            nn.Linear(64, 2)
        )

        self.rnn_hidden_dim = rnn_hidden_dim
        self.num_layers = num_layers
        self.projection_dim = projection_dim

    def forward(self, x):
        batch_size, seq_len = x.size(0), x.size(1)
        x = x.view(batch_size * seq_len, *x.shape[2:])
        cnn_features = self.cnn_features(x)
        projected_features = self.feature_projection(cnn_features)
        projected_features = projected_features.view(batch_size, seq_len, -1)

        if self.rnn_type == 'LSTM':
            rnn_output, (hidden, cell) = self.rnn(projected_features)
        else:
            rnn_output, hidden = self.rnn(projected_features)

        final_output = rnn_output[:, -1, :]
        logits = self.memory_classifier(final_output)
        return logits

    def get_attention_weights(self, x):
        """Get attention-like weights to see what the model focuses on"""
        batch_size, seq_len = x.size(0), x.size(1)

        x = x.view(batch_size * seq_len, *x.shape[2:])
        cnn_features = self.cnn_features(x)
        projected_features = self.feature_projection(cnn_features)
        projected_features = projected_features.view(batch_size, seq_len, -1)

        if self.rnn_type == 'LSTM':
            rnn_output, _ = self.rnn(projected_features)
        else:
            rnn_output, _ = self.rnn(projected_features)

        final_state = rnn_output[:, -1:, :]  # [batch_size, 1, hidden_dim]
        attention_scores = torch.bmm(rnn_output, final_state.transpose(1, 2))  # [batch_size, seq_len, 1]
        attention_weights = F.softmax(attention_scores.squeeze(-1), dim=1)  # [batch_size, seq_len]
        return attention_weights

def train_memory_model(model, train_loader, criterion, optimizer, device, epoch):
    """Train the model for one epoch"""
    model.train()
    model.cnn_features.eval()

    running_loss = 0.0
    correct = 0
    total = 0
    match_correct = 0
    match_total = 0
    nomatch_correct = 0
    nomatch_total = 0

    train_bar = tqdm(train_loader, desc=f'Epoch {epoch} Training')

    for batch_idx, (sequences, targets) in enumerate(train_bar):
        sequences, targets = sequences.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(sequences)
        loss = criterion(outputs, targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()


        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        # Separate accuracy for match vs no-match
        match_mask = (targets == 1)
        nomatch_mask = (targets == 0)

        if match_mask.sum() > 0:
            match_correct += predicted[match_mask].eq(targets[match_mask]).sum().item()
            match_total += match_mask.sum().item()

        if nomatch_mask.sum() > 0:
            nomatch_correct += predicted[nomatch_mask].eq(targets[nomatch_mask]).sum().item()
            nomatch_total += nomatch_mask.sum().item()


        train_bar.set_postfix({
            'Loss': f'{running_loss/(batch_idx+1):.4f}',
            'Acc': f'{100.*correct/total:.2f}%',
            'M': f'{100.*match_correct/max(match_total,1):.1f}%',
            'NM': f'{100.*nomatch_correct/max(nomatch_total,1):.1f}%'
        })

    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100. * correct / total
    match_acc = 100. * match_correct / max(match_total, 1)
    nomatch_acc = 100. * nomatch_correct / max(nomatch_total, 1)

    return epoch_loss, epoch_acc, match_acc, nomatch_acc

def test_memory_model(model, test_loader, criterion, device, epoch):
    """Test the model"""
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    match_correct = 0
    match_total = 0
    nomatch_correct = 0
    nomatch_total = 0

    all_predictions = []
    all_targets = []

    with torch.no_grad():
        test_bar = tqdm(test_loader, desc=f'Epoch {epoch} Testing')

        for sequences, targets in test_bar:
            sequences, targets = sequences.to(device), targets.to(device)

            outputs = model(sequences)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            all_predictions.extend(predicted.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())

            match_mask = (targets == 1)
            nomatch_mask = (targets == 0)

            if match_mask.sum() > 0:
                match_correct += predicted[match_mask].eq(targets[match_mask]).sum().item()
                match_total += match_mask.sum().item()

            if nomatch_mask.sum() > 0:
                nomatch_correct += predicted[nomatch_mask].eq(targets[nomatch_mask]).sum().item()
                nomatch_total += nomatch_mask.sum().item()

            test_bar.set_postfix({
                'Loss': f'{test_loss/len(test_loader):.4f}',
                'Acc': f'{100.*correct/total:.2f}%'
            })

    epoch_loss = test_loss / len(test_loader)
    epoch_acc = 100. * correct / total
    match_acc = 100. * match_correct / max(match_total, 1)
    nomatch_acc = 100. * nomatch_correct / max(nomatch_total, 1)

    return epoch_loss, epoch_acc, match_acc, nomatch_acc, all_predictions, all_targets


In [None]:
class VanillaRNNVisualMemoryModel(nn.Module):
    def __init__(self, pretrained_cnn_path, rnn_hidden_dim=256,
                 projection_dim=128, num_layers=2, dropout=0.3):
        super(VanillaRNNVisualMemoryModel, self).__init__()

        # Frozen CNN feature extractor (same as before)
        self.cnn_features = CNNFeatureExtractor(pretrained_cnn_path)
        cnn_feature_dim = self.cnn_features.feature_dim  # 512

        # Trainable feature projection
        self.feature_projection = nn.Sequential(
            nn.Linear(cnn_feature_dim, projection_dim * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(projection_dim * 2, projection_dim),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5)
        )

        # VANILLA RNN (key difference!)
        self.rnn_type = 'Vanilla'
        self.num_layers = num_layers
        self.rnn_hidden_dim = rnn_hidden_dim

        # Stack of vanilla RNN layers
        self.rnn_layers = nn.ModuleList()

        # First layer: input_size = projection_dim
        self.rnn_layers.append(
            nn.RNN(projection_dim, rnn_hidden_dim, num_layers=1,
                   batch_first=True, nonlinearity='tanh')
        )

        # Additional layers (if num_layers > 1)
        for _ in range(num_layers - 1):
            self.rnn_layers.append(
                nn.RNN(rnn_hidden_dim, rnn_hidden_dim, num_layers=1,
                       batch_first=True, nonlinearity='tanh')
            )

        # Dropout between layers
        self.rnn_dropout = nn.Dropout(dropout)

        # Memory comparison and classification head
        self.memory_classifier = nn.Sequential(
            nn.Linear(rnn_hidden_dim, rnn_hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(rnn_hidden_dim // 2, 64),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5),
            nn.Linear(64, 2)  # Binary: match (1) vs no-match (0)
        )

        self.projection_dim = projection_dim

    def forward(self, x):
        """
        Args:
            x: [batch_size, seq_len, channels, height, width]
        Returns:
            output: [batch_size, 2] logits for match/no-match classification
        """
        batch_size, seq_len = x.size(0), x.size(1)

        # Reshape to process all images at once
        x = x.view(batch_size * seq_len, *x.shape[2:])

        # Extract CNN features (frozen)
        cnn_features = self.cnn_features(x)  # [batch_size * seq_len, 512]

        # Project to lower dimension (trainable)
        projected_features = self.feature_projection(cnn_features)  # [batch_size * seq_len, projection_dim]

        # Reshape for RNN processing
        projected_features = projected_features.view(batch_size, seq_len, -1)  # [batch_size, seq_len, projection_dim]

        # Process through stacked vanilla RNN layers
        rnn_input = projected_features

        for i, rnn_layer in enumerate(self.rnn_layers):
            rnn_output, hidden = rnn_layer(rnn_input)

            # Apply dropout between layers (except last layer)
            if i < len(self.rnn_layers) - 1:
                rnn_output = self.rnn_dropout(rnn_output)

            rnn_input = rnn_output

        # Use the final output for classification
        final_output = rnn_output[:, -1, :]  # [batch_size, rnn_hidden_dim]

        # Binary classification: match vs no-match
        logits = self.memory_classifier(final_output)  # [batch_size, 2]

        return logits

    def get_hidden_states(self, x):
        """Get hidden states at each timestep for analysis"""
        batch_size, seq_len = x.size(0), x.size(1)

        x = x.view(batch_size * seq_len, *x.shape[2:])
        cnn_features = self.cnn_features(x)
        projected_features = self.feature_projection(cnn_features)
        projected_features = projected_features.view(batch_size, seq_len, -1)

        # Process through RNN layers to get final hidden states
        rnn_input = projected_features

        for rnn_layer in self.rnn_layers:
            rnn_output, _ = rnn_layer(rnn_input)
            rnn_input = rnn_output

        return rnn_output, projected_features

In [None]:
def train_visual_memory_model(train_loader, test_loader, test_dataset,
                             pretrained_cnn_path,rnn_type='GRU', num_epochs=15,
                             learning_rate=0.001, save_path='visual_memory_model_GRU.pth'):

    print("="*60)
    print("VISUAL MEMORY TASK TRAINING")
    print("="*60)

    if rnn_type == 'Vanilla':
        model = VanillaRNNVisualMemoryModel(
            pretrained_cnn_path=pretrained_cnn_path,
            rnn_hidden_dim=256,
            projection_dim=128,
            num_layers=2,
            dropout=0.3
        ).to(device)

    elif rnn_type == 'RIM':
        model = RIMVisualMemoryModel(
            pretrained_cnn_path=pretrained_cnn_path,
            rim_hidden_dim=256,        # Hidden dimension per unit
            projection_dim=128,
            num_units=6,               # Total number of mechanisms
            k_active_units=3,          # Active mechanisms per timestep
            num_layers=2,
            unit_type='LSTM',          # 'LSTM', 'GRU', or 'Vanilla'
            dropout=0.3
        )
    else:
      model = VisualMemoryModel(
          pretrained_cnn_path=pretrained_cnn_path,
          rnn_hidden_dim=256,
          projection_dim=128,
          rnn_type=rnn_type,
          num_layers=2,
          dropout=0.3
      ).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=learning_rate,
        weight_decay=1e-4
    )

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=3, verbose=True
    )

    train_losses = []
    train_accs = []
    test_losses = []
    test_accs = []
    train_match_accs = []
    train_nomatch_accs = []
    test_match_accs = []
    test_nomatch_accs = []

    best_test_acc = 0.0
    best_model_state = None

    print(f"\nStarting training for {num_epochs} epochs...")
    start_time = time.time()

    for epoch in range(1, num_epochs + 1):
        print(f"\nEpoch {epoch}/{num_epochs}")
        print("-" * 40)

        train_loss, train_acc, train_match_acc, train_nomatch_acc = train_memory_model(
            model, train_loader, criterion, optimizer, device, epoch
        )

        test_loss, test_acc, test_match_acc, test_nomatch_acc, predictions, targets = test_memory_model(
            model, test_loader, criterion, device, epoch
        )

        scheduler.step(test_acc)

        train_losses.append(train_loss)
        train_accs.append(train_acc)
        test_losses.append(test_loss)
        test_accs.append(test_acc)
        train_match_accs.append(train_match_acc)
        train_nomatch_accs.append(train_nomatch_acc)
        test_match_accs.append(test_match_acc)
        test_nomatch_accs.append(test_nomatch_acc)

        if test_acc > best_test_acc:
            best_test_acc = test_acc
            best_model_state = model.state_dict().copy()

        print(f"Train - Loss: {train_loss:.4f}, Acc: {train_acc:.2f}%, "
              f"Match: {train_match_acc:.2f}%, No-Match: {train_nomatch_acc:.2f}%")
        print(f"Test  - Loss: {test_loss:.4f}, Acc: {test_acc:.2f}%, "
              f"Match: {test_match_acc:.2f}%, No-Match: {test_nomatch_acc:.2f}%")

        if epoch % 5 == 0:
            plot_confusion_matrix(predictions, targets, epoch)

    training_time = time.time() - start_time
    print(f"\nTraining completed in {training_time:.1f} seconds")


    torch.save(best_model_state, save_path)
    print(f"Best model saved to {save_path} (Test Acc: {best_test_acc:.2f}%)")

    plot_training_history(train_losses, train_accs, test_losses, test_accs,
                         train_match_accs, train_nomatch_accs,
                         test_match_accs, test_nomatch_accs)

    model.load_state_dict(best_model_state)
    visualize_model_predictions(model, test_dataset, device, num_samples=1)

    print(f"\n" + "="*60)
    print("FINAL RESULTS")
    print("="*60)
    print(f"Best Test Accuracy: {best_test_acc:.2f}%")
    print(f"Final Test Match Accuracy: {test_match_accs[-1]:.2f}%")
    print(f"Final Test No-Match Accuracy: {test_nomatch_accs[-1]:.2f}%")
    print(f"Training Time: {training_time:.1f} seconds")

    return model, {
        'train_losses': train_losses,
        'train_accs': train_accs,
        'test_losses': test_losses,
        'test_accs': test_accs,
        'best_test_acc': best_test_acc
    }


## train

In [None]:
model, history = train_visual_memory_model(
        train_loader=train_loader,
        test_loader=test_loader,
        test_dataset=test_dataset,
        pretrained_cnn_path='small_cnn_model.pth',
        rnn_type='GRU',
        num_epochs=20,
        learning_rate=0.001,
        save_path='RNN_model/visual_memory_GRU_noise3.pth'
    )