In [1]:
from functools import partial
from glob import glob
import os
import time
import sys

import cv2
import matplotlib.pyplot as plt
import numpy as np
import PIL
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from tqdm import tqdm, tqdm_notebook

from dstorch.utils import random_weight_init

%matplotlib inline

# USE_GPUS = '1'
USE_GPUS = '0, 1, 2, 3'
os.environ['CUDA_VISIBLE_DEVICES'] = USE_GPUS

In [2]:
train_path = './data/train/'
val_path = './data/valid/'

In [3]:
class DogsCatsDataset(Dataset):
    def __init__(self, data_path, category=None, resize=500, input_size=224, 
                 crop_ratio=0.25, degenerate=True, color_drop=True, swap_labels_flag=True, 
                 transform=None):
        
        self.data_path = data_path
        self.category = category
        self.resize = resize
        self.input_size = input_size
        self.crop_ratio = crop_ratio
        self.degenerate = degenerate
        self.color_drop = color_drop
        self.swap_labels_flag = swap_labels_flag
        
        self.transform = transform
        
        self.dog_filename_list = glob(self.data_path + 'dogs/*.jpg')
        self.cat_filename_list = glob(self.data_path + 'cats/*.jpg')
        
    def __getitem__(self, index):
        dog_filename, cat_filename = self.dog_filename_list[index], self.cat_filename_list[index]
        
        dog_img, cat_img = cv2.imread(dog_filename), cv2.imread(cat_filename)
        dog_img, cat_img = cv2.cvtColor(dog_img, cv2.COLOR_BGR2RGB), cv2.cvtColor(cat_img, cv2.COLOR_BGR2RGB)
        
        if self.category is not None:
            img, label = dog_img, 1 if self.category == 'dog' else cat_img, 0
        
        else:
            choice = np.random.choice([0, 1])
            
            img = dog_img if choice else cat_img
            label = choice
           
        # Crop largest square image
        img = self.crop_center(img, 1)        
        # Resize image
        img = cv2.resize(img, (self.resize, self.resize))
        # Get center image
        img_center = self.crop_center(img, self.crop_ratio)
        # Get context image and idx
        img_context, idx = self.get_context_img(img)
        # Resize to 224
        img_center = cv2.resize(img_center, (self.input_size, self.input_size))
        img_context = cv2.resize(img_context, (self.input_size, self.input_size))
        
        if self.degenerate:
            degen_choice = np.random.choice([0, 1], p=[0.9, 0.1])
            
            if degen_choice:
                img_center, img_context = self.degenerate_img(img_center), self.degenerate_img(img_context)
            
        if self.color_drop:
            img_center, img_context = self.color_drop_img(img_center), self.color_drop_img(img_context)
            
        if self.swap_labels_flag:
            swap_choice = np.random.choice([0, 1], p=[0.9, 0.1])

            if swap_choice:
                img_center, img_context, idx = self.swap_labels(img_center, img_context, idx)
            
        # Convert to torch tensor
        img_center = torch.from_numpy(np.rollaxis(img_center / 255, 2).copy()).float()
        img_context = torch.from_numpy(np.rollaxis(img_context / 255, 2).copy()).float()
        
        # Stats of the training data
        center_mean = [0.4918, 0.4464, 0.4091]
        context_mean = [0.4868, 0.4494, 0.4100]
        
        center_std = [0.2016, 0.1920, 0.1911]
        context_std = [0.1939, 0.1853, 0.1876] 

        # Normalization
        img_center -= torch.tensor(center_mean).view(3, 1, 1)
        img_center /= torch.tensor(center_std).view(3, 1, 1)
        
        img_context -= torch.tensor(context_mean).view(3, 1, 1)
        img_context /= torch.tensor(context_std).view(3, 1, 1)
        
        if self.transform is not None:
            img_center = self.transform(img_center)
            img_context = self.transform(img_context)
            
        return (img_center, img_context), idx, label
    
    def __len__(self):
        
        return len(self.dog_filename_list)
    
    @staticmethod
    def swap_labels(img_center, img_context, idx):
        _, idx = divmod(idx - 4, 8)
        
        return img_context, img_center, idx
    
    @staticmethod
    def degenerate_img(img, low=0.1):
        img = img.copy()
        
        input_size = img.shape[0]

        p = np.random.uniform(low=low, size=1)        
        degen_size = int(p * input_size)

        img = cv2.resize(img, (degen_size, degen_size))
        img = cv2.resize(img, (input_size, input_size))

        return img
        
    @staticmethod
    def color_drop_img(img, std=0.01):
        img = img.copy()

        color_set = {0, 1, 2}
        drop_set = set(np.random.choice((0, 1, 2), size=2, replace=False).tolist())

        idx_remained = next(iter(color_set - drop_set))

        color_std = img[:, :, idx_remained].std() * std

        for color_idx in drop_set:
            mean = img[:, :, color_idx].mean()
            channel_data = np.random.normal(mean, color_std, size=[img.shape[0], img.shape[1]])
            img[:, :, color_idx] = channel_data

        return img

    @staticmethod
    def crop_center(img, crop_ratio):       
        x, y, _ = img.shape

        max_size = min(x, y)
        max_size = int(max_size * crop_ratio)

        x_start = max(x // 2 - (max_size // 2), 0)
        y_start = max(y // 2 - (max_size // 2), 0)

        x_end = x_start + max_size
        y_end = y_start + max_size

        img = img[x_start:x_end, y_start:y_end, :]

        return img
    
    def get_context_img(self, img):
        x, y, _ = img.shape

        max_size = min(x, y)
        max_size = int(max_size * self.crop_ratio)
        gap_size = max_size // 4
        
        x_center, y_center = x // 2, y // 2
        
        candidate_dict = {
                          0: (-1, -1),
                          1: (-1, 0),
                          2: (-1, 1),
                          3: (0, 1),
                          4: (1, 1),
                          5: (1, 0),
                          6: (1, -1),
                          7: (0, -1)
                         }
        
        idx = int(np.random.choice(range(len(candidate_dict))))
        context_factor = 150 + int(np.random.normal(0, 7))
        
        x_compare = x_center + context_factor * candidate_dict[idx][0]    
        y_compare = y_center + context_factor * candidate_dict[idx][1]
        
        x_compare_start, y_compare_start = x_compare - max_size // 2, y_compare - max_size // 2
        x_compare_end, y_compare_end = x_compare_start + max_size, y_compare_start + max_size
        
        return img[x_compare_start:x_compare_end, y_compare_start:y_compare_end, :], idx    

In [4]:
# data_path_list = glob(train_path + 'dogs/*.jpg') + glob(train_path + 'cats/*.jpg')

# bgr2rgb = partial(cv2.cvtColor, code=cv2.COLOR_BGR2RGB)

# img_list = list(map(bgr2rgb, [cv2.imread(path) for path in tqdm_notebook(data_path_list)]))

In [5]:
# row_list, col_list = [], []

# for img in img_list:
#     num_rows, num_cols, num_channels = img.shape
    
#     row_list.append(num_rows)
#     col_list.append(num_cols)
        
# print("Rows: Max: {}, Min: {}, Median: {}".format(np.max(row_list), np.min(row_list), 
#                                                   np.median(row_list)))
# print("Cols: Max: {}, Min: {}, Median: {}".format(np.max(col_list), np.min(col_list), 
#                                                   np.median(col_list)))


In [6]:
batch_size = 8 * len(USE_GPUS.split(','))

kwargs = {'num_workers': len(USE_GPUS.split(',')) * 4, 'pin_memory': False} \
          if torch.cuda.is_available() else {}

# mean_tuple = (0.4901, 0.4546, 0.4159)
# std_tuple = (0.1171, 0.1086, 0.1122)

# transform_train = transforms.Compose([
# #                                       transforms.RandomHorizontalFlip(),
#                                       transforms.Normalize(mean_tuple, std_tuple)
#                                      ])

# transform_val = transforms.Compose([
#                                      transforms.Normalize(mean_tuple, std_tuple)
#                                     ])

training_set = DogsCatsDataset(train_path)
training_no_aug_set = DogsCatsDataset(train_path, degenerate=False, 
                                      color_drop=False, swap_labels_flag=False)
validation_set = DogsCatsDataset(val_path, degenerate=False, 
                                 color_drop=False, swap_labels_flag=False)

train_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True, **kwargs)
train_no_aug_loader = DataLoader(training_set, batch_size=batch_size, shuffle=False,**kwargs)
valid_loader = DataLoader(validation_set, batch_size=batch_size, shuffle=False, **kwargs)

In [7]:
# def calc_data_stats(data_loader, dtype='image', repeat=20):
#     batch_size = data_loader.batch_size

#     sample_size = 0
    
#     center_mean_list = []
#     context_mean_list = []
    
#     for _ in tqdm_notebook(range(repeat)):
#         for (img_center, img_context), idx, label in tqdm_notebook(data_loader):
#             if torch.cuda.is_available():
#                 img_center, img_context = img_center.cuda(non_blocking=True), img_context.cuda(non_blocking=True)
                
#             img_center = img_center.view(img_center.size(0), img_center.size(1), -1)
#             img_context = img_context.view(img_context.size(0), img_context.size(1), -1)
            
#             center_mean_list.append(img_center.mean(2).mean(0))
#             context_mean_list.append(img_context.mean(2).mean(0))

#             sample_size += img_center.shape[0] 
        
#     center_mean = torch.stack(center_mean_list).mean(dim=0)
#     context_mean = torch.stack(context_mean_list).mean(dim=0)
    
#     center_se = torch.stack(center_mean_list).std(dim=0)
#     context_se = torch.stack(context_mean_list).std(dim=0)
    
#     center_std = center_se * torch.sqrt(torch.tensor(float(batch_size)))
#     context_std = context_se * torch.sqrt(torch.tensor(float(batch_size)))
    

#     print("=============================")
#     print("Dataset size: {}".format(sample_size))
#     print("Batch size: {}".format(batch_size))
    
#     print("Center Mean: {}".format(center_mean))
#     print("Context Mean: {}".format(context_mean))
    
#     print("Center STD: {}".format(center_std))
#     print("ContextSTD: {}".format(context_std))

In [8]:
# calc_data_stats(train_loader)

In [9]:
class ContextNet(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
#         self.resnet = nn.Sequential(*list(models.resnet50(pretrained=pretrained).children())[:-1])
        self.resnet = nn.Sequential(*list(models.resnet152(pretrained=pretrained).children())[:-1])
        
        self.classifier = nn.Sequential(
                                        nn.Linear(2048, 2048),
                                        nn.Dropout(0.5),
                                        nn.ReLU(inplace=True),
                                        nn.Linear(2048, 8),
                                        nn.LogSoftmax(dim=1)
        )
               
    def forward(self, x):
        img_center, img_compare = x
        
        img_center = self.resnet(img_center)
        img_compare = self.resnet(img_compare)
        
        img_flat_size = img_center.size(1) * img_center.size(2) * img_center.size(3)
        
        img_center = img_center.view(-1, img_flat_size)
        img_compare = img_compare.view(-1, img_flat_size)
        
        img = img_center + img_compare
        
        output = self.classifier(img)
        
        
        return output

In [10]:
model = ContextNet(False)
random_weight_init(model)
model = nn.DataParallel(model)

In [None]:
loss_list, loss_test_list = [], []

epochs = 5000
lr = 0.001

loss_func = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.001)

start_time = time.time()

for epoch in range(epochs):
    loss_temp_list, loss_test_temp_list = [], []
    
    epoch_start_time = time.time()
    model.train()
    
    if torch.cuda.is_available():
        model = model.cuda()

        if not next(model.parameters()).is_cuda:
            raise TypeError("model.cuda() is not working!")

    else:
        model = model.cpu()  

    for batch_idx, (data, idx, _) in enumerate(train_loader):
        img_center, img_context = data
        
        if torch.cuda.is_available():
            img_center, img_context = img_center.cuda(non_blocking=True), img_context.cuda(non_blocking=True)
            idx = idx.cuda(non_blocking=True)
        else:
            img_center, img_context, idx = img_center.cpu(), img_context.cpu(), idx.cpu()

        img_center, img_context, idx = Variable(img_center), Variable(img_context), Variable(idx)
    
        optimizer.zero_grad()
        
        output = model((img_center, img_context))
        
        loss = loss_func(output, idx)
        
        loss.backward()
        optimizer.step()
                
        if batch_idx % 20 == 0:
            print("Epoch: {0}, It: {1}, Loss: {2:.4f}".format(epoch, batch_idx, loss))
     
    model.eval()
        
    num_rows_train = 0
    correct_train = 0
    
    # Training set
    with torch.no_grad():
        for batch_idx, (data, idx, _) in enumerate(train_no_aug_loader):
            img_center, img_context = data
            
            if torch.cuda.is_available():
                img_center, img_context = img_center.cuda(non_blocking=True), img_context.cuda(non_blocking=True)
                idx = idx.cuda(non_blocking=True)
            else:
                img_center, img_context = img_center.cpu(), img_context.cpu(), idx.cpu()

            img_center, img_context, idx = Variable(img_center), Variable(img_context), Variable(idx)
            num_rows_train += img_center.size(0)
            
            output = model((img_center, img_context))
            loss = loss_func(output, idx)
            
            pred = output.data.max(1, keepdim=True)[1]
            correct_train += pred.eq(idx.data.view_as(pred)).cpu().sum()
            
            loss_temp_list.append(loss.detach().item())
     
    num_rows = 0
    correct = 0
    
    with torch.no_grad():
        for batch_idx, (data, idx, _) in enumerate(valid_loader):
            img_center, img_context = data
            
            if torch.cuda.is_available():
                img_center, img_context = img_center.cuda(non_blocking=True), img_context.cuda(non_blocking=True)
                idx = idx.cuda(non_blocking=True)
            else:
                img_center, img_context = img_center.cpu(), img_context.cpu(), idx.cpu()

            img_center, img_context, idx = Variable(img_center), Variable(img_context), Variable(idx)
            num_rows += img_center.size(0)
            
            output = model((img_center, img_context))
            loss_test = loss_func(output, idx)
            
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(idx.data.view_as(pred)).cpu().sum()
            
            loss_test_temp_list.append(loss_test.detach().item())
            
    loss_mean, loss_test_mean = np.mean(loss_temp_list), np.mean(loss_test_temp_list)
    
    # Save model
    model_path = 'models/model_{}.pth'.format(epoch)
    torch.save(model.state_dict(), model_path)
    print("Model saved at {}".format(model_path))
    
    print("=====================================================================")
    print("Training loss: {:.4f} Training acc: {:.4f}%".format(loss_mean, (100. * correct_train.item()) / num_rows_train))
    print("Test loss: {0:.4f}, Test acc: {1:.4f}%".format(np.mean(loss_test_mean), (100. * correct.item()) / num_rows))
    print("Epoch took {} seconds.".format(time.time() - epoch_start_time))
    print("=====================================================================")
    
    loss_list.append(loss_mean)
    loss_test_list.append(loss_test_mean)
    
    fig = plt.figure(figsize=(14, 8))

    # Draw for training loss
    plt.plot(range(len(loss_list)), list(map(np.log, loss_list)), label="Training")
    # Draw for test loss
    plt.plot(range(len(loss_test_list)), list(map(np.log, loss_test_list)), label="Test")
    # Side info
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Learning Curve")
    plt.legend()

    plt.show()
        
print("Training took {} seconds\n".format(round(time.time() - start_time, 2)))

Epoch: 0, It: 0, Loss: 5.9142
Epoch: 0, It: 20, Loss: 2.5943
