<a href="https://colab.research.google.com/github/vaibhavdubey7/snn/blob/main/siamese_nn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

###RUN

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageOps
import random
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import imageio as iio

from torch.autograd import Variable
from torch.utils.data import DataLoader


In [None]:
import warnings
warnings.filterwarnings('ignore')


In [None]:
TRAIN_DIR = './dataset_all/training/'
TEST_DIR = './dataset_all/testing/'

WEIGHT_PATH = './weights/steel/'
SAVE_IMAGE_PATH = './assets/'

SAVE_TEST_RESULTS_PATH = './assets/test_output'

device = torch.device('cuda:0' if (torch.cuda.is_available()) else 'cpu')


In [None]:
BATCH_SIZE = 32
N_EPOCHS = 100
LR = 0.0005
N_GPU = 1
SET_L = True
INPUT_D = 3 if SET_L == False else 1


In [None]:
def show_img(img, text=None, figsize=(20,20), save_into=None):
    img = img.numpy()
    plt.figure(figsize=figsize)
    plt.axis('off')
    if text:
        plt.text(50, 8, text, bbox={'facecolor': 'white', 'alpha': 1 })
    plt.imshow(np.transpose(img, (1, 2, 0)))

    if save_into:
        plt.savefig(save_into)
        plt.show()
    else:
        plt.show()


In [None]:
def explain_databatch(batch):
    """
    batch: tensor batch from dataset
    """
    print('Each batch of training data is a tuple of {} elements.'.format(len(batch)))

    print('Shape of each element of a tuple in the batch:')
    for i in range(len(batch)):
        print('element#{}: {}'.format(i, batch[i].shape))


In [None]:
class SiameseDataset(torch.utils.data.Dataset):
    """
    Load dataset for training the network
    """

    def __init__(self,
                 datasets,
                 train=True,
                 transform=None,
                 should_invert=True,
                 set_luminance=True):

        self.datasets = datasets
        self.train = train
        self.transform = transform
        self.should_invert = should_invert
        self.set_luminance = set_luminance

    def __getitem__(self, index):

        imageA_tupple = random.choice(self.datasets.imgs)
        get_same_class = random.randint(0,1)
        if get_same_class:
            while True:
                imageB_tupple = random.choice(self.datasets.imgs)
                if imageA_tupple[1] == imageB_tupple[1]:
                    break
        else:
            while True:
                imageB_tupple = random.choice(self.datasets.imgs)
                if imageA_tupple[1] != imageB_tupple[1]:
                    break

        imageA = Image.open(imageA_tupple[0])
        imageB = Image.open(imageB_tupple[0])

        if self.set_luminance:
            imageA = imageA.convert('L')
            imageB = imageB.convert('L')

        if self.should_invert:
            imageA = ImageOps.invert(imageA)
            imageB = ImageOps.invert(imageB)

        if self.transform is not None:
            imageA = self.transform(imageA)
            imageB = self.transform(imageB)

        if self.train:
            return imageA, imageB, torch.from_numpy(np.array([int(imageA_tupple[1] != imageB_tupple[1])],
                                              dtype=np.float32))
        else:
            return imageA, imageA_tupple[0], imageA_tupple[1], imageB, imageB_tupple[0], imageB_tupple[1]

    def __len__(self):
        return len(self.datasets.imgs)

In [None]:
class RandRotateTransform:
    """Rotate by one of the given angles."""

    def __init__(self, angles):
        self.angles = angles

    def __call__(self, x):
        angle = random.choice(self.angles)
        return transforms.functional.rotate(x, angle)


In [None]:
transform = transforms.Compose([
                            transforms.Resize((100,100)),
                            RandRotateTransform(angles=[0, 90, 180, 270]),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5,), (0.5,))
])

In [None]:
datasets = iio.imread("/content/datasets_images.png")

In [None]:
transform_visualize = transforms.Compose([
                            transforms.Resize((100,100)),
                            RandRotateTransform(angles=[0, 90, 180, 270]),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(),
])

siamese_datasets = SiameseDataset(datasets=datasets,
                                  train=True,
                                  transform=transform_visualize,
                                  should_invert=False,
                                  set_luminance=SET_L)

data_loader = DataLoader(siamese_datasets,
                         shuffle=True,
                         batch_size=BATCH_SIZE)

example_batch = next(iter(data_loader))
print('Total Datasets:', len(datasets), '\n')

explain_databatch(example_batch)

print('\nEach batch of the dataset 3 elements: 2 images and their similarity label')


AttributeError: ignored

In [None]:
concatenated = torch.cat((example_batch[0][:16], example_batch[1][:16]), 0)

show_img(
    torchvision.utils.make_grid(concatenated), save_into=os.path.join(SAVE_IMAGE_PATH, 'datasets_images.png')
)

NameError: ignored

In [None]:
siamese_datasets = SiameseDataset(datasets=datasets,
                                  transform=transform,
                                  should_invert=False,
                                  set_luminance=SET_L)

train_loader = DataLoader(siamese_datasets,
                          shuffle=True,
                          num_workers=0,
                          batch_size=BATCH_SIZE)


AttributeError: ignored

In [None]:
siamese_datasets_ = SiameseDataset(datasets=datasets,
                                  transform=transform_visualize,
                                  should_invert=False,
                                  set_luminance=SET_L)

train_loader_ = DataLoader(siamese_datasets_,
                          shuffle=True,
                          num_workers=0,
                          batch_size=BATCH_SIZE)

example_batch = next(iter(train_loader_))

explain_databatch(example_batch)

concatenated = torch.cat((example_batch[0][:4],
                          example_batch[1][:4]), 0)

show_img(torchvision.utils.make_grid(concatenated),
         save_into=os.path.join(SAVE_IMAGE_PATH, 'train_images.png'))

print(torchvision.utils.make_grid(concatenated).shape)


In [None]:
class SiamaseNet(nn.Module):

    def __init__(self):
        super(SiamaseNet, self).__init__()
        self.conv_layer = nn.Sequential(
            nn.ZeroPad2d(1),
            nn.Conv2d(INPUT_D, 4, kernel_size=3),
            nn.ReLU(),
            nn.ZeroPad2d(1),
            nn.Conv2d(4, 8, kernel_size=3),
            nn.ReLU(),
            nn.ZeroPad2d(1),
            nn.Conv2d(8, 8, kernel_size=3),
            nn.ReLU(),
        )

        self.fc_layer = nn.Sequential(
            nn.Linear(8*100*100, 500),
            nn.ReLU(),
            nn.Linear(500, 500),
            nn.ReLU(),
            nn.Linear(500, 5),
        )

    def forward_once(self, x):
        x = self.conv_layer(x)
        x = x.view(x.size()[0], -1)
        x = self.fc_layer(x)

        return x

    def forward(self, imageA, imageB):
        resultA = self.forward_once(imageA)
        resultB = self.forward_once(imageB)

        return resultA, resultB


In [None]:
siamese = SiamaseNet().to(device)


In [None]:
class ContrastiveLoss(torch.nn.Module):

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, resultA, resultB, label):
        euclidean_distance = 0.5*F.pairwise_distance(resultA, resultB, keepdim=True) # L2 Norm

        contrastive_loss = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                        (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))

        return contrastive_loss


In [None]:
contrastive_loss = ContrastiveLoss()

In [None]:
optimizer = optim.Adam(siamese.parameters(), lr=LR)

In [None]:
train_loss = 0
train_loss_history = []

N = len(datasets)

for epoch in range(N_EPOCHS):
    for batch_idx, data in enumerate(train_loader):
        imageA, imageB, label = data
        imageA, imageB, label = imageA.to(device), imageB.to(device), label.to(device)

        optimizer.zero_grad()
        resultA, resultB = siamese(imageA, imageB)
        loss = contrastive_loss(resultA, resultB, label)

        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    BatchCount_PER_EPOCH = N//BATCH_SIZE
    train_loss = train_loss/BatchCount_PER_EPOCH
    train_loss_history.append(train_loss)

    print('Train Epoch:{}\t\tLoss:{:.6f}'.format(epoch+1, train_loss))

    if not os.path.exists(WEIGHT_PATH):
        os.makedirs(WEIGHT_PATH)
    torch.save(siamese, os.path.join(WEIGHT_PATH, 'siamese_{:03}.pt'.format(epoch+1)))
    train_loss = 0

    siamese.train()

In [None]:
training_data = np.array(train_loss_history)
np.save("training_history", training_data)
training_data = np.load("training_history.npy")


In [None]:
plt.figure(figsize=(10,10))
plt.title("Training Loss for one-shot recognition of Defects in Steel surfaces")
plt.plot(training_data, label="Train Loss")

plt.xlabel("Epoch")
plt.grid(True)
plt.ylabel("Loss")
plt.legend()

if not os.path.exists(SAVE_IMAGE_PATH):
    os.makedirs(SAVE_IMAGE_PATH)

plt.savefig(os.path.join(SAVE_IMAGE_PATH, 'final_train_loss.svg'),
            quality=95, dpi=400, pad_inches=0, bbox_inches='tight')
plt.show()

In [None]:
TEST_DIR = './dataset_all/testing/'
class_list = os.listdir(TEST_DIR)
class_list.sort()
test_datasets = torchvision.datasets.ImageFolder(root=TEST_DIR)
siamese_test_datasets = SiameseDataset(datasets=test_datasets,
                                       train=False,
                                       transform=transform,
                                       should_invert=False,
                                       set_luminance=SET_L)
siamese = torch.load("./weights/steel/siamese_100.pt")
siamese.eval()


In [None]:
test_loader = DataLoader(siamese_test_datasets,
                         num_workers=0,
                         batch_size=1,
                         shuffle=True)

data_iter = iter(test_loader)

imageA, pathA, classA, _, _, _ = next(data_iter)
SAVE_TEST_RESULTS_PATH = "./assets/paper"
SAVE_TEST_RESULTS = "./assets/paper"
test_loader = DataLoader(siamese_test_datasets,
                         num_workers=0,
                         batch_size=1,
                         shuffle=True)

data_iter = iter(test_loader)

correct_ = 0
count = 0
for i in range(len(siamese_test_datasets)):
    try:
        count += 1
        imageA, pathA, classA, imageB, pathB, classB = next(data_iter)
    except:
        break

    with torch.no_grad():
        resultA, resultB = siamese(Variable(imageA).to(device), Variable(imageB).to(device))
        euclidean_distance = F.pairwise_distance(resultA, resultB)

        if euclidean_distance>=2. and classA!=classB:
            correct_ += 1
        elif euclidean_distance<2. and classA==classB:
            correct_ += 1

    concatenated = torch.cat((imageA, imageB),0)

    # Uncomment below lines to save figures
    image_file_name = 'test_output_{}.png'.format(i)
    save_results_in = os.path.join(SAVE_TEST_RESULTS_PATH, image_file_name)