In [37]:
import project_path

%load_ext autoreload
%autoreload 2

import math
import numpy as np
import pandas as pd
from pandas import DataFrame
from tqdm.auto import tqdm

import nltk
nltk.download('punkt')

from datasets import load_from_disk, Dataset
from sentence_transformers import SentenceTransformer

from src.text_split import extract_paragraphs, split_long_paragraphs, collapse_paragraphs_iteratively
from src.paths import datap

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


[nltk_data] Downloading package punkt to /home/justy/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [38]:
tqdm.pandas()

In [39]:
path = "/home/justy/workspace/ea/ea-forum-announcements/data/labeled_posts"
data = load_from_disk(path).to_pandas()
data.shape

(343, 19)

### Split by paragraph

In [40]:
max_n_words = 400
data['paragraphs'] = data.body.progress_map(extract_paragraphs)
data['paragraphs'] = data.paragraphs.progress_map(lambda p: split_long_paragraphs(p, max_n_words=max_n_words))
data = data[~data.apply(lambda x: x.paragraphs.empty, axis=1)]
data['paragraphs_split'] = data.paragraphs.progress_map(lambda x: collapse_paragraphs_iteratively(x, max_n_words=max_n_words))

  0%|          | 0/343 [00:00<?, ?it/s]

  0%|          | 0/343 [00:00<?, ?it/s]

  0%|          | 0/336 [00:00<?, ?it/s]

In [41]:
paragraph_split = pd.concat([
    DataFrame({"postId": r._id, "text": r.paragraphs_split.text.values, "label": r.label})
    for pid,r in data.iterrows()
], ignore_index=True)
paragraph_split.shape

(2136, 3)

## Settings

In [42]:
embeddings = "e5-base" # other option "ea-forum"
aggregation = "" # other option "mean" if you want to have a mean paragraph embedding per post

### Embed the posts

In [43]:
# e5 base embeddings

import torch.nn.functional as F

from torch import Tensor
from transformers import AutoTokenizer, AutoModel


def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

In [44]:
# EA forum embeddings
if embeddings == "ea-forum":
    model_path = "/home/justy/workspace/ea/ea-forum-announcements/models/sbert_v1/sbert:v1"
    model = SentenceTransformer.load(model_path)
    paragraphs_embeddings = model.encode(paragraph_split.text[:1], batch_size=1, show_progress_bar=True)

In [None]:
# e5 base embeddings
# Each input text should start with "query: " or "passage: ".
# For tasks other than retrieval, you can simply use the "query: " prefix.
if embeddings == "e5-base":
    paragraph_split.text = paragraph_split.text.apply(lambda x: "query: " + x)
    tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-base')
    model = AutoModel.from_pretrained('intfloat/e5-base')
    batch_size = 1
    n_batches = math.ceil(len(paragraph_split) / batch_size)
    paragraphs_embeddings = np.zeros((n_batches, 768))
    for idx in tqdm(range(n_batches)):
        idx1, idx2 = idx * batch_size, (idx + 1) * batch_size
        idx2 = min(len(paragraph_split), idx2)
        batch_dict = tokenizer(paragraph_split.text[idx1:idx2].tolist(), max_length=400, padding=True, truncation=True, return_tensors='pt')
        outputs = model(**batch_dict)
        embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
        paragraphs_embeddings[idx, :] = F.normalize(embeddings, p=2, dim=1).detach()

  0%|          | 0/2136 [00:00<?, ?it/s]

In [None]:
paragraph_split = pd.concat((paragraph_split, pd.DataFrame(paragraphs_embeddings)), axis=1)

In [None]:
paragraph_split.shape

In [None]:
if aggregation == "mean":
    dataset = paragraph_split.groupby("postId").mean(numeric_only=True)
dataset.shape

In [None]:
dataset.label = dataset.label.astype("int8")

In [None]:
dataset = dataset.sample(frac=1)

In [None]:
dataset.to_csv(datap("labeled_paragraphs_embedded_e5base.csv"))

### No preprocessing

In [19]:
dataset = data[["body", "label"]]
dataset.head()

Unnamed: 0,body,label
0,\n\nOur programs exist to have a positive impa...,2.0
1,\n\n## The most important century is the one w...,2.0
2,\n\nMeet us at the Karbach Biergarten for an I...,2.0
3,"\n\nDisclaimer: We (Sam Nolan, Hannah Rokebran...",2.0
4,"\n\nAt Founders Pledge, we just launched a new...",0.0


In [13]:
dataset = dataset.rename({"body": "text"}, axis=1)

In [15]:
dataset = Dataset.from_pandas(dataset).train_test_split(test_size=0.5)

In [16]:
dataset['train'].save_to_disk(datap("all_paragraphs_labeled_only/train"))

Flattening the indices:   0%|          | 0/1 [00:00<?, ?ba/s]

In [17]:
dataset['test'].save_to_disk(datap("all_paragraphs_labeled_only/test"))

Flattening the indices:   0%|          | 0/1 [00:00<?, ?ba/s]

### Choose only the first paragraph

In [82]:
first_paragraphs = paragraph_split.drop_duplicates(subset="postId", keep="first")
first_paragraphs.head()

Unnamed: 0,postId,text,label
0,70,Meet us at the Karbach Biergarten for an Intro...,2.0
1,182,"At Founders Pledge, we just launched a new add...",0.0
4,255,"Today we are both launching our organization, ...",0.0
5,274,Summary\nWe’re excited to announce VIVID - a n...,0.0
11,317,We are pleased to introduce Cause Innovation B...,1.0


In [83]:
first_paragraphs.label.value_counts()

0.0    41
1.0    32
2.0    22
Name: label, dtype: int64

In [84]:
first_paragraphs.loc[:, "label"] = first_paragraphs.label.astype("int8")

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  first_paragraphs.loc[:, "label"] = first_paragraphs.label.astype("int8")
  first_paragraphs.loc[:, "label"] = first_paragraphs.label.astype("int8")


In [85]:
first_paragraphs.shape

(95, 3)

In [86]:
first_paragraphs = Dataset.from_pandas(first_paragraphs).train_test_split(test_size=0.5)

In [87]:
first_paragraphs

DatasetDict({
    train: Dataset({
        features: ['postId', 'text', 'label', '__index_level_0__'],
        num_rows: 47
    })
    test: Dataset({
        features: ['postId', 'text', 'label', '__index_level_0__'],
        num_rows: 48
    })
})

In [88]:
first_paragraphs['train'].save_to_disk(datap("first_paragraphs_labeled_only/train"))

Flattening the indices:   0%|          | 0/1 [00:00<?, ?ba/s]

In [89]:
first_paragraphs['test'].save_to_disk(datap("first_paragraphs_labeled_only/test"))

Flattening the indices:   0%|          | 0/1 [00:00<?, ?ba/s]

In [67]:
# Dataset.from_pandas(first_paragraphs).save_to_disk(datap("first_paragraphs_labeled_only"))

In [10]:
# pars_encoded = model.encode(par_split_df.text, show_progress_bar=True)

In [None]:
# DataFrame(pars_encoded).groupby(par_split_df.postId.values).mean().to_csv(datap("posts_encoded.csv"))