In [1]:
from pathlib import Path
if Path.cwd().parent.stem == 't2':
    %cd ..
%config Completer.use_jedi = False

/home/step/Personal/UCH/2021-sem1/VisionComp/t2


## Seed

In [2]:
import torch
import pytorch_lightning as pl

pl.seed_everything(hash("setting a random seeds") % 2**32 - 1)
torch.backends.cudnn.benchmark = True

Global seed set to 2014041450


In [13]:
from torch.utils.data import random_split, DataLoader, Dataset
from torchvision import transforms, datasets
import skimage.morphology as morph

from typing import Tuple, Any

class ErosionReplicate(object):
    """Applies Erosion to one channel and replicates the channel"""
    def __call__(self, tensor):
        one_channel = tensor[0, :, :]
        one_channel = torch.from_numpy(morph.erosion(one_channel, morph.square(3)))    
        tensor[0, :, :] = one_channel
        tensor[1, :, :] = one_channel
        tensor[2, :, :] = one_channel
        return tensor
        

class SiameseDataset(Dataset):
    def __init__(self, base_dir_sk: str,  base_dir_ph: str, triplet: bool = False):
        super(Dataset, self).__init__()
        
        # base_dirs
        self.base_dir_sk = str(Path(base_dir_sk) / 'png_w256')
        self.base_dir_ph = base_dir_ph 
        
        # transforms
        base_tr = [
            transforms.Resize([224,224]),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]
        sketch_tr = [ErosionReplicate()]
        normalize_tr = [transforms.Normalize(mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5))]
        
        self.transforms_sk = transforms.Compose(base_tr + sketch_tr + normalize_tr)
        self.transforms_ph = transforms.Compose(base_tr + normalize_tr)
        
        
        # init two ImageFolders        
        self.dataset_sk = datasets.ImageFolder(self.base_dir_sk)
        self.dataset_ph = datasets.ImageFolder(self.base_dir_ph)
        
        self.n_classes = len(self.dataset_sk.classes)
        self.idx_to_class = {idx: label for label, idx in self.dataset_sk.class_to_idx.items()}
    
    
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index
        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path_sk, target_sk = self.dataset_sk.samples[index]
        path_ph, target_ph = self.dataset_ph.samples[index]
        
        sample_sk = self.dataset_sk.loader(path_sk)
        sample_ph = self.dataset_ph.loader(path_ph)
        
        sample_sk = self.transforms_sk(sample_sk)
        sample_ph = self.transforms_ph(sample_ph)
        
        
        
        return sample, target

In [14]:
SiameseDataset(base_dir_ph='data/Flickr15K/', base_dir_sk='data/Sketch_EITZ/').

<__main__.SiameseDataset at 0x7f8ce4041eb0>

Contrastive Model        |  Triplet Model
:-------------------------:|:-------------------------:
![alt text](../informe/contrastive_arch.png "Title") | ![alt text](../informe/triplet_arch.png "Title")

In [None]:
from torch.nn import functional as F
from torch import nn

class Normalize(nn.Module):
    def __init__(self, p: float = 2, dim: int = 1, eps: float = 1e-12):
        self.p = p
        self.dim = dim
        self.eps = eps
        
    def forward(self, input: torch.Tensor):
        return F.normalize(input, self.p, self.dim, self.eps)


class SimpleLearner(pl.LightningModule):
    def __init__(self, models: Type[Union[nn.Module, dict]], n_classes: int, mode: str):
        super().__init__()
        
        # shared weights
        shared_mlp = nn.Sequential(
            nn.Linear(model.fc.in_features, 512), # fc1
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 512) # fc2
        )
        
        # shared weights for classification
        classifier_mlp = nn.Sequential(
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, n_classes)
        )
        
        # normalization before contrastive or triplet
        siamese_out = nn.Sequential(
            nn.Normalize()
        )

        ce_criterion = nn.CrossEntropyLoss()
        
        if mode == 'contrastive':
            siamese_criterion = ContrastiveLoss()

        elif mode == 'triplet':
            siamese_criterion = nn.TripletMarginLoss(0.5)
            
        self._losses = (ce_criterion, siamese_criterion)

    def forward(self, x):
        if mode == 'contrastive':
            # anchor is sketch, positive is photo
            x_an, x_pos = x
            emb_an = models['sketch'](x_an)
            emb_pos = models['photo'](x_pos)
        
        elif mode == 'triplet':
            # anchor is sketch, positive and negative is photo
            x_anc, x_pos, x_neg = x
            emb_anc = shared_mlp(models['sketch'](x_anc))
            emb_pos = shared_mlp(models['photo'](x_pos))
            emb_neg = shared_mlp(models['photo'](x_pos))
            
            n_emb_anc = torch.unsqueeze(siamese_out(emb_anc), 1)
            n_emb_pos = torch.unsqueeze(siamese_out(emb_pos), 1)
            n_emb_neg = torch.unsqueeze(siamese_out(emb_neg), 1)
            
                        
            y_anc = torch.unsqueeze(classifier_mlp(emb_anc), 1)
            y_pos = torch.unsqueeze(classifier_mlp(emb_pos), 1)
            y_neg = torch.unsqueeze(classifier_mlp(emb_neg), 1)
            
            
            embeddings = torch.cat((n_emb_anc, n_emb_pos, n_emb_neg), 1)
            y_hats = torch.cat((y_anc, y_pos, y_neg), 1)
            return embeddings, y_hats

    def training_step(self, batch,  batch_idx):
        samples, targets = batch[0], batch[1]
        
        embeddings, y_hats = self.forward(samples)
        out = [embeddings, y_hats]
        
        losses = []
        losses_names = []
        for l, pred in zip(self._losses, out):
            loss = l.loss(pred, samples)
#             effective_loss = l.weight * loss

            losses_names.append(l.name)
            losses.append(loss)

        losses_dict = {n: l for n, l in zip(losses_names, losses)}
        if len(losses_names) > 1:
            self.log(losses_dict, prog_bar=True, logger=False)
        
        losses_dict['loss'] = sum(losses)
        return losses_dict

    def validation_step(self, batch, batch_ixd):
        img_lr = batch['lr']
        img_hr = batch['hr']
        img_sr = self.forward(img_lr)
        
        val_loss = self.criterion(img_sr, img_hr)
        
        self.log('val_loss', val_loss)
        return {'val_loss':val_loss }

    def configure_optimizers(self):
        # you define one or multiple optimizer, learning rate schedulers, etc.
        # links provided on the train.py about setting up 2 opts for GAN
        return [optim.Adam(self.parameters(), lr=PARAMS['learning_rate'])]
