# Task-1 - Data Scraping

In [None]:
# If using GPU

!pip install faiss-cpu
!pip install faiss-gpu
# !pip install torch torchvision torchaudio
!pip install wikipedia
!pip install datasets

In [None]:
import re
import json
import hashlib
import requests
from bs4 import BeautifulSoup
from tqdm import tqdm
from typing import Dict, List, Any
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor, as_completed
import wikipedia
import pandas as pd
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer, DPRContextEncoder, DPRContextEncoderTokenizer
import faiss
from datasets import Dataset
from transformers import DistilBertTokenizer, TFDistilBertForSequenceClassification
import tensorflow as tf
import json
from sentence_transformers import SentenceTransformer, util
from sentence_transformers import CrossEncoder
import torch
import numpy as np
from transformers import pipeline, set_seed
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from sklearn.preprocessing import normalize
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import string
from torch.utils.data import random_split, DataLoader

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

In [None]:
import torch
import pandas as pd
from datasets import Dataset
from transformers import DistilBertTokenizer, TFDistilBertForSequenceClassification
import tensorflow as tf

In [None]:
# device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

In [None]:
def g_regex(text: str) -> str:
    return re.sub(r'[^a-zA-Z0-9\s]', '', text)

def g_unique_id(url: str) -> str:
    return hashlib.md5(url.encode()).hexdigest()[:9]

def g_scraping(url: str) -> str:
    try:
        response = requests.get(url, timeout=30)
        response.raise_for_status()
        return response.text
    except requests.RequestException as e:
        return None

def g_extract_info(html_content: str, url: str) -> Dict[str, Any]:
    soup = BeautifulSoup(html_content, 'html.parser')
    title = soup.find('h1', {'id': 'firstHeading'})
    title = title.text if title else "No title found"
    paragraphs = soup.find_all('p')
    summary = ""
    for p in paragraphs:
        summary += p.text.strip() + " "
        if len(summary) > 5200:
            break
    summary = re.sub(r'\[\d+\]', '', summary)
    summary = g_regex(summary)
    return {
        "revision_id": g_unique_id(url),
        "title": title,
        "summary": summary[:500],
        "url": url,}

In [None]:
def get_wikipedia_pages(topic: str, num_pages: int = 600) -> List[str]:
    base_url = "https://en.wikipedia.org/w/api.php"
    params = {
        "action": "query",
        "format": "json",
        "list": "search",
        "srsearch": topic,
        "srlimit": num_pages}
    response = requests.get(base_url, params=params)
    data = response.json()
    return [f"https://en.wikipedia.org/wiki/{page['title'].replace(' ', '_')}"
            for page in data['query']['search']]

topic_structure = {
    # "Health": ["Diseases", "Global health", "Mental health", "Nutrition", "Healthcare systems",],
    # "Environment": ["Global warming", "Endangered species", "Deforestation rates", "Pollution", "Renewable energy"],
    "Technology": ["Emerging technologies", "AI advancements", "Robotics", "Biotechnology", "Cybersecurity"],
    # "Economy": ["Stock market performance", "Job markets", "Cryptocurrency trends", "Economic policies", "International trade"],
    "Entertainment": ["Music industry", "Popular cultural events", "Streaming platforms", "Film industry", "Gaming industry"],
    "Sports": ["Major sporting events", "Sports analytics", "Professional leagues", "Olympic games", "Esports"],
    # "Politics": ["Elections", "Public policy analysis", "International relations", "Political ideologies", "Government systems"],
    # "Education": ["Literacy rates", "Online education trends", "Student loan data", "Educational technology", "Higher education"],
    "Travel": ["Top tourist destinations", "Airline industry data", "Travel trends", "Hospitality industry", "Adventure tourism"],
    "Food": ["Culinary trends", "Nutrition", "Restaurant industry", "Food technology", "Sustainable food practices"],
    "Environment": ["Global renewable energy policies", "Renewable energy transition strategies", "Sustainable energy frameworks", "National clean energy goals", "Ocean biodiversity threats",  "Marine pollution impacts",  "Coral reef degradation causes", "Overfishing consequences", "Seafloor habitat destruction","Global clean water projects",
        "Water sanitation and hygiene programs",
        "Freshwater conservation",
        "Safe drinking water access",
        "Water scarcity solutions",
        "Benefits of urban greenery",
        "Urban reforestation projects",
        "City parks and sustainability",
        "Green infrastructure in cities",
        "Community urban gardening",
        "Global climate strikes",
        "Youth-led climate campaigns",
        "Environmental advocacy organizations",
        "Climate action networks",
        "Protests against climate inaction",
        "Carbon offset programs",
        "Cap-and-trade systems",
        "Emissions trading mechanisms",
        "Global carbon market trends",
        "Carbon credits pricing",
        "Eco-literacy education",
        "Sustainability education curricula",
        "Green school initiatives",
        "Environmental awareness campaigns",
        "Environmental learning modules",
        "Zero-waste lifestyle practices",
        "Circular economy and waste reduction",
        "Minimalist sustainable living",
        "Composting and waste management",
        "Zero-waste product design",
        "Historical climate change data",
        "Global temperature rise patterns",
        "Effects of greenhouse gases",
        "Anthropogenic global warming causes",
        "Projected climate scenarios"
    ],
    "Health": [
        "Global telehealth solutions",
        "Virtual healthcare trends",
        "Digital health innovations",
        "Remote patient monitoring technology",
        "AI in telemedicine",
        "Global NCD prevalence",
        "Chronic disease burden",
        "Risk factors for NCDs",
        "Prevention strategies for lifestyle diseases",
        "Cardiovascular disease statistics",
        "Access to rural healthcare",
        "Challenges in rural health systems",
        "Rural-urban health inequality",
        "Healthcare infrastructure in remote areas",
        "HIV/AIDS prevention programs",
        "Antiretroviral treatment access",
        "UNAIDS global efforts",
        "HIV awareness campaigns"],
    "Education" : ["STEM education trends","Education reforms globally","Learning disabilities research","Access to education for refugees","Gender disparities in education","Dropout rates in high school","Future of hybrid learning","Lifelong learning trends"],
    "Economy" : ["Taxation policies worldwide","Economic sanctions and their effects","Trade wars between nations","Venture capital investments","Supply chain disruptions","Currency exchange rate trends","Global pension systems","Rural development through microfinance","Impact of mergers and acquisitions","Income tax reforms","Renewable energy investments","Foreign direct investment (FDI) analysis"],
    "Politics" : ["International treaties and pacts","Women's representation in politics","Role of social media in elections","Parliamentary systems worldwide","Public protests and their impact","Political corruption indices","Rise of authoritarianism","Political engagement among youth","Anti-globalization movements","Peace negotiations and diplomacy","Election monitoring organizations","Political lobbying in the economy","Green politics and climate action"]
}


def g_subtopic(topic: str, subtopic: str) -> List[Dict[str, Any]]:
    urls = get_wikipedia_pages(f"{topic} {subtopic}")
    results = []
    for url in urls:
        content = g_scraping(url)
        if content:
            page_info = g_extract_info(content, url)
            page_info['topic'] = topic
            results.append(page_info)
    return results

In [None]:
import requests
import csv
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm

def get_wikipedia_pages(topic: str, num_pages: int = 550) -> List[str]:
    """Fetch URLs of Wikipedia pages for a given topic."""
    base_url = "https://en.wikipedia.org/w/api.php"
    params = {
        "action": "query",
        "format": "json",
        "list": "search",
        "srsearch": topic,
        "srlimit": num_pages
    }
    response = requests.get(base_url, params=params)
    data = response.json()
    return [f"https://en.wikipedia.org/wiki/{page['title'].replace(' ', '_')}"
            for page in data['query']['search']]

def g_subtopic(topic: str, subtopic: str, max_docs: int) -> List[Dict[str, Any]]:
    """Fetch Wikipedia data for a given subtopic."""
    urls = get_wikipedia_pages(f"{topic} {subtopic}", num_pages=max_docs)
    results = []
    for url in urls:
        content = g_scraping(url)  # Assumes g_scraping is implemented
        if content:
            page_info = g_extract_info(content, url)  # Assumes g_extract_info is implemented
            page_info['topic'] = topic
            results.append(page_info)
            if len(results) >= max_docs:  # Stop when max_docs is reached
                break
    return results

def main():
    retrieved_data = {}
    max_docs_per_topic = 5200  # Set the maximum documents per topic
    csv_file = 'check.csv'
    csv_columns = ['topic', 'title', 'url', 'revision_id', 'summary']

    total_tasks = sum(len(subtopics) for subtopics in topic_structure.values())

    with tqdm(total=total_tasks, desc="Processing Subtopics", unit="subtopic") as pbar:
        with ThreadPoolExecutor(max_workers=50) as executor:  # Limit to 50 threads for stability
            next_topic = {}
            for topic, subtopics in topic_structure.items():
                retrieved_data[topic] = []
                for subtopic in subtopics:
                    # Pass remaining document count dynamically
                    remaining_docs = max(0, max_docs_per_topic - len(retrieved_data[topic]))
                    if remaining_docs == 0:
                        break
                    future = executor.submit(g_subtopic, topic, subtopic, remaining_docs)
                    next_topic[future] = topic

            for future in as_completed(next_topic):
                topic = next_topic[future]
                try:
                    results = future.result()
                    retrieved_data[topic].extend(results)
                except Exception as exc:
                    print(f"Error processing topic {topic}: {exc}")
                finally:
                    pbar.update(1)

    # Save to CSV
    with open(csv_file, 'w', newline='', encoding='utf-8') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=csv_columns)
        writer.writeheader()
        for topic, documents in retrieved_data.items():
            for doc in documents:
                writer.writerow({
                    'topic': doc['topic'],
                    'title': doc.get('title'),
                    'url': doc.get('url'),
                    'revision_id': doc.get('revision_id'),
                    'summary': doc.get('summary')
                })

    print(f"Data saved to {csv_file}")

if __name__ == "__main__":
    main()



In [None]:
df = pd.read_csv('Final_scraped_data.csv')
df['topic'].value_counts()

In [None]:
df

# Task - 4 Wiki Q/A bot

In [None]:
df = pd.read_csv('Final_scraped_data.csv')
df = df[['topic','url','title','revision_id','summary']]
df.dropna()
df.drop_duplicates()
df.reset_index(drop=True)

In [None]:
df['topic'].value_counts()

In [None]:
def filter_by_topic(data, topic):
    # Get the relevant entries for the given topic
    filtered_data = data.get(topic, [])
    texts = [entry["text"] for entry in filtered_data]
    embeddings = [entry["embeddings"] for entry in filtered_data]
    return texts, embeddings

In [None]:
data_list = []
for index, row in df.iterrows():
    data_list.append({
        "id": str(index + 1),
        "title": row["title"],
        "text": row["summary"],
        "topic" : row["topic"],
        "revision_id": row['revision_id'],
        "url" : row['url']

    })

dataset_json = {"data": data_list}

with open("Final_data.json", "w") as f:
    json.dump(dataset_json, f, indent=4)

# Print the JSON structure to verify
print(json.dumps(dataset_json, indent=4))


In [None]:
model = SentenceTransformer('all-mpnet-base-v2')

In [None]:
def generate_and_store_embeddings(json_data, output_file):
    topic_embeddings = {}

    for doc in json_data["data"]:
        topic = doc["topic"]
        text = doc["text"]

        # Generate embedding
        embedding = model.encode(text).tolist()  # Convert to list for JSON compatibility

        # Organize by topic
        if topic not in topic_embeddings:
            topic_embeddings[topic] = []

        topic_embeddings[topic].append({
            "id": doc["id"],
            "title": doc["title"],
            "text": doc["text"],
            "embedding": embedding
        })

    # Save embeddings to a JSON file
    with open(output_file, 'w') as f:
        json.dump(topic_embeddings, f, indent=4)
    print(f"Embeddings saved to {output_file}")

In [None]:
def load_file(json_file):
    with open(json_file, 'r') as f:
        return json.load(f)

In [None]:
output_file = 'topic_embeddings.json'
generate_and_store_embeddings(dataset_json,output_file)

In [None]:
data = load_file('/content/drive/MyDrive/Project-3/topic_embeddings.json')
# topic_contents = filter_by_topic(data,'Health')

In [None]:
bi_encoder = SentenceTransformer("all-mpnet-base-v2")
cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

In [None]:
def remove_unfinished_sentences(text):
    """
    Removes incomplete sentences from the generated text.
    Identifies sentences based on punctuation (., !, ?) as valid endings.
    """
    # Match sentences ending with '.', '!', or '?'
    sentence_pattern = r"([A-Z][^.!?]*[.!?])"
    sentences = re.findall(sentence_pattern, text)
    return " ".join(sentences)

In [None]:
def bi_cross_pipeline(query, corpus,bi_encoder,cross_encoder, device=device):
    # Step 1: Encode the query using the bi-encoder
    query_embedding = bi_encoder.encode(query, convert_to_numpy=True)
    query_embedding = torch.tensor(query_embedding, dtype=torch.float32, device=device)  # Convert to tensor and move to device

    # Step 2: Extract corpus text and embeddings
    corpus_text, corpus_embeddings = corpus[0], corpus[1]
    corpus_embeddings = torch.tensor(corpus_embeddings, dtype=torch.float32, device=device)  # Convert to tensor and move to device

    # Step 3: Compute bi-encoder cosine similarity scores
    bi_scores = util.cos_sim(query_embedding, corpus_embeddings)[0]

    # Step 4: Get the top-k results based on bi-encoder scores
    top_k_indices = torch.topk(bi_scores, k=5).indices
    top_k_results = [(corpus_text[idx], bi_scores[idx].item()) for idx in top_k_indices]

    # Step 5: Use cross-encoder to re-rank the top-k results
    cross_input = [(query, corpus_text[idx]) for idx in top_k_indices]
    cross_scores = cross_encoder.predict(cross_input)  # Cross-encoder scores

    reranked_results = [
        {"text": corpus_text[idx], "score": cross_scores[i]}
        for i, idx in enumerate(top_k_indices)
    ]
    reranked_results.sort(key=lambda x: x["score"], reverse=True)  # Sort by cross-encoder scores

    # Step 6: Expand only the top result
    top_result = reranked_results[0]  # Get the best result
    print(top_result)
    # Load the pipeline for text generation
    generator = pipeline('text-generation', model='gpt2-xl',device=device,torch_dtype=torch.bfloat16)
    generated_text = generator(top_result['text'], max_length=300, num_return_sequences=1, top_k=50, top_p=0.9)

    # Print the generated text
    output_text = generated_text[0]['generated_text']
    expanded_text = remove_unfinished_sentences(output_text)
    print(expanded_text)



In [None]:
hits = bi_cross_pipeline('AIDS',topic_contents,bi_encoder,cross_encoder,device)

# Task - 4 Using RAG

In [None]:
# Load DPR question and context encoders
question_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")

context_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
context_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")

In [None]:
def load_corpus_from_json(json_file):
    with open(json_file, 'r') as file:
        corpus_data = json.load(file)

    # Extract summaries (you can adjust depending on your JSON structure)
    summaries = [str(item['text']) for item in corpus_data['data']]
    return summaries

In [None]:
# def encode_corpus(corpus, context_encoder, context_tokenizer):
#     # Tokenize and encode the context (passage)
#     inputs = context_tokenizer(corpus, return_tensors="pt", padding=True, truncation=True)
#     with torch.no_grad():
#         context_embeddings = context_encoder(**inputs).pooler_output

#     # Normalize embeddings
#     context_embeddings = normalize(context_embeddings.numpy(), axis=1)
#     return context_embeddings

def encode_corpus(corpus, context_encoder, context_tokenizer, device, batch_size=16):
    # List to hold the encoded embeddings
    embeddings = []
    context_encoder.to(device)

    # Process the corpus in batches
    for i in tqdm(range(0, len(corpus), batch_size), desc="Encoding Corpus"):
        # Get the current batch
        batch = corpus[i:i + batch_size]

        # Tokenize the batch and move it to the same device as the model
        inputs = context_tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512)

        # Move inputs to the correct device
        inputs = {key: value.to(device) for key, value in inputs.items()}

        with torch.no_grad():
            # Encode the batch of text (passages)
            context_embeddings = context_encoder(**inputs).pooler_output

        # Normalize the embeddings for each batch and append to the list
        context_embeddings = normalize(context_embeddings.cpu().numpy(), axis=1)  # Move back to CPU for sklearn
        embeddings.append(context_embeddings)

    # Concatenate all the embeddings into a single NumPy array
    embeddings = np.vstack(embeddings)

    return embeddings

In [None]:
def save_embeddings_to_drive(embeddings, embeddings_file):
    np.save(embeddings_file, embeddings)
    print(f"Embeddings saved to {embeddings_file}")

In [None]:
def load_embeddings_from_drive(embeddings_file):
    embeddings = np.load(embeddings_file)
    print(f"Embeddings loaded from {embeddings_file}")
    return embeddings

In [None]:
summaries = load_corpus_from_json('/content/drive/MyDrive/Project-3/Final_data.json')
summaries_embeddings = encode_corpus(summaries,context_encoder,context_tokenizer,device)
embeddings_file = '/content/drive/MyDrive/Project-3/summaries_embeddings.npy'
save_embeddings_to_drive(summaries_embeddings, embeddings_file)

In [None]:
summaries = load_embeddings_from_drive('/content/drive/MyDrive/Project-3/summaries_embeddings.npy')

In [None]:
full_data = load_file('/content/drive/MyDrive/Project-3/Final_data.json')['data']

In [None]:
def get_text(json_data):
  text = [str(item['text']) for item in json_data['data']]
  return text

In [None]:
corpus = get_text(full_data)

In [None]:
dataset = []
for i, entry in enumerate(full_data):
    dataset.append({
        'title': entry['title'],
        'text': entry['text'],
        'embedding': summaries[i]  # Use the precomputed embedding
    })
# Create a FAISS index
dimension = summaries.shape[1]  # Dimensionality of the embeddings
index = faiss.IndexFlatL2(dimension)  # Use L2 distance metric
index.add(summaries)  # Add embeddings to the FAISS index

# Store the FAISS index and dataset
dataset_dict = {
    'dataset': dataset,
    'index': index
}


# Task-3 Topic Classification

In [None]:
pip install datasets

In [None]:
import pandas as pd
from datasets import Dataset
from transformers import DistilBertTokenizer, TFDistilBertForSequenceClassification
import tensorflow as tf
import string

In [None]:
df = pd.read_csv('/content/drive/MyDrive/Project-3/Final_scraped_data.csv')
df.dropna()
df = df[['topic','summary']]

In [None]:
def remove_punctuation(input_string):
    translator = str.maketrans("", "", string.punctuation)
    result = input_string.translate(translator)
    return result

In [None]:
df['summary'] = df['summary'].apply(str)
for i in range(0,len(df)):
    df['summary'].iloc[i] = remove_punctuation(df['summary'].iloc[i])

In [None]:
df = df.sample(frac=1)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
unique_topics = sorted(df["topic"].unique())
topic_to_label = {topic: i for i, topic in enumerate(unique_topics)}

# Add a new column with numerical labels
df["label"] = df["topic"].map(topic_to_label)

In [None]:
df.groupby('topic').first()

In [None]:
label2id = {0:'Economy',1:'Education',2:'Entertainment',3:'Environment',4:'Food',5:'Health',6:'Politics',7:'Sports',8:'Technology',9:'Travel'}

In [None]:
df = df.reset_index(drop=True)

In [None]:
data_texts = df["summary"].to_list() # Features (not tokenized yet)
data_labels = df["label"].to_list() # Labels

In [None]:
from sklearn.model_selection import train_test_split

# Split Train and Validation data
train_texts, val_texts, train_labels, val_labels = train_test_split(data_texts, data_labels, test_size=0.2, random_state=0, shuffle=True)

# Keep some data for inference (testing)
train_texts, test_texts, train_labels, test_labels = train_test_split(train_texts, train_labels, test_size=0.01, random_state=0, shuffle=True)


In [None]:
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
train_encodings = tokenizer(train_texts, truncation=True, padding=True)
val_encodings = tokenizer(val_texts, truncation=True, padding=True)

In [None]:
train_dataset = tf.data.Dataset.from_tensor_slices((
dict(train_encodings),
train_labels
))
val_dataset = tf.data.Dataset.from_tensor_slices((
dict(val_encodings),
val_labels
))

In [None]:
model = TFDistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=10)
learning_rate_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=1e-5,
    decay_steps=1000,
    decay_rate=0.9,
    staircase=True
)
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate_schedule,epsilon=1e-6,weight_decay=1e-5)
model.compile(optimizer=optimizer, loss=model.hf_compute_loss, metrics=['accuracy'])

In [None]:
# cmodel = TFDistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=10)
# optimizer = tf.keras.optimizers.Adam(learning_rate=2e-5, epsilon=1e-08)
# model.compile(optimizer=optimizer, loss=model.hf_compute_loss, metrics=['accuracy'])

In [None]:
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)

model.fit(train_dataset.shuffle(1000).batch(16),
epochs=10,
batch_size=16,
validation_data=val_dataset.shuffle(1000).batch(16))

###########################################

In [None]:
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)

model.fit(train_dataset.shuffle(1000).batch(16),
epochs=20,
batch_size=16,
validation_data=val_dataset.shuffle(1000).batch(16))

In [None]:
model.save_pretrained('/content/drive/MyDrive/Project-3/DistilBERT')
tokenizer.save_pretrained('/content/drive/MyDrive/Project-3/DistilBERT')

#########################

In [None]:
model.fit(train_dataset.shuffle(1000).batch(16),
epochs=5,
batch_size=16,
validation_data=val_dataset.shuffle(1000).batch(16))

In [None]:
model.fit(train_dataset.shuffle(1000).batch(16),
epochs=5,
batch_size=16,
validation_data=val_dataset.shuffle(1000).batch(16))

In [None]:
model.fit(train_dataset.shuffle(1000).batch(16),
epochs=5,
batch_size=16,
validation_data=val_dataset.shuffle(1000).batch(16))

In [None]:
model.save_pretrained('/content/drive/MyDrive/Project-3/')
tokenizer.save_pretrained('/content/drive/MyDrive/Project-3/')

In [None]:
loaded_tokenizer = DistilBertTokenizer.from_pretrained('/content/drive/MyDrive/Project-3/DistilBERT')
loaded_model = TFDistilBertForSequenceClassification.from_pretrained('/content/drive/MyDrive/Project-3/DistilBERT')

In [None]:
test_text = test_texts[10]
test_text

In [None]:
predict_input = tokenizer.encode(test_text,
truncation=True,
padding=True,
return_tensors="tf")

output = model(predict_input)[0]

prediction_value = tf.argmax(output, axis=1).numpy()[0]
prediction_value

In [None]:
test_labels[10]

# References

1. https://medium.com/@kiddojazz/distilbert-for-multiclass-text-classification-using-transformers-d6374e6678ba
2. https://medium.com/@kiddojazz/distilbert-for-multiclass-text-classification-using-transformers-d6374e6678ba
3. https://huggingface.co/distilbert/distilbert-base-uncased-finetuned-sst-2-english
4. https://sbert.net/examples/applications/retrieve_rerank/README.html
5. https://huggingface.co/docs/transformers/en/model_doc/rag#transformers.RagRetriever
