In [None]:
import streamlit as st
from transformers import BartForConditionalGeneration, BartTokenizer, AutoTokenizer, AutoModel, AutoModelForQuestionAnswering
import torch
from docx import Document
import pdfplumber
from io import StringIO
import os
from googletrans import Translator

# Load models
@st.cache_resource
def load_bert_model():
    class BERT_Arch(torch.nn.Module):
        def __init__(self, bert):
            super(BERT_Arch, self).__init__()
            self.bert = bert
            self.dropout = torch.nn.Dropout(0.1)
            self.relu = torch.nn.ReLU()
            self.fc1 = torch.nn.Linear(768, 512)
            self.fc2 = torch.nn.Linear(512, 2)
            self.softmax = torch.nn.LogSoftmax(dim=1)

        def forward(self, sent_id, mask):
            cls_hs = self.bert(sent_id, attention_mask=mask)['pooler_output']
            x = self.fc1(cls_hs)
            x = self.relu(x)
            x = self.dropout(x)
            x = self.fc2(x)
            x = self.softmax(x)
            return x

    tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    bert = AutoModel.from_pretrained('bert-base-uncased')
    model = BERT_Arch(bert)
    model.load_state_dict(torch.load("fakenews_bert_weights.pt", map_location=torch.device('cpu')))
    model.eval()
    return tokenizer, model

@st.cache_resource
def load_bart_model():
    tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
    model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
    return tokenizer, model

@st.cache_resource
def load_qa_model():
    qa_model_name = "distilbert-base-cased-distilled-squad"
    qa_tokenizer = AutoTokenizer.from_pretrained(qa_model_name)
    qa_model = AutoModelForQuestionAnswering.from_pretrained(qa_model_name)
    return qa_tokenizer, qa_model

# Utility functions
def extract_text(file):
    file_extension = os.path.splitext(file.name)[1].lower()
    if file_extension == ".pdf":
        text = ""
        with pdfplumber.open(file) as pdf:
            for page in pdf.pages:
                text += page.extract_text() + "\n"
    elif file_extension == ".docx":
        doc = Document(file)
        text = "\n".join([para.text for para in doc.paragraphs])
    elif file_extension == ".txt":
        text = StringIO(file.read().decode('utf-8')).read()
    else:
        raise ValueError("Unsupported file type.")
    return text

def split_text_into_chunks(text, tokenizer, max_tokens=900, overlap=100):
    tokens = tokenizer.tokenize(text)
    chunks = []
    for i in range(0, len(tokens), max_tokens - overlap):
        chunk = tokens[i:i + max_tokens]
        chunk_text = tokenizer.convert_tokens_to_string(chunk)
        chunks.append(chunk_text)
    return chunks

def summarize_chunk(text, model, tokenizer, max_length=300):
    inputs = tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=1024, truncation=True)
    summary_ids = model.generate(inputs, max_length=max_length, min_length=50, length_penalty=2.0, num_beams=4, early_stopping=True)
    return tokenizer.decode(summary_ids[0], skip_special_tokens=True)

def recursive_summarize(text, model, tokenizer):
    chunks = split_text_into_chunks(text, tokenizer)
    summaries = [summarize_chunk(chunk, model, tokenizer) for chunk in chunks]
    combined = " ".join(summaries)
    if len(tokenizer.tokenize(combined)) > 1024:
        return recursive_summarize(combined, model, tokenizer)
    return combined

def classify_fake_news(text, tokenizer, model):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    with torch.no_grad():
        outputs = model(inputs["input_ids"], inputs["attention_mask"])
        probs = torch.softmax(outputs, dim=1)
        prediction = torch.argmax(probs, dim=1).item()
    return "Fake" if prediction == 0 else "Real"

def translate_text(text, target_language="fr"):
    translator = Translator()
    translated = translator.translate(text, dest=target_language)
    return translated.text

def answer_question(text, question, model, tokenizer):
    inputs = tokenizer(question, text, return_tensors='pt', truncation=True, max_length=512)
    with torch.no_grad():
        outputs = model(**inputs)
        start_idx = torch.argmax(outputs.start_logits)
        end_idx = torch.argmax(outputs.end_logits) + 1
        answer = tokenizer.convert_tokens_to_string(
            tokenizer.convert_ids_to_tokens(inputs.input_ids[0][start_idx:end_idx])
        )
    return answer if answer.strip() else "Could not find a relevant answer."

# ------------------ Streamlit UI ------------------

def main():
    st.set_page_config(page_title="VeriNews AI – Detect. Summarize. Translate.", layout="wide")
    st.title("VeriNews AI – Detect. Summarize. Translate.")

    uploaded_file = st.file_uploader("Upload PDF, DOCX or TXT file:", type=["pdf", "docx", "txt"])
    custom_text = st.text_area("Or paste/edit your own news article here:")

    if uploaded_file or custom_text:
        text = extract_text(uploaded_file) if uploaded_file else custom_text
        if st.button("Submit"):
            st.session_state["text"] = text
            st.session_state["bart_tokenizer"], st.session_state["bart_model"] = load_bart_model()
            st.session_state["bert_tokenizer"], st.session_state["bert_model"] = load_bert_model()
            st.session_state["qa_tokenizer"], st.session_state["qa_model"] = load_qa_model()

    if "text" in st.session_state:
        text = st.session_state["text"]
        bart_tokenizer = st.session_state["bart_tokenizer"]
        bart_model = st.session_state["bart_model"]
        bert_tokenizer = st.session_state["bert_tokenizer"]
        bert_model = st.session_state["bert_model"]
        qa_tokenizer = st.session_state["qa_tokenizer"]
        qa_model = st.session_state["qa_model"]

        # --- Summarize ---
        with st.container():
            if st.button("Summarize"):
                with st.spinner("Summarizing..."):
                    summary = recursive_summarize(text, bart_model, bart_tokenizer)
                    st.session_state["summary"] = summary
            if "summary" in st.session_state:
                st.subheader("📄 Summary")
                st.write(st.session_state["summary"])

        # --- Classify News ---
        with st.container():
            if st.button("Classify News"):
                with st.spinner("Classifying..."):
                    prediction = classify_fake_news(text, bert_tokenizer, bert_model)
                    st.session_state["prediction"] = prediction
            if "prediction" in st.session_state:
                st.subheader("🕵️ Fake News Prediction")
                st.markdown(f"**Prediction:** {st.session_state['prediction']}")

        # --- Translation ---
        target_lang = st.text_input("Enter language code (e.g., fr, hi, es):")

        with st.container():
            if st.button("Translate News") and target_lang:
                with st.spinner("Translating full news..."):
                    translated = translate_text(text, target_lang)
                    st.session_state["translated_news"] = translated
            if "translated_news" in st.session_state:
                st.subheader("🌍 Translated Full News")
                st.write(st.session_state["translated_news"])

        with st.container():
            if st.button("Translate Summary") and target_lang:
                if "summary" in st.session_state:
                    with st.spinner("Translating summary..."):
                        translated_summary = translate_text(st.session_state["summary"], target_lang)
                        st.session_state["translated_summary"] = translated_summary
                else:
                    st.warning("Please summarize the news before translating it.")
            if "translated_summary" in st.session_state:
                st.subheader("🌍 Translated Summary")
                st.write(st.session_state["translated_summary"])

        # --- Q&A ---
        question = st.text_input("Enter your question:")
        with st.container():
            if st.button("Get Answer") and question:
                with st.spinner("Answering..."):
                    answer = answer_question(text, question, qa_model, qa_tokenizer)
                    st.session_state["qa_answer"] = answer
                    st.session_state["qa_question"] = question
            if "qa_answer" in st.session_state:
                st.subheader("❓ Q&A")
                st.markdown(f"**Q:** {st.session_state['qa_question']}")
                st.markdown(f"**A:** {st.session_state['qa_answer']}")

if __name__ == "__main__":
    main()
