In [37]:
import glob
import os
import pandas as pd
import collections
from tqdm import tqdm
import pickle
from pathlib import Path
from typing import Tuple

In [38]:
task_data_path='data/processed_tasks/c2_muse_topic'
transcription_path='data/processed_tasks/c2_muse_topic/transcription_segments'
root_path = 'C:/Users/Dell/Desktop/Imarticus-learning/Capstone/DL-capstone/NLP/graphbasedTM/'

In [39]:
def get_partition() -> (dict, collections.defaultdict):
    """
    get_partition fetches information on what sample ID is for training/developing/testing

    :param task_data_path:
    :param path: csv file that maps each sample ID to a train/devel/test
    :return: dicts with mappings between the sample IDs and the proposal
    """
    
    # any label to collect filenames safely
    filepath = root_path + task_data_path + '/label_segments/arousal/*.csv'
    names = glob.glob(filepath)
    sample_ids = []
    
    for n in names:
        name_split = n.split(os.path.sep)[-1].split('.')[0]
        sample_ids.append(int(name_split))
    sample_ids = set(sample_ids)

    df = pd.read_csv(root_path + "data/processed_tasks/metadata/partition.csv", delimiter=",")
    data = df[["Id", "Proposal"]].values


    partition_to_id = collections.defaultdict(set)

    for i in range(data.shape[0]):
        sample_id = int(data[i, 0])
        partition = data[i, 1]

        if sample_id not in sample_ids:
            continue

        partition_to_id[partition].add(sample_id)

    return partition_to_id




In [40]:
def read_classification_classes(label_file):
    """
    read_classification_classes is used to extract the class_ids from the label file

    :param label_file: path to csv file
    :return: list of class ids
    """

    df = pd.read_csv(label_file, delimiter=",", usecols=['class_id'])
    
    label_list = df['class_id'].tolist()
    return label_list




In [41]:
def sort_trans_files(elem):
#     print("inside sort_trans_files", elem)
    """
    sort_trans_files is used to calculate a key with which the transcriptions files are sorted

    :param elem: a file name
    :return: file weight used in sorting
    """
    return int(elem.split('_')[-1].split('.')[0])


In [42]:
def prepare_data() -> dict:
    """
    prepare_data creates a dict for the segment-level transcripts and their topic label
    :param task_data_path:
    :param transcription_path:
    :return: dict that consists of transcripts and their topic label
    """
    # Reading transcriptions on SEGMENT-level, sep. in train, develop, test of the official challenge
   
    partition_to_id = get_partition()

    data = {}

    # training with test labels available
    for partition in tqdm(partition_to_id.keys()):
        text = []
        y = []

        for sample_id in tqdm(sorted(partition_to_id[partition])):
            transcription_files = root_path + transcription_path+'/'+str(sample_id)+'/*.' + 'csv'
            filenames = glob.glob(transcription_files)

            for file in sorted(filenames, key=sort_trans_files):
                df = pd.read_csv(file, delimiter=',')
                words = df['word'].tolist()
                text.append(" ".join(words))

            # extracting labels available
            label_file = root_path + task_data_path+'/label_segments/topic/'+str(sample_id) + ".csv"
            label_list = read_classification_classes(label_file)

            for i in label_list:
                y.append(i)

        data[partition] = {'text': text, 'labels': y}

    return data





In [43]:
def zip_data(text: list, labels: list) -> (list, list):
    """
    filter_dataset returns a set of segments without the very short segments,
    and their respective labels

    :param data: segment data set
    :param data_labels: set of data labels
    :return:
        - filtered_data - segment data set without the short segments
        - filtered_data_labels -  labels of the returning segment data set
    """
    new_data, new_data_labels = zip(*((segment, label) for segment, label
                                      in zip(text, labels)
                                      if len([w for w in segment.split() if w.isalpha()]) > 2))

    return new_data, new_data_labels


In [None]:
def get_data(data_set: str, get_test_data: bool) \
        -> Tuple[list, list, list, list]:
    """
    get_data collects the data and test_data

    :param data_set: name of the data set (MUSE or CRR)
    :param get_test_data: getting test data flag
    :param task_data_path:  path to data task
    :param transcription_path: path to transcription

    :return:
        - training data -
        - training labels -
        - testing data -
        - testing labels -
    """
    
    if data_set == "CRR":
        # using the Citysearch Restaurant Reviews corpus
        data, test_data = ([], [])
        with open("data/restaurant/test.txt") as f:
            for line in f:
                if "\n" != line:
                    test_data.append(line.replace("\n", ""))

        with open("data/restaurant/train.txt") as f:
            for line in f:
                if "\n" != line:
                    data.append(line.replace("\n", ""))

        test_labels = [-1 for _ in range(len(test_data))]
        labels = [-1 for _ in range(len(data))]

        return data, labels, test_data, test_labels
    
    if data_set == "MUSE":
        
        if Path("data/train_text.pickle").is_file():
            with open("data/train_text.pickle", "rb") as myFile:
                train_text = pickle.load(myFile)
                
            with open("data/train_label.pickle", "rb") as myFile:
                train_label = pickle.load(myFile)
                
            if get_test_data:
                
                with open("data/test_text.pickle", "rb") as myFile:
                    test_text = pickle.load(myFile)
                    
                with open("data/test_label.pickle", "rb") as myFile:
                    test_label = pickle.load(myFile)
                    
            else:
                test_text = None
                test_label = None
                
        else: 
            data_twain = prepare_data()
            
            train_text = data_twain['train']['text']
            train_text.extend(data_twain['devel']['text'])
            
            train_label = data_twain['train']['labels']
            train_label.extend(data_twain['devel']['labels'])
            
            with open("data/train_text.pickle", "wb") as myFile:
                pickle.dump(train_text, myFile)
                
            with open("data/train_label.pickle", "wb") as myFile:
                pickle.dump(train_label, myFile)
                
            if get_test_data:
                test_text = data_twain['test']['text']
                test_label = data_twain['test']['labels']
                
                with open("data/test_text.pickle", "wb") as myFile:
                    pickle.dump(test_text, myFile)
                    
                with open("data/test_label.pickle", "wb") as myFile:
                    pickle.dump(test_label, myFile)
                    
            else:
                test_text = None
                test_label = None
            

    X_train, Y_train = zip_data(train_text, train_label)

    if get_test_data:
        X_test, Y_test = zip_data(test_text, test_label)
    else:
        X_test, Y_test = (None, None)

    return X_train, Y_train, X_test, Y_test

