In [252]:
import json
import boto3
import pandas as pd

from tqdm import tqdm
from ast import literal_eval
from collections import defaultdict
from postprocess_cpu_model_outputs import convert_current_dict_to_previous_one, get_predictions_all

tqdm.pandas()

In [253]:
cateogory_names = [
    "sectors", "subpillars_2d", "subpillars_1d", "geo_location", "specific_needs_groups",
    "severity", "age", "gender", "reliability", "affected_groups_level_0",
    "affected_groups_level_1", "affected_groups_level_2", "affected_groups_level_3"
]

In [254]:
dataset_path = "test_v0.7.1.csv"
df = pd.read_csv(dataset_path)

In [255]:
def rename_keys(d):
    keys = list(d.keys())
    for k in keys:
        d[f"{k}_pred"] = d.pop(k)
    return d

In [256]:
class Predictions:
    def __init__(self):
        self.endpoint_name = "main-model-cpu"
        self.sg_client = boto3.session.Session().client(
            "sagemaker-runtime", region_name='us-east-1'
        )

    def create_df(self, entry):
        df = pd.DataFrame({"excerpt": entry, "index": [0]})
        df["return_type"] = "default_analyis"
        df["analyis_framework_id"] = "all"

        df["interpretability"] = False
        df["ratio_interpreted_labels"] = 0.5
        df["return_prediction_labels"] = True

        df["output_backbone_embeddings"] = False
        return df.to_json(orient="split")

    def invoke_endpoint(self, backbone_inputs_json):
        response = self.sg_client.invoke_endpoint(
            EndpointName=self.endpoint_name,
            Body=backbone_inputs_json,
            ContentType="application/json; format=pandas-split"
        )
        response_json = json.loads(response["Body"].read().decode("ascii"))
        return response_json

    def get_predictions(self, excerpt):
        df_json = self.create_df(excerpt)
        predictions = self.invoke_endpoint(df_json)
        return predictions

In [257]:
predictions = Predictions()

In [258]:
counts = df["project_id"].value_counts()

In [261]:
valid_project_ids = counts[counts>1500].index

In [262]:
valid_project_ids

Float64Index([2028.0, 2170.0, 2098.0], dtype='float64')

In [263]:
df_new = df[df["project_id"].isin(valid_project_ids)]

In [264]:
len(df_new)

6291

In [265]:
df_new.head()

Unnamed: 0,entry_id,excerpt,analysis_framework_id,lead_id,project_id,verified,sectors,subpillars_2d,subpillars_1d,geo_location,...,affected_groups_level_1,affected_groups_level_2,affected_groups_level_3,source_type,url,website,lang,translation_en,translation_fr,translation_es
6,187423,"[9thNov 2020,Nigeria]With the latest update, N...",1306,43202.0,2170.0,False,[],[],['Covid-19->Cases'],['Nigeria'],...,[],[],[],website,https://www.premiumtimesng.com/news/headlines/...,www.premiumtimesng.com,en,,"[9thnov 2020, Nigeria] Avec la dernière mise à...","[9thnov 2020, Nigeria] Con la última actualiza..."
13,187425,"[9thNov 2020,Nigeria]The 94 new cases were rep...",1306,43202.0,2170.0,False,[],[],['Covid-19->Cases'],"['Edo', 'Kwara', 'Federal Capital Territory', ...",...,[],[],[],website,https://www.premiumtimesng.com/news/headlines/...,www.premiumtimesng.com,en,,"[9thnov 2020, Nigeria] Les 94 nouveaux cas ont...","[9thnov 2020, Nigeria] Los 94 nuevos casos se ..."
19,187426,"[5thNov 2020,Nigeria]Abuja, Nigeria’s capital,...",1306,43202.0,2170.0,False,[],[],"['Covid-19->Cases', 'Covid-19->Testing']",['Nigeria'],...,[],[],[],website,https://www.premiumtimesng.com/news/headlines/...,www.premiumtimesng.com,en,,"[5THNOV 2020, Nigeria] Abuja, capitale du Nigé...","[5thnov 2020, Nigeria] Abuja, la capital de Ni..."
40,187463,"[10th Nov 2020,NorthEast Nigeria] The North-Ea...",1306,43198.0,2170.0,False,"['Shelter', 'Cross']",['Impact->Impact On People'],['Displacement->Type/Numbers/Movements'],"['Adamawa', 'Nigeria', 'Borno', 'Yobe']",...,[],[],[],website,https://reliefweb.int/report/nigeria/strengthe...,www.reliefweb.int,en,,"[10 novembre 2020, nord-est du Nigéria] Les Ét...","[10 de noviembre de 2020, Northeast Nigeria] L..."
51,243576,"[January 9, Syria] The areas in northeastern S...",1306,47112.0,2028.0,False,['Health'],['Priority Interventions->Expressed By Populat...,[],['Syrian Arab Republic'],...,[],[],[],website,https://english.enabbaladi.net/archives/2021/0...,english.enabbaladi.net,en,,"[9 janvier, la Syrie] Les régions du nord-est ...","[9 de enero, Siria] Las áreas en el noreste de..."


In [267]:
sampled_df_per_project = df_new.groupby("project_id").sample(n=50).reset_index(drop=True)

In [268]:
len(sampled_df_per_project)

150

In [269]:
sampled_df_per_project.head(2)

Unnamed: 0,entry_id,excerpt,analysis_framework_id,lead_id,project_id,verified,sectors,subpillars_2d,subpillars_1d,geo_location,...,affected_groups_level_1,affected_groups_level_2,affected_groups_level_3,source_type,url,website,lang,translation_en,translation_fr,translation_es
0,174380,"[6th september,Overall Syria] Riza said there ...",1306,40125.0,2028.0,False,['Health'],['Humanitarian Conditions->Physical And Mental...,['Casualties->Dead'],['Syrian Arab Republic'],...,[],[],[],website,https://www.middleeasteye.net/news/coronavirus...,www.middleeasteye.net,en,,"[6 septembre, la Syrie globale] Riza a déclaré...","[6 de septiembre, General Siria] Riza dijo que..."
1,224488,"[December 21, As-Suwayda] Fear of the Coronavi...",1306,46147.0,2028.0,False,"['Protection', 'Education']",['Humanitarian Conditions->Physical And Mental...,[],['As-Sweida'],...,[],[],[],website,http://tishreen.news.sy/?p=599768,tishreen.news.sy,en,,"[21 décembre, As-Suwayda] La peur du Coronavir...","[21 de diciembre, AS-SUWAYDA] El miedo al Coro..."


In [270]:
endpoint_outputs = sampled_df_per_project.excerpt.progress_apply(predictions.get_predictions)

100%|██████████| 150/150 [00:54<00:00,  2.77it/s]


In [271]:
x = defaultdict(list)

eval_batches = endpoint_outputs.to_list()
for eval_batch in eval_batches:
    output_ratios = eval_batch["raw_predictions"]

    thresholds = eval_batch["thresholds"]

    clean_thresholds = convert_current_dict_to_previous_one(thresholds)

    clean_outputs = [
        convert_current_dict_to_previous_one(one_entry_preds)
        for one_entry_preds in output_ratios
    ]

    final_predictions = get_predictions_all(clean_outputs)
   
    for k, v in final_predictions.items():
        x[k].append(v[0])

In [272]:
x = rename_keys(x)

In [273]:
test_df = pd.DataFrame(x)

In [274]:
merged_df = pd.concat([sampled_df_per_project, test_df], axis=1)

In [275]:
merged_df

Unnamed: 0,entry_id,excerpt,analysis_framework_id,lead_id,project_id,verified,sectors,subpillars_2d,subpillars_1d,geo_location,...,translation_fr,translation_es,sectors_pred,subpillars_2d_pred,subpillars_1d_pred,age_pred,gender_pred,affected_groups_pred,specific_needs_groups_pred,severity_pred
0,174380,"[6th september,Overall Syria] Riza said there ...",1306,40125.0,2028.0,False,['Health'],['Humanitarian Conditions->Physical And Mental...,['Casualties->Dead'],['Syrian Arab Republic'],...,"[6 septembre, la Syrie globale] Riza a déclaré...","[6 de septiembre, General Siria] Riza dijo que...",[Health],[Humanitarian Conditions->Physical And Mental ...,[Casualties->Dead],[],[],[],[],[Critical]
1,224488,"[December 21, As-Suwayda] Fear of the Coronavi...",1306,46147.0,2028.0,False,"['Protection', 'Education']",['Humanitarian Conditions->Physical And Mental...,[],['As-Sweida'],...,"[21 décembre, As-Suwayda] La peur du Coronavir...","[21 de diciembre, AS-SUWAYDA] El miedo al Coro...","[Education, Health]","[Humanitarian Conditions->Living Standards, Im...",[],[Children/Youth (5 to 17 years old)],[],[],[],[Major]
2,245365,"[November 2020, North-west Syria] By November,...",1306,47945.0,2028.0,False,[],[],['Covid-19->Cases'],['Northwest'],...,"[Novembre 2020, Syrie du Nord-Ouest] En novemb...","[De noviembre de 2020, Siria noroeste] Para no...",[Health],[Humanitarian Conditions->Physical And Mental ...,[Covid-19->Cases],[],[],[IDP],[],[Of Concern]
3,326857,"(Idlib, 03/2021) The Director General of the G...",1306,51351.0,2028.0,False,['WASH'],"['Impact->Impact On Systems, Services And Netw...",[],[],...,"(Idlib, 03/2021) Le directeur général de l'Org...","(IDLIB, 03/2021) El Director General de la Org...",[WASH],"[Humanitarian Conditions->Living Standards, Im...",[],[],[],[],[],[Major]
4,164076,Trend analysis and overview: The number of con...,1306,38063.0,2028.0,False,['Health'],['Humanitarian Conditions->Physical And Mental...,['Casualties->Dead'],['Syrian Arab Republic'],...,Analyse des tendances et aperçu: le nombre de ...,Análisis y descripción general de la tendencia...,[Health],[Humanitarian Conditions->Physical And Mental ...,[Casualties->Dead],[],[],[],[],[Critical]
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
145,405617,"[9th - 15th Aug 2021, Borno State] Case Manage...",1306,63139.0,2170.0,False,[],[],['Covid-19->Cases'],['Borno'],...,"[9ème - 15 août 2021, État de Borno] Gestion d...","[9º - 15 de agosto 2021, estado BORNO] Gestión...",[],[],[Covid-19->Cases],[],[],[],[],[]
146,165963,"As of 17 September 2020, the number of confirm...",1306,39937.0,2170.0,False,['Health'],['Humanitarian Conditions->Physical And Mental...,[],['Nigeria'],...,"Au 17 septembre 2020, le nombre de cas de Covi...","A partir del 17 de septiembre de 2020, el núme...",[Health],[Humanitarian Conditions->Physical And Mental ...,[Casualties->Injured],[],[],[],[],[Of Concern]
147,172255,"[04/10/2020-10/10/2020,Nigeria]According to NC...",1306,40932.0,2170.0,False,['Health'],['Humanitarian Conditions->Physical And Mental...,[],['Nigeria'],...,"[04 / 10/2020-10/10/2020, Nigéria] Selon la NC...","[04/10 / 2020-10 / 10/2020, Nigeria] Según NCD...",[Health],[Humanitarian Conditions->Physical And Mental ...,[],[],[],[],[],[Of Concern]
148,160718,"As markets in the neighboring states of Borno,...",1306,39430.0,2170.0,False,"['Nutrition', 'Cross', 'Food Security', 'Livel...","['At Risk->Risk And Vulnerabilities', 'Impact-...",[],"['Jigawa', 'Borno', 'Bauchi', 'Gombe', 'Kano']",...,"Comme les marchés des États-Unis de Borno, Gom...",Como los mercados en los estados vecinos de Bo...,"[Food Security, Livelihoods]",[At Risk->Risk And Vulnerabilities],[],[],[],"[Host, IDP]",[],[]


In [276]:
for category in cateogory_names:
    merged_df[category] = merged_df[category].apply(literal_eval)

In [277]:
merged_df

Unnamed: 0,entry_id,excerpt,analysis_framework_id,lead_id,project_id,verified,sectors,subpillars_2d,subpillars_1d,geo_location,...,translation_fr,translation_es,sectors_pred,subpillars_2d_pred,subpillars_1d_pred,age_pred,gender_pred,affected_groups_pred,specific_needs_groups_pred,severity_pred
0,174380,"[6th september,Overall Syria] Riza said there ...",1306,40125.0,2028.0,False,[Health],[Humanitarian Conditions->Physical And Mental ...,[Casualties->Dead],[Syrian Arab Republic],...,"[6 septembre, la Syrie globale] Riza a déclaré...","[6 de septiembre, General Siria] Riza dijo que...",[Health],[Humanitarian Conditions->Physical And Mental ...,[Casualties->Dead],[],[],[],[],[Critical]
1,224488,"[December 21, As-Suwayda] Fear of the Coronavi...",1306,46147.0,2028.0,False,"[Protection, Education]",[Humanitarian Conditions->Physical And Mental ...,[],[As-Sweida],...,"[21 décembre, As-Suwayda] La peur du Coronavir...","[21 de diciembre, AS-SUWAYDA] El miedo al Coro...","[Education, Health]","[Humanitarian Conditions->Living Standards, Im...",[],[Children/Youth (5 to 17 years old)],[],[],[],[Major]
2,245365,"[November 2020, North-west Syria] By November,...",1306,47945.0,2028.0,False,[],[],[Covid-19->Cases],[Northwest],...,"[Novembre 2020, Syrie du Nord-Ouest] En novemb...","[De noviembre de 2020, Siria noroeste] Para no...",[Health],[Humanitarian Conditions->Physical And Mental ...,[Covid-19->Cases],[],[],[IDP],[],[Of Concern]
3,326857,"(Idlib, 03/2021) The Director General of the G...",1306,51351.0,2028.0,False,[WASH],"[Impact->Impact On Systems, Services And Netwo...",[],[],...,"(Idlib, 03/2021) Le directeur général de l'Org...","(IDLIB, 03/2021) El Director General de la Org...",[WASH],"[Humanitarian Conditions->Living Standards, Im...",[],[],[],[],[],[Major]
4,164076,Trend analysis and overview: The number of con...,1306,38063.0,2028.0,False,[Health],[Humanitarian Conditions->Physical And Mental ...,[Casualties->Dead],[Syrian Arab Republic],...,Analyse des tendances et aperçu: le nombre de ...,Análisis y descripción general de la tendencia...,[Health],[Humanitarian Conditions->Physical And Mental ...,[Casualties->Dead],[],[],[],[],[Critical]
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
145,405617,"[9th - 15th Aug 2021, Borno State] Case Manage...",1306,63139.0,2170.0,False,[],[],[Covid-19->Cases],[Borno],...,"[9ème - 15 août 2021, État de Borno] Gestion d...","[9º - 15 de agosto 2021, estado BORNO] Gestión...",[],[],[Covid-19->Cases],[],[],[],[],[]
146,165963,"As of 17 September 2020, the number of confirm...",1306,39937.0,2170.0,False,[Health],[Humanitarian Conditions->Physical And Mental ...,[],[Nigeria],...,"Au 17 septembre 2020, le nombre de cas de Covi...","A partir del 17 de septiembre de 2020, el núme...",[Health],[Humanitarian Conditions->Physical And Mental ...,[Casualties->Injured],[],[],[],[],[Of Concern]
147,172255,"[04/10/2020-10/10/2020,Nigeria]According to NC...",1306,40932.0,2170.0,False,[Health],[Humanitarian Conditions->Physical And Mental ...,[],[Nigeria],...,"[04 / 10/2020-10/10/2020, Nigéria] Selon la NC...","[04/10 / 2020-10 / 10/2020, Nigeria] Según NCD...",[Health],[Humanitarian Conditions->Physical And Mental ...,[],[],[],[],[],[Of Concern]
148,160718,"As markets in the neighboring states of Borno,...",1306,39430.0,2170.0,False,"[Nutrition, Cross, Food Security, Livelihoods]","[At Risk->Risk And Vulnerabilities, Impact->Im...",[],"[Jigawa, Borno, Bauchi, Gombe, Kano]",...,"Comme les marchés des États-Unis de Borno, Gom...",Como los mercados en los estados vecinos de Bo...,"[Food Security, Livelihoods]",[At Risk->Risk And Vulnerabilities],[],[],[],"[Host, IDP]",[],[]


In [278]:
merged_df.to_csv("sampled_data_with_predictions_testset.csv", index=False)
