# Imports

In [None]:
import pandas as pd
from os import getenv
from sqlalchemy import create_engine
import seaborn as sns
import matplotlib.pyplot as plt
import pycountry
import contextlib
import numpy as np

# Constants

In [None]:
def check_given_var(env_var_str: str) -> str:
    """
    Check if the given environment variable is set and return its value.

    Args:
        env_var_str (str): The name of the environment variable to check.

    Returns:
        str: The value of the environment variable.

    Raises:
        AssertionError: If the environment variable is not found.
    """

    env_var = getenv(env_var_str)
    assert (
        env_var is not None
    ), f"{env_var_str} is required but not found in environment variables"
    return env_var


def check_env_vars() -> (str, str, str, str):  # type: ignore
    user = check_given_var("DBL_USER")
    database = check_given_var("DBL_DATABASE")
    password = check_given_var("DBL_PASSWORD")
    host = check_given_var("DBL_HOST")
    return user, database, password, host


USER, DATABASE, PASSWORD, HOST = check_env_vars()
# Test database runs waaaaaaaaaaaaaaaaaay faster, yet slow
# USER, DATABASE = "nezox2um_test", "nezox2um_test"
QUERY_ALL = """
SELECT 
    Users.user_id AS user_id, 
    Users.creation_time AS user_creation_time, 
    Users.verified,
    Users.followers_count,
    Users.friends_count,
    Users.statuses_count,
    Users.default_profile,
    Users.default_profile_image,
    Tweets.creation_time AS tweet_creation_time,
    Tweets.tweet_id,
    Tweets.full_text,
    Tweets.lang,
    Tweets.country_code,
    Tweets.favorite_count,
    Tweets.retweet_count,
    Tweets.possibly_sensitive,
    Tweets.replied_tweet_id,
    Tweets.reply_count,
    Tweets.quoted_status_id,
    Tweets.quote_count
FROM Users
INNER JOIN Tweets ON Users.user_id = Tweets.user_id
LIMIT 2000000;

"""
DTYPES = {
"user_id": "object",
"user_creation_time": "datetime64[ns]",
"verified": "bool",
"followers_count": "int32",
"friends_count": "int32",
"statuses_count": "int32",
"default_profile": "bool",
"default_profile_image": "bool",
"tweet_creation_time": "datetime64[ns]",
"tweet_id": "object",
"full_text": "object",
"lang": "category",
"country_code": "category",
"favorite_count": "int32",
"retweet_count": "int32",
"possibly_sensitive": "bool",
"replied_tweet_id": "object",
"reply_count": "int32",
"quoted_status_id": "object",
"quote_count": "int32",
}

COMPANY_NAME_TO_ID = {
    "Klm": "56377143",
    "Air France": "106062176",
    "British Airways": "18332190",
    "American Air": "22536055",
    "Lufthansa": "124476322",
    "Air Berlin": "26223583",
    "Air Berlin assist": "2182373406",
    "easyJet": "38676903",
    "Ryanair": "1542862735",
    "Singapore Airlines": "253340062",
    "Qantas": "218730857",
    "Etihad Airways": "45621423",
    "Virgin Atlantic": "20626359",
}

COMPANY_ID_TO_NAME = {v: k for k, v in COMPANY_NAME_TO_ID.items()}

# Helper functions

In [None]:
def fetch_data(query: str, dtype: bool =True) -> pd.DataFrame:
    engine = create_engine(f"mysql://{USER}:{PASSWORD}@{HOST}:3306/{DATABASE}")
    if dtype:
        return pd.read_sql_query(query, engine, dtype=DTYPES, index_col='tweet_id')
    return pd.read_sql_query(query, engine)


def get_size_of(size_bytes: float | int) -> str:
    if size_bytes == 0:
        return "0B"
    size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
    # Using numpy to calculate the logarithm base 1024
    i = int(np.floor(np.log(size_bytes) / np.log(1024)))
    # Using numpy to calculate power of 1024
    p = np.power(1024, i)
    # Computing the size division
    s = round(size_bytes / p, 2)
    return f"{s} {size_name[i]}"


def identify_dtype(column):
    """
    Identifies the most suitable data type for a pandas Series without loss of information.

    Args:
    column (pd.Series): The pandas Series for which the data type needs to be identified.

    Returns:
    str: Suggested data type as a string.
    """
    # Check if the column can be converted to numeric types (int or float)
    if pd.api.types.is_numeric_dtype(column):
        if not pd.to_numeric(column.dropna(), errors='coerce').notna().all():
            return 'object'  # Fallback if numeric conversion fails

        if not (column.dropna() % 1 == 0).all():
            return 'float'
        # Check range to decide between int types
        min_val, max_val = column.min(), column.max()
        if np.iinfo(np.int8).min <= min_val <= np.iinfo(np.int8).max and max_val <= np.iinfo(np.int8).max:
            return 'int8'
        elif np.iinfo(np.int16).min <= min_val <= np.iinfo(np.int16).max and max_val <= np.iinfo(np.int16).max:
            return 'int16'
        elif np.iinfo(np.int32).min <= min_val <= np.iinfo(np.int32).max and max_val <= np.iinfo(np.int32).max:
            return 'int32'
        else:
            return 'int64'
    # Check if the column can be converted to datetime
    with contextlib.suppress(ValueError, TypeError):
        pd.to_datetime(column)
        return 'datetime'
    # Check if the column should be categorical
    if pd.api.types.is_object_dtype(column):
        num_unique_values = len(column.unique())
        num_total_values = len(column)
        if num_unique_values / num_total_values < 0.5:
            return 'category'

    # Default to object type if none of the above conditions are met
    return 'object'


def get_full_language_name(language_code, default="Unknown Language"):
    """
    Convert a two-letter language code (ISO 639-1) to its full language name.
    
    Parameters:
    language_code (str): The two-letter ISO 639-1 language code.
    
    Returns:
    str: The full name of the language or a message indicating the code was not found.
    """
    if language_code=="Other languages":
        return language_code
    language = pycountry.languages.get(alpha_2=language_code, default=default)
    if language != default:
        language = language.name
    return language


def get_country_name(country_code, default="Unknown Country"):
    """
    Convert a two-letter country code (ISO 3166-1 alpha-2|) to its full country name.
    
    Parameters:
    country_code (str): The two-letter ISO 3166-1 alpha-2 country code.
    
    Returns:
    str: The full name of the country or a message indicating the code was not found.
    """
    country = pycountry.countries.get(alpha_2=country_code, default=default)
    if country != default:
        country = country.name
    return country
    


# Loading

In [None]:

# Test database runs waaaaaaaaaaaaaaaaaay faster, yet slow
# USER, DATABASE = "nezox2um_test", "nezox2um_test"
# import pandas as pd
# from sqlalchemy import create_engine
# from tqdm.notebook import tqdm

# def process_data_chunks(query):
#     engine = create_engine(f"mysql://{USER}:{PASSWORD}@{HOST}:3306/{DATABASE}")
#     chunk_size = 100_000
#     offset = 0
#     df_result = pd.DataFrame()

#     # Estimate the total number of rows
#     total_rows_query = f"SELECT COUNT(*) FROM ({query}) AS total"
#     total_rows = pd.read_sql_query(total_rows_query, engine).iloc[0, 0]
#     total_chunks = (total_rows // chunk_size) + 1

#     # Use tqdm to display the progress bar
#     with tqdm(total=total_chunks, desc="Processing Data Chunks") as pbar:
#         while True:
#             chunk_query = f"{query} LIMIT {chunk_size} OFFSET {offset}"
#             df_chunk = pd.read_sql_query(chunk_query, engine)
#             if df_chunk.empty:
#                 break
#             # Append the chunk to the list
#             df_result = pd.concat([df_result, df_result], ignore_index=True)
#             # Update the offset
#             offset += chunk_size
#             # Update the progress bar
#             pbar.update(1)
    
#     return df_result

# # Use the function to get and process data
# full_data = process_data_chunks(QUERY_ALL)
test_data = fetch_data(QUERY_ALL)


In [None]:
test_data.info()

In [None]:
get_size_of(test_data.memory_usage(index=False, deep=True).sum())

In [None]:
test_data.describe()

# Visualisations

# Database size

In [None]:
total_lines = 6511404
tweets_right_now = len(test_data)

# Example data
values = [total_lines, tweets_right_now]
labels = ["Number of possible tweets", "Number of stored tweets"]

plt.figure(figsize=(10, 8))
bars = plt.bar(labels, values)

# Add labels on top of each bar
for bar in bars:
    yval = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2 - 0.1, yval + 5000, f'{yval:,}', fontsize=12, weight='bold')

# Customize the chart
plt.title('Comparison of tweets provided vs stored', fontsize=16, weight='bold')
plt.ylabel('Number of Tweets', fontsize=14, weight='bold')
plt.xticks(fontsize=12, weight='bold')
plt.yticks(fontsize=12, weight='bold');

In [None]:
# TODO: what is exact change, for instance, how did deleting no user_id tweets impacted the storage
# Example data
data = [258, 414685, 190928, 2326, 15]
labels = ['Not a tweet', 'Duplicate tweet', 'Inhuman language', 'No tweet id', "Invalid user"]

# Ensure data and labels have the same length
assert len(data) == len(labels), "Data and labels must be the same length."

# Sort the data and labels in decreasing order
sorted_data_labels = sorted(zip(data, labels), reverse=True)
data, labels = zip(*sorted_data_labels)

# Choose a color palette

# Create the bar chart
plt.figure(figsize=(20, 8))
bars = plt.bar(labels, data)

# Add labels on top of each bar
for bar in bars:
    yval = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2 - 0.1, yval + 5000, f'{yval:,}', fontsize=12, weight='bold')

# Customize the chart
plt.title('Numbers of potential tweets not considered per category', fontsize=16, weight='bold')
plt.ylabel('Number of Tweets', fontsize=14, weight='bold')
plt.xticks(fontsize=12, weight='bold')
plt.yticks(fontsize=12, weight='bold');

In [None]:
total_lines = 35
tweets_right_now = 2.2

values = [total_lines, tweets_right_now]
labels = ["Raw data", "Filtered data"]

plt.figure(figsize=(10, 8))
plt.bar(labels, values)  

# Add labels on top of each bar
# for bar in bars:
#     yval = bar.get_height()
#     plt.text(bar.get_x() + bar.get_width()/2 - 0.1, yval, f'{yval:,}', fontsize=12, weight='bold')

# Customize the chart
plt.title('Comparison of storage required', fontsize=16, weight='bold')
plt.ylabel('Storage, GB', fontsize=14, weight='bold')
plt.xticks(fontsize=12, weight='bold')
plt.yticks(fontsize=12, weight='bold');

## Language related

In [None]:
# Change later to "un" instead fo "und"
lang_popularity_df = test_data.reset_index().groupby('lang', observed=True)\
    .count()[['tweet_id']]\
    .sort_values('tweet_id', ascending=False)
# lang_popularity_df.index = lang_popularity_df.index.map(get_full_language_name)
lang_popularity_df.head()

In [None]:
# Get top 10

# Step 1: Identify the top 10 languages by tweet_id count
top_10_languages = lang_popularity_df.nlargest(5, 'tweet_id')

# Step 3: Identify the other languages
other_languages_df = lang_popularity_df.loc[~lang_popularity_df.index.isin(top_10_languages.index)]

# Step 4: Aggregate the tweet_id and fraction for other languages
other_languages_agg = other_languages_df.sum()
other_languages_agg.name = 'Other languages'

# Step 5: Combine the top 10 languages with the aggregated other languages
final_df = pd.concat([top_10_languages, other_languages_agg.to_frame().T])
final_df.index.name = 'Language'
final_df

In [None]:
# Your existing code for plotting
top_5_popularity_lang_full = final_df.copy()
top_5_popularity_lang_full.index = top_5_popularity_lang_full.index.map(get_full_language_name)

labels = top_5_popularity_lang_full.index
sizes = top_5_popularity_lang_full["tweet_id"]
plt.figure(figsize=(10, 8))
plt.pie(sizes, labels=labels, autopct='%1.1f%%', startangle=140, textprops={'fontsize': 14, 'weight': 'bold'})
plt.title('Number of tweets by 5 most popular languages', fontsize=16, weight='bold')
plt.legend(labels, title="Countries", bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=12);

## Country of origin

In [None]:
# Change later to "un" instead fo "und"
country_df = test_data.reset_index().groupby('country_code', observed=True).count()[['tweet_id']].sort_values('tweet_id', ascending=False)
country_df.index = country_df.index.map(get_country_name)
country_df

In [None]:

top_10_countries = country_df.nlargest(1, 'tweet_id')
other_countries_df = country_df.loc[~country_df.index.isin(top_10_countries.index)]
other_countries_agg = other_countries_df.sum()

other_countries_agg.name = 'Other countries'
final_df = pd.concat([top_10_countries, other_countries_agg.to_frame().T])
final_df.index.name = 'Country'

labels = final_df.index
sizes = final_df['tweet_id']
plt.figure(figsize=(10, 8))
plt.pie(sizes, labels=labels, autopct='%1.1f%%', startangle=140, textprops={'fontsize': 14, 'weight': 'bold'})
plt.title("Number of tweets per known countries", fontsize=16, weight='bold')
plt.legend(labels, title="Countries", bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=12);

In [None]:
# Your existing code for plotting
df_plot_by_country = country_df[country_df.index != "Unknown Country"].copy()

# Step 1: Identify the top 10 languages by tweet_id count
top_10_countries = df_plot_by_country.nlargest(5, 'tweet_id')

# Step 3: Identify the other languages
other_countries_df = df_plot_by_country.loc[~df_plot_by_country.index.isin(top_10_countries.index)]

# Step 4: Aggregate the tweet_id and fraction for other languages
other_countries_agg = other_countries_df.sum()
other_countries_agg.name = 'Other countries'

# Step 5: Combine the top 10 languages with the aggregated other languages
final_df = pd.concat([top_10_countries, other_countries_agg.to_frame().T])
final_df.index.name = 'Country'
# fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 5))
# Prepare data for pie chart
labels = final_df.index
sizes = final_df['tweet_id']
plt.figure(figsize=(10, 8))
plt.pie(sizes, labels=labels, autopct='%1.1f%%', startangle=140, textprops={'fontsize': 14, 'weight': 'bold'})
plt.title("Distribution of tweets per known countries", fontsize=16, weight='bold')
plt.legend(labels, title="Countries", bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=12);

## Tweets from main accounts of the airlines

In [None]:
avia_names = set(COMPANY_NAME_TO_ID.values())

replies_to_avia_companies_df = test_data.loc[test_data['user_id'].apply(lambda x: any(x == avia_name for avia_name in avia_names))]
replies_to_avia_companies_df = replies_to_avia_companies_df.reset_index().groupby("user_id").count()[['tweet_id']].sort_values('tweet_id', ascending=False).reset_index()
replies_to_avia_companies_df["user_id"] = replies_to_avia_companies_df["user_id"].apply(lambda user_id: COMPANY_ID_TO_NAME.get(user_id, user_id))
replies_to_avia_companies_df = replies_to_avia_companies_df.set_index("user_id")
replies_to_avia_companies_df

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(25,8))
sns.barplot(data=replies_to_avia_companies_df, ax=ax, x='user_id', y='tweet_id')
# Customize the chart
plt.title('Number of tweets by airline company', fontsize=16, weight='bold')
plt.ylabel('Number of tweets', fontsize=14, weight='bold')
plt.xlabel('', fontsize=14, weight='bold')
plt.xticks(fontsize=12, weight='bold')
plt.yticks(fontsize=12, weight='bold');

## Replies to company posts

In [None]:
# Using SQL
df_reply = fetch_data("""
SELECT 
    t1.tweet_id AS tweet_id,
    t1.creation_time AS tweet_creation_time,
    t1.user_id AS user_id,
    t2.tweet_id AS original_tweet_id,
    t2.creation_time AS original_tweet_creation_time,
    t2.user_id AS original_user_id
FROM 
    Tweets t1
INNER JOIN 
    Tweets t2
ON 
    t1.replied_tweet_id = t2.tweet_id;
""", dtype=False).set_index("tweet_id")


In [None]:
df_reply["response_time"] = df_reply["tweet_creation_time"] - df_reply["original_tweet_creation_time"]
df_reply

In [None]:
df_reply["airline"] = df_reply["user_id"].map(COMPANY_ID_TO_NAME)
df_reply["original_airline"] = df_reply["original_user_id"].map(COMPANY_ID_TO_NAME)
df_reply

In [None]:
# Convert datetime and timedelta columns
# Convert datetime and timedelta columns
df_reply['tweet_creation_time'] = pd.to_datetime(df_reply['tweet_creation_time'])
df_reply['original_tweet_creation_time'] = pd.to_datetime(df_reply['original_tweet_creation_time'])
df_reply['response_time'] = pd.to_timedelta(df_reply['response_time'])

# Calculate average response time per airline
average_response_time_airline = df_reply[df_reply['airline'].notnull()].groupby('airline')['response_time'].mean()

# Calculate average response time for others users to react to each airline
average_response_time_reactions = df_reply[df_reply['original_airline'].notnull()].groupby('original_airline')['response_time'].mean()

# Combine the results into one DataFrame for plotting
combined_df = pd.DataFrame({
    'Airline': average_response_time_airline.index.union(average_response_time_reactions.index),
    'Airline Response Time': average_response_time_airline.reindex(average_response_time_airline.index.union(average_response_time_reactions.index)),
    'User Reaction Time': average_response_time_reactions.reindex(average_response_time_airline.index.union(average_response_time_reactions.index))
}).reset_index(drop=True)

# Convert timedelta to total seconds for plotting
combined_df['Airline Response Time'] = combined_df['Airline Response Time']
combined_df['User Reaction Time'] = combined_df['User Reaction Time']
combined_df.sort_values(by=['Airline Response Time', 'User Reaction Time'], ascending=[True, False], inplace=True)

combined_df

In [None]:
combined_df_plot = combined_df[combined_df["Airline"] != "Air Berlin"].copy()
combined_df_plot['Airline Response Time, hours'] = combined_df_plot['Airline Response Time'].dt.total_seconds() / 3600
combined_df_plot['User Reaction Time, days'] = combined_df_plot['User Reaction Time'].dt.total_seconds() / 86.400
# Plotting
fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(18, 12))
sns.barplot(combined_df_plot, x='Airline', y='Airline Response Time, hours', ax=ax[0])
ax[0].set_title('Average airline customer service response time', fontsize=16, weight='bold')
ax[0].set_ylabel('Airline Response Time, hours', fontsize=14, weight='bold')
sns.barplot(combined_df_plot, x='Airline', y='User Reaction Time, days', ax=ax[1])
ax[1].set_title('Average user reaction time to airline tweet', fontsize=16, weight='bold')
ax[1].set_ylabel('User Reaction Time, days', fontsize=14, weight='bold')

In [None]:
df_reply_luft = df_reply.groupby("airline")
# df_reply_luft.reset_index().set_index("original_tweet_id")
df_reply_luft.describe()

## Companies' activity and popularity in social media

In [None]:
test_data_no_index = test_data.reset_index().copy()
popularity_by_airlines = test_data_no_index.reset_index().loc[test_data_no_index['user_id']\
    .apply(lambda x: any(x == avia_name for avia_name in avia_names))]\
    .groupby("user_id")\
    .agg(
          tweet_number=("tweet_id", "count"),
          retweet_count=("retweet_count", "sum"),
          favorite_count=("favorite_count", "sum"),
          reply_count=("reply_count", "sum"),
          quote_count=("quote_count", "sum"),
          ).reset_index()
popularity_by_airlines["user_id"] = popularity_by_airlines["user_id"].apply(lambda user_id: COMPANY_ID_TO_NAME.get(user_id, user_id))
popularity_by_airlines = popularity_by_airlines.set_index("user_id").sort_values("retweet_count", ascending=False)
popularity_by_airlines

In [None]:
fig, ax = plt.subplots(nrows=len(popularity_by_airlines.columns),
                       figsize=(20, 6 * len(popularity_by_airlines.columns)))

for index, column in enumerate(popularity_by_airlines.columns):
    df_column = popularity_by_airlines
    sns.barplot(data=df_column, x='user_id', y=column, ax=ax[index])
    ax[index].set_title(f"{column} for each airline", fontsize=16, weight='bold')
    ax[index].set_xlabel("")
fig.subplots_adjust(hspace=0.4)

## Information regarding users

In [None]:
df_users = test_data_no_index.copy().groupby("user_id")
df_users = df_users.agg(
    user_creation_time=("user_creation_time", "min"),
    verified=("verified", "min"),
    followers_count=("followers_count", "min"),
    friends_count=("friends_count", "min"),
    statuses_count=("statuses_count", "min"),
    default_profile=("default_profile", "min"),
    default_profile_image=("default_profile_image", "max"),
    first_tweet=("tweet_creation_time", "min"),
    last_tweet=("tweet_creation_time", "max"),
    possibly_sensitive=("possibly_sensitive", "sum"),
    favorite_count=("favorite_count", "sum"),
    retweet_count=("retweet_count", "sum"),
    reply_count=("reply_count", "sum"),
    quote_count=("quote_count", "sum"),
    lang=("lang", "first")
)
df_users.head()

In [None]:
df_users.describe()

### Custom user "trustworthiness" classification

In [None]:
df_verified = df_users.groupby("verified").agg(verified=("user_creation_time", "count"))
df_verified.plot(kind="bar", title="Verified user ratio", legend=False, figsize=(12, 6))
df_verified

In [None]:
default_profile = df_users.groupby("default_profile").agg(default_profile=("user_creation_time", "count"))
default_profile.plot(kind="bar", title="Has default profile user ratio", legend=False, figsize=(12, 6))
default_profile

In [None]:
default_profile_image = df_users.groupby("default_profile_image").agg(default_profile_image=("user_creation_time", "count"))
default_profile_image.plot(kind="bar", title="Has default profile image user ratio", legend=False, figsize=(12, 6))
default_profile_image

In [None]:
df_users["time_to_tweet"] = df_users["first_tweet"] - df_users["user_creation_time"]
df_users["time_to_tweet"].describe()

## Tweets information

In [None]:
df_sensitive = test_data.groupby("possibly_sensitive").count()[["user_id"]]
df_sensitive

In [None]:
df_sensitive.plot(kind="bar", title="Possibly sensitive tweets", legend=False, figsize=(12, 6));