In [None]:
from flask import Flask, request, jsonify
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pickle
import re
from flask_cors import CORS
import os
import csv
from datetime import datetime
from flask import send_from_directory


# Load word2idx mapping
with open("word2idx.pkl", "rb") as f:
    word2idx = pickle.load(f)

# Load trained model
checkpoint = torch.load("final_model.pth", map_location=torch.device('cpu'))

# Constants
PAD_WORD = "<PAD>"
MAX_SENT_LEN = 16
MAX_SENT_NUM = 4
EMBED_DIM = 200  # Updated to match the trained model
NUM_CLASSES = 4

# Define CNN Model
class WordCNN(nn.Module):
    def __init__(self, vocab_size, embed_dim=100, num_filters=10, kernel_sizes=[2, 3, 4], padding_idx=0):
        super(WordCNN, self).__init__()
        self.trainable_embedding = nn.Embedding(vocab_size, 100, padding_idx=padding_idx)
        self.static_embedding = nn.Embedding(vocab_size, 100, padding_idx=padding_idx)

        self.static_embedding.weight.requires_grad = False

        self.convs = nn.ModuleList([
            nn.Conv2d(1, num_filters, (k, 200)) for k in kernel_sizes
        ])
        self.output_dim = num_filters * len(kernel_sizes)

    def forward(self, x):
        batch_size, num_sentences, sentence_len = x.shape
        embedded_trainable = self.trainable_embedding(x)  # [batch,4,16,100]
        embedded_static = self.static_embedding(x)  # [batch,4,16,100]
    
    # Concatenate embeddings (200D)
        embedded = torch.cat((embedded_trainable, embedded_static), dim=-1)
    
        embedded = embedded.view(batch_size * num_sentences, 1, sentence_len, -1)
        conv_outs = [F.max_pool1d(F.relu(conv(embedded)).squeeze(3), conv(embedded).size(2)).squeeze(2) for conv in self.convs]
        out = torch.cat(conv_outs, dim=1).view(batch_size, num_sentences, -1)
        return out

# Define Attention Layer
class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super(Attention, self).__init__()
        self.attn = nn.Linear(hidden_dim * 2, hidden_dim)
        self.score = nn.Linear(hidden_dim, 1, bias=False)

    def forward(self, lstm_out):
        attn_weights = torch.softmax(self.score(torch.tanh(self.attn(lstm_out))).squeeze(2), dim=1)
        context_vector = torch.sum(lstm_out * attn_weights.unsqueeze(2), dim=1)
        return context_vector

# Define LSTM Model
class SentenceLSTMWithAttention(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_classes):
        super(SentenceLSTMWithAttention, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True, bidirectional=True)
        self.attention = Attention(hidden_dim)
        self.fc = nn.Linear(hidden_dim * 2, num_classes)

    def forward(self, x):
        lstm_out, _ = self.lstm(x)
        context_vector = self.attention(lstm_out)
        return self.fc(context_vector)

# Initialize models
vocab_size = len(word2idx)
cnn_model = WordCNN(vocab_size, embed_dim=EMBED_DIM)
lstm_model = SentenceLSTMWithAttention(cnn_model.output_dim, 128, 2, NUM_CLASSES)

# Load trained weights
cnn_model.load_state_dict(checkpoint['cnn_model_state_dict'])
lstm_model.load_state_dict(checkpoint['lstm_model_state_dict'])
cnn_model.eval()
lstm_model.eval()

# Flask API
app = Flask(__name__)
CORS(app)  # This enables CORS for all routes

# OR for more control:
CORS(app, resources={
    r"/predict": {"origins": "*"},
    r"/predict-and-log": {"origins": "*"}
})

def preprocess_text(text):
    text_clean = re.sub(r'[^\w\s]', ' ', text.lower()).split()

    # Convert words to indices
    chunks = []
    for i in range(0, len(text_clean), MAX_SENT_LEN):
        chunk = [word2idx.get(word, word2idx.get("<UNK>", 1)) for word in text_clean[i:i+MAX_SENT_LEN]]
        # Pad if needed
        while len(chunk) < MAX_SENT_LEN:
            chunk.append(word2idx.get(PAD_WORD, 0))
        chunks.append(chunk)

    # Ensure exactly MAX_SENT_NUM sentences
    while len(chunks) < MAX_SENT_NUM:
        chunks.append([word2idx.get(PAD_WORD, 0)] * MAX_SENT_LEN)

    # Convert to tensor
    input_tensor = torch.tensor(chunks[:MAX_SENT_NUM], dtype=torch.long).unsqueeze(0)

    print(f"Processed input tensor shape: {input_tensor.shape}")  # Debugging output
    return input_tensor


CSV_FILE = "predictions_log.csv"
if not os.path.exists(CSV_FILE):
    with open(CSV_FILE, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(["timestamp", "text_snippet", "predicted_class", "full_text_length"])

@app.route('/predict', methods=['POST'])
def predict():
    data = request.get_json()
    text = data.get("text", "")
    
    if not text:
        return jsonify({"error": "No text provided"}), 400
    
    input_tensor = preprocess_text(text)
    
    with torch.no_grad():
        cnn_output = cnn_model(input_tensor)
        prediction = lstm_model(cnn_output)
        predicted_class = torch.argmax(prediction, dim=1).item()
        confidence = torch.softmax(prediction, dim=1)[0][predicted_class].item()
    
    return jsonify({
        "prediction": predicted_class,
        "confidence": float(confidence)
    })
@app.route('/predict-and-log', methods=['POST'])
def predict_and_log():
    data = request.get_json()
    text = data.get("text", "")
    
    # First get prediction
    pred_response = predict()
    if pred_response.status_code != 200:
        return pred_response
    
    prediction_data = pred_response.get_json()
    
    # Log to CSV
    with open(CSV_FILE, 'a', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([
            datetime.now().isoformat(),
            text[:50] + "..." if len(text) > 50 else text,
            prediction_data["prediction"],
            len(text)
        ])
    
    return jsonify({
        **prediction_data,
        "logged": True
    })


@app.route('/')
def home():
    return send_from_directory('.', 'index.html')
if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000)


 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:5000
 * Running on http://10.50.43.246:5000
Press CTRL+C to quit
127.0.0.1 - - [01/Apr/2025 06:25:25] "OPTIONS /predict HTTP/1.1" 200 -
127.0.0.1 - - [01/Apr/2025 06:25:25] "POST /predict HTTP/1.1" 200 -


Processed input tensor shape: torch.Size([1, 4, 16])


127.0.0.1 - - [01/Apr/2025 06:25:38] "OPTIONS /predict HTTP/1.1" 200 -
127.0.0.1 - - [01/Apr/2025 06:25:38] "POST /predict HTTP/1.1" 200 -


Processed input tensor shape: torch.Size([1, 4, 16])


127.0.0.1 - - [01/Apr/2025 06:25:54] "OPTIONS /predict-and-log HTTP/1.1" 200 -
127.0.0.1 - - [01/Apr/2025 06:25:54] "POST /predict-and-log HTTP/1.1" 200 -


Processed input tensor shape: torch.Size([1, 4, 16])


127.0.0.1 - - [01/Apr/2025 06:26:18] "OPTIONS /predict-and-log HTTP/1.1" 200 -
127.0.0.1 - - [01/Apr/2025 06:26:18] "POST /predict-and-log HTTP/1.1" 200 -


Processed input tensor shape: torch.Size([1, 4, 16])
