In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
cd drive/MyDrive/Neuromatch_project/

In [None]:
!pip install torch torchvision matplotlib seaborn tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
from tqdm import tqdm
import random
import pickle
import os

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"📱 Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

plt.style.use('default')
plt.rcParams['figure.figsize'] = (12, 8)

print("✅ Setup complete!")

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(128 * 7 * 7, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = F.relu(self.conv3(x))
        x = self.dropout1(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)
        return x

    def extract_features(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = F.relu(self.conv3(x))
        x = self.dropout1(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return x

# CNN Feature Extractor (frozen)
class CNNFeatureExtractor(nn.Module):
    def __init__(self, pretrained_cnn_path, feature_dim=512):
        super(CNNFeatureExtractor, self).__init__()
        self.cnn = SimpleCNN()
        self.cnn.load_state_dict(torch.load(pretrained_cnn_path, map_location=device))
        for param in self.cnn.parameters():
            param.requires_grad = False
        self.cnn.eval()
        self.feature_dim = feature_dim

    def forward(self, x):
        with torch.no_grad():
            features = self.cnn.extract_features(x)
        return features

# Visual Memory Model
class VisualMemoryModel(nn.Module):
    def __init__(self, pretrained_cnn_path, rnn_hidden_dim=256,
                 projection_dim=128, rnn_type='LSTM', num_layers=2, dropout=0.3):
        super(VisualMemoryModel, self).__init__()

        self.cnn_features = CNNFeatureExtractor(pretrained_cnn_path)
        cnn_feature_dim = self.cnn_features.feature_dim

        self.feature_projection = nn.Sequential(
            nn.Linear(cnn_feature_dim, projection_dim * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(projection_dim * 2, projection_dim),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5)
        )

        self.rnn_type = rnn_type
        if rnn_type == 'LSTM':
            self.rnn = nn.LSTM(
                projection_dim, rnn_hidden_dim,
                num_layers=num_layers,
                batch_first=True,
                dropout=dropout if num_layers > 1 else 0
            )
        elif rnn_type == 'GRU':
            self.rnn = nn.GRU(
                projection_dim, rnn_hidden_dim,
                num_layers=num_layers,
                batch_first=True,
                dropout=dropout if num_layers > 1 else 0
            )

        self.memory_classifier = nn.Sequential(
            nn.Linear(rnn_hidden_dim, rnn_hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(rnn_hidden_dim // 2, 64),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5),
            nn.Linear(64, 2)
        )

        self.rnn_hidden_dim = rnn_hidden_dim
        self.num_layers = num_layers
        self.projection_dim = projection_dim

    def forward(self, x):
        batch_size, seq_len = x.size(0), x.size(1)
        x = x.view(batch_size * seq_len, *x.shape[2:])
        cnn_features = self.cnn_features(x)
        projected_features = self.feature_projection(cnn_features)
        projected_features = projected_features.view(batch_size, seq_len, -1)

        if self.rnn_type == 'LSTM':
            rnn_output, (hidden, cell) = self.rnn(projected_features)
        else:
            rnn_output, hidden = self.rnn(projected_features)

        final_output = rnn_output[:, -1, :]
        logits = self.memory_classifier(final_output)
        return logits

    def get_hidden_states(self, x):
        batch_size, seq_len = x.size(0), x.size(1)
        x = x.view(batch_size * seq_len, *x.shape[2:])
        cnn_features = self.cnn_features(x)
        projected_features = self.feature_projection(cnn_features)
        projected_features = projected_features.view(batch_size, seq_len, -1)

        if self.rnn_type == 'LSTM':
            rnn_output, _ = self.rnn(projected_features)
        else:
            rnn_output, _ = self.rnn(projected_features)

        return rnn_output, projected_features

print("✅ Model architecture loaded!")

In [None]:
class VanillaRNNVisualMemoryModel(nn.Module):
    def __init__(self, pretrained_cnn_path, rnn_hidden_dim=256,
                 projection_dim=128, num_layers=2, dropout=0.3):
        super(VanillaRNNVisualMemoryModel, self).__init__()

        # Frozen CNN feature extractor (same as before)
        self.cnn_features = CNNFeatureExtractor(pretrained_cnn_path)
        cnn_feature_dim = self.cnn_features.feature_dim  # 512

        # Trainable feature projection
        self.feature_projection = nn.Sequential(
            nn.Linear(cnn_feature_dim, projection_dim * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(projection_dim * 2, projection_dim),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5)
        )

        # VANILLA RNN (key difference!)
        self.rnn_type = 'Vanilla'
        self.num_layers = num_layers
        self.rnn_hidden_dim = rnn_hidden_dim

        # Stack of vanilla RNN layers
        self.rnn_layers = nn.ModuleList()

        # First layer: input_size = projection_dim
        self.rnn_layers.append(
            nn.RNN(projection_dim, rnn_hidden_dim, num_layers=1,
                   batch_first=True, nonlinearity='tanh')
        )

        # Additional layers (if num_layers > 1)
        for _ in range(num_layers - 1):
            self.rnn_layers.append(
                nn.RNN(rnn_hidden_dim, rnn_hidden_dim, num_layers=1,
                       batch_first=True, nonlinearity='tanh')
            )

        # Dropout between layers
        self.rnn_dropout = nn.Dropout(dropout)

        # Memory comparison and classification head
        self.memory_classifier = nn.Sequential(
            nn.Linear(rnn_hidden_dim, rnn_hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(rnn_hidden_dim // 2, 64),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5),
            nn.Linear(64, 2)  # Binary: match (1) vs no-match (0)
        )

        self.projection_dim = projection_dim

    def forward(self, x):
        """
        Args:
            x: [batch_size, seq_len, channels, height, width]
        Returns:
            output: [batch_size, 2] logits for match/no-match classification
        """
        batch_size, seq_len = x.size(0), x.size(1)

        # Reshape to process all images at once
        x = x.view(batch_size * seq_len, *x.shape[2:])

        # Extract CNN features (frozen)
        cnn_features = self.cnn_features(x)  # [batch_size * seq_len, 512]

        # Project to lower dimension (trainable)
        projected_features = self.feature_projection(cnn_features)  # [batch_size * seq_len, projection_dim]

        # Reshape for RNN processing
        projected_features = projected_features.view(batch_size, seq_len, -1)  # [batch_size, seq_len, projection_dim]

        # Process through stacked vanilla RNN layers
        rnn_input = projected_features

        for i, rnn_layer in enumerate(self.rnn_layers):
            rnn_output, hidden = rnn_layer(rnn_input)

            # Apply dropout between layers (except last layer)
            if i < len(self.rnn_layers) - 1:
                rnn_output = self.rnn_dropout(rnn_output)

            rnn_input = rnn_output

        # Use the final output for classification
        final_output = rnn_output[:, -1, :]  # [batch_size, rnn_hidden_dim]

        # Binary classification: match vs no-match
        logits = self.memory_classifier(final_output)  # [batch_size, 2]

        return logits

    def get_hidden_states(self, x):
        """Get hidden states at each timestep for analysis"""
        batch_size, seq_len = x.size(0), x.size(1)

        x = x.view(batch_size * seq_len, *x.shape[2:])
        cnn_features = self.cnn_features(x)
        projected_features = self.feature_projection(cnn_features)
        projected_features = projected_features.view(batch_size, seq_len, -1)

        # Process through RNN layers to get final hidden states
        rnn_input = projected_features

        for rnn_layer in self.rnn_layers:
            rnn_output, _ = rnn_layer(rnn_input)
            rnn_input = rnn_output

        return rnn_output, projected_features

In [None]:
def test_model_accuracy(model, test_loader, device):
        """Test overall model accuracy"""
        model.eval()
        correct = 0
        total = 0
        match_correct = 0
        match_total = 0
        nomatch_correct = 0
        nomatch_total = 0

        all_predictions = []
        all_targets = []
        all_confidences = []

        with torch.no_grad():
            for sequences, targets in tqdm(test_loader, desc="Testing"):
                sequences, targets = sequences.to(device), targets.to(device)
                outputs = model(sequences)
                probabilities = F.softmax(outputs, dim=1)
                _, predicted = outputs.max(1)

                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

                # Store predictions and confidences
                all_predictions.extend(predicted.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())
                all_confidences.extend(probabilities.max(1)[0].cpu().numpy())

                # Separate accuracy for match vs no-match
                match_mask = (targets == 1)
                nomatch_mask = (targets == 0)

                if match_mask.sum() > 0:
                    match_correct += predicted[match_mask].eq(targets[match_mask]).sum().item()
                    match_total += match_mask.sum().item()

                if nomatch_mask.sum() > 0:
                    nomatch_correct += predicted[nomatch_mask].eq(targets[nomatch_mask]).sum().item()
                    nomatch_total += nomatch_mask.sum().item()

        overall_acc = 100. * correct / total
        match_acc = 100. * match_correct / max(match_total, 1)
        nomatch_acc = 100. * nomatch_correct / max(nomatch_total, 1)
        avg_confidence = np.mean(all_confidences)

        return {
            'overall_accuracy': overall_acc/100,
            'match_accuracy': match_acc,
            'nomatch_accuracy': nomatch_acc,
            'total_samples': total,
            'match_samples': match_total,
            'nomatch_samples': nomatch_total,
            'average_confidence': avg_confidence,
            'predictions': all_predictions,
            'targets': all_targets,
            'confidences': all_confidences
        }

In [None]:
import glob
import pickle
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader

data_num = [1, 2, 3, 5, 7, 10]
PRETRAINED_CNN_PATH = 'small_cnn_model.pth'

# Dictionary to store results
results_dict = {}
accuracy_scores = []
noise_levels = []

# Loop through each number in data_num
for num in data_num:
    print(f"\n{'='*50}")
    print(f"Testing with noise level: {num}")
    print(f"{'='*50}")

    # Define paths for current iteration
    VISUAL_MEMORY_MODEL_PATH = f'./RNN_model/visual_memory_Vanilla_noise{num}.pth'
    TEST_DATASET_PATH = f'./my_datasets/test_1000samples_*{num}dist_*.pkl'

    # Check if model exists
    try:

        model = VanillaRNNVisualMemoryModel(
              pretrained_cnn_path=PRETRAINED_CNN_PATH,
              rnn_hidden_dim=256,
              projection_dim=128,
              num_layers=2,
              dropout=0.3
          ).to(device)
        """
        model = VisualMemoryModel(
            pretrained_cnn_path=PRETRAINED_CNN_PATH,
            rnn_hidden_dim=256,
            projection_dim=128,
            rnn_type='LSTM',  # Note: You mentioned GRU but your model files are named GRU_noise*
            num_layers=2,
            dropout=0.3
        ).to(device)
        """
        model.load_state_dict(torch.load(VISUAL_MEMORY_MODEL_PATH, map_location=device))
        model.eval()
        print(f"✅ Model loaded: {VISUAL_MEMORY_MODEL_PATH}")

    except FileNotFoundError:
        print(f"❌ Model not found: {VISUAL_MEMORY_MODEL_PATH}")
        continue
    except Exception as e:
        print(f"❌ Error loading model: {e}")
        continue

    # Find and load test dataset
    test_files = glob.glob(TEST_DATASET_PATH)

    if test_files:
        test_path = test_files[0]  # Use the first matching file
        print(f"📁 Loading test dataset from: {test_path}")

        try:
            with open(test_path, 'rb') as f:
                test_data = pickle.load(f)

            # Create simple dataset class
            class SimpleTestDataset:
                def __init__(self, data_dict):
                    if isinstance(data_dict, dict):
                        self.samples = data_dict['samples']
                        self.labels = data_dict['labels']
                        self.metadata = data_dict['metadata']
                        self.num_distractors = data_dict.get('num_distractors', num)
                    else:
                        # If it's the dataset object itself
                        self.samples = data_dict.samples
                        self.labels = data_dict.labels
                        self.metadata = data_dict.metadata
                        self.num_distractors = data_dict.num_distractors

                def __len__(self):
                    return len(self.samples)

                def __getitem__(self, idx):
                    return self.samples[idx], self.labels[idx]

                def get_sample_with_metadata(self, idx):
                    return self.samples[idx], self.labels[idx], self.metadata[idx]

            test_dataset = SimpleTestDataset(test_data)
            print(f"✅ Test dataset loaded: {len(test_dataset)} samples")

            # Create data loader
            test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

            # Test model and get results
            results = test_model_accuracy(model, test_loader, device)
            overall_accuracy = results['overall_accuracy']

            # Store results
            results_dict[num] = {
                'overall_accuracy': overall_accuracy,
                'full_results': results
            }
            accuracy_scores.append(overall_accuracy)
            noise_levels.append(num)

            print(f"🎯 Overall Accuracy for noise {num}: {overall_accuracy:.4f}")

        except Exception as e:
            print(f"❌ Error processing dataset: {e}")
            continue

    else:
        print(f"❌ No test dataset found matching: {TEST_DATASET_PATH}")
        continue

'''
# Print summary of results
print(f"\n{'='*60}")
print("SUMMARY OF RESULTS")
print(f"{'='*60}")
for num in data_num:
    if num in results_dict:
        acc = results_dict[num]['overall_accuracy']
        print(f"Noise {num:2d}: {acc:.4f} ({acc*100:.2f}%)")
    else:
        print(f"Noise {num:2d}: No results (missing files)")

# Plot results
if accuracy_scores and noise_levels:
    plt.figure(figsize=(10, 6))
    plt.plot(noise_levels, accuracy_scores, 'bo-', linewidth=2, markersize=8)
    plt.xlabel('Number of Distractors (Noise Level)', fontsize=12)
    plt.ylabel('Overall Accuracy', fontsize=12)
    plt.title('Model Performance along with Number of Distractors', fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.ylim(0.5, 1)
    plt.xticks(noise_levels)

    # Add value labels on points
    for i, (x, y) in enumerate(zip(noise_levels, accuracy_scores)):
        plt.annotate(f'{y:.3f}', (x, y), textcoords="offset points",
                    xytext=(0,10), ha='center', fontsize=10)

    plt.tight_layout()
    #plt.savefig('model_accuracy_vs_noise.png', dpi=300, bbox_inches='tight')
    plt.show()

    print(f"\n📊 Plot saved as 'model_accuracy_vs_noise.png'")
else:
    print("\n❌ No results to plot - no models were successfully tested")
'''

### Get RDM correlation

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.metrics.pairwise import cosine_similarity
from scipy.spatial.distance import pdist, squareform
from scipy.cluster.hierarchy import dendrogram, linkage
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')


def extract_hidden_states_dataset(model, dataset, device, num_samples=100):
    """Extract hidden states for multiple samples to analyze geometry"""

    model.eval()
    all_hidden_states = []
    all_metadata = []

    print(f"🔍 Extracting hidden states from {num_samples} samples...")

    with torch.no_grad():
        for i in tqdm(range(min(num_samples, len(dataset)))):
            sequence, label = dataset[i]

            try:
                _, _, metadata = dataset.get_sample_with_metadata(i)
                target_digit = metadata['target_digit']
                probe_digit = metadata['probe_digit']
                is_match = metadata['is_match']
            except:
                target_digit = -1
                probe_digit = -1
                is_match = (label == 1)

            # Get hidden states for this sequence
            sequence_batch = sequence.unsqueeze(0).to(device)
            hidden_states, projected_features = model.get_hidden_states(sequence_batch)

            # Store results
            hidden_np = hidden_states[0].cpu().numpy()  # [seq_len, hidden_dim]
            projected_np = projected_features[0].cpu().numpy()  # [seq_len, proj_dim]

            sample_data = {
                'sample_idx': i,
                'target_digit': target_digit,
                'probe_digit': probe_digit,
                'is_match': is_match,
                'true_label': label,
                'hidden_states': hidden_np,  # [5, 256]
                'projected_features': projected_np,  # [5, 128]
                'sequence': sequence.numpy()
            }

            all_hidden_states.append(sample_data)
            all_metadata.append({
                'sample_idx': i,
                'target_digit': target_digit,
                'probe_digit': probe_digit,
                'is_match': is_match
            })

    print(f"✅ Extracted hidden states for {len(all_hidden_states)} samples")
    return all_hidden_states, all_metadata

In [None]:


import glob
import pickle
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics.pairwise import cosine_similarity
from scipy.spatial.distance import pdist, squareform
from scipy.stats import spearmanr
from scipy.stats import pearsonr
import warnings

data_num = [1, 2, 3, 5, 7, 10]
PRETRAINED_CNN_PATH = 'small_cnn_model.pth'

# Dictionary to store results
results_dict = {}
accuracy_scores = []
noise_levels = []

# Loop through each number in data_num
for num in data_num:
    print(f"\n{'='*50}")
    print(f"Testing with noise level: {num}")
    print(f"{'='*50}")

    # Define paths for current iteration
    VISUAL_MEMORY_MODEL_PATH = f'./RNN_model/visual_memory_Vanilla_noise{num}.pth'
    TEST_DATASET_PATH = f'./my_datasets/test_1000samples_*{num}dist_*.pkl'

    # Check if model exists
    try:
        # Load model
        """
        model = VisualMemoryModel(
            pretrained_cnn_path=PRETRAINED_CNN_PATH,
            rnn_hidden_dim=256,
            projection_dim=128,
            rnn_type='LSTM',  # Note: You mentioned GRU but your model files are named GRU_noise*
            num_layers=2,
            dropout=0.3
        ).to(device)
        """
        model = VanillaRNNVisualMemoryModel(
              pretrained_cnn_path=PRETRAINED_CNN_PATH,
              rnn_hidden_dim=256,
              projection_dim=128,
              num_layers=2,
              dropout=0.3
          ).to(device)

        model.load_state_dict(torch.load(VISUAL_MEMORY_MODEL_PATH, map_location=device))
        model.eval()
        print(f"✅ Model loaded: {VISUAL_MEMORY_MODEL_PATH}")

    except FileNotFoundError:
        print(f"❌ Model not found: {VISUAL_MEMORY_MODEL_PATH}")
        continue
    except Exception as e:
        print(f"❌ Error loading model: {e}")
        continue

    # Find and load test dataset
    test_files = glob.glob(TEST_DATASET_PATH)

    if test_files:
        test_path = test_files[0]  # Use the first matching file
        print(f"📁 Loading test dataset from: {test_path}")

        try:
            with open(test_path, 'rb') as f:
                test_data = pickle.load(f)

            # Create simple dataset class
            class SimpleTestDataset:
                def __init__(self, data_dict):
                    if isinstance(data_dict, dict):
                        self.samples = data_dict['samples']
                        self.labels = data_dict['labels']
                        self.metadata = data_dict['metadata']
                        self.num_distractors = data_dict.get('num_distractors', num)
                    else:
                        # If it's the dataset object itself
                        self.samples = data_dict.samples
                        self.labels = data_dict.labels
                        self.metadata = data_dict.metadata
                        self.num_distractors = data_dict.num_distractors

                def __len__(self):
                    return len(self.samples)

                def __getitem__(self, idx):
                    return self.samples[idx], self.labels[idx]

                def get_sample_with_metadata(self, idx):
                    return self.samples[idx], self.labels[idx], self.metadata[idx]

            test_dataset = SimpleTestDataset(test_data)
            print(f"✅ Test dataset loaded: {len(test_dataset)} samples")

            # Create data loader
            test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
            hidden_states_data, metadata = extract_hidden_states_dataset(
                    model, test_dataset, device, 1000
                )
            digit_states = {}
            digit_states['Target'] = {}
            digit_states['Probe'] = {}
            for digit in range(10):
              digit_samples = [d for d in hidden_states_data if d['target_digit'] == digit]
              if len(digit_samples) > 0:
                  # Collect all states for this digit at this timestep
                  states = [sample['hidden_states'][0] for sample in digit_samples]
                  digit_states['Target'][digit] = np.array(states)

                  states = [sample['hidden_states'][-1] for sample in digit_samples]
                  digit_states['Probe'][digit] = np.array(states)
            timestep_names = ['Target', 'Probe']
            rdms = {}
            for idx, timestep_name in enumerate(timestep_names):
                    if idx >= 6:
                        break

                    # Get available digits for this timestep
                    available_digits = sorted([d for d in digit_states[timestep_name].keys()
                                            if len(digit_states[timestep_name][d]) > 0])

                    if len(available_digits) < 2:
                        continue

                    # Calculate average representation for each digit
                    avg_representations = []
                    digit_labels = []

                    for digit in available_digits:
                        states = digit_states[timestep_name][digit]
                        avg_state = np.mean(states, axis=0)
                        avg_representations.append(avg_state)
                        digit_labels.append(digit)

                    avg_representations = np.array(avg_representations)  # [n_digits, hidden_dim]



                        # Cosine dissimilarity = 1 - cosine similarity
                    similarity_matrix = cosine_similarity(avg_representations)
                    rdm = 1 - similarity_matrix
                    rdms[timestep_name] = rdm

            t1 = 'Target'
            t2 = 'Probe'
            rdm1_upper = rdms[t1][np.triu_indices_from(rdms[t1], k=1)]
            rdm2_upper = rdms[t2][np.triu_indices_from(rdms[t2], k=1)]

            if len(rdm1_upper) > 0 and len(rdm2_upper) > 0:
                corr, _ = spearmanr(rdm1_upper, rdm2_upper)

            print(corr)
            results_dict[num] = {
                'overall_accuracy': corr,
                'full_results': results
            }
            accuracy_scores.append(corr)
            noise_levels.append(num)

        except Exception as e:
            print(f"❌ Error processing dataset: {e}")
            continue

    else:
        print(f"❌ No test dataset found matching: {TEST_DATASET_PATH}")
        continue

In [None]:
print(f"\n{'='*60}")
print("SUMMARY OF RESULTS")
print(f"{'='*60}")
for num in data_num:
    if num in results_dict:
        acc = results_dict[num]['overall_accuracy']
        print(f"Noise {num:2d}: {acc:.4f} ({acc*100:.2f}%)")
    else:
        print(f"Noise {num:2d}: No results (missing files)")

# Plot results
if accuracy_scores and noise_levels:
    plt.figure(figsize=(10, 6))
    plt.plot(noise_levels, accuracy_scores, 'bo-', linewidth=2, markersize=8)
    plt.xlabel('Number of Distractors (Noise Level)', fontsize=12)
    plt.ylabel('Correlation Value', fontsize=12)
    plt.title('Correlation between Target RDM and Probe RDM', fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.ylim(0.5, 1)
    plt.xticks(noise_levels)
    # Add value labels on points
    for i, (x, y) in enumerate(zip(noise_levels, accuracy_scores)):
        plt.annotate(f'{y:.3f}', (x, y), textcoords="offset points",
                    xytext=(0,10), ha='center', fontsize=10)

    plt.tight_layout()
    #plt.savefig('model_accuracy_vs_noise.png', dpi=300, bbox_inches='tight')
    plt.show()