# Imports

In [1]:
import os
import sqlite3
import pandas as pd
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
from scipy.special import softmax
import pandas as pd
from concurrent.futures import ThreadPoolExecutor
from sqlalchemy import create_engine
# v1
from tqdm import tqdm
from collections import defaultdict


2024-06-02 19:26:50.948661: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  from .autonotebook import tqdm as notebook_tqdm


# Constants

In [2]:
def check_given_var(env_var_str: str) -> str:
    """
    Check if the given environment variable is set and return its value.

    Args:
        env_var_str (str): The name of the environment variable to check.

    Returns:
        str: The value of the environment variable.

    Raises:
        AssertionError: If the environment variable is not found.
    """

    env_var = os.getenv(env_var_str)
    assert (
        env_var is not None
    ), f"{env_var_str} is required but not found in environment variables"
    return env_var


def check_env_vars() -> (str, str, str, str):  # type: ignore
    user = check_given_var("DBL_USER")
    database = check_given_var("DBL_DATABASE")
    password = check_given_var("DBL_PASSWORD")
    host = check_given_var("DBL_HOST")
    return user, database, password, host


USER, DATABASE, PASSWORD, HOST = check_env_vars()
# USER, DATABASE = "nezox2um_test", "nezox2um_test"
QUERY_ALL = """
SELECT 
    Users.user_id AS user_id, 
    Users.creation_time AS user_creation_time, 
    Tweets.creation_time AS tweet_creation_time,
    Tweets.tweet_id,
    Tweets.full_text,
    Tweets.lang,
    Tweets.replied_tweet_id
FROM Users
INNER JOIN Tweets ON Users.user_id = Tweets.user_id;
"""


DTYPES = {
"user_id": "object",
"tweet_id": "object",
"full_text": "object",
"lang": "category",
"replied_tweet_id": "object",
}

COMPANY_NAME_TO_ID = {
    "Klm": "56377143",
    "Air France": "106062176",
    "British Airways": "18332190",
    "American Air": "22536055",
    "Lufthansa": "124476322",
    "Air Berlin": "26223583",
    "Air Berlin assist": "2182373406",
    "easyJet": "38676903",
    "Ryanair": "1542862735",
    "Singapore Airlines": "253340062",
    "Qantas": "218730857",
    "Etihad Airways": "45621423",
    "Virgin Atlantic": "20626359",
}

COMPANY_ID_TO_NAME = {v: k for k, v in COMPANY_NAME_TO_ID.items()}

# Helper functions

In [3]:
def get_local_data(query: str, path: str, dtype: bool = True) -> pd.DataFrame:
    # Connect to the SQLite database using a context manager
    with sqlite3.connect(path) as connection:
        # Read the data into a DataFrame
        if dtype:
            df = pd.read_sql_query(query, connection,
                                   dtype=DTYPES,
                                   index_col='tweet_id')
            df['tweet_creation_time'] = pd.to_datetime(df['tweet_creation_time'])
            df['user_creation_time'] = pd.to_datetime(df['user_creation_time'])
        else:
            df = pd.read_sql_query(query, connection)
    
    return df


def fetch_data(query: str, dtype: bool = True) -> pd.DataFrame:
    engine = create_engine(f"mysql://{USER}:{PASSWORD}@{HOST}:3306/{DATABASE}")
    if dtype:
        return pd.read_sql_query(query, engine,
                                 dtype=DTYPES, index_col='tweet_id')
    return pd.read_sql_query(query, engine)


# Loading

In [4]:
# Server
# test_data = fetch_data(QUERY_ALL)
# Local
path =  os.path.join(
    os.path.dirname(
        os.path.dirname(
            os.getcwd()
            )
        ),
    "data_processed", "local_backup.db")

test_data = get_local_data(QUERY_ALL, path)

# Conversations

In [5]:
convo_special = test_data[["user_id", "replied_tweet_id", "tweet_creation_time"]]
convo_special

Unnamed: 0_level_0,user_id,replied_tweet_id,tweet_creation_time
tweet_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
1131172858951024641,393374091,,2019-05-22 12:20:00+00:00
1130922003702177800,880417607865815040,1130615560910254080,2019-05-21 19:43:11+00:00
1131172864147808257,3420691215,,2019-05-22 12:20:01+00:00
1131172867985485824,394376606,1131032916232826881,2019-05-22 12:20:02+00:00
1131030279278063616,227687574,,2019-05-22 02:53:26+00:00
...,...,...,...
1244696703690772485,278698748,,2020-03-30 18:43:14+00:00
1244696708983984131,246520593,,2020-03-30 18:43:15+00:00
1244696710447800320,109284383,,2020-03-30 18:43:15+00:00
1244696713350217728,1223576386432126976,,2020-03-30 18:43:16+00:00


In [6]:
class TrieNode:
    def __init__(self):
        self.children = defaultdict(TrieNode)
        self.is_end = False

class Trie:
    def __init__(self):
        self.root = TrieNode()

    def insert(self, conversation):
        node = self.root
        for tweet_id in conversation:
            node = node.children[tweet_id]
        node.is_end = True

    def is_subset(self, conversation):
        node = self.root
        for tweet_id in conversation:
            if tweet_id not in node.children:
                return False
            node = node.children[tweet_id]
        return True

def trace_conversation(start_tweet_id: str, tweet_dict: dict):
    convo = []
    current_tweet_id = start_tweet_id
    users_in_conversation = set()
    local_processed_tweet_ids = set()  # Local set to track the current conversation
    while current_tweet_id:
        if current_tweet_id not in tweet_dict or current_tweet_id in local_processed_tweet_ids:
            break
        tweet_info = tweet_dict[current_tweet_id]
        convo.append(current_tweet_id)
        users_in_conversation.add(tweet_info['user_id'])
        local_processed_tweet_ids.add(current_tweet_id)
        if len(users_in_conversation) > 2:
            return convo[:-1][::-1]  # As soon as the third user appears, we delete his tweet and return
        current_tweet_id = tweet_info['replied_tweet_id']
    return convo[::-1] if len(users_in_conversation) == 2 else None

def extract_and_filter_conversations(df: pd.DataFrame):
    df = df.sort_values("tweet_creation_time", ascending=False)
    df.index = df.index.astype(str)
    tweet_dict = df.to_dict('index')
    conversations = []
    trie = Trie()  # Initialize trie for subset checks

    # Start tracing conversations from tweets that are replies
    for tweet_id in tqdm(df[df['replied_tweet_id'].notnull()].index, desc="Extracting all conversations"):
        if conversation := trace_conversation(tweet_id, tweet_dict):
            if not trie.is_subset(conversation):
                trie.insert(conversation)
                conversations.append(conversation)

    return conversations



In [7]:
conversations = extract_and_filter_conversations(convo_special)

Extracting all conversations: 100%|██████████| 1795409/1795409 [01:54<00:00, 15746.77it/s]


In [8]:
data = []
for convo_num, convo in enumerate(conversations, start=1):
    data.extend((convo_num, tweet_id) for tweet_id in convo)
# Create a DataFrame
df_conversations = pd.DataFrame(data, columns=['Conversation', 'Tweet_ID'])

# Set MultiIndex
df_conversations

Unnamed: 0,Conversation,Tweet_ID
0,1,1244694453190897664
1,1,1244696682979303426
2,2,1244677304598609923
3,2,1244696641401163776
4,3,1244648694454026240
...,...,...
2712242,1064150,451125255294443521
2712243,1064151,430790355962052608
2712244,1064151,430792524043931648
2712245,1064152,248528541157834752


In [9]:

# Merge the conversation DataFrame with the test_data DataFrame
df_conversations_full = df_conversations.merge(test_data, left_on='Tweet_ID', right_index=True, how='left')

# Set the MultiIndex again with Conversation and Tweet_ID
df_conversations_full.set_index(['Conversation', 'Tweet_ID'], inplace=True)
df_conversations_full


Unnamed: 0_level_0,Unnamed: 1_level_0,user_id,user_creation_time,tweet_creation_time,full_text,lang,replied_tweet_id
Conversation,Tweet_ID,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
1,1244694453190897664,521835883,2012-03-12 01:11:22+00:00,2020-03-30 18:34:17+00:00,@nealrach @VirginAtlantic Siiiigh.... Still no...,en,1243885949697888263
1,1244696682979303426,20626359,2009-02-11 20:50:56+00:00,2020-03-30 18:43:09+00:00,@Jade_Velveteese Hi Jade. We have an ‘Away fro...,en,1244694453190897664
2,1244677304598609923,396021583,2011-10-22 16:35:05+00:00,2020-03-30 17:26:09+00:00,@VirginAtlantic Sod off your primary sharehold...,en,1244669964289806338
2,1244696641401163776,832964639436701696,2017-02-18 14:47:00+00:00,2020-03-30 18:42:59+00:00,"@Boyde11 @VirginAtlantic Get your facts right,...",en,1244677304598609923
3,1244648694454026240,1233410199500791809,2020-02-28 15:14:56+00:00,2020-03-30 15:32:27+00:00,@flavioArCab @Chapux0204 @chechiffss @aeronaut...,es,1244643427515535360
...,...,...,...,...,...,...,...
1064150,451125255294443521,22536055,2009-03-02 21:23:05+00:00,2014-04-01 22:33:37+00:00,@lanaupdates_ Your information has been forwar...,en,451124070730719233
1064151,430790355962052608,64327804,2009-08-10 03:34:27+00:00,2014-02-04 19:49:59+00:00,"@AmericanAir phew, they finally turned on the ...",en,
1064151,430792524043931648,22536055,2009-03-02 21:23:05+00:00,2014-02-04 19:58:36+00:00,@benjy_greenberg It looks like we'll have you ...,en,430790355962052608
1064152,248528541157834752,19911051,2009-02-02 15:17:02+00:00,2012-09-19 21:06:36+00:00,Un-fucking believable!\nThanks @BritishAirways...,en,


In [10]:
airline_conversation = df_conversations_full.loc[df_conversations_full.index.get_level_values('Conversation').isin(df_conversations_full[df_conversations_full['user_id'] == COMPANY_NAME_TO_ID["Lufthansa"]].index.get_level_values('Conversation'))]
airline_conversation

Unnamed: 0_level_0,Unnamed: 1_level_0,user_id,user_creation_time,tweet_creation_time,full_text,lang,replied_tweet_id
Conversation,Tweet_ID,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
140,1244658051740774400,124476322,2010-03-19 14:30:32+00:00,2020-03-30 16:09:38+00:00,@thick_daddy The online cancellation tool will...,en,1244650317494566913
140,1244688840033546245,119901222,2010-03-04 22:18:43+00:00,2020-03-30 18:11:59+00:00,@lufthansa @thick_daddy I received the same em...,en,1244658051740774400
140,1244691330837762051,124476322,2010-03-19 14:30:32+00:00,2020-03-30 18:21:53+00:00,"@NateThomasNOLA @thick_daddy At the moment, my...",en,1244688840033546245
194,1244684192040071173,62555545,2009-08-03 16:35:21+00:00,2020-03-30 17:53:31+00:00,@lufthansa had an email stating changes to my ...,en,
194,1244688444162486273,124476322,2010-03-19 14:30:32+00:00,2020-03-30 18:10:24+00:00,"@Holgate1987 At the moment, my colleagues in t...",en,1244684192040071173
...,...,...,...,...,...,...,...
1064013,1095473044573700096,1339792290,2013-04-09 17:56:30+00:00,2019-02-13 00:01:41+00:00,@lufthansa Read on @YahooNews you sued a guy f...,en,1094985662384594944
1064019,1090912573304774662,124476322,2010-03-19 14:30:32+00:00,2019-01-31 10:00:00+00:00,"With 297 seats, the #A340-600 is next in line ...",en,
1064019,1090913775417479168,18631142,2009-01-05 13:13:10+00:00,2019-01-31 10:04:46+00:00,"@lufthansa 747-400, then 747-8, then A380. :)",en,1090912573304774662
1064028,1083376699780280321,478699784,2012-01-30 15:32:40+00:00,2019-01-10 14:55:08+00:00,Hi @lufthansa how come there is no 'Ms' option...,en,


In [11]:
airline_conversation = airline_conversation.reset_index()
airline_conversation['New_Conversation'] = pd.factorize(airline_conversation['Conversation'])[0] + 1
airline_conversation = airline_conversation.set_index(['New_Conversation', 'Tweet_ID'])
airline_conversation = airline_conversation.sort_index(level='New_Conversation')
airline_conversation

Unnamed: 0_level_0,Unnamed: 1_level_0,Conversation,user_id,user_creation_time,tweet_creation_time,full_text,lang,replied_tweet_id
New_Conversation,Tweet_ID,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
1,1244658051740774400,140,124476322,2010-03-19 14:30:32+00:00,2020-03-30 16:09:38+00:00,@thick_daddy The online cancellation tool will...,en,1244650317494566913
1,1244688840033546245,140,119901222,2010-03-04 22:18:43+00:00,2020-03-30 18:11:59+00:00,@lufthansa @thick_daddy I received the same em...,en,1244658051740774400
1,1244691330837762051,140,124476322,2010-03-19 14:30:32+00:00,2020-03-30 18:21:53+00:00,"@NateThomasNOLA @thick_daddy At the moment, my...",en,1244688840033546245
2,1244684192040071173,194,62555545,2009-08-03 16:35:21+00:00,2020-03-30 17:53:31+00:00,@lufthansa had an email stating changes to my ...,en,
2,1244688444162486273,194,124476322,2010-03-19 14:30:32+00:00,2020-03-30 18:10:24+00:00,"@Holgate1987 At the moment, my colleagues in t...",en,1244684192040071173
...,...,...,...,...,...,...,...,...
14271,1095473044573700096,1064013,1339792290,2013-04-09 17:56:30+00:00,2019-02-13 00:01:41+00:00,@lufthansa Read on @YahooNews you sued a guy f...,en,1094985662384594944
14272,1090912573304774662,1064019,124476322,2010-03-19 14:30:32+00:00,2019-01-31 10:00:00+00:00,"With 297 seats, the #A340-600 is next in line ...",en,
14272,1090913775417479168,1064019,18631142,2009-01-05 13:13:10+00:00,2019-01-31 10:04:46+00:00,"@lufthansa 747-400, then 747-8, then A380. :)",en,1090912573304774662
14273,1083376699780280321,1064028,478699784,2012-01-30 15:32:40+00:00,2019-01-10 14:55:08+00:00,Hi @lufthansa how come there is no 'Ms' option...,en,


# Categorization

In [22]:
from transformers import pipeline
test_airline_conversation = airline_conversation.head(10)

classifier = pipeline("zero-shot-classification", model="joeddav/xlm-roberta-large-xnli")

candidate_labels = [
    "Flight Delays / Cancellations",
    "Booking",
    "Check-in",
    "Customer Service Complaints",
    "Seating / Boarding",
    "In-flight Experience",
    "General Flight Information",
    "Refunds",
    "Frequent Flyer Program",
    "Safety / Security",
    "Special Assistance",
    "Food / Beverage",
    "Overbooking",
    "Technical Difficulties",
    "Promotions / Offers",
    "Lost Luggage",
    "Baggage Issues"
]
def classify_conversation(conversation_text):
    result = classifier(conversation_text, candidate_labels)
    return result['labels'][0]

conversation_groups = test_airline_conversation.groupby('New_Conversation')
first_tweet_texts = conversation_groups.apply(lambda group: group['full_text'].iloc[0])
categories = first_tweet_texts.apply(classify_conversation)
test_airline_conversation['category'] = test_airline_conversation['New_Conversation'].map(categories)


All model checkpoint layers were used when initializing TFXLMRobertaForSequenceClassification.

All the layers of TFXLMRobertaForSequenceClassification were initialized from the model checkpoint at joeddav/xlm-roberta-large-xnli.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFXLMRobertaForSequenceClassification for predictions without further training.


KeyError: 'New_Conversation'

In [20]:
pd.set_option('display.max_colwidth', None)
test_airline_conversation


Unnamed: 0_level_0,Unnamed: 1_level_0,Conversation,user_id,user_creation_time,tweet_creation_time,full_text,lang,replied_tweet_id,category
New_Conversation,Tweet_ID,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
1,1244658051740774400,140,124476322,2010-03-19 14:30:32+00:00,2020-03-30 16:09:38+00:00,"@thick_daddy The online cancellation tool will only refund according to the fare conditions, Penelope. /Mac",en,1.244650317494567e+18,cancellations
1,1244688840033546245,140,119901222,2010-03-04 22:18:43+00:00,2020-03-30 18:11:59+00:00,"@lufthansa @thick_daddy I received the same email after cancelling and requesting a refund online (my flight segments had been cancelled by airline). Are you able to check refund status, or do we have to call the service center?",en,1.2446580517407744e+18,cancellations
1,1244691330837762051,140,124476322,2010-03-19 14:30:32+00:00,2020-03-30 18:21:53+00:00,"@NateThomasNOLA @thick_daddy At the moment, my colleagues in the Service Center are working on all re-bookings for the next 72 hours. Please reach out to them at a later stage to inquire about your refund status https://t.co/eRWyrKTFGQ. /Susi",en,1.2446888400335462e+18,cancellations
2,1244684192040071173,194,62555545,2009-08-03 16:35:21+00:00,2020-03-30 17:53:31+00:00,"@lufthansa had an email stating changes to my reservation ( you have cancelled my last leg and leaves me stranded in frankfurt ) states about rearranging via phone, but how can I when can't get through, no mention of refund tho, however you've cancelled so im entitled !! Advise",en,,booking issues
2,1244688444162486273,194,124476322,2010-03-19 14:30:32+00:00,2020-03-30 18:10:24+00:00,"@Holgate1987 At the moment, my colleagues in the Service Center are working on all re-bookings for the next 72 hours. Please reach out to them at a later stage to request a refund https://t.co/eRWyrKTFGQ. /Susi",en,1.2446841920400712e+18,booking issues
2,1244689066567794688,194,62555545,2009-08-03 16:35:21+00:00,2020-03-30 18:12:53+00:00,"@lufthansa My fight is in 3 days !!! last time I spoke to a rep 3 weeks ago, they said they would refund me in full, I got an email stating this would happen and nothing, why was I lied to ? They're rebooking for next 72hours so if I call in 4 days I will get a refund ?",en,1.2446884441624863e+18,booking issues
3,1244685953983225857,195,562252389,2012-04-24 16:49:01+00:00,2020-03-30 18:00:31+00:00,@lufthansa I requested a refund back on 16/03 and I’m yet to receive any updates. I appreciate your busy I’ve tried calling numerous times for 2hrs plus. Could you give me an update?,en,,refunds
3,1244689043952209921,195,124476322,2010-03-19 14:30:32+00:00,2020-03-30 18:12:47+00:00,"@AndyHall52 Due to the sheer amount of refund requests needed to be initiated manually by my colleagues, we are unable to predict you with a time frame at the moment, unfortunately. Your patience is highly appreciated, as my colleagues will refund your ticket as soon as possible. /Susi",en,1.2446859539832259e+18,refunds
4,1244684192040071173,201,62555545,2009-08-03 16:35:21+00:00,2020-03-30 17:53:31+00:00,"@lufthansa had an email stating changes to my reservation ( you have cancelled my last leg and leaves me stranded in frankfurt ) states about rearranging via phone, but how can I when can't get through, no mention of refund tho, however you've cancelled so im entitled !! Advise",en,,booking issues
4,1244688444162486273,201,124476322,2010-03-19 14:30:32+00:00,2020-03-30 18:10:24+00:00,"@Holgate1987 At the moment, my colleagues in the Service Center are working on all re-bookings for the next 72 hours. Please reach out to them at a later stage to request a refund https://t.co/eRWyrKTFGQ. /Susi",en,1.2446841920400712e+18,booking issues


### Fine-tuning the model

In [23]:
import tensorflow as tf
from transformers import XLMRobertaTokenizer, TFXLMRobertaForSequenceClassification
from sklearn.preprocessing import MultiLabelBinarizer

tokenizer = XLMRobertaTokenizer.from_pretrained('joeddav/xlm-roberta-large-xnli')
model = TFXLMRobertaForSequenceClassification.from_pretrained('joeddav/xlm-roberta-large-xnli')


#This is the example data you should feed into the model
labeled_data = [
    {"tweet": "full_text 1", "labels": ["category1"]},
    {"tweet": "full_text 2", "labels": ["category3"]},
]

inputs = tokenizer([item["tweet"] for item in labeled_data], padding=True, truncation=True, return_tensors="tf")

mlb = MultiLabelBinarizer()
labels = mlb.fit_transform([item["labels"] for item in labeled_data])
labels = tf.convert_to_tensor(labels, dtype=tf.int32)

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=3e-5), loss=tf.keras.losses.BinaryCrossentropy(), metrics=['accuracy'])

#The model automatically saves in to your local directory after you run it once. For better performance
#it is possible to adjust epoch size and batch_size, but i would not recommend it.

model.fit(inputs, labels, epochs=3, batch_size=8)



# Sentiment Analysis

In [14]:
#Load the tokenizer and model once
model_name = "cardiffnlp/twitter-roberta-base-sentiment"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = TFAutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment")

#Check if GPU is available
if tf.test.is_gpu_available():
    device = '/GPU:0'
else:
    device = '/CPU:0'

#Set the labels for sentiment results
labels = {
    0 : 'negative',
    1 : 'neutral',
    2 : 'positive'
}

def process_batch(texts):
    """
    Apply sentiment analysis to a batch of texts using a pre-trained transformer model.

    Parameters:
    texts (list): A list of texts to analyze.

    Returns:
    list: A list of ranked sentiment labels for the input texts.

    This function uses a pre-trained transformer model for sequence classification to analyze the sentiment of the texts in the given batch. It returns a list of ranked sentiment labels for the input texts, where the first label is the most likely sentiment and the subsequent labels are less likely sentiments.
    """
    encoded_input = tokenizer(texts, return_tensors='tf', padding=True, truncation=True)
    with tf.device(device):
        output = model(encoded_input)

    scores = output[0].numpy()
    scores = softmax(scores, axis=1)

    sentiment_scores = scores[:, 2] - scores[:, 0]
    return sentiment_scores.tolist()

def apply_sentiment_analysis(df, text_column, batch_size=64, max_workers=4):
    """
    Apply sentiment analysis to the given DataFrame using a pre-trained transformer model.

    Parameters:
    df (pd.DataFrame): The input DataFrame containing the text column to analyze.
    text_column (str): The name of the column in the DataFrame containing the text to analyze.
    batch_size (int): The number of texts to process in each batch. Default is 128.
    max_workers (int): The maximum number of worker threads to use for parallel processing. Default is 4.

    Returns:
    pd.DataFrame: The input DataFrame with an additional 'sentiment' column containing the sentiment analysis results.

    This function uses a pre-trained transformer model for sequence classification to analyze the sentiment of the texts in the given DataFrame. It applies the sentiment analysis in parallel using multiple worker threads to improve performance. The results are then added to the input DataFrame as a new 'sentiment' column.
    """
    texts = df[text_column].tolist()
    results = []

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i+batch_size]
            results.extend(executor.submit(process_batch, batch).result())

    df['sentiment'] = results
    return df


All model checkpoint layers were used when initializing TFRobertaForSequenceClassification.

All the layers of TFRobertaForSequenceClassification were initialized from the model checkpoint at cardiffnlp/twitter-roberta-base-sentiment.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFRobertaForSequenceClassification for predictions without further training.


Instructions for updating:
Use `tf.config.list_physical_devices('GPU')` instead.


In [15]:
pd.set_option('display.max_colwidth', None)  
full_text_df = airline_conversation[['full_text']].copy()
airline_conversation_sample = full_text_df.head(5)
airline_conversation_sample = apply_sentiment_analysis(airline_conversation_sample, 'full_text')
airline_conversation_sample


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['sentiment'] = results


Unnamed: 0_level_0,Unnamed: 1_level_0,full_text,sentiment
New_Conversation,Tweet_ID,Unnamed: 2_level_1,Unnamed: 3_level_1
1,1244658051740774400,"@thick_daddy The online cancellation tool will only refund according to the fare conditions, Penelope. /Mac",-0.012933
1,1244688840033546245,"@lufthansa @thick_daddy I received the same email after cancelling and requesting a refund online (my flight segments had been cancelled by airline). Are you able to check refund status, or do we have to call the service center?",-0.467777
1,1244691330837762051,"@NateThomasNOLA @thick_daddy At the moment, my colleagues in the Service Center are working on all re-bookings for the next 72 hours. Please reach out to them at a later stage to inquire about your refund status https://t.co/eRWyrKTFGQ. /Susi",-0.02974
2,1244684192040071173,"@lufthansa had an email stating changes to my reservation ( you have cancelled my last leg and leaves me stranded in frankfurt ) states about rearranging via phone, but how can I when can't get through, no mention of refund tho, however you've cancelled so im entitled !! Advise",-0.652843
2,1244688444162486273,"@Holgate1987 At the moment, my colleagues in the Service Center are working on all re-bookings for the next 72 hours. Please reach out to them at a later stage to request a refund https://t.co/eRWyrKTFGQ. /Susi",-0.124847


In [16]:
import mysql.connector
from mysql.connector import Error
from typing import Tuple, List
import pandas as pd
from tqdm import tqdm


def connect_to_database(user: str, database: str, password: str, host: str):
    """
    Establish a connection to the database.
    
    Args:
        user: The database user.
        database: The name of the database.
        password: The password for the database.
        host: The database host.
    
    Returns:
        A connection object to the MySQL database.
    """
    try:
        connection = mysql.connector.connect(
            user=user,
            password=password,
            host=host,
            database=database
        )
        if connection.is_connected():
            return connection
    except Error as e:
        print(f"Error while connecting to MySQL: {e}")
    return None


def get_batches(df: pd.DataFrame, batch_size: int = 1000) -> List[pd.DataFrame]:
    """
    Split DataFrame into batches of DataFrames with specified batch size.
    
    Args:
        df: The DataFrame containing tweet data.
        batch_size: The size of each batch.
        
    Returns:
        A list of DataFrames, each containing a batch of rows.
    """
    return [df.iloc[i:i + batch_size] for i in range(0, len(df), batch_size)]

def update_sentiment_batch(user: str, database: str, password: str, host: str, batch: List[Tuple[str, float]]):
    """
    Update sentiment values for a batch of data in the database.
    
    Args:
        user: The database user.
        database: The name of the database.
        password: The password for the database.
        host: The database host.
        batch: A list of tuples containing tweet_id and sentiment.
    """
    try:
        connection = connect_to_database(user, database, password, host)
        if connection is None:
            return
        cursor = connection.cursor()
        update_query = "UPDATE Tweets SET sentiment_score = %s WHERE tweet_id = %s"
        cursor.executemany(update_query, batch)
        connection.commit()
        cursor.close()
        connection.close()
    except Error as e:
        print(f"Error updating batch: {e}")

def upload_sentiment(batch: Tuple[Tuple[float, str]],
                     user: str, database: str, password: str,
                     host: str) -> None:
    """
    Create and insert batches of tweets into the database in parallel.
    
    Args:
        batches_list: Tuple of (sentiment, tweet_id) pairs.
        user: The database user.
        database: The name of the database.
        password: The password for the database.
        host: The database host.
    """
    connection = connect_to_database(user, database, password, host)
    if connection is None:
        return
    cursor = connection.cursor()
    update_query = "UPDATE Tweets SET sentiment_score = %s WHERE tweet_id = %s"
    cursor.executemany(update_query, batch)
    connection.commit()
    cursor.close()
    connection.close()

    print("Sentiment values updated successfully")

In [17]:
def convert_to_list(df: pd.DataFrame) -> List[List]:
    """
    Convert DataFrame with tweet_id as index to a list of lists containing sentiment and tweet_id.
    
    Args:
        df: The DataFrame with tweet_id as index and sentiment as a column.
        
    Returns:
        A list of lists containing tweet_id and sentiment.
    """
    tweet_ids = df.index.to_numpy()
    sentiments = df['sentiment'].to_numpy()
    return tuple(zip(sentiments, tweet_ids))





In [18]:
data_batches = get_batches(test_data[["full_text"]], 10_000)

KeyboardInterrupt: 

In [None]:
user, database, password, host = os.getenv("DBL_USER"), os.getenv("DBL_DATABASE"), os.getenv("DBL_PASSWORD"), os.getenv("DBL_HOST")
for batch in tqdm(data_batches):
    sentiment_with_score = apply_sentiment_analysis(batch, "full_text")
    upload_sentiment(convert_to_list(sentiment_with_score), user, database, password, host)