# RLDT

## Step 1: Import the necessary libraries:

In [1]:
import numpy as np
import pandas as pd
import networkx as nx
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

## Step 2: Define the environment:

### Step 2.1: Devices

#### *Gloabl variables*

In [2]:
num_IOT_devices = 5

voltages_frequencies_IOT = [
    (1e6, 1.8),
    (2e6, 2.3),
    (4e6, 2.7),
    (8e6, 4.0),
    (16e6, 5.0),
    (32e6, 6.5),
]
num_MEC_devices = 10

voltages_frequencies_MEC = [
    (6 * 1e8, 0.8),
    (7.5 * 1e8, 0.825),
    (10 * 1e8, 1.0),
    (15 * 1e8, 1.2),
    (30 * 1e8, 2),
    (40 * 1e8, 3.1),
]

task_kinds = [1,2,3,4]

min_num_nodes_dag = 4
max_num_nodes_dag = 20
max_num_parents_dag = 5

num_dag_generations = 100

#### *IOT*

In [3]:
devices_data_IOT = []
for i in range(num_IOT_devices):
    cpu_cores = np.random.choice([4, 6, 8])
    device_info = {
        "id": i,
        "number_of_cpu_cores": cpu_cores,
        "occupied_cores": [np.random.choice([0, 1]) for _ in range(cpu_cores)],
        "voltages_frequencies": [
            voltages_frequencies_IOT[i]
            for i in np.random.choice(6, size=4, replace=False)
        ],
        "ISL": np.random.randint(10, 21),
        "capacitance": [np.random.uniform(2, 3) * 1e-9 for _ in range(cpu_cores)],
        "powerIdle": [
            np.random.choice([700, 800, 900]) * 1e-6 for _ in range(cpu_cores)
        ],
        "batteryLevel": np.random.randint(36, 41) * 1e9,
        "errorRate": np.random.randint(1, 6) / 100,
        "accetableTasks": np.random.choice(
            task_kinds, size=np.random.randint(2, 5), replace=False
        ),
        "handleSafeTask": np.random.choice([0, 1], p=[0.25, 0.75]),
    }
    devices_data_IOT.append(device_info)

IoTdevices = pd.DataFrame(devices_data_IOT)

IoTdevices.set_index("id", inplace=True)
IoTdevices

Unnamed: 0_level_0,number_of_cpu_cores,occupied_cores,voltages_frequencies,ISL,capacitance,powerIdle,batteryLevel,errorRate,accetableTasks,handleSafeTask
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
0,4,"[0, 0, 1, 1]","[(16000000.0, 5.0), (2000000.0, 2.3), (8000000...",15,"[2.8393484748462677e-09, 2.3013173786224465e-0...","[0.0009, 0.0009, 0.0009, 0.0007999999999999999]",37000000000.0,0.03,"[3, 1]",0
1,6,"[1, 1, 1, 0, 1, 0]","[(32000000.0, 6.5), (16000000.0, 5.0), (100000...",11,"[2.22187929212393e-09, 2.7955996035069807e-09,...","[0.0007, 0.0007, 0.0007, 0.0007, 0.0007, 0.0007]",37000000000.0,0.04,"[4, 2]",0
2,4,"[1, 1, 0, 1]","[(8000000.0, 4.0), (1000000.0, 1.8), (2000000....",18,"[2.4786499180052465e-09, 2.4907512446789654e-0...","[0.0007999999999999999, 0.0007999999999999999,...",39000000000.0,0.03,"[3, 4, 2, 1]",1
3,4,"[1, 0, 1, 0]","[(32000000.0, 6.5), (2000000.0, 2.3), (1000000...",13,"[2.6502320013282e-09, 2.1964597476417556e-09, ...","[0.0007, 0.0007, 0.0007, 0.0007]",39000000000.0,0.01,"[2, 4, 3, 1]",1
4,8,"[1, 1, 1, 1, 0, 0, 0, 0]","[(4000000.0, 2.7), (8000000.0, 4.0), (2000000....",15,"[2.800776192794575e-09, 2.0393758414378526e-09...","[0.0009, 0.0009, 0.0007, 0.0007, 0.0009, 0.000...",38000000000.0,0.05,"[4, 3, 2, 1]",0


#### *MEC*

In [4]:
devices_data_MEC = []
for i in range(num_MEC_devices):
    cpu_cores = np.random.choice([16,32,64])
    device_info = {
        "id": i,
        "number_of_cpu_cores": cpu_cores,
        "occupied_cores": [np.random.choice([0, 1]) for _ in range(cpu_cores)],
        "voltages_frequencies": [
            voltages_frequencies_MEC[i]
            for i in np.random.choice(6, size=4, replace=False)
        ],
        "capacitance": [np.random.uniform(1.5, 2) * 1e-9 for _ in range(cpu_cores)],
        "powerIdle": [np.random.choice([9, 9, 10]) * 1e-5 for _ in range(cpu_cores)],
        "errorRate": np.random.randint(5, 11) / 100,
        "accetableTasks": np.random.choice(
            task_kinds, size=np.random.randint(2, 5), replace=False
        ),
        "handleSafeTask": np.random.choice([0, 1], p=[0.75, 0.25]),
    }
    devices_data_MEC.append(device_info)

MECDevices = pd.DataFrame(devices_data_MEC)

MECDevices.set_index("id", inplace=True)
MECDevices

Unnamed: 0_level_0,number_of_cpu_cores,occupied_cores,voltages_frequencies,capacitance,powerIdle,errorRate,accetableTasks,handleSafeTask
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
0,16,"[0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1]","[(3000000000.0, 2), (4000000000.0, 3.1), (7500...","[1.8355634530230644e-09, 1.539026369897655e-09...","[9e-05, 9e-05, 9e-05, 0.0001, 0.0001, 0.0001, ...",0.08,"[1, 2, 3]",1
1,64,"[0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, ...","[(750000000.0, 0.825), (1000000000.0, 1.0), (3...","[1.531574105278295e-09, 1.7236392489420708e-09...","[9e-05, 9e-05, 0.0001, 9e-05, 0.0001, 9e-05, 9...",0.08,"[2, 3, 4]",0
2,32,"[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, ...","[(600000000.0, 0.8), (3000000000.0, 2), (75000...","[1.7916878539861075e-09, 1.9105045622527975e-0...","[0.0001, 9e-05, 9e-05, 0.0001, 0.0001, 0.0001,...",0.1,"[4, 2, 3]",0
3,64,"[1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, ...","[(4000000000.0, 3.1), (1000000000.0, 1.0), (30...","[1.942670819383203e-09, 1.7158349841866995e-09...","[9e-05, 0.0001, 9e-05, 9e-05, 0.0001, 9e-05, 0...",0.09,"[3, 2, 4]",0
4,32,"[1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, ...","[(3000000000.0, 2), (750000000.0, 0.825), (400...","[1.8405550207350166e-09, 1.6574394008836663e-0...","[0.0001, 9e-05, 9e-05, 9e-05, 9e-05, 9e-05, 9e...",0.1,"[3, 4]",1
5,64,"[0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, ...","[(750000000.0, 0.825), (1000000000.0, 1.0), (3...","[1.7597609420885992e-09, 1.7788393088657065e-0...","[9e-05, 9e-05, 9e-05, 9e-05, 9e-05, 9e-05, 0.0...",0.09,"[4, 2, 1, 3]",0
6,32,"[0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, ...","[(750000000.0, 0.825), (3000000000.0, 2), (600...","[1.5680527444152568e-09, 1.9832816462821643e-0...","[9e-05, 9e-05, 9e-05, 0.0001, 0.0001, 9e-05, 9...",0.05,"[1, 3, 4]",0
7,32,"[1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, ...","[(1500000000.0, 1.2), (1000000000.0, 1.0), (40...","[1.9276235962285845e-09, 1.9673710192983593e-0...","[0.0001, 9e-05, 9e-05, 9e-05, 9e-05, 9e-05, 9e...",0.05,"[4, 2, 1]",0
8,32,"[0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, ...","[(1500000000.0, 1.2), (600000000.0, 0.8), (100...","[1.6111603328049387e-09, 1.993591016927708e-09...","[9e-05, 9e-05, 9e-05, 0.0001, 9e-05, 9e-05, 9e...",0.09,"[4, 1]",0
9,16,"[1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1]","[(750000000.0, 0.825), (600000000.0, 0.8), (15...","[1.6867791969600372e-09, 1.865984859948367e-09...","[0.0001, 0.0001, 0.0001, 9e-05, 9e-05, 0.0001,...",0.09,"[1, 4]",0


#### *CLOUD*

In [5]:
cloud = [(2.8e9, 1), (3.9e9, 2)]

### Step 2.2: Application

#### *helper function : generate_random_dag*

In [6]:
def generate_random_dag(num_nodes):
    dag = nx.DiGraph()

    nodes = [f"t{i+1}" for i in range(num_nodes)]
    dag.add_nodes_from(nodes)

    available_parents = {node: list(nodes[:i]) for i, node in enumerate(nodes)}

    for i in range(2, num_nodes + 1):
       
        num_parents = min(
            random.randint(1, min(i, max_num_parents_dag)), len(available_parents[f"t{i}"])
        )

        # select parents
        parent_nodes = random.sample(available_parents[f"t{i}"], num_parents)
        # add parents
        dag.add_edges_from((parent_node, f"t{i}") for parent_node in parent_nodes)

        # update available parents
        available_parents[f"t{i}"] = list(nodes[:i])

    return dag

#### *Generate task DAGs*

In [7]:
tasks_data = []

start_node_number = 1
for run in range(num_dag_generations):

    num_nodes = random.randint(min_num_nodes_dag, max_num_nodes_dag)

    random_dag = generate_random_dag(num_nodes)

    mapping = {
        f"t{i}": f"t{i + start_node_number - 1}" for i in range(1, num_nodes + 1)
    }

    random_dag = nx.relabel_nodes(random_dag, mapping)

    for node in random_dag.nodes:
        parents = list(random_dag.predecessors(node))
        task_info = {
            "id": node,
            "dependency": parents,
            "mobility": np.random.randint(1, 10),
            "kind": np.random.choice(task_kinds),
            "safe": np.random.choice([0, 1], p=[0.95, 0.05]),
            "computationalLoad": int(np.random.uniform(1, 100) * 1e4),
            "dataEntrySize": (np.random.randint(10, 100) // 10)
            * (10 ** np.random.choice([3, 6])),
            "returnDataSize": (np.random.randint(10, 100) // 10)
            * (10 ** np.random.choice([3, 6])),
            "status": "READY",
        }
        tasks_data.append(task_info)
    start_node_number += num_nodes

np.random.shuffle(tasks_data)
tasks = pd.DataFrame(tasks_data)

tasks.set_index("id", inplace=True)

tasks

Unnamed: 0_level_0,dependency,mobility,kind,safe,computationalLoad,dataEntrySize,returnDataSize,status
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
t155,[t154],9,4,0,346245,1000000,2000000,READY
t113,[t112],3,2,0,347400,8000000,7000,READY
t833,[t831],9,2,0,585702,7000,6000000,READY
t474,"[t468, t471, t473]",2,4,0,22830,8000,6000000,READY
t944,[t943],8,4,0,840595,9000000,3000000,READY
...,...,...,...,...,...,...,...,...
t721,[t720],3,4,0,39520,9000000,4000,READY
t830,"[t824, t828]",4,4,0,799135,9000,1000,READY
t60,"[t58, t59]",6,2,0,410841,1000,7000000,READY
t449,[t447],4,3,0,208234,2000000,4000,READY


In [16]:
tasks

Unnamed: 0_level_0,dependency,mobility,kind,safe,computationalLoad,dataEntrySize,returnDataSize,status,cluster
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
t155,[t154],9,4,0,346245,1000000,2000000,READY,7
t113,[t112],3,2,0,347400,8000000,7000,READY,3
t833,[t831],9,2,0,585702,7000,6000000,READY,3
t474,"[t468, t471, t473]",2,4,0,22830,8000,6000000,READY,7
t944,[t943],8,4,0,840595,9000000,3000000,READY,7
...,...,...,...,...,...,...,...,...,...
t721,[t720],3,4,0,39520,9000000,4000,READY,7
t830,"[t824, t828]",4,4,0,799135,9000,1000,READY,7
t60,"[t58, t59]",6,2,0,410841,1000,7000000,READY,3
t449,[t447],4,3,0,208234,2000000,4000,READY,5


## Step 3: Preprocessing

### Step 3.1: Clustering

#### *Tasks*

In [8]:
x = 0
for kind in task_kinds:
    for safe in (0, 1):
        selected_tasks = tasks.loc[(tasks["kind"] == kind) & (tasks["safe"] == safe), :]
        x += 1
        tasks.loc[selected_tasks.index, "cluster"] = x
        
tasks["cluster"] = tasks["cluster"].astype(int)
tasks

Unnamed: 0_level_0,dependency,mobility,kind,safe,computationalLoad,dataEntrySize,returnDataSize,status,cluster
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
t155,[t154],9,4,0,346245,1000000,2000000,READY,7
t113,[t112],3,2,0,347400,8000000,7000,READY,3
t833,[t831],9,2,0,585702,7000,6000000,READY,3
t474,"[t468, t471, t473]",2,4,0,22830,8000,6000000,READY,7
t944,[t943],8,4,0,840595,9000000,3000000,READY,7
...,...,...,...,...,...,...,...,...,...
t721,[t720],3,4,0,39520,9000000,4000,READY,7
t830,"[t824, t828]",4,4,0,799135,9000,1000,READY,7
t60,"[t58, t59]",6,2,0,410841,1000,7000000,READY,3
t449,[t447],4,3,0,208234,2000000,4000,READY,5


#### *IOT & MEC Devices*

In [9]:
x = 1
IoTdevices["cluster"] = [[] for _ in range(len(IoTdevices))]

for kind in task_kinds:
    for safe in (0, 1):
        selected_IoTdevices = IoTdevices[
            IoTdevices["accetableTasks"].apply(lambda lst: kind in lst)
        ]

        condition = (IoTdevices.index.isin(selected_IoTdevices.index)) & (
            (IoTdevices["handleSafeTask"] == 0) & (safe == 0)
            | (IoTdevices["handleSafeTask"] == 1)
        )

        IoTdevices.loc[condition, "cluster"] = IoTdevices.loc[
            condition, "cluster"
        ].apply(lambda lst: lst + [x])

        x += 1
# ---------------------------------
x = 1
MECDevices["cluster"] = [[] for _ in range(len(MECDevices))]

for kind in task_kinds:
    for safe in (0, 1):
        selected_MECDevices = MECDevices[
            MECDevices["accetableTasks"].apply(lambda lst: kind in lst)
        ]

        condition = (MECDevices.index.isin(selected_MECDevices.index)) & (
            (MECDevices["handleSafeTask"] == 0) & (safe == 0)
            | (MECDevices["handleSafeTask"] == 1)
        )

        MECDevices.loc[condition, "cluster"] = MECDevices.loc[
            condition, "cluster"
        ].apply(lambda lst: lst + [x])

        x += 1

MECDevices
IoTdevices

Unnamed: 0_level_0,number_of_cpu_cores,occupied_cores,voltages_frequencies,ISL,capacitance,powerIdle,batteryLevel,errorRate,accetableTasks,handleSafeTask,cluster
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
0,4,"[0, 0, 1, 1]","[(16000000.0, 5.0), (2000000.0, 2.3), (8000000...",15,"[2.8393484748462677e-09, 2.3013173786224465e-0...","[0.0009, 0.0009, 0.0009, 0.0007999999999999999]",37000000000.0,0.03,"[3, 1]",0,"[1, 5]"
1,6,"[1, 1, 1, 0, 1, 0]","[(32000000.0, 6.5), (16000000.0, 5.0), (100000...",11,"[2.22187929212393e-09, 2.7955996035069807e-09,...","[0.0007, 0.0007, 0.0007, 0.0007, 0.0007, 0.0007]",37000000000.0,0.04,"[4, 2]",0,"[3, 7]"
2,4,"[1, 1, 0, 1]","[(8000000.0, 4.0), (1000000.0, 1.8), (2000000....",18,"[2.4786499180052465e-09, 2.4907512446789654e-0...","[0.0007999999999999999, 0.0007999999999999999,...",39000000000.0,0.03,"[3, 4, 2, 1]",1,"[1, 2, 3, 4, 5, 6, 7, 8]"
3,4,"[1, 0, 1, 0]","[(32000000.0, 6.5), (2000000.0, 2.3), (1000000...",13,"[2.6502320013282e-09, 2.1964597476417556e-09, ...","[0.0007, 0.0007, 0.0007, 0.0007]",39000000000.0,0.01,"[2, 4, 3, 1]",1,"[1, 2, 3, 4, 5, 6, 7, 8]"
4,8,"[1, 1, 1, 1, 0, 0, 0, 0]","[(4000000.0, 2.7), (8000000.0, 4.0), (2000000....",15,"[2.800776192794575e-09, 2.0393758414378526e-09...","[0.0009, 0.0009, 0.0007, 0.0007, 0.0009, 0.000...",38000000000.0,0.05,"[4, 3, 2, 1]",0,"[1, 3, 5, 7]"


### Step 3.2: Queueing

In [10]:
sorted_tasks = tasks.sort_values(by="mobility").copy()
proccessingQueue = set()
for index, row in sorted_tasks.iterrows():
    if row["status"] == "Queued":
        continue
    for dep in row["dependency"]:
        if sorted_tasks.loc[dep, "status"] != "Queued":
            proccessingQueue.add(dep)
            sorted_tasks.loc[dep, "status"] = "Queued"
            dependecny_meet = False

    proccessingQueue.add(index)
    sorted_tasks.loc[index, "status"] = "Queued"

proccessingQueue = list(proccessingQueue)
proccessingQueue

['t682',
 't1046',
 't877',
 't459',
 't942',
 't35',
 't344',
 't718',
 't757',
 't704',
 't66',
 't189',
 't519',
 't830',
 't197',
 't216',
 't1040',
 't958',
 't391',
 't164',
 't476',
 't824',
 't1011',
 't724',
 't486',
 't835',
 't1051',
 't215',
 't1009',
 't1018',
 't100',
 't965',
 't260',
 't1029',
 't73',
 't427',
 't375',
 't219',
 't423',
 't356',
 't597',
 't1003',
 't419',
 't449',
 't368',
 't317',
 't159',
 't1135',
 't621',
 't933',
 't110',
 't316',
 't1007',
 't80',
 't21',
 't1047',
 't1024',
 't804',
 't666',
 't643',
 't919',
 't196',
 't688',
 't229',
 't210',
 't303',
 't237',
 't951',
 't27',
 't330',
 't71',
 't1017',
 't422',
 't538',
 't864',
 't399',
 't1026',
 't1066',
 't772',
 't867',
 't581',
 't981',
 't971',
 't734',
 't477',
 't801',
 't870',
 't472',
 't401',
 't160',
 't12',
 't698',
 't652',
 't355',
 't516',
 't348',
 't607',
 't892',
 't53',
 't191',
 't584',
 't1090',
 't320',
 't1068',
 't98',
 't1098',
 't212',
 't814',
 't1118',
 't924',
 

## Step 4 : DDT

### Step 4.1:  Initializing The tree

In [11]:

class DDTNode(nn.Module):
    def __init__(self, feature_size, num_classes, depth, max_depth):
        """
        Initializes the DifferentiableDecisionTreeNode.

        Args:
        - feature_size (int): The size of the input features.
        - num_classes (int): The number of output classes.
        - depth (int): The depth of the current node in the decision tree.
        - max_depth (int): The maximum depth allowed for the decision tree.
        """
        super(DDTNode, self).__init__()
        self.feature_size = feature_size
        self.num_classes = num_classes
        self.depth = depth
        self.max_depth = max_depth
        # learnable parameters
        self.weight = nn.Parameter(torch.randn(feature_size))
        self.bias = nn.Parameter(torch.randn(1))
        self.prob_distribution = nn.Parameter(torch.zeros(num_classes))
        # If not at leaf, create left and right child nodes
        if self.depth < self.max_depth:
            self.left_child = DDTNode(feature_size, num_classes, depth+1, max_depth)
            self.right_child = DDTNode(feature_size, num_classes, depth+1, max_depth)

    def forward(self, x):
        """
        Forward pass of the DifferentiableDecisionTreeNode.

        Args:
        - x (torch.Tensor): Input data of shape (batch_size, feature_size).

        Returns:
        - torch.Tensor: Output probabilities for each class for each instance in the batch.
        """
        # Check if we are at a leaf node
        if self.depth == self.max_depth:
            # We are at a leaf, return the softmax probabilities for all instances in the batch
            probs = F.softmax(self.prob_distribution, dim=0)
            print(probs)
            return probs.expand(x.size(0), -1)  
        else:
            # Process each item in the batch individually (not efficient!) TODO make it vectorized
            decisions = torch.sigmoid(torch.matmul(x, self.weight) + self.bias)
            batch_results = torch.zeros(x.size(0), self.num_classes)
            for i in range(x.size(0)):
                decision = decisions[i]
                if decision > 0.5:
                    batch_results[i] = self.right_child(x[i:i+1])
                else:
                    batch_results[i] = self.left_child(x[i:i+1])
            return batch_results


class DDT(nn.Module):
    def __init__(self, feature_size, num_classes, max_depth):
        """
        Initializes the DDT (Differentiable Decision Tree) model.

        Args:
        - feature_size (int): The size of the input features.
        - num_classes (int): The number of output classes.
        - max_depth (int): The maximum depth allowed for the decision tree.
        """
        super(DDT, self).__init__()
        self.root = DDTNode(feature_size, num_classes, 0, max_depth)

    def forward(self, x):
        """
        Forward pass of the DDT model.

        Args:
        - x (torch.Tensor): Input data of shape (batch_size, feature_size).

        Returns:
        - torch.Tensor: Output probabilities for each class for each instance in the batch.
        """
        return self.root(x)
    


### Step 4.2: Custom data generation functions

In [12]:
def create_dataset(num_samples, num_features, num_classes):
    """
    Function to create a synthetic dataset.

    Args:
    - num_samples (int): Number of samples in the dataset.
    - num_features (int): Number of features for each sample.
    - num_classes (int): Number of classes for classification.

    Returns:
    - torch.Tensor: Features tensor of shape (num_samples, num_features).
    - torch.Tensor: Labels tensor of shape (num_samples,) containing class labels.
    """
    features = torch.rand(num_samples, num_features)  # Generate random features
    labels = torch.randint(0, num_classes, (num_samples,))  # Generate random class labels
    return features, labels


def create_imbalanced_dataset(num_samples, num_features, num_classes, focus_class):
    """
    Generate an imbalanced dataset with a focus on a specific class.

    Args:
        num_samples (int): Number of samples in the dataset.
        num_features (int): Number of features for each sample.
        num_classes (int): Number of classes for the labels.
        focus_class (int): The class to focus on.

    Returns:
        torch.Tensor: Features tensor of shape (num_samples, num_features).
        torch.Tensor: Labels tensor of shape (num_samples,).
    """
    # Generate random features
    features = torch.rand(num_samples, num_features)
    
    # Generate random labels
    labels = torch.randint(0, num_classes, (num_samples,))
    
    # Assign labels focusing on the desired class
    focus_indices = random.sample(range(num_samples), int(0.8 * num_samples))
    labels[focus_indices] = focus_class
    
    # Assign labels randomly to other classes
    other_classes = [c for c in range(num_classes) if c != focus_class]
    for idx in range(num_samples):
        if idx not in focus_indices:
            labels[idx] = random.choice(other_classes)
    
    # Shuffle the dataset
    indices = list(range(num_samples))
    random.shuffle(indices)
    features = features[indices]
    labels = labels[indices]
    
    return features, labels


### Step 4.3: Train & Test functions

In [13]:
# Define a function to train the model
def train(model, train_loader, criterion, optimizer, epochs):
    """
    Function to train the model.

    Args:
    - model (torch.nn.Module): The model to be trained.
    - train_loader (torch.utils.data.DataLoader): DataLoader for training dataset.
    - criterion (torch.nn.Module): Loss function.
    - optimizer (torch.optim.Optimizer): Optimization algorithm.
    - epochs (int): Number of epochs for training.
    """
    model.train()  # Set the model to training mode
    for epoch in range(epochs):
        for inputs, labels in train_loader:
            optimizer.zero_grad()  # Clear gradients
            outputs = model(inputs.float())  # Forward pass
            loss = criterion(outputs, labels.long())  # Calculate loss
            loss.backward()  # Backward pass
            optimizer.step()  # Update parameters
        print(f"Epoch {epoch+1}, Loss: {loss.item()}")

# Define a function to test the model
def test(model, test_loader, criterion):
    """
    Function to evaluate the model on the test set.

    Args:
    - model (torch.nn.Module): The trained model.
    - test_loader (torch.utils.data.DataLoader): DataLoader for test dataset.
    - criterion (torch.nn.Module): Loss function.
    """
    model.eval()  # Set the model to evaluation mode
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            outputs = model(inputs.float())
            test_loss += criterion(outputs, labels.long()).item()
            pred = outputs.argmax(dim=1, keepdim=True)
            correct += pred.eq(labels.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    print(f"Test set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)")

# Generate the datasets
train_features, train_labels = create_dataset(800, 7, 3)
test_features, test_labels = create_dataset(200, 7, 3)

# Wrap the datasets in TensorDataset
train_dataset = TensorDataset(train_features, train_labels)
test_dataset = TensorDataset(test_features, test_labels)

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=10)

# Initialize the model, loss function, and optimizer
model = DDT(7, 3, 3)  # Create a DDT model with input size 7, 3 classes, and maximum depth of 3
criterion = nn.CrossEntropyLoss()  # Cross-entropy loss for multi-class classification
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam optimizer with learning rate 0.001

# Train and test the model
epochs = 10
train(model, train_loader, criterion, optimizer, epochs)  # Train the model
test(model, test_loader, criterion)  # Test the trained model


tensor([0.3333, 0.3333, 0.3333], grad_fn=<SoftmaxBackward0>)
tensor([0.3333, 0.3333, 0.3333], grad_fn=<SoftmaxBackward0>)
tensor([0.3333, 0.3333, 0.3333], grad_fn=<SoftmaxBackward0>)
tensor([0.3333, 0.3333, 0.3333], grad_fn=<SoftmaxBackward0>)
tensor([0.3333, 0.3333, 0.3333], grad_fn=<SoftmaxBackward0>)
tensor([0.3333, 0.3333, 0.3333], grad_fn=<SoftmaxBackward0>)
tensor([0.3333, 0.3333, 0.3333], grad_fn=<SoftmaxBackward0>)
tensor([0.3333, 0.3333, 0.3333], grad_fn=<SoftmaxBackward0>)
tensor([0.3333, 0.3333, 0.3333], grad_fn=<SoftmaxBackward0>)
tensor([0.3333, 0.3333, 0.3333], grad_fn=<SoftmaxBackward0>)
tensor([0.3333, 0.3337, 0.3330], grad_fn=<SoftmaxBackward0>)
tensor([0.3331, 0.3338, 0.3331], grad_fn=<SoftmaxBackward0>)
tensor([0.3333, 0.3337, 0.3330], grad_fn=<SoftmaxBackward0>)
tensor([0.3333, 0.3337, 0.3330], grad_fn=<SoftmaxBackward0>)
tensor([0.3331, 0.3338, 0.3331], grad_fn=<SoftmaxBackward0>)
tensor([0.3333, 0.3337, 0.3330], grad_fn=<SoftmaxBackward0>)
tensor([0.3333, 0.3337, 

### Additional Step: Testing Probability Dist weights update process using a focused dataset

In [14]:
# Generate the imbalanced dataset
num_samples = 2000
num_features = 7
num_classes = 3
focus_class = 0

train_features, train_labels = create_imbalanced_dataset(num_samples, num_features, num_classes, focus_class)
test_features, test_labels = create_dataset(200, num_features, num_classes)
# test_features, test_labels = create_imbalanced_dataset(200, num_features, num_classes,focus_class)


# Wrap the datasets in TensorDataset
train_dataset = TensorDataset(train_features, train_labels)
test_dataset = TensorDataset(test_features, test_labels)

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=10)

model = DDT(7, 3, 3)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Train and test the model
epochs = 10
train(model, train_loader, criterion, optimizer, epochs)
test(model, test_loader, criterion)

tensor([0.3333, 0.3333, 0.3333], grad_fn=<SoftmaxBackward0>)
tensor([0.3333, 0.3333, 0.3333], grad_fn=<SoftmaxBackward0>)
tensor([0.3333, 0.3333, 0.3333], grad_fn=<SoftmaxBackward0>)
tensor([0.3333, 0.3333, 0.3333], grad_fn=<SoftmaxBackward0>)
tensor([0.3333, 0.3333, 0.3333], grad_fn=<SoftmaxBackward0>)
tensor([0.3333, 0.3333, 0.3333], grad_fn=<SoftmaxBackward0>)
tensor([0.3333, 0.3333, 0.3333], grad_fn=<SoftmaxBackward0>)
tensor([0.3333, 0.3333, 0.3333], grad_fn=<SoftmaxBackward0>)
tensor([0.3333, 0.3333, 0.3333], grad_fn=<SoftmaxBackward0>)
tensor([0.3333, 0.3333, 0.3333], grad_fn=<SoftmaxBackward0>)
tensor([0.3333, 0.3333, 0.3333], grad_fn=<SoftmaxBackward0>)
tensor([0.3333, 0.3337, 0.3330], grad_fn=<SoftmaxBackward0>)
tensor([0.3333, 0.3333, 0.3333], grad_fn=<SoftmaxBackward0>)
tensor([0.3333, 0.3333, 0.3333], grad_fn=<SoftmaxBackward0>)
tensor([0.3333, 0.3337, 0.3330], grad_fn=<SoftmaxBackward0>)
tensor([0.3338, 0.3331, 0.3331], grad_fn=<SoftmaxBackward0>)
tensor([0.3333, 0.3337, 

### Step 4.4: Predict Function

In [15]:
def predict(model, input_data):
    """
    Function to perform inference using the trained model.

    Args:
    - model (torch.nn.Module): The trained model.
    - input_data (torch.Tensor): Input data for prediction of shape (batch_size, num_features).

    Returns:
    - torch.Tensor: Predicted classes for each input instance.
    """
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():
        output = model(input_data.float())  # Forward pass
        predicted_classes = output.argmax(dim=1)  # Get the index of the maximum value (predicted class)
    return predicted_classes

# Example input data for prediction
input_data = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], [0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4]]) 

# Perform prediction using the trained model
predicted_classes = predict(model, input_data)

# Print the predicted classes
print("Predicted classes:", predicted_classes)


tensor([0.8810, 0.0588, 0.0602])
tensor([0.8810, 0.0588, 0.0602])
Predicted classes: tensor([0, 0])
