# Imports

In [1]:
import contextlib
import os
import sqlite3
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pycountry
import seaborn as sns
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
from scipy.special import softmax
import pandas as pd
from concurrent.futures import ThreadPoolExecutor
import numpy as np
from sqlalchemy import create_engine


  from .autonotebook import tqdm as notebook_tqdm


# Constants

In [2]:
def check_given_var(env_var_str: str) -> str:
    """
    Check if the given environment variable is set and return its value.

    Args:
        env_var_str (str): The name of the environment variable to check.

    Returns:
        str: The value of the environment variable.

    Raises:
        AssertionError: If the environment variable is not found.
    """

    env_var = os.getenv(env_var_str)
    assert (
        env_var is not None
    ), f"{env_var_str} is required but not found in environment variables"
    return env_var


def check_env_vars() -> (str, str, str, str):  # type: ignore
    user = check_given_var("DBL_USER")
    database = check_given_var("DBL_DATABASE")
    password = check_given_var("DBL_PASSWORD")
    host = check_given_var("DBL_HOST")
    return user, database, password, host


USER, DATABASE, PASSWORD, HOST = check_env_vars()
# USER, DATABASE = "nezox2um_test", "nezox2um_test"
QUERY_ALL = """
SELECT 
    Users.user_id AS user_id, 
    Users.creation_time AS user_creation_time, 
    Tweets.creation_time AS tweet_creation_time,
    Tweets.tweet_id,
    Tweets.full_text,
    Tweets.lang,
    Tweets.replied_tweet_id
FROM Users
INNER JOIN Tweets ON Users.user_id = Tweets.user_id;
"""


DTYPES = {
"user_id": "object",
"tweet_id": "object",
"full_text": "object",
"lang": "category",
"replied_tweet_id": "object",
}

COMPANY_NAME_TO_ID = {
    "Klm": "56377143",
    "Air France": "106062176",
    "British Airways": "18332190",
    "American Air": "22536055",
    "Lufthansa": "124476322",
    "Air Berlin": "26223583",
    "Air Berlin assist": "2182373406",
    "easyJet": "38676903",
    "Ryanair": "1542862735",
    "Singapore Airlines": "253340062",
    "Qantas": "218730857",
    "Etihad Airways": "45621423",
    "Virgin Atlantic": "20626359",
}

COMPANY_ID_TO_NAME = {v: k for k, v in COMPANY_NAME_TO_ID.items()}

# Helper functions

In [3]:
def get_local_data(query: str, path: str, dtype: bool = True) -> pd.DataFrame:
    # Connect to the SQLite database using a context manager
    with sqlite3.connect(path) as connection:
        # Read the data into a DataFrame
        if dtype:
            df = pd.read_sql_query(query, connection,
                                   dtype=DTYPES,
                                   index_col='tweet_id')
            df['tweet_creation_time'] = pd.to_datetime(df['tweet_creation_time'])
            df['user_creation_time'] = pd.to_datetime(df['user_creation_time'])
        else:
            df = pd.read_sql_query(query, connection)
    
    return df


def fetch_data(query: str, dtype: bool = True) -> pd.DataFrame:
    engine = create_engine(f"mysql://{USER}:{PASSWORD}@{HOST}:3306/{DATABASE}")
    if dtype:
        return pd.read_sql_query(query, engine,
                                 dtype=DTYPES, index_col='tweet_id')
    return pd.read_sql_query(query, engine)


# Loading

In [4]:
# Server
# test_data = fetch_data(QUERY_ALL)
# Local
path =  os.path.join(
    os.path.dirname(
        os.path.dirname(
            os.getcwd()
            )
        ),
    "data_processed", "local_backup.db")

test_data = get_local_data(QUERY_ALL, path)

# Conversations

In [6]:
convo_special = test_data.sort_values("tweet_creation_time", ascending=False)[["user_id", "replied_tweet_id"]]
convo_special

Unnamed: 0_level_0,user_id,replied_tweet_id
tweet_id,Unnamed: 1_level_1,Unnamed: 2_level_1
1244696713765564416,56784613,
1244696713350217728,1223576386432126976,
1244696710447800320,109284383,
1244696708983984131,246520593,
1244696703690772485,278698748,
...,...,...
773181150,10812972,
773176947,10812972,
773176924,10812972,
773176134,10812972,


In [7]:
# v1
from tqdm import tqdm
from collections import defaultdict

class TrieNode:
    def __init__(self):
        self.children = defaultdict(TrieNode)
        self.is_end = False

class Trie:
    def __init__(self):
        self.root = TrieNode()

    def insert(self, conversation):
        node = self.root
        for tweet_id in conversation:
            node = node.children[tweet_id]
        node.is_end = True

    def is_subset(self, conversation):
        node = self.root
        for tweet_id in conversation:
            if tweet_id not in node.children:
                return False
            node = node.children[tweet_id]
        return True



def trace_conversation(start_tweet_id, tweet_dict):
    convo = []
    current_tweet_id = start_tweet_id
    users_in_conversation = set()
    local_processed_tweet_ids = set()  # Local set to track the current conversation
    while current_tweet_id:
        if current_tweet_id not in tweet_dict or current_tweet_id in local_processed_tweet_ids:
            break
        tweet_info = tweet_dict[current_tweet_id]
        convo.append(current_tweet_id)
        users_in_conversation.add(tweet_info['user_id'])
        local_processed_tweet_ids.add(current_tweet_id)
        if len(users_in_conversation) > 2:
            return None  # More than two users, not an exclusive conversation
        current_tweet_id = tweet_info['replied_tweet_id']
    return convo[::-1] if len(users_in_conversation) == 2 else None


def extract_and_filter_conversations(df):
    tweet_dict = df.to_dict('index')
    conversations = []
    trie = Trie()  # Initialize trie for subset checks

    # Start tracing conversations from tweets that are replies
    for tweet_id in tqdm(df[df['replied_tweet_id'].notnull()].index,
                         desc="Extracting all conversations"):
        if conversation := trace_conversation(tweet_id, tweet_dict):
            if not trie.is_subset(conversation):
                trie.insert(conversation)
                conversations.append(conversation)

    return conversations

In [8]:
conversations = extract_and_filter_conversations(convo_special)

Extracting all conversations: 100%|██████████| 1795409/1795409 [00:08<00:00, 218684.12it/s]


In [9]:
data = []
for convo_num, convo in enumerate(conversations, start=1):
    data.extend((convo_num, tweet_id) for tweet_id in convo)
# Create a DataFrame
df_conversations = pd.DataFrame(data, columns=['Conversation', 'Tweet_ID'])

# Set MultiIndex
df_conversations

Unnamed: 0,Conversation,Tweet_ID
0,1,1244344799647449089
1,1,1244696491580628993
2,2,1244593729312362497
3,2,1244696406570475525
4,3,1242875007270891523
...,...,...
2206483,864591,451125255294443521
2206484,864592,430790355962052608
2206485,864592,430792524043931648
2206486,864593,248528541157834752


In [10]:

# Merge the conversation DataFrame with the test_data DataFrame
df_conversations_full = df_conversations.merge(test_data, left_on='Tweet_ID', right_index=True, how='left')

# Set the MultiIndex again with Conversation and Tweet_ID
df_conversations_full.set_index(['Conversation', 'Tweet_ID'], inplace=True)
df_conversations_full


Unnamed: 0_level_0,Unnamed: 1_level_0,user_id,user_creation_time,tweet_creation_time,full_text,lang,replied_tweet_id
Conversation,Tweet_ID,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
1,1244344799647449089,806967414193868800,2016-12-08 21:03:19+00:00,2020-03-29 19:24:53+00:00,How do you feel about your travel agency / air...,en,
1,1244696491580628993,2774375013,2014-09-18 09:48:49+00:00,2020-03-30 18:42:23+00:00,@adnansaleemiX @qatarairways @emirates @easyJe...,en,1244344799647449089
2,1244593729312362497,281046179,2011-04-12 14:30:18+00:00,2020-03-30 11:54:03+00:00,A group of 150 Irish citizens will arrive toda...,en,
2,1244696406570475525,907026711010836480,2017-09-10 23:43:14+00:00,2020-03-30 18:42:03+00:00,@AerLingus @dfatirl @British_Airways Any chanc...,en,1244593729312362497
3,1242875007270891523,22536055,2009-03-02 21:23:05+00:00,2020-03-25 18:04:27+00:00,We’re waiving change fees for customers who ha...,en,
...,...,...,...,...,...,...,...
864591,451125255294443521,22536055,2009-03-02 21:23:05+00:00,2014-04-01 22:33:37+00:00,@lanaupdates_ Your information has been forwar...,en,451124070730719233
864592,430790355962052608,64327804,2009-08-10 03:34:27+00:00,2014-02-04 19:49:59+00:00,"@AmericanAir phew, they finally turned on the ...",en,
864592,430792524043931648,22536055,2009-03-02 21:23:05+00:00,2014-02-04 19:58:36+00:00,@benjy_greenberg It looks like we'll have you ...,en,430790355962052608
864593,248528541157834752,19911051,2009-02-02 15:17:02+00:00,2012-09-19 21:06:36+00:00,Un-fucking believable!\nThanks @BritishAirways...,en,


In [11]:
airline_conversation = df_conversations_full.loc[df_conversations_full.index.get_level_values('Conversation').isin(df_conversations_full[df_conversations_full['user_id'] == COMPANY_NAME_TO_ID["Lufthansa"]].index.get_level_values('Conversation'))]
airline_conversation

Unnamed: 0_level_0,Unnamed: 1_level_0,user_id,user_creation_time,tweet_creation_time,full_text,lang,replied_tweet_id
Conversation,Tweet_ID,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
144,1244684192040071173,62555545,2009-08-03 16:35:21+00:00,2020-03-30 17:53:31+00:00,@lufthansa had an email stating changes to my ...,en,
144,1244688444162486273,124476322,2010-03-19 14:30:32+00:00,2020-03-30 18:10:24+00:00,"@Holgate1987 At the moment, my colleagues in t...",en,1244684192040071173
144,1244689066567794688,62555545,2009-08-03 16:35:21+00:00,2020-03-30 18:12:53+00:00,@lufthansa My fight is in 3 days !!! last time...,en,1244688444162486273
145,1244685953983225857,562252389,2012-04-24 16:49:01+00:00,2020-03-30 18:00:31+00:00,@lufthansa I requested a refund back on 16/03 ...,en,
145,1244689043952209921,124476322,2010-03-19 14:30:32+00:00,2020-03-30 18:12:47+00:00,@AndyHall52 Due to the sheer amount of refund ...,en,1244685953983225857
...,...,...,...,...,...,...,...
864459,1095473044573700096,1339792290,2013-04-09 17:56:30+00:00,2019-02-13 00:01:41+00:00,@lufthansa Read on @YahooNews you sued a guy f...,en,1094985662384594944
864465,1090912573304774662,124476322,2010-03-19 14:30:32+00:00,2019-01-31 10:00:00+00:00,"With 297 seats, the #A340-600 is next in line ...",en,
864465,1090913775417479168,18631142,2009-01-05 13:13:10+00:00,2019-01-31 10:04:46+00:00,"@lufthansa 747-400, then 747-8, then A380. :)",en,1090912573304774662
864471,1083376699780280321,478699784,2012-01-30 15:32:40+00:00,2019-01-10 14:55:08+00:00,Hi @lufthansa how come there is no 'Ms' option...,en,


In [12]:
airline_conversation = airline_conversation.reset_index()
airline_conversation['New_Conversation'] = pd.factorize(airline_conversation['Conversation'])[0] + 1
airline_conversation = airline_conversation.set_index(['New_Conversation', 'Tweet_ID'])
airline_conversation = airline_conversation.sort_index(level='New_Conversation')
airline_conversation

Unnamed: 0_level_0,Unnamed: 1_level_0,Conversation,user_id,user_creation_time,tweet_creation_time,full_text,lang,replied_tweet_id
New_Conversation,Tweet_ID,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
1,1244684192040071173,144,62555545,2009-08-03 16:35:21+00:00,2020-03-30 17:53:31+00:00,@lufthansa had an email stating changes to my ...,en,
1,1244688444162486273,144,124476322,2010-03-19 14:30:32+00:00,2020-03-30 18:10:24+00:00,"@Holgate1987 At the moment, my colleagues in t...",en,1244684192040071173
1,1244689066567794688,144,62555545,2009-08-03 16:35:21+00:00,2020-03-30 18:12:53+00:00,@lufthansa My fight is in 3 days !!! last time...,en,1244688444162486273
2,1244685953983225857,145,562252389,2012-04-24 16:49:01+00:00,2020-03-30 18:00:31+00:00,@lufthansa I requested a refund back on 16/03 ...,en,
2,1244689043952209921,145,124476322,2010-03-19 14:30:32+00:00,2020-03-30 18:12:47+00:00,@AndyHall52 Due to the sheer amount of refund ...,en,1244685953983225857
...,...,...,...,...,...,...,...,...
12912,1095473044573700096,864459,1339792290,2013-04-09 17:56:30+00:00,2019-02-13 00:01:41+00:00,@lufthansa Read on @YahooNews you sued a guy f...,en,1094985662384594944
12913,1090912573304774662,864465,124476322,2010-03-19 14:30:32+00:00,2019-01-31 10:00:00+00:00,"With 297 seats, the #A340-600 is next in line ...",en,
12913,1090913775417479168,864465,18631142,2009-01-05 13:13:10+00:00,2019-01-31 10:04:46+00:00,"@lufthansa 747-400, then 747-8, then A380. :)",en,1090912573304774662
12914,1083376699780280321,864471,478699784,2012-01-30 15:32:40+00:00,2019-01-10 14:55:08+00:00,Hi @lufthansa how come there is no 'Ms' option...,en,


# Sentiment Analysis

In [None]:
#Load the tokenizer and model once
model_name = "cardiffnlp/twitter-roberta-base-sentiment"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = TFAutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment")

#Check if GPU is available
if tf.test.is_gpu_available():
    device = '/GPU:0'
else:
    device = '/CPU:0'

#Set the labels for sentiment results
labels = {
    0 : 'negative',
    1 : 'neutral',
    2 : 'positive'
}

def process_batch(texts):
    """
    Apply sentiment analysis to a batch of texts using a pre-trained transformer model.

    Parameters:
    texts (list): A list of texts to analyze.

    Returns:
    list: A list of ranked sentiment labels for the input texts.

    This function uses a pre-trained transformer model for sequence classification to analyze the sentiment of the texts in the given batch. It returns a list of ranked sentiment labels for the input texts, where the first label is the most likely sentiment and the subsequent labels are less likely sentiments.
    """
    encoded_input = tokenizer(texts, return_tensors='tf', padding=True, truncation=True)
    with tf.device(device):
        output = model(encoded_input)

    scores = output[0].numpy()
    scores = softmax(scores, axis=1)

    sentiment_scores = scores[:, 2] - scores[:, 0]
    return sentiment_scores.tolist()

def apply_sentiment_analysis(df, text_column, batch_size=64, max_workers=4):
    """
    Apply sentiment analysis to the given DataFrame using a pre-trained transformer model.

    Parameters:
    df (pd.DataFrame): The input DataFrame containing the text column to analyze.
    text_column (str): The name of the column in the DataFrame containing the text to analyze.
    batch_size (int): The number of texts to process in each batch. Default is 128.
    max_workers (int): The maximum number of worker threads to use for parallel processing. Default is 4.

    Returns:
    pd.DataFrame: The input DataFrame with an additional 'sentiment' column containing the sentiment analysis results.

    This function uses a pre-trained transformer model for sequence classification to analyze the sentiment of the texts in the given DataFrame. It applies the sentiment analysis in parallel using multiple worker threads to improve performance. The results are then added to the input DataFrame as a new 'sentiment' column.
    """
    texts = df[text_column].tolist()
    results = []

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i+batch_size]
            results.extend(executor.submit(process_batch, batch).result())

    df['sentiment'] = results
    return df


In [None]:
pd.set_option('display.max_colwidth', None)  
full_text_df = airline_conversation[['full_text']].copy()
airline_conversation_sample = full_text_df.head(100)
airline_conversation_sample = apply_sentiment_analysis(airline_conversation_sample, 'full_text')
airline_conversation_sample
