In [None]:
# !pip install tqdm albumentations

In [24]:
import os
import shutil
import numpy as np
from tqdm import tqdm
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler
from torchvision import transforms
from sklearn.model_selection import KFold
from albumentations import Compose, RandomResizedCrop, HorizontalFlip, Normalize, RandomRotate90, ShiftScaleRotate, CoarseDropout
from albumentations.pytorch import ToTensorV2
from torch.utils.tensorboard import SummaryWriter


In [25]:
# Device configuration
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters and configurations
config = {
    "base_dir": "/Users/saahil/Desktop/Coding_Projects/DL/MicroscopicFungi/archive-2",
    "batch_size": 32,
    "epochs": 15,
    "learning_rate": 1e-4,
    "height": 224,
    "width": 224,
    "channels": 3,
    "num_folds": 5,
    "patience": 10,
    "seed": 40,
    "log_dir": "./logs",
}




In [26]:
log_dir = config["log_dir"]

# Clear the log directory
if os.path.exists(log_dir):
    shutil.rmtree(log_dir)
os.makedirs(log_dir)

In [27]:
writer = SummaryWriter(config["log_dir"])



In [28]:
class FungiDataset(Dataset):
    def __init__(self, root_dir, transform=None, subset='train'):
        self.root_dir = os.path.join(root_dir, subset)
        self.transform = transform
        self.classes = ['H1', 'H2', 'H3', 'H5', 'H6']  # List of class names
        self.image_paths, self.labels = self._load_dataset()

    def _load_dataset(self):
        image_paths, labels = [], []
        for label, cls in enumerate(self.classes):
            cls_dir = os.path.join(self.root_dir, cls)
            if not os.path.exists(cls_dir):
                raise FileNotFoundError(f"Directory {cls_dir} does not exist.")
            for img_name in os.listdir(cls_dir):
                img_path = os.path.join(cls_dir, img_name)
                if os.path.isfile(img_path):
                    image_paths.append(img_path)
                    labels.append(label)
        return image_paths, labels

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image=np.array(image))['image']
        return image, label

In [29]:
class CustomCNN(nn.Module):
    def __init__(self, num_classes):
        super(CustomCNN, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(config["channels"], 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
        )
        self.fc_layers = nn.Sequential(
            nn.Linear(512 * (config["height"] // 16) * (config["width"] // 16), 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(1024, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, num_classes),
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        x = self.fc_layers(x)
        return x

In [30]:
def get_transforms():
    return Compose([
        RandomResizedCrop(config["height"], config["width"], scale=(0.8, 1.0)),
        HorizontalFlip(),
        RandomRotate90(),
        ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=30),
        CoarseDropout(max_holes=8, max_height=32, max_width=32),
        Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2()
    ])


In [31]:
def save_checkpoint(model, optimizer, fold, epoch, best=False):
    state = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
    }
    filename = f'checkpoint_fold{fold}_epoch{epoch}{"_best" if best else ""}.pth'
    torch.save(state, filename)

In [32]:
def train_epoch(model, dataloader, criterion, optimizer):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for inputs, labels in tqdm(dataloader, desc="Training", leave=False):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_accuracy = 100 * correct / total
    return epoch_loss, epoch_accuracy

In [33]:
def validate_epoch(model, dataloader, criterion):
    model.eval()
    running_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Validation", leave=False):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss = running_loss / len(dataloader.dataset)
    val_accuracy = 100 * correct / total
    return val_loss, val_accuracy

In [34]:
from sklearn.model_selection import StratifiedKFold
from sklearn.utils.class_weight import compute_class_weight

def train_model():
    dataset = FungiDataset(config["base_dir"], transform=get_transforms(), subset='train')
    num_classes = len(dataset.classes)
    
    # Compute class weights for handling class imbalance
    class_weights = compute_class_weight('balanced', classes=np.arange(len(dataset.classes)), y=dataset.labels)
    class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
    
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    
    # Use StratifiedKFold to ensure each fold has a similar class distribution
    skf = StratifiedKFold(n_splits=config["num_folds"], shuffle=True, random_state=config["seed"])

    for fold, (train_idx, val_idx) in enumerate(skf.split(np.arange(len(dataset)), dataset.labels), 1):

        print(f"Fold {fold}/{config['num_folds']}")
        
        # Extract the labels for the train and validation indices
        train_labels = np.array(dataset.labels)[train_idx]
        val_labels = np.array(dataset.labels)[val_idx]
        print(f"Fold {fold} - Train Class Distribution: {np.bincount(train_labels)}")
        print(f"Fold {fold} - Val Class Distribution: {np.bincount(val_labels)}")

        # Set up the data samplers and loaders
        train_sampler = SubsetRandomSampler(train_idx)
        val_sampler = SubsetRandomSampler(val_idx)
        train_loader = DataLoader(dataset, batch_size=config["batch_size"], sampler=train_sampler)
        val_loader = DataLoader(dataset, batch_size=config["batch_size"], sampler=val_sampler)

        # Reinitialize the model for each fold
        model = CustomCNN(num_classes=len(dataset.classes)).to(device)
        optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"])
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)

        best_val_loss, patience_counter = float('inf'), 0
        best_model_path = f'checkpoint_fold{fold}_best.pth'

        for epoch in range(1, config["epochs"] + 1):
            print(f"Epoch {epoch}/{config['epochs']}")

            train_loss, train_accuracy = train_epoch(model, train_loader, criterion, optimizer)
            val_loss, val_accuracy = validate_epoch(model, val_loader, criterion)

            print(f"Train Loss: {train_loss:.4f}, Acc: {train_accuracy:.2f}%, Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}%")

            writer.add_scalar('Loss/train', train_loss, epoch)
            writer.add_scalar('Loss/val', val_loss, epoch)
            writer.add_scalar('Accuracy/train', train_accuracy, epoch)
            writer.add_scalar('Accuracy/val', val_accuracy, epoch)
            writer.add_scalar('Learning Rate', optimizer.param_groups[0]['lr'], epoch)

            # Update the learning rate based on validation loss
            scheduler.step(val_loss)

            # Save the model if it has the best validation loss so far
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
                print(f"New best model found for fold {fold} at epoch {epoch}, saving model...")
                torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'epoch': epoch,
                    'best_val_loss': best_val_loss,
                }, best_model_path)
            else:
                patience_counter += 1

            # Early stopping
            if patience_counter >= config["patience"]:
                print("Early stopping triggered")
                break

    writer.close()


In [35]:
# train_model()

In [36]:
def load_fold_models(num_folds, model_class, num_classes, model_paths):
    models = []
    for fold in range(1, num_folds + 1):
        model = model_class(num_classes=num_classes).to(device)
        checkpoint = torch.load(model_paths[fold - 1])
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()  # Set the model to evaluation mode
        models.append(model)
    return models


In [37]:
def predict_ensemble(models, image):
    with torch.no_grad():
        outputs = [model(image) for model in models]
        avg_output = torch.mean(torch.stack(outputs), dim=0)
        _, predicted = torch.max(avg_output, 1)
    return predicted.item()


In [41]:
# Create an instance of FungiDataset to get the classes
dataset = FungiDataset(config["base_dir"], transform=get_transforms(), subset='train')
num_classes = len(dataset.classes)

# Define the absolute base path to where the checkpoints are saved
base_model_path = "/Users/saahil/Desktop/Coding_Projects/DL/MicroscopicFungi"

# Define the paths to the saved models for each fold using the absolute path
model_paths = [os.path.join(base_model_path, f'checkpoint_fold{fold}_best.pth') for fold in range(1, config["num_folds"] + 1)]

# Load all the saved models
models = load_fold_models(config["num_folds"], CustomCNN, num_classes, model_paths)

# Load and preprocess the image
def preprocess_image(image_path, transform):
    image = Image.open(image_path).convert('RGB')
    image = transform(image=np.array(image))['image']
    image = image.unsqueeze(0)  # Add batch dimension
    return image.to(device)

# Single image path
image_path = "/Users/saahil/Desktop/Coding_Projects/DL/MicroscopicFungi/archive-2/test/H2/H2_1a_12.jpg.jpg"
image = preprocess_image(image_path, get_transforms())

# Get the ensemble prediction
predicted_class = predict_ensemble(models, image)
print(f"The predicted class for the image is: {dataset.classes[predicted_class]}")


  checkpoint = torch.load(model_paths[fold - 1])


The predicted class for the image is: H2
