In [1]:
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image, UnidentifiedImageError
import numpy as np
from sklearn.svm import OneClassSVM
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix, classification_report
import os
from tqdm import tqdm  # Progress bar
import matplotlib.pyplot as plt
import seaborn as sns

# Load pretrained ResNet50 model (without classifier)
resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
resnet = torch.nn.Sequential(*list(resnet.children())[:-1])  # Remove last layer
resnet.eval()

# Image transformation
transform = transforms.Compose([
    transforms.Resize((600, 600)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Function to extract deep features with progress bar
def extract_features(folder):
    features = []
    filenames = []
    
    for filename in tqdm(os.listdir(folder), desc=f"Extracting features from {folder}"):
        img_path = os.path.join(folder, filename)
        
        # Skip non-image files
        if not (filename.lower().endswith((".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".gif"))):
            continue  

        try:
            img = Image.open(img_path).convert("RGB")
            img = transform(img).unsqueeze(0)

            with torch.no_grad():
                feat = resnet(img).flatten().numpy()  # Extract deep features
            
            features.append(feat)
            filenames.append(filename)  # Only store valid image filenames

        except (UnidentifiedImageError, OSError) as e:
            print(f"Skipping file {filename}: {e}")

    return np.array(features), filenames  # Return valid features and filenames

# Load and extract deep features with progress bars
X_train, _ = extract_features("data/train")
X_test_cat, test_filenames_cat = extract_features("data/test/cat")  # Normal class
X_test_other, test_filenames_other = extract_features("data/test/other")  # Anomaly class

# Normalize and train One-Class SVM
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test_cat = scaler.transform(X_test_cat)
X_test_other = scaler.transform(X_test_other)

print("\nTraining One-Class SVM...")
oc_svm = OneClassSVM(kernel="rbf", gamma=0.001, nu=0.08)
oc_svm.fit(X_train)

# Predict with progress bar
print("\nPredicting anomalies...")
predictions_cat = [oc_svm.predict([x])[0] for x in tqdm(X_test_cat, desc="Processing cat images")]
predictions_other = [oc_svm.predict([x])[0] for x in tqdm(X_test_other, desc="Processing other images")]

# Prepare labels
true_labels = [1] * len(predictions_cat) + [-1] * len(predictions_other)  # 1 = normal (cat), -1 = anomaly (other)
pred_labels = predictions_cat + predictions_other

# Compute confusion matrix
cm = confusion_matrix(true_labels, pred_labels, labels=[1, -1])

# Save confusion matrix
plt.figure(figsize=(6, 4))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Normal', 'Anomaly'], yticklabels=['Normal', 'Anomaly'])
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix")
plt.savefig("confusion_matrix.png")
plt.close()

# Save classification report
report = classification_report(true_labels, pred_labels, target_names=['Normal (Cat)', 'Anomaly (Other)'])
with open("classification_report.txt", "w") as f:
    f.write(report)

# Save individual results
results = []
for filename, pred in zip(test_filenames_cat + test_filenames_other, pred_labels):
    status = "Normal (Cat)" if pred == 1 else "Anomaly (Other)"
    results.append(f"{filename}: {status}")
    print(f"{filename}: {status}")

with open("results.txt", "w") as f:
    f.write("\n".join(results))

# Display results
print("\nConfusion Matrix:")
print(cm)
print("\nClassification Report:")
print(report)

Extracting features from data/train: 100%|██████████| 12001/12001 [31:50<00:00,  6.28it/s]
Extracting features from data/test/cat: 100%|██████████| 501/501 [01:17<00:00,  6.47it/s]
Extracting features from data/test/other: 100%|██████████| 500/500 [01:17<00:00,  6.46it/s]



Training One-Class SVM...

Predicting anomalies...


Processing cat images: 100%|██████████| 500/500 [00:00<00:00, 530.01it/s]
Processing other images: 100%|██████████| 500/500 [00:00<00:00, 537.49it/s]


cat.12000.jpg: Normal (Cat)
cat.12001.jpg: Anomaly (Other)
cat.12002.jpg: Normal (Cat)
cat.12003.jpg: Normal (Cat)
cat.12004.jpg: Anomaly (Other)
cat.12005.jpg: Anomaly (Other)
cat.12006.jpg: Normal (Cat)
cat.12007.jpg: Normal (Cat)
cat.12008.jpg: Normal (Cat)
cat.12009.jpg: Normal (Cat)
cat.12010.jpg: Normal (Cat)
cat.12011.jpg: Anomaly (Other)
cat.12012.jpg: Normal (Cat)
cat.12013.jpg: Normal (Cat)
cat.12014.jpg: Normal (Cat)
cat.12015.jpg: Anomaly (Other)
cat.12016.jpg: Anomaly (Other)
cat.12017.jpg: Normal (Cat)
cat.12018.jpg: Anomaly (Other)
cat.12019.jpg: Normal (Cat)
cat.12020.jpg: Anomaly (Other)
cat.12021.jpg: Normal (Cat)
cat.12022.jpg: Normal (Cat)
cat.12023.jpg: Normal (Cat)
cat.12024.jpg: Normal (Cat)
cat.12025.jpg: Normal (Cat)
cat.12026.jpg: Normal (Cat)
cat.12027.jpg: Normal (Cat)
cat.12028.jpg: Anomaly (Other)
cat.12029.jpg: Normal (Cat)
cat.12030.jpg: Anomaly (Other)
cat.12031.jpg: Normal (Cat)
cat.12032.jpg: Normal (Cat)
cat.12033.jpg: Anomaly (Other)
cat.12034.jpg: 