In [1]:
%pip install pyriksdagen
from lxml import etree
import progressbar
from pyparlaclarin.read import paragraph_iterator, speeches_with_name
from pyriksdagen.utils import protocol_iterators, download_corpus
import pyriksdagen
# We need a parser for reading in XML data
parser = etree.XMLParser(remove_blank_text=True)
import pandas as pd
from sklearn.model_selection import train_test_split

Note: you may need to restart the kernel to use updated packages.


In [2]:
print(pyriksdagen.__spec__)
download_corpus(partitions=["politicians", "records"])

ModuleSpec(name='pyriksdagen', loader=<_frozen_importlib_external.SourceFileLoader object at 0x7473401ad5d0>, origin='/home/laurinemeier/anaconda3/lib/python3.11/site-packages/pyriksdagen/__init__.py', submodule_search_locations=['/home/laurinemeier/anaconda3/lib/python3.11/site-packages/pyriksdagen'])


politicians.zip: 100%|██████████| 2.20M/2.20M [00:00<00:00, 28.9MiB/s]




records.zip: 100%|██████████| 1.60G/1.60G [00:20<00:00, 84.8MiB/s]


In [2]:
protocols = list(protocol_iterators(corpus_root="data/", start=1867, end=202122))

In [4]:
def oppna_data_to_dict(input_dict):
    """
    Load protocols with the new XML / HTML structure (from 2013 onwards)
    and convert it to a python dict with contents.
    """
    data = {}
    data["paragraphs"] = []

    # Metadata
    session = input_dict["dokumentstatus"]["dokument"]["rm"]
    session = session.replace("/", "")
    pid = input_dict["dokumentstatus"]["dokument"]["nummer"]
    date = input_dict["dokumentstatus"]["dokument"]["datum"]
    html = input_dict["dokumentstatus"]["dokument"]["html"]
    html_tree = clean_html(html)
    year = int(date.split("-")[0])
    protocol_id = f"prot-{session}--{pid}"

    data["protocol_id"] = protocol_id
    data["date"] = date.split(" ")[0]
    data["session"] = session

    # New HTML structure with div[@class='Section1']
    section1 = html_tree.xpath(".//div[@class='Section1']")
    for elements in section1:
        for elem in elements:
            if elem.tag in ["p", "h1", "h2"]:
                elemtext = "".join(elem.itertext())
                linebreak = elemtext.strip() == "" and "\n" in elemtext
                if linebreak:
                    pass
                else:
                    paragraph = elemtext.strip()
                    paragraph = paragraph.replace("\n", " ")
                    paragraph = re.sub("\\s+", " ", paragraph)
                    data["paragraphs"].append(paragraph)

    if len(data["paragraphs"]) == 0:
        tree = html_tree

        # Old data structure 1990-2003
        pres = tree.findall(".//pre")
        if len(pres) > 0:
            for pre in pres:
                if pre.text is not None:
                    tblocks = re.sub("([a-zß-ÿ,])- ?\n ?([a-zß-ÿ])", "\\1\\2", pre.text)
                    tblocks = re.sub("([a-zß-ÿ,]) ?\n ?([a-zß-ÿ])", "\\1 \\2", tblocks)
                    for paragraph in tblocks.split("\n"):
                        paragraph = paragraph.replace("\n", " ")
                        paragraph = paragraph.replace("\n", " ")
                        data["paragraphs"].append(paragraph)

        # Standard HTML structure, roughly 2003-2013
        elif len(tree.xpath("//div[@class='indrag']")) > 0:
            tree = tree.xpath("//body")[0]
            for elem in tree:
                elemtext = "".join(elem.itertext())
                linebreak = elemtext.strip() == "" and "\n" in elemtext
                if elem.tag == "br" or linebreak:
                    pass
                else:
                    paragraph = elemtext.strip()
                    paragraph = paragraph.replace("\n", " ")
                    paragraph = re.sub("\\s+", " ", paragraph)
                    data["paragraphs"].append(paragraph)
    return data


In [8]:
protocol_in_question = protocols[12]
root = etree.parse(protocol_in_question, parser).getroot()

In [13]:

for elem in list(paragraph_iterator(root, output="lxml"))[:7]:
  print(" ".join(elem.itertext()))



          RIKSDAGENS Ar PROTOKOLL
        

          1955 ANDRA KAMMAREN Nr 13
        

          13—15 april
        

          Debatter m. m.
        

          Onsdagen den 13 april Sid.
        

          Familjerådgivning «... 5 Interpellation av herr Ericsson : i
          Näs ang. de minskade perioderna
        


In [3]:
data=[]

for i in range(len(protocols)):
  protocol_in_question = protocols[i]
  root = etree.parse(protocol_in_question, parser).getroot()
  element_str=""
  for elem in list(paragraph_iterator(root, output="lxml")):
    element_str += " ".join(elem.itertext()).replace("\n","")
    
  data.append({"protocole": i,"texte": "".join(element_str.split())})

df=pd.DataFrame(data)
print(df)


       protocole                                              texte
0              0  Sedan,ikraftafRiketsRegeringsform,lagtimaRiksd...
1              1  24Den19Januari.Lördagenden19Januari,KL!/,11f.m...
2              2  Den21Januari.25mandeaftidenfördessaval,erhålli...
3              3  Den22Januari.41Tisdagenden22Januari.Kl.10£m.g1...
4              4  Den23Januari.-55Onsdagenden23Januari.Kl.10f.m....
...          ...                                                ...
17637      17637  §1JusteringavprotokollProtokolletförden7juniju...
17638      17638  §1JusteringavprotokollProtokolletförden8juniju...
17639      17639  §1JusteringavprotokollProtokollenförden9,10,13...
17640      17640  §1AnmälanomåtertagandeavplatsiriksdagenTalmann...
17641      17641  §1AnmälanomsubsidiaritetsprövningarTalmannenan...

[17642 rows x 2 columns]


In [4]:
df.to_pickle("swerick_data_long.pkl")

In [21]:
from datasets import load_dataset
swerick_dataset = load_dataset("pandas",data_files="swerick_data.pkl")

Generating train split: 0 examples [00:00, ? examples/s]

In [22]:
print(swerick_dataset)

DatasetDict({
    train: Dataset({
        features: ['protocole', 'texte'],
        num_rows: 130
    })
})


In [5]:
df_train,df_test = train_test_split(df,test_size=0.2,random_state=42)
df_train.to_pickle("swerick_data_long_train.pkl")
df_test.to_pickle("swerick_data_long_test.pkl")

In [6]:
def tokenize_function(examples,tokenizer):
    result = tokenizer(examples["texte"])
    if tokenizer.is_fast:
        result["word_ids"] = [result.word_ids(i) for i in range(len(result["input_ids"]))]
    return result

In [7]:
def group_texts(examples,chunk_size):
    # Concatenate all texts
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    # Compute length of concatenated texts
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the last chunk if it's smaller than chunk_size
    total_length = (total_length // chunk_size) * chunk_size
    # Split by chunks of max_len
    result = {
        k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
        for k, t in concatenated_examples.items()
    }
    # Create a new labels column
    result["labels"] = result["input_ids"].copy()
    return result

In [8]:
model_checkpoint = "KBLab/bert-base-swedish-cased"

In [10]:
from transformers import AutoModelForMaskedLM
from transformers import AutoTokenizer
import torch
from datasets import load_dataset
from transformers import DataCollatorForLanguageModeling
from transformers import TrainingArguments
from transformers import Trainer
import math
from torch.utils.data import DataLoader
from transformers import default_data_collator
from torch.optim import AdamW
from accelerate import Accelerator
from transformers import get_scheduler
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

In [11]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [21]:
tokenizer.model_max_length

1000000000000000019884624838656

In [13]:
#datasest
data_files = {"train": "swerick_data_train.pkl", "test": "swerick_data_test.pkl"}
swerick_dataset = load_dataset("pandas",data_files=data_files)
print(swerick_dataset)


DatasetDict({
    train: Dataset({
        features: ['protocole', 'texte', '__index_level_0__'],
        num_rows: 104
    })
    test: Dataset({
        features: ['protocole', 'texte', '__index_level_0__'],
        num_rows: 26
    })
})


In [18]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)#add the MASK term

In [15]:
tokenized_datasets = swerick_dataset.map(
      lambda examples: tokenize_function(examples, tokenizer), batched=True, remove_columns=["texte", "protocole",'__index_level_0__']
)
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'word_ids'],
        num_rows: 104
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'word_ids'],
        num_rows: 26
    })
})

In [17]:
chunk_size =128
lm_datasets = tokenized_datasets.map( lambda examples: group_texts(examples,chunk_size), batched=True) #dataset with chunk
lm_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 57211
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 12561
    })
})

In [22]:
lm_datasets = lm_datasets.remove_columns(["word_ids","token_type_ids"])
batch_size = 64
train_dataloader = DataLoader(
    lm_datasets["train"],
    shuffle=True,
    batch_size=batch_size,
    collate_fn=data_collator,
)
train_dataloader = [
    inputs.to("cpu") for inputs in train_dataloader
]

In [23]:
len(train_dataloader)

894