In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image

# -------------------- Config --------------------
train_dir = ""  # path to training directory
batch_size = 16
epochs = 30
learning_rate = 1e-4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -------------------- Class Labels --------------------
# Get class names from folder names
class_names = sorted([d for d in os.listdir(train_dir) if os.path.isdir(os.path.join(train_dir, d))])
num_classes = len(class_names)
class_to_idx = {cls_name: idx for idx, cls_name in enumerate(class_names)}

print("Detected Classes:", class_names)

# -------------------- Dataset --------------------
class MultiClassDataset(Dataset):
    def __init__(self, root_dir, transform=None, class_to_idx=None):
        self.root_dir = root_dir
        self.transform = transform
        self.class_to_idx = class_to_idx
        self.images = []
        self.labels = []

        for class_name in class_names:
            sub_path = os.path.join(root_dir, class_name)
            class_idx = self.class_to_idx[class_name]  # Access the class_to_idx dictionary

            for fname in os.listdir(sub_path):
                if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
                    self.images.append(os.path.join(sub_path, fname))
                    self.labels.append(class_idx)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = torch.tensor(self.labels[idx], dtype=torch.long)

        img = Image.open(img_path).convert("L")  # grayscale

        if self.transform:
            img = self.transform(img)

        # Convert grayscale (1 channel) â†’ 3 channels for ResNet
        img = img.repeat(3, 1, 1)

        return img, label

# -------------------- Transforms --------------------
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # Normalize grayscale images
])

# -------------------- Data Loader --------------------
train_dataset = MultiClassDataset(train_dir, transform=transform, class_to_idx=class_to_idx)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# -------------------- ResNet50 --------------------
model = models.resnet50(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)

# -------------------- Loss & Optimizer --------------------
criterion = nn.CrossEntropyLoss()   # MULTI-CLASS LOSS
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# -------------------- Accuracy Function --------------------
def accuracy(outputs, targets):
    _, preds = torch.max(outputs, 1)
    return (preds == targets).sum().item(), targets.size(0)

# -------------------- Training Loop --------------------
for epoch in range(epochs):
    model.train()
    running_loss = 0
    correct_total = 0
    total_samples = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        correct, total = accuracy(outputs, labels)
        correct_total += correct
        total_samples += total

    epoch_loss = running_loss / total_samples
    epoch_acc = correct_total / total_samples

    print(f"Epoch [{epoch+1}/{epochs}] "
          f"Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")

# -------------------- Save --------------------
save_path = ""  # path to save the model
torch.save(model.state_dict(), save_path)
print("Model saved to:", save_path)
