# NEU Surface Defect Image Classification

This notebook demonstrates how to train and evaluate an image classification model using the NEU surface defect dataset that is distributed in YOLO format. The workflow covers downloading the dataset, converting it into a folder layout that `torchvision` can consume, training a transfer learning model, and running inference on held-out images.

## 1. Environment setup
Install the required dependencies for PyTorch, data processing, and evaluation.

In [None]:
# If you are running this notebook on a CPU-only machine you can install the CPU wheels from PyTorch.
# Remove the index override below if you have a CUDA-enabled environment configured.
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
!pip install -q pyyaml scikit-learn


## 2. Imports and configuration

In [None]:
import os
import random
import shutil
import zipfile
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import yaml
from sklearn.metrics import classification_report, confusion_matrix
from torch import nn
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, models, transforms

plt.rcParams['figure.figsize'] = (8, 6)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"Using device: {DEVICE}")


## 3. Download the NEU YOLO dataset

The following cell downloads the dataset archive from Kaggle using the provided curl command. Make sure your Kaggle API credentials are configured in the environment before running this step.

In [None]:
dataset_dir = Path("data/neu_yolo")
dataset_dir.mkdir(parents=True, exist_ok=True)
zip_path = dataset_dir / "neu-yolo.zip"

if not zip_path.exists():
    print("Downloading dataset...")
    !curl -L -o ./data/neu_yolo/neu-yolo.zip   https://www.kaggle.com/api/v1/datasets/download/zymzym/neu-yolo
else:
    print("Dataset archive already exists, skipping download.")

print(f"Archive location: {zip_path}")


## 4. Extract and reorganize data for classification

The dataset ships in YOLO format with separate image and label folders. The helper function below extracts the archive and converts it into a standard folder layout (`split/class_name/image.jpg`) that can be consumed by `torchvision.datasets.ImageFolder`. Each image is assigned the class that appears in its corresponding YOLO label file.

In [None]:
raw_root = dataset_dir / "raw"
classification_root = Path("data/neu_classification")


def extract_archive(zip_file: Path, target_dir: Path) -> Path:
    if target_dir.exists() and any(target_dir.iterdir()):
        print(f"Using existing extracted contents in {target_dir}")
        return target_dir

    print(f"Extracting {zip_file} to {target_dir} ...")
    target_dir.mkdir(parents=True, exist_ok=True)
    with zipfile.ZipFile(zip_file) as zf:
        zf.extractall(target_dir)
    print("Extraction complete.")
    return target_dir


def read_class_names(dataset_root: Path):
    yaml_path = next(dataset_root.rglob("data.yaml"), None)
    if yaml_path is None:
        raise FileNotFoundError("Could not locate data.yaml inside the extracted dataset.")

    with open(yaml_path, "r") as f:
        data_cfg = yaml.safe_load(f)

    names = data_cfg.get("names") or data_cfg.get("class_names")
    if isinstance(names, dict):
        names = [names[k] for k in sorted(names, key=lambda x: int(x) if isinstance(x, str) and x.isdigit() else x)]
    if not isinstance(names, (list, tuple)):
        raise ValueError("Unable to parse class names from data.yaml")
    return list(names)


def convert_to_classification_layout(dataset_root: Path, target_root: Path) -> None:
    target_root.mkdir(parents=True, exist_ok=True)
    class_names = read_class_names(dataset_root)
    print(f"Detected classes: {class_names}")

    for split in ["train", "valid", "test"]:
        image_dir = dataset_root / split / "images"
        label_dir = dataset_root / split / "labels"
        if not image_dir.exists():
            print(f"Skipping missing split: {split}")
            continue

        for label_file in sorted(label_dir.glob("*.txt")):
            with open(label_file, "r") as f:
                first_line = f.readline().strip()
            if not first_line:
                continue
            class_idx = int(first_line.split()[0])
            class_name = class_names[class_idx]

            candidate = image_dir / f"{label_file.stem}.jpg"
            if not candidate.exists():
                candidate = image_dir / f"{label_file.stem}.png"
            if not candidate.exists():
                matches = list(image_dir.glob(f"{label_file.stem}.*"))
                if not matches:
                    continue
                candidate = matches[0]

            dest_dir = target_root / split / class_name
            dest_dir.mkdir(parents=True, exist_ok=True)
            shutil.copy(candidate, dest_dir / candidate.name)

    print(f"Conversion complete. Classification data stored in {target_root}")


extract_archive(zip_path, raw_root)
convert_to_classification_layout(raw_root, classification_root)


## 5. Create datasets and data loaders
We apply data augmentation to the training split and create `DataLoader` objects for efficient batching.

In [None]:
image_size = 224
batch_size = 32

train_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

inference_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dir = classification_root / "train"
valid_dir = classification_root / "valid"
test_dir = classification_root / "test"

if not train_dir.exists():
    raise RuntimeError("Training split was not created. Please verify the dataset conversion step.")

base_train_dataset = datasets.ImageFolder(train_dir)
class_names = base_train_dataset.classes
num_classes = len(class_names)

val_ratio = 0.2 if len(base_train_dataset) > 5 else 0.0
if valid_dir.exists():
    valid_dataset = datasets.ImageFolder(valid_dir, transform=inference_transform)
    train_dataset = datasets.ImageFolder(train_dir, transform=train_transform)
else:
    indices = torch.randperm(len(base_train_dataset), generator=torch.Generator().manual_seed(SEED)).tolist()
    val_size = int(len(base_train_dataset) * val_ratio)
    val_indices = indices[:val_size]
    train_indices = indices[val_size:]
    train_dataset = Subset(datasets.ImageFolder(train_dir, transform=train_transform), train_indices)
    valid_dataset = Subset(datasets.ImageFolder(train_dir, transform=inference_transform), val_indices) if val_indices else None

if test_dir.exists():
    test_dataset = datasets.ImageFolder(test_dir, transform=inference_transform)
else:
    test_dataset = valid_dataset

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True) if valid_dataset is not None else None
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True) if test_dataset is not None else None

print(f"Classes: {class_names}")
print(f"Train batches: {len(train_loader)}")
if valid_loader:
    print(f"Validation batches: {len(valid_loader)}")
if test_loader:
    print(f"Test batches: {len(test_loader)}")


### Preview a batch of training images

In [None]:
def show_batch(loader):
    images, labels = next(iter(loader))
    unnorm = images.clone()
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
    unnorm = unnorm * std + mean

    grid_rows = 4
    grid_cols = 4
    fig, axes = plt.subplots(grid_rows, grid_cols, figsize=(8, 8))
    for ax, img, label in zip(axes.flatten(), unnorm[: grid_rows * grid_cols], labels[: grid_rows * grid_cols]):
        img = img.permute(1, 2, 0).cpu().numpy()
        img = np.clip(img, 0, 1)
        ax.imshow(img)
        ax.set_title(class_names[label])
        ax.axis("off")
    plt.tight_layout()

show_batch(train_loader)


## 6. Define the model and training utilities
We fine-tune a pretrained ResNet-18 model from `torchvision` by replacing its final classification layer.

In [None]:
weights = models.ResNet18_Weights.DEFAULT
model = models.resnet18(weights=weights)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

print(model)


In [None]:
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc


def evaluate(model, loader, criterion, device):
    if loader is None:
        return None, None
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc


## 7. Train the model

In [None]:
num_epochs = 5
best_valid_acc = 0.0
best_state = None

for epoch in range(1, num_epochs + 1):
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, DEVICE)
    valid_loss, valid_acc = evaluate(model, valid_loader, criterion, DEVICE)
    scheduler.step()

    msg = f"Epoch {epoch}/{num_epochs} - Train loss: {train_loss:.4f}, Train acc: {train_acc:.4f}"
    if valid_loader:
        msg += f" | Valid loss: {valid_loss:.4f}, Valid acc: {valid_acc:.4f}"
        if valid_acc > best_valid_acc:
            best_valid_acc = valid_acc
            best_state = model.state_dict()
    print(msg)

if best_state is not None:
    model.load_state_dict(best_state)
    print(f"Loaded best model weights with validation accuracy {best_valid_acc:.4f}")


## 8. Evaluate on the test split

In [None]:
if test_loader is None:
    raise RuntimeError("Test loader is not available. Ensure the dataset contains a test or validation split.")

model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(DEVICE)
        outputs = model(images)
        preds = torch.argmax(outputs, dim=1).cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(labels.numpy())

print(classification_report(all_labels, all_preds, target_names=class_names))
print("Confusion matrix:
", confusion_matrix(all_labels, all_preds))


## 9. Run inference on a single image

In [None]:

import PIL.Image as Image

def predict_image(image_path: Path):
    image = Image.open(image_path).convert("RGB")
    input_tensor = inference_transform(image).unsqueeze(0).to(DEVICE)
    model.eval()
    with torch.no_grad():
        output = model(input_tensor)
        probs = torch.nn.functional.softmax(output, dim=1)[0].cpu().numpy()
    top_idx = np.argmax(probs)
    return class_names[top_idx], probs[top_idx], probs

# Pick a random image from the test set (or validation set if test is not available)
search_dir = test_dir if test_dir.exists() else train_dir
extensions = ("*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tif", "*.tiff")
candidates = []
for pattern in extensions:
    candidates.extend(search_dir.rglob(pattern))
if not candidates:
    raise RuntimeError(f"No images found in {search_dir}.")
random_image = random.choice(candidates)
label, confidence, probas = predict_image(random_image)

print(f"Predicted class: {label} (confidence: {confidence:.2%})")
print(f"Image path: {random_image}")

plt.imshow(Image.open(random_image))
plt.title(f"Predicted: {label} ({confidence:.1%})")
plt.axis('off')
plt.show()
