In [1]:
import pickle

import torch

import tensorflow_datasets as tfds
import tensorflow as tf

from flair.data import Sentence
from flair.embeddings import BertEmbeddings, DocumentPoolEmbeddings
from segtok.segmenter import split_single

# Embedding Model

In [2]:
albert = BertEmbeddings(bert_model_or_path="albert-base-v2")

albert_embedding = DocumentPoolEmbeddings([albert])

In [3]:
sent = Sentence("Berlin and Munich are nice cities .")
albert_embedding.embed(sent)

embedd_result = sent.get_embedding()
print(embedd_result.shape)
print(embedd_result)

torch.Size([3072])
tensor([-0.6863, -0.5820,  1.0685,  ...,  0.7118,  0.6721,  0.5402],
       device='cuda:0', grad_fn=<CatBackward>)


# Dataset

In [4]:
cnn_dailymail = tfds.load(name="cnn_dailymail")

INFO:absl:No config specified, defaulting to first: cnn_dailymail/plain_text
INFO:absl:Overwrite dataset info from restored data version.
INFO:absl:Reusing dataset cnn_dailymail (/home/yannik/tensorflow_datasets/cnn_dailymail/plain_text/3.0.0)
INFO:absl:Constructing tf.data.Dataset for split None, from /home/yannik/tensorflow_datasets/cnn_dailymail/plain_text/3.0.0


In [5]:
train_tds = cnn_dailymail['train']
test_tds = cnn_dailymail['test']
val_tds = cnn_dailymail['validation']

## Prepare Dataset
- for faster training, we will clean the data, compute the Albert-Base Embedding of all articles and save it to Files, so that we don't have to do it while training

In [6]:
def normalize_text(text):
    """Lowercase and remove quotes from a TensorFlow string."""
    text = tf.strings.lower(text)
    text = tf.strings.regex_replace(text,"'(.*)'", r"\1")
    return text


def map_func(features):
    article_text = normalize_text(features["article"])
    highlights_text = normalize_text(features['highlights'])
    
    return article_text.numpy().decode('UTF-8'), highlights_text.numpy().decode('UTF-8')
        

In [7]:
def get_embedding_of_article(article, i):
    list_embedding = []
    
    sentences = split_single(article)
    for j, sentence in enumerate(sentences):
        # cuts to long sentences of 
        if len(sentence) > 750:
            sentence = sentence[:750]
            
        if len(sentence) > 1:
            sent = Sentence(sentence)

            albert_embedding.embed(sent)
            x = sent.get_embedding()
            x = x.to('cpu').detach().numpy()
            list_embedding.append(x)
    return list_embedding

In [8]:
def embedd_ds(ds):
    new_ds = []
    for i, item in enumerate(ds):
        article, higlights = map_func(item)
        x = get_embedding_of_article(article, i)
        new_ds.append({"article": x, "article_text": article, "highlights": higlights})
    return new_ds

In [9]:
def save_file(data, filename):
    with open(filename, "wb") as f:
        pickle.dump(data, f)

        
def load_file(filename):  
    with open(filename, "rb") as f:
        return pickle.load(f)

In [10]:
train_ds = embedd_ds(train_tdf)
save_file(train_ds, "train.pkl")

test_ds = embedd_ds(test_df)
save_file(test_ds, "test.pkl")

val_ds = embedd_ds(val_tfd)
save_file(val_ds, "val.pkl")

NameError: name 'train_tdf' is not defined