In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import Subset
from torch.utils.data import DataLoader
from collections import Counter
from torch.optim import Adam
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
import random
from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import os

## Dataset

In [4]:
human_faces_dir = '/kaggle/input/human-faces/Humans'
flowers_dir = '/kaggle/input/flowers-dataset/test'

human_faces_files = len(os.listdir(human_faces_dir))
flowers_files = len(os.listdir(flowers_dir))

print(f"Number of files in human_faces_dir: {human_faces_files}")
print(f"Number of files in flowers_dir: {flowers_files}")

Number of files in human_faces_dir: 7219
Number of files in flowers_dir: 924


In [5]:
my_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

class CustomDataset(Dataset):
    def __init__(self, human_faces_dir, flowers_dir, transform=None):
        self.human_faces_dir = human_faces_dir
        self.flowers_dir = flowers_dir
        
        human_faces_images = [os.path.join(human_faces_dir, fname) for fname in os.listdir(human_faces_dir)]
        self.human_faces_images = random.sample(human_faces_images, 300)
        flowers_images = [os.path.join(flowers_dir, fname) for fname in os.listdir(flowers_dir)]
        self.flowers_images = random.sample(flowers_images, 200)
        
        self.all_images = self.human_faces_images + self.flowers_images
        self.labels = [1] * len(self.human_faces_images) + [0] * len(self.flowers_images)
        self.transform = transform

    def __len__(self):
        return len(self.all_images)

    def __getitem__(self, idx):
        img_path = self.all_images[idx]
        image = Image.open(img_path).convert("RGB")
        
        label = self.labels[idx]
        
        if self.transform: image = self.transform(image)
        return image, label

dataset = CustomDataset(human_faces_dir, flowers_dir, transform=my_transform)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
for images, labels in dataloader:
    print(images.shape)
    print(labels)
    break

torch.Size([128, 3, 128, 128])
tensor([1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0,
        1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1,
        0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1,
        0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0,
        0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1,
        1, 0, 1, 0, 0, 1, 0, 0])


In [7]:
class TestDataset(Dataset):
    def __init__(self, test_dir, transform=None):
        self.test_dir = test_dir
        test_images = [os.path.join(test_dir, fname) for fname in os.listdir(test_dir) if fname.endswith(('.jpg', '.png', '.jpeg'))]
        test_images = random.sample(test_images, 128)
        self.all_images = test_images
        self.transform = transform

    def __len__(self):
        return len(self.all_images)

    def __getitem__(self, idx):
        img_path = self.all_images[idx]
        image = Image.open(img_path).convert("RGB")
        
        if self.transform: image = self.transform(image)
        return image

test_dir = "/kaggle/input/face-dataset/human-swap/"
test_dataset = TestDataset(test_dir, transform=my_transform)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=False)

In [8]:
for images in test_dataloader:
    print(images.shape)
    break

torch.Size([128, 3, 128, 128])


## Capsule

In [None]:
def debug_print(debug, message, tensor=None):
    if debug:
        if tensor is not None:
            print(f"{message}: {tensor.shape}")
        else:
            print(message)

def squash(input_tensor, epsilon=1e-7):
    squared_norm = (input_tensor ** 2 + epsilon).sum(-1, keepdim=True)
    output_tensor = (squared_norm / (1. + squared_norm)) *  (input_tensor / torch.sqrt(squared_norm))
    return output_tensor

class ConvLayer(nn.Module):
    def __init__(self, in_channels=1, out_channels=256, kernel_size=9, debug=False):
        super(ConvLayer, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=1)
        self.debug = debug

    def forward(self, x):
        x = self.conv(x)
        debug_print(self.debug, "x after conv", x)
        x = F.relu(x)
        debug_print(self.debug, "x after ReLU", x)
        return x


class PrimaryCaps(nn.Module):
    def __init__(self, num_capsules=8, in_channels=256, out_channels=32, kernel_size=9, debug=False):
        super(PrimaryCaps, self).__init__()
        self.capsules = nn.ModuleList()
        for _ in range(num_capsules):
            capsule = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=2, padding=0)
            self.capsules.append(capsule)
        self.debug = debug
    
    def forward(self, x):
        stacked_capsules = [capsule(x) for capsule in self.capsules]
        if self.debug: 
            print("capsule_out:")
            for capsule_out in stacked_capsules:
                print("\t", capsule_out.shape)
        stacked_capsules = torch.stack(stacked_capsules, dim=1)
        debug_print(self.debug, "stacked_capsules", stacked_capsules)
        flattened_capsules = stacked_capsules.view(x.size(0), 32 * 6 * 6, -1)
        debug_print(self.debug, "flattened_capsules", flattened_capsules)
        squashed_output = squash(flattened_capsules)
        debug_print(self.debug, "squashed_output", squashed_output)
        return squashed_output


class DigitCaps(nn.Module):
    def __init__(self, num_capsules=10, num_routes=32 * 6 * 6, in_channels=8, out_channels=16, debug=False):
        super(DigitCaps, self).__init__()
        self.in_channels = in_channels
        self.num_routes = num_routes
        self.num_capsules = num_capsules
        self.W = nn.Parameter(torch.randn(1, num_routes, num_capsules, out_channels, in_channels))
        self.debug = debug

    def forward(self, x):
        batch_size = x.size(0)
        x = torch.stack([x] * self.num_capsules, dim=2).unsqueeze(4)
        debug_print(self.debug, "x after stacking", x)

        W = torch.cat([self.W] * batch_size, dim=0)
        debug_print(self.debug, "W", W)
        
        u_hat = torch.matmul(W, x)
        debug_print(self.debug, "u_hat", u_hat)

        b_ij = Variable(torch.zeros(1, self.num_routes, self.num_capsules, 1))
        b_ij = b_ij.to(device)
        debug_print(self.debug, "b_ij", b_ij)

        num_iter = 3
        for i in range(num_iter):
            if self.debug: print()
            c_ij = F.softmax(b_ij, dim=1)
            debug_print(self.debug, "c_ij", c_ij)
            c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)
            debug_print(self.debug, "c_ij after repeat", c_ij)

            s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
            debug_print(self.debug, "s_j", s_j)

            v_j = squash(s_j)
            debug_print(self.debug, "v_j", v_j)
            
            a_ij = torch.matmul(u_hat.transpose(3, 4), torch.cat([v_j] * self.num_routes, dim=1))
            debug_print(self.debug, "a_ij", a_ij)
            b_ij = b_ij + a_ij.squeeze(4).mean(dim=0, keepdim=True)
            debug_print(self.debug, "b_ij updated", b_ij)

        return v_j.squeeze(1)

class Decoder(nn.Module):
    def __init__(self, debug=False):
        super(Decoder, self).__init__()
        self.debug = debug
        
    def forward(self, x, data):
        classes = torch.sqrt((x ** 2).sum(2))
        debug_print(self.debug, "classes before softmax", classes)
        classes = F.softmax(classes, dim=0)
        debug_print(self.debug, "classes after softmax", classes)
        
        _, max_length_indices = classes.max(dim=1)
        debug_print(self.debug, "max_length_indices", max_length_indices)
        masked = Variable(torch.sparse.torch.eye(10))
        debug_print(self.debug, "masked", masked)
        masked = masked.to(device)
        masked = masked.index_select(dim=0, index=Variable(max_length_indices.squeeze(1).data))
        debug_print(self.debug, "masked after index_select", masked)
        
        return masked


In [None]:
class CapsuleNet(nn.Module):
    def __init__(self, debug=False):
        super(CapsuleNet, self).__init__()
        self.conv_layer = ConvLayer(debug=debug)
        self.primary_capsules = PrimaryCaps(debug=debug)
        self.digit_capsules = DigitCaps(debug=debug)
        self.decoder = Decoder(debug=debug)
        self.mse_loss = nn.MSELoss()
        self.debug = debug
        
    def forward(self, data):
        debug_print(self.debug, f"Input data", data)
        debug_print(self.debug, "\nCONV")
        output = self.conv_layer(data)
        debug_print(self.debug, "\nPRIMARY")
        output = self.primary_capsules(output)
        debug_print(self.debug, "\nDIGIT")
        output = self.digit_capsules(output)
        debug_print(self.debug, "\nDECODER")
        masked = self.decoder(output, data)
        debug_print(self.debug, "\nOUTPUT", output)
        return output, masked

    def loss(self, x, target):
        return self.margin_loss(x, target)
    
    def margin_loss(self, x, labels):
        batch_size = x.size(0)
        v_k = torch.sqrt((x**2).sum(dim=2, keepdim=True))
        left = F.relu(0.9 - v_k).view(batch_size, -1)
        right = F.relu(v_k - 0.1).view(batch_size, -1)
        loss = labels * left + 0.5 * (1.0 - labels) * right
        loss = loss.sum(dim=1).mean()
        return loss

In [None]:
capsule_net = CapsuleNet().to(device) 
optimizer = Adam(capsule_net.parameters())

n_epochs = 20

for epoch in range(n_epochs):
    capsule_net.train()
    train_loss = 0
    correct_train = 0
    total_train = 0
    for batch_id, (data, target) in enumerate(tqdm(train_loader)):
        target = torch.sparse.torch.eye(10).index_select(dim=0, index=target)
        data, target = Variable(data), Variable(target)
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output, masked = capsule_net(data)
        loss = capsule_net.loss(output, target)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        preds = np.argmax(masked.data.cpu().numpy(), axis=1)
        targets = np.argmax(target.data.cpu().numpy(), axis=1)
        correct_train += np.sum(preds == targets)
        total_train += len(targets)
    
    train_accuracy = correct_train / total_train
    avg_train_loss = train_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{n_epochs} - Loss: {avg_train_loss:.4f} - Accuracy: {train_accuracy:.4f}")
    

In [None]:
capsule_net.eval()
test_loss = 0
correct_test = 0
total_test = 0

with torch.no_grad():
    for data, target in test_loader:
        target = torch.sparse.torch.eye(10).index_select(dim=0, index=target)
        data, target = Variable(data), Variable(target)
        data, target = data.to(device), target.to(device)

        output, masked = capsule_net(data)
        loss = capsule_net.loss(output, target)
        test_loss += loss.item()

        preds = np.argmax(masked.data.cpu().numpy(), axis=1)
        targets = np.argmax(target.data.cpu().numpy(), axis=1)
        correct_test += np.sum(preds == targets)
        total_test += len(targets)

test_accuracy = correct_test / total_test
avg_test_loss = test_loss / len(test_loader)

print(f"Test Loss: {avg_test_loss:.4f} - Test Accuracy: {test_accuracy:.4f}")

In [None]:
capsule_net.eval()

with torch.no_grad():
    data, target = next(iter(test_loader))
    target = torch.sparse.torch.eye(10).index_select(dim=0, index=target)
    data, target = Variable(data), Variable(target)
    data, target = data.to(device), target.to(device)
    output, masked = capsule_net(data)
    preds = np.argmax(masked.data.cpu().numpy(), axis=1)
    targets = np.argmax(target.data.cpu().numpy(), axis=1)

    fig, axes = plt.subplots(3, 5, figsize=(12, 6))
    axes = axes.flatten()
    
    for i in range(15):
        ax = axes[i]
        ax.imshow(data[i].cpu().numpy().squeeze(), cmap='gray')
        ax.set_title(f"True: {targets[i]} \nPred: {preds[i]}")
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()

## CNN

In [14]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 256, kernel_size=9, stride=1, padding=4)
        self.conv2 = nn.Conv2d(256, 128, kernel_size=9, stride=2, padding=4)
        self.conv3 = nn.Conv2d(128, 64, kernel_size=9, stride=2, padding=4)
        self.conv4 = nn.Conv2d(64, 32, kernel_size=9, stride=2, padding=4)
        self.fc_input_size = self._get_fc_input_size()
        self.fc = nn.Linear(self.fc_input_size, 1)
        self.sigmoid = nn.Sigmoid()

    def _get_fc_input_size(self):
        dummy_input = torch.zeros(1, 3, 128, 128)
        x = self.conv1(dummy_input)
        x = F.relu(F.max_pool2d(x, 2))
        x = F.relu(self.conv2(x))
        x = F.relu(F.max_pool2d(x, 2))
        x = F.relu(self.conv3(x))
        x = F.relu(F.max_pool2d(x, 2))
        x = F.relu(self.conv4(x))
        x = x.view(x.size(0), -1)
        return x.size(1)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(F.max_pool2d(x, 2))
        x = F.relu(self.conv2(x))
        x = F.relu(F.max_pool2d(x, 2))
        x = F.relu(self.conv3(x))
        x = F.relu(F.max_pool2d(x, 2))
        x = F.relu(self.conv4(x))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = self.sigmoid(x)
        return x

    def loss(self, output, target):
        target = target.unsqueeze(1)
        target = target.float()
        return F.binary_cross_entropy(output, target)


In [15]:
cnn = CNN().to(device)
optimizer = Adam(cnn.parameters())

n_epochs = 20
for epoch in range(n_epochs):
    cnn.train()
    train_loss = 0
    correct_train = 0
    total_train = 0
    for batch_id, (data, target) in enumerate(tqdm(dataloader)):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = cnn(data)
        loss = cnn.loss(output, target)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        preds = (output >= 0.5).float()
        correct_train += preds.eq(target.view_as(preds)).sum().item()
        total_train += target.size(0)
    
    train_accuracy = correct_train / total_train
    avg_train_loss = train_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{n_epochs} - Loss: {avg_train_loss:.4f} - Accuracy: {train_accuracy:.4f}")

100%|██████████| 4/4 [00:11<00:00,  2.99s/it]


Epoch 1/20 - Loss: 0.8237 - Accuracy: 0.4600


100%|██████████| 4/4 [00:11<00:00,  2.90s/it]


Epoch 2/20 - Loss: 0.6798 - Accuracy: 0.6040


100%|██████████| 4/4 [00:11<00:00,  2.92s/it]


Epoch 3/20 - Loss: 0.6371 - Accuracy: 0.7000


100%|██████████| 4/4 [00:11<00:00,  2.95s/it]


Epoch 4/20 - Loss: 0.5743 - Accuracy: 0.6720


100%|██████████| 4/4 [00:11<00:00,  2.88s/it]


Epoch 5/20 - Loss: 0.4327 - Accuracy: 0.8240


100%|██████████| 4/4 [00:11<00:00,  2.89s/it]


Epoch 6/20 - Loss: 0.3353 - Accuracy: 0.8680


100%|██████████| 4/4 [00:11<00:00,  2.94s/it]


Epoch 7/20 - Loss: 0.3127 - Accuracy: 0.8820


 75%|███████▌  | 3/4 [00:11<00:03,  3.86s/it]


KeyboardInterrupt: 

In [19]:
cnn.eval()
test_loss = 0
correct_test = 0
total_test = 0

for data in test_dataloader:
    data = data.to(device)
    output = cnn(data)
    target = torch.zeros(output.size(0)).float()
    
    target = target.to(device)
    loss = cnn.loss(output, target)

    test_loss += loss.item()

    preds = (output >= 0.5).float()
    correct_test += preds.eq(target.view_as(preds)).sum().item()
    total_test += target.size(0)

test_accuracy = correct_test / total_test
avg_test_loss = test_loss / len(test_dataloader)
print(f"Loss: {avg_test_loss:.4f} - Accuracy: {test_accuracy:.4f}")

Loss: 1.9946 - Accuracy: 0.0859
