Imports

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.models as models

  from .autonotebook import tqdm as notebook_tqdm


Torch Version

In [5]:
print(torch.__version__)

2.2.2


GPU Check

In [6]:
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print (x)
else:
    print ("MPS device not found.")

tensor([1.], device='mps:0')


In [7]:
torch.set_num_threads(1) # to minimize the cpu utilization

AlexNet

In [10]:
AlexNet = models.alexnet(pretrained=False)
num_classes = 2 
AlexNet.classifier[6] = torch.nn.Linear(AlexNet.classifier[6].in_features, num_classes)

In [11]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [12]:
train_set = datasets.ImageFolder(root='real-vs-fake/train', transform=transform)
test_set = datasets.ImageFolder(root='real-vs-fake/test', transform=transform)
val_set = datasets.ImageFolder(root='real-vs-fake/valid', transform=transform)


train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
test_loader = DataLoader(test_set, batch_size=32, shuffle=False)
val_loader = DataLoader(val_set, batch_size=32, shuffle=False)


num_train_images = len(train_loader.dataset)
num_test_images = len(test_loader.dataset)
num_val_images = len(val_loader.dataset)


print("Number of training images:", num_train_images)
print("Number of testing images:", num_test_images)
print("Number of validating images:", num_val_images)


train_classes = train_loader.dataset.classes
print("Classes in the dataset:", train_classes)

Number of training images: 159957
Number of testing images: 75992
Number of validating images: 25998
Classes in the dataset: ['fake', 'real']


In [3]:
model = AlexNet.to(mps_device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 10
best_val_accuracy = 0.0

num_epochs = 10
best_val_accuracy = 0.0

for epoch in range(num_epochs):
    model.train() 
    train_loss = 0.0
    train_correct = 0
    total_train = 0
    for images, labels in train_loader:
        images, labels = images.to("mps"), labels.to("mps")

        optimizer.zero_grad() 
        outputs = model(images)
        if isinstance(outputs, models.GoogLeNetOutputs):
            outputs = outputs.logits
        loss = criterion(outputs, labels)
        loss.backward()  
        optimizer.step() 
        
        train_loss += loss.item() * images.size(0) 
        _, predicted = torch.max(outputs, 1) 
        total_train += labels.size(0) 
        train_correct += (predicted == labels).sum().item() 
    
    train_accuracy = 100.0 * train_correct / total_train
    
    model.eval() 
    val_loss = 0.0
    val_correct = 0
    total_val = 0
    
    with torch.no_grad(): 
        for images, labels in val_loader:
            images, labels = images.to("mps"), labels.to("mps")
            outputs = model(images)

            if isinstance(outputs, models.GoogLeNetOutputs):
                outputs = outputs.logits  

            loss = criterion(outputs, labels) 
            
            val_loss += loss.item() * images.size(0) 
            _, predicted = torch.max(outputs, 1) 
            total_val += labels.size(0) 
            val_correct += (predicted == labels).sum().item() 
    
    val_accuracy = 100.0 * val_correct / total_val  

    print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss / total_train:.4f}, Train Accuracy: {train_accuracy:.2f}%, Validation Loss: {val_loss / total_val:.4f}, Validation Accuracy: {val_accuracy:.2f}%")

    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        torch.save(model.state_dict(), "best_cnn_model.pth")

Epoch 1/10, Train Loss: 0.3932, Train Accuracy: 79.95%, Validation Loss: 0.3775, Validation Accuracy: 81.71%
Epoch 2/10, Train Loss: 0.2966, Train Accuracy: 86.52%, Validation Loss: 0.2847, Validation Accuracy: 87.49%
Epoch 3/10, Train Loss: 0.2355, Train Accuracy: 88.06%, Validation Loss: 0.2566, Validation Accuracy: 89.90%
Epoch 4/10, Train Loss: 0.2287, Train Accuracy: 90.66%, Validation Loss: 0.2304, Validation Accuracy: 91.00%
Epoch 5/10, Train Loss: 0.1986, Train Accuracy: 91.93%, Validation Loss: 0.2420, Validation Accuracy: 92.12%
Epoch 6/10, Train Loss: 0.2091, Train Accuracy: 92.55%, Validation Loss: 0.2275, Validation Accuracy: 93.77%
Epoch 7/10, Train Loss: 0.2081, Train Accuracy: 93.71%, Validation Loss: 0.2071, Validation Accuracy: 94.01%
Epoch 8/10, Train Loss: 0.2017, Train Accuracy: 94.09%, Validation Loss: 0.1547, Validation Accuracy: 95.92%
Epoch 9/10, Train Loss: 0.1435, Train Accuracy: 95.19%, Validation Loss: 0.1385, Validation Accuracy: 96.82%
Epoch 10/10, Train 

In [1]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(mps_device), labels.to(mps_device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy on test set: {(100 * correct / total):.2f}%")

Accuracy on test set: 97.54%
