In [1]:
from tqdm.notebook import tqdm

In [2]:
import sqlite3
from typing import Any, Dict, List, Optional, Tuple, Union

import pandas as pd


def form_connection_params(local: bool, notebook: bool = False) -> Dict[str, Any]:
    """
    Forms connection parameters based on the local and notebook flags.

    Args:
        local (bool): Flag indicating whether the connection is local.
        notebook (bool, optional): Flag indicating whether the notebook
            environment is used. Defaults to False.

    Returns:
        Dict[str, str]: A dictionary containing connection parameters based on the flags.
    """
    if local:
        if notebook:
            return {"file_path": "/kaggle/working/local_backup (9).db"}


def connect_to_database(
    connection_params: Dict[str, str], local: bool
) -> Union[sqlite3.Connection]:
    """
    Connects to a database based on the provided connection parameters and local flag.

    Args:
        connection_params (Dict[str, str]): The parameters required to establish the database connection.
        local (bool): Flag indicating whether the connection should be local (SQLite) or remote (MySQL).

    Returns:
        Union[sqlite3.Connection, mysql.connector.MySQLConnection]: The database connection object if successful.

    Raises:
        Error: If an error occurs during the connection process.
    """
    if local:
        return sqlite3.connect(connection_params["file_path"])


def execute_queries(
    connection: Union[sqlite3.Connection],
    queries: List[Union[str, Tuple[str, List[Tuple[Any]]]]],
) -> None:
    """
    Executes a list of SQL queries on the provided database connection.

    Args:
        connection: The database connection object.
        queries (List[Union[str, Tuple[str, List[Tuple[Any]]]]): A list of SQL
            queries to execute, where each query can be a string or a tuple
            containing the SQL query and parameters.

    Returns:
        None
    """
    cursor = connection.cursor()
    for query in queries:
        if isinstance(query, str):
            cursor.execute(query)
        else:
            sql, params = query
            cursor.executemany(sql, params)
    connection.commit()


def get_dataframe_from_query(
    query: str,
    connection_params: Dict[str, str],
    local: bool,
    dtypes: Optional[Dict[str, str]] = None,
    index_col: Optional[str] = None,
    parse_dates: Optional[List[str]] = None,
) -> pd.DataFrame:
    """
    Retrieves a pandas DataFrame by executing a query on a database connection.

    Args:
        query (str): The SQL query to execute.
        connection_params (Dict[str, str]): The parameters required to establish the database connection.
        local (bool): Flag indicating whether the connection is local (SQLite) or remote (MySQL).
        dtypes (Optional[Dict[str, str]], optional): Dictionary specifying column data types.
            Defaults to None.
        index_col (Optional[str], optional): Name of the column to set as the index.
            Defaults to None.
        parse_dates (Optional[List[str]], optional): List of columns to parse as dates.
            Defaults to None.

    Returns:
        pd.DataFrame: A pandas DataFrame containing the results of the query.
    """
    if parse_dates:
        dtypes = {k: v for k, v in (dtypes or {}).items() if k not in parse_dates}

    with sqlite3.connect(connection_params["file_path"]) as connection:
        return pd.read_sql_query(
            query,
            connection,
            dtype=dtypes,
            index_col=index_col,
            parse_dates=parse_dates,
        )


In [3]:
import contextlib
import gc
import os
import re
import string
import sys
from concurrent.futures import ThreadPoolExecutor
from typing import List, Tuple

import pandas as pd
import tensorflow as tf
from scipy.special import softmax
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

# Load the tokenizer and model
model_name = "cardiffnlp/twitter-roberta-base-sentiment"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = TFAutoModelForSequenceClassification.from_pretrained(
    "cardiffnlp/twitter-roberta-base-sentiment"
)
if physical_devices := tf.config.list_physical_devices("GPU"):
    try:
        for device in physical_devices:
            tf.config.experimental.set_memory_growth(device, True)
    except RuntimeError as e:
        print(e)

device = "/GPU:0" if tf.config.list_physical_devices("GPU") else "/CPU:0"


def process_batch(texts):
    """
    Apply sentiment analysis to a batch of texts using a pre-trained transformer model.

    Parameters:
    texts (list): A list of texts to analyze.

    Returns:
    list: A list of ranked sentiment labels for the input texts.

    This function uses a pre-trained transformer model for sequence classification to analyze the sentiment of the texts in the given batch. It returns a list of ranked sentiment labels for the input texts, where the first label is the most likely sentiment and the subsequent labels are less likely sentiments.
    """
    # Get the maximum sequence length for the model
    encoded_input = tokenizer(
        texts, return_tensors="tf", padding=True, truncation=True, max_length=512
    )
    with tf.device(device):
        output = model(encoded_input)

    scores = output[0].numpy()
    scores = softmax(scores, axis=1)

    sentiment_scores = scores[:, 2] - scores[:, 0]
    return sentiment_scores.tolist()


def clear_gpu_memory():
    tf.keras.backend.clear_session()  # Clear the current session
    with contextlib.suppress(AttributeError):
        tf.compat.v1.reset_default_graph()  # For TensorFlow 1.x compatibility
    gc.collect()  # Explicitly call the garbage collector


def apply_sentiment_analysis(
    df: pd.DataFrame, text_column: str, batch_size=128, max_workers=4
):
    """
    Apply sentiment analysis to the given DataFrame using a pre-trained transformer model.

    Parameters:
    df (pd.DataFrame): The input DataFrame containing the text column to analyze.
    text_column (str): The name of the column in the DataFrame containing the text to analyze.
    batch_size (int): The number of texts to process in each batch. Default is 128.
    max_workers (int): The maximum number of worker threads to use for parallel processing. Default is 4.

    Returns:
    pd.DataFrame: The input DataFrame with an additional 'sentiment' column containing the sentiment analysis results.

    This function uses a pre-trained transformer model for sequence classification
    to analyze the sentiment of the texts in the given DataFrame. It applies the
    sentiment analysis in parallel using multiple worker threads to improve performance.
    The results are then added to the input DataFrame as a new 'sentiment' column.
    """
    texts = df[text_column].tolist()
    results = []

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        for i in range(0, len(texts), batch_size):
            batch = texts[i : i + batch_size]
            results.extend(executor.submit(process_batch, batch).result())
            clear_gpu_memory()
    df["sentiment"] = results
    return df


def clean_mentions(tweet):
    # Remove RT since it has no meaning
    tweet = re.sub(r"^RT ", "", tweet)
    # Remove URLs since they should not impact anything
    tweet = re.sub(r"http\S+|www\S+|https\S+", "", tweet)
    # Remove user mentions
    tweet = re.sub(r"@\w+", "", tweet)
    # Remove hashtag symbols but keep the words
    tweet = re.sub(r"#", "", tweet)
    # Remove unnecessary punctuation while keeping emoticons and important punctuation
    tweet = tweet.translate(
        str.maketrans(
            "",
            "",
            string.punctuation.replace("!", "").replace("?", "").replace("#", ""),
        )
    )
    # Strip unnecessary whitespace
    return tweet.strip()


def update_sentiment_scores(
    batch: List[Tuple[str, str]], connection_params: dict, local: bool
) -> None:
    """
    Update full_text values for a batch of data in the local SQLite database.

    Args:
        batch: List of (full_text, tweet_id) pairs.
        db_path: The path to the SQLite database file.
    """
    update_query = "UPDATE Tweets SET sentiment_score = ? WHERE tweet_id = ?"
    if not local:
        update_query = update_query.replace("?", "%s")
    with connect_to_database(connection_params, local) as connection:
        execute_queries(connection, [(update_query, batch)])
        connection.commit()


def get_batches(df: pd.DataFrame, batch_size: int = 1000) -> List[pd.DataFrame]:
    """
    Split DataFrame into batches of DataFrames with specified batch size.

    Args:
        df: The DataFrame containing tweet data.
        batch_size: The size of each batch.

    Returns:
        A list of DataFrames, each containing a batch of rows.
    """
    return [df.iloc[i : i + batch_size] for i in range(0, len(df), batch_size)]


def convert_to_list(df: pd.DataFrame) -> List[List]:
    """
    Convert DataFrame with tweet_id as index to a list of lists containing sentiment and tweet_id.

    Args:
        df: The DataFrame with tweet_id as index and sentiment as a column.

    Returns:
        A list of lists containing tweet_id and sentiment.
    """
    tweet_ids = df.index.to_numpy()
    sentiments = df["sentiment"].to_numpy()
    return tuple(zip(sentiments, tweet_ids))

2024-06-16 07:07:47.774383: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-16 07:07:47.774504: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-16 07:07:47.915977: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


config.json:   0%|          | 0.00/747 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/150 [00:00<?, ?B/s]

tf_model.h5:   0%|          | 0.00/501M [00:00<?, ?B/s]

All model checkpoint layers were used when initializing TFRobertaForSequenceClassification.

All the layers of TFRobertaForSequenceClassification were initialized from the model checkpoint at cardiffnlp/twitter-roberta-base-sentiment.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFRobertaForSequenceClassification for predictions without further training.


Physical devices cannot be modified after being initialized


In [4]:
local = True
connection_params = form_connection_params(local, True)
query = """
SELECT tweet_id, full_text
FROM Tweets
WHERE sentiment_score IS NULL;
"""
batch_size = 10_000

In [5]:
!cp "/kaggle/input/dbl-sentiment-update/local_backup (9).db" "/kaggle/working/"

  pid, fd = os.forkpty()


In [6]:
test_data = get_dataframe_from_query(query, connection_params, local, index_col="tweet_id")

In [7]:
# Apply the cleaning function to the DataFrame
test_data['cleaned_text'] = test_data['full_text'].apply(clean_mentions)

In [8]:
test_data

Unnamed: 0_level_0,full_text,cleaned_text
tweet_id,Unnamed: 1_level_1,Unnamed: 2_level_1
1212599633509277697,This is so crazy😩,This is so crazy😩
1212599644984926210,"RT @andrewkimmel: Dear @AmericanAir,\n\nAfter ...",Dear \n\nAfter arriving back to LA from Indone...
1212599647157526530,RT @JamesHasson20: The lack of self-awareness ...,The lack of selfawareness in this thread is hi...
1212599702933258240,RT @blexijade: This thread lol https://t.co/vj...,This thread lol
1212599704401395717,"RT @andrewkimmel: Dear @AmericanAir,\n\nAfter ...",Dear \n\nAfter arriving back to LA from Indone...
...,...,...
1244696703690772485,RT @jfergo86: Me parece a mí o el avión es más...,Me parece a mí o el avión es más grande que el...
1244696708983984131,Today’s random pic of the day is the one of Vo...,Today’s random pic of the day is the one of Vo...
1244696710447800320,RT @SchipholWatch: @spbverhagen @markduursma @...,Nog niet aan de orde? Als in er is nog geen st...
1244696713350217728,RT @wiltingklaas: Tweede Kamer stemt over vlie...,Tweede Kamer stemt over vliegtaks via Of ze ...


In [9]:
data_batches = get_batches(test_data[["cleaned_text"]], batch_size)

In [10]:
for batch in tqdm(data_batches, desc="Updating sentiment_score: "):
    # feel free to change the 
    df_sentiment = apply_sentiment_analysis(batch, "cleaned_text", 128, 2)
    update_sentiment_scores(convert_to_list(df_sentiment), connection_params, local)

Updating sentiment_score:   0%|          | 0/2294 [00:00<?, ?it/s]

KeyboardInterrupt: 