In [None]:
import os
import time
import torch
import threading
import numpy as np
import pandas as pd
import torch.nn as nn
from tqdm import tqdm
import matplotlib.pyplot as plt

# Setup

In [None]:
part_id2str = {1:"01F",2:"01M",3:"02F",4:"02M",5:"03F",6:"03M",7:"04F",8:"04M",9:"05F",10:"05M"}
part_id = 6
label_path = "./Data/iemocap/iemocap_"+part_id2str[part_id]+".test.csv"
'''
formal:
        dict:{PL_ID:Log_timestamp, ..., PL_ID:Log_timestamp},
'''
PL_Log_Dict = {
        "S1":"2022-05-02_18-59",
        "S2":"2022-05-02_19-01",
        "T1":"2022-05-02_17-12",
        "T2":"2022-05-02_17-12",
        "M0":"2022-05-02_18-03",
        }
PL_Set_Size = 4
epoch_nums = 100

# Label

In [None]:
label2id = {'e0': 0, 'e1': 1, 'e2': 2, 'e3': 3}
label = np.array(pd.read_csv(label_path)["emotion"].map(lambda emo: label2id[emo]).values)
print("load label done! total",len(label),"samples.")

# Vote

### Combination

In [None]:
def Cal_Combination(iter_index=0, Comb_List=[]):
    if iter_index >= len(PL_List):
        if Comb_List!=[]:
            Combination.append(Comb_List)
        return
    # not contain index element
    Cal_Combination(iter_index+1, Comb_List[:])
    # contain index element
    Comb_List_Next = Comb_List[:]
    Comb_List_Next.append(PL_List[iter_index])
    Cal_Combination(iter_index+1, Comb_List_Next[:])

PL_List = []
for PL_ID in PL_Log_Dict.keys():
    PL_List.append(PL_ID)
Combination = []
Cal_Combination()
print("cal combination done! total",len(Combination),"combinations.")
# print(Combination)

### synchro mode

In [None]:
mode_decoder = {1:"Baseline", 2:"2-Cross", 3:"3-Cross", 4:"4-Cross", 5:"5-Cross"}
# process one combination
for PL_List in tqdm(Combination):
    PL_nums = len(PL_List)
    mode = part_id2str[part_id]+"/"+"Analyze/Synchro/"+mode_decoder[PL_nums]
    combination_ID = "".join(PL_List)  # concat all PL-IDs, e.g. S1S2.
    # logits path
    logits_root_path_list = ["./Logs/"+part_id2str[part_id]+"/"+PL_ID+"/Logits/"+PL_Log_Dict[PL_ID]+"/" for PL_ID in PL_List]
    # acc init
    max_acc = 0
    acc_result = []
    for epoch in range(1,epoch_nums+1):
        # acc init
        acc_num = 0
        # load logits
        logits_path = [root_path+str(epoch)+".csv" for root_path in logits_root_path_list]
        df_logits_list = [pd.read_csv(path, index_col="index")["logits"] for path in logits_path]
        for df_logits in df_logits_list:
            df_logits.sort_index(inplace=True)
        # analyze logits
        for index in range(len(label)):
            logits_list = [np.array(df_logits_list[PL_index][index][1:-1].split(",")).astype(np.float_) for PL_index in range(len(df_logits_list))]
            logits = np.zeros((4))
            for _logits in logits_list:
                logits += _logits
            prediction = np.argmax(logits)
            if label[index] == prediction:
                acc_num += 1
        acc = 100*acc_num/len(label)
        max_acc = max(max_acc, acc)
        acc_result.append(acc)
    # logging
    if not os.path.exists("./Logs/"+mode+"/"):
        os.makedirs("./Logs/"+mode+"/")
    df_acc_result = pd.DataFrame(acc_result, columns=["acc"])
    df_acc_result.to_csv("./Logs/"+mode+"/"+combination_ID+"_acc.txt", index=False)
    # plot
    plt.figure(num=combination_ID, figsize=(8,5), dpi=100)
    plt.title(mode_decoder[PL_nums]+" "+combination_ID+"("+str(np.round(max_acc, 2))+"%)", fontsize=14)
    plt.xlabel("epoch", fontsize=14)
    plt.ylabel("acc", fontsize=14)
    plt.plot(acc_result)
    plt.savefig("./Logs/"+mode+"/"+combination_ID+".png")
    # plt.show()

### interflow mode

In [None]:
'''
PLsets
    each PL has one PLset in the PLsets, total N PLset.
    each PLset has several PLout, total M PLout, and there is a "minacc" field to record the PLOuts' min acc,
         in order to quickly decide.
    each PLout identified by epoch_index and estimated by acc.
    form as:
        PLsets: {
            PLset0:{
                minacc:xx
                PLout0:{epoch_index:xx,acc:xx},
                PLout1:{epoch_index:xx,acc:xx},
                ...
                PLoutM:{epoch_index:xx,acc:xx}
            },
            ...
            PLsetN:{
                minacc:xx
                PLout0:{epoch_index:xx,acc:xx},
                PLout1:{epoch_index:xx,acc:xx},
                ...
                PLoutM:{epoch_index:xx,acc:xx}
            },
        }
PLout only depends on eval acc alone. When there is a better PLout, it will replace the worst one.
'''
PLsets = {}
'''
PLoutUpdate
    The newer PLout needs to calculate the new interflow acc. We use PLoutUpdate as the waiting queue.
    e.g. 
        PLSet0 has updated the PLout10,
        then we push "(0,10)" into PLoutUpdate as the meaning of PL0 had a newer PLout where its epoch_index is 10,
        and need to calculate acc based on PLSet0:PLout10 with all other PLs' PLout.
'''
PLoutUpdate = []
'''
PLset_available_out
record each PLout's epoch in distinguishing different PL way. 
this is what only required, then can further get epoch combination sequence and to calculate acc.
e.g.
    [
        [PL0_epochxx,PL0_epochxx,PL0_epochxx,PL0_epochxx],  # the available PLouts' epoch_index of PL0
        [PL1_epochxx],                                   # the available PLouts' epoch_index of PL1
        ...
    ]
'''
PLset_available_out = []
'''
Epoch_Combination
the combination of all PLs' available out.
combine the PLout from different PLset_available_out[PL_index].
e.g.
    [
        [PL0_epochxx, PL1_epochxx, PL2_epochxx, ... , PLN_epochxx],
        [PL0_epochxx, PL1_epochxx, PL2_epochxx, ... , PLN_epochxx],
        ...
    ]
'''
Epoch_Combination = []

# multi thread temp
acc_list = []
thread_acc = None

def Cal_Acc(update_PL_index, update_epoch, acc_list):
    # init PLset_available_out
    PLset_available_out = [[] for _ in range(len(PL_List))]
    for PL_index in range(len(PL_List)):
        if update_PL_index == PL_index:
            PLset_available_out[PL_index].append(update_epoch)
        else:
            for PLout_index in range(1, len(PLsets[PL_index])):  # PLout_index=0 is "accmin" field
                PLset_available_out[PL_index].append(PLsets[PL_index][PLout_index]["epoch_index"])
    Epoch_Combination = Cal_Epoch_Combination(PLset_available_out, Epoch_Combination=[])
    # cal acc
    for PLout_comb in Epoch_Combination:
        acc_num = 0
        # logits path
        logits_root_path_list = ["./Logs/"+part_id2str[part_id]+"/"+PL_ID+"/Logits/"+PL_Log_Dict[PL_ID]+"/" for PL_ID in PL_List]
        logits_path = []
        for PL_index in range(len(PL_List)):  
            logits_path.append(logits_root_path_list[PL_index]+str(PLout_comb[PL_index])+".csv")
        # load logits
        df_logits_list = [pd.read_csv(path, index_col="index")["logits"] for path in logits_path]
        for df_logits in df_logits_list:
            df_logits.sort_index(inplace=True)
        # analyze logits
        for index in range(len(label)):
            logits_list = [np.array(df_logits_list[PL_index][index][1:-1].split(",")).astype(np.float_) for PL_index in range(len(df_logits_list))]
            logits = np.zeros((4))
            for _logits in logits_list:
                logits += _logits
            prediction = np.argmax(logits)
            if label[index] == prediction:
                acc_num += 1
        acc = 100*acc_num/len(label)
        acc_list.append(acc)

def Update_PLSets(epoch, df_dev_acc_list):
    global PLsets
    global PLoutUpdate
    for PL_index, df_dev_acc in enumerate(df_dev_acc_list):
        dev_acc = df_dev_acc.iloc[epoch-1, 0]  # dataframe index is start at 0.
        # if PLset does not full, then just push.
        if len(PLsets[PL_index]) <= PL_Set_Size:
            PLsets[PL_index][len(PLsets[PL_index])] = {"epoch_index": epoch, "acc": dev_acc}
            PLsets[PL_index]["minacc"] = min(PLsets[PL_index]["minacc"], dev_acc)
            PLoutUpdate.append((PL_index, epoch))
        else:
            # if PLout is qualified go into PLset, the push, and replace the obsolete one.
            if PLsets[PL_index]["minacc"] <= dev_acc:
                obsolete_index = None
                new_minacc = dev_acc
                for PLout_index in range(1,PL_Set_Size+1):
                    if PLsets[PL_index][PLout_index]["acc"] < new_minacc:
                        new_minacc = PLsets[PL_index][PLout_index]["acc"]
                        obsolete_index = PLout_index
                # update
                if obsolete_index != None:
                    PLsets[PL_index][obsolete_index]["epoch_index"] = epoch
                    PLsets[PL_index][obsolete_index]["acc"] = dev_acc
                    PLsets[PL_index]["minacc"] = new_minacc
                    PLoutUpdate.append((PL_index, epoch))

def Cal_Epoch_Combination(PLset_available_out, Epoch_Combination=[], iter_index=0, Comb_List=[]):
    if iter_index >= len(PL_List):
        if Comb_List != []:
            Epoch_Combination.append(Comb_List)
        return
    for available_out in PLset_available_out[iter_index]:
        Comb_List_Next = Comb_List[:]
        Comb_List_Next.append(available_out)
        Cal_Epoch_Combination(PLset_available_out, Epoch_Combination, iter_index+1, Comb_List_Next[:])
    return Epoch_Combination
        
def Update_Acc():
    global PLsets
    global PLoutUpdate
    global acc_list
    global thread_acc
    acc_list = []
    acc_max = 0
    # check PLoutUpdate to recalculate acc.
    while PLoutUpdate != []:
        (update_PL_index, update_epoch) = PLoutUpdate.pop()
        thread_acc = threading.Thread(target=Cal_Acc, args=(update_PL_index, update_epoch, acc_list))
        thread_acc.start()
    thread_acc.join()
    for acc in acc_list:
        acc_max = max(acc_max, acc)
    return acc_max

In [None]:
mode_decoder = {1:"Baseline", 2:"2-Cross", 3:"3-Cross", 4:"4-Cross", 5:"5-Corss"}
# process one combination
for PL_List in tqdm(Combination):
    # init PL
    PL_nums = len(PL_List)
    for PL_index in range(PL_nums):
        PLsets[PL_index] = {"minacc":float("inf")}
    mode = part_id2str[part_id]+"/"+"Analyze/Interflow/"+mode_decoder[PL_nums]
    combination_ID = "".join(PL_List)  # concat all PL-IDs, e.g. S1S2.
    # dev_acc path
    dev_acc_path_list = ["./Logs/"+part_id2str[part_id]+"/"+PL_ID+"/"+PL_Log_Dict[PL_ID]+"_dev_acc.txt" for PL_ID in PL_List]
    # dev_acc
    df_dev_acc_list = [pd.read_csv(path, header=None) for path in dev_acc_path_list]
    # acc init
    max_acc = 0
    acc_result = []
    for epoch in range(1,epoch_nums+1):
        # update PLsets
        Update_PLSets(epoch, df_dev_acc_list)
        # update acc
        acc = Update_Acc()
        max_acc = max(max_acc, acc)
        acc_result.append(max_acc)
    # logging
    if not os.path.exists("./Logs/"+mode+"/"):
        os.makedirs("./Logs/"+mode+"/")
    df_acc_result = pd.DataFrame(acc_result, columns=["acc"])
    df_acc_result.to_csv("./Logs/"+mode+"/"+combination_ID+"_acc.txt", index=False)
    # plot
    plt.figure(num=combination_ID, figsize=(8,5), dpi=100)
    plt.title(mode_decoder[PL_nums]+" "+combination_ID+"("+str(np.round(max_acc, 2))+"%)", fontsize=14)
    plt.xlabel("epoch", fontsize=14)
    plt.ylabel("acc", fontsize=14)
    plt.plot(acc_result)
    plt.savefig("./Logs/"+mode+"/"+combination_ID+".png")
    plt.show()