<a href="https://colab.research.google.com/github/satwiksps/Wafer_Defect_Classification/blob/main/wafer_detection_v1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# 📦 Imports
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
from timm import create_model
from tqdm import tqdm
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 📁 Dataset
class WaferMapDataset(Dataset):
    def __init__(self, npz_path, transform=None):
        data = np.load(npz_path)
        self.images = data['arr_0']  # (N, 52, 52)
        self.labels = data['arr_1']  # (N, 8)

        self.targets = [tuple(row) for row in self.labels]
        label_map = {label: idx for idx, label in enumerate(sorted(set(self.targets)))}
        self.targets = torch.tensor([label_map[tuple(row)] for row in self.labels])
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx].astype(np.uint8)
        img = Image.fromarray(img).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, self.targets[idx]

# 🧠 Model Loader
def get_model(name, num_classes=38):
    return create_model(name, pretrained=True, num_classes=num_classes).to(device)

model_names = {
    "mobilevit": "mobilevitv2_050",
    "tinyvit": "tiny_vit_5m_224",
    "swiftformer": "swiftformer_s",
    "levit": "levit_128s"
}

# 🔁 Training & Validation
def train_one_epoch(model, loader, optimizer):
    model.train()
    total_correct, total = 0, 0
    for x, y in tqdm(loader, desc="Training", leave=False):
        x, y = x.to(device), y.to(device)
        out = model(x)
        loss = F.cross_entropy(out, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = out.argmax(1)
        total_correct += (preds == y).sum().item()
        total += y.size(0)
    return total_correct / total

def validate(model, loader):
    model.eval()
    total_correct, total = 0, 0
    all_preds, all_targets = [], []
    with torch.no_grad():
        for x, y in tqdm(loader, desc="Validation", leave=False):
            x, y = x.to(device), y.to(device)
            out = model(x)
            preds = out.argmax(1)
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(y.cpu().numpy())
            total_correct += (preds == y).sum().item()
            total += y.size(0)
    acc = total_correct / total
    return acc, all_preds, all_targets

# 🧪 Full Experiment
def run_experiment(model_key, npz_path="Wafer_Map_Datasets.npz", epochs=5, batch_size=64):
    print(f"\n=== Running {model_key.upper()} ===")
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])

    dataset = WaferMapDataset(npz_path, transform)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_ds, val_ds = random_split(dataset, [train_size, val_size])
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size)

    model = get_model(model_names[model_key])
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

    for epoch in range(epochs):
        train_acc = train_one_epoch(model, train_loader, optimizer)
        val_acc, preds, targets = validate(model, val_loader)
        print(f"Epoch {epoch+1}: Train Acc = {train_acc:.4f}, Val Acc = {val_acc:.4f}")

    # Final Evaluation
    print("\n🔍 Classification Report:")
    print(classification_report(targets, preds, digits=3))

    print("\n📊 Confusion Matrix:")
    cm = confusion_matrix(targets, preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=False, cmap="Blues")
    plt.title(f"Confusion Matrix - {model_key.upper()}")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.show()

    return val_acc

# 🚀 Run All Models
results = {}
for model_key in ["mobilevit", "tinyvit", "swiftformer", "levit"]:
    acc = run_experiment(model_key, epochs=5)
    results[model_key] = acc

print("\n🏁 Final Accuracy Comparison:")
for k, v in results.items():
    print(f"{k.upper()}: {v:.4f}")



=== Running MOBILEVIT ===


BadZipFile: File is not a zip file