In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc, RocCurveDisplay
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm 
from facenet_pytorch import InceptionResnetV1 

In [15]:
data_transforms = transforms.Compose([
    transforms.Resize((160, 160)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

train_data = datasets.ImageFolder('Data/train', transform=data_transforms)
test_data = datasets.ImageFolder('Data/test', transform=data_transforms)

train_loader = DataLoader(train_data, batch_size = 128, shuffle=True)
test_loader = DataLoader(test_data, batch_size = 128, shuffle=False)

class_names = train_data.classes

In [16]:
device = torch.device('cuda')

facenet = InceptionResnetV1(pretrained='vggface2').eval().to(device)
for param in facenet.parameters():
    param.requires_grad = False 

model = nn.Sequential(
    facenet,
    nn.Linear(512, 256),
    nn.BatchNorm1d(256),
    nn.ReLU(),
    nn.Dropout(0.3),
    
    nn.Linear(256, 128),
    nn.BatchNorm1d(128),
    nn.ReLU(),
    nn.Dropout(0.3),

    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Linear(64, 2)
).to(device)

In [17]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model[1:].parameters(), lr=1e-4)

In [None]:
num_epochs = 30
for epoch in range(num_epochs):
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {running_loss/total:.4f}, Acc: {correct/total:.2%}")

# Save model
torch.save(model.state_dict(), 'deep1.pth')

Epoch 1/30: 100%|██████████| 736/736 [1:08:56<00:00,  5.62s/it]


Epoch 1/30 - Loss: 0.2628, Acc: 88.96%


Epoch 2/30: 100%|██████████| 736/736 [16:15<00:00,  1.33s/it]


Epoch 2/30 - Loss: 0.2628, Acc: 89.11%


Epoch 3/30: 100%|██████████| 736/736 [05:10<00:00,  2.37it/s]


Epoch 3/30 - Loss: 0.2599, Acc: 89.04%


Epoch 4/30: 100%|██████████| 736/736 [04:37<00:00,  2.65it/s]


Epoch 4/30 - Loss: 0.2582, Acc: 89.18%


Epoch 5/30: 100%|██████████| 736/736 [04:36<00:00,  2.66it/s]


Epoch 5/30 - Loss: 0.2579, Acc: 89.16%


Epoch 6/30: 100%|██████████| 736/736 [04:37<00:00,  2.66it/s]


Epoch 6/30 - Loss: 0.2565, Acc: 89.19%


Epoch 7/30: 100%|██████████| 736/736 [04:37<00:00,  2.65it/s]


Epoch 7/30 - Loss: 0.2553, Acc: 89.26%


Epoch 8/30: 100%|██████████| 736/736 [04:38<00:00,  2.65it/s]


Epoch 8/30 - Loss: 0.2515, Acc: 89.37%


Epoch 9/30: 100%|██████████| 736/736 [04:38<00:00,  2.65it/s]


Epoch 9/30 - Loss: 0.2520, Acc: 89.36%


Epoch 10/30: 100%|██████████| 736/736 [04:37<00:00,  2.66it/s]


Epoch 10/30 - Loss: 0.2519, Acc: 89.42%


Epoch 11/30: 100%|██████████| 736/736 [04:37<00:00,  2.65it/s]


Epoch 11/30 - Loss: 0.2509, Acc: 89.54%


Epoch 12/30: 100%|██████████| 736/736 [04:37<00:00,  2.65it/s]


Epoch 12/30 - Loss: 0.2484, Acc: 89.57%


Epoch 13/30: 100%|██████████| 736/736 [04:36<00:00,  2.66it/s]


Epoch 13/30 - Loss: 0.2488, Acc: 89.63%


Epoch 14/30: 100%|██████████| 736/736 [04:36<00:00,  2.66it/s]


Epoch 14/30 - Loss: 0.2463, Acc: 89.66%


Epoch 15/30: 100%|██████████| 736/736 [04:37<00:00,  2.65it/s]


Epoch 15/30 - Loss: 0.2448, Acc: 89.92%


Epoch 16/30: 100%|██████████| 736/736 [04:36<00:00,  2.66it/s]


Epoch 16/30 - Loss: 0.2451, Acc: 89.71%


Epoch 17/30: 100%|██████████| 736/736 [04:38<00:00,  2.65it/s]


Epoch 17/30 - Loss: 0.2459, Acc: 89.75%


Epoch 18/30: 100%|██████████| 736/736 [04:37<00:00,  2.65it/s]


Epoch 18/30 - Loss: 0.2441, Acc: 89.77%


Epoch 19/30: 100%|██████████| 736/736 [04:38<00:00,  2.65it/s]


Epoch 19/30 - Loss: 0.2448, Acc: 89.79%


Epoch 20/30:  98%|█████████▊| 719/736 [04:32<00:06,  2.63it/s]

In [None]:
device = torch.device('cuda')
facenet = InceptionResnetV1(pretrained='vggface2').eval().to(device)
for param in facenet.parameters():
    param.requires_grad = False

model = nn.Sequential(
    facenet,
    nn.Linear(512, 256),
    nn.BatchNorm1d(256),
    nn.ReLU(),
    nn.Dropout(0.3),
    
    nn.Linear(256, 128),
    nn.BatchNorm1d(128),
    nn.ReLU(),
    nn.Dropout(0.3),

    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Linear(64, 2)
).to(device)

model.load_state_dict(torch.load("deep.pth"))
model.eval()

In [None]:
model.eval()
all_preds = []
all_labels = []
all_probs = []

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)

        probs = F.softmax(outputs, dim=1)
        _, preds = torch.max(probs, 1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        all_probs.extend(probs[:, 1].cpu().numpy()) 

# ---- Classification Report ----
print("\nClassification Report:")
print(classification_report(all_labels, all_preds, target_names=class_names))

# ---- Confusion Matrix ----
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.savefig('conmat1.png')
plt.show()

# ---- ROC Curve + AUC ----
fpr, tpr, _ = roc_curve(all_labels, all_probs)
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(6, 4))
plt.plot(fpr, tpr, label=f"ROC Curve (AUC = {roc_auc:.2f})", linewidth=2)
plt.plot([0, 1], [0, 1], 'k--', label="Random Guess")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve")
plt.legend(loc="lower right")
plt.grid(True)
plt.tight_layout()
plt.savefig("roc1.png")
plt.show()