In [None]:
from torchvision import transforms
from PIL import Image 
import numpy as np
import os
import json
from collections import Counter
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
import seaborn as sns
from torch.utils.data import Dataset, DataLoader, SequentialSampler, WeightedRandomSampler
#from torch.optim.lr_scheduler import CosineAnnealingLR

In [None]:
class DataProcessor:
    def __init__(self, folder_path, disease_list):
        self.folder_path = folder_path
        self.disease_list = disease_list
        self.folders_with_diseases_labels = {}
        self.folder_name_with_diseases = []
        self.label_counts = None

    def read_data(self):
        for root, dirs, files in os.walk(os.path.join(self.folder_path, 'imgs')):
            for folder_name in dirs:
                folder_path = os.path.join(root, folder_name)
                
                detection_file_path = os.path.join(folder_path, 'detection.json')
                with open(detection_file_path, 'r') as detection_file:
                    detection_data = json.load(detection_file)

                    disease_labels = [label for item in detection_data for label in item.keys() if label in self.disease_list]
                    
                    # if disease_labels is not empty and also to remove 3 images with count = 1 labels
                    if disease_labels and len(disease_labels)==1:
                        
                        self.folders_with_diseases_labels[folder_name] = disease_labels[0]
                        self.folder_name_with_diseases.append(folder_name)
                         

    def delete_folders(self):
        # frequency of each merged label
        self.label_counts = Counter(self.folders_with_diseases_labels.values())

        # delete folders with label counts <= 3
        folders_to_delete = [folder_name for folder_name, label in self.folders_with_diseases_labels.items() if self.label_counts[label] <= 3]

        for folder_name in folders_to_delete:
            del self.folders_with_diseases_labels[folder_name]
            self.folder_name_with_diseases.remove(folder_name)
            
    def get_training_data(self):
        training_data = []
        for folder_name, label in self.folders_with_diseases_labels.items():
            folder_path = os.path.join(self.folder_path, 'imgs', folder_name)
            image_path = os.path.join(folder_path, 'source.jpg')
            training_data.append((image_path, label))
        return training_data
    


folder_path = 'Slake1.0'
# we use 5 diseases related to chest images (heart and lungs)
disease_list = ['Pneumothorax', 'Pneumonia','Effusion','Cardiomegaly','Lung Cancer']
data_processor = DataProcessor(folder_path, disease_list)
data_processor.read_data()
#data_processor.delete_folders()
#data = data_processor.get_training_data()


In [None]:
class CustomDataset(Dataset):
    def __init__(self, data_processor, folder_names, transform=None, is_train=True):
        self.data_processor = data_processor
        self.folder_names = folder_names
        self.transform = transform
        self.is_train = is_train

        # map labels to index
        self.label_to_index = {label: idx for idx, label in enumerate(set(data_processor.folders_with_diseases_labels.values()))}

    def __len__(self):
        return len(self.folder_names)

    def __getitem__(self, idx):
        folder_name = self.folder_names[idx]
        folder_path = os.path.join(self.data_processor.folder_path, 'imgs', folder_name)
        
        # read images 'source.jpg' in each folder
        image_path = os.path.join(folder_path, 'source.jpg')
        image = Image.open(image_path).convert('RGB')
        
        if self.is_train:
            # Random rotation
            angle = np.random.uniform(-10, 10)
            image = image.rotate(angle)

            # Random x shift
            #if np.random.rand() > 0.5:
            shift_x = np.random.uniform(-10, 10)
            image = transforms.functional.affine(image, angle=0, translate=(shift_x, 0), scale=1, shear=0)

            # Random y shift
            #if np.random.rand() > 0.5:
            shift_y = np.random.uniform(-10, 10)
            image = transforms.functional.affine(image, angle=0, translate=(0, shift_y), scale=1, shear=0)

            # Random crop 
            if np.random.rand() > 0.5:
                i, j, h, w = transforms.RandomCrop.get_params(image, output_size=(256, 256))
                image = transforms.functional.crop(image, i, j, h, w)
                
            if np.random.rand() > 0.5:
                contrast_factor = np.random.uniform(0.5, 0.8)
                image = transforms.functional.adjust_contrast(image, contrast_factor)


        if self.transform:
            image = self.transform(image)

        label = self.data_processor.folders_with_diseases_labels[folder_name]
        label = self.label_to_index[label]

        if not torch.is_tensor(image):
            image = transforms.ToTensor()(image)

        return image, label

    def is_minority_class(self, label_index):
        MINORITY_THRESHOLD = 20
        return self.class_counts[label_index] < MINORITY_THRESHOLD
    
    def print_class_mapping(self):
        print("Class Index to Label Mapping:")
        for label, index in self.label_to_index.items():
            print(f"Index {index}: {label}")
            
    def print_class_samples(self):
        augmented_counts = Counter()

        for idx in range(len(self.folder_names)):
            folder_name = self.folder_names[idx]
            label = self.data_processor.folders_with_diseases_labels[folder_name]
            label_index = self.label_to_index[label]

            if self.is_train and self.is_minority_class(label_index):
                augmented_counts[label_index] += 1

        print("Number of samples after data augmentation for each class:")
        for index, count in augmented_counts.items():
            label = [label for label, idx in self.label_to_index.items() if idx == index][0]
            print(f"Class '{label}': {count} samples")
            
    def print_class_labels_count(self):
        label_counts = Counter()

        for idx in range(len(self.folder_names)):
            folder_name = self.folder_names[idx]
            label = self.data_processor.folders_with_diseases_labels[folder_name]
            label_index = self.label_to_index[label]
            label_counts[label_index] += 1

        print("Number of labels for each class in the dataset:")
        for index, count in label_counts.items():
            label = [label for label, idx in self.label_to_index.items() if idx == index][0]
            print(f"Class '{label}': {count} labels")


In [None]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
data_labels = list(data_processor.folders_with_diseases_labels.values())
# 80% training and 20% validation sets
training_data, validation_data = train_test_split(data_processor.folder_name_with_diseases, test_size=0.2,
                                 random_state=42, shuffle = True, stratify=np.array(data_labels))


# CustomDataset for both training and validation
train_dataset = CustomDataset(data_processor, folder_names=training_data, transform=transform, is_train=True)
train_dataset.print_class_mapping()
#train_dataset.print_class_samples()
#train_dataset.print_class_labels_count()

val_dataset = CustomDataset(data_processor, folder_names=validation_data, transform=transform, is_train=False)
#val_dataset.print_class_labels_count()


train_labels = []

for sample in train_dataset:
    _, label = sample
    train_labels.append(label)

train_label_counts = dict(Counter(train_labels))
train_weight_samples = [1/train_label_counts[x] for x in train_labels]

train_sampler = WeightedRandomSampler(train_weight_samples, num_samples=len(train_labels), replacement=True)
train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=8)

val_sampler = SequentialSampler(val_dataset)
val_loader = DataLoader(val_dataset, sampler=val_sampler, batch_size=8)


In [None]:
class FineTunedAlexNet(nn.Module):
    def __init__(self, num_classes, dropout_prob=0.5):
        super(FineTunedAlexNet, self).__init__()

        alexnet = models.alexnet(pretrained=True)

        self.features = alexnet.features
        self.avgpool = alexnet.avgpool

        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout_prob),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout_prob),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
        )

        self.fc = nn.Linear(4096, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        x = self.fc(x)
        return x


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

num_classes= 5
model = ImageClassificationModel(num_classes=num_classes).to(device)

label_counts = {
    'Cardiomegaly': 33,
    'Pneumonia': 28,
    'Pneumothorax': 15,
    'Lung Cancer': 18,
    'Effusion': 17,
    }

# inverse class frequencies
total_samples = sum(label_counts.values())
class_weights = {label: total_samples / (len(label_counts) * count) for label, count in label_counts.items()}

# to normalize the weights
total_weights = sum(class_weights.values())
class_weights = {label: weight / total_weights for label, weight in class_weights.items()}

# Convert class_weights to a tensor
class_weights_tensor = torch.tensor(list(class_weights.values())).to(device)

# weighted CrossEntropyLoss
#criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)

#criterion = nn.CrossEntropyLoss()

class FocalLoss(nn.Module):
    def __init__(self, class_weights=None, gamma=2):
        super(FocalLoss, self).__init__()
        self.class_weights = class_weights
        self.gamma = gamma

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, weight=self.class_weights)
        pt = torch.exp(-ce_loss)
        focal_loss = (1 - pt) ** self.gamma * ce_loss
        return focal_loss

# Create Focal Loss with class weights
criterion = FocalLoss(class_weights_tensor)

weight_decay = 5*1e-2
#optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=weight_decay)
#optimizer = optim.RMSprop(model.parameters(), lr=0.03, weight_decay=weight_decay)
optimizer = optim.SGD(model.parameters(), lr=0.0045, weight_decay=weight_decay)

def calculate_accuracy(outputs, targets):
    _, predicted = torch.max(outputs, 1)
    return accuracy_score(targets.cpu().numpy(), predicted.cpu().numpy())

def train(model, train_loader, optimizer, criterion):
    model.train()
    total_loss = 0.0
    total_samples = 0
    total_correct = 0
    
    for inputs, targets in train_loader:
        inputs = inputs.to(device)
        targets = targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        loss.backward()
        optimizer.step()

        total_loss += loss.item() * inputs.size(0)
        total_samples += inputs.size(0)

        _, predicted = torch.max(outputs, 1)
        total_correct += (predicted == targets).sum().item()

    epoch_loss = total_loss / total_samples
    epoch_accuracy = total_correct / total_samples

    return epoch_loss, epoch_accuracy

def validate(model, val_loader, criterion):
    model.eval()
    total_loss = 0.0
    total_samples = 0
    total_correct = 0
    all_targets = []
    all_predictions = []
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)

            outputs = model(inputs)

            loss = criterion(outputs, targets)

            total_loss += loss.item() * inputs.size(0)
            
            total_samples += inputs.size(0)

            _, predicted = torch.max(outputs, 1)
            total_correct += (predicted == targets).sum().item()

    epoch_loss = total_loss / total_samples
    epoch_accuracy = total_correct / total_samples

    return epoch_loss, epoch_accuracy

train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []
best_val_accuracy = 0.0
num_epochs = 20
all_val_predictions = []
all_val_targets = []

for epoch in range(num_epochs):
    
    train_loss, train_accuracy = train(model, train_loader, optimizer, criterion)
    val_loss, val_accuracy = validate(model, val_loader, criterion)

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accuracies.append(train_accuracy)
    val_accuracies.append(val_accuracy)
    
    print(f"Epoch {epoch+1}/{num_epochs}:")
    print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")

    # Confusion matrix
    model.eval()
    all_targets = []
    all_predictions = []

    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)

            outputs = model(inputs)

            _, predicted = torch.max(outputs, 1)

            all_targets.extend(targets.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    all_val_targets.append(all_targets)
    all_val_predictions.append(all_predictions)

    # Save the model if the current validation accuracy is better than the previous best
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        torch.save(model.state_dict(), 'cnn_model.pth')

# Calculate the confusion matrix after all epochs
all_val_targets = np.concatenate(all_val_targets)
all_val_predictions = np.concatenate(all_val_predictions)
cm = confusion_matrix(all_val_targets, all_val_predictions)
print(cm)

# Display confusion matrix using seaborn
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='g', cmap='Blues', xticklabels=label_counts.keys(), yticklabels=label_counts.keys())
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
# save files
save_path = '/storage/homefs/zh21i037/'
filename = 'confusion matrix.png'
save_filename = os.path.join(save_path, filename)

plt.savefig(save_filename)
plt.close()


In [None]:
# training and validation loss
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()

# training and validation accuracy
plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label='Train Accuracy')
plt.plot(val_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()

plt.tight_layout()

# to save files
save_path = '/storage/homefs/zh21i037/'
filename = 'losses and accuracies.png'
save_filename = os.path.join(save_path, filename)

plt.savefig(save_filename)
plt.close()


In [None]:
# to get probablity distribution of diseases
class FineTunedAlexNetClassifier:
    def __init__(self, num_classes, model_path):
        self.num_classes = num_classes
        self.model = FineTunedAlexNet(num_classes)
        self.model.load_state_dict(torch.load(model_path))
        self.model.eval()

        # preprocess the input image
        self.preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

    def classify_image(self, image_path):
        
        input_image = Image.open(image_path).convert('RGB')
        input_tensor = self.preprocess(input_image)
        # add batch dimension to the image
        input_batch = input_tensor.unsqueeze(0)  

        # predictions
        with torch.no_grad():
            output = self.model(input_batch)

        # convert the class output to probability distribution using softmax
        probabilities = torch.nn.functional.softmax(output[0], dim=0)

        # class labels and the corresponding probabilities
        class_labels = [f"Class {i}" for i in range(self.num_classes)]
        class_probabilities = list(zip(class_labels, probabilities))

        return class_probabilities


model_path = 'cnn_model.pth' 
image_path = 'Slake1.0/imgs/xmlab333/source.jpg'

classifier = FineTunedAlexNetClassifier(num_classes = 5, model_path = model_path)
predictions = classifier.classify_image(image_path)

print("Probability Distribution of Different Classes")
for class_label, probability in predictions:
    print(f"{class_label}: {probability.item()}")
