In [15]:
import json
import os
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from sklearn.metrics import f1_score
from tqdm import tqdm
from wordcloud import WordCloud
from collections import Counter
from transformers import BertForTokenClassification, BertTokenizer, AdamW

In [6]:
DATA_DIR = "/kaggle/input/d/andreaschrter/gutbrainie2025/gutbrainie2025"

In [43]:
class AnnotationDataset(Dataset):
    def __init__(self, root_path, tokenizer=None, split='Train', quality_filter=['platinum_quality', 'gold_quality', 'silver_quality']):
        self.samples = []
        annotations_dir = os.path.join(root_path, 'Annotations', split)
            
        self.tokenizer = tokenizer
               
        if split == 'Train':
            for quality in quality_filter:  # filter out bronze quality since it contains autogenerated annotations
                quality_dir = os.path.join(annotations_dir, quality)
                json_format_dir = os.path.join(quality_dir, 'json_format')
                if not os.path.exists(json_format_dir):
                    print(f"No folder {json_format_dir} was found!")
                    continue
                
                # append data points (tuple of article identifier and corresponding annotations as a dictionary) to the sample list 
                for file_name in os.listdir(json_format_dir):
                    if file_name.endswith('.json'):
                        file_path = os.path.join(json_format_dir, file_name)
                        with open(file_path, 'r', encoding='utf-8') as f:
                            data = json.load(f)
                            #self.samples.extend(data.items())  
                            sorted_items = sorted(data.items(), key=lambda item: item[0])  # sort items by article identifier number
                            self.samples.extend(sorted_items)
                          
        elif split == 'Dev':
            json_format_dir = os.path.join(annotations_dir, 'json_format')
            if not os.path.exists(json_format_dir):
                raise FileNotFoundError(f"No folder {json_format_dir} was found!")
                
            json_files = [fname for fname in os.listdir(json_format_dir) if fname.endswith('.json')]
            for json_file in json_files:
                file_path = os.path.join(json_format_dir, json_file)
                with open(file_path, 'r', encoding='utf-8') as f:
                    data = json.load(f)
                    sorted_items = sorted(data.items(), key=lambda item: item[0]) 
                    self.samples.extend(sorted_items)        
        else:
            raise ValueError("Specify a split, must be either 'Train' or 'Dev'!")
        
    def __len__(self):
        return len(self.samples) 
    
    def __getitem__(self, idx):
        return self.samples[idx]  # one data point (=article id) with annotations
        
    def plot_abstract_lengths(self):
        """
        Plots the distribution of tokenized word lengths of abstracts using either whitespace tokenization or BERT tokenization.
        """
        abstract_lengths = []
        for article_id, data in self.samples:
            abstract = data['metadata'].get('abstract', '')
            
            if self.tokenizer:  # use BERT tokenizer if its given
                tokens = self.tokenizer.tokenize(abstract)
                token_count = len(tokens)
                abstract_lengths.append(token_count)
                tokenizer_type = "BERT Tokenized"
            else:  # white space tokenization (just as an overview, baselines use NLTK tokenizer)
                word_count = len(abstract.split())
                abstract_lengths.append(word_count)
                tokenizer_type = "Whitespace Tokenized"
                
        print("Maximum number of tokens per abstract: ", max(abstract_lengths))
        plt.figure(figsize=(8, 4))
        plt.hist(abstract_lengths, bins=30, color='#E6E6FA', edgecolor='#D1C8E3')
        plt.title(f"Distribution of Abstract Lengths ({tokenizer_type})", fontsize=14, fontweight='bold')
        plt.xlabel("Token Count" if self.tokenizer else "Word Count", fontsize=12, fontweight='medium')
        plt.ylabel("Frequency", fontsize=12, fontweight='medium')
        plt.grid(True, linestyle='--', alpha=0.5)
        plt.tick_params(axis='both', which='major', labelsize=12, length=6, width=1.2, direction='in', grid_alpha=0.5)
    
        plt.tight_layout()
        plt.show()

    
    def get_text_data(self):
        """
        Extracts title and abstract text from the dataset.
        """
        all_titles = []
        all_abstracts = []
        
        for _, data in self.samples:
            if 'metadata' in data:
                if 'title' in data['metadata'] and data['metadata']['title']:
                    all_titles.append(data['metadata']['title'])
                if 'abstract' in data['metadata'] and data['metadata']['abstract']:
                    all_abstracts.append(data['metadata']['abstract'])
            
        return " ".join(all_titles), " ".join(all_abstracts)

    def build_vocab(self): # important for vocabulary coverage check 
        """
        Tokenizes the dataset and builds a vocabulary.
        """
        vocab = Counter()
        all_titles, all_abstracts = self.get_text_data()  # get the raw text
        
        # tokenize text (based on whitespace and punctuation)
        words = re.findall(r'\b\w+\b', all_titles.lower()) + re.findall(r'\b\w+\b', all_abstracts.lower())
        
        vocab.update(words)  # count all word occurences
        return vocab

In [None]:
def create_label_wordcloud(dataset, entity_type, background_color='white', max_words=200, width=800, height=400):
    """
    Create and a wordcloud for all entity spans of a given entity type in the NERDataset."""
    # get all spans (entities)
    spans = []
    for article_id, data in dataset.samples:
        text = data['metadata'].get('title', '') + ' ' + data['metadata'].get('abstract', '')
        for ent in data.get('entities', []):
            if ent['label'].lower() == entity_type.lower():
                spans.append(ent['text_span'])

    freqs = Counter(spans)
    
    wc = WordCloud(background_color="white",collocations=False, colormap='RdPu_r', max_words=200, 
                   width=7200, height=4800).generate_from_frequencies(freqs)
    
    plt.figure(figsize=(12,8), dpi=600)
    plt.imshow(wc, interpolation='bilinear')
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(f"word_cloud_{entity_type}", dpi=600)
    plt.show()

dev_dataset = AnnotationDataset(DATA_DIR, tokenizer=tokenizer, split="Dev")
create_label_wordcloud(dev_dataset, 'human')