In [1]:
import torch
from torch import nn
from torch.utils import data
import os
import numpy as np
from torch.utils.data import Sampler, BatchSampler,SequentialSampler, RandomSampler, DataLoader
from skimage import io, transform
import matplotlib.pyplot as plt
from torchvision import transforms, utils, models

In [2]:
def listdir_fullpath(d):
    return [os.path.join(d, f) for f in os.listdir(d)]

In [3]:
def create_filelists(data_folder):
    identities = os.listdir(data_folder)
    identities.sort()
    n_identities = len(identities)
    data_source = []
    Filelist = []
    idx = 0
    for identity in identities:
        filenames = listdir_fullpath(data_folder + identity)
        filenames.sort()
        samples = [{'filename':filenames[i],'idx':idx+i} for i in range(len(filenames))]
        data_source.append(samples)
        Filelist.extend(filenames)
        idx = idx + len(filenames)
    return data_source, Filelist, identities

In [4]:
data_folder = '/media/jmukherjee/2TB_Store/DNN_Projects/Adversarial/data/VGGFACE2_100/Train/'
data_source , Filelist, identity_list = create_filelists(data_folder)

In [5]:
print(data_source[-1][-1])

{'idx': 28183, 'filename': '/media/jmukherjee/2TB_Store/DNN_Projects/Adversarial/data/VGGFACE2_100/Train/n006996/0373_02.jpg'}


In [6]:
class PKSampler(Sampler):
    def __init__(self, data_source, p , k):
        self.data_source = data_source #list of lists containing id and filenames and index in total filelist
        self.p = p
        self.k = k
        self.dataset_size = data_source[-1][-1]['idx']
        self.length = self.dataset_size // ((len(self.data_source) // self.p) * self.p )
        
    def __iter__(self):
        id_list_len = len(self.data_source)
        batch = []
        
        for n_pass in range(self.dataset_size // (((id_list_len // self.p) * self.p ) * self.k)):
            random_id_list = np.random.permutation(id_list_len)
            id_idx = 0
            for r_id in range((id_list_len//self.p)-1):
                if id_idx>=id_list_len:
                    break
                for p_i in range(self.p):
                    id_idx = id_idx + 1
                    random_sample_list = np.random.permutation(len(self.data_source[random_id_list[id_idx]]))
                    for k_i in range(self.k):
                        batch.append(self.data_source[random_id_list[id_idx]][random_sample_list[k_i]]['idx'])
                yield batch
                batch = []
                
    def __len__(self):
        return self.length


In [7]:
class VGGFace2_100_PK(data.Dataset):
    def __init__(self, Filelist, cropsize=[224,224] ,transform=None):
        self.transform = transform
        self.cropsize = cropsize
        self.Filelist = Filelist
                    
    def __len__(self):
        return len(self.Filelist)

    def __getitem__(self, index):
        img_name = self.Filelist[index]
        image = io.imread(img_name)/255.0
              
        h,w = image.shape[:2]
#         print(h,w)
        min_dim = min(h,w)
#         print(min_dim)
        
        if min_dim <= 224:
            image = transform.rescale(image, 228.0/min_dim)
            h,w = image.shape[:2]
#             print("resized = ",h,w)
        
        
        new_h,new_w = self.cropsize
        
#         print(h,w,new_h,new_w)
        
        top = np.random.randint(0, h - new_h )
        left = np.random.randint(0, w - new_w )
            
        image = image[ top : top + new_h,
                      left : left + new_w ]
        
#         plt.figure(),
#         plt.imshow(image)
        
        image = image.transpose(2,0,1)
            
        id_name = img_name.split("/")[-2]

        return image,id_name

In [8]:
def class_hist(class_id_list, batch_ids):
    counts = { class_id_list[i]:0 for i in range(len(class_id_list)) }
    for i in range(len(batch_ids)):
        counts[batch_ids[i]]+= 1
    print("\n**Class Histogram**\n")
    for i in range(len(class_id_list)):
        if(counts[class_id_list[i]]>0):
            print(class_id_list[i],counts[class_id_list[i]])

In [9]:
class VGG16convnet(nn.Module):
            def __init__(self):
                super(VGG16convnet, self).__init__()
                self.model = models.vgg16_bn(pretrained=True).cuda().float()
                self.model.classifier = nn.Sequential(*list(self.model.classifier.children())[:4])
                
            def forward(self, x):
                return self.model.forward(x)


In [13]:
def train():
    net = VGG16convnet().cuda()
    VGGTest = VGGFace2_100_PK(Filelist)
    pksampler = PKSampler(data_source, 5 , 4)
    dataloader = DataLoader(VGGTest, batch_sampler=pksampler,
                        shuffle=False, num_workers=1)
    
    epoch_count=2
    
    for epoch_no in range(epoch_count):
        print("Epoch: "+str(epoch_no))
        for i, (images, ids) in enumerate(dataloader):
            embeddings = net(images.float().cuda())
            print(embeddings.shape)
            print(ids)
            break
            

In [14]:
train()

Epoch: 0


  warn("The default mode, 'constant', will be changed to 'reflect' in "


torch.Size([20, 4096])
('n004524', 'n004524', 'n004524', 'n004524', 'n000415', 'n000415', 'n000415', 'n000415', 'n006369', 'n006369', 'n006369', 'n006369', 'n004431', 'n004431', 'n004431', 'n004431', 'n001403', 'n001403', 'n001403', 'n001403')
Epoch: 1


  warn("The default mode, 'constant', will be changed to 'reflect' in "


RuntimeError: CUDA error: out of memory