In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn as nn
import numpy as np
import os, json

import torch
from torchvision import models, transforms
from torch.autograd import Variable
import torch.nn.functional as F
import glob

In [None]:
DATA_PATH = "/kaggle/input/autism-image-data/AutismDataset"
MODEL_PATH = "/kaggle/input/autism-spectrum-detection-from-kaggle-zenodo/Model results/Model results/best_densenet201_autism.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
OUTPUT_PATH = "/kaggle/working"

In [None]:
def load_test_dataset():
    """Load and summarize test dataset structure and return metadata.
    Preview samples randomly drawn (up to 5 per class)."""
    test_path = os.path.join(DATA_PATH, "test")

    if not os.path.exists(test_path):
        print(f"Test directory not found: {test_path}")
        return None

    # Get all image files
    image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp']
    all_images = []

    for ext in image_extensions:
        all_images.extend(glob.glob(os.path.join(test_path, ext)))
        all_images.extend(glob.glob(os.path.join(test_path, ext.upper())))

    autistic_images = [img for img in all_images if 'Autistic.' in os.path.basename(img)]
    non_autistic_images = [img for img in all_images if 'Non_Autistic.' in os.path.basename(img)]

    print("Test Dataset Summary:")
    print(f"  Total images: {len(all_images)}")
    print(f"  Autistic images: {len(autistic_images)}")
    print(f"  Non-Autistic images: {len(non_autistic_images)}")

    import random
    sample_autistic = random.sample(autistic_images, min(5, len(autistic_images))) if autistic_images else []
    sample_non_autistic = random.sample(non_autistic_images, min(5, len(non_autistic_images))) if non_autistic_images else []

    return {
        'all_images': all_images,
        'autistic_images': autistic_images,
        'non_autistic_images': non_autistic_images,
        'sample_autistic': sample_autistic,
        'sample_non_autistic': sample_non_autistic
    }

In [None]:
# Initial dataset loading (selection occurs AFTER predictions)
import random

test_dataset = load_test_dataset()

selected_image_paths = []  # will be populated after predictions

if not test_dataset:
    print("No dataset information available. Please check the data path.")
else:
    print("Dataset loaded. Run model + prediction cells to generate selection.")

In [None]:
# resize and take the center part of image to what our model expects
def get_input_transform():
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.ConvertImageDtype(torch.float),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet normalization
    ])   

    return transform

def get_input_tensors(img):
    transform = get_input_transform()
    # unsqueeze converts single image to batch of 1
    return transform(img).unsqueeze(0)

In [None]:
def create_densenet201_model(num_classes=2, pretrained=True, dropout_rate=0.5):
    """
    Create DenseNet201 model with custom classification head
    
    Args:
        num_classes (int): Number of output classes
        pretrained (bool): Whether to use pretrained weights
        dropout_rate (float): Dropout rate for regularization
    """
    # Load pretrained DenseNet201
    model = models.densenet201(pretrained=pretrained)
    
    # Get the number of features from the classifier
    num_features = model.classifier.in_features
    
    # Replace the classifier with our custom head
    model.classifier = nn.Sequential(
        nn.Dropout(dropout_rate),
        nn.Linear(num_features, 512),
        nn.BatchNorm1d(512),
        nn.ReLU(inplace=True),
        nn.Dropout(dropout_rate/2),
        nn.Linear(512, 256),
        nn.BatchNorm1d(256),
        nn.ReLU(inplace=True),
        nn.Dropout(dropout_rate/4),
        nn.Linear(256, num_classes)
    )
    
    return model

In [None]:
model = create_densenet201_model(num_classes=2)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.to(DEVICE)
print("model loaded successfully")

In [None]:
def select_samples_by_prediction(results, total=6, ratio_correct=0.5, shuffle=True, seed=None):
    """Select image records based on model prediction correctness balance."""
    if not results:
        return [], []
    import random
    rnd = random.Random(seed)
    correct = [r for r in results if r['correct']]
    incorrect = [r for r in results if not r['correct']]
    if shuffle:
        rnd.shuffle(correct)
        rnd.shuffle(incorrect)
    need_correct = int(round(total * ratio_correct))
    need_incorrect = total - need_correct
    sel_correct = correct[:need_correct]
    sel_incorrect = incorrect[:need_incorrect]
    if len(sel_correct) < need_correct:
        deficit = need_correct - len(sel_correct)
        extras = incorrect[need_incorrect:need_incorrect+deficit]
        sel_incorrect.extend(extras)
    if len(sel_incorrect) < need_incorrect:
        deficit = need_incorrect - len(sel_incorrect)
        extras = correct[need_correct:need_correct+deficit]
        sel_correct.extend(extras)
    selected = sel_correct + sel_incorrect
    if shuffle:
        rnd.shuffle(selected)
    return selected, [r['path'] for r in selected]

In [None]:
# Helper: compute predictions over all test images
from typing import List, Dict

def compute_predictions(test_dataset, batch_size=16):
    if not test_dataset:
        print("No dataset info provided.")
        return []
    image_paths = test_dataset['all_images']
    results = []

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.ConvertImageDtype(torch.float),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    softmax = torch.nn.Softmax(dim=1)
    model.eval()

    def true_label_from_name(p):
        name = os.path.basename(p)
        if 'Non_Autistic.' in name:
            return 0
        if 'Autistic.' in name:
            return 1
        return None

    batch = []
    batch_meta = []

    def flush_batch():
        if not batch:
            return
        with torch.no_grad():
            tensor_batch = torch.stack(batch).to(DEVICE)
            logits = model(tensor_batch)
            probs = softmax(logits)
            confs, preds = torch.max(probs, 1)
            for meta, pred, conf, prob_vec in zip(batch_meta, preds.cpu(), confs.cpu(), probs.cpu()):
                results.append({
                    'path': meta['path'],
                    'true': meta['true'],
                    'pred': int(pred.item()),
                    'conf': float(conf.item()),
                    'prob_autistic': float(prob_vec[1].item()),
                    'correct': (meta['true'] is not None and int(pred.item()) == meta['true'])
                })
        batch.clear()
        batch_meta.clear()

    for p in image_paths:
        tlabel = true_label_from_name(p)
        try:
            img = Image.open(p).convert('RGB')
            batch.append(transform(img))
            batch_meta.append({'path': p, 'true': tlabel})
            if len(batch) == batch_size:
                flush_batch()
        except Exception as e:
            print(f"Skipping {p}: {e}")
    flush_batch()

    return results

In [None]:
# If your dataset labels were 0 = non-autistic, 1 = autistic
idx2label = ["non-autistic", "autistic"]
cls2idx = {"non-autistic": 1, "autistic": 0}
cls2label = {"non-autistic": "non-autistic", "autistic": "autistic"}

In [None]:
model.eval()  # put model in eval mode

for img in selected_image_paths:
    img_t = get_input_tensors(Image.open(img)).to(DEVICE)
    with torch.no_grad():
        output = model(img_t)
    _, predicted = torch.max(output, 1)
    predicted_label = idx2label[predicted.item()]
    print(f"Image: {os.path.basename(img)} | Predicted: {predicted_label}")

In [None]:
# Compute predictions and select sample set
prediction_records = compute_predictions(test_dataset)
print(f"Computed predictions for {len(prediction_records)} images")

sample_records, selected_image_paths = select_samples_by_prediction(
    prediction_records,
    total=6,
    ratio_correct=0.8,
    shuffle=True,
    seed=None,
)

print("Selected sample set (filename | prob_autistic | correct):")
for r in sample_records:
    status = 'OK' if r['correct'] else 'WRONG'
    print(f"  {os.path.basename(r['path'])} | {r['prob_autistic']:.3f} | {status}")

In [None]:
# Visualize selected samples with prediction status
if selected_image_paths:
    cols = min(3, len(selected_image_paths))
    rows = (len(selected_image_paths) + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 4 * rows))
    if rows == 1:
        axes = np.array([axes])
    axes = axes.flatten()
    meta_by_path = {r['path']: r for r in sample_records}
    for i, p in enumerate(selected_image_paths):
        img = Image.open(p)
        axes[i].imshow(img)
        rec = meta_by_path.get(p, {})
        status = 'OK' if rec.get('correct') else 'WRONG'
        prob = rec.get('prob_autistic')
        prob_txt = f"p_aut={prob:.2f}" if prob is not None else ''
        axes[i].set_title(f"{os.path.basename(p)}\n{status} {prob_txt}", fontsize=10)
        axes[i].axis('off')
    for j in range(i + 1, len(axes)):
        axes[j].axis('off')
    plt.suptitle('Selected Samples (Post-Prediction)', fontsize=14)
    plt.tight_layout()
    plt.show()
else:
    print("No selected images to visualize yet.")