In [3]:
import pandas as pd
import boto3
import json


from tqdm import tqdm
tqdm.pandas()

In [4]:
dataset_path = "test_v0.7.1.csv"
df = pd.read_csv(dataset_path)

In [5]:
class Embeddings:
    def __init__(self):
        self.endpoint_name = "main-model-cpu"
        self.sg_client = boto3.session.Session().client(
            "sagemaker-runtime", region_name='us-east-1'
        )

    def create_df(self, entry):
        df = pd.DataFrame({"excerpt": entry, "index": [0]})
        df["return_type"] = "default_analyis"
        df["analyis_framework_id"] = "all"

        df["interpretability"] = False
        df["ratio_interpreted_labels"] = 0.5
        df["return_prediction_labels"] = False

        df["output_backbone_embeddings"] = True
        df["pooling_type"] = "['mean_pooling']"
        df["finetuned_task"] = "['first_level_tags', 'subpillars']"
        df["embeddings_return_type"] = "list"

        return df.to_json(orient="split")

    def invoke_endpoint(self, backbone_inputs_json):
        response = self.sg_client.invoke_endpoint(
            EndpointName=self.endpoint_name,
            Body=backbone_inputs_json,
            ContentType="application/json; format=pandas-split"
        )
        embeddings = json.loads(response["Body"].read().decode("ascii"))
        return embeddings["output_backbone"][0]

    def get_embeddings(self, excerpt):
        df_json = self.create_df(excerpt)
        embeddings = self.invoke_endpoint(df_json)
        return embeddings

In [6]:
embeddings = Embeddings()

In [9]:
#df["embeddings"] = df.excerpt.apply(embeddings.get_embeddings)

In [11]:
#print(df.embeddings)
df.to_csv("data_with_embeddings.csv")

In [15]:
len(df.embeddings[3614])

768

In [7]:
counts = df["project_id"].value_counts()

In [8]:
valid_project_ids = counts[counts>1000].index

In [9]:
valid_project_ids

Float64Index([2028.0, 2170.0, 2098.0, 2099.0, 2225.0, 2311.0], dtype='float64')

In [10]:
df_new = df[df["project_id"].isin(valid_project_ids)]

In [11]:
len(df_new)

10419

In [12]:
df_new.head()

Unnamed: 0,entry_id,excerpt,analysis_framework_id,lead_id,project_id,verified,sectors,subpillars_2d,subpillars_1d,geo_location,...,affected_groups_level_1,affected_groups_level_2,affected_groups_level_3,source_type,url,website,lang,translation_en,translation_fr,translation_es
1,489433,Primary and secondary net enrollment rates are...,1306,67488.0,2225.0,False,['Education'],['Humanitarian Conditions->Living Standards'],['Context->Socio Cultural'],['République démocratique du Congo'],...,['Affected'],[],[],website,https://blogs.worldbank.org/education/free-pri...,blogs.worldbank.org,en,,Les taux de scolarisation nets primaires et se...,Las tasas de inscripción netas primarias y sec...
4,489430,"Like few other countries globally, the majorit...",1306,67488.0,2225.0,False,[],[],['Context->Socio Cultural'],['République démocratique du Congo'],...,['Affected'],[],[],website,https://blogs.worldbank.org/education/free-pri...,blogs.worldbank.org,en,,"Comme peu d'autres pays à l'échelle mondiale, ...",Al igual que algunos otros países a nivel mund...
5,489438,And the policy is already having significant i...,1306,67488.0,2225.0,False,['Education'],"['Impact->Impact On Systems, Services And Netw...",['Context->Socio Cultural'],['République démocratique du Congo'],...,['Affected'],[],[],website,https://blogs.worldbank.org/education/free-pri...,blogs.worldbank.org,en,,Et la politique a déjà des impacts significati...,Y la política ya está teniendo impactos signif...
6,187423,"[9thNov 2020,Nigeria]With the latest update, N...",1306,43202.0,2170.0,False,[],[],['Covid-19->Cases'],['Nigeria'],...,[],[],[],website,https://www.premiumtimesng.com/news/headlines/...,www.premiumtimesng.com,en,,"[9thnov 2020, Nigeria] Avec la dernière mise à...","[9thnov 2020, Nigeria] Con la última actualiza..."
7,489437,"But almost 2 years on, it is safe to say that ...",1306,67488.0,2225.0,False,['Education'],"['Capacities & Response->National Response', '...",['Context->Socio Cultural'],['République démocratique du Congo'],...,['Affected'],[],[],website,https://blogs.worldbank.org/education/free-pri...,blogs.worldbank.org,en,,"Mais près de 2 ans, il est prudent de dire que...","Pero casi 2 años, es seguro decir que la RDC p..."


In [13]:
sampled_df_per_project = df_new.groupby("project_id").sample(n=1000)

In [14]:
len(sampled_df_per_project)

6000

In [15]:
sampled_df_per_project.head(2)

Unnamed: 0,entry_id,excerpt,analysis_framework_id,lead_id,project_id,verified,sectors,subpillars_2d,subpillars_1d,geo_location,...,affected_groups_level_1,affected_groups_level_2,affected_groups_level_3,source_type,url,website,lang,translation_en,translation_fr,translation_es
2559,220648,"[9th Dec, Overall Syria] It is understood that...",1306,45655.0,2028.0,False,[],[],"['Covid-19->Deaths', 'Covid-19->Cases']",[],...,[],[],[],website,https://reliefweb.int/sites/reliefweb.int/file...,reliefweb.int,en,,"[9ème décembre, la Syrie globale] Il est enten...","[9 de diciembre, Siria general] Se entiende qu..."
5542,492314,"[August 2021, Overall Syria] In light of the d...",1306,67914.0,2028.0,False,"['Cross', 'Livelihoods']","['Impact->Impact On Systems, Services And Netw...",[],"['Hama', 'Syrian Arab Republic', 'Damascus', '...",...,"['Affected', 'Affected', 'Affected', 'Affected...","['Displaced', 'Non Displaced', 'Non Displaced'...","['None', 'Host', 'IDP']",website,https://fscluster.org/sites/default/files/docu...,fscluster.org,en,,"[Août 2021, Syrie globale] À la lumière de la ...","[De agosto de 2021, Siria general] A la luz de..."


In [16]:
sampled_df_per_project["embeddings"] = sampled_df_per_project.excerpt.progress_apply(embeddings.get_embeddings)

100%|██████████| 6000/6000 [48:13<00:00,  2.07it/s]    


In [17]:
sampled_df_per_project.to_csv("sampled_data_with_embeddings_testset.csv")
