In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
cd drive/MyDrive/Neuromatch_project/

/content/drive/MyDrive/Neuromatch_project


In [3]:
!pip install torch torchvision matplotlib seaborn tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
from tqdm import tqdm
import random
import pickle
import os

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"📱 Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

plt.style.use('default')
plt.rcParams['figure.figsize'] = (12, 8)

print("✅ Setup complete!")

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [4]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(128 * 7 * 7, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = F.relu(self.conv3(x))
        x = self.dropout1(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)
        return x

    def extract_features(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = F.relu(self.conv3(x))
        x = self.dropout1(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return x

# CNN Feature Extractor (frozen)
class CNNFeatureExtractor(nn.Module):
    def __init__(self, pretrained_cnn_path, feature_dim=512):
        super(CNNFeatureExtractor, self).__init__()
        self.cnn = SimpleCNN()
        self.cnn.load_state_dict(torch.load(pretrained_cnn_path, map_location=device))
        for param in self.cnn.parameters():
            param.requires_grad = False
        self.cnn.eval()
        self.feature_dim = feature_dim

    def forward(self, x):
        with torch.no_grad():
            features = self.cnn.extract_features(x)
        return features

# Visual Memory Model
class VisualMemoryModel(nn.Module):
    def __init__(self, pretrained_cnn_path, rnn_hidden_dim=256,
                 projection_dim=128, rnn_type='LSTM', num_layers=2, dropout=0.3):
        super(VisualMemoryModel, self).__init__()

        self.cnn_features = CNNFeatureExtractor(pretrained_cnn_path)
        cnn_feature_dim = self.cnn_features.feature_dim

        self.feature_projection = nn.Sequential(
            nn.Linear(cnn_feature_dim, projection_dim * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(projection_dim * 2, projection_dim),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5)
        )

        self.rnn_type = rnn_type
        if rnn_type == 'LSTM':
            self.rnn = nn.LSTM(
                projection_dim, rnn_hidden_dim,
                num_layers=num_layers,
                batch_first=True,
                dropout=dropout if num_layers > 1 else 0
            )
        elif rnn_type == 'GRU':
            self.rnn = nn.GRU(
                projection_dim, rnn_hidden_dim,
                num_layers=num_layers,
                batch_first=True,
                dropout=dropout if num_layers > 1 else 0
            )

        self.memory_classifier = nn.Sequential(
            nn.Linear(rnn_hidden_dim, rnn_hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(rnn_hidden_dim // 2, 64),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5),
            nn.Linear(64, 2)
        )

        self.rnn_hidden_dim = rnn_hidden_dim
        self.num_layers = num_layers
        self.projection_dim = projection_dim

    def forward(self, x):
        batch_size, seq_len = x.size(0), x.size(1)
        x = x.view(batch_size * seq_len, *x.shape[2:])
        cnn_features = self.cnn_features(x)
        projected_features = self.feature_projection(cnn_features)
        projected_features = projected_features.view(batch_size, seq_len, -1)

        if self.rnn_type == 'LSTM':
            rnn_output, (hidden, cell) = self.rnn(projected_features)
        else:
            rnn_output, hidden = self.rnn(projected_features)

        final_output = rnn_output[:, -1, :]
        logits = self.memory_classifier(final_output)
        return logits

    def get_hidden_states(self, x):
        batch_size, seq_len = x.size(0), x.size(1)
        x = x.view(batch_size * seq_len, *x.shape[2:])
        cnn_features = self.cnn_features(x)
        projected_features = self.feature_projection(cnn_features)
        projected_features = projected_features.view(batch_size, seq_len, -1)

        if self.rnn_type == 'LSTM':
            rnn_output, _ = self.rnn(projected_features)
        else:
            rnn_output, _ = self.rnn(projected_features)

        return rnn_output, projected_features

print("✅ Model architecture loaded!")

✅ Model architecture loaded!


In [6]:
PRETRAINED_CNN_PATH = 'small_cnn_model.pth'
VISUAL_MEMORY_MODEL_PATH = './RNN_model/visual_memory_GRU_noise2.pth' # change this

try:
    # Create model instance
    model = VisualMemoryModel(
        pretrained_cnn_path=PRETRAINED_CNN_PATH,
        rnn_hidden_dim=256,
        projection_dim=128,
        rnn_type='GRU',
        num_layers=2,
        dropout=0.3
    ).to(device)

    model.load_state_dict(torch.load(VISUAL_MEMORY_MODEL_PATH, map_location=device))
    model.eval()

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)


except FileNotFoundError as e:
    print(f"❌ Error loading model: {e}")
    print("📝 Make sure these files exist:")
    print(f"   - {PRETRAINED_CNN_PATH}")
    print(f"   - {VISUAL_MEMORY_MODEL_PATH}")
    print("\n💡 Tip: Update the file paths in the cell above to match your files")

In [7]:

TEST_DATASET_PATH = './my_datasets/test_1000samples_2dist_20250725_0016.pkl'  # Update this!

import glob
test_files = glob.glob(TEST_DATASET_PATH)

if test_files:
    test_path = test_files[0]  # Use the first matching file
    print(f"📁 Loading test dataset from: {test_path}")

    with open(test_path, 'rb') as f:
        test_data = pickle.load(f)

    # Create simple dataset class
    class SimpleTestDataset:
        def __init__(self, data_dict):
            if isinstance(data_dict, dict):
                self.samples = data_dict['samples']
                self.labels = data_dict['labels']
                self.metadata = data_dict['metadata']
                self.num_distractors = data_dict.get('num_distractors', 3)
            else:
                # If it's the dataset object itself
                self.samples = data_dict.samples
                self.labels = data_dict.labels
                self.metadata = data_dict.metadata
                self.num_distractors = data_dict.num_distractors

        def __len__(self):
            return len(self.samples)

        def __getitem__(self, idx):
            return self.samples[idx], self.labels[idx]

        def get_sample_with_metadata(self, idx):
            return self.samples[idx], self.labels[idx], self.metadata[idx]

    test_dataset = SimpleTestDataset(test_data)
    print(f"✅ Test dataset loaded: {len(test_dataset)} samples")

else:
    raise FileNotFoundError("No test dataset found")


# Create DataLoader
if 'test_dataset' in locals() and test_dataset is not None:
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
    print(f"📊 Test DataLoader created: {len(test_loader)} batches")
else:
    print("❌ Test dataset not available")

📁 Loading test dataset from: ./my_datasets/test_1000samples_2dist_20250725_0016.pkl
✅ Test dataset loaded: 1000 samples
📊 Test DataLoader created: 32 batches


In [8]:
if 'test_dataset' in locals() and test_dataset is not None:

    def test_model_accuracy(model, test_loader, device):
        """Test overall model accuracy"""
        model.eval()
        correct = 0
        total = 0
        match_correct = 0
        match_total = 0
        nomatch_correct = 0
        nomatch_total = 0

        all_predictions = []
        all_targets = []
        all_confidences = []

        with torch.no_grad():
            for sequences, targets in tqdm(test_loader, desc="Testing"):
                sequences, targets = sequences.to(device), targets.to(device)
                outputs = model(sequences)
                probabilities = F.softmax(outputs, dim=1)
                _, predicted = outputs.max(1)

                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

                # Store predictions and confidences
                all_predictions.extend(predicted.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())
                all_confidences.extend(probabilities.max(1)[0].cpu().numpy())

                # Separate accuracy for match vs no-match
                match_mask = (targets == 1)
                nomatch_mask = (targets == 0)

                if match_mask.sum() > 0:
                    match_correct += predicted[match_mask].eq(targets[match_mask]).sum().item()
                    match_total += match_mask.sum().item()

                if nomatch_mask.sum() > 0:
                    nomatch_correct += predicted[nomatch_mask].eq(targets[nomatch_mask]).sum().item()
                    nomatch_total += nomatch_mask.sum().item()

        overall_acc = 100. * correct / total
        match_acc = 100. * match_correct / max(match_total, 1)
        nomatch_acc = 100. * nomatch_correct / max(nomatch_total, 1)
        avg_confidence = np.mean(all_confidences)

        return {
            'overall_accuracy': overall_acc,
            'match_accuracy': match_acc,
            'nomatch_accuracy': nomatch_acc,
            'total_samples': total,
            'match_samples': match_total,
            'nomatch_samples': nomatch_total,
            'average_confidence': avg_confidence,
            'predictions': all_predictions,
            'targets': all_targets,
            'confidences': all_confidences
        }

    # Run test
    results = test_model_accuracy(model, test_loader, device)

    print(f"\n📊 TEST RESULTS:")
    print(f"=" * 40)
    print(f"Overall Accuracy: {results['overall_accuracy']:.2f}%")
    print(f"Match Trials Accuracy: {results['match_accuracy']:.2f}% ({results['match_samples']} samples)")
    print(f"No-Match Trials Accuracy: {results['nomatch_accuracy']:.2f}% ({results['nomatch_samples']} samples)")
    print(f"Average Confidence: {results['average_confidence']:.3f}")
    print(f"Total Samples: {results['total_samples']}")

    # Performance interpretation
    print(f"\n💡 PERFORMANCE INTERPRETATION:")
    if results['overall_accuracy'] > 90:
        print("🎉 Excellent performance!")
    elif results['overall_accuracy'] > 80:
        print("✅ Good performance!")
    elif results['overall_accuracy'] > 70:
        print("📈 Decent performance")
    elif results['overall_accuracy'] > 50:
        print("⚠️ Poor performance")
    else:
        print("❌ Very poor performance - worse than random!")

    print(f"🎯 Random baseline: 50%")
    print(f"📊 Your model: {results['overall_accuracy']:.1f}%")
    print(f"📈 Improvement: +{results['overall_accuracy'] - 50:.1f} percentage points")

else:
    print("❌ Cannot run tests without test dataset")


Testing: 100%|██████████| 32/32 [00:00<00:00, 33.44it/s]


📊 TEST RESULTS:
Overall Accuracy: 94.10%
Match Trials Accuracy: 98.40% (500 samples)
No-Match Trials Accuracy: 89.80% (500 samples)
Average Confidence: 0.960
Total Samples: 1000

💡 PERFORMANCE INTERPRETATION:
🎉 Excellent performance!
🎯 Random baseline: 50%
📊 Your model: 94.1%
📈 Improvement: +44.1 percentage points





In [None]:
model.eval()
dataset = test_dataset

indices = np.random.choice(len(dataset), 10, replace=False)

with torch.no_grad():
    for i, idx in enumerate(indices):
        sequence, true_label = dataset[idx]
        sequence_batch = sequence.unsqueeze(0).to(device)
        hidden_states, projected_features = model.get_hidden_states(sequence_batch)


In [None]:
hidden_states[0].shape

torch.Size([5, 256])