In [None]:
import os
from concurrent.futures import ThreadPoolExecutor
from typing import List, Tuple

import gc
import re
import sqlite3
import pandas as pd
import tensorflow as tf
from scipy.special import softmax
from tqdm.notebook import tqdm
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

In [12]:

def fetch_data(query: str, path: str) -> pd.DataFrame:
    with sqlite3.connect(path) as connection:
        return pd.read_sql_query(query, connection, index_col='tweet_id')


def process_batch(texts):
    """
    Apply sentiment analysis to a batch of texts using a pre-trained transformer model.

    Parameters:
    texts (list): A list of texts to analyze.

    Returns:
    list: A list of ranked sentiment labels for the input texts.

    This function uses a pre-trained transformer model for sequence classification to analyze the sentiment of the texts in the given batch. It returns a list of ranked sentiment labels for the input texts, where the first label is the most likely sentiment and the subsequent labels are less likely sentiments.
    """
    # Get the maximum sequence length for the model

    encoded_input = tokenizer(texts, return_tensors="tf", padding=True, truncation=True)
    with tf.device(device):
        output = model(encoded_input)

    scores = output[0].numpy()
    scores = softmax(scores, axis=1)

    sentiment_scores = scores[:, 2] - scores[:, 0]
    return sentiment_scores.tolist()

def clear_gpu_memory():
    tf.keras.backend.clear_session()  # Clear the current session
    try:
        tf.compat.v1.reset_default_graph()  # For TensorFlow 1.x compatibility
    except AttributeError:
        pass
    gc.collect()  # Explicitly call the garbage collector


def apply_sentiment_analysis(df, text_column, batch_size=128, max_workers=4):
    """
    Apply sentiment analysis to the given DataFrame using a pre-trained transformer model.

    Parameters:
    df (pd.DataFrame): The input DataFrame containing the text column to analyze.
    text_column (str): The name of the column in the DataFrame containing the text to analyze.
    batch_size (int): The number of texts to process in each batch. Default is 128.
    max_workers (int): The maximum number of worker threads to use for parallel processing. Default is 4.

    Returns:
    pd.DataFrame: The input DataFrame with an additional 'sentiment' column containing the sentiment analysis results.

    This function uses a pre-trained transformer model for sequence classification to analyze the sentiment of the texts in the given DataFrame. It applies the sentiment analysis in parallel using multiple worker threads to improve performance. The results are then added to the input DataFrame as a new 'sentiment' column.
    """
    texts = df[text_column].tolist()
    results = []

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        for i in range(0, len(texts), batch_size):
            batch = texts[i : i + batch_size]
            results.extend(executor.submit(process_batch, batch).result())
            clear_gpu_memory()
    df["sentiment"] = results
    return df


def get_batches(df: pd.DataFrame, batch_size: int = 1000) -> List[pd.DataFrame]:
    """
    Split DataFrame into batches of DataFrames with specified batch size.

    Args:
        df: The DataFrame containing tweet data.
        batch_size: The size of each batch.

    Returns:
        A list of DataFrames, each containing a batch of rows.
    """
    return [df.iloc[i : i + batch_size] for i in range(0, len(df), batch_size)]

def convert_to_list(df: pd.DataFrame) -> List[List]:
    """
    Convert DataFrame with tweet_id as index to a list of lists containing sentiment and tweet_id.

    Args:
        df: The DataFrame with tweet_id as index and sentiment as a column.

    Returns:
        A list of lists containing tweet_id and sentiment.
    """
    tweet_ids = df.index.to_numpy()
    sentiments = df["sentiment"].to_numpy()
    return tuple(zip(sentiments, tweet_ids))

def connect_to_local_database(db_path: str):
    """
    Establish a connection to the local SQLite database.

    Args:
        db_path: The path to the SQLite database file.

    Returns:
        A connection object to the SQLite database.
    """
    try:
        return sqlite3.connect(db_path)
    except sqlite3.Error as e:
        print(f"Error while connecting to SQLite: {e}")
    return None


def update_text_local(
    batch: List[Tuple[str, str]], db_path: str
) -> None:
    """
    Update full_text values for a batch of data in the local SQLite database.

    Args:
        batch: List of (full_text, tweet_id) pairs.
        db_path: The path to the SQLite database file.
    """
    connection = connect_to_local_database(db_path)
    if connection is None:
        return
    try:
        cursor = connection.cursor()
        update_query = "UPDATE Tweets SET sentiment_score = ? WHERE tweet_id = ?"
        cursor.executemany(update_query, batch)
        connection.commit()
    except sqlite3.Error as e:
        print(f"Error updating batch: {e}")
    finally:
        if cursor:
            cursor.close()
        if connection:
            connection.close()


In [5]:
print("Start the upload")
path = os.path.join("data_processed", "local_backup.db")
query = """
SELECT tweet_id, full_text
FROM Tweets
WHERE sentiment_score IS NULL;
"""
batch_size = 5_000
# Load the tokenizer and model

model_name = "cardiffnlp/twitter-roberta-base-sentiment"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
if physical_devices := tf.config.list_physical_devices('GPU'):
    try:
        for device in physical_devices:
            tf.config.experimental.set_memory_growth(device, True)
    except RuntimeError as e:
        print(e)

device = '/GPU:0' if tf.config.list_physical_devices('GPU') else '/CPU:0'

Start the upload




config.json:   0%|          | 0.00/747 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/150 [00:00<?, ?B/s]

tf_model.h5:   0%|          | 0.00/501M [00:00<?, ?B/s]

All model checkpoint layers were used when initializing TFRobertaForSequenceClassification.

All the layers of TFRobertaForSequenceClassification were initialized from the model checkpoint at cardiffnlp/twitter-roberta-base-sentiment.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFRobertaForSequenceClassification for predictions without further training.


Physical devices cannot be modified after being initialized


  pid, fd = os.forkpty()
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [7]:
# path = "/kaggle/working/local_backup.db"
path =  os.path.join(os.path.dirname(os.getcwd()),
                     "data_processed",
                     "local_backup.db")
test_data = fetch_data(query, path)

In [8]:
# Airlines list
airlines = [
    'KLM', "British_Airways", "airfrance", "AmericanAir", "lufthansa", 
    "airberlinAssist", "easyJet", "Ryanair", "SingaporeAir", "Qantas",
    "EtihadAirways", "VirginAtlantic", "airberlin"
]

# Function to clean mentions not in the airlines list
def clean_mentions(text):
    # Regex pattern to find mentions
    mention_pattern = r'@([A-Za-z0-9_]+)'
    mentions = re.findall(mention_pattern, text)
    
    # Check if each mention is in the airlines list
    valid_mentions = [f"@{mention}" for mention in mentions if mention in airlines]
    
    # Replace invalid mentions in the text
    for mention in mentions:
        if mention not in airlines:
            text = text.replace(f"@{mention}", "")
    
    return text

# Apply the cleaning function to the DataFrame
test_data['cleaned_text'] = test_data['full_text'].apply(clean_mentions)

In [9]:
test_data

Unnamed: 0_level_0,full_text,cleaned_text
tweet_id,Unnamed: 1_level_1,Unnamed: 2_level_1
1151900551367614465,"""In 2019, I just find it the weirdest thing go...","""In 2019, I just find it the weirdest thing go..."
1151902427676925952,Delighted to see @IrishTimes contributors slow...,Delighted to see contributors slowly but sure...
1151902431015518208,@KLM @HeatherYemm Why don’t the passengers jus...,@KLM Why don’t the passengers just cover thei...
1151902438988947456,RT @clara_wichmann: Please ⁦@KLM⁩ mogen kleine...,RT : Please ⁦@KLM⁩ mogen kleine babytjes ook w...
1151902440360484865,@KLM @HeatherYemm WTF?? Permitted??? Det er (e...,@KLM WTF?? Permitted??? Det er (eneste) mad t...
...,...,...
1244696703690772485,RT @jfergo86: Me parece a mí o el avión es más...,RT : Me parece a mí o el avión es más grande q...
1244696708983984131,Today’s random pic of the day is the one of Vo...,Today’s random pic of the day is the one of Vo...
1244696710447800320,RT @SchipholWatch: @spbverhagen @markduursma @...,RT : @KLM Nog niet aan de orde? Als in: e...
1244696713350217728,RT @wiltingklaas: Tweede Kamer stemt over vlie...,RT : Tweede Kamer stemt over vliegtaks https:/...


In [25]:
data_batches = get_batches(test_data[["cleaned_text"]], batch_size)

In [26]:
for batch in tqdm(data_batches, desc="Updating text: "):
    df_sentiment = apply_sentiment_analysis(batch, "cleaned_text", 32, 2)
    update_text_local(convert_to_list(df_sentiment), path)

Updating text:   2%|▏         | 17/984 [38:45<36:52:25, 137.28s/it]Exception ignored in: <function _xla_gc_callback at 0x7cd54e355f30>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/jax/_src/lib/__init__.py", line 98, in _xla_gc_callback
    def _xla_gc_callback(*args):
KeyboardInterrupt: 
Exception ignored in: <function _xla_gc_callback at 0x7cd54e355f30>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/jax/_src/lib/__init__.py", line 98, in _xla_gc_callback
    def _xla_gc_callback(*args):
KeyboardInterrupt: 
Exception ignored in: <function _xla_gc_callback at 0x7cd54e355f30>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/jax/_src/lib/__init__.py", line 98, in _xla_gc_callback
    def _xla_gc_callback(*args):
KeyboardInterrupt: 
Exception ignored in: <function _xla_gc_callback at 0x7cd54e355f30>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/ja

KeyboardInterrupt: 