In [97]:
%matplotlib inline
import os, sys, re
import glob

import pandas as pd
import numpy as np
import torch
import torch.utils.data
import torch.nn

from random import randrange
from PIL import Image
import matplotlib.pyplot as plt

In [98]:
import argparse
""" Training and hyperparameter search configurations """
curr_dir = os.getcwd()

parser = argparse.ArgumentParser(description='Alzheimer Classification Tester')
parser.add_argument('--oasis2_path', type=str, default='/Users/valenetjong/Downloads/OAS2_RAW_PART1',
                    help='directory to oasis 2 download')
parser.add_argument('--img_dir', type=str, default='/Users/valenetjong/alzheimer-classification/oasis2',
                    help='directory for image storage')
parser.add_argument('--oasis2csv_path', type=str, default='/Users/valenetjong/alzheimer-classification/datacsv/oasis_longitudinal.csv',
                    help='path to oasis 2 csv')
parser.add_argument('--process_flag', type=bool, default=False,
                    help="extract files from disk if True, use already extracted files, if False")
parser.add_argument('--create_dataset', type=bool, default=True,
                    help="create dataset from scratch if True, load in processed dataset if False")
parser.add_argument('--best_custom_model_path', type=str, default='/Users/valenetjong/alzheimer-classification/models/DeepCNNModel_epoch33.pt',
                    help='path to best custom model for testing')
parser.add_argument('--best_resnet_model_path', type=str, default='/Users/valenetjong/alzheimer-classification/models/ResNet.pt',
                    help='path to best resnet model for testing')
parser.add_argument('--num_classes', type=int, default=3,
                    help='number of classes')
parser.add_argument('--seed', type=int, default=1,
                    help='random seed (default: 1)')
args = parser.parse_args('')
# Set random seed to reproduce results
torch.manual_seed(args.seed)

<torch._C.Generator at 0x1275292b0>

## Data Loading and Processing

In [99]:
import nibabel as nib
import os
import glob

DEMENTIA_MAP = {
    '0.0': "nondemented",
    '0.5': "mildly demented",
    '1.0': 'moderately demented',
}

def convert_and_rename_hdr_img_to_nifti(base_dir, output_dir, oasis2_csv_path, slice_idx=140):
    """
    Convert and rename .hdr/.img files to .nifti format.

    Parameters:
    base_dir (str): Base directory containing the subdirectories.
    output_dir (str): Directory where .nifti files will be saved.
    oasis2_csv_path: Path to the CSV file containing Oasis 2 metadata.
    """
    oasis_df = pd.read_csv(oasis2_csv_path)

    # Check if the output directory exists, if not, create it
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Search for the specific subdirectories
    subdirectories = glob.glob(os.path.join(base_dir, 'OAS2_*_MR1'))

    for sub_dir in subdirectories:
        # Construct the path to the RAW folder
        raw_dir = os.path.join(sub_dir, 'RAW')

        # Check if RAW directory exists
        if os.path.exists(raw_dir):
            # Construct the file path for hdr file
            hdr_file = os.path.join(raw_dir, 'mpr-2.nifti.hdr')
            pid = os.path.basename(sub_dir)
            row = oasis_df.loc[oasis_df['MRI ID'] == pid]
            dementia_type = str(row['CDR'].item())

            # Check if hdr file exists
            if os.path.exists(hdr_file):
                # Load the image (this should automatically include the associated .img file)
                img = nib.load(hdr_file)
                data = img.get_fdata()
                if len(data.shape) == 4:
                    data = data[:, :, :, 0]
                slice_2d = data[:, slice_idx, :]
                normalized_slice = (slice_2d - np.min(slice_2d)) / (np.max(slice_2d) - np.min(slice_2d))
                scaled_slice = (255 * normalized_slice).astype(np.uint8)

                # Save the slice as a PNG
                image = Image.fromarray(scaled_slice)
                png_file_name = os.path.basename(hdr_file).replace('.nifti.hdr', '.png')
                processed_dir = os.path.join(output_dir, DEMENTIA_MAP[dementia_type])
                os.makedirs(processed_dir, exist_ok=True)
                png_file_path = os.path.join(processed_dir, f'{pid}_{png_file_name}')
                image.save(png_file_path)
                print(f'Processed and saved {png_file_path}')
            else:
                print(f'No .hdr file found in {raw_dir}')
        else:
            print(f'No RAW directory found in {sub_dir}')

if args.process_flag:
    in_dir = args.oasis2_path
    out_dir = os.path.join(args.img_dir, 'raw')
    oasis2_csv = args.oasis2csv_path
    convert_and_rename_hdr_img_to_nifti(in_dir, out_dir, oasis2_csv)

In [100]:
import cv2 as cv
import tempfile
import shutil
import skimage.exposure

""" Pre-processing Functions """
# Pre-determined max dimensions of cropped images
CONV_WIDTH = 137
CONV_HEIGHT = 167

def convert_to_grayscale(img):
    """
    Converts an image to grayscale. Handles images with alpha channel.
    """
    if img.mode in ["RGBA", "LA"] or (img.mode == "P" and 'transparency' in img.info):
        # Use alpha channel as mask
        alpha = img.split()[-1]
        bg = Image.new("RGB", img.size, (255, 255, 255))
        bg.paste(img, mask=alpha)
        return np.array(bg.convert('L'))
    else:
        return np.array(img.convert('L'))

def normalize_intensity(img):
    """
    Normalizes the intensity of an image to the range [0, 255].

    Parameters:
    img: The image to be normalized.

    Returns:
    Normalized image.
    """
    img_min = img.min()
    img_max = img.max()
    normalized_img = (img - img_min) / (img_max - img_min) * 255
    return normalized_img.astype(np.uint8)

def apply_low_pass_filter(img, kernel_size=3):
    """
    Applies a Gaussian blur (low-pass filter) to the image.

    Parameters:
    img: The image to be filtered.
    kernel_size: Size of the Gaussian kernel.

    Returns:
    Filtered image.
    """
    return cv.GaussianBlur(img, (kernel_size, kernel_size), 0)

def increase_intensity(img):
    """
    Increases the intensity of an image using contrast stretching.

    Parameters:
    img: The image whose intensity is to be increased.

    Returns:
    Image with increased intensity.
    """
    p2, p98 = np.percentile(img, (2, 98))
    return skimage.exposure.rescale_intensity(img, in_range=(p2, p98))

def pad_image_to_size(img, width, height):
    """
    Pads an image with zeros to the specified width and height.

    Parameters:
    img: The image to be padded.
    width: The desired width.
    height: The desired height.

    Returns:
    Padded image.
    """
    padded_img = np.zeros((height, width), dtype=img.dtype)
    y_offset = (height - img.shape[0]) // 2
    x_offset = (width - img.shape[1]) // 2
    padded_img[y_offset:y_offset+img.shape[0], x_offset:x_offset+img.shape[1]] = img
    return padded_img

def crop_black_boundary(mri_image, kernel_size=50):
    """
    Crops the black boundary from an MRI image, while ignoring small noise within the black regions.

    Parameters:
    mri_image: Input MRI image.

    Returns:
    Cropped MRI image with black boundaries removed.
    """
    # Thresholding to get the binary image for contour detection
    _, thresh = cv.threshold(mri_image, 1, 255, cv.THRESH_BINARY)

    # Apply morphological operations to remove small noise
    kernel = np.ones((kernel_size, kernel_size), np.uint8)
    cleaned = cv.morphologyEx(thresh, cv.MORPH_OPEN, kernel)

    # Finding contours
    contours, _ = cv.findContours(cleaned, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)

    # If no contours found, return original image
    if not contours:
        return mri_image

    largest_contour = max(contours, key=cv.contourArea)
    x, y, w, h = cv.boundingRect(largest_contour)
    cropped_image = mri_image[y:y+h, x:x+w]
    return cropped_image

def process_image(fn, target_subdir):
    """
    Processes a single MRI image file.
    """
    with Image.open(fn) as img:
        img_gray = convert_to_grayscale(img)

    # Apply low-pass filter
    img_filtered = apply_low_pass_filter(img_gray)

    # Increase the intensity
    img_enhanced = increase_intensity(img_filtered)

    # Crop the black boundary
    img_cropped = crop_black_boundary(img_enhanced)

    img_height, img_width = img_cropped.shape

    os.makedirs(target_subdir, exist_ok=True)
    target_path = os.path.join(target_subdir, os.path.basename(fn))
    cv.imwrite(target_path, img_cropped)

    return img_height, img_width

def extract_files(base_dir, target_dir):
    """
    Extracts and processes MRI files from a given directory and its subdirectories.

    Parameters:
    base_dir: Directory containing MRI files.
    target_dir: Directory where processed files will be saved.
    """
    HEIGHT = 0
    WIDTH = 0

    for subdir, _, files in os.walk(base_dir):
        for file in files:
            if file.lower().endswith('.png'):
                source_path = os.path.join(subdir, file)
                relative_path = os.path.relpath(subdir, base_dir)
                target_subdir = os.path.join(target_dir, relative_path)
                
                img_height, img_width = process_image(source_path, target_subdir)
                HEIGHT = max(HEIGHT, img_height)
                WIDTH = max(WIDTH, img_width)
    return HEIGHT, WIDTH

if args.process_flag:
    in_dir = os.path.join(args.img_dir, 'raw')
    out_dir = os.path.join(args.img_dir, 'processed')
    extract_files(in_dir, out_dir)

In [101]:
import cv2 as cv
import os
import numpy as np
import skimage.exposure

CONV_WIDTH = 137
CONV_HEIGHT = 167

def increase_contrast(image):
    """
    Increases the contrast of an image using adaptive histogram equalization.
    """
    # Convert image to float and scale to range 0-1
    img_float = image.astype(np.float32) / 255
    # Apply adaptive equalization
    img_contrast = skimage.exposure.equalize_adapthist(img_float)
    # Scale back to range 0-255 and return as uint8
    return (img_contrast * 255).astype(np.uint8)

def process_directory(input_dir, output_dir):
    """
    Reads, enhances contrast, and resizes images from the input directory,
    and saves them to the output directory while preserving the subdirectory structure.
    """
    for subdir, _, files in os.walk(input_dir):
        for filename in files:
            if filename.lower().endswith('.png'):
                file_path = os.path.join(subdir, filename)

                # Create corresponding subdirectory in output directory
                relative_path = os.path.relpath(subdir, input_dir)
                output_subdir = os.path.join(output_dir, relative_path)
                os.makedirs(output_subdir, exist_ok=True)

                # Process the image
                img = cv.imread(file_path, cv.IMREAD_GRAYSCALE)
                img_enhanced = increase_contrast(img)
                img_resized = cv.resize(img_enhanced, (CONV_WIDTH, CONV_HEIGHT), interpolation=cv.INTER_AREA)

                # Save the modified image
                output_path = os.path.join(output_subdir, filename)
                cv.imwrite(output_path, img_resized)

if args.process_flag:
    in_dir = os.path.join(args.img_dir, 'processed')
    out_dir = os.path.join(args.img_dir, 'modified')
    process_directory(in_dir, out_dir)

In [102]:
import os
import torch
from torchvision import transforms
from PIL import Image
from collections import Counter

# Assuming args is defined and has num_classes attribute
# args.num_classes = 2 or 3 based on your requirement

LABEL_MAP = {
    "nondemented": 0,
    "mildly demented": 1,
    'moderately demented': 1 if args.num_classes == 2 else 2
}

def load_dataset(base_dir):
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    all_images = []
    all_labels = []
    all_pids = []
    class_counts = Counter()

    # Automatically find all subdirectories in base_dir
    for folder_name in os.listdir(base_dir):
        folder_path = os.path.join(base_dir, folder_name)
        if os.path.isdir(folder_path):  # Check if it's a directory
            class_label = LABEL_MAP[folder_name]
            for image_file in os.listdir(folder_path):
                if image_file == '.DS_Store':
                    continue  # Skip .DS_Store files
                image_path = os.path.join(folder_path, image_file)
                if os.path.isfile(image_path):
                    pid = '_'.join(os.path.basename(image_path).split('_')[:3])
                    all_pids.append(pid)
                    with Image.open(image_path) as img:
                        img_tensor = transform(img)
                        all_images.append(img_tensor)
                        all_labels.append(class_label)
                        class_counts[folder_name] += 1

    X = torch.stack(all_images)
    y = torch.tensor(all_labels, dtype=torch.long)  # Changed to long for integer labels
    return X, y, class_counts, all_pids

# Example usage
# Set args values or replace args.img_dir and args.create_dataset with appropriate values
if args.create_dataset:
    X, y, class_counts, all_pids = load_dataset(os.path.join(args.img_dir, 'modified'))

    print(f"Combined Tensor Size: {X.size()}")
    print(f"Labels Tensor Size: {y.size()}")
    print(f"Class Counts: {class_counts}")

Combined Tensor Size: torch.Size([54, 1, 167, 137])
Labels Tensor Size: torch.Size([54])
Class Counts: Counter({'nondemented': 38, 'mildly demented': 13, 'moderately demented': 3})


In [103]:
if args.create_dataset:
    print(f"Number of nondemented in train dataset as percentage: {((y == 0).sum() / (X.shape[0])) * 100:0.2f}%")
    print(f"Number of mildly demented in train dataset as percentage: {((y == 1).sum() / (X.shape[0])) * 100:0.2f}%")
    print(f"Number of moderately demented in train dataset as percentage: {((y == 2).sum() / (X.shape[0])) * 100:0.2f}%")

Number of nondemented in train dataset as percentage: 70.37%
Number of mildly demented in train dataset as percentage: 24.07%
Number of moderately demented in train dataset as percentage: 5.56%


## Models

In [104]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DeepCNNModel(nn.Module):
    def __init__(self, fc_size, conv_in_size, conv_hid_size, conv_out_size, dropout, num_classes=3):
        super(DeepCNNModel, self).__init__()
        
        # Convolutional Block 1
        self.conv1 = nn.Conv2d(1, conv_in_size, kernel_size=3, padding=1)  
        self.bn1 = nn.BatchNorm2d(conv_in_size)
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        
        # Convolutional Block 2
        self.conv2 = nn.Conv2d(conv_in_size, conv_hid_size, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(conv_hid_size)
        self.pool2 = nn.MaxPool2d(kernel_size=3)
        
        # Convolutional Block 3
        self.conv3 = nn.Conv2d(conv_hid_size, conv_hid_size, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(conv_hid_size)
        self.pool3 = nn.MaxPool2d(kernel_size=2)
        
        # Convolutional Block 4
        self.conv4 = nn.Conv2d(conv_hid_size, conv_out_size, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(conv_out_size)
        self.pool4 = nn.MaxPool2d(kernel_size=3)

        # Compute the flattened size for the fully connected layer
        self._to_linear = None
        self._forward_conv(torch.randn(1, 1, 137, 167))

        # Fully connected layers
        self.fc1 = nn.Linear(self._to_linear, fc_size)
        self.dropout1 = nn.Dropout(p=dropout)
        self.fc2 = nn.Linear(fc_size, num_classes)
        self.dropout2 = nn.Dropout(p=dropout)

    def _forward_conv(self, x):
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        x = self.pool3(F.relu(self.bn3(self.conv3(x))))
        x = self.pool4(F.relu(self.bn4(self.conv4(x))))
        if self._to_linear is None:
            self._to_linear = x[0].shape[0] * x[0].shape[1] * x[0].shape[2]
        return x

    def forward(self, x):
        x = self._forward_conv(x)
        x = x.view(-1, self._to_linear)  # Flatten the output for the fully connected layers
        x = self.dropout1(F.relu(self.fc1(x)))
        x = self.dropout2(self.fc2(x))
        return F.log_softmax(x, dim=1)

## Testing

In [105]:
import torch
from torch.utils.data import TensorDataset, DataLoader

def test(model, test_dataset, criterion, batch_size=8, device='cpu'):
    """
    Test the PyTorch model and gather predictions for each class.

    Parameters:
    model (torch.nn.Module): The trained PyTorch model.
    test_dataset (torch.utils.data.Dataset): Dataset for testing.
    criterion (torch.nn.modules.loss): Loss function.
    device (str): Device to run the test ('cuda' or 'cpu').

    Returns:
    float: The average loss over the test dataset.
    float: The overall accuracy over the test dataset.
    list: A list of tuples, each containing a batch of true labels and predicted labels.
    """
    model.eval()  # Set the model to evaluation mode
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    all_predictions = []
    all_targets = []
    total_loss = 0.0
    correct_predictions = 0
    total_predictions = 0

    with torch.no_grad():  # Disable gradient computation
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            total_loss += loss.item()

            # Calculate accuracy
            _, predicted = torch.max(outputs.data, 1)
            total_predictions += targets.size(0)
            correct_predictions += (predicted == targets).sum().item()

            # Save the predictions and the targets for each batch
            all_targets.append(targets.cpu())
            all_predictions.append(predicted.cpu())

    avg_loss = total_loss / len(test_loader)
    accuracy = correct_predictions / total_predictions * 100

    return avg_loss, accuracy, all_targets, all_predictions

In [106]:
def calculate_class_weights(y):
    # Count the frequency of each class
    class_counts = Counter(y.numpy())
    total_samples = sum(class_counts.values())

    # Calculate weights: Inverse of frequency
    weights = {class_id: total_samples/class_counts[class_id] for class_id in class_counts}

    # Convert to a list in the order of class ids
    weights_list = [weights[i] for i in sorted(weights)]
    return torch.tensor(weights_list, dtype=torch.float32)

### Load in config from best run

In [107]:
import json

config_file_path = args.best_custom_model_path.replace('.pt', '.config')
with open(config_file_path, 'r') as file:
    config_dict = json.load(file)

print(config_dict)
fc_size = config_dict["fc_size"]["value"]
conv_in_size = config_dict["conv_in_size"]["value"]
conv_hid_size = config_dict["conv_hid_size"]["value"]
conv_out_size = config_dict["conv_out_size"]["value"]
dropout = config_dict["dropout"]["value"]
batch_size = config_dict["batch_size"]["value"]

print("fc_size:", fc_size)
print("conv_in_size:", conv_in_size)
print("conv_hid_size:", conv_hid_size)
print("conv_out_size:", conv_out_size)
print("dropout:", dropout)
print("batch_size:", batch_size)

{'lr': {'desc': None, 'value': 0.0001}, '_wandb': {'desc': None, 'value': {'t': {'1': [1, 5, 41, 53, 55], '2': [1, 5, 41, 53, 55], '3': [23, 37], '4': '3.11.4', '5': '0.16.1', '8': [1, 5], '13': 'darwin-x86_64'}, 'framework': 'torch', 'start_time': 1702631742.827628, 'cli_version': '0.16.1', 'is_jupyter_run': True, 'python_version': '3.11.4', 'is_kaggle_kernel': False}}, 'dropout': {'desc': None, 'value': 0.2}, 'fc_size': {'desc': None, 'value': 32}, 'batch_size': {'desc': None, 'value': 8}, 'max_epochs': {'desc': None, 'value': 250}, 'hidden_size': {'desc': None, 'value': 8}, 'conv_in_size': {'desc': None, 'value': 256}, 'conv_hid_size': {'desc': None, 'value': 32}, 'conv_out_size': {'desc': None, 'value': 16}}
fc_size: 32
conv_in_size: 256
conv_hid_size: 32
conv_out_size: 16
dropout: 0.2
batch_size: 8


### Test

In [108]:
model = DeepCNNModel(fc_size, conv_in_size, conv_hid_size, conv_out_size, dropout)
model.load_state_dict(torch.load(args.best_custom_model_path))
test_set = TensorDataset(X, y)
criterion = nn.CrossEntropyLoss(weight=calculate_class_weights(y))
avg_loss, accuracy, all_targets, all_predictions = test(model, test_set, criterion, batch_size)

print(all_targets)
print(all_predictions)

[tensor([2, 2, 2, 1, 1, 1, 1, 1]), tensor([1, 1, 1, 1, 1, 1, 1, 1]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0])]
[tensor([1, 1, 1, 1, 1, 0, 0, 1]), tensor([0, 0, 1, 0, 1, 1, 1, 0]), tensor([1, 1, 0, 0, 0, 1, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 1, 0]), tensor([0, 1, 1, 0, 1, 0, 1, 0]), tensor([1, 0, 1, 1, 0, 1, 1, 1]), tensor([1, 0, 1, 0, 0, 0])]


In [109]:
flat_targets = torch.cat(all_targets).flatten()
flat_predictions = torch.cat(all_predictions).flatten()
mismatch_indices = torch.where(flat_targets != flat_predictions)[0].tolist()
print(f"Average Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")

Average Loss: 1.0207, Accuracy: 53.70%


In [110]:
print(flat_targets)
print(flat_predictions)
print(mismatch_indices)

tensor([2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0])
tensor([1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0,
        0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1,
        1, 0, 1, 0, 0, 0])
[0, 1, 2, 5, 6, 8, 9, 11, 15, 16, 17, 21, 30, 33, 34, 36, 38, 40, 42, 43, 45, 46, 47, 48, 50]


In [111]:
print(all_pids)

['OAS2_0044_MR1', 'OAS2_0066_MR1', 'OAS2_0071_MR1', 'OAS2_0080_MR1', 'OAS2_0063_MR1', 'OAS2_0043_MR1', 'OAS2_0037_MR1', 'OAS2_0089_MR1', 'OAS2_0079_MR1', 'OAS2_0040_MR1', 'OAS2_0032_MR1', 'OAS2_0016_MR1', 'OAS2_0075_MR1', 'OAS2_0058_MR1', 'OAS2_0002_MR1', 'OAS2_0028_MR1', 'OAS2_0077_MR1', 'OAS2_0073_MR1', 'OAS2_0094_MR1', 'OAS2_0029_MR1', 'OAS2_0090_MR1', 'OAS2_0004_MR1', 'OAS2_0013_MR1', 'OAS2_0097_MR1', 'OAS2_0017_MR1', 'OAS2_0070_MR1', 'OAS2_0067_MR1', 'OAS2_0049_MR1', 'OAS2_0054_MR1', 'OAS2_0069_MR1', 'OAS2_0047_MR1', 'OAS2_0027_MR1', 'OAS2_0030_MR1', 'OAS2_0034_MR1', 'OAS2_0057_MR1', 'OAS2_0022_MR1', 'OAS2_0035_MR1', 'OAS2_0031_MR1', 'OAS2_0008_MR1', 'OAS2_0045_MR1', 'OAS2_0052_MR1', 'OAS2_0056_MR1', 'OAS2_0041_MR1', 'OAS2_0078_MR1', 'OAS2_0051_MR1', 'OAS2_0068_MR1', 'OAS2_0042_MR1', 'OAS2_0036_MR1', 'OAS2_0018_MR1', 'OAS2_0001_MR1', 'OAS2_0005_MR1', 'OAS2_0061_MR1', 'OAS2_0091_MR1', 'OAS2_0086_MR1']


In [112]:
print([all_pids[i] for i in mismatch_indices])

['OAS2_0044_MR1', 'OAS2_0066_MR1', 'OAS2_0071_MR1', 'OAS2_0043_MR1', 'OAS2_0037_MR1', 'OAS2_0079_MR1', 'OAS2_0040_MR1', 'OAS2_0016_MR1', 'OAS2_0028_MR1', 'OAS2_0077_MR1', 'OAS2_0073_MR1', 'OAS2_0004_MR1', 'OAS2_0047_MR1', 'OAS2_0034_MR1', 'OAS2_0057_MR1', 'OAS2_0035_MR1', 'OAS2_0008_MR1', 'OAS2_0052_MR1', 'OAS2_0041_MR1', 'OAS2_0078_MR1', 'OAS2_0068_MR1', 'OAS2_0042_MR1', 'OAS2_0036_MR1', 'OAS2_0018_MR1', 'OAS2_0005_MR1']


### Test ResNet

In [113]:
import torchvision.transforms as transforms

class GrayscaleToRGBTransform:
    def __call__(self, tensor):
        # Check if the tensor has one channel (grayscale)
        if tensor.shape[0] == 1:
            # Repeat the tensor across 3 channels
            tensor = tensor.repeat(3, 1, 1)
        return tensor

res_transform = transforms.Compose([
    GrayscaleToRGBTransform(),
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def apply_all_transforms(X, transform):
    transformed_data = []
    for x in X:
        x = transform(x) 
        transformed_data.append(x)
    return torch.stack(transformed_data)


X_test_resnet = apply_all_transforms(X, transform=res_transform)

In [117]:
from torchvision.models import resnet50, ResNet50_Weights
model = resnet50(weights=ResNet50_Weights.DEFAULT)
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, args.num_classes) 
model.load_state_dict(torch.load(args.best_resnet_model_path))

batch_size = 16
test_set = TensorDataset(X_test_resnet, y)
criterion = nn.CrossEntropyLoss(weight=calculate_class_weights(y))
print("batch_size:", batch_size)
avg_loss, accuracy, all_targets, all_predictions = test(model, test_set, criterion, batch_size)

batch_size: 16


In [118]:
print(f"Average Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")

Average Loss: 1.1389, Accuracy: 16.67%


In [119]:
print(all_targets)
print(all_predictions)

[tensor([2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]), tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0])]
[tensor([2, 2, 2, 1, 2, 1, 2, 1, 2, 2, 2, 1, 1, 1, 2, 2]), tensor([2, 1, 2, 2, 1, 2, 2, 1, 1, 1, 1, 2, 2, 1, 1, 1]), tensor([1, 2, 2, 1, 2, 1, 2, 2, 2, 1, 2, 2, 2, 1, 2, 2]), tensor([1, 2, 2, 1, 2, 1])]
