In [113]:
# !pip install tqdm albumentations

In [114]:
import os
import shutil
import numpy as np
from tqdm import tqdm
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler
from torchvision import transforms
from sklearn.model_selection import KFold
from albumentations import Compose, RandomResizedCrop, HorizontalFlip, Normalize, RandomRotate90, ShiftScaleRotate, CoarseDropout
from albumentations.pytorch import ToTensorV2
from torch.utils.tensorboard import SummaryWriter


In [127]:
# Device configuration
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters and configurations
config = {
    "base_dir": "/Users/saahil/Desktop/Coding_Projects/DL/MicroscopicFungi/archive-2",
    "batch_size": 32,
    "epochs": 10,
    "learning_rate": 4e-3,
    "height": 224,
    "width": 224,
    "channels": 3,
    "num_folds": 3,
    "patience": 10,
    "seed": 40,
    "log_dir": "./logs",
}




In [128]:
log_dir = config["log_dir"]

# Clear the log directory
if os.path.exists(log_dir):
    shutil.rmtree(log_dir)
os.makedirs(log_dir)

In [129]:
writer = SummaryWriter(config["log_dir"])



In [130]:
class FungiDataset(Dataset):
    def __init__(self, root_dir, transform=None, subset='train'):
        self.root_dir = os.path.join(root_dir, subset)
        self.transform = transform
        self.classes = ['H1', 'H2', 'H3', 'H5', 'H6']
        self.image_paths, self.labels = self._load_dataset()

    def _load_dataset(self):
        image_paths, labels = [], []
        for label, cls in enumerate(self.classes):
            cls_dir = os.path.join(self.root_dir, cls)
            if not os.path.exists(cls_dir):
                raise FileNotFoundError(f"Directory {cls_dir} does not exist.")
            for img_name in os.listdir(cls_dir):
                img_path = os.path.join(cls_dir, img_name)
                if os.path.isfile(img_path):
                    image_paths.append(img_path)
                    labels.append(label)
        return image_paths, labels

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image=np.array(image))['image']
        return image, label


In [131]:
class CustomCNN(nn.Module):
    def __init__(self, num_classes):
        super(CustomCNN, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(config["channels"], 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
        )
        self.fc_layers = nn.Sequential(
            nn.Linear(512 * (config["height"] // 16) * (config["width"] // 16), 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(1024, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, num_classes),
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        x = self.fc_layers(x)
        return x

In [132]:
def get_transforms():
    return Compose([
        RandomResizedCrop(config["height"], config["width"], scale=(0.8, 1.0)),
        HorizontalFlip(),
        RandomRotate90(),
        ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=30),
        CoarseDropout(max_holes=8, max_height=32, max_width=32),
        Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2()
    ])


In [133]:
def save_checkpoint(model, optimizer, fold, epoch, best=False):
    state = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
    }
    filename = f'checkpoint_fold{fold}_epoch{epoch}{"_best" if best else ""}.pth'
    torch.save(state, filename)

In [134]:
# def load_checkpoint(model, optimizer, filename):
#     checkpoint = torch.load(filename)
#     model.load_state_dict(checkpoint['model'])
#     optimizer.load_state_dict(checkpoint['optimizer'])
#     return checkpoint['epoch']



In [135]:
def train_epoch(model, dataloader, criterion, optimizer):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for inputs, labels in tqdm(dataloader, desc="Training", leave=False):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_accuracy = 100 * correct / total
    return epoch_loss, epoch_accuracy

In [136]:
def validate_epoch(model, dataloader, criterion):
    model.eval()
    running_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Validation", leave=False):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss = running_loss / len(dataloader.dataset)
    val_accuracy = 100 * correct / total
    return val_loss, val_accuracy

In [137]:
# def train_model():
#     dataset = FungiDataset(config["base_dir"], transform=get_transforms(), subset='train')
#     kf = KFold(n_splits=config["num_folds"], shuffle=True, random_state=config["seed"])

#     for fold, (train_idx, val_idx) in enumerate(kf.split(np.arange(len(dataset))), 1):
#         print(f"Fold {fold}/{config['num_folds']}")

#         train_sampler = SubsetRandomSampler(train_idx)
#         val_sampler = SubsetRandomSampler(val_idx)
#         train_loader = DataLoader(dataset, batch_size=config["batch_size"], sampler=train_sampler)
#         val_loader = DataLoader(dataset, batch_size=config["batch_size"], sampler=val_sampler)

#         model = CustomCNN(num_classes=len(dataset.classes)).to(device)
#         criterion = nn.CrossEntropyLoss()
#         optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"])
        
        
#         scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)

#         best_val_loss, patience_counter = float('inf'), 0
#         best_model_path = f'checkpoint_fold{fold}_best.pth'

#         for epoch in range(1, config["epochs"] + 1):
#             print(f"Epoch {epoch}/{config['epochs']}")

#             train_loss, train_accuracy = train_epoch(model, train_loader, criterion, optimizer)
#             val_loss, val_accuracy = validate_epoch(model, val_loader, criterion)

#             print(f"Train Loss: {train_loss:.4f}, Acc: {train_accuracy:.2f}%, Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}%")

#             writer.add_scalar('Loss/train', train_loss, epoch)
#             writer.add_scalar('Loss/val', val_loss, epoch)
#             writer.add_scalar('Accuracy/train', train_accuracy, epoch)
#             writer.add_scalar('Accuracy/val', val_accuracy, epoch)
#             writer.add_scalar('Learning Rate', optimizer.param_groups[0]['lr'], epoch)

#             # Update the scheduler based on the validation loss
#             scheduler.step(val_loss)

#             if val_loss < best_val_loss:
#                 best_val_loss = val_loss
#                 patience_counter = 0
#                 print(f"New best model found for fold {fold} at epoch {epoch}, saving model...")
#                 torch.save({
#                     'model_state_dict': model.state_dict(),
#                     'optimizer_state_dict': optimizer.state_dict(),
#                     'epoch': epoch,
#                     'best_val_loss': best_val_loss,
#                 }, best_model_path)
#             else:
#                 patience_counter += 1

#             if patience_counter >= config["patience"]:
#                 print("Early stopping triggered")
#                 break

#     writer.close()


In [None]:
from sklearn.model_selection import StratifiedKFold

def train_model():
    dataset = FungiDataset(config["base_dir"], transform=get_transforms(), subset='train')
    strat_kf = StratifiedKFold(n_splits=config["num_folds"], shuffle=True, random_state=config["seed"])

    for fold, (train_idx, val_idx) in enumerate(strat_kf.split(np.arange(len(dataset)), dataset.labels), 1):
        print(f"Fold {fold}/{config['num_folds']}")

        train_sampler = SubsetRandomSampler(train_idx)
        val_sampler = SubsetRandomSampler(val_idx)
        train_loader = DataLoader(dataset, batch_size=config["batch_size"], sampler=train_sampler)
        val_loader = DataLoader(dataset, batch_size=config["batch_size"], sampler=val_sampler)

        model = CustomCNN(num_classes=len(dataset.classes)).to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"])
        
        
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)

        best_val_loss, patience_counter = float('inf'), 0
        best_model_path = f'checkpoint_fold{fold}_best.pth'

        for epoch in range(1, config["epochs"] + 1):
            print(f"Epoch {epoch}/{config['epochs']}")

            train_loss, train_accuracy = train_epoch(model, train_loader, criterion, optimizer)
            val_loss, val_accuracy = validate_epoch(model, val_loader, criterion)

            print(f"Train Loss: {train_loss:.4f}, Acc: {train_accuracy:.2f}%, Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}%")

            writer.add_scalar('Loss/train', train_loss, epoch)
            writer.add_scalar('Loss/val', val_loss, epoch)
            writer.add_scalar('Accuracy/train', train_accuracy, epoch)
            writer.add_scalar('Accuracy/val', val_accuracy, epoch)
            writer.add_scalar('Learning Rate', optimizer.param_groups[0]['lr'], epoch)

            # Update the scheduler based on the validation loss
            scheduler.step(val_loss)

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
                print(f"New best model found for fold {fold} at epoch {epoch}, saving model...")
                torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'epoch': epoch,
                    'best_val_loss': best_val_loss,
                }, best_model_path)
            else:
                patience_counter += 1

            if patience_counter >= config["patience"]:
                print("Early stopping triggered")
                break

    writer.close()


In [138]:
train_model()

Fold 1/3
Epoch 1/10


                                                           

Train Loss: 4.1833, Acc: 30.87%, Val Loss: 0.4620, Val Acc: 34.49%
New best model found for fold 1 at epoch 1, saving model...
Epoch 2/10


                                                           

Train Loss: 0.8770, Acc: 42.21%, Val Loss: 0.4127, Val Acc: 46.07%
New best model found for fold 1 at epoch 2, saving model...
Epoch 3/10


                                                           

Train Loss: 0.8564, Acc: 43.53%, Val Loss: 0.4258, Val Acc: 42.05%
Epoch 4/10


                                                           

Train Loss: 0.8388, Acc: 46.05%, Val Loss: 0.4224, Val Acc: 47.09%
Epoch 5/10


                                                           

Train Loss: 0.8616, Acc: 45.45%, Val Loss: 0.4021, Val Acc: 48.35%
New best model found for fold 1 at epoch 5, saving model...
Epoch 6/10


                                                           

Train Loss: 0.8226, Acc: 47.70%, Val Loss: 0.4059, Val Acc: 49.31%
Epoch 7/10


                                                           

Train Loss: 0.8309, Acc: 47.91%, Val Loss: 0.4473, Val Acc: 43.55%
Epoch 8/10


                                                           

Train Loss: 0.8242, Acc: 48.48%, Val Loss: 0.3889, Val Acc: 53.33%
New best model found for fold 1 at epoch 8, saving model...
Epoch 9/10


                                                           

Train Loss: 0.7987, Acc: 50.20%, Val Loss: 0.4031, Val Acc: 49.49%
Epoch 10/10


                                                           

Train Loss: 0.8036, Acc: 49.26%, Val Loss: 0.3825, Val Acc: 53.21%
New best model found for fold 1 at epoch 10, saving model...
Fold 2/3
Epoch 1/10


                                                           

Train Loss: 2.5105, Acc: 19.44%, Val Loss: 0.5366, Val Acc: 19.20%
New best model found for fold 2 at epoch 1, saving model...
Epoch 2/10


                                                           

Train Loss: 1.0733, Acc: 19.80%, Val Loss: 0.5367, Val Acc: 19.68%
Epoch 3/10


                                                           

Train Loss: 1.0732, Acc: 19.11%, Val Loss: 0.5367, Val Acc: 19.68%
Epoch 4/10


                                                           

Train Loss: 1.0732, Acc: 19.23%, Val Loss: 0.5367, Val Acc: 19.20%
Epoch 5/10


                                                           

Train Loss: 1.0731, Acc: 19.47%, Val Loss: 0.5367, Val Acc: 19.68%
Epoch 6/10


                                                           

Train Loss: 1.0729, Acc: 20.16%, Val Loss: 0.5367, Val Acc: 19.68%
Epoch 7/10


                                                          

KeyboardInterrupt: 