In [None]:
### Inputs: traces_dict, node_details_dict and trace_details_dict
# Node details dict= nid: [nis, type]
### Config file: DB split and SLtype split
### Outputs: updated_node_details

In [1]:
import pickle
import yaml
import random
import json

import networkx as nx
import numpy as np

In [2]:
def pkl_to_dict(file_path):
    with open(file_path, 'rb') as pkl_file:
        T_prime = pickle.load(pkl_file)
    return T_prime

def save_dict_as_pkl(traces_dict, file_name):
    with open(file_name+'.pkl', 'wb') as pkl_file:
        pickle.dump(traces_dict, pkl_file)

def save_dict_as_json(traces_dict, file_name):
    with open(file_name+'.json', 'w') as json_file:
        json.dump(traces_dict, json_file)

def read_yaml(file):
    with open(file, 'r') as f:
        data = yaml.safe_load(f)
    return data

def build_digraph_from_tracesdict(traces_dict):

    full_graph_edge_list = []
    for edge_list in traces_dict.values():
        full_graph_edge_list.extend(edge_list)
    G = nx.DiGraph()
    G.add_edges_from(full_graph_edge_list)

    return G

def prune_node_details(traces_dict, node_dets):
    nodes_from_traces = []
    for _, e_list in traces_dict.items():
        for e in e_list:
            if e[0] not in nodes_from_traces:
                nodes_from_traces.append(e[0])
            if e[1] not in nodes_from_traces:
                nodes_from_traces.append(e[1])
    pruned_node_dets = {node: details for node, details\
                         in node_dets.items() if node in nodes_from_traces}

    return pruned_node_dets

In [10]:
# Read in configs
config = read_yaml('enrichment_config.yaml')
databases = config['Databases']

'''
NODE ENRICHMENT
---------------
Input: Traces_dict, Node_details_dict
Output: Node split output
        node_split_output = {'sf_split': {DB1: {'count': 30, 'nodes_list': [nid1, nid2, ...]}, ...},..}
                             'sl_split': ,,}
'''
# Node details dict= nid: [nis, SF, DB_name] (or) [nis, SL, SL_type]
traces_dict = pkl_to_dict('traces/exp_495nodes_100ktraces.pkl')
# selected_keys = ['0b5218f615919497680352000ed6c1']
# traces_dict = {key: traces_dict[key] for key in selected_keys if key in traces_dict}

node_dets = pkl_to_dict('new_node_details_data.pkl')
node_dets = prune_node_details(traces_dict, node_dets)

sf_arr = [nid for nid, n_info in node_dets.items() if n_info[1] == "db"]
sl_arr = [nid for nid, n_info in node_dets.items() if n_info[1] != "db"]


sf_count = len(sf_arr)
print("Number of SF nodes in trace graph:", sf_count)
sl_count = len(sl_arr)
print("Number of SL nodes in trace graph:", sl_count)
total_nodes = sf_count + sl_count

db_split_arr = [[db_name, info['percentage']] for db_name, info in databases.items()]# [[DB1, 30%],...]
sl_type_split = [['Python', sl_count]]
print("Database split Input:", db_split_arr)

def percent_to_count(arr, count):
    raw_counts = [round(count * (i[1] / 100)) for i in arr]
    diff = count - sum(raw_counts)
    
    # Distribute the difference
    idx = 0
    while diff != 0:
        if diff > 0:
            raw_counts[idx] += 1  # Increase by 1 if we need to add
            diff -= 1
        elif diff < 0:
            raw_counts[idx] -= 1  # Decrease by 1 if we need to remove
            diff += 1
        idx = (idx + 1) % len(raw_counts)
    
    for idx, i in enumerate(arr):
        arr[idx] = [i[0], raw_counts[idx]]
    
    return arr

db_split_arr = percent_to_count(db_split_arr, sf_count) # nid: [nis, SF, DB_name]
print("Database split output:", db_split_arr)


sf_split_info = {ntype: {"count": value, "nodes_list": []} for ntype, value in db_split_arr}
sl_split_info = {'Python': {"count": sl_count, "nodes_list": []}}

def assign_nodes_to_types(split_arr, sfsl_arr, split_info):
    sfsl_arr_cpy = sfsl_arr.copy()
    # Assign nodes to db and sl types
    for i in split_arr:
        ctr = 0
        name = i[0] # type name: eg: Mongo, Redis, Relay
        for _ in range(i[1]):
            ctr += 1
            nid = sfsl_arr_cpy.pop(random.randint(0, len(sfsl_arr_cpy) - 1))
            node_dets[nid].append(name) # add type to node details
            split_info[name]["nodes_list"].append(nid) # add node to list of nodes for that type
        # print(ctr, name)
    return node_dets, split_info

node_dets, sf_split_info = assign_nodes_to_types(db_split_arr, sf_arr, sf_split_info)
node_dets, sl_split_info = assign_nodes_to_types(sl_type_split, sl_arr, sl_split_info)

# Saving node split output
node_split_output = {'sf_split': sf_split_info, 'sl_split': sl_split_info}
print("Nodes Split Output:", node_split_output)
save_dict_as_json(node_split_output, 'enrichment_runs/dmix3_redis_heavy/node_split_output')

Number of SF nodes in trace graph: 279
Number of SL nodes in trace graph: 216
Database split Input: [['MongoDB', 20], ['Redis', 60], ['Postgres', 20]]
Database split output: [['MongoDB', 56], ['Redis', 167], ['Postgres', 56]]
Nodes Split Output: {'sf_split': {'MongoDB': {'count': 56, 'nodes_list': ['n6268', 'n3613', 'n3636', 'n7033', 'n1361', 'n6660', 'n1343', 'n4365', 'n6454', 'n6474', 'n8049', 'n5410', 'n5601', 'n7222', 'n6420', 'n3570', 'n6487', 'n1937', 'n2530', 'n7731', 'n6617', 'n2764', 'n5952', 'n5681', 'n2661', 'n5724', 'n5482', 'n5401', 'n3634', 'n5416', 'n7555', 'n4101', 'n4667', 'n7960', 'n5206', 'n6840', 'n8301', 'n6182', 'n2557', 'n5541', 'n2599', 'n7476', 'n6469', 'n412', 'n136', 'n7870', 'n5459', 'n3514', 'n7695', 'n434', 'n2521', 'n4726', 'n804', 'n3492', 'n6918', 'n224']}, 'Redis': {'count': 167, 'nodes_list': ['n4797', 'n7084', 'n6035', 'n5628', 'n7422', 'n5691', 'n3223', 'n4388', 'n1787', 'n3207', 'n6576', 'n3460', 'n6803', 'n3427', 'n5227', 'n4550', 'n6300', 'n429',

In [None]:
# 0b5218f615919497680352000ed6c1 1 5: Mongo testing
# 0b5218f615919497680352000ed6c1 1 5:
# 0b521d5415919436596543000e0312 1 5:

# ticker = 0
# for tid in traces_dict:
#     e_list = traces_dict[tid]c
#     sf_needed = 0
#     t_nodes = []
#     for e in e_list:
#         if e[1] not in t_nodes:
#             t_nodes.append(e[1])
#         if e[0] not in t_nodes:
#             t_nodes.append(e[0])
#         if e[1] in sf_arr:
#             sf_needed += 1
#     if sf_needed == 2 and len(t_nodes) < 35:
#         ticker += 1
#         print(tid, sf_needed, len(t_nodes))


In [11]:
'''
Object id Enrichment
Output: Trace packets.
        Trace packets = [t_node_calls_dict, t_data_ops_dict]
        t_node_calls_dict = Key: dm node, Value: list of [dm node, op_id]
        t_data_ops_dict = Key: data op id, Value: data op packet
'''

class Wl_config:
    """
    Format: record_count, record_size_dist,
                 data_access_pattern, rw_ratio, async_sync_ratio, seed
    """
    def __init__(self, record_count, record_size_dist,\
                 data_access_pattern, rw_ratio, async_sync_ratio, seed):
        self.record_count = record_count
        self.record_size_dist = record_size_dist
        self.data_access_pattern = data_access_pattern
        self.rw_ratio = rw_ratio
        self.async_sync_ratio = async_sync_ratio
        self.seed = seed

        # Setting seed
        np.random.seed(self.seed)
        random.seed(self.seed)
        # Generate object sizes and data access pattern
        self.obj_ids_list = np.arange(1, self.record_count + 1)
        self.object_sizes_dict = self.generate_object_sizes()
        self.probabilities = self.generate_data_access_pattern()

    def generate_object_sizes(self):
        if self.record_size_dist == 'lognormal':
            obj_sizes = np.random.lognormal(mean=np.log(self.record_count), \
                                                 sigma=np.log(self.record_count), \
                                                 size=self.record_count)
        elif self.record_size_dist == 'uniform':
            obj_sizes = np.random.uniform(low=1, high=self.record_count, size=self.record_count)
        else:
            raise ValueError('Invalid record size distribution, only lognormal & uniform are allowed for now')
        return dict(zip(self.obj_ids_list, obj_sizes))
    
    def generate_data_access_pattern(self):
        if self.data_access_pattern == 'zipfian':
            alpha = 1.2
            probabilities = np.random.zipf(alpha, len(self.obj_ids_list))
            probabilities = probabilities / probabilities.sum()
        elif self.data_access_pattern == 'uniform':
            probabilities = np.ones(len(self.obj_ids_list)) / len(self.obj_ids_list)
        else:
            raise ValueError('Invalid data access pattern, only zipfian & uniform are allowed for now.')
        return probabilities


def gen_sfnode_dataops(sf_node, wl_config, traces_dict, node_dets):
    '''
    For a given sf node, generate data ops (count total dm calls to sf node)
    Return: ops_dict= Key: op_id, Value: op_packet
    op_packet = {'op_id': op_id, 'op_type': op_type, 'op_obj_id': op_obj_id,\
                 'db': sf_node_db}
    '''
    obj_ids_list = wl_config.obj_ids_list
    # obj_sizes_dict = wl_config.object_sizes_dict
    data_acc_probabilities = wl_config.probabilities
    w_prob = wl_config.rw_ratio / (1 + wl_config.rw_ratio)

    sf_node_db = node_dets[sf_node][2]

   # find the number of ops to be generated
    total_ops = 0
    for e_list in traces_dict.values():# count total dm calls to sf node
        for e in e_list:
            if e[1] == node:
                total_ops += 1

    # generate ops for sf node
    ops_dict = {}   # key: op_id, value: op_packet
    for op_id in range(1, total_ops + 1):
        op_type = 'write' if random.random() < w_prob else 'read'
        # op_obj_id = np.random.choice(obj_ids_list,\
        #                              p=data_acc_probabilities)# Select by data access pattern
        op_obj_id = random.randrange(1, wl_config.record_count + 1)
        # op_obj_size = obj_sizes_dict[op_obj_id]
        operation = {'op_id': op_id, 'op_type': op_type, 'op_obj_id': f"key_{op_obj_id}",\
                      'db': sf_node_db} # op_packet
        ops_dict[op_id] = operation
    
    return ops_dict


# convert edges_list to node_calls_dict format 
def gen_node_calls_dict(edges_list, async_sync_ratio):
    '''
    Return: node_calls_dict = Key: dm node, Value: list of [dm node, op_id, async_flag]
            (op_id = -1 for SL) (async_flag = 1 for async, 0 for sync)
    '''
    node_calls_dict = {}
    for edge in edges_list:
        if edge[0] not in node_calls_dict:
            node_calls_dict[edge[0]] = []
        async_prob = async_sync_ratio / (1 + async_sync_ratio)
        async_flag = 1 if random.random() < async_prob else 0
        node_calls_dict[edge[0]].append([edge[1], -1, async_flag]) # [dm node, op_id, async/sync] (-1 for SL) (1/0 for async/sync)
    return node_calls_dict


# Reading enrichment config file
enrichment_config = read_yaml('enrichment_config.yaml')
record_count = enrichment_config['WorkloadConfig']['record_count']
record_size_dist = enrichment_config['WorkloadConfig']['record_size_dist']
data_access_pattern = enrichment_config['WorkloadConfig']['data_access_pattern']
rw_ratio = enrichment_config['WorkloadConfig']['rw_ratio']
async_sync_ratio = enrichment_config['WorkloadConfig']['async_sync_ratio']
# Format: record_count, record_size_dist, data_access_pattern, rw_ratio, async_sync_ratio, seed
wl1 = Wl_config(record_count, record_size_dist, data_access_pattern, rw_ratio, async_sync_ratio, seed=50) # to be read from config file

# Generate data op packets for each sf node
G_agg = build_digraph_from_tracesdict(traces_dict)
overall_data_ops = {}   # key: sf_node, value: ops_dict
check = 0
for node in node_dets:
    if node in G_agg.nodes() and node_dets[node][1] == 'db':
        overall_data_ops[node] = \
            gen_sfnode_dataops(node, wl1, traces_dict, node_dets)

def get_pop_first_dict_item(d):
    first_key = list(d.keys())[0]
    first_item = d.pop(first_key)
    return first_key, first_item

def get_node_type(node_id, data):
    '''
    data: node_split_output.json
    '''
    for split_type, services in data.items():
        for service, service_data in services.items():
            if node_id in service_data['nodes_list']:
                return service

def remove_self_node_calls(node_call_dict):
    for node, dm_nodes in node_call_dict.items():
        node_call_dict[node] = [dm_node for dm_node in dm_nodes if dm_node[0] != node]
    return node_call_dict

def get_leaf_nodes(node_call_dict):
    '''Returns: leaf nodes in a request call graph'''
    all_nodes = set(node_call_dict.keys())
    called_nodes = set()
    for calls in node_call_dict.values():
        for call in calls:
            called_nodes.add(call[0])
    leaf_nodes = called_nodes - all_nodes
    return leaf_nodes

def get_logger_nodes_for_request_call_graph(node_call_dict):
    '''Returns: list of nodes that log for the request call graph
                SL leaf nodes and SL node predecessor to SF leaf nodes.
    '''
    logger_nodes = set()
    t_leaf_nodes = get_leaf_nodes(node_call_dict) # find all leaf nodes
    for ln in t_leaf_nodes:
        for node, calls in node_call_dict.items():
            for call in calls:
                if call[0] == ln and call[1] != -1: # Leaf SF node
                    logger_nodes.add(node)
                elif call[0] == ln and call[1] == -1: # Leaf SL node
                    logger_nodes.add(ln)
    return list(logger_nodes)

def has_cycle(graph):
    def dfs(node, visited, rec_stack):
        if node not in visited:
            # Mark the current node as visited and add to the recursion stack
            visited.add(node)
            rec_stack.add(node)
            # Check all the nodes this node is connected to
            for neighbor_info in graph.get(node, []):
                neighbor = neighbor_info[0]
                # If the neighbor is not visited, do a recursive DFS call
                if neighbor not in visited and dfs(neighbor, visited, rec_stack):
                    return True
                # If the neighbor is already in the recursion stack, it's a cycle
                elif neighbor in rec_stack:
                    return True
            rec_stack.remove(node)
        return False
    visited = set()
    rec_stack = set()
    # Check for cycles starting from each node in the graph
    for node in graph.keys():
        if node not in visited:
            if dfs(node, visited, rec_stack):
                return True
    return False

'''
Making the trace packet:
trace_packet = [t_node_calls_dict, t_data_ops_dict, t_ini_node, t_ini_node_type]
t_node_calls_dict = Key: dm node, Value: list of [dm node, op_id]
t_data_ops_dict = Key: data op id, Value: data op packet
'''
trace_details_data = pkl_to_dict('new_trace_details_data.pkl')
node_split_output = json.load(open('./enrichment_runs/dmix2_mongo_heavy/node_split_output.json'))
all_trace_packets = {}
cycle_ctr = 0
for tid in traces_dict:
    t_node_calls_dict = gen_node_calls_dict(traces_dict[tid], async_sync_ratio)
    t_data_ops_dict = {} # key: data op id, value: data op packet
    for t_node in t_node_calls_dict:
        for idx in range(len(t_node_calls_dict[t_node])):# Why is it not entering the if loop?
            dm_node = t_node_calls_dict[t_node][idx][0]
            if node_dets[dm_node][1] == 'db': # ie if dm node is a sf node
                # Select a data op id from the data ops dict and pop it
                if len(overall_data_ops[dm_node]) == 0:
                    print("Error: No data ops for sf node", dm_node)
                    break
                # Select a data op id from the data ops dict and pop it
                op_id, op_packet = get_pop_first_dict_item(overall_data_ops[dm_node])
                t_node_calls_dict[t_node][idx][1] = op_id
                t_data_ops_dict[op_id] = op_packet
    t_ini_node = trace_details_data[tid][2] # getting initial node
    t_ini_node_type = get_node_type(t_ini_node, node_split_output)
    t_node_calls_dict = remove_self_node_calls(t_node_calls_dict) # Remove self calls in node_calls_dict
    # Get log nodes for this request call graph
    t_logger_nodes = get_logger_nodes_for_request_call_graph(t_node_calls_dict)
    if has_cycle(t_node_calls_dict):
        continue

    trace_packet = {"tid": tid, "node_calls_dict": t_node_calls_dict, "data_ops_dict": t_data_ops_dict,\
                     "initial_node": t_ini_node, "initial_node_type": t_ini_node_type, "logger_nodes": t_logger_nodes}
    if has_cycle(t_node_calls_dict):
        cycle_ctr += 1
    all_trace_packets[tid] = trace_packet
print("Cycle Ctr: ", cycle_ctr)
save_dict_as_json(all_trace_packets, 'enrichment_runs/dmix3_redis_heavy/all_trace_packets')

Cycle Ctr:  0
