In [1]:
import sys
import config
sys.path.append(config.root)
import glob
import torch
import time
import multiprocessing as mp
import threading as td
from dataloader.read_preprocess import read_preprocess
from dataloader.parallel_loader import parallel_load_dset
from dataloader.loader import scene_tensor_dset
from queue import Queue
from utils.transforms import crop_scales
from utils.img_aug import rotate, flip, noise, missing

In [2]:
# ----------Data paths-------------- #
paths_as = sorted(glob.glob(config.root+'/data/s1_ascend/*'))
paths_des = sorted(glob.glob(config.root+'/data/s1_descend/*'))
paths_truth = sorted(glob.glob(config.root+'/data/s1_truth/*'))


In [3]:
### -----data read & pre-processing------- ###
scene_list, truth_list = read_preprocess(paths_as=paths_as,\
                                paths_des=paths_des, paths_truth=paths_truth)


In [4]:
# ## single process/multiple threads: load 15*num_thread patches
# tra_dset = parallel_load_dset(scene_list[0:15], \
#                                 truth_list[0:15], num_thread=50)


In [5]:
## multiple process
tra_dset = scene_tensor_dset(scene_tensor_list=scene_list,\
            truth_tensor_list=truth_list, transforms=config.transforms_tra)


torch.Size([4, 3306, 3632])

In [71]:
def scene2patch(q, scene, truth):
    '''pre-processing (e.g., random crop)'''
    transforms = [rotate(p=1), flip(p=0.5), noise(p=0.5, \
            std_min=0.001, std_max=0.1), missing(p=0.5, ratio_max = 0.25)]
    patches, ptruth = crop_scales(scales=[2048, 512, 256])(scene, truth)
    for transform in transforms:
        patches, ptruth = transform(patches, ptruth)
    ptruth = torch.unsqueeze(ptruth,0)
    q.put((patches, ptruth))

q = Queue()
start = time.time()
scene2patch(q, scene_list[0], truth_list[0])
print(f'time:{time.time()-start}')


time:0.054624080657958984


### Multiprocessing

In [7]:
# ###!!!!!!!!!Error

# print( "父进程启动id：%d" % os.getpid())
# p1 = mp.Process(target=scene2patch, args=(scene_list[0], truth_list[0]))
# p1.start()
# p1.join()



### Multi-threading


In [116]:
## define multi-thread job
def job(q, scene, truth):    
    '''q is Queue'''
    transforms = [rotate(p=1), flip(p=0.5), noise(p=0.5, \
            std_min=0.001, std_max=0.1), missing(p=0.5, ratio_max = 0.25)]
    '''convert image to patches group'''
    patches_group, truth=crop_scales(scales=[2048, 512, 256])(scene, truth)
    for transform in transforms:
        patches_group, truth = transform(patches_group, truth)
    truth = torch.unsqueeze(truth,0)
    q.put((patches_group, truth))

def parallel_read(scene, truth, num_thread=20):
    '''multi-thread reading training data
        cooperated with the job function
    '''
    patch_list, ptruth_list = [], []
    q = Queue()
    threads = [td.Thread(target=job, args=(q, scene, \
                        truth)) for i in range(num_thread)]
    start = [t.start() for t in threads]
    join = [t.join() for t in threads]
    for i in range(num_thread):
        patch, ptruth = q.get()
        patch_list.append(patch)
        ptruth_list.append(ptruth)
    return patch_list, ptruth_list

start = time.time()
patch_lists, ptruth_lists = [], []
for i in range(15):
    patch_list, ptruth_list = parallel_read(scene_list[0], truth_list[0], num_thread=1)
    patch_lists += patch_list
    ptruth_lists += ptruth_list 

print(f'time: {time.time() - start}')
len(patch_lists)


time: 0.901658296585083


15

time: 0.1555314064025879


10

In [123]:
## define multi-thread job
def job(q, scene_list, truth_list):    
    '''q is Queue'''
    patch_list, ptruth_list = [],[]
    transforms = [rotate(p=1), flip(p=0.5), noise(p=0.5, \
            std_min=0.001, std_max=0.1), missing(p=0.5, ratio_max = 0.25)]
    '''convert image to patches group'''
    zip_data = list(zip(scene_list, truth_list))
    for scene, truth in zip_data:
        patches_group, truth=crop_scales(scales=[2048, 512, 256])(scene, truth)
        for transform in transforms:
            patches_group, truth = transform(patches_group, truth)
        truth = torch.unsqueeze(truth,0)
        patch_list.append(patches_group), ptruth_list.append(truth)
    q.put((patch_list, ptruth_list))

def parallel_read(scene_list, truth_list, num_thread=20):
    '''multi-thread reading training data
        cooperated with the job function
    '''
    patch_lists, ptruth_lists = [], []
    q = Queue()
    threads = [td.Thread(target=job, args=(q, scene_list, \
                        truth_list)) for i in range(num_thread)]
    start = [t.start() for t in threads]
    join = [t.join() for t in threads]
    for i in range(num_thread):
        patch_list, ptruth_list = q.get()
        patch_lists += patch_list
        ptruth_lists += ptruth_list
    return patch_lists, ptruth_lists

start = time.time()
patches_list, ptruth_list = parallel_read(\
                            scene_list[0:15], truth_list[0:15], num_thread=1000)
print('time:{}'.format(time.time()-start))        
len(patches_list)


In [48]:

tra_loader = torch.utils.data.DataLoader(tra_dset, \
                                batch_size=config.batch_size, shuffle=True)



time:0.0002493858337402344


In [13]:
start = time.time()
for patch, truth in tra_loader:
    a = 0
!free -m
print('time:{}'.format(time.time()-start))        


              total        used        free      shared  buff/cache   available
Mem:          64301       19865       27848           6       16587       43718
Swap:          2047         624        1423
time:1.3459398746490479
