# Imports

In [1]:
import os
import sys
from concurrent.futures import ThreadPoolExecutor
from typing import List, Tuple

import pandas as pd
from imblearn.over_sampling import SMOTE
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.svm import SVC
from sklearn.utils.class_weight import compute_class_weight
from tqdm.notebook import tqdm

sys.path.append(os.path.join(os.path.dirname(os.getcwd()), "_0_Constants_and_Utils"))


from category_utils import convert_to_list, get_batches, normalise_text
from database_utils import (
    connect_to_database,
    execute_queries,
    form_connection_params,
    get_dataframe_from_query,
)

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Chekm\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\Chekm\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\Chekm\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


# Constants

In [2]:
QUERY_TEXT = """
SELECT 
    c.conversation_id, 
    CONCAT(t1.full_text, ' ', t2.full_text) AS combined_full_text
FROM 
    Conversations AS c
JOIN 
    Tweets AS t1
ON 
    c.tweet_id = t1.tweet_id
JOIN 
    ConversationsCategory AS cc
ON 
    c.conversation_id = cc.conversation_id
JOIN 
    Conversations AS c2
ON 
    c.conversation_id = c2.conversation_id 
AND 
    c2.tweet_order = 2
JOIN 
    Tweets AS t2
ON 
    c2.tweet_id = t2.tweet_id
WHERE 
    c.tweet_order = 1
"""

# Loading

In [3]:
# Set local = False if you want to query the online MySQL database
local = True
connection_params = form_connection_params(local, True)

In [4]:
test_data = get_dataframe_from_query(QUERY_TEXT, connection_params, local, index_col="conversation_id")
test_data

Unnamed: 0_level_0,combined_full_text
conversation_id,Unnamed: 1_level_1
1,@nealrach @VirginAtlantic Siiiigh.... Still no...
2,We’re waiving change fees for customers who ha...
3,@katiewithani Please be assured if your flight...
4,@Grenzmauer75 @elliotday @easyJet Exactly. Do ...
5,@RosamariaP3 Hola Rosa ✌ Siento que aún no hay...
...,...
458720,@airfrance j'ai mis une bombe dans un a avion...
458721,@Ryanair What if I make it into a Turban then?...
458722,@AmericanAir Please help me!! I've fallen on ...
458723,@AmericanAir i was kidding thanks for the foll...


In [5]:
test_data['cleaned_text'] = test_data['combined_full_text'].apply(normalise_text)

In [6]:
test_data

Unnamed: 0_level_0,combined_full_text,cleaned_text
conversation_id,Unnamed: 1_level_1,Unnamed: 2_level_1
1,@nealrach @VirginAtlantic Siiiigh.... Still no...,siiiigh still idea theyre back officially...
2,We’re waiving change fees for customers who ha...,’ waiving change fee customer travel plan ...
3,@katiewithani Please be assured if your flight...,please assured flight cancelled contact...
4,@Grenzmauer75 @elliotday @easyJet Exactly. Do ...,exactly plan go bankrupt pay shareholder ...
5,@RosamariaP3 Hola Rosa ✌ Siento que aún no hay...,hola rosa ✌ siento que aún hayas recibido el ...
...,...,...
458720,@airfrance j'ai mis une bombe dans un a avion...,jai mi une bombe dans un avion n prenons ce m...
458721,@Ryanair What if I make it into a Turban then?...,make turban sorry harry luggage bi...
458722,@AmericanAir Please help me!! I've fallen on ...,please help ive fallen one plain please re...
458723,@AmericanAir i was kidding thanks for the foll...,kidding thanks follow tho information for...


# Categorization

### Prepare the data for training and testing

In [8]:
# Step 1: Load the data
data = pd.read_excel('clean_labels.xlsx').query("Category != 'Undefined category'")

# Step 3: Convert text data to numerical format using TF-IDF
tfidf_vectorizer = TfidfVectorizer(max_features=5000)
X = tfidf_vectorizer.fit_transform(data["text"]).toarray()
y = data['Category']

# Step 4: Addressing Imbalance using SMOTE
smote = SMOTE(random_state=42)  # Ensuring reproducibility
X_resampled, y_resampled = smote.fit_resample(X, y)

# Step 5: Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X_resampled, y_resampled,
                                                    test_size=0.05, random_state=42)

# Step 6: Encode the labels
label_encoder = LabelEncoder()
label_encoder.fit(y_resampled)  # Fit the encoder on the resampled labels
y_train_encoded = label_encoder.transform(y_train)
y_test_encoded = label_encoder.transform(y_test)

# Step 7: Compute class weights
class_weights = compute_class_weight('balanced', classes=label_encoder.classes_, y=y_resampled)
class_weight_dict = {i : class_weights[i] for i in range(len(class_weights))}

# Step 8: Create and train the SVM model with a linear kernel
svm_model = SVC(kernel='linear', probability=True, random_state=42, class_weight=class_weight_dict)
svm_model.fit(X_train, y_train_encoded)

# Step 9: Evaluate the model performance on the test set
y_pred = svm_model.predict(X_test)

# Ensuring the target names match the unique classes in the training set
unique_classes = label_encoder.classes_
print(classification_report(y_test_encoded, y_pred,
                            labels=range(len(unique_classes)),
                            target_names=unique_classes))

                                         precision    recall  f1-score   support

                    Baggage and Luggage       1.00      1.00      1.00        31
                                Booking       1.00      1.00      1.00        37
                               Check-in       0.98      1.00      0.99        47
Customer Service and Special Assistance       1.00      0.97      0.98        33
               Delays and Cancellations       1.00      0.96      0.98        25
            Flight Information Requests       0.96      1.00      0.98        23
                     Food and Beverages       1.00      1.00      1.00        36
                         Frequent Flyer       1.00      1.00      1.00        33
                   In-Flight Experience       0.88      0.88      0.88        24
                  Promotions and Offers       1.00      1.00      1.00        34
               Refunds and Transactions       1.00      1.00      1.00        35
                    Safety 

In [9]:
def predict_categories(batch):
    tweet_vectors = tfidf_vectorizer.transform(batch).toarray()
    predicted_labels = svm_model.predict(tweet_vectors)
    categories = label_encoder.inverse_transform(predicted_labels)
    return categories.tolist()

In [10]:
df = pd.read_excel('clean_labels.xlsx').query("Category != 'Undefined category'")
test = df.copy()
test['our_guess'] = predict_categories(test['text'].tolist())
accuracy = (test['Category'] == test['our_guess']).mean()*100


print(f"Accuracy after fine-tuning: {accuracy:.2f}%")
test

Accuracy after fine-tuning: 98.84%


Unnamed: 0,text,Category,our_guess
0,al hilo de la demostración de este fenómeno v...,In-Flight Experience,In-Flight Experience
1,hi rachel price isnt available possible m...,Booking,Booking
2,thank showing around airbus first class s...,In-Flight Experience,In-Flight Experience
3,ejuqv oeizc easyjet europe landed lowi plan...,Flight Information Requests,Flight Information Requests
4,ezydn gezwd easyjet landed lowi planespotti...,Flight Information Requests,Flight Information Requests
...,...,...,...
2602,frequent flyer im extremely annoyed looking...,Frequent Flyer,Frequent Flyer
2603,definitely come top worst airline experienc...,Frequent Flyer,Frequent Flyer
2604,current flight information updated regularly...,Frequent Flyer,Frequent Flyer
2605,often see get lot complaint deal freque...,Frequent Flyer,Frequent Flyer


In [11]:
def get_category(df, text_column, batch_size=128, max_workers=4):
    texts = df[text_column].tolist()
    labels = []

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [executor.submit(predict_categories, texts[i:i + batch_size]) for i in range(0, len(texts), batch_size)]
        for future in futures:
            batch_labels = future.result()
            labels.extend(batch_labels)

    df["category"] = labels
    return df


def update_categories(
    batch: List[Tuple[str, str]], connection_params: dict, local: bool
) -> None:
    """
    Update full_text values for a batch of data in the local SQLite database.

    Args:
        batch: List of (full_text, tweet_id) pairs.
        db_path: The path to the SQLite database file.
    """
    update_query = "UPDATE ConversationsCategory SET category = ? WHERE conversation_id = ?"
    if not local:
        update_query = update_query.replace("?", "%s")
    with connect_to_database(connection_params, local) as connection:
        execute_queries(connection, [(update_query, batch)])

In [12]:
data_batches = get_batches(test_data[["cleaned_text"]], 10_000)

In [13]:
for batch in tqdm(data_batches, desc="Updating categories: "):
    df_categories = get_category(batch, "cleaned_text", 512, 10)
    update_categories(convert_to_list(df_categories), connection_params, local)

Updating categories:   0%|          | 0/46 [00:00<?, ?it/s]