<a href="https://colab.research.google.com/github/rediahmds/eco-sort/blob/main/train/train_cnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Prepare dataset

In [1]:
!pip install kagglehub



### Download

In [9]:
import kagglehub

alistair_ds = kagglehub.dataset_download("alistairking/recyclable-and-household-waste-classification")
print("Path to dataset files:", alistair_ds)

mostafa_ds = kagglehub.dataset_download("mostafaabla/garbage-classification")
print("Path to dataset files:", mostafa_ds)

joe_ds = kagglehub.dataset_download("joebeachcapital/realwaste")
print("Path to dataset files:", joe_ds)

glhdamar_ds = kagglehub.dataset_download("glhdamar/new-trash-classfication-dataset")
print("Path to dataset files:", glhdamar_ds)

Downloading from https://www.kaggle.com/api/v1/datasets/download/glhdamar/new-trash-classfication-dataset?dataset_version_number=3...


100%|██████████| 391M/391M [00:05<00:00, 68.4MB/s]

Extracting files...





Path to dataset files: /root/.cache/kagglehub/datasets/glhdamar/new-trash-classfication-dataset/versions/3


Show directory tree

In [11]:
from pathlib import Path

def print_directory_tree(root: Path, prefix: str = ""):
    """
    Mencetak struktur direktori dengan tampilan seperti pohon.
    Hanya menampilkan folder (tanpa file).
    """
    subdirs = sorted([p for p in root.iterdir() if p.is_dir()])
    for i, subdir in enumerate(subdirs):
        connector = "└── " if i == len(subdirs) - 1 else "├── "
        print(f"{prefix}{connector}{subdir.name}")
        extension = "    " if i == len(subdirs) - 1 else "│   "
        print_directory_tree(subdir, prefix + extension)

# Path ke folder utama
alistair_path = Path(alistair_ds) / "images" / "images"
mostafa_path = Path(mostafa_ds) / "garbage_classification"
joe_path = Path(joe_ds) / "realwaste-main" / "RealWaste"
glhdamar_path = Path(glhdamar_ds) / "new-dataset-trash-type-v2"

# Cetak pohon direktori
print(alistair_path.name)
print_directory_tree(alistair_path)

print(mostafa_path.name)
print_directory_tree(mostafa_path)

print(joe_path.name)
print_directory_tree(joe_path)

print(glhdamar_ds)
print_directory_tree(glhdamar_path)


/root/.cache/kagglehub/datasets/glhdamar/new-trash-classfication-dataset/versions/3
├── cardboard
├── e-waste
├── glass
├── metal
├── organic
├── paper
├── plastic
├── textile
└── trash


### Copy Dataset

In [4]:
from pathlib import Path
import shutil
import random

def copy_n_files(src_dir, dst_dir, n, randomize=False):
    src_path = Path(src_dir)
    dst_path = Path(dst_dir)

    # Buat folder tujuan jika belum ada
    dst_path.mkdir(parents=True, exist_ok=True)

    # Ambil semua file dari direktori sumber
    all_files = [f for f in src_path.iterdir() if f.is_file()]

    # Pastikan n tidak lebih besar dari jumlah file
    n = min(n, len(all_files))

    # Tentukan file mana yang akan disalin
    if randomize:
        files_to_copy = random.sample(all_files, n)
    else:
        files_to_copy = sorted(all_files)[:n]

    # Copy file satu per satu
    for file in files_to_copy:
        shutil.copy(file, dst_path)
        print(f"Copied: {file.name}")

    print(f"\nTotal {n} files copied from '{src_dir}' to '{dst_dir}' (random: {randomize}).")


#### Customize Alistair Dataset

In [5]:
from pathlib import Path
import shutil
from tqdm import tqdm

source_root = alistair_path
target_root = Path("dataset/train")
target_root.mkdir(parents=True, exist_ok=True)

class_map = {
    "food_waste": "organic",
    "eggshells": "organic",
    "coffee_grounds": "organic",
    "tea_bags": "organic",
    "plastic_soda_bottles": "plastic",
    "plastic_trash_bags": "plastic",
    "plastic_food_containers": "plastic",
    "plastic_shopping_bags": "plastic",
    "plastic_straws": "plastic",
    "plastic_water_bottles": "plastic",
    "plastic_detergent_bottles": "plastic",
    "plastic_cup_lids": "plastic",
    "glass_food_jars": "glass",
    "glass_beverage_bottles": "glass",
    "glass_cosmetic_containers": "glass",
    "aluminum_soda_cans": "metal",
    "aluminum_food_cans": "metal",
    "steel_food_cans": "metal",
    "aerosol_cans": "metal",
    "cardboard_boxes": "paper",
    "cardboard_packaging": "paper",
    "magazines": "paper",
    "newspaper": "paper",
    "office_paper": "paper",
    "paper_cups": "paper",
    "styrofoam_cups": "styrofoam",
    "styrofoam_food_containers": "styrofoam",
    "clothing": "textiles",
    "shoes": "textiles"
}

print("🚀 Memulai pengelompokan dataset dengan penamaan ulang...\n")

for class_name, parent_class in class_map.items():
    for subset in ["default", "real_world"]:
        class_dir = source_root / class_name / subset
        if class_dir.exists():
            img_list = list(class_dir.glob("*.*"))
            print(f"📁 Menyalin {len(img_list)} gambar dari '{class_name}/{subset}' ke '{parent_class}'")
            for i, img in enumerate(tqdm(img_list, desc=f"{class_name}/{subset}", leave=False)):
                dest_dir = target_root / parent_class
                dest_dir.mkdir(parents=True, exist_ok=True)

                # Penamaan ulang
                ext = img.suffix
                new_name = f"{subset}_{class_name}_{i:04d}{ext}"
                shutil.copy(img, dest_dir / new_name)

print("\n✅ Pengelompokan selesai tanpa konflik penamaan.")
print("📂 Dataset tersimpan di:", target_root.resolve())


🚀 Memulai pengelompokan dataset dengan penamaan ulang...

📁 Menyalin 250 gambar dari 'food_waste/default' ke 'organic'




📁 Menyalin 250 gambar dari 'food_waste/real_world' ke 'organic'




📁 Menyalin 250 gambar dari 'eggshells/default' ke 'organic'




📁 Menyalin 250 gambar dari 'eggshells/real_world' ke 'organic'




📁 Menyalin 250 gambar dari 'coffee_grounds/default' ke 'organic'




📁 Menyalin 250 gambar dari 'coffee_grounds/real_world' ke 'organic'




📁 Menyalin 250 gambar dari 'tea_bags/default' ke 'organic'




📁 Menyalin 250 gambar dari 'tea_bags/real_world' ke 'organic'


                                                            

📁 Menyalin 250 gambar dari 'plastic_soda_bottles/default' ke 'plastic'








📁 Menyalin 250 gambar dari 'plastic_soda_bottles/real_world' ke 'plastic'




📁 Menyalin 250 gambar dari 'plastic_trash_bags/default' ke 'plastic'




📁 Menyalin 250 gambar dari 'plastic_trash_bags/real_world' ke 'plastic'




📁 Menyalin 250 gambar dari 'plastic_food_containers/default' ke 'plastic'




📁 Menyalin 250 gambar dari 'plastic_food_containers/real_world' ke 'plastic'




📁 Menyalin 250 gambar dari 'plastic_shopping_bags/default' ke 'plastic'




📁 Menyalin 250 gambar dari 'plastic_shopping_bags/real_world' ke 'plastic'




📁 Menyalin 250 gambar dari 'plastic_straws/default' ke 'plastic'




📁 Menyalin 250 gambar dari 'plastic_straws/real_world' ke 'plastic'




📁 Menyalin 250 gambar dari 'plastic_water_bottles/default' ke 'plastic'




📁 Menyalin 250 gambar dari 'plastic_water_bottles/real_world' ke 'plastic'




📁 Menyalin 250 gambar dari 'plastic_detergent_bottles/default' ke 'plastic'




📁 Menyalin 250 gambar dari 'plastic_detergent_bottles/real_world' ke 'plastic'




📁 Menyalin 250 gambar dari 'plastic_cup_lids/default' ke 'plastic'




📁 Menyalin 250 gambar dari 'plastic_cup_lids/real_world' ke 'plastic'




📁 Menyalin 250 gambar dari 'glass_food_jars/default' ke 'glass'




📁 Menyalin 250 gambar dari 'glass_food_jars/real_world' ke 'glass'




📁 Menyalin 250 gambar dari 'glass_beverage_bottles/default' ke 'glass'




📁 Menyalin 250 gambar dari 'glass_beverage_bottles/real_world' ke 'glass'




📁 Menyalin 250 gambar dari 'glass_cosmetic_containers/default' ke 'glass'




📁 Menyalin 250 gambar dari 'glass_cosmetic_containers/real_world' ke 'glass'




📁 Menyalin 250 gambar dari 'aluminum_soda_cans/default' ke 'metal'




📁 Menyalin 250 gambar dari 'aluminum_soda_cans/real_world' ke 'metal'




📁 Menyalin 250 gambar dari 'aluminum_food_cans/default' ke 'metal'




📁 Menyalin 250 gambar dari 'aluminum_food_cans/real_world' ke 'metal'




📁 Menyalin 250 gambar dari 'steel_food_cans/default' ke 'metal'




📁 Menyalin 250 gambar dari 'steel_food_cans/real_world' ke 'metal'




📁 Menyalin 250 gambar dari 'aerosol_cans/default' ke 'metal'




📁 Menyalin 250 gambar dari 'aerosol_cans/real_world' ke 'metal'




📁 Menyalin 250 gambar dari 'cardboard_boxes/default' ke 'paper'




📁 Menyalin 250 gambar dari 'cardboard_boxes/real_world' ke 'paper'




📁 Menyalin 250 gambar dari 'cardboard_packaging/default' ke 'paper'




📁 Menyalin 250 gambar dari 'cardboard_packaging/real_world' ke 'paper'




📁 Menyalin 250 gambar dari 'magazines/default' ke 'paper'




📁 Menyalin 250 gambar dari 'magazines/real_world' ke 'paper'




📁 Menyalin 250 gambar dari 'newspaper/default' ke 'paper'




📁 Menyalin 250 gambar dari 'newspaper/real_world' ke 'paper'




📁 Menyalin 250 gambar dari 'office_paper/default' ke 'paper'




📁 Menyalin 250 gambar dari 'office_paper/real_world' ke 'paper'




📁 Menyalin 250 gambar dari 'paper_cups/default' ke 'paper'




📁 Menyalin 250 gambar dari 'paper_cups/real_world' ke 'paper'




📁 Menyalin 250 gambar dari 'styrofoam_cups/default' ke 'styrofoam'




📁 Menyalin 250 gambar dari 'styrofoam_cups/real_world' ke 'styrofoam'




📁 Menyalin 250 gambar dari 'styrofoam_food_containers/default' ke 'styrofoam'




📁 Menyalin 250 gambar dari 'styrofoam_food_containers/real_world' ke 'styrofoam'




📁 Menyalin 250 gambar dari 'clothing/default' ke 'textiles'




📁 Menyalin 250 gambar dari 'clothing/real_world' ke 'textiles'




📁 Menyalin 250 gambar dari 'shoes/default' ke 'textiles'




📁 Menyalin 250 gambar dari 'shoes/real_world' ke 'textiles'


                                                                     


✅ Pengelompokan selesai tanpa konflik penamaan.
📂 Dataset tersimpan di: /content/dataset/train




In [17]:
# Uncomment all for first run

copy_n_files(f"{mostafa_ds}/garbage_classification/paper", "dataset/train/paper", 500, randomize=True)
copy_n_files(f"{mostafa_ds}/garbage_classification/cardboard", "dataset/train/paper", 500, randomize=True)

copy_n_files(f"{mostafa_ds}/garbage_classification/white-glass", "dataset/train/glass", 600, randomize=True)
copy_n_files(f"{mostafa_ds}/garbage_classification/brown-glass", "dataset/train/glass", 600, randomize=True)
copy_n_files(f"{mostafa_ds}/garbage_classification/green-glass", "dataset/train/glass", 600, randomize=True)

copy_n_files(f"{mostafa_ds}/garbage_classification/clothes", "dataset/train/textiles", 1500, randomize=True)
copy_n_files(f"{mostafa_ds}/garbage_classification/shoes", "dataset/train/textiles", 1500, randomize=True)

copy_n_files(f"{mostafa_ds}/garbage_classification/metal", "dataset/train/metal", 750, randomize=True)
copy_n_files(f"{joe_ds}/realwaste-main/RealWaste/Metal", "dataset/train/metal", 750, randomize=True)

copy_n_files(f"{mostafa_ds}/garbage_classification/biological", "dataset/train/organic", 980, randomize=True)
copy_n_files(f"{glhdamar_ds}/new-dataset-trash-type-v2/organic", "dataset/train/organic", 960, randomize=True)



Copied: biological588.jpg
Copied: Vegetation_210.jpg
Copied: biological256.jpg
Copied: Vegetation_320.jpg
Copied: Vegetation_92.jpg
Copied: Food Organics_145.jpg
Copied: Vegetation_187.jpg
Copied: biological756.jpg
Copied: Vegetation_344.jpg
Copied: Vegetation_39.jpg
Copied: biological273.jpg
Copied: biological509.jpg
Copied: organic_008875_photo.jpg
Copied: biological823.jpg
Copied: biological678.jpg
Copied: biological66.jpg
Copied: biological973.jpg
Copied: biological212.jpg
Copied: organic_002456_photo.jpg
Copied: biological52.jpg
Copied: organic_008724_photo.jpg
Copied: Vegetation_84.jpg
Copied: organic_003474_photo.jpg
Copied: Food Organics_149.jpg
Copied: Vegetation_387.jpg
Copied: biological118.jpg
Copied: organic_007215_photo.jpg
Copied: biological587.jpg
Copied: Food Organics_208.jpg
Copied: biological946.jpg
Copied: biological166.jpg
Copied: Vegetation_269.jpg
Copied: Food Organics_369.jpg
Copied: Vegetation_403.jpg
Copied: biological517.jpg
Copied: biological47.jpg
Copied: b

#### TODO: Check for duplicates

In [23]:
!pip install imagehash

Collecting imagehash
  Downloading ImageHash-4.3.2-py2.py3-none-any.whl.metadata (8.4 kB)
Downloading ImageHash-4.3.2-py2.py3-none-any.whl (296 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m296.7/296.7 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: imagehash
Successfully installed imagehash-4.3.2


In [24]:
from pathlib import Path
from PIL import Image
import imagehash
from collections import defaultdict

def remove_duplicate_images(dataset_dir: Path | str, delete: bool = True, hash_func=imagehash.average_hash) -> list[Path]:
    """
    Deteksi dan hapus gambar duplikat berdasarkan hash visual.

    Args:
        dataset_dir (Path or str): Direktori utama dataset.
        delete (bool): Jika True, hapus gambar duplikat. Jika False, hanya laporkan.
        hash_func (function): Fungsi hash dari imagehash (default: average_hash).

    Returns:
        List[Path]: Daftar path gambar yang terdeteksi sebagai duplikat.
    """
    dataset_dir = Path(dataset_dir)
    hashes = defaultdict(list)
    duplicates = []

    print(f"🔍 Mendeteksi duplikat di dalam: {dataset_dir.resolve()}")

    for img_path in dataset_dir.rglob("*.*"):
        if img_path.suffix.lower() not in [".jpg", ".jpeg", ".png"]:
            continue
        try:
            img = Image.open(img_path).convert("RGB")
            h = hash_func(img)
            if h in hashes:
                print(f"⚠️ Duplikat ditemukan: {img_path.name} ↔ {hashes[h][0].name}")
                duplicates.append(img_path)
            hashes[h].append(img_path)
        except Exception as e:
            print(f"❌ Gagal membuka {img_path}: {e}")

    if delete:
        for dup_path in duplicates:
            print(f"🗑️ Menghapus: {dup_path}")
            try:
                dup_path.unlink()
            except Exception as e:
                print(f"❌ Gagal menghapus {dup_path}: {e}")

    print(f"✅ Total duplikat ditemukan: {len(duplicates)}")
    return duplicates


In [25]:
remove_duplicate_images("dataset/train", delete=False)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
⚠️ Duplikat ditemukan: default_glass_cosmetic_containers_0150.png ↔ real_world_glass_cosmetic_containers_0120.png
⚠️ Duplikat ditemukan: green-glass471.jpg ↔ green-glass544.jpg
⚠️ Duplikat ditemukan: green-glass345.jpg ↔ brown-glass449.jpg
⚠️ Duplikat ditemukan: default_glass_cosmetic_containers_0062.png ↔ real_world_glass_cosmetic_containers_0125.png
⚠️ Duplikat ditemukan: real_world_glass_food_jars_0018.png ↔ default_glass_food_jars_0035.png
⚠️ Duplikat ditemukan: default_glass_cosmetic_containers_0116.png ↔ real_world_glass_cosmetic_containers_0066.png
⚠️ Duplikat ditemukan: real_world_glass_cosmetic_containers_0226.png ↔ real_world_glass_cosmetic_containers_0148.png
⚠️ Duplikat ditemukan: real_world_glass_cosmetic_containers_0051.png ↔ default_glass_cosmetic_containers_0228.png
⚠️ Duplikat ditemukan: real_world_glass_cosmetic_containers_0026.png ↔ default_glass_cosmetic_containers_0088.png
⚠️ Duplikat ditemukan: real_

[PosixPath('dataset/train/organic/real_world_eggshells_0090.png'),
 PosixPath('dataset/train/organic/default_coffee_grounds_0225.png'),
 PosixPath('dataset/train/organic/real_world_eggshells_0067.png'),
 PosixPath('dataset/train/organic/default_eggshells_0169.png'),
 PosixPath('dataset/train/organic/default_food_waste_0158.png'),
 PosixPath('dataset/train/organic/real_world_food_waste_0052.png'),
 PosixPath('dataset/train/organic/real_world_coffee_grounds_0217.png'),
 PosixPath('dataset/train/organic/real_world_food_waste_0161.png'),
 PosixPath('dataset/train/organic/real_world_food_waste_0176.png'),
 PosixPath('dataset/train/organic/default_food_waste_0029.png'),
 PosixPath('dataset/train/organic/real_world_coffee_grounds_0104.png'),
 PosixPath('dataset/train/organic/real_world_tea_bags_0239.png'),
 PosixPath('dataset/train/organic/real_world_coffee_grounds_0183.png'),
 PosixPath('dataset/train/organic/real_world_food_waste_0081.png'),
 PosixPath('dataset/train/organic/default_tea_bag

### Create Validation Dataset

This dataset will be created by moving some files from training dataset.

In [7]:
def move_validation_split_custom(train_dir, val_dir, per_class_counts: dict, random_select=True):
    train_dir = Path(train_dir)
    val_dir = Path(val_dir)
    val_dir.mkdir(parents=True, exist_ok=True)

    for class_name, n in per_class_counts.items():
        class_dir = train_dir / class_name
        if not class_dir.exists():
            print(f"⚠️ Folder tidak ditemukan: {class_dir}")
            continue

        images = sorted([p for p in class_dir.glob("*.*") if p.suffix.lower() in {'.jpg', '.jpeg', '.png'}])
        selected = random.sample(images, min(n, len(images))) if random_select else images[:n]
        val_class_dir = val_dir / class_name
        val_class_dir.mkdir(parents=True, exist_ok=True)

        print(f"📁 {class_name}: Memindahkan {len(selected)} file...")
        for img in tqdm(selected, desc=f"  Pindah {class_name}", leave=False):
            shutil.move(str(img), str(val_class_dir / img.name))

    print("\n✅ Selesai membuat validasi set proporsional.")


In [None]:
# 15% validasi per kelas
per_class_counts = {
    "glass": 495,
    "metal": 525,
    "organic": 300,
    "paper": 600,
    "plastic": 600,
    "styrofoam": 150,
    "textiles": 600
}

move_validation_split_custom("dataset/train", "dataset/test", per_class_counts, random_select=True)

📁 glass: Memindahkan 495 file...




📁 metal: Memindahkan 525 file...




📁 organic: Memindahkan 300 file...




📁 paper: Memindahkan 600 file...




📁 plastic: Memindahkan 600 file...




📁 styrofoam: Memindahkan 150 file...




📁 textiles: Memindahkan 600 file...


                                                          


✅ Selesai membuat validasi set proporsional.




### Data Distribution checking

In [18]:
from collections import Counter
from torchvision.datasets import ImageFolder

train_dataset = ImageFolder("dataset/train")
label_counts = Counter([label for _, label in train_dataset])
print("Label mapping:", train_dataset.class_to_idx)
print("Distribusi kelas:", label_counts)

test_dataset = ImageFolder("dataset/test")
label_counts = Counter([label for _, label in test_dataset])
print("Distribusi kelas:", label_counts)


Label mapping: {'glass': 0, 'metal': 1, 'organic': 2, 'paper': 3, 'plastic': 4, 'styrofoam': 5, 'textiles': 6}
Distribusi kelas: Counter({3: 4000, 4: 4000, 6: 4000, 2: 3536, 1: 3500, 0: 3300, 5: 1000})


FileNotFoundError: [Errno 2] No such file or directory: 'dataset/test'

### Connect to Google Drive (Optional)

This functionality allows to save the trained models to current Google account

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Training and Evaluation

#### Setup ClearML

Go ahead and sign-up/sign-in to [AI Infrastructure Platform | Maximize AI Performance & Scalability | ClearML](https://clear.ml/)

After that, go to Settings -> Workspace -> Create new credentials

The new credentials will be created and shows two options:

Local Python (Recommended)
Jupyter Notebook
Both actually are the same things, it only differs on how to use the new credentials.

This time, use the clearml CLI app to consume the credentials, when prompted, paste it.

In [None]:
!pip install clearml



In [None]:
!clearml-init

ClearML SDK setup process
Configuration file already exists: /root/clearml.conf
Leaving setup, feel free to edit the configuration file.


In [None]:
!pip install torch torchvision matplotlib



In [None]:
#@title <b>Time Out Preventer (Advanced) </b></strong>
%%capture
AUTO_RECONNECT = True #@param {type:"boolean"}
#@markdown **Run this code to prevent Google Colab from Timeout**
from os import makedirs
makedirs("/root/.config/rclone", exist_ok = True)
if AUTO_RECONNECT:
  import IPython
  from google.colab import output

  display(IPython.display.Javascript('''
  function ClickConnect(){
    btn = document.querySelector("colab-connect-button")
    if (btn != null){
      console.log("Click colab-connect-button");
      btn.click()
      }

    btn = document.getElementById('ok')
    if (btn != null){
      console.log("Click reconnect");
      btn.click()
      }
    }

  setInterval(ClickConnect,60000)
  '''))

### Training Options (Choose one of these)

#### (1) Training with early stopping - Recommended

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from collections import Counter
from sklearn.utils.class_weight import compute_class_weight
import time
import os
from clearml import Task, Logger, OutputModel


# 🔁 Transformasi
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# 📁 Load dataset
train_dataset = datasets.ImageFolder("dataset/train", transform=transform)
val_dataset = datasets.ImageFolder("dataset/test", transform=transform)
class_names = train_dataset.classes

print("Label mapping:", train_dataset.class_to_idx)
print("Train distribusi:", Counter([label for _, label in train_dataset]))
print("Val distribusi:", Counter([label for _, label in val_dataset]))

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

# ⚙️ Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, len(class_names))
model = model.to(device)

# Inbalanced dataset
labels = [label for _, label in train_dataset]
class_weights = compute_class_weight(class_weight='balanced',
                                     classes=np.unique(labels),
                                     y=labels)

weights = torch.tensor(class_weights, dtype=torch.float).to(device)
criterion = nn.CrossEntropyLoss(weight=weights)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)

# 🔁 Training loop
epochs = 32
patience = 12
train_accs, val_accs = [], []
train_losses, val_losses = [], []
best_val_acc = 0
early_stop_counter = 0

# Model Storage
model_name = model.__class__.__name__
best_model_path = f"/content/drive/MyDrive/{model_name}_best_model.pt"
latest_model_path = f"/content/drive/MyDrive/{model_name}_latest_model.pt"
os.makedirs(os.path.dirname("/content/drive/MyDrive/Cool Lee Yeah/8th Semester/Skripsi/AI Models"), exist_ok=True)

task = Task.init(
    project_name="EcoSort CNN",
    task_name=f"{model_name} Training {time.strftime('%a, %b %-d, %Y - %H:%M:%S')}",
    task_type=Task.TaskTypes.training
)
logger = task.get_logger()

# 🔍 Logging sample predictions
def log_predictions(images, labels, preds, class_names, epoch):
    fig, axs = plt.subplots(1, 5, figsize=(15, 3))
    for i in range(min(5, len(images))):
        img = images[i].cpu().permute(1, 2, 0) * 0.5 + 0.5  # unnormalize
        axs[i].imshow(img.numpy())
        axs[i].axis('off')
        axs[i].set_title(f"T: {class_names[labels[i]]}\nP: {class_names[preds[i]]}")
    logger.report_matplotlib_figure(title="Sample Predictions", series="Validation", figure=fig, iteration=epoch)
    plt.close(fig)

# 🏃 Training starts
try:
  for epoch in range(epochs):
      start_time = time.time()

      model.train()
      train_loss = 0
      correct, total = 0, 0

      for images, labels in train_loader:
          images, labels = images.to(device), labels.to(device)
          optimizer.zero_grad()
          outputs = model(images)
          loss = criterion(outputs, labels)
          loss.backward()
          optimizer.step()

          train_loss += loss.item()
          _, preds = torch.max(outputs, 1)
          correct += torch.sum(preds == labels)
          total += labels.size(0)

      train_acc = correct / total
      train_accs.append(train_acc.item())
      train_losses.append(train_loss / len(train_loader))

      # ClearML Logging (Train)
      logger.report_scalar("Accuracy", "Train", value=train_acc.item(), iteration=epoch)
      logger.report_scalar("Loss", "Train", value=train_loss / len(train_loader), iteration=epoch)
      logger.report_scalar("LR", "Learning Rate", value=optimizer.param_groups[0]['lr'], iteration=epoch)

      # 🔍 Validasi
      model.eval()
      val_loss = 0
      correct, total = 0, 0
      y_true, y_pred = [], []
      last_batch_images, last_batch_labels, last_batch_preds = None, None, None

      with torch.no_grad():
          for images, labels in val_loader:
              images, labels = images.to(device), labels.to(device)
              outputs = model(images)
              loss = criterion(outputs, labels)

              val_loss += loss.item()
              _, preds = torch.max(outputs, 1)
              correct += torch.sum(preds == labels)
              total += labels.size(0)

              y_true.extend(labels.cpu().numpy())
              y_pred.extend(preds.cpu().numpy())

              last_batch_images = images
              last_batch_labels = labels
              last_batch_preds = preds

      val_acc = correct / total
      val_accs.append(val_acc.item())
      val_losses.append(val_loss / len(val_loader))

      # ClearML Logging (Val)
      logger.report_scalar("Accuracy", "Validation", value=val_acc.item(), iteration=epoch)
      logger.report_scalar("Loss", "Validation", value=val_loss / len(val_loader), iteration=epoch)

      # Classification Report
      report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
      for class_name in class_names:
          logger.report_scalar("Precision", class_name, value=report[class_name]["precision"], iteration=epoch)
          logger.report_scalar("Recall", class_name, value=report[class_name]["recall"], iteration=epoch)
          logger.report_scalar("F1-Score", class_name, value=report[class_name]["f1-score"], iteration=epoch)
      logger.report_scalar("F1-Score", "Macro Avg", value=report["macro avg"]["f1-score"], iteration=epoch)

      # Confusion Matrix
      cm = confusion_matrix(y_true, y_pred)
      fig, ax = plt.subplots(figsize=(6, 6))
      disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
      disp.plot(cmap="Blues", ax=ax)
      plt.title("Confusion Matrix")
      logger.report_matplotlib_figure(title="Confusion Matrix", series="Validation", figure=fig, iteration=epoch)
      plt.close(fig)

      # Sample Prediction Logging
      log_predictions(last_batch_images, last_batch_labels, last_batch_preds, class_names, epoch)
      print(f"🔁 Epoch {epoch+1}/{epochs} - Train Acc: {train_acc:.4f} - Val Acc: {val_acc:.4f}")

      # Save model
      if val_acc >= best_val_acc:
          best_val_acc = val_acc
          early_stop_counter = 0
          torch.save(model.state_dict(), best_model_path)
          print(f"🏆 (Best) Model saved")

      else:
          torch.save(model.state_dict(), latest_model_path)
          print(f"📦 (Latest) Model saved")
          early_stop_counter += 1
          if early_stop_counter >= patience:
              print("⏹️ Early stopping triggered.")
              break

      duration = time.time() - start_time
      eta = (epochs - epoch - 1) * duration
      print(f"⏱️ Epoch time: Took {duration:.2f}s - ETA: ~{eta/60:.1f} min")
      logger.report_scalar("Epoch Time (sec)", "Duration", value=duration, iteration=epoch)
      print("\n")

  # 🎉 Done
  print("=== Final Classification Report ===")
  print(classification_report(y_true, y_pred, target_names=class_names))

  task.mark_completed()

except Exception:
    task.mark_failed()


ClearML Task: created new task id=b32342711e1d410c997b8345a8047d0c
2025-07-19 07:52:39,981 - clearml.Task - INFO - Storing jupyter notebook directly as code
ClearML results page: https://app.clear.ml/projects/06fa9058610d4908bdf75f3a0a10c4b2/experiments/b32342711e1d410c997b8345a8047d0c/output/log
Label mapping: {'glass': 0, 'metal': 1, 'organic': 2, 'paper': 3, 'plastic': 4, 'styrofoam': 5, 'textiles': 6}
Train distribusi: Counter({3: 3400, 4: 3400, 6: 3400, 1: 2975, 0: 2805, 2: 1700, 5: 850})
Val distribusi: Counter({3: 600, 4: 600, 6: 600, 1: 525, 0: 495, 2: 300, 5: 150})
2025-07-19 07:54:41,080 - clearml.model - INFO - Selected model id: 19f8b6976a0344a0ae63d4f6f4dc3be5
ClearML Monitor: Could not detect iteration reporting, falling back to iterations as seconds-from-start
Epoch 1/32 - Train Acc: 0.6450 - Val Acc: 0.6786
2025-07-19 07:59:01,379 - clearml.frameworks - INFO - Found existing registered model id=d733f9cedb084ef884057ff923d1bfd9 [/content/drive/MyDrive/best_model.pt] reus

#### (2) Training without early stopping

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms, models
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from collections import Counter
import time

# Transform
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# Dataset
train_dataset = datasets.ImageFolder("dataset/train", transform=transform)
class_names = train_dataset.classes
print("Label mapping:", train_dataset.class_to_idx)
print("Distribusi:", Counter([label for _, label in train_dataset]))

val_dataset = datasets.ImageFolder("dataset/test", transform=transform)
print("Label mapping:", val_dataset.class_to_idx)
print("Distribusi:", Counter([label for _, label in val_dataset]))

# Split
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

# Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, len(class_names))
model = model.to(device)

# Loss & Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training
epochs = 6
train_accs, val_accs = [], []
best_val_acc = 0

for epoch in range(epochs):
    model.train()
    correct, total, loss_total = 0, 0, 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        loss_total += loss.item()
        _, preds = torch.max(outputs, 1)
        correct += torch.sum(preds == labels)
        total += labels.size(0)

    train_acc = correct / total
    train_accs.append(train_acc.item())
    print(f"Epoch {epoch+1}/{epochs} - Train Acc: {train_acc:.4f}")

    # Validasi
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            correct += torch.sum(preds == labels)
            total += labels.size(0)

    val_acc = correct / total
    val_accs.append(val_acc.item())
    print(f"            → Val Acc: {val_acc:.4f}")

# Simpan
torch.save(model.state_dict(), "model_cnn.pt")
print("✅ Model disimpan.")

### Evaluation

In [None]:
# 📊 Evaluasi dengan laporan & Confusion Matrix
model.eval()
y_true, y_pred = [], []

for images, labels in val_loader:
    images = images.to(device)
    outputs = model(images)
    _, preds = torch.max(outputs, 1)

    y_true.extend(labels.numpy())
    y_pred.extend(preds.cpu().numpy())

print("\n=== Classification Report ===")
print(classification_report(y_true, y_pred, target_names=class_names))

# Confusion Matrix
cm = confusion_matrix(y_true, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
disp.plot(cmap="Blues")
plt.title("Confusion Matrix")
plt.show()

# Plot akurasi training & val
plt.plot(train_accs, label="Train")
plt.plot(val_accs, label="Validation")
plt.ylabel("Accuracy")
plt.xlabel("Epoch")
plt.title("Train vs Validation Accuracy")
plt.legend()
plt.grid(True)
plt.show()