In [1]:
!pip uninstall datasets modelscope -y
!pip install "datasets>=3.0.0,<4.0.0" modelscope[datasets]

Found existing installation: datasets 4.0.0
Uninstalling datasets-4.0.0:
  Successfully uninstalled datasets-4.0.0
[0mCollecting datasets<4.0.0,>=3.0.0
  Downloading datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Collecting modelscope[datasets]
  Downloading modelscope-1.33.0-py3-none-any.whl.metadata (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.3/43.3 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
Collecting addict (from modelscope[datasets])
  Downloading addict-2.4.0-py3-none-any.whl.metadata (1.0 kB)
Collecting oss2 (from modelscope[datasets])
  Downloading oss2-2.19.1.tar.gz (298 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m298.8/298.8 kB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting crcmod>=1.7 (from oss2->modelscope[datasets])
  Downloading crcmod-1.7.tar.gz (89 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m89.7/89.7 kB[0m [31m5.6 MB/s

In [None]:
# Install required packages
!pip install -q timm pillow matplotlib seaborn scikit-learn torch_xla mmcv-full mmcls

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/607.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m604.2/607.9 kB[0m [31m23.8 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m607.9/607.9 kB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.8/46.8 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m80.3/80.3 MB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m648.8/648.8 kB[0m [31m21.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m256.2/256.2 kB[0m [31m16.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
# @title
"""
Garbage Classification using ModelScope ConvNeXt-Base
- Train a new model on validation split and show training process
- Evaluate the pre-trained model
- Classify images from URL
- Optimized for TPU
"""
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
import torchvision.transforms as transforms
from PIL import Image
import requests
from io import BytesIO
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
from modelscope.msdatasets import MsDataset
from modelscope.utils.constant import DownloadMode
from modelscope.pipelines import pipeline
from modelscope.models import Model
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

In [None]:
# @title
use_tpu = False
try:
    import torch_xla
    import torch_xla.core.xla_model as xm
    device = xm.xla_device()
    use_tpu = True
    print(f"✓ Using TPU: {device}")
except:
    if torch.cuda.is_available():
        device = torch.device('cuda')
        print(f"✓ Using GPU: {device}")
        print(f"  GPU Name: {torch.cuda.get_device_name(0)}")
    else:
        device = torch.device('cpu')
        print(f"✓ Using CPU: {device}")
        print("  Note: Training will be slower on CPU")

In [None]:
if torch.cuda.is_available():
        device = torch.device('cuda')
        print(f"✓ Using GPU: {device}")
        print(f"  GPU Name: {torch.cuda.get_device_name(0)}")

In [None]:
# @title
# =============================================================================
# 1. LOAD DATASET
# =============================================================================
print("\n=== Loading Dataset from ModelScope ===")
# Load the garbage265 dataset - using validation split for training
ms_train_dataset = MsDataset.load(
    'garbage265', namespace='tany0699',
    subset_name='default', split='validation',
    download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)

print(f"Dataset loaded: {len(ms_train_dataset)} samples")
print("Sample data:", next(iter(ms_train_dataset)))

# Split validation set into train and test
total_size = len(ms_train_dataset)
train_size = int(0.8 * total_size)
test_size = total_size - train_size
print(f"\nSplitting into Train: {train_size}, Test: {test_size}")

In [None]:
# @title
# =============================================================================
# 2. PREPARE DATASET
# =============================================================================
class GarbageDataset(Dataset):
    def __init__(self, ms_dataset, transform=None):
        self.data = list(ms_dataset)
        self.transform = transform
        print(ms_dataset[0])
        # Get unique labels and create mapping
        labels = [item['category'] for item in self.data]
        self.unique_labels = sorted(list(set(labels)))
        self.label_to_idx = {label: idx for idx, label in enumerate(self.unique_labels)}
        self.idx_to_label = {idx: label for label, idx in self.label_to_idx.items()}

        print(f"Found {len(self.unique_labels)} classes: {self.unique_labels[:10]}...")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image = item['image:FILE']
        if isinstance(image, str):
            image = Image.open(image).convert('RGB')
        elif not isinstance(image, Image.Image):
            image = Image.open(BytesIO(image)).convert('RGB')

        label = self.label_to_idx[item['category']]

        if self.transform:
            image = self.transform(image)

        return image, label

In [None]:
# @title
# Define transforms
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create full dataset first to get class info
full_dataset = GarbageDataset(ms_train_dataset, transform=train_transform)
num_classes = len(full_dataset.unique_labels)
class_names = full_dataset.unique_labels

# Split into train and test
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])
test_dataset.dataset.transform = test_transform

# Create dataloaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

In [None]:
# @title
# =============================================================================
# 3. LOAD PRE-TRAINED MODEL FOR EVALUATION
# =============================================================================
print("\n=== Loading Pre-trained Model ===")
pretrained_model = Model.from_pretrained('damo/cv_convnext-base_image-classification_garbage')
pretrained_classifier = pipeline('image-classification', model=pretrained_model)

In [None]:
# =============================================================================
# 4. CREATE AND TRAIN A NEW MODEL
# =============================================================================
print("\n=== Creating New Model ===")

# Load base ConvNeXt model and modify for our number of classes
import timm
model = timm.create_model('convnext_base', pretrained=True, num_classes=num_classes)
model = model.to(device)

# Training setup
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2)
num_epochs = 4 # 10

# Training history
train_losses = []
train_accs = []
val_losses = []
val_accs = []

print(f"\n=== Training for {num_epochs} epochs ===")
print(f"Training on {len(train_dataset)} samples")
print(f"Validating on {len(test_dataset)} samples")

# Training loop
for epoch in range(num_epochs):
    # Training phase
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()

        # For TPU
        if 'xla' in str(device):
            xm.optimizer_step(optimizer)
        else:
            optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100.*correct/total:.2f}%'})

    train_loss = running_loss / len(train_loader)
    train_acc = 100. * correct / total
    train_losses.append(train_loss)
    train_accs.append(train_acc)

    # Validation phase
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Val]'):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    val_loss = running_loss / len(test_loader)
    val_acc = 100. * correct / total
    val_losses.append(val_loss)
    val_accs.append(val_acc)

    scheduler.step(val_loss)

    print(f'Epoch {epoch+1}/{num_epochs}:')
    print(f'  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
    print(f'  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')

In [None]:
# =============================================================================
# 5. PLOT TRAINING HISTORY
# =============================================================================
print("\n=== Plotting Training History ===")
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Loss plot
ax1.plot(range(1, num_epochs+1), train_losses, 'b-', label='Train Loss', marker='o')
ax1.plot(range(1, num_epochs+1), val_losses, 'r-', label='Val Loss', marker='s')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True)

# Accuracy plot
ax2.plot(range(1, num_epochs+1), train_accs, 'b-', label='Train Acc', marker='o')
ax2.plot(range(1, num_epochs+1), val_accs, 'r-', label='Val Acc', marker='s')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Training and Validation Accuracy')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.savefig('training_history.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# =============================================================================
# 6. EVALUATE PRE-TRAINED MODEL (4 Main Categories)
# =============================================================================
print("\n=== Evaluating Pre-trained Model ===")

# Map detailed labels to 4 main categories
def get_main_category(label_text):
    """Extract main category from detailed label"""
    if label_text.startswith('厨余垃圾'):
        return 0  # Kitchen waste
    elif label_text.startswith('可回收物'):
        return 1  # Recyclable
    elif label_text.startswith('其他垃圾') or label_text.startswith('其它垃圾'):
        return 2  # Other waste
    elif label_text.startswith('有害垃圾'):
        return 3  # Hazardous waste
    else:
        return 2  # Default to other waste

# Map dataset numeric labels to main categories
def map_dataset_label_to_main(label_idx):
    """Map dataset label index to main category"""
    # Based on the 265 classes order:
    # 0-51: 厨余垃圾 (Kitchen waste)
    # 52-197: 可回收物 (Recyclable)
    # 198-251: 其他垃圾 (Other waste)
    # 252-264: 有害垃圾 (Hazardous waste)
    if label_idx <= 51:
        return 0
    elif label_idx <= 197:
        return 1
    elif label_idx <= 251:
        return 2
    else:
        return 3

main_category_names = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾']

# Get predictions from pre-trained model
pretrained_preds = []
true_labels = []

for images, labels in tqdm(test_loader, desc='Evaluating Pre-trained'):
    # Map true labels to main categories
    true_labels.extend([map_dataset_label_to_main(l.item()) for l in labels])

    # Convert images back for pipeline
    for img_tensor in images:
        # Denormalize
        img_tensor = img_tensor.cpu()
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        img_tensor = img_tensor * std + mean
        img_tensor = torch.clamp(img_tensor, 0, 1)

        # Convert to PIL
        img_pil = transforms.ToPILImage()(img_tensor)

        # Predict
        result = pretrained_classifier(img_pil)
        pred_label_text = result['labels'][0]

        # Map prediction to main category
        main_cat = get_main_category(pred_label_text)
        pretrained_preds.append(main_cat)

# Calculate metrics
pretrained_acc = accuracy_score(true_labels, pretrained_preds)
print(f"\nPre-trained Model Accuracy (4 categories): {pretrained_acc*100:.2f}%")

# Confusion Matrix for Pre-trained Model
cm_pretrained = confusion_matrix(true_labels, pretrained_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm_pretrained, annot=True, fmt='d', cmap='Blues',
            xticklabels=main_category_names,
            yticklabels=main_category_names)
plt.title(f'Confusion Matrix - Pre-trained Model (Acc: {pretrained_acc*100:.2f}%)')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.savefig('confusion_matrix_pretrained.png', dpi=300, bbox_inches='tight')
plt.show()

# Classification Report
print("\nPre-trained Model Classification Report:")
print(classification_report(true_labels, pretrained_preds,
                          target_names=main_category_names))


In [None]:
# =============================================================================
# 8. URL IMAGE CLASSIFICATION FUNCTION
# =============================================================================
print("\n=== URL Image Classification Ready ===")

def classify_from_url(image_url, use_pretrained=True):
    """
    Classify an image from URL

    Args:
        image_url: URL of the image
        use_pretrained: If True, use pre-trained model, else use newly trained model
    """
    try:
        # Download image
        response = requests.get(image_url)
        img = Image.open(BytesIO(response.content)).convert('RGB')

        # Display image
        plt.figure(figsize=(8, 8))
        plt.imshow(img)
        plt.axis('off')
        plt.title(f'Input Image\n{image_url}')
        plt.show()

        if use_pretrained:
            # Use pre-trained model
            result = pretrained_classifier(img)
            print("\n=== Pre-trained Model Prediction ===")
            for i, (label, score) in enumerate(zip(result['labels'], result['scores'])):
                print(f"{i+1}. {label}: {score:.4f}")
        else:
            # Use newly trained model
            img_tensor = test_transform(img).unsqueeze(0).to(device)

            model.eval()
            with torch.no_grad():
                outputs = model(img_tensor)
                probs = torch.softmax(outputs, dim=1)[0]
                top5_prob, top5_idx = torch.topk(probs, min(5, len(class_names)))

            print("\n=== Newly Trained Model Prediction ===")
            for i, (idx, prob) in enumerate(zip(top5_idx, top5_prob)):
                print(f"{i+1}. {class_names[idx]}: {prob:.4f}")

    except Exception as e:
        print(f"Error: {e}")

# Example usage
print("\n" + "="*60)
print("To classify an image from URL, use:")
print("classify_from_url('YOUR_IMAGE_URL', use_pretrained=True)  # Pre-trained model")
print("classify_from_url('YOUR_IMAGE_URL', use_pretrained=False) # Newly trained model")
print("="*60)

In [None]:
# Example with a sample URL (you can change this)
sample_url = "https://thumbs.dreamstime.com/z/flattened-coca-cola-can-ground-discarded-disposable-coca-cola-can-crumpled-empty-single-use-fizzy-drink-coke-can-problem-213740333.jpg?ct=jpeg"
print(f"\nExample classification with: {sample_url}")
classify_from_url(sample_url, use_pretrained=True)
# Example with a sample URL (you can change this)
sample_url = "https://plus.unsplash.com/premium_photo-1724249989963-9286e126af81?q=80&w=1170&auto=format&fit=crop&ixlib=rb-4.1.0&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
print(f"\nExample classification with: {sample_url}")
classify_from_url(sample_url, use_pretrained=True)