In [1]:
from functools import partial
from glob import glob
import os
import time
import sys

import cv2
import matplotlib.pyplot as plt
import numpy as np
import PIL
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from tqdm import tqdm, tqdm_notebook

# from dstorch.data import calc_data_stats

%matplotlib inline

# USE_GPUS = '1'
USE_GPUS = '1, 2, 3'
os.environ['CUDA_VISIBLE_DEVICES'] = USE_GPUS

In [2]:
train_path = './data/train/'
val_path = './data/valid/'

In [3]:
class DogsCatsDataset(Dataset):
    def __init__(self, data_path, category=None, resize=500, input_size=224, 
                 crop_ratio=0.25, degenerate=True, color_drop=True, transform=None):
        self.data_path = data_path
        self.category = category
        self.resize = resize
        self.input_size = input_size
        self.crop_ratio = crop_ratio
        self.degenerate = degenerate
        self.color_drop = color_drop
        
        self.transform = transform
        
        self.dog_filename_list = glob(self.data_path + 'dogs/*.jpg')
        self.cat_filename_list = glob(self.data_path + 'cats/*.jpg')
        
    def __getitem__(self, index):
        dog_filename, cat_filename = self.dog_filename_list[index], self.cat_filename_list[index]
        
        dog_img, cat_img = cv2.imread(dog_filename), cv2.imread(cat_filename)
        dog_img, cat_img = cv2.cvtColor(dog_img, cv2.COLOR_BGR2RGB), cv2.cvtColor(cat_img, cv2.COLOR_BGR2RGB)
        
        if self.category is not None:
            img, label = dog_img, 1 if self.category == 'dog' else cat_img, 0
        
        else:
            choice = np.random.choice([0, 1])
            
            img = dog_img if choice else cat_img
            label = choice
           
        # Crop largest square image
        img = self.crop_center(img, 1)        
        # Resize image
        img = cv2.resize(img, (self.resize, self.resize))
        # Get center image
        img_center = self.crop_center(img, self.crop_ratio)
        # Get context image and idx
        img_context, idx = self.get_context_img(img)
        # Resize to 224
        img_center = cv2.resize(img_center, (self.input_size, self.input_size))
        img_context = cv2.resize(img_context, (self.input_size, self.input_size))
        
        if self.degenerate:
            p = np.random.rand(1)
            
            degen_size = int(p * self.input_size)
            
            img_center = cv2.resize(img_center, (degen_size, degen_size))
            img_center = cv2.resize(img_center, (self.input_size, self.input_size))
            
            img_context = cv2.resize(img_context, (degen_size, degen_size))
            img_context = cv2.resize(img_context, (self.input_size, self.input_size))
            
        if self.color_drop:
            color_set = {0, 1, 2}
            drop_set_center = set(np.random.choice((0, 1, 2), size=2, replace=False).tolist())
            drop_set_context = set(np.random.choice((0, 1, 2), size=2, replace=False).tolist())
            
            idx_remained_center = next(iter(color_set - drop_set_center))
            idx_remained_context = next(iter(color_set - drop_set_context))
            
            color_std_center = img_center[:, :, idx_remained_center].std() * 0.01
            color_std_context = img_context[:, :, idx_remained_context].std() * 0.01
            
            for color_idx in drop_list_center:
                center_mean = img_center[:, :, color_index]
                img_center[:, :, color_idx] = np.random.normal(center_mean, color_std_center, size=[self.input_size, self.input_size])
                
            for color_idx in drop_list_context:
                context_mean = img_context[:, :, color_index]
                img_context[:, :, color_index] = np.random.normal(context_mean, color_std_context, size=[self.input_size, self.input_size])
            
        # Convert to torch tensor
        img_center = torch.from_numpy(np.rollaxis(img_center / 255, 2).copy()).float()
        img_context = torch.from_numpy(np.rollaxis(img_context / 255, 2).copy()).float()
        
        if self.transform is not None:
            img_center = self.transform(img_center)
            img_context = self.transform(img_context)
            
#         print(img_center.shape, img_context.shape)
        return (img_center, img_context), idx, label
    
    def __len__(self):
        
        return len(self.dog_filename_list)
    
    @staticmethod
    def crop_center(img, crop_ratio):       
        x, y, _ = img.shape

        max_size = min(x, y)
        max_size = int(max_size * crop_ratio)

        x_start = max(x // 2 - (max_size // 2), 0)
        y_start = max(y // 2 - (max_size // 2), 0)

        x_end = x_start + max_size
        y_end = y_start + max_size

        img = img[x_start:x_end, y_start:y_end, :]

        return img
    
    def get_context_img(self, img):
        x, y, _ = img.shape

        max_size = min(x, y)
        max_size = int(max_size * self.crop_ratio)
        gap_size = max_size // 4
        
        x_center, y_center = x // 2, y // 2
        
        candidate_dict = {
                          0: (-1, -1),
                          1: (-1, 0),
                          2: (-1, 1),
                          3: (0, 1),
                          4: (1, 1),
                          5: (1, 0),
                          6: (1, -1),
                          7: (0, -1)
                         }
        
        idx = int(np.random.choice(range(len(candidate_dict))))
        context_factor = 150 + int(np.random.normal(0, 7))
        
        x_compare = x_center + context_factor * candidate_dict[idx][0]    
        y_compare = y_center + context_factor * candidate_dict[idx][1]
        
        x_compare_start, y_compare_start = x_compare - max_size // 2, y_compare - max_size // 2
        x_compare_end, y_compare_end = x_compare_start + max_size, y_compare_start + max_size
        
#         print("x_center, y_center: {}".format((x_center, y_center)))
#         print("x_compare, y_compare: {}".format((y_compare, y_compare)))
#         print("Idx: {}".format(idx))
#         print(x_compare_start, x_compare_end, y_compare_start, y_compare_end)
        
        return img[x_compare_start:x_compare_end, y_compare_start:y_compare_end, :], idx    

In [4]:
# data_path_list = glob(train_path + 'dogs/*.jpg') + glob(train_path + 'cats/*.jpg')

# bgr2rgb = partial(cv2.cvtColor, code=cv2.COLOR_BGR2RGB)

# img_list = list(map(bgr2rgb, [cv2.imread(path) for path in tqdm_notebook(data_path_list)]))

In [5]:
# row_list, col_list = [], []

# for img in img_list:
#     num_rows, num_cols, num_channels = img.shape
    
#     row_list.append(num_rows)
#     col_list.append(num_cols)
        
# print("Rows: Max: {}, Min: {}, Median: {}".format(np.max(row_list), np.min(row_list), 
#                                                   np.median(row_list)))
# print("Cols: Max: {}, Min: {}, Median: {}".format(np.max(col_list), np.min(col_list), 
#                                                   np.median(col_list)))


In [6]:
batch_size = 24 * len(USE_GPUS.split(','))

kwargs = {'num_workers': len(USE_GPUS.split(',')) * 4, 'pin_memory': False} \
          if torch.cuda.is_available() else {}

# mean_tuple = (0.4901, 0.4546, 0.4159)
# std_tuple = (0.1171, 0.1086, 0.1122)

# transform_train = transforms.Compose([
# #                                       transforms.RandomHorizontalFlip(),
#                                       transforms.Normalize(mean_tuple, std_tuple)
#                                      ])

# transform_val = transforms.Compose([
#                                      transforms.Normalize(mean_tuple, std_tuple)
#                                     ])

training_set = DogsCatsDataset(train_path)
validation_set = DogsCatsDataset(val_path)

train_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True, **kwargs)
valid_loader = DataLoader(validation_set, batch_size=batch_size, shuffle=False, **kwargs)

In [7]:
# def calc_data_stats(data_loader, dtype='image', repeat=10):
#     batch_size = data_loader.batch_size

#     sample_size = 0
#     mean_list = []
    
#     for _ in tqdm_notebook(range(repeat)):
#         for data in data_loader:
#             x, y = data
#             x = x.view(x.size(0), x.size(1), -1)
#             mean_list.append(x.mean(2).mean(0))

#             sample_size += x.shape[0] 
        
#     mean = torch.stack(mean_list).mean(dim=0)
#     se = torch.stack(mean_list).std(dim=0)
#     std = se * torch.sqrt(torch.tensor(float(batch_size)))

#     print("=============================")
#     print("Dataset size: {}".format(sample_size))
#     print("Batch size: {}".format(batch_size))
#     print("Mean: {}".format(mean))
#     print("STD: {}".format(std))

#     return mean, std

In [8]:
# calc_data_stats(train_loader)

In [9]:
class ContextNet(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        self.resnet50 = nn.Sequential(*list(models.resnet50(pretrained=pretrained).children())[:-1])
        
        self.classifier = nn.Sequential(
                                        nn.Linear(2048, 2048),
                                        nn.Dropout(0.5),
                                        nn.ReLU(inplace=True),
                                        nn.Linear(2048, 8),
                                        nn.LogSoftmax(dim=1)
        )
               
    def forward(self, x):
        img_center, img_compare = x
        
        img_center = self.resnet50(img_center)
        img_compare = self.resnet50(img_compare)
        
        img_flat_size = img_center.size(1) * img_center.size(2) * img_center.size(3)
        
        img_center = img_center.view(-1, img_flat_size)
        img_compare = img_compare.view(-1, img_flat_size)
        
        img = img_center + img_compare
        
        output = self.classifier(img)
        
        
        return output

In [10]:
model = ContextNet(True)
model = nn.DataParallel(model)

In [None]:
loss_list, loss_test_list = [], []

epochs = 450
lr = 0.0001

loss_func = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.0003)

start_time = time.time()

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    model.train()
    
    if torch.cuda.is_available():
        model = model.cuda()

        if not next(model.parameters()).is_cuda:
            raise TypeError("model.cuda() is not working!")

    else:
        model = model.cpu()  

    for batch_idx, (data, idx, _) in enumerate(train_loader):
        img_center, img_context = data
        
        if torch.cuda.is_available():
            img_center, img_context = img_center.cuda(non_blocking=True), img_context.cuda(non_blocking=True)
            idx = idx.cuda(non_blocking=True)
        else:
            img_center, img_context, idx = img_center.cpu(), img_context.cpu(), idx.cpu()

        img_center, img_context, idx = Variable(img_center), Variable(img_context), Variable(idx)
    
        optimizer.zero_grad()
        
        output = model((img_center, img_context))
        
        loss = loss_func(output, idx)
        
        loss.backward()
        optimizer.step()
        
        loss_list.append(loss.detach())
                
        if (batch_idx + 1) % 20 == 0:
            print("Epoch: {0}, It: {1}, Loss: {2:.4f}".format(epoch, batch_idx, loss))
     
    model.eval()
        
    num_rows_train = 0
    correct_train = 0
    
    # Training set
    with torch.no_grad():
        for batch_idx, (data, idx, _) in enumerate(train_loader):
            img_center, img_context = data
            
            if torch.cuda.is_available():
                img_center, img_context = img_center.cuda(non_blocking=True), img_context.cuda(non_blocking=True)
                idx = idx.cuda(non_blocking=True)
            else:
                img_center, img_context = img_center.cpu(), img_context.cpu(), idx.cpu()

            img_center, img_context, idx = Variable(img_center), Variable(img_context), Variable(idx)
            num_rows_train += img_center.size(0)
            
            output = model((img_center, img_context))
            
            pred = output.data.max(1, keepdim=True)[1]
            correct_train += pred.eq(idx.data.view_as(pred)).cpu().sum()
     
    num_rows = 0
    correct = 0
    
    with torch.no_grad():
        for batch_idx, (data, idx, _) in enumerate(valid_loader):
            img_center, img_context = data
            
            if torch.cuda.is_available():
                img_center, img_context = img_center.cuda(non_blocking=True), img_context.cuda(non_blocking=True)
                idx = idx.cuda(non_blocking=True)
            else:
                img_center, img_context = img_center.cpu(), img_context.cpu(), idx.cpu()

            img_center, img_context, idx = Variable(img_center), Variable(img_context), Variable(idx)
            num_rows += img_center.size(0)
            
            output = model((img_center, img_context))
            loss_test = loss_func(output, idx)
            
            loss_test_list.append(loss_test)
            loss_test_mean = torch.mean(torch.stack(loss_test_list))
            
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(idx.data.view_as(pred)).cpu().sum()
    
#     # Save model
#     model_path = 'models/model_{}.pth'.format(epoch - 1)
#     torch.save(model.state_dict(), model_path)
#     print("Model saved at {}".format(model_path))
    
    print("=====================================================================")
    print("Training set acc: {:.4f}%".format((100. * correct_train.item()) / num_rows_train))
    print("Loss: {0:.4f}, Test acc: {1:.4f}%".format(loss_test_mean, (100. * correct.item()) / num_rows))
    print("Epoch took {} seconds.".format(time.time() - epoch_start_time))
    print("=====================================================================")
        
print("Training took {} seconds\n".format(round(time.time() - start_time, 2)))

Epoch: 1, It: 4, Loss: 2.1139
Epoch: 1, It: 9, Loss: 2.0591
Epoch: 1, It: 14, Loss: 2.1368
Epoch: 1, It: 19, Loss: 2.0839
Epoch: 1, It: 24, Loss: 2.0468
Epoch: 1, It: 29, Loss: 1.9869
Epoch: 1, It: 34, Loss: 1.9175
Epoch: 1, It: 39, Loss: 2.2743
Epoch: 1, It: 44, Loss: 1.9528
Epoch: 1, It: 49, Loss: 2.1454
Epoch: 1, It: 54, Loss: 1.9786
Epoch: 1, It: 59, Loss: 2.0264
Epoch: 1, It: 64, Loss: 1.9734
Epoch: 1, It: 69, Loss: 1.9406
Epoch: 1, It: 74, Loss: 1.9951
Epoch: 1, It: 79, Loss: 1.9732
Epoch: 1, It: 84, Loss: 2.1117
Epoch: 1, It: 89, Loss: 1.8681
Epoch: 1, It: 94, Loss: 1.8485
Epoch: 1, It: 99, Loss: 2.0536
Epoch: 1, It: 104, Loss: 1.9380
Epoch: 1, It: 109, Loss: 2.0091
Epoch: 1, It: 114, Loss: 1.8484
Epoch: 1, It: 119, Loss: 1.8370
Epoch: 1, It: 124, Loss: 2.0358
Epoch: 1, It: 129, Loss: 2.0042
Epoch: 1, It: 134, Loss: 1.9677
Epoch: 1, It: 139, Loss: 1.8727
Epoch: 1, It: 144, Loss: 2.0972
Epoch: 1, It: 149, Loss: 1.8399
Epoch: 1, It: 154, Loss: 1.7145
Epoch: 1, It: 159, Loss: 1.872

Epoch: 7, It: 159, Loss: 1.4652
Loss: 1.8251, Acc: 32.9
Epoch took 192.5610854625702 seconds.
Epoch: 8, It: 4, Loss: 1.5034
Epoch: 8, It: 9, Loss: 1.5391
Epoch: 8, It: 14, Loss: 1.3655
Epoch: 8, It: 19, Loss: 1.5633
Epoch: 8, It: 24, Loss: 1.3461
Epoch: 8, It: 29, Loss: 1.6193
Epoch: 8, It: 34, Loss: 1.2702
Epoch: 8, It: 39, Loss: 1.3866
Epoch: 8, It: 44, Loss: 1.4748
Epoch: 8, It: 49, Loss: 1.5085
Epoch: 8, It: 54, Loss: 1.4975
Epoch: 8, It: 59, Loss: 1.3945
Epoch: 8, It: 64, Loss: 1.2965
Epoch: 8, It: 69, Loss: 1.4090
Epoch: 8, It: 74, Loss: 1.4160
Epoch: 8, It: 79, Loss: 1.2893
Epoch: 8, It: 84, Loss: 1.5098
Epoch: 8, It: 89, Loss: 1.4054
Epoch: 8, It: 94, Loss: 1.3378
Epoch: 8, It: 99, Loss: 1.4296
Epoch: 8, It: 104, Loss: 1.5927
Epoch: 8, It: 109, Loss: 1.5504
Epoch: 8, It: 114, Loss: 1.3393
Epoch: 8, It: 119, Loss: 1.3872
Epoch: 8, It: 124, Loss: 1.3502
Epoch: 8, It: 129, Loss: 1.3158
Epoch: 8, It: 134, Loss: 1.2761
Epoch: 8, It: 139, Loss: 1.5017
Epoch: 8, It: 144, Loss: 1.3507


Epoch: 14, It: 99, Loss: 1.4527
Epoch: 14, It: 104, Loss: 1.2785
Epoch: 14, It: 109, Loss: 1.2331
Epoch: 14, It: 114, Loss: 1.2881
Epoch: 14, It: 119, Loss: 1.1714
Epoch: 14, It: 124, Loss: 1.3134
Epoch: 14, It: 129, Loss: 1.4112
Epoch: 14, It: 134, Loss: 1.1426
Epoch: 14, It: 139, Loss: 1.5264
Epoch: 14, It: 144, Loss: 1.3115
Epoch: 14, It: 149, Loss: 1.3799
Epoch: 14, It: 154, Loss: 1.3431
Epoch: 14, It: 159, Loss: 1.4040
Loss: 1.8174, Acc: 29.2
Epoch took 192.24722695350647 seconds.
Epoch: 15, It: 4, Loss: 1.2634
Epoch: 15, It: 9, Loss: 1.2661
Epoch: 15, It: 14, Loss: 1.1249
Epoch: 15, It: 19, Loss: 1.2944
Epoch: 15, It: 24, Loss: 1.2867
Epoch: 15, It: 29, Loss: 1.4029
Epoch: 15, It: 34, Loss: 1.1330
Epoch: 15, It: 39, Loss: 1.4956
Epoch: 15, It: 44, Loss: 1.3098
Epoch: 15, It: 49, Loss: 1.4585
Epoch: 15, It: 54, Loss: 1.2980
Epoch: 15, It: 59, Loss: 1.2665
Epoch: 15, It: 64, Loss: 1.5352
Epoch: 15, It: 69, Loss: 1.2852
Epoch: 15, It: 74, Loss: 1.2989
Epoch: 15, It: 79, Loss: 1.2503

Epoch: 21, It: 29, Loss: 1.4540
Epoch: 21, It: 34, Loss: 1.2422
Epoch: 21, It: 39, Loss: 1.1568
Epoch: 21, It: 44, Loss: 1.2351
Epoch: 21, It: 49, Loss: 1.1834
Epoch: 21, It: 54, Loss: 1.3804
Epoch: 21, It: 59, Loss: 1.5201
Epoch: 21, It: 64, Loss: 1.2787
Epoch: 21, It: 69, Loss: 1.2660
Epoch: 21, It: 74, Loss: 1.2991
Epoch: 21, It: 79, Loss: 1.1951
Epoch: 21, It: 84, Loss: 1.3023
Epoch: 21, It: 89, Loss: 1.1448
Epoch: 21, It: 94, Loss: 1.0200
Epoch: 21, It: 99, Loss: 1.5332
Epoch: 21, It: 104, Loss: 1.4796
Epoch: 21, It: 109, Loss: 1.2523
Epoch: 21, It: 114, Loss: 1.4069
Epoch: 21, It: 119, Loss: 1.0840
Epoch: 21, It: 124, Loss: 1.1735
Epoch: 21, It: 129, Loss: 1.2687
Epoch: 21, It: 134, Loss: 1.4035
Epoch: 21, It: 139, Loss: 1.3449
Epoch: 21, It: 144, Loss: 1.3809
Epoch: 21, It: 149, Loss: 1.2041
Epoch: 21, It: 154, Loss: 1.3316
Epoch: 21, It: 159, Loss: 1.2068
Loss: 1.7989, Acc: 32.4
Epoch took 197.2409589290619 seconds.
Epoch: 22, It: 4, Loss: 1.3614
Epoch: 22, It: 9, Loss: 1.3075
