In [1]:
import os
import sqlite3

from collections import defaultdict
from typing import List, Tuple

import mysql
import mysql.connector
import pandas as pd
from mysql.connector import Error
from tqdm.notebook import tqdm

In [2]:
class TrieNode:
    def __init__(self):
        self.children = defaultdict(TrieNode)
        self.is_end = False

class Trie:
    def __init__(self):
        self.root = TrieNode()

    def insert(self, conversation):
        node = self.root
        for tweet_id in conversation:
            node = node.children[tweet_id]
        node.is_end = True

    def is_subset(self, conversation):
        node = self.root
        for tweet_id in conversation:
            if tweet_id not in node.children:
                return False
            node = node.children[tweet_id]
        return True

def trace_conversation(start_tweet_id: str, tweet_dict: dict):
    convo = []
    current_tweet_id = start_tweet_id
    users_in_conversation = set()
    local_processed_tweet_ids = set()  # Local set to track the current conversation
    while current_tweet_id:
        if current_tweet_id not in tweet_dict or current_tweet_id in local_processed_tweet_ids:
            break
        tweet_info = tweet_dict[current_tweet_id]
        convo.append(current_tweet_id)
        users_in_conversation.add(tweet_info['user_id'])
        local_processed_tweet_ids.add(current_tweet_id)
        if len(users_in_conversation) > 2:
            return convo[:-1][::-1], users_in_conversation  # As soon as the third user appears, we delete his tweet and return
        current_tweet_id = tweet_info['replied_tweet_id']
    return (convo[::-1], users_in_conversation) if len(users_in_conversation) == 2 else (None, users_in_conversation)

def extract_and_filter_conversations(df: pd.DataFrame, user_ids: list):
    df = df.sort_values("tweet_creation_time", ascending=False)
    df.index = df.index.astype(str)
    tweet_dict = df.to_dict('index')
    conversations = []
    trie = Trie()  # Initialize trie for subset checks
    user_ids_set = set(user_ids)  # Convert list to set for faster membership checking

    # Start tracing conversations from tweets that are replies
    for tweet_id in tqdm(df[df['replied_tweet_id'].notnull()].index, desc="Extracting all conversations"):
        conversation, users_in_conversation = trace_conversation(tweet_id, tweet_dict)
        if conversation and user_ids_set.intersection(users_in_conversation):
            if not trie.is_subset(conversation):
                trie.insert(conversation)
                conversations.append(conversation)

    return conversations


def get_local_data(query: str, path: str, dtype: bool = True) -> pd.DataFrame:
    # Connect to the SQLite database using a context manager
    with sqlite3.connect(path) as connection:
        # Read the data into a DataFrame
        if dtype:
            df = pd.read_sql_query(query, connection,
                                   dtype=DTYPES,
                                   index_col='tweet_id')
            df['tweet_creation_time'] = pd.to_datetime(df['tweet_creation_time'])
            df['user_creation_time'] = pd.to_datetime(df['user_creation_time'])
        else:
            df = pd.read_sql_query(query, connection)
    
    return df


# def fetch_data(query: str, dtype: bool = True) -> pd.DataFrame:
#     engine = create_engine(f"mysql://{USER}:{PASSWORD}@{HOST}:3306/{DATABASE}")
#     if dtype:
#         return pd.read_sql_query(query, engine,
#                                  dtype=DTYPES, index_col='tweet_id')
#     return pd.read_sql_query(query, engine)



In [3]:
def check_given_var(env_var_str: str) -> str:
    """
    Check if the given environment variable is set and return its value.

    Args:
        env_var_str (str): The name of the environment variable to check.

    Returns:
        str: The value of the environment variable.

    Raises:
        AssertionError: If the environment variable is not found.
    """

    env_var = os.getenv(env_var_str)
    assert (
        env_var is not None
    ), f"{env_var_str} is required but not found in environment variables"
    return env_var


def check_env_vars() -> (str, str, str, str):  # type: ignore
    user = check_given_var("DBL_USER")
    database = check_given_var("DBL_DATABASE")
    password = check_given_var("DBL_PASSWORD")
    host = check_given_var("DBL_HOST")
    return user, database, password, host


USER, DATABASE, PASSWORD, HOST = check_env_vars()
# USER, DATABASE = "nezox2um_test", "nezox2um_test"
QUERY_ALL = """
SELECT 
    Users.user_id AS user_id, 
    Users.creation_time AS user_creation_time, 
    Tweets.creation_time AS tweet_creation_time,
    Tweets.tweet_id,
    Tweets.full_text,
    Tweets.lang,
    Tweets.replied_tweet_id
FROM Users
INNER JOIN Tweets ON Users.user_id = Tweets.user_id;
"""


DTYPES = {
"user_id": "object",
"tweet_id": "object",
"full_text": "object",
"lang": "category",
"replied_tweet_id": "object",
}
COMPANY_NAME_TO_ID = {
    "Klm": "56377143",
    "Air France": "106062176",
    "British Airways": "18332190",
    "American Air": "22536055",
    "Lufthansa": "124476322",
    "Air Berlin": "26223583",
    "Air Berlin assist": "2182373406",
    "easyJet": "38676903",
    "Ryanair": "1542862735",
    "Singapore Airlines": "253340062",
    "Qantas": "218730857",
    "Etihad Airways": "45621423",
    "Virgin Atlantic": "20626359",
}

# Extracting all the IDs into a list
company_ids = list(COMPANY_NAME_TO_ID.values())

In [4]:
# Server
# test_data = fetch_data(QUERY_ALL)
# Local
path =  os.path.join(
        os.path.dirname(
            os.getcwd()
        ),
    "data_processed", "local_backup.db")

test_data = get_local_data(QUERY_ALL, path)

In [5]:
convo_special = test_data[["user_id", "replied_tweet_id", "tweet_creation_time"]]
convo_special

Unnamed: 0_level_0,user_id,replied_tweet_id,tweet_creation_time
tweet_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
1131172858951024641,393374091,,2019-05-22 12:20:00+00:00
1130922003702177800,880417607865815040,1130615560910254080,2019-05-21 19:43:11+00:00
1131172864147808257,3420691215,,2019-05-22 12:20:01+00:00
1131172867985485824,394376606,1131032916232826881,2019-05-22 12:20:02+00:00
1131030279278063616,227687574,,2019-05-22 02:53:26+00:00
...,...,...,...
1244696703690772485,278698748,,2020-03-30 18:43:14+00:00
1244696708983984131,246520593,,2020-03-30 18:43:15+00:00
1244696710447800320,109284383,,2020-03-30 18:43:15+00:00
1244696713350217728,1223576386432126976,,2020-03-30 18:43:16+00:00


In [6]:
conversations = extract_and_filter_conversations(convo_special, company_ids)

Extracting all conversations:   0%|          | 0/1795409 [00:00<?, ?it/s]

In [7]:
conversations

[['1244694453190897664', '1244696682979303426'],
 ['1244677304598609923', '1244696641401163776'],
 ['1244644204132909060', '1244696371900436481'],
 ['1242875007270891523', '1244696352090656770'],
 ['1244663027452071936', '1244696298638450696'],
 ['1244550514970329088', '1244553548668579852', '1244696257781805056'],
 ['1244683000195022855', '1244696213552758787'],
 ['1244542518987014144', '1244542697366532096', '1244696138512556033'],
 ['1244695110639632386', '1244696125833175041'],
 ['1244542518987014144', '1244696104341471234'],
 ['1239668797218402305',
  '1241034899156545542',
  '1241147545096765443',
  '1242202731919675397',
  '1242339956263182336',
  '1244683643416698880',
  '1244689590549647361',
  '1244695969918222338'],
 ['1241498515039272960', '1244695952151240705'],
 ['1238774159498522624', '1244695827269984256'],
 ['1244689059072573440', '1244695718792699905'],
 ['1244692213369577473', '1244695542711693312', '1244695647271485448'],
 ['1244679105427181570', '124468914755980492

In [8]:
data = []
for convo_num, convo in enumerate(conversations, start=1):
    data.extend((convo_num, tweet_id) for tweet_id in convo)
# Create a DataFrame
df_conversations = pd.DataFrame(data, columns=['Conversation', 'Tweet_ID'])

# Set MultiIndex
df_conversations

Unnamed: 0,Conversation,Tweet_ID
0,1,1244694453190897664
1,1,1244696682979303426
2,2,1244677304598609923
3,2,1244696641401163776
4,3,1244644204132909060
...,...,...
1346561,493694,452657442057646080
1346562,493695,451124070730719233
1346563,493695,451125255294443521
1346564,493696,430790355962052608


In [9]:

# Merge the conversation DataFrame with the test_data DataFrame
df_conversations_full = df_conversations.merge(test_data, left_on='Tweet_ID', right_index=True, how='left')

# Set the MultiIndex again with Conversation and Tweet_ID
df_conversations_full.set_index(['Conversation', 'Tweet_ID'], inplace=True)
df_conversations_full

Unnamed: 0_level_0,Unnamed: 1_level_0,user_id,user_creation_time,tweet_creation_time,full_text,lang,replied_tweet_id
Conversation,Tweet_ID,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
1,1244694453190897664,521835883,2012-03-12 01:11:22+00:00,2020-03-30 18:34:17+00:00,@nealrach @VirginAtlantic Siiiigh.... Still no...,en,1243885949697888263
1,1244696682979303426,20626359,2009-02-11 20:50:56+00:00,2020-03-30 18:43:09+00:00,@Jade_Velveteese Hi Jade. We have an ‘Away fro...,en,1244694453190897664
2,1244677304598609923,396021583,2011-10-22 16:35:05+00:00,2020-03-30 17:26:09+00:00,@VirginAtlantic Sod off your primary sharehold...,en,1244669964289806338
2,1244696641401163776,832964639436701696,2017-02-18 14:47:00+00:00,2020-03-30 18:42:59+00:00,"@Boyde11 @VirginAtlantic Get your facts right,...",en,1244677304598609923
3,1244644204132909060,274980475,2011-03-31 11:55:53+00:00,2020-03-30 15:14:37+00:00,@easyJet Please reply to my DM!,en,1244643452589088771
...,...,...,...,...,...,...,...
493694,452657442057646080,2198564846,2013-11-16 23:24:47+00:00,2014-04-06 04:01:58+00:00,@AmericanAir They cannot hear my screams,en,452657293017227265
493695,451124070730719233,701977520,2012-07-17 23:34:18+00:00,2014-04-01 22:28:54+00:00,@AmericanAir i was kidding thanks for the foll...,en,451123990304522240
493695,451125255294443521,22536055,2009-03-02 21:23:05+00:00,2014-04-01 22:33:37+00:00,@lanaupdates_ Your information has been forwar...,en,451124070730719233
493696,430790355962052608,64327804,2009-08-10 03:34:27+00:00,2014-02-04 19:49:59+00:00,"@AmericanAir phew, they finally turned on the ...",en,


# Uploading conversations

In [10]:
# from itertools import islice

# def split_list_itertools(lst: list, batch_size: int):
#     it = iter(lst)
#     return iter(lambda: list(islice(it, batch_size)), [])

# def connect_to_database(user: str, database: str, password: str, host: str):
#     """
#     Establish a connection to the database.

#     Args:
#         user: The database user.
#         database: The name of the database.
#         password: The password for the database.
#         host: The database host.

#     Returns:
#         A connection object to the MySQL database.
#     """
#     try:
#         connection = mysql.connector.connect(
#             user=user, password=password, host=host, database=database
#         )
#         if connection.is_connected():
#             return connection
#     except Error as e:
#         print(f"Error while connecting to MySQL: {e}")
#     return None


# def insert_conversations(
#     batch_list, user: str, database: str, password: str, host: str
# ) -> None:
#     """
#     Create and insert batches of tweets into the database in parallel.

#     Args:
#         batches_list: Tuple of (sentiment, tweet_id) pairs.
#         user: The database user.
#         database: The name of the database.
#         password: The password for the database.
#         host: The database host.
#     """
#     connection = connect_to_database(user, database, password, host)
#     if connection is None:
#         return
#     cursor = connection.cursor()
#     insertion_conversations = """
#     INSERT IGNORE INTO Conversations(first_tweet_id, last_tweet_id, conversation)
#     VALUES(%s, %s, %s);
#     """
#     cursor.executemany(insertion_conversations, batch_list)
#     connection.commit()
#     cursor.close()
#     connection.close()


# def upload_data(conversations: pd.DataFrame, batch_size: int):
#     """
#     Convert DataFrame with tweet_id as index to a list of lists containing sentiment and tweet_id.

#     Args:
#         df: The DataFrame with tweet_id as index and sentiment as a column.

#     Returns:
#         A list of lists containing tweet_id and sentiment.
#     """
#     conversations_upload = [
#         [conversation[0], conversation[-1], ",".join(conversation[1:-1])]
#         for conversation in conversations
#     ]
#     for batch in tqdm(split_list_itertools(conversations_upload, batch_size)):
#         insert_conversations(batch, USER, DATABASE, PASSWORD, HOST)
    

In [11]:
# upload_data(conversations, 10000)

In [12]:
rows = []

# Loop through each conversation
for conv_id, conv in enumerate(conversations, start=1):
    # Loop through each tweet in the conversation
    rows.extend(
        (conv_id, order, tweet_id)
        for order, tweet_id in enumerate(conv, start=1)
    )
# Create a DataFrame from the rows
df = pd.DataFrame(rows, columns=['conversation_id', 'tweet_order', 'tweet_id'])
df

Unnamed: 0,conversation_id,tweet_order,tweet_id
0,1,1,1244694453190897664
1,1,2,1244696682979303426
2,2,1,1244677304598609923
3,2,2,1244696641401163776
4,3,1,1244644204132909060
...,...,...,...
1346561,493694,3,452657442057646080
1346562,493695,1,451124070730719233
1346563,493695,2,451125255294443521
1346564,493696,1,430790355962052608


In [18]:
import sqlite3

def connect_to_db(db_path):
    return sqlite3.connect(db_path)

def delete_conversations(db_name: str, table_name):
    """
    Connects to the SQLite database and deletes the specified table.
    
    Args:
    db_name (str): The name of the SQLite database.
    table_name (str): The name of the table to delete.
    
    Returns:
    str: A message indicating the result of the operation.
    """
    try:
        # Connect to the SQLite database
        conn = connect_to_db(db_name)
        cursor = conn.cursor()
        
        # Create the DROP TABLE SQL statement
        drop_table_sql = "DROP TABLE IF EXISTS ConversationsCategory;"
        # drop_table_sql = f"DROP TABLE IF EXISTS {table_name};"
        
        # Execute the SQL statement
        cursor.execute(drop_table_sql)
        
        # Commit the changes
        conn.commit()
        
        return f"Table '{table_name}' deleted successfully."
    
    except sqlite3.Error as error:
        return f"Error while deleting table: {error}"
    
    finally:
        # Close the database connection
        if conn:
            conn.close()


def create_conversations(db_name: str):
    """
    Connects to the SQLite database and deletes the specified table.
    
    Args:
    db_name (str): The name of the SQLite database.
    table_name (str): The name of the table to delete.
    
    Returns:
    str: A message indicating the result of the operation.
    """
    try:
        # Connect to the SQLite database
        conn = connect_to_db(db_name)
        cursor = conn.cursor()
        
        # Create the DROP TABLE SQL statement
        drop_table_sql = """
        CREATE TABLE IF NOT EXISTS Conversations(
        conversation_id INTEGER,
        tweet_order INTEGER,
        tweet_id VARCHAR(20),
        PRIMARY KEY (conversation_id, tweet_order),
        FOREIGN KEY (tweet_id) REFERENCES Tweets(tweet_id)
        )
"""
        
        # Execute the SQL statement
        cursor.execute(drop_table_sql)
        
        # Commit the changes
        conn.commit()
        
        return "Table created successfully."
    
    except sqlite3.Error as error:
        return f"Error while creating table: {error}"
    
    finally:
        # Close the database connection
        if conn:
            conn.close()

def create_categories(db_name: str):
    """
    Connects to the SQLite database and deletes the specified table.
    
    Args:
    db_name (str): The name of the SQLite database.
    table_name (str): The name of the table to delete.
    
    Returns:
    str: A message indicating the result of the operation.
    """
    try:
        # Connect to the SQLite database
        conn = connect_to_db(db_name)
        cursor = conn.cursor()
        
        # Create the DROP TABLE SQL statement
        drop_table_sql = """
        CREATE TABLE IF NOT EXISTS ConversationsCategory (
        conversation_id INTEGER PRIMARY KEY,
        category VARCHAR(255),
        confidence FLOAT,
        FOREIGN KEY (conversation_id) REFERENCES Conversations(conversation_id)
);
"""
        
        # Execute the SQL statement
        cursor.execute(drop_table_sql)
        
        # Commit the changes
        conn.commit()
        
        return "Table created successfully."
    
    except sqlite3.Error as error:
        return f"Error while creating table: {error}"
    
    finally:
        # Close the database connection
        if conn:
            conn.close()

# Example usage
table_name = "Conversations"

delete_conversations(path, table_name)
create_conversations(path)
create_categories(path)

'Table created successfully.'

In [27]:
from itertools import islice

def split_list_itertools(lst: list, batch_size: int):
    it = iter(lst)
    return iter(lambda: list(islice(it, batch_size)), [])


def insert_conversations(batch_list, db_path: str) -> None:
  """
  Create and insert batches of tweets into the database in parallel.

  Args:
      batch_list: Tuple of (sentiment, tweet_id) pairs.
      db_path: Path to the database file.
  """

  with connect_to_db(db_path) as connection:
    cursor = connection.cursor()
    # insertion_conversations = """
    #   INSERT OR IGNORE INTO Conversations(conversation_id, tweet_order, tweet_id)
    #   VALUES(?, ?, ?);
    # """
    insert_category = """
      INSERT OR IGNORE INTO ConversationsCategory(conversation_id, category, confidence)
      VALUES(?, ?, ?);
    """
    # print(batch_list)
    # cursor.executemany(insertion_conversations, batch_list)
    cursor.executemany(insert_category, batch_list)
    connection.commit()


def upload_data(conversations: pd.DataFrame, batch_size: int, db_path):
    """
    Convert DataFrame with tweet_id as index to a list of lists containing sentiment and tweet_id.

    Args:
        df: The DataFrame with tweet_id as index and sentiment as a column.

    Returns:
        A list of lists containing tweet_id and sentiment.
    """
    conversations_upload = conversations.values.tolist()
    categories_to_upload = [[row, "No Category", 0] for row in list({row[0] for row in conversations_upload})]
    for batch in tqdm(split_list_itertools(categories_to_upload, batch_size)):
        insert_conversations(batch, db_path)
    

In [28]:
upload_data(df, 10_000, path)

0it [00:00, ?it/s]

In [16]:
# list({row[0] for row in df.values.tolist()})

[1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39,
 40,
 41,
 42,
 43,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 51,
 52,
 53,
 54,
 55,
 56,
 57,
 58,
 59,
 60,
 61,
 62,
 63,
 64,
 65,
 66,
 67,
 68,
 69,
 70,
 71,
 72,
 73,
 74,
 75,
 76,
 77,
 78,
 79,
 80,
 81,
 82,
 83,
 84,
 85,
 86,
 87,
 88,
 89,
 90,
 91,
 92,
 93,
 94,
 95,
 96,
 97,
 98,
 99,
 100,
 101,
 102,
 103,
 104,
 105,
 106,
 107,
 108,
 109,
 110,
 111,
 112,
 113,
 114,
 115,
 116,
 117,
 118,
 119,
 120,
 121,
 122,
 123,
 124,
 125,
 126,
 127,
 128,
 129,
 130,
 131,
 132,
 133,
 134,
 135,
 136,
 137,
 138,
 139,
 140,
 141,
 142,
 143,
 144,
 145,
 146,
 147,
 148,
 149,
 150,
 151,
 152,
 153,
 154,
 155,
 156,
 157,
 158,
 159,
 160,
 161,
 162,
 163,
 164,
 165,
 166,
 167,
 168,
 169,
 170,
 171,
 172,
 173,
 174,
 175,
 176,
 177,
 178,
 179,
 180,
 181,
 182,
 183,
 184,
 185