In [1]:
# General purpose libraries
import boto3
import copy
import csv
import datetime
import json
import numpy as np
import pandas as pd
import s3fs
from collections import defaultdict
import time
import re
import random
from sentence_transformers import SentenceTransformer
import sentencepiece
from scipy.spatial import distance
from json import JSONEncoder
import sys
sys.path.append("/Users/dafirebanks/Projects/policy-data-analyzer/")
sys.path.append("C:/Users/jordi/Documents/GitHub/policy-data-analyzer/")
from tasks.data_loading.src.utils import *

### 1. Set up AWS

In [2]:
def aws_credentials_from_file(f_name):
    with open(f_name, "r") as f:
        creds = json.load(f)
    
    return creds["aws"]["id"], creds["aws"]["secret"]

def aws_credentials(path, filename):
    file = path + filename
    with open(file, 'r') as dict:
        key_dict = json.load(dict)
    for key in key_dict:
        KEY = key
        SECRET = key_dict[key]
    return KEY, SECRET

### 2. Optimized full loop

In [3]:
def aws_credentials(path, filename):
    file = path + filename
    with open(file, 'r') as dict:
        key_dict = json.load(dict)
    for key in key_dict:
        KEY = key
        SECRET = key_dict[key]
    return KEY, SECRET

def aws_credentials_from_file(f_name):
    with open(f_name, "r") as f:
        creds = json.load(f)
    
    return creds["aws"]["id"], creds["aws"]["secret"]

def load_all_sentences(language, s3, bucket_name, init_doc, end_doc):
    policy_dict = {}
    sents_folder = f"{language}_documents/sentences"
    
    for i, obj in enumerate(s3.Bucket(bucket_name).objects.all().filter(Prefix="english_documents/sentences/")):
        
        if not obj.key.endswith("/") and init_doc <= i < end_doc:
            
            serializedObject = obj.get()['Body'].read()
            policy_dict = {**policy_dict, **json.loads(serializedObject)}
            
    return labeled_sentences_from_dataset(policy_dict)

def save_results_as_separate_csv(results_dictionary, queries_dictionary, init_doc, results_limit, aws_id, aws_secret):
    path = "s3://wri-nlp-policy/english_documents/assisted_labeling"
    col_headers = ["sentence_id", "similarity_score", "text"]
    for i, query in enumerate(results_dictionary.keys()):
        filename = f"{path}/query_{queries_dictionary[query]}_{i}_results_{init_doc}.csv"
        pd.DataFrame(results_dictionary[query], columns=col_headers).head(results_limit).to_csv(filename, storage_options={"key": aws_id, "secret": aws_secret})

def labeled_sentences_from_dataset(dataset):
    sentence_tags_dict = {}

    for document in dataset.values():
        sentence_tags_dict.update(document['sentences'])

    return sentence_tags_dict

In [None]:
# Set up AWS
credentials_file = '/Users/dafirebanks/Documents/credentials.json'
aws_id, aws_secret = aws_credentials_from_file(credentials_file)
region = 'us-east-1'

s3 = boto3.resource(
    service_name = 's3',
    region_name = region,
    aws_access_key_id = aws_id,
    aws_secret_access_key = aws_secret
)

In [4]:
path = "C:/Users/jordi/Documents/claus/"
filename = "AWS_S3_keys_wri.json"
aws_id, aws_secret = aws_credentials(path, filename)
region = 'us-east-1'

bucket = 'wri-nlp-policy'

s3 = boto3.resource(
    service_name = 's3',
    region_name = region,
    aws_access_key_id = aws_id,
    aws_secret_access_key = aws_secret
)

In [5]:
# Define params
init_at_doc = 14778
end_at_doc = 16420

similarity_threshold = 0
search_results_limit = 500

language = "english"
bucket_name = 'wri-nlp-policy'

transformer_name = 'xlm-r-bert-base-nli-stsb-mean-tokens'
model = SentenceTransformer(transformer_name)


# Get all sentence documents

sentences = load_all_sentences(language, s3, bucket_name, init_at_doc, end_at_doc )

# Define queries
path = "../../input/"
filename = "English_queries.xlsx"
file = path + filename
df = pd.read_excel(file, engine='openpyxl', sheet_name = "Hoja1", usecols = "A:C")

queries = {}
for index, row in df.iterrows():
    queries[row['Query sentence']] = row['Policy instrument']



# Calculate and store query embeddings
query_embeddings = dict(zip(queries, [model.encode(query.lower(), show_progress_bar=False) for query in queries]))

# For each sentence, calculate its embedding, and store the similarity
query_similarities = defaultdict(list)

i = 0
for sentence_id, sentence in sentences.items():
    sentence_embedding = model.encode(sentence['text'].lower(), show_progress_bar=False)
    i += 1
    if i % 100 == 0:
        print(i)
    
    for query_text, query_embedding in query_embeddings.items():
        score = round(1 - distance.cosine(sentence_embedding, query_embedding), 4)
        if score > similarity_threshold:
            query_similarities[query_text].append([sentence_id, score, sentences[sentence_id]['text']])
            
# Sort results by similarity score
for query in query_similarities:
    query_similarities[query] = sorted(query_similarities[query], key = lambda x : x[1], reverse=True)
    
# Store results
save_results_as_separate_csv(query_similarities, queries, init_at_doc, search_results_limit, aws_id, aws_secret)


100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
4100
4200
4300
4400
4500
4600
4700
4800
4900
5000
5100
5200
5300
5400
5500
5600
5700
5800
5900
6000
6100
6200
6300
6400
6500
6600
6700
6800
6900
7000
7100
7200
7300
7400
7500
7600
7700
7800
7900
8000
8100
8200
8300
8400
8500
8600
8700
8800
8900
9000
9100
9200
9300
9400
9500
9600
9700
9800
9900
10000
10100
10200
10300
10400
10500
10600
10700
10800
10900
11000
11100
11200
11300
11400
11500
11600
11700
11800
11900
12000
12100
12200
12300
12400
12500
12600
12700
12800
12900
13000
13100
13200
13300
13400
13500
13600
13700
13800
13900
14000
14100
14200
14300
14400
14500
14600
14700
14800
14900
15000
15100
15200
15300
15400
15500
15600
15700
15800
15900
16000
16100
16200
16300
16400
16500
16600
16700
16800
16900
17000
17100
17200
17300
17400
17500
17600
17700
17800
17900
18000
18100
18200
18300
18400
1850

133000
133100
133200
133300
133400
133500
133600
133700
133800
133900
134000
134100
134200
134300
134400
134500
134600
134700
134800
134900
135000
135100
135200
135300
135400
135500
135600
135700
135800
135900
136000
136100
136200
136300
136400
136500
136600
136700
136800
136900
137000
137100
137200
137300
137400
137500
137600
137700
137800
137900
138000
138100
138200
138300
138400
138500
138600
138700
138800
138900
139000
139100
139200
139300
139400
139500
139600
139700
139800
139900
140000
140100
140200
140300
140400
140500
140600
140700
140800
140900
141000
141100
141200
141300
141400
141500
141600
141700
141800
141900
142000
142100
142200
142300
142400
142500
142600
142700
142800
142900
143000
143100
143200
143300
143400
143500
143600
143700
143800
143900
144000
144100
144200
144300
144400
144500
144600
144700
144800
144900
145000
145100
145200
145300
145400
145500
145600
145700
145800
145900
146000
146100
146200
146300
146400
146500
146600
146700
146800
146900
147000
147100
147200

250100
250200
250300
250400
250500
250600
250700
250800
250900
251000
251100
251200
251300
251400
251500
251600
251700
251800
251900
252000
252100
252200
252300
252400
252500
252600
252700
252800
252900
253000
253100
253200
253300
253400
253500
253600
253700
253800
253900
254000
254100
254200
254300
254400
254500
254600
254700
254800
254900
255000
255100
255200
255300
255400
255500
255600
255700
255800
255900
256000
256100
256200
256300
256400
256500
256600
256700
256800
256900
257000
257100
257200
257300
257400
257500
257600
257700
257800
257900
258000
258100
258200
258300
258400
258500
258600
258700
258800
258900
259000
259100
259200
259300
259400
259500
259600
259700
259800
259900
260000
260100
260200
260300
260400
260500
260600
260700
260800
260900
261000
261100
261200
261300
261400
261500
261600
261700
261800
261900
262000
262100
262200
262300
262400
262500
262600
262700
262800
262900
263000
263100
263200
263300
263400
263500
263600
263700
263800
263900
264000
264100
264200
264300

367200
367300
367400
367500
367600
367700
367800
367900
368000
368100
368200
368300
368400
368500
368600
368700
368800
368900
369000
369100
369200
369300
369400
369500
369600
369700
369800
369900
370000
370100
370200
370300
370400
370500
370600
370700
370800
370900
371000
371100
371200
371300
371400
371500
371600
371700
371800
371900
372000
372100
372200
372300
372400
372500
