In [None]:
import sys, os
import matplotlib.pyplot as plt
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

sys.path.insert(0, os.path.join(os.getcwd(), "../../"))
sys.path.insert(0, "/usr/local/cuda/targets/x86_64-linux/lib/stubs/") # 
from analysis.utils import histogram, abs_err, gmae, GPU_COUNT
from analysis.memory_bw_utils import *

superscript = str.maketrans("0123456789", "⁰¹²³⁴⁵⁶⁷⁸⁹")
collectives = ['all_to_allv', 'all_reduce']
num_of_collectives = len(collectives)

In [None]:
# GPU_COUNT = 4
dir_prefix = "../../3rdparty/param/train/comms/pt/bench_results"
data = process_param_data(
    prefix=dir_prefix,
    collectives=collectives,
    num_gpus=GPU_COUNT,
)

In [None]:
fig = plt.figure(figsize=(12, 18))
for idx, collective in enumerate(collectives):
    ax = fig.add_subplot(num_of_collectives, 2, idx * 2 + 1)
    ax.set_title('{} latency'.format(collective))
    ax.plot(data[collective]['size'], data[collective]['latency'], marker='o')
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_ylim([1e1, 1e6])
    ax.set_xticks([2**i for i in range(2, 33)])
    ax.set_xticklabels(["2{}".format(str(j).translate(superscript)) for j in range(2, 33)])

    ax = fig.add_subplot(num_of_collectives, 2, idx * 2 + 2)
    ax.set_title('{} BW'.format(collective))
    ax.plot(data[collective]['size'], data[collective]['alg_bw'], marker='o')
    ax.plot(data[collective]['size'], data[collective]['bus_bw'], marker='o')
    ax.legend(['alg_bw', 'bus_bw'])
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_ylim([1e-3, 1e3])
    ax.set_xticks([2**i for i in range(2, 33)])
    ax.set_xticklabels(["2{}".format(str(j).translate(superscript)) for j in range(2, 33)])

fig.tight_layout(pad=0.5)

In [None]:
# fig = plt.figure(figsize=(8, 8))
# ax1 = fig.add_subplot(3, 1, 1)
# ax2 = fig.add_subplot(3, 1, 2)
# ax3 = fig.add_subplot(3, 1, 3)
# for idx, collective in enumerate(collectives):
    
#     ax1.plot(data[collective]['size'], data[collective]['latency'], marker='o')
#     ax1.set_xscale('log')
#     ax1.set_yscale('log')
#     ax1.set_ylim([1e1, 1e6])
#     ax1.set_xticks([2**i for i in range(2, 33)])
#     ax1.set_xticklabels(["2{}".format(str(j).translate(superscript)) for j in range(2, 33)])
#     ax1.set_xlabel("Message size (bytes)", fontsize=12)
#     ax1.set_ylabel("Latency (us)", fontsize=12)

#     ax2.plot(data[collective]['size'], data[collective]['bus_bw'], marker='o')
#     ax2.set_xscale('log')
#     ax2.set_yscale('log')
#     ax2.set_ylim([1e-3, 1e2])
#     ax2.set_xticks([2**i for i in range(2, 33)])
#     ax2.set_xticklabels(["2{}".format(str(j).translate(superscript)) for j in range(2, 33)])
#     ax2.set_xlabel("Message size (bytes)", fontsize=12)
#     ax2.set_ylabel("Bus BW (GB/s)", fontsize=12)

#     ax3.plot(data[collective]['size'], data[collective]['alg_bw'], marker='o')
#     ax3.set_xscale('log')
#     ax3.set_yscale('log')
#     ax3.set_ylim([1e-3, 1e2])
#     ax3.set_xticks([2**i for i in range(2, 33)])
#     ax3.set_xticklabels(["2{}".format(str(j).translate(superscript)) for j in range(2, 33)])
#     ax3.set_xlabel("Message size (bytes)", fontsize=12)
#     ax3.set_ylabel("Alg BW (GB/s)", fontsize=12)

# ax1.legend(collectives)
# ax2.legend(collectives)
# ax3.legend(collectives)
# fig.suptitle("Communication collectives microbenchmark on a quad-V100 DGX-1", fontsize=14)
# fig.tight_layout(pad=0.5)
# plt.savefig('../../3rdparty/param/train/comms/pt/latency_bw.png', bbox_inches='tight')

In [None]:
# BW curve fitting. Input: total message size in bytes, output: BW in GB/s.
mem_chs = {}
sigmoid_params = {}
fig, axs = plt.subplots(1, len(collectives), figsize=(8*len(collectives), 6))
for idx, collective in enumerate(collectives):
    ax = axs[idx] if len(collectives) > 1 else axs
    mem_chs[collective] = get_memory_characteristics(data[collective])
    sigmoid_params[collective] = fit_sigmoid_bw_predictor(data[collective], mem_chs[collective])
    d = {
        "mul_factor": MUL_FACTOR_FUNCS[collective](GPU_COUNT),
        "mem_ch": mem_chs[collective],
        "sigmoid_param": sigmoid_params[collective],
    }
    f1 = lambda x: predict_bus_bw(x, **d)
    ax.plot(data[collective]['size'], data[collective]['bus_bw'])
    ax.plot(data[collective]['size'], data[collective]['size'].apply(f1))
    ax.axvline(2 ** mem_chs[collective]["ln_p"], linestyle='--', color='green')
    ax.axvline(2 ** mem_chs[collective]["sats_p"], linestyle='--', color='green')
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_title(collective)
    ax.set_ylabel('Latency (us)', va='center', rotation=90, fontsize=14)
    ax.set_xlabel("Message Size (in bytes)", fontsize=14)
    ax.grid()
    fig.tight_layout()
    # print(data[collective]['bus_bw'])
    # print(data[collective]['size'].apply(f1))

    # # Prediction
    # print("----- {} -----".format(collective))
    # for idx, size in enumerate(data[collective]['size']):
    #     f_mul_factor = MUL_FACTOR_FUNCS[collective]
    #     f_sigmoid_bw = sigmoid_params[collective](GPU_COUNT)
    #     print("{:.2f}, {:.2f}, {:.2f}".format(
    #         data[collective]['latency'][idx],
    #         predict_linear(size, f_mul_factor, *mem_chs[collective]),
    #         predict_data_movement_time(size, f_mul_factor, mem_ch, sigmoid_param),
    #     ))

Current limitations:
- Is there a way to directly get the min/max BW from device connection configuration w/o benchmarking? Can the bus BW be derived from the algo BW which seems to follow a pattern (50, 75, 87.5 GB/s)?

e.g. 4 GPUs, all_to_all, each GPU sends 1/4 elements to each of the other GPUs
- --b/--e (in bytes per rank): 16, 32, 64...
- allSizes (in bytes per rank): 16, 32, 64...
- memSize / size (B) in printed results (in bytes per rank): 16, 32, 64...
- num-elements in printed results (in elements COMMUNICATED per rank-pair): 1, 2, 4...

commsParams.element_size: 4 (float)
comm_fn: backendFuncs.collectiveFunc[commsParams.collective]
comms.py comm op line 1202 calls runColl line 258
--z/commsParamsHolderBase's blockingFlag/~asyncOp 1: non-blocking, 0: blocking
gatherBenchTime line 767: gather bench time stored in tensors on each device to a list of tensors on rank 0.

param pytorch_dist_backend: all_to_all line 163 calls dist.all_to_all_single line 170, wait function called at line 389
dlrm extend_distributed: alltoall line 597 calls All2All_Req line 404 calls dist.all_to_all_single line 429 (list of local tensors concatenated and flatten to 1D)

e.g. batched_emb
dist.all_to_all_single: input_split_sizes (how tables are distributed to devices, e.g. 13 tables and [2,3,3,5] on 4 GPUs), output_split_sizes (how batches are distributed to devices; set to None for equal distribution, e.g. batch size 2048 -> 512 per GPU)
common case: input_split_sizes not None, output_split_sizes None

- reduce scatter: memSize measures the INPUT size in bytes per rank (equal to total OUTPUT size on all devices)
- all gather: memSize measures the OUTPUT size in bytes per rank (equal to total INPUT size on all devices)

In [None]:
mem_ch = mem_chs['all_to_allv']
sigmoid_param = sigmoid_params["all_to_allv"]
mul_factor = MUL_FACTOR_FUNCS["all_to_allv"](GPU_COUNT)
df = process_general_a2a_param_data(
    prefix=dir_prefix,
    num_gpus=GPU_COUNT,
)
df.head()

In [None]:
# Exploration
def get_adjusted_size(s, t):
    splitted = s.split(',')
    B = int(splitted[0]) // GPU_COUNT
    D = int(splitted[2])
    tables = [int(t) for t in splitted[1].split('-')]
    if t == "sum":
        T = sum(tables)
    elif t == "max_of_max":
        T = max([max(sum(tables) - t, t * (GPU_COUNT-1)) for t in tables])
    elif t == "max_of_sum":
        T = max([(sum(tables) - t + t * (GPU_COUNT-1)) for t in tables])
    else:
        raise Exception("Unrecognized max_type")
    return B * T * D * 4 # float32

for tt in ['sum', 'max_of_max', 'max_of_sum']:
    adjusted_size = df['btd'].apply(get_adjusted_size, args=(tt,))
    y1 = adjusted_size.apply(predict_data_movement_time, args=(mul_factor, mem_ch, sigmoid_param))
    error1 = abs_err(y1, df['latency'])
    print("All to allv ({}): GMAE: {:.2f}%, mean: {:.2f}%, std: {:.2f}%".format(tt, gmae(error1) * 100.0, error1.mean() * 100.0, error1.std() * 100.0))
    _ = histogram(error1)
    sorted_df = pd.DataFrame({
        "size": adjusted_size,
        "latency": df["latency"],
    }).sort_values(['size'], ascending=True)
    fig = plt.figure()
    ax = plt.gca()
    ax.scatter(sorted_df['size'], sorted_df['latency'], s=1)
    ax.plot(sorted_df['size'], sorted_df['size'].apply(predict_data_movement_time, args=(mul_factor, mem_ch, sigmoid_param)), color='orange', linewidth=4)
    ax.set_yscale('log')
    ax.set_xscale('log')
    ax.axvline(2 ** mem_ch["ln_p"], linestyle='--', color='green')
    ax.axvline(2 ** mem_ch["sats_p"], linestyle='--', color='green')
    ax.set_ylabel('Latency (us)', va='center', rotation=90, fontsize=14)
    ax.set_xlabel("Message Size (in bytes)", fontsize=14)
    ax.grid()
    fig.tight_layout()
    if tt == "max_of_max":
        plt.savefig('./a2a.pdf', bbox_inches='tight')
        plt.savefig('./a2a.png', bbox_inches='tight')