In [None]:
import io
import os
import glob 
import random
import itertools
from collections import defaultdict
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import KLDivLoss
import torch.optim as optim
from torchvision.utils import make_grid
from torchvision import transforms, datasets, models
import torchvision.transforms as tt
from torch.utils.data import DataLoader, Dataset, random_split, ConcatDataset
from torch.utils.tensorboard import SummaryWriter
# from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm_notebook
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from torchsummary import summary


In [None]:
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(0)
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
df = pd.read_csv(<your_metadata.csv>)
df = df[df['label']>0] # Including BI-RADS 1-5
df['label'] = df['label']-1 # Rescoring to 0-4
file_paths = df['name'].to_list()
labels = df['label']
labels = labels.to_list()
print(file_paths)
print(labels)

In [None]:
def extract_identifier_view(file_path):
    filename = os.path.basename(file_path)
    identifier, view = filename.split('-')[0], '-'.join(filename.split('-')[1:])
    return identifier, view

# Create a mapping from file paths to labels
file_path_to_label = {file_path: label for file_path, label in zip(file_paths, labels)}

# Grouping the file paths by identifier
grouped_images = defaultdict(list)
for file_path in file_paths:
    identifier, view = extract_identifier_view(file_path)
    grouped_images[identifier].append((view, file_path))

# Pairing the images and creating new labels
paired_images = []
new_labels = []
for identifier, images in grouped_images.items():
    view_dict = {view: path for view, path in images}
    if file_path_to_label[view_dict['L-CC.png']] >= 0:
        if 'L-CC.png' in view_dict and 'L-MLO.png' in view_dict:
            paired_images.append((view_dict['L-CC.png'], view_dict['L-MLO.png']))
            new_labels.append(file_path_to_label[view_dict['L-CC.png']])
            
            paired_images.append((view_dict['L-MLO.png'], view_dict['L-CC.png']))
            new_labels.append(file_path_to_label[view_dict['L-CC.png']])
            
        if 'R-CC.png' in view_dict and 'R-MLO.png' in view_dict:
            paired_images.append((view_dict['R-CC.png'], view_dict['R-MLO.png']))
            new_labels.append(file_path_to_label[view_dict['R-CC.png']])
            
            paired_images.append((view_dict['R-MLO.png'], view_dict['R-CC.png']))
            new_labels.append(file_path_to_label[view_dict['R-CC.png']])
    else:
        print(view_dict)

In [None]:
# Custom dataset class for mammography images
class MammographyDataset(Dataset):
    def __init__(self, image_pairs, labels, transform=None):
        self.image_pairs = image_pairs
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_pairs)

    def __getitem__(self, idx):
        img_path1, img_path2 = self.image_pairs[idx]
        img1 = Image.open(img_path1).convert('RGB')
        img2 = Image.open(img_path2).convert('RGB')

        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)

        label = self.labels[idx]
        return (img1, img2), label

# Define transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
train_pairs, test_pairs, train_labels, test_labels = train_test_split(paired_images, new_labels, test_size=0.2, random_state=42, stratify=True)

train_dataset = MammographyDataset(train_pairs, train_labels, transform=transform)
test_dataset = MammographyDataset(test_pairs, test_labels, transform=transform)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
class ResNetFC(nn.Module):
    def __init__(self, num_classes=5):
        super(ResNetFC, self).__init__()
        self.base_model = models.resnet50(pretrained=True)
        self.fc_in_features = self.base_model.fc.in_features
        self.base_model.fc = nn.Identity() 
        self.fc1 = nn.Linear(self.fc_in_features, 100)
        self.fc2 = nn.Linear(100, 50)
        self.fc3 = nn.Linear(50, num_classes)

    def forward(self, img):
        feat = self.base_model(img)
        output = self.fc1(feat)
        output = self.fc2(output)
        output = self.fc3(output)
        return output


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu");
model = ResNetFC()
model = model.to(device)
criterion = nn.CrossEntropyLoss()
kl_div = KLDivLoss(reduction = "batchmean")

In [None]:
config = {
    'architecture': 'feedforward',
    'lr': 0.0001,
    'scheduler_factor': 0.5,
    'scheduler_patience': 2,
    'scheduler_min_lr': 1e-6,
    'epochs': 40
}
optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    'min',
    factor=config['scheduler_factor'],
    patience=config['scheduler_patience'],
    min_lr=config['scheduler_min_lr']
)

In [None]:
# Training loop
min_loss = 100
for epoch in tqdm_notebook(range(config['epochs'])):
    model.train()
    running_loss = 0.0
    for (images1, images2), labels in train_dataloader:
        images1, images2, labels = images1.to(device), images2.to(device), labels.to(device)
        optimizer.zero_grad()
        out1 = model(images1)
        out2 = model(images2)
        loss1 = criterion(out1, labels)
        
        kl_loss1 = F.kl_div(F.log_softmax(out2, dim=1), F.softmax(out1, dim=1), reduction='batchmean')
        kl_loss2 = F.kl_div(F.log_softmax(out1, dim=1), F.softmax(out2, dim=1), reduction='batchmean')
        kl_loss = (kl_loss1 + kl_loss2)/2
        
        total_loss = loss1 + kl_loss
        total_loss.backward()
        optimizer.step()

        running_loss += total_loss.item()
        
    print(f"Epoch [{epoch + 1}/{config['epochs']}], Train Loss: {running_loss / len(train_dataloader):.6f}")
    
    model.eval()
    test_loss = 0.0
    with torch.no_grad():
        for (images1, images2), labels in test_dataloader:
            images1, images2, labels = images1.to(device), images2.to(device), labels.to(device)
            out1 = model(images1)
            out2 = model(images2)
            loss1 = criterion(out1, labels)
            
            kl_loss1 = F.kl_div(F.log_softmax(out2, dim=1), F.softmax(out1, dim=1), reduction='batchmean')
            kl_loss2 = F.kl_div(F.log_softmax(out1, dim=1), F.softmax(out2, dim=1), reduction='batchmean')
            kl_loss = (kl_loss1 + kl_loss2)/2
            
            total_loss = loss1 + kl_loss
            test_loss += total_loss.item()
        scheduler.step(test_loss)
           
    print(f"Test Loss: {test_loss / len(test_dataloader):.6f}")
               
    if test_loss < min_loss:
        state_dict = {
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'train_loss': running_loss,
            'val_loss': test_loss,
            'best_val_loss': min_loss,
        }
        weight_path = f"<your_dir>/resnet50_pretrained_paired_aux_{epoch}.pth"
        torch.save(model.state_dict(), weight_path)
        min_loss = test_loss

In [None]:
model = ResNetFC()
weight_path = '<your_best_weight_path>'
model.load_state_dict(torch.load(weight_path))
model = model.to(device)

In [None]:
model.eval()
test_loss = 0.0
cm_labels = []
cm_preds = []
with torch.no_grad():
    for images, labels in test_dataloader:
        image1, labels = images[0].to(device), labels.to(device)
        outputs = model(image1)
        preds = torch.argmax(outputs, dim=1)
        cm_preds += list(preds.cpu().numpy())
        cm_labels += list(labels.cpu().numpy())
        loss = criterion(outputs, labels)
        test_loss += loss.item()
print(f"Test Loss: {test_loss / len(test_dataloader):.6f}")

In [None]:
def denormalize(tensor, mean, std):
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)
    return tensor

In [None]:
def grad_plot(image_tensors, single_label):
    input_tensor1 = image_tensors[0].unsqueeze(0).cuda()  # Add batch dimension if not already present
    input_tensor2 = image_tensors[1].unsqueeze(0).cuda()  # Add batch dimension if not already present
    target_layers = [model.base_model.layer4[-1]] # Define the target layers and create the GradCAM object
    cam = GradCAM(model=model, target_layers=target_layers)
    targets = [ClassifierOutputTarget(single_label)] # Define the target class for Grad-CAM
    grayscale_cam1 = cam(input_tensor=input_tensor1, targets=targets)[0, :] # Generate the CAM
    grayscale_cam2 = cam(input_tensor=input_tensor2, targets=targets)[0, :] # Generate the CAM
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    denormalized_tensor1 = denormalize(image_tensors[0].clone(), mean, std)
    denormalized_tensor2 = denormalize(image_tensors[1].clone(), mean, std)
    rgb_img1 = denormalized_tensor1.permute(1, 2, 0).cpu().numpy()  # Convert CHW to HWC format
    rgb_img1 = (rgb_img1 * 255).astype(np.uint8)  # Scale to [0, 255] and convert to uint8 for visualization
    rgb_img2 = denormalized_tensor2.permute(1, 2, 0).cpu().numpy()  # Convert CHW to HWC format
    rgb_img2 = (rgb_img2 * 255).astype(np.uint8)  # Scale to [0, 255] and convert to uint8 for visualization
    visualization1 = show_cam_on_image(rgb_img1 / 255.0, grayscale_cam1, use_rgb=True)
    visualization2 = show_cam_on_image(rgb_img2 / 255.0, grayscale_cam2, use_rgb=True)

    model.eval()
    with torch.no_grad():
        outputs1 = model(input_tensor1)
        _, predicted1 = torch.max(outputs1, 1)
        predicted_label1 = predicted1.item()
        
        outputs2 = model(input_tensor2)
        _, predicted2 = torch.max(outputs2, 1)
        predicted_label2 = predicted2.item()

    plt.figure(figsize=(6, 6))

    plt.subplot(2, 2, 1)
    plt.imshow(rgb_img1)
    plt.title(f"Original CC: BIRADS {single_label+1}")
    plt.axis('off')
    
    plt.subplot(2, 2, 2) 
    plt.imshow(rgb_img2)
    plt.title(f"Original MLO: BIRADS {single_label+1}")
    plt.axis('off')

    plt.subplot(2, 2, 3)
    plt.imshow(visualization1)
    plt.title(f"Grad-CAM CC: BIRADS {predicted_label1+1}")
    plt.axis('off')
    
    plt.subplot(2, 2, 4)
    plt.imshow(visualization2)
    plt.title(f"Grad-CAM MLO: BIRADS {predicted_label2+1}")
    plt.axis('off')

    plt.show()

In [None]:
# Plot up to 40 images, but feel free to adjust the number as needed.
count = 0
for img, label in test_dataset:
    grad_plot(img, label)
    count +=1
    if count >=40:
        break

In [None]:
# Plot the confusion matrix using seaborn
cm = confusion_matrix(cm_labels, cm_preds)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False,
            xticklabels=['1', '2', '3', '4', '5'],
            yticklabels=['1', '2', '3', '4', '5'])

plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix')
plt.show()

In [None]:
import re

The input text
text = """ Your model output, for example :
Epoch [1/40], Train Loss: 1.367641, learning rate : 0.0001
Test Loss: 1.330273
Epoch [2/40], Train Loss: 0.992598, learning rate : 0.0001
Test Loss: 1.575209
Epoch [3/40], Train Loss: 0.595229, learning rate : 0.0001
Test Loss: 1.421532
"""

# Regular expressions to extract epochs, train losses, and validation losses
epoch_pattern = re.compile(r"Epoch \[(\d+)/\d+\]")
train_loss_pattern = re.compile(r"Train Loss: ([\d\.]+)")
val_loss_pattern = re.compile(r"Test Loss: ([\d\.]+)")

# Extracting the data
epochs = [int(epoch) for epoch in epoch_pattern.findall(text)]
train_losses = [float(train_loss) for train_loss in train_loss_pattern.findall(text)]
val_losses = [float(val_loss) for val_loss in val_loss_pattern.findall(text)]

# Output the extracted data
print("Epochs:", epochs)
print("Train Losses:", train_losses)
print("Validation Losses:", val_losses)

plt.figure(figsize=(10, 5))
plt.plot(epochs, train_losses, label='Train Loss')
plt.plot(epochs, val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Train and Validation Loss over Epochs')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
def calculate_accuracy(all_preds, all_labels):
    total_samples = len(all_labels)
    correct_predictions = sum(1 for pred, label in zip(all_preds, all_labels) if pred == label)
    accuracy = correct_predictions / total_samples
    return accuracy

accuracy = calculate_accuracy(cm_preds, cm_labels)
print("Accuracy:", accuracy)
