In [27]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler
from torchvision import transforms
from sklearn.metrics import cohen_kappa_score
from sklearn.model_selection import KFold
import numpy as np
from PIL import Image
import os

In [28]:
device = torch.device("mps")
base_dir = "/Users/saahil/Desktop/Coding_Projects/DL/MicroscopicFungi/archive-2"
FACTOR = 4
BATCH_SIZE = 8 * FACTOR
EPOCHS = 100
LEARNING_RATE = 1e-3 * FACTOR
HEIGHT = 224
WIDTH = 224
CHANNELS = 3
VALIDATION_SPLIT = 0.1
SEED = 40
PATIENCE = 10  # For Early Stopping
NUM_FOLDS = 5  # K-Fold Cross Validation

In [29]:
class BoneBreakDataset(Dataset):
    def __init__(self, root_dir, transform=None, subset='train'):
        self.root_dir = os.path.join(root_dir, subset)
        self.transform = transform
        self.classes = ['H1', 'H2', 'H3', 'H5', 'H6']
        self.class_map = {
            'H1': 'Candida albicans',
            'H2': 'Aspergillus niger',
            'H3': 'Trichophyton rubrum',
            'H5': 'Trichophyton mentagrophytes',
            'H6': 'Epidermophyton floccosum'
        }
        self.image_paths = []
        self.labels = []

        for label, cls in enumerate(self.classes):
            cls_dir = os.path.join(self.root_dir, cls)
            for img_name in os.listdir(cls_dir):
                img_path = os.path.join(cls_dir, img_name)
                if os.path.isfile(img_path):  # Ensure it's a file
                    self.image_paths.append(img_path)
                    self.labels.append(label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

In [30]:
transform = transforms.Compose([
    transforms.Resize((HEIGHT, WIDTH)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.RandomResizedCrop((HEIGHT, WIDTH), scale=(0.8, 1.0))
])

In [31]:
dataset = BoneBreakDataset(base_dir, transform=transform, subset='train')
kf = KFold(n_splits=NUM_FOLDS, shuffle=True, random_state=SEED)


In [32]:
class CustomCNN(nn.Module):
    def __init__(self, num_classes):
        super(CustomCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=CHANNELS, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1)
        
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        
        self.fc1 = nn.Linear(512 * (HEIGHT // 16) * (WIDTH // 16), 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, num_classes)
        
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.pool(F.relu(self.conv4(x)))
        
        x = x.view(-1, 512 * (HEIGHT // 16) * (WIDTH // 16))
        
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        
        return x

In [33]:
fold = 0
for train_idx, val_idx in kf.split(np.arange(len(dataset))):
    fold += 1
    print(f"Fold {fold}/{NUM_FOLDS}")

    train_sampler = SubsetRandomSampler(train_idx)
    val_sampler = SubsetRandomSampler(val_idx)
    
    train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=train_sampler)
    val_loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=val_sampler)
    
    model = CustomCNN(num_classes=len(dataset.classes)).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)
    
    best_val_loss = float('inf')
    patience_counter = 0
    
    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        epoch_loss = running_loss / len(train_loader.dataset)
        train_accuracy = 100 * correct / total
        
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        val_loss /= len(val_loader.dataset)
        val_accuracy = 100 * correct / total
        
        print(f'Epoch {epoch+1}/{EPOCHS}, Loss: {epoch_loss:.4f}, Acc: {train_accuracy:.2f}%, Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}%')
        
        scheduler.step(val_loss)

        # Early Stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(model.state_dict(), f'torchmodel_fold{fold}_best.pth')
        else:
            patience_counter += 1
        
        if patience_counter >= PATIENCE:
            print("Early stopping triggered")
            break

print('Training Complete')

Fold 1/5
Epoch 1/100, Loss: 2.5484, Acc: 32.75%, Val Loss: 0.2671, Val Acc: 40.90%
Epoch 2/100, Loss: 1.0402, Acc: 44.77%, Val Loss: 0.2585, Val Acc: 45.40%
Epoch 3/100, Loss: 1.0416, Acc: 46.38%, Val Loss: 0.2440, Val Acc: 51.90%
Epoch 4/100, Loss: 1.0417, Acc: 45.95%, Val Loss: 0.2411, Val Acc: 51.40%
Epoch 5/100, Loss: 0.9906, Acc: 49.08%, Val Loss: 0.2378, Val Acc: 53.00%


KeyboardInterrupt: 