In [None]:
## 1. Setup and Imports

# Import required modules
import torch
from torch import nn, optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import torchvision
from torchvision import datasets, models, transforms
import torch.nn.functional as F

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import json
import copy
import time

# Check GPU availability
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


## 2. Data Preparation


# Define data directories
data_dir = './COVID-19_Radiography_Dataset/Dataset'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/validation'
test_dir = data_dir + '/test'

# Define data transformations
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomRotation(30),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'valid': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

# Load datasets
image_datasets = {
    x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
    for x in ['train', 'valid', 'test']
}

# Create dataloaders
batch_size = 64
dataloaders = {
    x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size,
                                 shuffle=True, num_workers=4)
    for x in ['train', 'valid', 'test']
}

# Get dataset sizes
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid', 'test']}
print("Dataset sizes:", dataset_sizes)

# Get class names
class_names = image_datasets['train'].classes
print("Classes:", class_names)


## 3. Model Architecture


def build_classifier(num_in_features, hidden_layers, num_out_features):
    """
    Build a custom classifier for the DenseNet model
    """
    classifier = nn.Sequential()
    if hidden_layers == None:
        classifier.add_module('fc0', nn.Linear(num_in_features, 4))
    else:
        layer_sizes = zip(hidden_layers[:-1], hidden_layers[1:])
        classifier.add_module('fc0', nn.Linear(num_in_features, hidden_layers[0]))
        classifier.add_module('relu0', nn.ReLU())
        classifier.add_module('drop0', nn.Dropout(.6))
        classifier.add_module('relu1', nn.ReLU())
        classifier.add_module('drop1', nn.Dropout(.5))
        for i, (h1, h2) in enumerate(layer_sizes):
            classifier.add_module(f'fc{i+1}', nn.Linear(h1, h2))
            classifier.add_module(f'relu{i+1}', nn.ReLU())
            classifier.add_module(f'drop{i+1}', nn.Dropout(.5))
        classifier.add_module('output', nn.Linear(hidden_layers[-1], num_out_features))
    return classifier

# Initialize model
model = models.densenet201(pretrained=True)
num_in_features = 1920

# Freeze parameters
for param in model.parameters():
    param.requires_grad = False

# Build and set classifier
classifier = build_classifier(num_in_features, hidden_layers=None, num_out_features=4)
model.classifier = classifier

# Set up loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adadelta(model.parameters(), lr=0.1)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)


## 4. Training Function


def train_model(model, criterion, optimizer, scheduler, num_epochs=20):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)

        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            if phase == 'valid' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')

    model.load_state_dict(best_model_wts)
    return model


## 5. Train the Model


# Move model to device and train
model = model.to(device)
model = train_model(model, criterion, optimizer, scheduler, num_epochs=20)


## 6. Evaluation


def evaluate_model():
    model.eval()
    accuracy = 0

    for inputs, labels in dataloaders['test']:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        equality = (labels.data == outputs.max(1)[1])
        accuracy += equality.type_as(torch.FloatTensor()).mean()

    print(f"Test accuracy: {accuracy/len(dataloaders['test']):.3f}")

evaluate_model()


## 7. Save and Load Model


def save_checkpoint():
    model.class_to_idx = image_datasets['train'].class_to_idx
    checkpoint = {
        'input_size': 1920,
        'output_size': 4,
        'epochs': epochs,
        'batch_size': 64,
        'model': models.densenet201(pretrained=True),
        'classifier': classifier,
        'scheduler': scheduler,
        'optimizer': optimizer.state_dict(),
        'state_dict': model.state_dict(),
        'class_to_idx': model.class_to_idx
    }
    torch.save(checkpoint, 'model.pth')

def load_checkpoint(filepath):
    checkpoint = torch.load(filepath)
    model = checkpoint['model']
    model.classifier = checkpoint['classifier']
    model.state_dict = checkpoint['state_dict']
    model.class_to_idx = checkpoint['class_to_idx']
    
    for param in model.parameters():
        param.requires_grad = False
        
    return model, checkpoint['class_to_idx']


## 8. Inference Functions


def process_image(image):
    """Process a PIL image for use in a PyTorch model"""
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    return preprocess(image.convert('RGB'))

def predict(image_path, model, topk=4):
    """Predict the class of an image using a trained deep learning model"""
    img = Image.open(image_path)
    img = process_image(img)
    img = np.expand_dims(img, 0)
    img = torch.from_numpy(img)
    
    model.eval()
    inputs = Variable(img).to(device)
    logits = model.forward(inputs)
    
    ps = F.softmax(logits, dim=1)
    topk = ps.cpu().topk(topk)
    
    return (e.data.numpy().squeeze().tolist() for e in topk)

def view_classify(img_path, prob, classes):
    """View an image and its predicted classes"""
    image = Image.open(img_path)

    fig, (ax1, ax2) = plt.subplots(figsize=(6,10), ncols=1, nrows=2)
    disease_name = img_path.split('/')[-2]
    ax1.set_title(disease_name)
    ax1.imshow(image)
    ax1.axis('off')
    
    y_pos = np.arange(len(prob))
    ax2.barh(y_pos, prob, align='center')
    ax2.set_yticks(y_pos)
    ax2.set_yticklabels(disease_classes)
    ax2.invert_yaxis()
    ax2.set_title('Class Probability')


## 9. Example Usage


# Load and process an image
img_path = './test/COVID/example.jpg'
probs, classes = predict(img_path, model.to(device))

# Display results
for prob, cls in zip(probs, classes):
    print(f"Probability of {cat_to_name[str(cls)]} is {prob*100:.2f}%")

# Visualize results
view_classify(img_path, probs, classes)

