<a href="https://colab.research.google.com/github/simranbains9810/zero_shot_classifier/blob/main/zero_shot_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Classifying news headlines using Zero Shot Classification (BART - Bidirectional and Auto-Regressive Transformers)**

In [1]:
!python -V

Python 3.10.12


In [2]:
!pip install gensim
!pip install transformers
!pip install NLTK



In [3]:
!pip install streamlit



In [4]:
!pip install pyngrok



In [11]:
#Installing packages
#%%writefile app.py

import re
import textwrap
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import pandas as pd
import gensim
from gensim import corpora

from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
import nltk
from transformers import BartForSequenceClassification, BartTokenizer, pipeline

from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords
from wordcloud import WordCloud

import streamlit as st
import matplotlib.pyplot as plt

from google.colab import files

Cleaning dataset

In [6]:
uploaded = files.upload()

Saving zero_shot_subset.csv to zero_shot_subset (3).csv


In [8]:
test = pd.read_csv("zero_shot_subset.csv")
test

Unnamed: 0.1,Unnamed: 0,title,date,stock
0,963395.0,Top Stocks In The Surety & Title Insurance Ind...,2010-04-13 05:00:00-04:00,ORI
1,102011.0,Asterias Provides 6 Mo. Data Readout From Its ...,2018-07-17 07:05:00-04:00,AST
2,952152.0,ONEOK Partners to Participate in Bakken Day,2010-08-04 16:22:00-04:00,OKS
3,189463.0,Shares of Broadsoft to Resume Trade at 4:35PM EST,2013-11-04 16:12:00-05:00,BSFT
4,535387.0,"Courier Corp Holder Gamco Reports 6.97%, Up Fr...",2015-03-30 16:20:00-04:00,GBL
...,...,...,...,...
995,748895.0,Hearing Ares in Talks To Acquire Kayne Anderson,2015-06-10 15:41:00-04:00,KYN
996,132726.0,Alibaba becomes Major Olympics Sponsor through...,2017-01-19 06:04:00-05:00,BABA
997,1029212.0,A Peek Into The Market Before The Trading Starts,2011-06-07 07:31:00-04:00,PPL
998,622004.0,"Stocks Which Set New 52-Week Low Yesterday, Mo...",2018-11-27 11:26:00-05:00,HNRG


In [12]:
financial_news_taxonomy = {
    "Market Movements": [
        "Stock Trends", "Commodity Prices", "Currency Rates", "Bond Yields",
        "Crypto Prices", "Volatility Index", "Sector Rotation"
    ],
    "Company News": [
        "Earnings", "Product Launches", "M&A", "Legal & Regulatory", "Leadership",
        "Share Buybacks", "Bankruptcies", "Dividend Announcements", "Corporate Restructuring"
    ],
    "Economic Data": [
        "Employment", "Inflation", "GDP", "Consumer Sentiment",
        "Housing Market Data", "Retail Sales", "Trade Balance", "Manufacturing Output"
    ],
    "Policy & Regulation": [
        "Central Bank Actions", "Trade & Tax Policies", "Environmental Rules",
        "Interest Rate Decisions", "Fiscal Policies", "Antitrust Actions"
    ],
    "Global Events": [
        "Geopolitical Issues", "Elections", "Natural Disasters", "Health Crises",
        "Trade Wars", "Sanctions", "International Treaties", "Global Supply Chain Disruptions"
    ],
    "Sector Highlights": [
        "Tech Developments", "Banking News", "Energy Updates", "Healthcare Innovations", "Real Estate Trends",
        "Telecommunications", "Consumer Goods", "Automotive Industry", "Pharmaceutical Breakthroughs"
    ],
    "Investment Insights": [
        "Fund Activities", "Asset Trends", "Investment Strategies",
        "Private Equity", "Venture Capital", "Risk Management", "Hedge Fund Performance", "Sustainable Investing"
    ]
}

print(financial_news_taxonomy)

financial_news_terms = [item for sublist in financial_news_taxonomy.values() for item in sublist]
print(financial_news_terms)

{'Market Movements': ['Stock Trends', 'Commodity Prices', 'Currency Rates', 'Bond Yields', 'Crypto Prices', 'Volatility Index', 'Sector Rotation'], 'Company News': ['Earnings', 'Product Launches', 'M&A', 'Legal & Regulatory', 'Leadership', 'Share Buybacks', 'Bankruptcies', 'Dividend Announcements', 'Corporate Restructuring'], 'Economic Data': ['Employment', 'Inflation', 'GDP', 'Consumer Sentiment', 'Housing Market Data', 'Retail Sales', 'Trade Balance', 'Manufacturing Output'], 'Policy & Regulation': ['Central Bank Actions', 'Trade & Tax Policies', 'Environmental Rules', 'Interest Rate Decisions', 'Fiscal Policies', 'Antitrust Actions'], 'Global Events': ['Geopolitical Issues', 'Elections', 'Natural Disasters', 'Health Crises', 'Trade Wars', 'Sanctions', 'International Treaties', 'Global Supply Chain Disruptions'], 'Sector Highlights': ['Tech Developments', 'Banking News', 'Energy Updates', 'Healthcare Innovations', 'Real Estate Trends', 'Telecommunications', 'Consumer Goods', 'Automot

In [13]:
nltk.download("stopwords")
nltk.download("punkt")
nltk.download("wordnet")

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


True

In [14]:
#Process data to tokenise, remove stop words and clean it
tokenized_data = test['title'].apply(word_tokenize).tolist()
stop_words = set(stopwords.words('english'))
cleaned_data = [[word for word in doc if word.lower() not in stop_words and word.isalpha()] for doc in tokenized_data]

# Initialize lemmatizer and define English stopwords
lemmatizer = WordNetLemmatizer()

In [15]:
#Testing whether GPU is available, and set to either GPU or CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [16]:
#Function to process text data by tokenizing, removing stop words and lemmatizing. Function is applied to 'title' column of the test dataframe
def process_headline(headline):
    tokens = word_tokenize(headline)
    clean_tokens = [lemmatizer.lemmatize(token) for token in tokens if token.lower() not in stop_words]
    return " ".join(clean_tokens)

test['processed_headline'] = test['title'].apply(process_headline)

In [17]:
test.head()

Unnamed: 0.1,Unnamed: 0,title,date,stock,processed_headline
0,963395.0,Top Stocks In The Surety & Title Insurance Ind...,2010-04-13 05:00:00-04:00,ORI,Top Stocks Surety & Title Insurance Industry H...
1,102011.0,Asterias Provides 6 Mo. Data Readout From Its ...,2018-07-17 07:05:00-04:00,AST,Asterias Provides 6 Mo . Data Readout AST-OPC1...
2,952152.0,ONEOK Partners to Participate in Bakken Day,2010-08-04 16:22:00-04:00,OKS,ONEOK Partners Participate Bakken Day
3,189463.0,Shares of Broadsoft to Resume Trade at 4:35PM EST,2013-11-04 16:12:00-05:00,BSFT,Shares Broadsoft Resume Trade 4:35PM EST
4,535387.0,"Courier Corp Holder Gamco Reports 6.97%, Up Fr...",2015-03-30 16:20:00-04:00,GBL,"Courier Corp Holder Gamco Reports 6.97 % , 0.0..."


Loading the BART model:
1.   Tokenize the input text using the BART Tokenizer
2.   Pass the tokenized input through the BartforSequenceClassification model to get predictions
3.   Use the model output to interpret the results, such as predicting text entailment or performing zero-shot classification


In [18]:
from transformers.models.bart.modeling_bart import BartForConditionalGeneration
model_name = "facebook/bart-large-mnli"
tokeniser = BartTokenizer.from_pretrained(model_name)
model = BartForSequenceClassification.from_pretrained(model_name)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [19]:
classifier = pipeline(
    task="zero-shot-classification",
    model=model_name,
    tokenizer=tokeniser,
    device=device.index
)

In [20]:
def process_headline(headline):
    tokens = word_tokenize(headline)
    clean_tokens = [lemmatizer.lemmatize(token) for token in tokens if token.lower() not in stop_words]
    return " ".join(clean_tokens)

In [21]:
# Initialize zero-shot classification pipeline
@st.cache(allow_output_mutation=True)
def load_classifier():
    return pipeline("zero-shot-classification", model="facebook/bart-large-mnli")

classifier = load_classifier()


2024-10-12 12:53:36.049 
  command:

    streamlit run /usr/local/lib/python3.10/dist-packages/colab_kernel_launcher.py [ARGUMENTS]
2024-10-12 12:53:36.054 
`st.cache` is deprecated and will be removed soon. Please use one of Streamlit's new
caching commands, `st.cache_data` or `st.cache_resource`. More information
[in our docs](https://docs.streamlit.io/develop/concepts/architecture/caching).

**Note**: The behavior of `st.cache` was updated in Streamlit 1.36 to the new caching
logic used by `st.cache_data` and `st.cache_resource`. This might lead to some problems
or unexpected behavior in certain edge cases.



Streamlit app

In [22]:
# Streamlit app layout
st.set_page_config(
    page_title="Zero-shot Classifier App",
    page_icon="✅",
    layout="wide",
    initial_sidebar_state="expanded",
)

st.title("Zero-shot Classifier for Financial News")
st.write("Classify financial news headlines using zero-shot classification.")

# Input form: text area or file upload
user_input = st.text_area("Enter your headline here:", "")
uploaded_file = st.file_uploader("Or upload a CSV file", type=["csv"])

column_name = None
df = None

if uploaded_file:
    # Load the CSV to check columns
    df = pd.read_csv(uploaded_file)

    # Allow the user to select the column to classify
    column_name = st.selectbox("Choose a column to classify:", df.columns)

# Classify button with progress
if st.button("Classify"):
    progress_bar = st.progress(0)

    if user_input:
        # Single text input classification
        label, score = classify_text(user_input)
        st.write(f"Predicted Label: {label}")
        st.write(f"Confidence Score: {score:.4f}")
        progress_bar.progress(1.0)  # Complete progress bar when done

    elif uploaded_file and column_name:
        # Classify CSV data
        if column_name in df.columns:
            total_len = len(df)
            for i, row in enumerate(df[column_name]):
                df.at[i, 'Predicted Label'], df.at[i, 'Confidence Score'] = classify_text(row)
                progress_bar.progress((i + 1) / total_len)
            st.write(df)

            # Plotting charts and word cloud
            st.subheader('Predicted Label Distribution')
            fig1, ax1 = plt.subplots(figsize=(10, 6))
            df['Predicted Label'].value_counts().plot(kind='bar', ax=ax1, color='skyblue')
            plt.xticks(rotation=45)
            ax1.set_xlabel('Predicted Label')
            ax1.set_ylabel('Count')
            ax1.set_title('Distribution of Predicted Labels')
            st.pyplot(fig1)

            st.subheader('Confidence Score Distribution')
            fig2, ax2 = plt.subplots(figsize=(10, 6))
            df['Confidence Score'].hist(bins=30, ax=ax2, color='salmon')
            ax2.set_xlabel('Confidence Score')
            ax2.set_ylabel('Frequency')
            ax2.set_title('Distribution of Confidence Scores')
            st.pyplot(fig2)

            st.subheader('Word Cloud from Headlines')
            text = ' '.join(df[column_name].apply(process_headline))
            wordcloud = WordCloud(background_color='white', colormap='viridis', width=800, height=400, max_words=200).generate(text)
            fig3, ax3 = plt.subplots(figsize=(12, 6))
            ax3.imshow(wordcloud, interpolation='bilinear')
            ax3.axis('off')
            st.pyplot(fig3)
        else:
            st.write(f"Column '{column_name}' not found in the uploaded CSV.")

    progress_bar.empty()  # Reset the progress bar after completion

2024-10-12 12:53:42.555 Session state does not function when running a script without `streamlit run`


In [23]:
from pyngrok import ngrok

# Kill any existing Streamlit processes
!killall streamlit

# Start ngrok with a valid HTTP tunnel configuration on port 8501
public_url = ngrok.connect(addr="8501", proto="http")
print(f"Streamlit app will be available on: {public_url}")

# Run the Streamlit app
!streamlit run app.py &>/dev/null&

streamlit: no process found
Streamlit app will be available on: NgrokTunnel: "https://40b8-35-196-80-85.ngrok-free.app" -> "http://localhost:8501"
