In [None]:
from convokit import Corpus, download
from typing import List
from pathlib import Path
import pandas as pd

In [2]:
datasets_path = Path("./datasets/coarsediscourse")
test_path = datasets_path / "coursediscourse_test.parquet"
train_path = datasets_path / "coursediscourse_train.parquet"

In [None]:
corpus = Corpus(filename=download("reddit-coarse-discourse-corpus"))

In [None]:
sentences_list: List[List[str]] = []
labels_list: List[List[List[str]]] = []  

for conversation in corpus.iter_conversations():

    sentences = []
    labels = []

    speaker_map = {speaker_id: f"Speaker {(idx + 1)}"  for idx, speaker_id in enumerate(conversation.get_speaker_ids())}
    for utterance in conversation.iter_utterances():
        text = utterance.text
        text = " ".join([text_segment for text_segment in text.split("\n") if len(text_segment.split()) > 1])
        text = " ".join(text.split("\t"))
        text = " ".join(text.split())

        sentences.append(text)
        label = utterance.meta.get('majority_type', 'other')
        if label is None:
            label = 'other'
        if label == "negativereaction":
            label = "negative reaction"
        labels.append(label)

    assert len(sentences) == len(labels), "Number of labels and sentences do not match"
    sentences_list.append(sentences)
    labels_list.append(labels)

assert len(sentences_list) == len(labels_list), "Number of labels and sentences do not match"

# create dataframe
df = pd.DataFrame({"sentences": sentences_list, "labels": labels_list})

# unique labels
unique_labels = set()
for labels in labels_list:
    unique_labels.update(labels)
print(unique_labels)

In [None]:
df.head()

In [None]:
# split the df into training and test set
from sklearn.model_selection import train_test_split

train_df, test_df = train_test_split(df, test_size=0.4, random_state=42)

In [None]:
# save dataframe
train_df.to_parquet(train_path)
test_df.to_parquet(test_path)

In [None]:
# read dataframe
train_df = pd.read_parquet(train_path)
test_df = pd.read_parquet(test_path)