In [2]:
import tkinter as tk
from tkinter import messagebox
from PIL import ImageTk, Image
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from dataCleaningScript import appriviation_converter
from dataCleaningScript import clean_text
import csv
import random

# Load the pre-trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load('bert_model_cb', map_location=device)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

# Function to preprocess the tweet text
def preprocess_tweet(tweet_text):
    tweet = [tweet_text]
    # Apply data cleaning steps
    tweet_text = appriviation_converter(tweet_text)
    tweet_text = clean_text(tweet_text)
    tweet = [tweet_text]

    inputs = tokenizer.batch_encode_plus(
        tweet,
        add_special_tokens=True,
        max_length=128,
        padding='max_length',
        truncation=True,
        return_token_type_ids=False,
        return_attention_mask=True,
        return_tensors='pt'
    )
    input_ids = inputs['input_ids']
    attention_masks = inputs['attention_mask']
    return input_ids, attention_masks

# Function to classify the tweet
def classify_tweet():
    tweet_text = tweet_entry.get("1.0", tk.END).strip()  # Get the tweet text from the Text widget

    # Check if the tweet text is not empty
    if tweet_text:
        # Preprocess the tweet
        input_ids, attention_masks = preprocess_tweet(tweet_text)

        # Move inputs to the same device as the model
        input_ids = input_ids.to(device)
        attention_masks = attention_masks.to(device)
        #print(input_ids)
        #print(attention_masks)

        # Make the prediction
        with torch.no_grad():
            logits = model(input_ids, attention_masks)[0]
            predicted_label = torch.argmax(logits, dim=1).item()

        # Map the predicted label to the corresponding class
        class_mapping = {0: 'oppose', 1: 'support'}
        predicted_class = class_mapping[predicted_label]

        # Clear previous prediction
        predicted_label_var.set("")

        # Update the label after 1 second
        window.after(500, lambda: predicted_label_var.set(f"The Stance Classification for the given tweet is: {predicted_class}"))

# Function to add a random tweet to the tweet_entry
def add_random_tweet():
    with open('tweetstance.csv', 'r') as file:
        tweets = list(csv.reader(file))
    if tweets:
        random_tweets = random.choice(tweets)
        random_tweet = random_tweets[0]
        #print(random_tweets[1])
        tweet_entry.delete("1.0", tk.END)
        tweet_entry.insert(tk.END, random_tweet)
        random_tweet_textbox.delete(0, tk.END)
        random_tweet_textbox.insert(tk.END, f"The original stance is '{random_tweets[1]}'")

# Function to clear the tweet_entry
def clear_tweet():
    tweet_entry.delete("1.0", tk.END)
    random_tweet_textbox.delete(0, tk.END)
    predicted_label_var.set("")  # Clear the predicted stance label

# Create the GUI
window = tk.Tk()
window.title("FEEDS-Project-SS2023")
window.geometry("1600x1000")  # Full page size

# Set the background color
window.configure(bg="#00aced")

# Create a frame to hold the widgets
frame = tk.Frame(window, bg="#00aced")
frame.pack(expand=True)

# Create the label and entry widgets
project_label = tk.Label(frame, text="FEEDS",
                         font=("Britannic Bold", 72, "bold"), bg="#00aced", fg="white")
project_label.pack(pady=10)

# Create the label and entry widgets
project_label = tk.Label(frame, text="Neural Representations for Multimodal Data",
                         font=("Britannic Bold", 32, "bold"), bg="#00aced", fg="white")
project_label.pack(pady=10)

task_label = tk.Label(frame, text="Argumentative Stanc (AS) as Classification",
                      font=("Britannic Bold", 24, "bold"), bg="#00aced", fg="white")
task_label.pack(pady=10)

# Create a frame for tweet label and entry
tweet_frame = tk.Frame(frame, bg="#0084b4")
tweet_frame.pack(pady=10)

# Load the image and create a PhotoImage object
image = Image.open("twitter.png")
image = image.resize((110, 100))  # Resize the image if needed
photo = ImageTk.PhotoImage(image)

# Create a label for the image and display it
image_label = tk.Label(tweet_frame, image=photo, bg="#0084b4")
image_label.grid(row=0, column=0, rowspan=2, padx=10)

tweet_label = tk.Label(tweet_frame, text="Enter the tweet text:", font=("Britannic Bold", 20, "bold"), bg="#0084b4", fg="white")
tweet_label.grid(row=0, column=1, sticky="w", padx=10)

tweet_entry = tk.Text(tweet_frame, height=4, width=60, font=("ARVO", 14))
tweet_entry.grid(row=1, column=1, padx=10, pady=(0, 10), sticky="w")

# Create the classify button
classify_button = tk.Button(frame, text="Classify", command=classify_tweet, font=("Britannic Bold", 14, "bold"),
                            bg="#0084b4", fg="white")
classify_button.pack(pady=10)

# Create a frame for the buttons
buttons_frame = tk.Frame(frame, bg="#00aced")
buttons_frame.pack(pady=10)

# Create a button to add a random tweet
random_tweet_button = tk.Button(buttons_frame, text="Add Random Tweet", command=add_random_tweet, font=("Britannic Bold", 14, "bold"),
                            bg="#0084b4", fg="white")
random_tweet_button.pack(side=tk.LEFT, padx=5, pady=10)

# Create a button to clear the tweet entry
clear_button = tk.Button(buttons_frame, text="Clear", command=clear_tweet, font=("Britannic Bold", 14, "bold"),
                            bg="#0084b4", fg="white")
clear_button.pack(side=tk.LEFT, padx=5, pady=10)

# Create a small text box to display random_tweets[1]
random_tweet_textbox = tk.Entry(frame, font=("ARVO", 14), width=60)
random_tweet_textbox.pack(pady=10)

# Create a label to display the predicted class
predicted_label_var = tk.StringVar()
predicted_label_var.set("")
predicted_label = tk.Label(frame, textvariable=predicted_label_var, font=("ARVO", 18, "bold"),
                           bg="#00aced", fg="white")
predicted_label.pack(pady=10)

# Add the line "Tweet your opinion on Gun Control or Abortion"
tweet_opinion_label = tk.Label(frame, text="[Write your opinion on #GunControl or #Abortion]",
                               font=("Britannic Bold", 18, "bold"), bg="#00aced", fg="white")
tweet_opinion_label.pack(pady=10)


# Center align the widgets
frame.place(relx=0.5, rely=0.5, anchor=tk.CENTER)

# Run the GUI main loop
window.mainloop()