In [None]:
import torch
import os
import torchvision.transforms as transforms
from PIL import Image
import random
import torchvision.transforms.functional as TF
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pickle
from torch.utils.data import DataLoader

In [None]:
data = pd.read_csv('objectInfo150.txt', sep='\t', lineterminator='\n')
name2idx = {}
for i in range(150):
    line = data.loc[i]
    name2idx[ line['Name']  ] = line['Idx']

   
   
class Opt_train:
    train = True
    aug = True
   
    full_data_dir = '../../data/ade20k/full_data_bedroom/'
   
    bg_size=512
   
    fg_img_size=256
    fg_seg_size=256
    fg_sem_size=128

   
   
class Opt_test:
    train = False
    aug = False  
   
    full_data_dir = '../../data/ade20k/full_data_bedroom/'
   
    bg_size = 512
   
    fg_img_size=256
    fg_seg_size=256
    fg_sem_size=128


In [None]:

  

def get_box(mask):
    "mask should be a 2D np.array "
    if mask.sum()==0:
        return 0,0,0,0 # means this object is cropped out during aug
    y,x = np.where(mask == 1)
    x1,x2,y1,y2 = x.min(),x.max(),y.min(),y.max()
    w = x2-x1
    h = y2-y1
    return x1,y1,x2,y2


def enlarge_box(x1, y1, x2, y2, width, height, ratio):
    w, h = x2-x1, y2-y1
    r = int( max(w,h) * (ratio/2) )
    center_x = int( (x1+x2)/2 )
    center_y = int( (y1+y2)/2 )
    y1 = max(0, center_y-r)
    y2 = min(height, center_y+r)
    x1 = max(0, center_x-r)
    x2 = min(width, center_x+r)
    return x1, y1, x2, y2


In [None]:



   

class Dataset(torch.utils.data.Dataset):
    def __init__(self, fg_classes, train):
       
        opt = Opt_train if train else Opt_test
       
        self.train = train  
        self.full_data_dir = opt.full_data_dir  
        self.aug = opt.aug
       
        self.bg_size = opt.bg_size
        self.fg_img_size = opt.fg_img_size
        self.fg_sem_size = opt.fg_sem_size
        self.fg_seg_size = opt.fg_seg_size
       
       
        temp = 'training' if self.train else 'validation'
       
 
        self.img_files = os.listdir( os.path.join( self.full_data_dir,'images', temp ) )
        self.sem_files = os.listdir( os.path.join( self.full_data_dir,'annotations', temp ) )  
        self.ins_files = os.listdir( os.path.join( self.full_data_dir,'annotations_instance', temp ) )
        assert( len(self.img_files)==len(self.sem_files)==len(self.ins_files)    )
       
        self.img_files = [  os.path.join( self.full_data_dir,'images', temp, item ) for item in self.img_files ]
        self.sem_files = [  os.path.join( self.full_data_dir,'annotations', temp, item ) for item in self.sem_files ]
        self.ins_files = [  os.path.join( self.full_data_dir,'annotations_instance', temp, item ) for item in self.ins_files ]
       
        self.img_files.sort()
        self.sem_files.sort()
        self.ins_files.sort()
               
       
        self.fg_classes = fg_classes  
       
        for name in fg_classes:
            assert name in name2idx
       

        with open(  self.full_data_dir+'ins_of_each_sem_'+temp+'.txt', "rb") as fp:  
            self.ins_of_each_sem = pickle.load(fp)
           
           
    def exist_check(self, x1, y1, x2, y2, W, H):
       
        # these are box location of this instance in final fg image
        new_x1 = int(self.bg_size/W *x1)
        new_x2 = int(self.bg_size/W *x2)
        new_y1 = int(self.bg_size/H *y1)
        new_y2 = int(self.bg_size/H *y2)
        if new_x1>=new_x2 or new_y1>=new_y2:
            return False, None  # means this object is too small to exist in final scene

       
        return True, { 'new_box':[new_x1,new_y1,new_x2,new_y2],  'new_size':[new_y2-new_y1, new_x2-new_x1]  }

   
   
    def instance_process(self, img, sem, ins, ins_idxs):
        """
        This is a processer for each instance
        img, sem and ins are all in original resolution,
        ins_idxs is a list contains ins idxes wanted for this class
        It will return a list contaning multiple dict
        and each dict has information for each instance
        """
        W,H =img.size
       
        results = []
        ins_array = np.array(ins)
       
        for idx in ins_idxs:
           
            # get box for this instance
            this_instance_mask = (ins_array==idx)    
            x1, y1, x2, y2 = get_box(this_instance_mask)
           
            # check if this instance will be presented in final scene if so its box and size
            exist, result = self.exist_check(x1, y1, x2, y2, W, H)
           
            if exist:                
                # crop img and ins mask (name it seg)
                img = img.crop([x1, y1, x2, y2]).resize( (self.fg_img_size,self.fg_img_size), Image.NEAREST )          
                seg = ins.crop([x1, y1, x2, y2]).resize( (self.fg_seg_size,self.fg_seg_size), Image.NEAREST )
               
                # enlarge current box to give more global information and crop sem
                x1, y1, x2, y2 = enlarge_box(x1, y1, x2, y2, W, H, 2) # hardcoded enlarge twice
                sem = sem.crop([x1, y1, x2, y2]).resize( (self.fg_sem_size,self.fg_sem_size), Image.NEAREST  )
               
                #transform then into tensor
                result['img'] = ((TF.to_tensor(img)-0.5)/0.5).unsqueeze(0)    
                result['sem'] = torch.tensor( np.array(sem) ).unsqueeze(0).unsqueeze(0).long()
                result['seg'] = torch.tensor(  (np.array(seg)==idx)*1  ).unsqueeze(0).unsqueeze(0).float()  
       
                results.append(result)
           
        return results
       
       
       
               
    def main_process(self, img, sem, ins, ins_sem):
           
        # get img and sem with resolution used in bg
        bg_img = img.resize( (self.bg_size,self.bg_size), Image.NEAREST )  
        bg_sem = sem.resize( (self.bg_size,self.bg_size), Image.NEAREST )  # used both in bg and final spade
         

        # create fg instance each time  
        fg_data = {}
 
        for class_name in self.fg_classes:    
       
            if class_name in ins_sem: # it means this image has this semantic
                this_class_fg_data = self.instance_process(img, sem, ins, ins_sem[class_name] )
               
                if len(this_class_fg_data) != 0:   #otherwise means all instance of this sem are cropped out
                    fg_data[class_name] = this_class_fg_data
       
       
        bg_img = ((TF.to_tensor(bg_img)-0.5)/0.5).unsqueeze(0)
        bg_sem = torch.tensor( np.array(bg_sem) ).unsqueeze(0).unsqueeze(0).long()  
        bg_data = {   'bg_img':bg_img, 'bg_sem':bg_sem   }
         
        return bg_data, fg_data


   
    def transform(self, img, sem, ins):
        if not self.aug:
            return img, sem, ins

        if random.random() > 0.5:      
            img = TF.hflip(img)
            sem = TF.hflip(sem)
            ins = TF.hflip(ins)
           
           
        W,H = img.size
        new_w, new_h = int(W*random.uniform(0.8, 1)), int(H*random.uniform(0.8, 1))

        # Random crop
        i, j, h, w = transforms.RandomCrop.get_params(img, output_size=(new_h, new_w))
        img = TF.crop(img, i, j, h, w)
        sem = TF.crop(sem, i, j, h, w)
        ins = TF.crop(ins, i, j, h, w)
       
        return img, sem, ins

   
    def __getitem__(self, idx):
       
       
        # real full image, semantic map, instance map      
        img = Image.open( self.img_files[idx]  )
        sem = Image.open( self.sem_files[idx]  )
        ins = Image.open( self.ins_files[idx]  )
       
        # read ins info for each sem
        ins_sem = self.ins_of_each_sem[idx]
       
        # apply data aug if specified
        img, sem, ins = self.transform(img, sem, ins)
   
        # process all data
        bg_data, fg_data = self.main_process(img, sem, ins, ins_sem)
     
        return bg_data, fg_data


    def __len__(self):
        return len(self.img_files)








def get_dataloader(fg_classes, train, batch_size=32, shuffle=True, drop_last=True):
    """
    DONT SET NUMBER OF WORKERS ON V8, SOME MULTIPLE PROCESSING BUGS
    WHAT IS IN HERE DOES NOT HELPFUL:
    https://github.com/pytorch/pytorch/issues/973
    """
   
    def collate_fn(data):
        return data

    dataset = Dataset(fg_classes, train)
    dataloader = DataLoader( dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=shuffle, drop_last=drop_last  )

    return dataloader