### Simplified Code for Image Recognition b/w Cat & Flower Images

In [71]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
import matplotlib.pyplot as plt

from PIL import Image  # Original import

# Import for precision, recall, and F1 score
from sklearn.metrics import precision_score, recall_score, f1_score  # Original

# Pretrained SwinTransformer import
import timm  # New addition

In [72]:
# Define the SwinTransformer model (using pretrained model)
class SwinTransformer(nn.Module):  # Updated class definition
    def __init__(self, num_classes=2):
        super(SwinTransformer, self).__init__()
        # Use pre-trained SwinTransformer from timm
        self.model = timm.create_model('swin_tiny_patch4_window7_224', pretrained=True, num_classes=num_classes)  # New addition

    def forward(self, x):
        return self.model(x)  # New addition

In [73]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [74]:
# Data transforms with augmentation (New addition)
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Original
    transforms.RandomHorizontalFlip(),  # New addition
    transforms.RandomRotation(10),  # New addition
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),  # New addition
    transforms.ToTensor(),  # Original
])

# This portion of the code can be further expanded or put in more detail for image segmentation, contour detection or image quality assessment.

In [75]:
# Load data
train_data = datasets.ImageFolder('C:/Users/sharm/Downloads/Cat_train', transform=transform)  # Original
test_data = datasets.ImageFolder('C:/Users/sharm/Downloads/Cat_test', transform=transform)  # Original

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)  # Original
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)  # Original


In [76]:
# Initialize model, loss, and optimizer
model = SwinTransformer().to(device)  # Updated model to use pretrained SwinTransformer
criterion = nn.CrossEntropyLoss()  # Original
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.3)  # Original learning rate and weight decay retained


In [77]:
# Training loop
def train(epochs):
    for epoch in range(epochs):
        model.train()
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')  # Original

In [78]:
# Test function
def test():
    model.eval()  # Original
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Accuracy on test images: {100 * correct / total}%')  # Original


In [79]:
# Train the model
train(epochs=10)  # Original

# Test the model
test()  # Original

Epoch [1/10], Loss: 0.7400
Epoch [2/10], Loss: 0.5016
Epoch [3/10], Loss: 0.2741
Epoch [4/10], Loss: 0.1330
Epoch [5/10], Loss: 0.1064
Epoch [6/10], Loss: 0.0407
Epoch [7/10], Loss: 0.0301
Epoch [8/10], Loss: 0.0293
Epoch [9/10], Loss: 0.0154
Epoch [10/10], Loss: 0.0236
Accuracy on test images: 66.66666666666667%


In [80]:
# Function to predict on a single image
def predict_image(image_path):
    image = Image.open(image_path)  # Original
    image = transform(image).unsqueeze(0).to(device)  # Original
    
    model.eval()  # Set the model to evaluation mode (Original)
    with torch.no_grad():  # Original
        output = model(image)
        _, predicted = torch.max(output, 1)
    
    return train_data.classes[predicted.item()]  # Original

In [81]:
# Example usage
custom_image_path = "C:/Users/sharm/Downloads/Cat_test/class_2/unnamed.png"  # Original; other option for test image is 's1'.
prediction = predict_image(custom_image_path)  # Original
print(f"Prediction: {prediction}")  # Original

Prediction: class_2


class_1 denotes cat image 

class_2 denotes flower/miscellaneous image

We can also have a small function to view the image before running the prediction model to see whether the image is of cat or flower.


In [82]:
all_preds = []
all_labels = []

In [83]:
# Collect predictions for the test set
with torch.no_grad():  # Original
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

In [84]:
# Compute and print recall and F1 score
recall = recall_score(all_labels, all_preds, average='weighted')  # Original
f1 = f1_score(all_labels, all_preds, average='weighted')  # Original

print(f'Recall: {recall}')  # Original
print(f'F1 Score: {f1}')  # Original

Recall: 0.6666666666666666
F1 Score: 0.6926406926406927
