In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import accuracy_score
from tqdm import tqdm
import matplotlib.pyplot as plt


device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Parameters
data_dir = "dataset"
batch_size = 16
num_epochs = 30
img_size = 160
learning_rate = 1e-4
weight_decay = 1e-4
patience = 3
num_workers = 0

train_dir = "dataset"

# Transforms
train_transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

val_transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

# Load dataset
full_dataset = datasets.ImageFolder(root=train_dir)
class_names = full_dataset.classes
print("Classes:", class_names)

train_size = int(0.9 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_data, val_data = random_split(full_dataset, [train_size, val_size])

# Wrappers to apply transforms
class TransformWrapper(torch.utils.data.Dataset):
    def __init__(self, dataset, transform):
        self.dataset = dataset
        self.transform = transform
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        return self.transform(img), label

train_dataset = TransformWrapper(train_data, train_transform)
val_dataset = TransformWrapper(val_data, val_transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

# Model: Pretrained ResNet18
model = models.resnet18(pretrained=True)
model.fc = nn.Sequential(
    nn.Dropout(0.5),
    nn.Linear(model.fc.in_features, 2)
)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

best_acc = 0.0
early_stop_counter = 0

print("Training started...")
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    
    for images, labels in progress_bar:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())
    
    avg_loss = total_loss / len(train_loader)

    # Validation
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu())
            all_labels.extend(labels.cpu())

    acc = accuracy_score(all_labels, all_preds)
    print(f"Epoch {epoch+1}/{num_epochs} | Avg Loss: {avg_loss:.4f} | Val Accuracy: {acc:.4f}")

    # Early stopping
    if acc > best_acc:
        best_acc = acc
        early_stop_counter = 0
        torch.save(model.state_dict(), "best_resnet_blur_classifier.pth")
        print("Saved best model.")
    else:
        early_stop_counter += 1
        if early_stop_counter >= patience:
            print(f"Early stopping after {patience} epochs without improvement.")
            break

print("Training complete!")


Using device: mps
Classes: ['blur', 'clear']




Training started...


Epoch 1/30: 100%|██████████| 57/57 [00:06<00:00,  8.80it/s, loss=0.00363]


Epoch 1/30 | Avg Loss: 0.1900 | Val Accuracy: 0.9800
Saved best model.


Epoch 2/30: 100%|██████████| 57/57 [00:05<00:00, 10.95it/s, loss=2.45]   


Epoch 2/30 | Avg Loss: 0.0866 | Val Accuracy: 0.9800


Epoch 3/30: 100%|██████████| 57/57 [00:05<00:00, 11.03it/s, loss=0.000312]


Epoch 3/30 | Avg Loss: 0.0294 | Val Accuracy: 0.9800


Epoch 4/30: 100%|██████████| 57/57 [00:05<00:00,  9.65it/s, loss=0.00168] 


Epoch 4/30 | Avg Loss: 0.0135 | Val Accuracy: 1.0000
Saved best model.


Epoch 5/30: 100%|██████████| 57/57 [00:09<00:00,  5.79it/s, loss=0.00634] 


Epoch 5/30 | Avg Loss: 0.0092 | Val Accuracy: 0.9900


Epoch 6/30: 100%|██████████| 57/57 [00:08<00:00,  6.73it/s, loss=9.89e-5] 


Epoch 6/30 | Avg Loss: 0.0049 | Val Accuracy: 0.9900


Epoch 7/30: 100%|██████████| 57/57 [00:07<00:00,  7.42it/s, loss=0.00382] 


Epoch 7/30 | Avg Loss: 0.0087 | Val Accuracy: 0.9900
Early stopping after 3 epochs without improvement.
Training complete!


In [5]:
from PIL import Image

def classify_image(image_path, model, class_names, device):
    """
    Classify a single image.

    Args:
        image_path (str): Path to the image.
        model (torch.nn.Module): Trained PyTorch model.
        class_names (list): List of class names.
        device (torch.device): Device to run inference on.

    Returns:
        str: Predicted class label.
    """
    model.eval()
    
    image = Image.open(image_path).convert("RGB")
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
    ])
    
    image_tensor = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(image_tensor)
        prediction = torch.argmax(output, dim=1).item()
    
    return class_names[prediction]


In [6]:
# Load best model weights
model.load_state_dict(torch.load("best_resnet_blur_classifier.pth", map_location=device))

# Classify an image
image_path = "text_overlay_dataset/img_00007.jpg"
predicted_class = classify_image(image_path, model, class_names, device)
print("Predicted class:", predicted_class)

Predicted class: clear
