This notebook follows from the notebook `grib_test__FOR_0220_dataset_(grid-sampling_and_local-inspection).ipynb`. This notebook optimizes the computing of `ts_2mTemp` by parallelization.

In [1]:
import pygrib
msgs = pygrib.open(r"C:\SUSTech\datasets_of_graduation_project\0220.grib")

## Grid-sampling
Given the lat-lon range (a rectangular region) of the dataset, the grid-sampling is to uniformly sample points from it.

In [2]:
LAT_RANGE = (90, 66.5) # the entire map is from 90 degree to 66.5 degree, closed interval
LON_RANGE = (-180, 179.5)

In [3]:
LAT_RESOL = 5 # the grid resolution of latitudes; two distinct lats are differred by 5 degrees
LON_RESOL = 30
LAT_VICINITY_R = 1 # Consider the average of the vicinity of each grid point, to reduce the noise. The radius of the vicinity is 1 degree here. This value should not exceed LAT_RESOL
LON_VICINITY_R = 1

In [4]:
DISTINCT_LATS = [90 - i * LAT_RESOL 
    for i in range(int((abs(LAT_RANGE[1] - LAT_RANGE[0]) // LAT_RESOL) + 1))] # round down when sampling points
DISTINCT_LONS = [-180 + i * LON_RESOL 
    for i in range(int((abs(LON_RANGE[1] - LON_RANGE[0]) // LON_RESOL) + 1))]

In [5]:
import numpy as np
from bidict import bidict
msg = msgs[1]

# (lat, lon) -> (lat_idx, lon_idx)
latlon_to_latlonIdx = bidict()
for lat in DISTINCT_LATS:
    for lon in DISTINCT_LONS:
        latlon_to_latlonIdx[(lat, lon)] = (np.where(msg["distinctLatitudes"] == lat)[0][0], 
                                     np.where(msg["distinctLongitudes"] == lon)[0][0])

print(latlon_to_latlonIdx)

# (lat_idx, lon_idx) -> ((lat1_idx, lat2_idx), (lon1_idx, lon2_idx)), 
# with lat1_idx <= lat2_idx, lon1_idx <= lon2_idx
latlonIdx_to_vicinityIdx = bidict()
for latlonIdx in latlon_to_latlonIdx.values():
    lat_idx, lon_idx = latlonIdx
    lat1_idx = int(max(0, 
                lat_idx - LAT_VICINITY_R // 0.25))
    lat2_idx = int(min(len(msg["distinctLatitudes"]), 
                lat_idx + LAT_VICINITY_R // 0.25 + 1))
    lon1_idx = int(max(0, 
                lon_idx - LON_VICINITY_R // 0.25))
    lon2_idx = int(min(len(msg["distinctLongitudes"]), 
                lon_idx + LON_VICINITY_R // 0.25 + 1))
    latlonIdx_to_vicinityIdx[latlonIdx] = ((lat1_idx, lat2_idx), (lon1_idx, lon2_idx))
    
print(latlonIdx_to_vicinityIdx)


bidict({(90, -180): (0, 0), (90, -150): (0, 120), (90, -120): (0, 240), (90, -90): (0, 360), (90, -60): (0, 480), (90, -30): (0, 600), (90, 0): (0, 720), (90, 30): (0, 840), (90, 60): (0, 960), (90, 90): (0, 1080), (90, 120): (0, 1200), (90, 150): (0, 1320), (85, -180): (20, 0), (85, -150): (20, 120), (85, -120): (20, 240), (85, -90): (20, 360), (85, -60): (20, 480), (85, -30): (20, 600), (85, 0): (20, 720), (85, 30): (20, 840), (85, 60): (20, 960), (85, 90): (20, 1080), (85, 120): (20, 1200), (85, 150): (20, 1320), (80, -180): (40, 0), (80, -150): (40, 120), (80, -120): (40, 240), (80, -90): (40, 360), (80, -60): (40, 480), (80, -30): (40, 600), (80, 0): (40, 720), (80, 30): (40, 840), (80, 60): (40, 960), (80, 90): (40, 1080), (80, 120): (40, 1200), (80, 150): (40, 1320), (75, -180): (60, 0), (75, -150): (60, 120), (75, -120): (60, 240), (75, -90): (60, 360), (75, -60): (60, 480), (75, -30): (60, 600), (75, 0): (60, 720), (75, 30): (60, 840), (75, 60): (60, 960), (75, 90): (60, 1080)

## Compute the `ts_2mTemp` for local inspection (parallelized)
### Trial 1

In [6]:
# from concurrent.futures import ThreadPoolExecutor

# if "msg" in locals(): del msg
# ts_2mTemp = dict()
# for (lat, lon) in latlon_to_latlonIdx.keys():
#     ts_2mTemp[(lat, lon)] = np.array([])

In [7]:
# msgs.rewind()
# msgs_len = sum(1 for _ in msgs)
# print(msgs_len)

In [8]:
# BATCH_SIZE = 100
# batches_heads = list(range(1, msgs_len+1, BATCH_SIZE)) # the first batch starts with the 1st msg, the second batch starts with the 101st msg, and so on

In [9]:
# def compute_avg_for_latlon(latlon, batch_head):
#     lat, lon = latlon
#     lat1_idx, lat2_idx = latlonIdx_to_vicinityIdx[latlon_to_latlonIdx[(lat, lon)]][0]
#     lon1_idx, lon2_idx = latlonIdx_to_vicinityIdx[latlon_to_latlonIdx[(lat, lon)]][1]
    
#     # compute the average of 2m-temps in the vicinity of (lat, lon) for this batch
#     batch_ts_2mTemp = np.array([np.mean(msgs[t]["values"][lat1_idx:lat2_idx, lon1_idx:lon2_idx]) 
#                    for t in range(batch_head, min(batch_head+BATCH_SIZE, msgs_len+1))])
    
#     # 将计算结果直接存入ts_2mTemp字典
#     ts_2mTemp[(lat, lon)] = np.concatenate((ts_2mTemp[(lat, lon)], batch_ts_2mTemp))

# def compute_avg_for_msg(batch_head):
#     # 使用线程池来并行计算每个batch的所有(lat, lon)的平均值，并直接更新ts_2mTemp
#     with ThreadPoolExecutor() as executor:
#         # 对每个经纬度对进行并行处理
#         futures = [executor.submit(compute_avg_for_latlon, latlon, batch_head) 
#                    for latlon in latlon_to_latlonIdx.keys()]
        
#         # 等待所有任务完成
#         for future in futures:
#             future.result()  # 这个步骤确保每个计算任务完成

In [10]:
# msgs.rewind()

# from tqdm import tqdm
# for batch_head in tqdm(batches_heads, desc="Processing batches", total=len(batches_heads), unit="batch"):
#     compute_avg_for_msg(batch_head)
#     if batch_head == 101: print("The first batch has been processed.")

### Trial 2

In [11]:
# import numpy as np
# import concurrent.futures

# if "msg" in locals(): del msg

# def process_latlon(lat, lon):
#     lat1_idx, lat2_idx = latlonIdx_to_vicinityIdx[latlon_to_latlonIdx[(lat, lon)]][0]
#     lon1_idx, lon2_idx = latlonIdx_to_vicinityIdx[latlon_to_latlonIdx[(lat, lon)]][1]
#     return (lat, lon, np.array([np.mean(msg["values"][lat1_idx:lat2_idx, lon1_idx:lon2_idx]) for msg in msgs]))

# def parallel_processing():
#     ts_2mTemp = dict()

#     # 使用 ThreadPoolExecutor 来并行处理
#     with concurrent.futures.ThreadPoolExecutor() as executor:
#         futures = [
#             executor.submit(process_latlon, lat, lon)
#             for (lat, lon) in latlon_to_latlonIdx.keys()
#         ]
        
#         for future in concurrent.futures.as_completed(futures):
#             lat, lon, result = future.result()
#             ts_2mTemp[(lat, lon)] = result

#     return ts_2mTemp

# ts_2mTemp = parallel_processing()

### Trial 3

In [12]:
# import numpy as np
# import concurrent.futures

# if "msg" in locals(): del msg

# def process_latlon(lat, lon):
#     # 使用 get 方法来安全地访问字典
#     latlon_idx = latlon_to_latlonIdx.get((lat, lon))
#     if latlon_idx is None:
#         return (lat, lon, np.array([]))  # 返回空数组，如果找不到对应的 latlon_idx
    
#     lat1_idx, lat2_idx = latlonIdx_to_vicinityIdx[latlon_idx][0]
#     lon1_idx, lon2_idx = latlonIdx_to_vicinityIdx[latlon_idx][1]
    
#     # 计算均值，避免直接引用全局变量 msg，传递 msg 给子进程
#     try:
#         result = np.array([np.mean(msg["values"][lat1_idx:lat2_idx, lon1_idx:lon2_idx]) for msg in msgs])
#     except Exception as e:
#         result = np.array([])  # 如果计算过程中有错误，返回空数组
#         print(f"Error processing {lat}, {lon}: {e}")
    
#     return (lat, lon, result)

# def parallel_processing():
#     ts_2mTemp = dict()

#     # 使用 ProcessPoolExecutor 来并行处理 CPU 密集型任务
#     with concurrent.futures.ProcessPoolExecutor() as executor:
#         futures = [
#             executor.submit(process_latlon, lat, lon)
#             for (lat, lon) in latlon_to_latlonIdx.keys()
#         ]
        
#         for future in concurrent.futures.as_completed(futures):
#             lat, lon, result = future.result()
#             if result.size > 0:  # 只存储有效的结果
#                 ts_2mTemp[(lat, lon)] = result

#     return ts_2mTemp

# # 调用并行处理函数
# ts_2mTemp = parallel_processing()


### Trial 4 (the only runable one, but still too slow)

In [13]:
# from concurrent.futures import ThreadPoolExecutor

# if "msg" in locals(): del msg
# ts_2mTemp = dict()
# for (lat, lon) in latlon_to_latlonIdx.keys():
#     ts_2mTemp[(lat, lon)] = np.array([])

# def compute_avg_for_latlon(latlon, msg):
#     lat, lon = latlon
#     lat1_idx, lat2_idx = latlonIdx_to_vicinityIdx[latlon_to_latlonIdx[(lat, lon)]][0]
#     lon1_idx, lon2_idx = latlonIdx_to_vicinityIdx[latlon_to_latlonIdx[(lat, lon)]][1]
    
#     # 计算该区域的平均值
#     avg_value = np.mean(msg["values"][lat1_idx:lat2_idx, lon1_idx:lon2_idx])
    
#     # 将计算结果直接存入ts_2mTemp字典
#     ts_2mTemp[(lat, lon)] = np.append(ts_2mTemp[(lat, lon)], avg_value)

# def compute_avg_for_msg(msg):
#     # 使用线程池来并行计算每个msg的所有(lat, lon)的平均值，并直接更新ts_2mTemp
#     with ThreadPoolExecutor(max_workers=60) as executor:
#         # 对每个经纬度对进行并行处理
#         futures = [executor.submit(compute_avg_for_latlon, latlon, msg) 
#                    for latlon in latlon_to_latlonIdx.keys()]
        
#         # 等待所有任务完成
#         for future in futures:
#             future.result()  # 这个步骤确保每个计算任务完成
    

# msgs.rewind()

# from tqdm import tqdm
# for msg in tqdm(msgs, desc="Processing msgs", unit="msg"):
#     compute_avg_for_msg(msg)

### Trial 5

In [None]:
# import numpy as np
# from multiprocessing import Pool

# # Initialize ts_2mTemp dictionary
# ts_2mTemp = dict()

# def compute_for_latlon(latlon):
#     lat, lon = latlon
#     msgs.rewind()
#     lat1_idx, lat2_idx = latlonIdx_to_vicinityIdx[latlon_to_latlonIdx[(lat, lon)]][0]
#     lon1_idx, lon2_idx = latlonIdx_to_vicinityIdx[latlon_to_latlonIdx[(lat, lon)]][1]
#     return (lat, lon), np.array([np.mean(msg["values"][lat1_idx:lat2_idx, lon1_idx:lon2_idx]) for msg in msgs])

# # List of latlon pairs
# latlon_pairs = list(latlon_to_latlonIdx.keys())

# # Use multiprocessing Pool to parallelize the computation
# with Pool(processes=4) as pool:  # Adjust the number of processes as needed
#     results = pool.map(compute_for_latlon, latlon_pairs)

# # Update ts_2mTemp with the results
# for (lat, lon), values in results:
#     ts_2mTemp[(lat, lon)] = values

### Trial 6

In [7]:
# test for multiprocessing
from multiprocessing import  Process

def fun1(name):
    print('test %s multi-processes' %name)

if __name__ == '__main__':
    process_list = []
    for i in range(5):  #开启5个子进程执行fun1函数
        p = Process(target=fun1, args=('Python',)) #实例化进程对象
        p.start()
        process_list.append(p)

    for i in process_list:
        p.join()

    print('main process end')

main process end


In [None]:
from multiprocessing import Pool

def fun1(name):
    print('test %s multi-processes' % name)

if __name__ == '__main__':
    with Pool(5) as p:
        p.map(fun1, ['Python'] * 5)

    print('main process end')

In [None]:
msgs.close()

In [None]:
a = np.array([1, 2, 3])
b = [4, 5, 6]
np.concatenate((a, b))

array([1, 2, 3, 4, 5, 6])

In [None]:
# msgs.rewind()

# import time

# start = time.time()
# t = 1
# values = None
# for msg in msgs:
#     values = msg["values"]
#     t += 1
#     if t == 100: break
# end = time.time()
# print(end-start, "s")

# start = time.time()
# t = 1
# for t in range(1, 101):
#     values = msgs[t]["values"]
# end = time.time()
# print(end-start, "s")