In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import Subset
from torch.utils.data import DataLoader
from collections import Counter
from torch.optim import Adam
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
import random
from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import os

## Dataset

In [19]:
human_faces_dir = '/kaggle/input/human-faces/Humans'
flowers_dir = '/kaggle/input/face-dataset/human-swap/'

human_faces_files = len(os.listdir(human_faces_dir))
flowers_files = len(os.listdir(flowers_dir))

print(f"Number of files in human_faces_dir: {human_faces_files}")
print(f"Number of files in flowers_dir: {flowers_files}")

Number of files in human_faces_dir: 7219
Number of files in flowers_dir: 6676


In [20]:
my_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

class CustomDataset(Dataset):
    def __init__(self, human_faces_dir, flowers_dir, transform=None):
        self.human_faces_dir = human_faces_dir
        self.flowers_dir = flowers_dir
        
        human_faces_images = [os.path.join(human_faces_dir, fname) for fname in os.listdir(human_faces_dir)]
        self.human_faces_images = random.sample(human_faces_images, 300)
        flowers_images = [os.path.join(flowers_dir, fname) for fname in os.listdir(flowers_dir)]
        self.flowers_images = random.sample(flowers_images, 200)
        
        self.all_images = self.human_faces_images + self.flowers_images
        self.labels = [1] * len(self.human_faces_images) + [0] * len(self.flowers_images)
        self.transform = transform

    def __len__(self):
        return len(self.all_images)

    def __getitem__(self, idx):
        img_path = self.all_images[idx]
        image = Image.open(img_path).convert("RGB")
        
        label = self.labels[idx]
        
        if self.transform: image = self.transform(image)
        return image, label

dataset = CustomDataset(human_faces_dir, flowers_dir, transform=my_transform)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [21]:
for images, labels in dataloader:
    print(images.shape)
    print(labels)
    break

torch.Size([128, 3, 128, 128])
tensor([0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0,
        1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1,
        1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0,
        0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 0])


In [22]:
class TestDataset(Dataset):
    def __init__(self, test_dir, transform=None):
        self.test_dir = test_dir
        test_images = [os.path.join(test_dir, fname) for fname in os.listdir(test_dir) if fname.endswith(('.jpg', '.png', '.jpeg'))]
        test_images = random.sample(test_images, 128)
        self.all_images = test_images
        self.transform = transform

    def __len__(self):
        return len(self.all_images)

    def __getitem__(self, idx):
        img_path = self.all_images[idx]
        image = Image.open(img_path).convert("RGB")
        
        if self.transform: image = self.transform(image)
        return image

test_dir = "/kaggle/input/face-dataset/human-swap/"
test_dataset = TestDataset(test_dir, transform=my_transform)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=False)

In [23]:
for images in test_dataloader:
    print(images.shape)
    break

torch.Size([128, 3, 128, 128])


## Capsule

In [24]:
def debug_print(debug, message, tensor=None):
    if debug:
        if tensor is not None:
            print(f"{message}: {tensor.shape}")
        else:
            print(message)

def squash(input_tensor, epsilon=1e-7):
    squared_norm = (input_tensor ** 2 + epsilon).sum(-1, keepdim=True)
    output_tensor = (squared_norm / (1. + squared_norm)) *  (input_tensor / torch.sqrt(squared_norm))
    return output_tensor

class ConvLayer(nn.Module):
    def __init__(self, in_channels=3, out_channels=256, kernel_size=9, debug=False):
        super(ConvLayer, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=1)
        self.debug = debug

    def forward(self, x):
        x = self.conv(x)
        debug_print(self.debug, "x after conv", x)
        x = F.relu(x)
        debug_print(self.debug, "x after ReLU", x)
        return x


class PrimaryCaps(nn.Module):
    def __init__(self, num_capsules=8, in_channels=256, out_channels=32, kernel_size=9, debug=False):
        super(PrimaryCaps, self).__init__()
        self.capsules = nn.ModuleList()
        for _ in range(num_capsules):
            capsule = nn.Sequential(
                nn.Conv2d(in_channels=in_channels, out_channels=128, kernel_size=9, stride=2, padding=0),
                nn.Conv2d(in_channels=128, out_channels=64, kernel_size=9, stride=2, padding=0),
                nn.Conv2d(in_channels=64, out_channels=out_channels, kernel_size=9, stride=3, padding=0)
            )
            self.capsules.append(capsule)
        self.debug = debug
    
    def forward(self, x):
        stacked_capsules = [capsule(x) for capsule in self.capsules]
        if self.debug: 
            print("capsule_out:")
            for capsule_out in stacked_capsules:
                print("\t", capsule_out.shape)
        stacked_capsules = torch.stack(stacked_capsules, dim=1)
        debug_print(self.debug, "stacked_capsules", stacked_capsules)
        flattened_capsules = stacked_capsules.view(x.size(0), 32 * 6 * 6, -1)
        debug_print(self.debug, "flattened_capsules", flattened_capsules)
        squashed_output = squash(flattened_capsules)
        debug_print(self.debug, "squashed_output", squashed_output)
        return squashed_output


class DigitCaps(nn.Module):
    def __init__(self, num_capsules=2, num_routes=32 * 6 * 6, in_channels=8, out_channels=16, debug=False):
        super(DigitCaps, self).__init__()
        self.in_channels = in_channels
        self.num_routes = num_routes
        self.num_capsules = num_capsules
        self.W = nn.Parameter(torch.randn(1, num_routes, num_capsules, out_channels, in_channels))
        self.debug = debug

    def forward(self, x):
        batch_size = x.size(0)
        x = torch.stack([x] * self.num_capsules, dim=2).unsqueeze(4)
        debug_print(self.debug, "x after stacking", x)

        W = torch.cat([self.W] * batch_size, dim=0)
        debug_print(self.debug, "W", W)
        
        u_hat = torch.matmul(W, x)
        debug_print(self.debug, "u_hat", u_hat)

        b_ij = Variable(torch.zeros(1, self.num_routes, self.num_capsules, 1))
        b_ij = b_ij.to(device)
        debug_print(self.debug, "b_ij", b_ij)

        num_iter = 3
        for i in range(num_iter):
            if self.debug: print()
            c_ij = F.softmax(b_ij, dim=1)
            debug_print(self.debug, "c_ij", c_ij)
            c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)
            debug_print(self.debug, "c_ij after repeat", c_ij)

            s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
            debug_print(self.debug, "s_j", s_j)

            v_j = squash(s_j)
            debug_print(self.debug, "v_j", v_j)
            
            a_ij = torch.matmul(u_hat.transpose(3, 4), torch.cat([v_j] * self.num_routes, dim=1))
            debug_print(self.debug, "a_ij", a_ij)
            b_ij = b_ij + a_ij.squeeze(4).mean(dim=0, keepdim=True)
            debug_print(self.debug, "b_ij updated", b_ij)

        return v_j.squeeze(1)

class Decoder(nn.Module):
    def __init__(self, debug=False):
        super(Decoder, self).__init__()
        self.debug = debug
        
    def forward(self, x, data):
        classes = torch.sqrt((x ** 2).sum(2))
        debug_print(self.debug, "classes before softmax", classes)
        classes = F.softmax(classes, dim=0)
        debug_print(self.debug, "classes after softmax", classes)
        
        _, max_length_indices = classes.max(dim=1)
        debug_print(self.debug, "max_length_indices", max_length_indices)
        masked = Variable(torch.sparse.torch.eye(2))
        debug_print(self.debug, "masked", masked)
        masked = masked.to(device)
        masked = masked.index_select(dim=0, index=Variable(max_length_indices.squeeze(1).data))
        debug_print(self.debug, "masked after index_select", masked)
        
        return masked

In [25]:
class CapsuleNet(nn.Module):
    def __init__(self, debug=False):
        super(CapsuleNet, self).__init__()
        self.conv_layer = ConvLayer(debug=debug)
        self.primary_capsules = PrimaryCaps(debug=debug)
        self.digit_capsules = DigitCaps(debug=debug)
        self.decoder = Decoder(debug=debug)
        self.mse_loss = nn.MSELoss()
        self.debug = debug
        
    def forward(self, data):
        debug_print(self.debug, f"Input data", data)
        debug_print(self.debug, "\nCONV")
        output = self.conv_layer(data)
        debug_print(self.debug, "\nPRIMARY")
        output = self.primary_capsules(output)
        debug_print(self.debug, "\nDIGIT")
        output = self.digit_capsules(output)
        debug_print(self.debug, "\nDECODER")
        masked = self.decoder(output, data)
        debug_print(self.debug, "\nOUTPUT", output)
        return output, masked

    def loss(self, x, target):
        return self.margin_loss(x, target)
    
    def margin_loss(self, x, labels):
        batch_size = x.size(0)
        v_k = torch.sqrt((x**2).sum(dim=2, keepdim=True))
        left = F.relu(0.9 - v_k).view(batch_size, -1)
        right = F.relu(v_k - 0.1).view(batch_size, -1)
        loss = labels * left + 0.5 * (1.0 - labels) * right
        loss = loss.sum(dim=1).mean()
        return loss

In [26]:
capsule_net = CapsuleNet(debug=True).to(device) 
data, target = next(iter(dataloader))
target = torch.sparse.torch.eye(2).index_select(dim=0, index=target)
data, target = Variable(data), Variable(target)
data, target = data.to(device), target.to(device)
output, masked = capsule_net(data)

Input data: torch.Size([128, 3, 128, 128])

CONV
x after conv: torch.Size([128, 256, 120, 120])
x after ReLU: torch.Size([128, 256, 120, 120])

PRIMARY
capsule_out:
	 torch.Size([128, 32, 6, 6])
	 torch.Size([128, 32, 6, 6])
	 torch.Size([128, 32, 6, 6])
	 torch.Size([128, 32, 6, 6])
	 torch.Size([128, 32, 6, 6])
	 torch.Size([128, 32, 6, 6])
	 torch.Size([128, 32, 6, 6])
	 torch.Size([128, 32, 6, 6])
stacked_capsules: torch.Size([128, 8, 32, 6, 6])
flattened_capsules: torch.Size([128, 1152, 8])
squashed_output: torch.Size([128, 1152, 8])

DIGIT
x after stacking: torch.Size([128, 1152, 2, 8, 1])
W: torch.Size([128, 1152, 2, 16, 8])
u_hat: torch.Size([128, 1152, 2, 16, 1])
b_ij: torch.Size([1, 1152, 2, 1])

c_ij: torch.Size([1, 1152, 2, 1])
c_ij after repeat: torch.Size([128, 1152, 2, 1, 1])
s_j: torch.Size([128, 1, 2, 16, 1])
v_j: torch.Size([128, 1, 2, 16, 1])
a_ij: torch.Size([128, 1152, 2, 1, 1])
b_ij updated: torch.Size([1, 1152, 2, 1])

c_ij: torch.Size([1, 1152, 2, 1])
c_ij after

In [27]:
capsule_net = CapsuleNet().to(device) 
optimizer = Adam(capsule_net.parameters())

n_epochs = 20

for epoch in range(n_epochs):
    capsule_net.train()
    train_loss = 0
    correct_train = 0
    total_train = 0
    for batch_id, (data, target) in enumerate(tqdm(dataloader)):
        target = torch.sparse.torch.eye(2).index_select(dim=0, index=target)
        data, target = Variable(data), Variable(target)
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output, masked = capsule_net(data)
        loss = capsule_net.loss(output, target)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        preds = np.argmax(masked.data.cpu().numpy(), axis=1)
        targets = np.argmax(target.data.cpu().numpy(), axis=1)
        correct_train += np.sum(preds == targets)
        total_train += len(targets)
    
    train_accuracy = correct_train / total_train
    avg_train_loss = train_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{n_epochs} - Loss: {avg_train_loss:.4f} - Accuracy: {train_accuracy:.4f}")
    

100%|██████████| 4/4 [01:10<00:00, 17.55s/it]


Epoch 1/20 - Loss: 0.8335 - Accuracy: 0.5120


100%|██████████| 4/4 [01:07<00:00, 16.76s/it]


Epoch 2/20 - Loss: 0.7427 - Accuracy: 0.5540


100%|██████████| 4/4 [01:06<00:00, 16.68s/it]


Epoch 3/20 - Loss: 0.6956 - Accuracy: 0.5280


100%|██████████| 4/4 [01:06<00:00, 16.66s/it]


Epoch 4/20 - Loss: 0.6626 - Accuracy: 0.5260


100%|██████████| 4/4 [01:06<00:00, 16.68s/it]


Epoch 5/20 - Loss: 0.6410 - Accuracy: 0.5560


100%|██████████| 4/4 [01:06<00:00, 16.62s/it]


Epoch 6/20 - Loss: 0.6244 - Accuracy: 0.5420


100%|██████████| 4/4 [01:06<00:00, 16.58s/it]


Epoch 7/20 - Loss: 0.6123 - Accuracy: 0.5400


100%|██████████| 4/4 [01:06<00:00, 16.54s/it]


Epoch 8/20 - Loss: 0.6016 - Accuracy: 0.5520


100%|██████████| 4/4 [01:06<00:00, 16.63s/it]


Epoch 9/20 - Loss: 0.5933 - Accuracy: 0.5620


100%|██████████| 4/4 [01:06<00:00, 16.61s/it]


Epoch 10/20 - Loss: 0.5844 - Accuracy: 0.5600


100%|██████████| 4/4 [01:06<00:00, 16.59s/it]


Epoch 11/20 - Loss: 0.5753 - Accuracy: 0.5840


100%|██████████| 4/4 [01:06<00:00, 16.58s/it]


Epoch 12/20 - Loss: 0.5675 - Accuracy: 0.5820


100%|██████████| 4/4 [01:06<00:00, 16.59s/it]


Epoch 13/20 - Loss: 0.5571 - Accuracy: 0.6000


100%|██████████| 4/4 [01:06<00:00, 16.59s/it]


Epoch 14/20 - Loss: 0.5501 - Accuracy: 0.6360


100%|██████████| 4/4 [01:06<00:00, 16.59s/it]


Epoch 15/20 - Loss: 0.5416 - Accuracy: 0.6280


100%|██████████| 4/4 [01:06<00:00, 16.63s/it]


Epoch 16/20 - Loss: 0.5280 - Accuracy: 0.6380


100%|██████████| 4/4 [01:06<00:00, 16.60s/it]


Epoch 17/20 - Loss: 0.5115 - Accuracy: 0.6580


100%|██████████| 4/4 [01:06<00:00, 16.63s/it]


Epoch 18/20 - Loss: 0.4895 - Accuracy: 0.6500


100%|██████████| 4/4 [01:06<00:00, 16.59s/it]


Epoch 19/20 - Loss: 0.4659 - Accuracy: 0.6080


100%|██████████| 4/4 [01:05<00:00, 16.49s/it]

Epoch 20/20 - Loss: 0.4521 - Accuracy: 0.6240





In [29]:
capsule_net.eval()
test_loss = 0
correct_test = 0
total_test = 0

with torch.no_grad():
    for data in test_dataloader:
        target = torch.sparse.torch.eye(2).index_select(dim=0, index=torch.tensor([0] * 128))
        data, target = Variable(data), Variable(target)
        data, target = data.to(device), target.to(device)
    
        output, masked = capsule_net(data)
        loss = capsule_net.loss(output, target)
    
        test_loss += loss.item()
    
        preds = np.argmax(masked.data.cpu().numpy(), axis=1)
        targets = np.argmax(target.data.cpu().numpy(), axis=1)
        print(preds)
        print(targets)
        correct_test += np.sum(preds == targets)
        total_test += len(targets)
    
    test_accuracy = correct_test / total_test
    avg_test_loss = test_loss / len(test_dataloader)
    print(f"Loss: {avg_test_loss:.4f} - Accuracy: {test_accuracy:.4f}")

[1 0 1 1 0 1 0 0 0 0 0 1 1 0 1 1 1 1 1 1 1 1 1 1 0 0 1 0 0 1 1 0 1 1 0 1 1
 0 1 0 0 1 0 1 1 1 0 1 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 1 1 0 1 1 1 1 1 1 0
 1 0 0 0 0 1 0 0 1 0 1 1 1 1 0 1 0 0 1 0 1 1 0 1 0 1 0 0 1 1 0 0 1 0 1 1 0
 1 0 1 1 1 0 1 1 0 0 0 0 1 1 1 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
Loss: 0.8182 - Accuracy: 0.4609
