In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
import medmnist
from medmnist import INFO
from torchvision.utils import save_image

# Set dataset flag
data_flag = 'chestmnist'
download = True

# Load dataset information
info = INFO[data_flag]
DataClass = getattr(medmnist, info['python_class'])

# Define transformations (Normalization only for training)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load dataset with transformations
train_dataset = DataClass(split='train', transform=transform, download=download)
val_dataset = DataClass(split='val', transform=transform, download=download)
test_dataset = DataClass(split='test', transform=transform, download=download)

# Create DataLoaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Class labels
class_names = [
    'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass', 'Nodule',
    'Pneumonia', 'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema', 'Fibrosis',
    'Pleural_Thickening', 'Hernia'
]

# Define CNN Model
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.5)
        
        self.fc1 = nn.Linear(128 * 3 * 3, 256)  # Adjusted FC layer size
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 14)  # 14 classes for ChestMNIST

    def forward(self, x):
        x = self.pool(torch.relu(self.bn1(self.conv1(x))))
        x = self.pool(torch.relu(self.bn2(self.conv2(x))))
        x = self.pool(torch.relu(self.bn3(self.conv3(x))))
        x = x.view(x.size(0), -1)  # Flatten
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)  
        return x

# Initialize Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cnn = CNN().to(device)

# Define Loss and Optimizer
loss_fn = nn.BCEWithLogitsLoss()  # BCE loss for multi-label classification
optimizer = optim.Adam(cnn.parameters(), lr=0.001)

# Training Loop
num_epochs = 5
for epoch in range(num_epochs):
    cnn.train()
    total_loss = 0
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.float().to(device)

        optimizer.zero_grad()
        outputs = cnn(images)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_loader):.4f}")

# Save Model Weights
torch.save(cnn.state_dict(), "chestmnist.pth")
print("Model weights saved successfully!")

# Accuracy Evaluation on Test Set
cnn.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.float().to(device)
        outputs = cnn(images)
        predictions = (torch.sigmoid(outputs) > 0.5).float()  # Convert logits to probabilities and threshold at 0.5

        correct += (predictions == labels).sum().item()  # Count correct predictions
        total += labels.numel()  # Count total labels

accuracy = correct / total
print(f"Test Accuracy: {accuracy * 100:.2f}%")
