Written by Prokash Chandra Roy

In [13]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import confusion_matrix, accuracy_score
import seaborn as sns
from torchvision import transforms

In [14]:
# Set parameters
IMG_HEIGHT = 256
IMG_WIDTH = 256
BATCH_SIZE = 128
EPOCHS = 40
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Custom Dataset class for loading 1-channel TIFF images and labels
class TiffDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load image
        img_path = self.image_paths[idx]
        try:
            img = Image.open(img_path).convert('L')  # Convert to grayscale (1 channel)
        except Exception as e:
            raise ValueError(f"Failed to load image {img_path}: {e}")
        
        # Apply transformations
        if self.transform:
            img = self.transform(img)
        
        label = self.labels[idx]
        return img, label

# Load data from folder and CSV
def load_data(image_folder, csv_path):
    # Read CSV
    try:
        df = pd.read_csv(csv_path)
    except Exception as e:
        raise ValueError(f"Failed to read CSV {csv_path}: {e}")
    
    # Check if required columns exist
    if 'index' not in df.columns or 'class' not in df.columns:
        raise ValueError("CSV must have 'index' and 'class' columns")
    
    # Get image paths and labels, converting integer index to .tiff filename
    image_paths = [os.path.join(image_folder, f"{int(img_name)}.tiff") for img_name in df['index']]
    labels = df['class'].values
    
    # Validate image paths
    missing_images = [img_path for img_path in image_paths if not os.path.exists(img_path)]
    if missing_images:
        print(f"Warning: {len(missing_images)} images not found, e.g., {missing_images[:5]}")
    
    # Encode labels
    le = LabelEncoder()
    labels = le.fit_transform(labels)
    
    return image_paths, labels, le


Using device: cuda


In [15]:
# Define CNN model for 1-channel input
class CNN(nn.Module):
    def __init__(self, num_classes):
        super(CNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),  # 1 input channel
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * (IMG_HEIGHT // 8) * (IMG_WIDTH // 8), 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


In [16]:

def plot_training_history(train_losses,
                          val_losses,
                          train_accuracies,
                          val_accuracies,
                          save_path='training_history.png',
                          font_size=16):
   
    plt.figure(figsize=(12, 4))

    # Loss subplot
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses,   label='Validation Loss')
    plt.title('Loss',       fontsize=font_size + 2)
    plt.xlabel('Epoch',     fontsize=font_size)
    plt.ylabel('Loss',      fontsize=font_size)
    plt.xticks(fontsize=font_size)
    plt.yticks(fontsize=font_size)
    plt.legend(fontsize=font_size)

    # Accuracy subplot
    plt.subplot(1, 2, 2)
    plt.plot(train_accuracies, label='Training Accuracy')
    plt.plot(val_accuracies,   label='Validation Accuracy')
    plt.title('Accuracy',      fontsize=font_size + 2)
    plt.xlabel('Epoch',        fontsize=font_size)
    plt.ylabel('Accuracy',     fontsize=font_size)
    plt.xticks(fontsize=font_size)
    plt.yticks(fontsize=font_size)
    plt.legend(fontsize=font_size)

    plt.tight_layout()
    plt.savefig(save_path)
    print(f"Training history plot saved to: {os.path.abspath(save_path)}")
    plt.close()


def plot_confusion_matrix(y_true,
                          y_pred,
                          classes=None,
                          save_path='confusion_matrix.png',
                          font_size=16):
    """
    y_true, y_pred : lists or arrays of true / predicted labels
    classes        : list of class names, e.g. ['Circular','Elliptical','Rock']
    font_size      : base font size for ticks, labels, and annotations
    """
    if classes is None:
        classes = ['Circular', 'Elliptical', 'Rock']

    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 8))

    sns.heatmap(
        cm,
        annot=True,
        fmt='d',
        cmap='Blues',
        xticklabels=classes,
        yticklabels=classes,
        annot_kws={"size": font_size}
    )

    plt.title('Confusion Matrix', fontsize=font_size + 2)
    plt.xlabel('Predicted Label',    fontsize=font_size)
    plt.ylabel('True Label',         fontsize=font_size)

    plt.xticks(fontsize=font_size, rotation=45)
    plt.yticks(fontsize=font_size, rotation=0)

    plt.tight_layout()
    plt.savefig(save_path)
    print(f"Confusion matrix plot saved to: {os.path.abspath(save_path)}")
    plt.close()

def main():
    # Paths and filenames
    image_folder      = 'training/all'
    csv_path          = 'training.csv'
    best_model_path   = 'best_image_classification_model.pth'
    final_model_path  = 'final_image_classification_model.pth'
    log_csv_path      = 'train_log.csv'

    # 1. Load data
    image_paths, labels, label_encoder = load_data(image_folder, csv_path)

    # 2. Train/test split
    train_paths, test_paths, train_labels, test_labels = train_test_split(
        image_paths, labels,
        test_size=0.2,
        random_state=42,
        stratify=labels
    )
    print(f"Training set size: {len(train_paths)}, Test set size: {len(test_paths)}")

    # 3. Transforms & datasets
    transform = transforms.Compose([
        transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5]),
    ])
    train_dataset = TiffDataset(train_paths, train_labels, transform=transform)
    test_dataset  = TiffDataset(test_paths,  test_labels,  transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader  = DataLoader(test_dataset,  batch_size=BATCH_SIZE, shuffle=False)

    # 4. Model, loss, optimizer
    num_classes = len(label_encoder.classes_)
    model       = CNN(num_classes).to(DEVICE)
    criterion   = nn.CrossEntropyLoss()
    optimizer   = optim.Adam(model.parameters(), lr=0.001)

    # 5. Training loop
    best_val_acc = 0.0
    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []

    for epoch in range(EPOCHS):
        # — train —
        model.train()
        running_loss, correct, total = 0.0, 0, 0
        for images, labels in train_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total   += labels.size(0)

        train_loss = running_loss / len(train_loader)
        train_acc  = correct / total
        train_losses.append(train_loss)
        train_accuracies.append(train_acc)

        # — validate —
        model.eval()
        val_loss, correct, total = 0.0, 0, 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, preds = torch.max(outputs, 1)
                correct += (preds == labels).sum().item()
                total   += labels.size(0)

        val_loss = val_loss / len(test_loader)
        val_acc  = correct / total
        val_losses.append(val_loss)
        val_accuracies.append(val_acc)

        print(f"Epoch {epoch+1}/{EPOCHS} — "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} — "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), best_model_path)
            print(f"→ Saved best model (Val Acc: {best_val_acc:.4f}) at epoch {epoch+1}")

    # 6. Save training log
    df = pd.DataFrame({
        'epoch':      np.arange(1, EPOCHS + 1),
        'train_loss': train_losses,
        'val_loss':   val_losses,
        'train_acc':  train_accuracies,
        'val_acc':    val_accuracies,
    })
    df.to_csv(log_csv_path, index=False)
    print(f"Training log saved to: {os.path.abspath(log_csv_path)}")

    # 7. Load best model & evaluate on test
    print(f"\nLoading best model from {best_model_path} for final evaluation…")
    model.load_state_dict(torch.load(best_model_path))
    model.to(DEVICE)
    model.eval()

    y_true, y_pred = [], []
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())

    test_accuracy = accuracy_score(y_true, y_pred)
    print(f"Final Test Accuracy (Best Model): {test_accuracy:.4f}")

    # 8. Generate plots with custom labels and font size
    plot_confusion_matrix(
        y_true,
        y_pred,
        classes=['Circular', 'Elliptical', 'Rock'],
        save_path='confusion_matrix.png',
        font_size=16
    )
    plot_training_history(train_losses, val_losses, train_accuracies, val_accuracies)

    # 9. Save final model & encoder
    torch.save(model.state_dict(), final_model_path)
    np.save('label_encoder_classes.npy', label_encoder.classes_)

if __name__ == '__main__':
    main()


Training set size: 14296, Test set size: 3574
Epoch 1/40 — Train Loss: 0.3252, Train Acc: 0.9028 — Val Loss: 0.0477, Val Acc: 0.9910
→ Saved best model (Val Acc: 0.9910) at epoch 1
Epoch 2/40 — Train Loss: 0.0542, Train Acc: 0.9849 — Val Loss: 0.0560, Val Acc: 0.9877
Epoch 3/40 — Train Loss: 0.0280, Train Acc: 0.9915 — Val Loss: 0.0063, Val Acc: 0.9989
→ Saved best model (Val Acc: 0.9989) at epoch 3
Epoch 4/40 — Train Loss: 0.0063, Train Acc: 0.9990 — Val Loss: 0.0013, Val Acc: 0.9997
→ Saved best model (Val Acc: 0.9997) at epoch 4
Epoch 5/40 — Train Loss: 0.0072, Train Acc: 0.9987 — Val Loss: 0.0018, Val Acc: 0.9992
Epoch 6/40 — Train Loss: 0.0063, Train Acc: 0.9990 — Val Loss: 0.0057, Val Acc: 0.9969
Epoch 7/40 — Train Loss: 0.0036, Train Acc: 0.9992 — Val Loss: 0.0103, Val Acc: 0.9983
Epoch 8/40 — Train Loss: 0.0038, Train Acc: 0.9993 — Val Loss: 0.0007, Val Acc: 0.9997
Epoch 9/40 — Train Loss: 0.0006, Train Acc: 1.0000 — Val Loss: 0.0029, Val Acc: 0.9992
Epoch 10/40 — Train Loss: 0