In [None]:
import medmnist
from medmnist import INFO
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms

# Set dataset flag
data_flag = 'chestmnist'
download = True

# Load dataset information
info = INFO[data_flag]
DataClass = getattr(medmnist, info['python_class'])

# Define transformations (Data Augmentation)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.Normalize((0.5,), (0.5,))
])

# Load the dataset with transformations
train_dataset = DataClass(split='train', transform=transform, download=download)
val_dataset = DataClass(split='val', transform=transform, download=download)
test_dataset = DataClass(split='test', transform=transform, download=download)

# Create DataLoaders with drop_last=True to fix batch size mismatch
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

# Remove the incorrect normalization

# Define CNN Model
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.5)
        
        # Dynamically calculate the fully connected input size
        self.fc1 = nn.Linear(128 * 3 * 3, 256)  # Adjusted FC layer size
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 14)

    def forward(self, x):
        x = self.pool(torch.relu(self.bn1(self.conv1(x))))
        x = self.pool(torch.relu(self.bn2(self.conv2(x))))
        x = self.pool(torch.relu(self.bn3(self.conv3(x))))
        
        # Print shape to debug size before fully connected layers
        #print(f"Flatten shape: {x.shape}")

        x = x.view(x.size(0), -1)  # Ensure correct flattening
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)  
        return x

# Model Initialization
cnn = CNN()

# Define Hyperparameters
num_epochs = 15
learning_rate = 0.001

# Define Loss and Optimizer
loss_fn = nn.BCEWithLogitsLoss()  # BCE loss for multi-label classification
optimizer = optim.Adam(cnn.parameters(), lr=learning_rate)

# Learning Rate Scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

# Training Loop
for epoch in range(num_epochs):
    cnn.train()
    total_loss = 0
    
    for images, labels in train_loader:
        images = images.float()
        labels = labels.float()

        optimizer.zero_grad()
        outputs = cnn(images)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Step the learning rate scheduler
    scheduler.step()

    # Compute validation loss
    cnn.eval()
    with torch.no_grad():
        val_loss = 0
        for images, labels in val_loader:
            images = images.float()
            labels = labels.float()
            outputs = cnn(images)
            loss = loss_fn(outputs, labels)
            val_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {total_loss/len(train_loader):.4f}, Val Loss: {val_loss/len(val_loader):.4f}")

# Evaluate Model
cnn.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images = images.float()
        labels = labels.float()
        outputs = cnn(images)
        predictions = (torch.sigmoid(outputs) > 0.5).float()
        correct += (predictions == labels).sum().item()
        total += labels.numel()

accuracy = correct / total
print(f"Test Accuracy: {accuracy*100:.4f}")

torch.save(cnn.state_dict(), 'chectMNIST.pth')