In [None]:
import torch
import torchvision.models as models
import torchvision.transforms as transform
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, WeightedRandomSampler
import torch.optim as optim
import torch.nn as nn
import numpy as np

In [None]:
class CustomAugmentedDataset(ImageFolder):
    
    # root: folder dataset, transform_dict: nama class & augmentasi khusus
    def __init__(self, root, transform_dict):
        super().__init__(root)
        self.transform_dict = transform_dict
    
    # ambil gambar dari class, apply augmentasi, return gambar hasil augmentasi
    def __getitem__(self, index):
        path, target = self.samples[index]
        sample = self.loader(path)
    
        class_name = self.classes[target]
        transform = self.transform_dict.get(class_name, self.transform)

        if transform:
            sample = transform(sample)
        
        return sample, target

In [None]:
default_transform = transform.Compose([
    transform.Resize(256),
    transform.CenterCrop(224),
    transform.ToTensor(),
    transform.Normalize(mean=[0.485, 0.456, 0.406],
                 std=[0.229, 0.224, 0.225]),
])

strong_transform = transform.Compose([
    transform.Resize(256),
    transform.RandomResizedCrop(224),
    transform.RandomHorizontalFlip(),
    transform.RandomRotation(30),
    transform.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
    transform.ToTensor(),
    transform.Normalize(mean=[0.485, 0.456, 0.406],
                 std=[0.229, 0.224, 0.225]),
])

In [None]:
# deteksi augmentasi khusus buat red spot
small_classes = ['red-spot']
transform_dict = {cls: strong_transform for cls in small_classes}

# load dataset beserta transform-nya
train_dataset = CustomAugmentedDataset(root='../../dataset/Train/', transform_dict=transform_dict)
val_dataset = CustomAugmentedDataset(root='../../dataset/Valid/', transform_dict=transform_dict)

# class count
class_counts = np.bincount([label for _, label in train_dataset.samples])
class_weights = 1. / torch.tensor(class_counts, dtype=torch.float)

# sampler buat ngatasin class imbalance dengan cara ngasih probabilitas sampling lebih tinggi ke class yang punya data sedikit
sample_weights = [class_weights[label] for _, label in train_dataset.samples]
sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

# dataloader
train_dataloader = DataLoader(train_dataset, batch_size=8, sampler=sampler, num_workers=6, pin_memory=True)
val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4)

print(f"Classes detected: {train_dataset.classes}")

print("Augmentation summary per class:")
for cls in train_dataset.classes:
    print(f"{cls.ljust(15)} → {'Strong' if cls in transform_dict else 'Default'}")

In [None]:
# import model
model = models.vgg19(pretrained=True) # GANTI SESUAI MODEL YANG MAU DI-TRAIN
model.eval()

In [None]:
num_classes = len(train_dataset.classes)

# ganti layer terakhir biar disesuain sama jumlah class di dataset
model.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes) # ADJUST SESUAI MODEL

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('is cuda available?', torch.cuda.is_available())
model = model.to(device)


In [None]:
# loss function buat klasifikasi multi-class
criterion = nn.CrossEntropyLoss()

# optimizer
optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.00001)

In [None]:
num_epochs = 30

print(f"Total batches: {len(train_dataloader)}")
best_val_loss = float('inf')

for epoch in range(num_epochs):
    print(f"Starting Epoch {epoch+1}/{num_epochs}")
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (images, labels) in enumerate(train_dataloader):
        if batch_idx % 100 == 0:
            print(f"Processing batch {batch_idx+1}/{len(train_dataloader)}")
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

    avg_loss = running_loss / len(train_dataloader)
    train_accuracy = (correct / total) * 100
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}, Accuracy: {train_accuracy:.2f}%")

    # validation
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for images, labels in val_dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            val_correct += (predicted == labels).sum().item()
            val_total += labels.size(0)

    avg_val_loss = val_loss / len(val_dataloader)
    val_accuracy = (val_correct / val_total) * 100
    print(f"Validation Loss: {avg_val_loss:.4f}, Accuracy: {val_accuracy:.2f}%")

In [None]:
torch.save(model.state_dict(), "vggnet19.pth")
print("Model saved successfully!")