In [14]:
from pathlib import Path
import re
import xml.etree.ElementTree as ET
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
Literatur & Unterhaltung, Sachbuch, Kinderbuch & Jugendbuch, Ratgeber, Ganzheitliches Bewusstsein, Glaube & Ethik, Architektur & Garten, and Künste.

In [15]:
def fix_xml(entry):
    # Escape ampersands not part of an entity
    entry = re.sub(r'&(?![a-zA-Z]+;)', '&amp;', entry)
    # Remove percent-encoded newlines in URLs
    entry = entry.replace('%0A', '')
    return entry

def parse_books(file_path):
    with open(file_path, encoding='utf-8') as f:
        text = f.read()

    book_entries = re.findall(r'(<book.*?</book>)', text, re.DOTALL)
    books = []
    for entry in book_entries:
        entry_fixed = fix_xml(entry)
        try:
            root = ET.fromstring(entry_fixed)
            title = root.findtext('title')
            body = root.findtext('body')
            categories = []
            for cat in root.findall('.//category'):
                cat_topics = [topic.text for topic in cat.findall('topic')]
                categories.append(cat_topics)
            books.append({
                'text': f"{title}\n\n{body}",
                'category': categories[0][0],
                # 'categories': categories
            })
        except ET.ParseError as e:
            print(f"Error parsing entry: {e}")
    return books

In [16]:
def stratified_sample(df, category_col, n_total, min_per_category):
    # Sample min_per_category from each category, get their indices
    sampled_indices = []
    for cat, group in df.groupby(category_col):
        n = min(len(group), min_per_category)
        sampled_indices.extend(group.sample(n=n, random_state=42).index.tolist())
    
    # Remove already sampled indices
    remaining_df = df.drop(index=sampled_indices)
    remaining_n = n_total - len(sampled_indices)
    if remaining_n > 0:
        additional_indices = remaining_df.sample(n=remaining_n, random_state=42).index.tolist()
        all_indices = sampled_indices + additional_indices
    else:
        all_indices = sampled_indices[:n_total]
    
    # Return the sampled dataframe without duplicates
    return df.loc[all_indices].reset_index(drop=True)

In [17]:
dev_data = Path("data/blurbs_dev.txt")
dev_books = parse_books(dev_data)
dev_df = pd.DataFrame(dev_books)

In [18]:
test_data = Path("data/blurbs_test.txt")
test_books = parse_books(test_data)
test_df = pd.DataFrame(test_books)

In [19]:
test_df = pd.concat([test_df, dev_df], ignore_index=True)
test_df['category'].value_counts()

category
Literatur & Unterhaltung      3200
Sachbuch                       895
Kinderbuch & Jugendbuch        820
Ratgeber                       694
Ganzheitliches Bewusstsein     285
Glaube & Ethik                 232
Architektur & Garten            60
Künste                          50
Name: count, dtype: int64

In [20]:
test_df_sampled = stratified_sample(test_df, 'category', n_total=1700, min_per_category=50)
test_df_sampled['category'].value_counts()

category
Literatur & Unterhaltung      726
Sachbuch                      237
Kinderbuch & Jugendbuch       234
Ratgeber                      194
Ganzheitliches Bewusstsein    107
Glaube & Ethik                101
Architektur & Garten           51
Künste                         50
Name: count, dtype: int64

In [21]:
# remove rows with category "Literatur & Unterhaltung" so that only 250 rows remain
test_df_sampled2 = test_df_sampled[test_df_sampled['category'] != 'Literatur & Unterhaltung']
test_df_sampled3 = test_df_sampled[test_df_sampled['category'] == 'Literatur & Unterhaltung'].sample(n=250, random_state=42)
test_df_sampled_final = pd.concat([test_df_sampled2, test_df_sampled3], ignore_index=True)

In [22]:
train_data = Path("data/blurbs_train.txt")
train_books = parse_books(train_data)
train_df = pd.DataFrame(train_books)
train_df['category'].value_counts()

category
Literatur & Unterhaltung      7622
Sachbuch                      1999
Kinderbuch & Jugendbuch       1897
Ratgeber                      1630
Ganzheitliches Bewusstsein     638
Glaube & Ethik                 502
Künste                         133
Architektur & Garten           127
Name: count, dtype: int64

In [23]:
len(test_df) + len(train_df)

20784

In [24]:
train_df_sampled = stratified_sample(train_df, 'category', n_total=4000, min_per_category=120)
train_df_sampled['category'].value_counts()

category
Literatur & Unterhaltung      1777
Kinderbuch & Jugendbuch        537
Sachbuch                       531
Ratgeber                       459
Ganzheitliches Bewusstsein     244
Glaube & Ethik                 207
Künste                         124
Architektur & Garten           121
Name: count, dtype: int64

In [25]:
# remove rows with category "Literatur & Unterhaltung" so that only 250 rows remain
train_df_sampled2 = train_df_sampled[train_df_sampled['category'] != 'Literatur & Unterhaltung']
train_df_sampled3 = train_df_sampled[train_df_sampled['category'] == 'Literatur & Unterhaltung'].sample(n=550, random_state=42)
train_df_sampled_final = pd.concat([train_df_sampled2, train_df_sampled3], ignore_index=True)

In [26]:
def plot_column_distribution(dataframe, column_name):
    """
    Plots the distribution of a specified column in the DataFrame.
    Args:
        dataframe (pd.DataFrame): The DataFrame containing the data.
        column_name (str): The name of the column to plot.
    """
    sns.set(style="whitegrid")
    plt.figure(figsize=(12, 6))
    sns.countplot(data=dataframe, x=column_name, order=dataframe[column_name].value_counts().index)
    plt.title(f"Distribution of {column_name}")
    plt.xlabel(column_name)
    plt.ylabel("Count")
    plt.xticks(rotation=90)
    plt.tight_layout()
    plt.show()

In [None]:
plot_column_distribution(train_df_sampled_final, "category")

In [None]:
plot_column_distribution(test_df_sampled_final, "category")

In [27]:
len(test_df_sampled_final)

1224

In [None]:
# save the sampled dataframe to a new CSV file
# but we switch train and test (as we do not need many train examples for few-shot training)
parquet_output_path = Path("data/germeval_test.parquet")
train_df_sampled_final.to_parquet(parquet_output_path, index=False)
parquet_output_path = Path("data/germeval_train.parquet")
test_df_sampled_final.to_parquet(parquet_output_path, index=False)