In [65]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset,DataLoader
import albumentations as A

import cv2
import os
from glob import glob
from tqdm import tqdm
import numpy as np

from itertools import permutations
from scipy.spatial import distance

from easydict  import EasyDict 

## Model - AlexNet

In [66]:
class AlexNet(nn.Module) :
    def __init__(self, 
                 first_in=3, 
                 first_stride=4, 
                 dropout_rate=0.5, 
                 in_feature=256*6*6,
                 out_feature=4608,
                 num_classes=1000
                ) :
        super(AlexNet, self).__init__()
        self.conv_1 = nn.Sequential(
            nn.Conv2d(in_channels=first_in, 
                      out_channels=96, 
                      kernel_size=11, 
                      padding=2, 
                      stride=first_stride),
            nn.ReLU(),
            nn.LocalResponseNorm(size=5, 
                                 k=2,    
                                 alpha=0.0001, 
                                 beta=0.75),
            nn.MaxPool2d(kernel_size=3, 
                        stride=2),
        )
        
        self.conv_2 = nn.Sequential(
            nn.Conv2d(in_channels=96, 
                     out_channels=256,
                      padding=2,
                      kernel_size=5),
            nn.ReLU(),
            nn.LocalResponseNorm(size=5, 
                                 k=2,    
                                 alpha=0.0001, 
                                 beta=0.75),
            nn.MaxPool2d(kernel_size=3, 
                        stride=2),
        )
        
        self.conv_3 = nn.Sequential(
            nn.Conv2d(in_channels=256, 
                     out_channels=384,
                      padding=1,
                      kernel_size=3),
            nn.ReLU(),
        )
        
        self.conv_4 = nn.Sequential(
            nn.Conv2d(in_channels=384, 
                     out_channels=384,
                      padding=1,
                      kernel_size=3),
            nn.ReLU(),
        )
        
        self.conv_5 = nn.Sequential(
            nn.Conv2d(in_channels=384, 
                     out_channels=256,
                      padding=1,
                      kernel_size=3,
                     ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, 
                         stride=2),
        )
                
        self.conv_layers = nn.Sequential(
            self.conv_1, 
            self.conv_2, 
            self.conv_3, 
            self.conv_4, 
            self.conv_5, 
        )
        
        self.fc_layers = nn.Sequential(
            nn.Linear(in_feature, out_feature),
            nn.ReLU(),
            nn.Dropout(p=dropout_rate),
            nn.Linear(out_feature, out_feature),
            nn.ReLU(),
            nn.Dropout(p=dropout_rate),
            nn.Linear(out_feature, num_classes),
        )

        
        for idx, layers in enumerate(self.conv_layers):
            self.initialization(idx, layers)

        self.initialization('fc', self.fc_layers)

    def initialization(self, idx, layers) :
        for layer in layers :
            if 'weight' in dir(layer):
                nn.init.normal_(layer.weight, mean=0, std=0.01)                
                if idx in ['fc', 1, 3, 4] :
                    nn.init.constant_(layer.bias.data, 1)
                elif idx in [0, 2] :
                    nn.init.constant_(layer.bias.data, 0)    
                    
    def forward(self, x) :
        x = self.conv_layers(x)
        # x = x.contiguous().view(-1) 

        x = x.flatten(1)

        x = self.fc_layers(x)
        return x
    


## Model - CFN

In [67]:
class CFN(nn.Module) :
    def __init__(self, 
                 in_channel=27, 
                 strd=2, 
                 in_feature=256*3*3, 
                 out_feature=4608, 
                 num_classes=69) :
        super(CFN, self).__init__()
        self.alexnet = AlexNet(first_in=in_channel, 
                               first_stride=strd, 
                               in_feature=in_feature, 
                               out_feature=out_feature
                              )
        # alexnet 논문에 나와있는 방법으로 초기화한 layer들을 가져옴
        self.conv_layers = self.alexnet.conv_layers
        self.fc6 = self.alexnet.fc_layers[0]
        
        # fc7, fc8, output 포함
        self.classifier = nn.Sequential(
            nn.Linear(out_feature, 4096), 
            nn.ReLU(),
            nn.Linear(4096, 100), 
            nn.ReLU(),
            nn.Linear(100, num_classes)
        )
        
        self.alexnet.initialization('fc', self.classifier)
        
    def forward(self, x) :
        x = self.conv_layers(x)
        x = x.flatten(1)
        x = self.fc6(x)
        x = self.classifier(x)
        return x

## puzzle pipeline

In [68]:
class jigsaw_pipeline() :
    def __init__(self, normalize=True, color_jitter=False) :
        self.color_jitter = color_jitter
        
        if normalize :
            self.transforms = A.Compose([
                A.RandomCrop(64, 64),
                A.Normalize()
            ])
        else : 
            self.transforms = A.Compose([
                A.RandomCrop(64, 64)
            ])
            
    def __call__(self, img) :
        h, w, c = img.shape
        
        # step 1. resize width or height to 256 with preserve the original aspect ratio.
        if h < w :
            resized_w = int(w / h * 256)
            img = cv2.resize(img, (256 , resized_w))

        elif w <= h :
            resized_h = int(h / w * 256)
            img = cv2.resize(img, (resized_h, 256))
        
        # step 2. Random crop size 225 x 225
        img = A.RandomCrop(225, 225)(image = img)['image']
        
        # step 3 and 4. split 3 x 3 grid of 75 x 75 pixels tiles and random crop 64 x 64
        for i in range(3) :
            for j in range(3) :
                crop_img = img[i*75 : (i * 75) + 75, j*75 : (j * 75) + 75, :]
                
                if i == 0 and j == 0 :    
                    tile_img = self.transforms(image = crop_img)['image']
                else : 
                    tile_img = np.concatenate((tile_img,self.transforms(image = crop_img)['image']),axis=2)
        
        return tile_img

## Generate Permutation Sets

In [69]:
def generate_permutation_sets(number_permu_set=100, max_hamming=True, save=True) :

    item = [i for i in range(1,10)]
    permutate_items = np.array(list(permutations(item, 9)))
    max_hamming = True

    N = number_permu_set # permutation 개수
    j = np.random.choice(len(permutate_items), 1, replace=False)

    for i in tqdm(range(1, N+1)) :
        if i == 1 :
            p = np.array(permutate_items[j])
        else : 
            hat_p = np.array(permutate_items[j]).reshape([1, -1])
            p = np.concatenate([p, hat_p], axis=0)

        permutate_items = np.delete(permutate_items, j, axis=0)
        d = distance.cdist(p, permutate_items, metric='hamming').mean(axis=0)

        if max_hamming == True :
            j = np.argmax(d)
        else : 
            j = np.argmin(d)
            
    if save :
        np.save(f'permutation_{N}_sets.npy', p)
    
    return p

## Dataset

In [70]:
class JIGSAW_DATASET(Dataset) :
    def __init__(self, img_list, permu_set, normalize=True) :
        self.img_list = img_list
        self.permu_set = permu_set
        self.pipeline = jigsaw_pipeline(normalize=normalize)
    
    def __len__(self) :
        return len(self.img_list)
    
    def shuffle_tiles(self, tile, permu):
        shuffled_tile = np.zeros_like(tile)
        for idx, num in enumerate(permu, start=1) :
            shuffled_tile[:, :, (idx-1)*3 : idx*3] = tile[:, :, (num-1) * 3 : num * 3]
        return shuffled_tile
            
    def __getitem__(self, idx) :
        img = cv2.imread(self.img_list[idx])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = self.pipeline(img)
        
        permu_idx = np.random.randint(self.permu_set.shape[0])        
        permu = self.permu_set[permu_idx]
        
        img = self.shuffle_tiles(img, permu)
        img = img.transpose(2,0,1)

        return torch.FloatTensor(img), torch.tensor(permu_idx, dtype=torch.int64)

## Training

In [75]:
def _accuracy(pred, label) :
    idx_pred = torch.argmax(pred, dim=1)
    acc = idx_pred == label
    return acc.sum() / acc.shape[0]
            

In [76]:
def training(args) :
    model = CFN(in_feature=1024, num_classes=args.num_classes).to(args.device)
    optimizer = torch.optim.SGD(params=model.parameters(), lr=args.learning_rate)
    criterion  = torch.nn.CrossEntropyLoss()
    
    ps = np.load(args.permutation_path)
    img_list = glob(args.img_path)
    
    jigsaw_dataset = JIGSAW_DATASET(img_list, ps, normalize=False)
    dataloader = DataLoader(jigsaw_dataset, batch_size=args.batch_size, shuffle=True)

    for E in range(1, args.epochs+1) :
        avg_loss, avg_acc = 0, 0
        for img, label in dataloader:
            model.train()
            img = img.to(args.device)
            label = label.to(args.device)

            optimizer.zero_grad()
            pred = model(img)
            loss = criterion(pred, label)
            acc = _accuracy(pred.cpu(), label.cpu())
            
            avg_loss += loss.cpu().item()
            avg_acc += acc.item()
            
            loss.backward()
            optimizer.step()

        print("Epochs :", E, f" || ACC : {avg_acc/len(dataloader):0.4f}  ||  Loss : {avg_loss/len(dataloader):0.4f}")

## Configurations

In [77]:
Options = {
    'permutation_path' : './permutation_30_sets.npy',
    'batch_size' : 64,
    'epochs' : 70,
    'learning_rate' : 0.01,
    'device' : "cuda:0",
    "img_path" : './data/*',
    "num_classes" : 30 # should be same as number of permutations set
}
args = EasyDict(Options)

In [78]:
training(args)

Epochs : 1  || ACC : 0.0384  ||  Loss : 3.3975
Epochs : 2  || ACC : 0.0156  ||  Loss : 3.4285
Epochs : 3  || ACC : 0.0192  ||  Loss : 3.4104
Epochs : 4  || ACC : 0.0348  ||  Loss : 3.4112
Epochs : 5  || ACC : 0.0365  ||  Loss : 3.4155
Epochs : 6  || ACC : 0.0104  ||  Loss : 3.4069
Epochs : 7  || ACC : 0.0244  ||  Loss : 3.3987
Epochs : 8  || ACC : 0.0104  ||  Loss : 3.4188
Epochs : 9  || ACC : 0.0365  ||  Loss : 3.3963
Epochs : 10  || ACC : 0.0367  ||  Loss : 3.4115
Epochs : 11  || ACC : 0.0244  ||  Loss : 3.4127
Epochs : 12  || ACC : 0.0156  ||  Loss : 3.4183
Epochs : 13  || ACC : 0.0156  ||  Loss : 3.4075
Epochs : 14  || ACC : 0.0260  ||  Loss : 3.4077
Epochs : 15  || ACC : 0.0296  ||  Loss : 3.4066
Epochs : 16  || ACC : 0.0260  ||  Loss : 3.3981
Epochs : 17  || ACC : 0.0384  ||  Loss : 3.4013
Epochs : 18  || ACC : 0.0365  ||  Loss : 3.4015
Epochs : 19  || ACC : 0.0348  ||  Loss : 3.4037
Epochs : 20  || ACC : 0.0419  ||  Loss : 3.4066
Epochs : 21  || ACC : 0.0296  ||  Loss : 3.4136
E

KeyboardInterrupt: 

In [None]:
np.random.randint(5)

In [22]:
a = torch.tensor([1,2,3,4,5])
b = torch.tensor([2,2,3,4,5])

c = a == b
c.sum()/c.shape[0]

tensor(0.8000)

In [19]:
a.sum()

tensor(15)