In [None]:
import requests
from bs4 import BeautifulSoup
import os
import json
import random
import requests
from pathlib import Path
import fitz
import re
from typing import List, Dict, Tuple, Optional
from collections import defaultdict
import logging
from datetime import datetime
import torch
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.decomposition import LatentDirichletAllocation
from textstat import textstat
from transformers import AutoTokenizer, AutoModel
from torch.nn.functional import softmax
import numpy as np
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor
import threading
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset,  Subset
from sklearn.model_selection import StratifiedShuffleSplit, train_test_split
import pandas as pd
from tqdm import tqdm

In [None]:
search_url = "https://arxiv.org/search/?query={CONFERENCE_NAME}&searchtype=all&abstracts=show&order=-announced_date_first&size=200"
output_dir = "{CONFERENCE_NAME}"

os.makedirs(output_dir, exist_ok=True)

def download_pdf(pdf_url, output_path):
    response = requests.get(pdf_url, stream=True)
    if response.status_code == 200:
        with open(output_path, 'wb') as file:
            for chunk in response.iter_content(chunk_size=1024):
                file.write(chunk)
        print(f"[SUCCESS] Downloaded: {output_path}")
    else:
        print(f"[ERROR] Failed to download: {pdf_url}")

response = requests.get(search_url)
if response.status_code == 200:
    soup = BeautifulSoup(response.content, "html.parser")
    pdf_links = soup.find_all("a", href=lambda href: href and "/pdf/" in href)

    for link in pdf_links:
        pdf_url = f"https://arxiv.org{link['href']}.pdf"
        pdf_name = os.path.basename(link['href']) + ".pdf"
        pdf_path = os.path.join(output_dir, pdf_name)
        
        if not os.path.exists(pdf_path):
            download_pdf(pdf_url, pdf_path)
        else:
            print(f"[INFO] Already downloaded: {pdf_path}")
else:
    print(f"[ERROR] Failed to fetch ArXiv page: {response.status_code}")


In [None]:
class TextCorruptor:
    def __init__(self):
        self.api_url = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3"
        self.headers = {"Authorization": "Bearer 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"}
        self._test_api_connection()
        self.fallback_phrases = [
            "potato dreams fly upward", "singing mountains eat clouds",
            "blue ideas sleep furiously", "yesterday tomorrow today simultaneously"
        ]

    def _test_api_connection(self):
        try:
            response = requests.post(self.api_url, headers=self.headers, json={"inputs": "test"})
            response.raise_for_status()
        except Exception as e:
            print(f"API connection failed: {e}")

    def clean_text(self, text):
        if not isinstance(text, str):
            return text
        return ' '.join(text.replace('\n', ' ').split())

    def remove_characters(self, text):
        if not text:
            return text
        chars = list(text)
        remove_count = random.randint(len(chars) // 20, len(chars) // 10)
        for _ in range(remove_count):
            if chars:
                idx = random.randint(0, len(chars) - 1)
                chars.pop(idx)
        return ''.join(chars)

    def remove_words(self, text):
        words = text.split()
        if len(words) <= 1:
            return text
            
        remove_count = random.randint(len(words) // 6, len(words) // 4)
        for _ in range(remove_count):
            if words:
                idx = random.randint(0, len(words) - 1)
                words.pop(idx)
        return ' '.join(words)

    def remove_sentences(self, section):
        sentences = [s.strip() for s in section.split('.') if s.strip()]
        if len(sentences) <= 1:
            return section
            
        remove_count = random.randint(max(1, int(len(sentences) * 0.2)), 
                                    max(1, int(len(sentences) * 0.4)))
        for _ in range(remove_count):
            if sentences:
                idx = random.randint(0, len(sentences) - 1)
                sentences.pop(idx)
        return '. '.join(sentences) + '.' if sentences else ''

    def remove_paragraphs(self, text):
        paragraphs = text.split('. ')
        if len(paragraphs) <= 1:
            return text
            
        remove_count = random.randint(len(paragraphs) // 4, len(paragraphs) // 2)
        for _ in range(remove_count):
            if paragraphs:
                idx = random.randint(0, len(paragraphs) - 1)
                paragraphs.pop(idx)
        return '. '.join(paragraphs)

    def generate_nonsense(self):
        try:
            payload = {
                "inputs": "Generate a nonsensical phrase. it should be completely random and should be atleast 5 - 20 words",
                "parameters": {"max_length": 50, "temperature": 0.9}
            }
            response = requests.post(self.api_url, headers=self.headers, json=payload)
            text = response.json()[0]["generated_text"].split(":")[-1].strip('"\'').strip()
            return text if text and len(text.split()) <= 5 else random.choice(self.fallback_phrases)
        except:
            return random.choice(self.fallback_phrases)

    def add_nonsense(self, section):
        words = section.split()
        if not words:
            return section
        num_phrases = random.randint(1, 2)
        for _ in range(num_phrases):
            if words:
                pos = random.randint(0, len(words))
                words.insert(pos, self.generate_nonsense())
        return ' '.join(words)

    def disturb_grammar(self, text):
        if not isinstance(text, str) or not text.strip():
            return text
        words = text.split()
        if len(words) < 2:
            return text
            
        for i in range(len(words)):
            if random.random() > 0.8:
                if words[i].lower() in {'a', 'an', 'the'}:
                    words[i] = ''
                elif len(words[i]) > 3:
                    if words[i].endswith('ing'):
                        words[i] = words[i][:-3] + 'ed'
                    elif words[i].endswith('ed'):
                        words[i] = words[i][:-2] + 'ing'
        return ' '.join(w for w in words if w)

    def reorder_text(self, text):
        sentences = [s.strip() for s in text.split('.') if s.strip()]
        if len(sentences) <= 1:
            return text
        random.shuffle(sentences)
        return '. '.join(sentences) + '.'

    def corrupt_document(self, data):
        cleaned_data = {self.clean_text(k): self.clean_text(v) for k, v in data.items()}
        
        # More moderate section removal (20-40% of sections)
        sections = list(cleaned_data.keys())
        if len(sections) > 1:
            remove_count = random.randint(
                max(1, int(len(sections) * 0.2)),
                max(1, int(len(sections) * 0.4))
            )
            for _ in range(remove_count):
                if sections:
                    cleaned_data.pop(random.choice(sections))
                    sections = list(cleaned_data.keys())
        
        corrupted = {}
        for heading, content in cleaned_data.items():
            # Apply removal operations with moderate probabilities
            if random.random() < 0.5:
                content = self.remove_paragraphs(content)
            if random.random() < 0.7:
                content = self.remove_sentences(content)
            if random.random() < 0.7:
                content = self.remove_words(content)
            if random.random() < 0.6:
                content = self.remove_characters(content)
            
            # Apply other corruptions
            if random.random() < 0.5:
                content = self.add_nonsense(content)
            if random.random() < 0.6:
                content = self.disturb_grammar(content)
            if random.random() < 0.5:
                content = self.reorder_text(content)
            
            # 30% chance to corrupt heading
            if random.random() < 0.4:
                heading = self.remove_words(heading)
                if random.random() < 0.4:
                    heading = self.remove_characters(heading)
            
            if content.strip():
                corrupted[heading] = content
        
        items = list(corrupted.items())
        random.shuffle(items)
        return dict(items)

def process_directory(input_dir: str, output_dir: str):
    input_path = Path(input_dir)
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    corruptor = TextCorruptor()
    
    for json_file in input_path.glob('**/*.json'):
        output_file = output_path / json_file.name
        
        if output_file.exists():
            print(f"Skipping {json_file.name} - already processed")
            continue
            
        try:
            with open(json_file, 'r') as f:
                data = json.load(f)
            corrupted_data = corruptor.corrupt_document(data)
            with open(output_file, 'w') as f:
                json.dump(corrupted_data, f, indent=4)
            print(f"Processed: {json_file.name} -> {output_file.name}")
        except Exception as e:
            print(f"Error processing {json_file}: {e}")

if __name__ == "__main__":
    input_dir = "Dataset/texts/publishable"
    output_dir = "Dataset/texts/non-publishable"
    process_directory(input_dir, output_dir)

In [None]:
class DoraemonPDFParser:
    def __init__(self, min_section_length: int = 50):
        self.patterns = {
            'urls': r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+',
            'citations': r'\[[0-9,\s-]+\]',
            'references': r'^(?:References?|REFERENCES?|Bibliography|BIBLIOGRAPHY)(?:\s|$)',
            'appendix': r'^(?:Appendix|APPENDIX)(?:\s+[A-Z])?(?:\s|:|$)',
            'acknowledgments': r'^(?:Acknowledgments?|ACKNOWLEDGMENTS?)(?:\s|$)',
            'emails': r'[\w\.-]+@[\w\.-]+\.\w+',
            'line_numbers': r'^\d+$',
            'page_numbers': r'^\d+$',
            'cross_refs': r'(Fig\.|Figure|Table|Section)\s*\d+',
            'figure_captions': r'(Figure|Fig\.)\s*\d+[.:]\s*.*?(?=\n|$)',
            'table_captions': r'Table\s*\d+[.:]\s*.*?(?=\n|$)',
            'section_number': r'^(?:\d+\.)*\d+(?:\s+|\b)|^\.\d+(?:\s+|\b)'
        }
        self.heading_font_sizes = []
        self.found_references = False
        self.found_appendix = False
        self.min_section_length = min_section_length
        self.setup_logging()

    def setup_logging(self):
        log_dir = "logs"
        os.makedirs(log_dir, exist_ok=True)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        log_file = os.path.join(log_dir, f"pdf_parser_{timestamp}.log")
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler(log_file),
                logging.StreamHandler()
            ]
        )
        self.logger = logging.getLogger(__name__)

    def clean_text(self, text: str) -> str:
        text = text.encode('ascii', 'ignore').decode('ascii')
        for pattern_name, pattern in self.patterns.items():
            if pattern_name not in ['figure_captions', 'table_captions', 'section_number']:
                text = re.sub(pattern, '', text)
        text = re.sub(r'(?<=[a-z])-\n(?=[a-z])', '', text)
        text = re.sub(r'(?<=[a-z])\n(?=[a-z])', ' ', text)
        text = re.sub(r'\s+', ' ', text)
        return text.strip()

    def get_font_properties(self, block: Dict) -> Tuple[str, float, bool, str]:
        text = ""
        max_font_size = 0
        is_bold = False
        font_face = ""
        for line in block.get("lines", []):
            for span in line.get("spans", []):
                text += span.get("text", "")
                current_size = span.get("size", 0)
                if current_size > max_font_size:
                    max_font_size = current_size
                    font_face = span.get("font", "")
                is_bold = is_bold or (span.get("flags", 0) & 2 ** 2 != 0)
        return text, max_font_size, is_bold, font_face

    def clean_heading(self, text: str) -> str:
        text = re.sub(r'^(?:\d+\.)*\d+(?:\s+|\b)|^\.\d+(?:\s+|\b)', '', text)
        text = re.sub(r'^(?:\d+\.)*\d+([A-Z][a-z])', r'\1', text)
        text = re.sub(r'^\.\d+([A-Z][a-z])', r'\1', text)
        text = re.sub(r'\.+\s*$', '', text)
        text = re.sub(r'^\d+([A-Z])', r'\1', text)
        text = re.sub(r'^\.\d+([A-Z])', r'\1', text)
        return text.strip()

    def is_heading(self, text: str, font_size: float, is_bold: bool, font_face: str) -> Tuple[bool, str]:
        if re.match(self.patterns['references'], text):
            self.found_references = True
            return False, text
        elif re.match(self.patterns['appendix'], text):
            self.found_appendix = True
            return False, text
            
        heading_patterns = [
            r'^(?:\d+\.)*\d+\s+[A-Z][A-Za-z\s]+$',
            r'^(?:\d+\.)*\d+\s+[A-Z][A-Z\s]+$',
            r'^(?:\d+\.)*\d+[A-Z][A-Za-z\s]+$',
            r'^\.\d+\s+[A-Z][A-Za-z\s]+$',
            r'^\.\d+[A-Z][A-Za-z\s]+$',
            r'^[A-Z][A-Z\s]{3,}[A-Z]$',
            r'^[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*$'
        ]
        
        if len(self.heading_font_sizes) < 10:
            self.heading_font_sizes.append(font_size)
        
        avg_font_size = sum(self.heading_font_sizes) / len(self.heading_font_sizes) if self.heading_font_sizes else 11
        is_larger_font = font_size > avg_font_size + 1
        text = text.strip()
        
        is_pattern_match = any(re.match(pattern, text) for pattern in heading_patterns)
        is_short = len(text) < 200
        has_heading_properties = (is_larger_font or is_bold) and is_short
        
        if is_pattern_match or has_heading_properties:
            cleaned_heading = self.clean_heading(text)
            return True, cleaned_heading
        return False, text

    def should_include_section(self, heading: str, text: str) -> bool:
        if self.found_references or self.found_appendix:
            return False
            
        if any(re.match(self.patterns[pattern], heading.strip()) 
               for pattern in ['references', 'appendix', 'acknowledgments']):
            if re.match(self.patterns['references'], heading.strip()):
                self.found_references = True
            elif re.match(self.patterns['appendix'], heading.strip()):
                self.found_appendix = True
            return False
            
        if len(text.strip().split()) < self.min_section_length:
            return False
            
        return True

    def parse_pdf(self, pdf_path: str) -> Dict:
        try:
            doc = fitz.open(pdf_path)
            sections = {}
            current_heading = ''
            current_content = []
            self.heading_font_sizes = []
            self.found_references = False
            self.found_appendix = False
            
            for page_num in range(len(doc)):
                page = doc[page_num]
                blocks = page.get_text("dict")["blocks"]
                for block in blocks:
                    if block.get("type") == 0:
                        text, font_size, is_bold, font_face = self.get_font_properties(block)
                        text = self.clean_text(text)
                        if not text:
                            continue
                        is_heading, heading_text = self.is_heading(text, font_size, is_bold, font_face)
                        if is_heading:
                            if current_heading and current_content:
                                section_text = ' '.join(current_content)
                                if self.should_include_section(current_heading, section_text):
                                    sections[current_heading] = section_text
                            current_heading = heading_text
                            current_content = []
                        else:
                            current_content.append(text)
            
            if current_heading and current_content:
                section_text = ' '.join(current_content)
                if self.should_include_section(current_heading, section_text):
                    sections[current_heading] = section_text
            
            doc.close()
            return sections
            
        except Exception as e:
            self.logger.error(f"Error parsing PDF {pdf_path}: {str(e)}")
            return {}

    def save_to_json(self, parsed_content: Dict, output_path: str):
        try:
            os.makedirs(os.path.dirname(output_path), exist_ok=True)
            with open(output_path, 'w', encoding='utf-8') as f:
                json.dump(parsed_content, f, indent=2, ensure_ascii=False)
            self.logger.info(f"Successfully saved parsed content to {output_path}")
        except Exception as e:
            self.logger.error(f"Error saving to JSON: {str(e)}")

def process_directory(input_dir: str, output_dir: str, min_section_length: int = 50):
    parser = DoraemonPDFParser(min_section_length=min_section_length)
    count = 0
    for root, _, files in os.walk(input_dir):
        for file in files:
            if file.lower().endswith('.pdf'):
                pdf_path = os.path.join(root, file)
                relative_path = os.path.relpath(root, input_dir)
                output_path = os.path.join(
                    output_dir,
                    relative_path,
                    f"P{count:04d}.json"
                )
                count += 1
                #parser.logger.info(f"Processing: {pdf_path}")
                parsed_content = parser.parse_pdf(pdf_path)
                parser.save_to_json(parsed_content, output_path)

if __name__ == "__main__":
    base_input_dir = "Dataset/pdfs/non-publishable"
    base_output_dir = "Dataset/texts/non-publishable"
    input_dir = os.path.join(base_input_dir)
    output_dir = os.path.join(base_output_dir)
    process_directory(input_dir, output_dir)

In [None]:
class DoraemonProcessor:
    def __init__(self, model_name: str = "allenai/scibert_scivocab_uncased"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        print(f"[INFO] Model loaded on: {self.device}")
        self.hidden_size = 768
        self.lock = threading.Lock()

    def get_chunks(self, text: str, max_length: int = 512) -> Tuple[torch.Tensor, torch.Tensor]:
        tokens = self.tokenizer(text, return_tensors="pt", truncation=False)
        input_ids = tokens["input_ids"][0]
        attention_mask = tokens["attention_mask"][0]
        num_tokens = input_ids.shape[0]
        chunk_boundaries = range(0, num_tokens, max_length)
        chunks_ids = []
        chunks_mask = []

        for start in chunk_boundaries:
            end = min(start + max_length, num_tokens)
            chunk_ids = input_ids[start:end]
            chunk_mask = attention_mask[start:end]
            if chunk_ids.shape[0] < max_length:
                pad_length = max_length - chunk_ids.shape[0]
                chunk_ids = torch.nn.functional.pad(chunk_ids, (0, pad_length), value=self.tokenizer.pad_token_id)
                chunk_mask = torch.nn.functional.pad(chunk_mask, (0, pad_length), value=0)
            chunks_ids.append(chunk_ids)
            chunks_mask.append(chunk_mask)

        all_chunk_ids = torch.stack(chunks_ids).unsqueeze(0)
        all_chunk_masks = torch.stack(chunks_mask).unsqueeze(0)
        return all_chunk_ids.to(self.device), all_chunk_masks.to(self.device)

    @torch.no_grad()
    def process_text(self, chunks_ids: torch.Tensor, chunks_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        all_embeddings = []
        all_hidden_states = []

        for i in range(chunks_ids.shape[1]):
            chunk_ids = chunks_ids[:, i, :]
            chunk_mask = chunks_mask[:, i, :]
            outputs = self.model(input_ids=chunk_ids, attention_mask=chunk_mask, output_hidden_states=True)
            all_embeddings.append(outputs.last_hidden_state)
            all_hidden_states.append(torch.stack(outputs.hidden_states[-4:]))

        final_embeddings = torch.cat(all_embeddings, dim=1)
        final_mask = chunks_mask.squeeze(0).reshape(-1)
        final_mask = final_mask[:final_embeddings.shape[1]]
        return final_embeddings, final_mask, torch.stack(all_hidden_states)

    def aggregate_embeddings(self, embeddings: torch.Tensor, attention_mask: torch.Tensor) -> Dict[str, torch.Tensor]:
        if attention_mask.dim() == 1:
            attention_mask = attention_mask.unsqueeze(0)
        if embeddings.dim() == 2:
            embeddings = embeddings.unsqueeze(0)
        
        mask_expanded = attention_mask.unsqueeze(-1).expand(embeddings.size())
        masked_embeddings = embeddings * mask_expanded
        
        sum_embeddings = torch.sum(masked_embeddings, dim=1)
        valid_tokens = torch.sum(attention_mask, dim=1, keepdim=True)
        valid_tokens = torch.clamp(valid_tokens, min=1e-9)
        mean_pooled = sum_embeddings / valid_tokens
        
        masked_embeddings_for_max = masked_embeddings.clone()
        masked_embeddings_for_max[~mask_expanded.bool()] = float('-inf')
        max_pooled = torch.max(masked_embeddings_for_max, dim=1)[0]
        
        attention_weights = torch.mean(embeddings, dim=-1)
        attention_weights = attention_weights.masked_fill(~attention_mask.bool(), float('-inf'))
        attention_scores = softmax(attention_weights, dim=1).unsqueeze(-1)
        attention_pooled = torch.sum(attention_scores * embeddings, dim=1)

        return {"mean": mean_pooled, "max": max_pooled, "attention": attention_pooled}

    def _calculate_statistical_features(self, text: str) -> Dict[str, float]:
        words = text.split()
        sentences = text.split('.')
        unique_words = len(set(words))
        total_words = max(1, len(words))
        total_sentences = max(1, len(sentences))

        return {
            "word_count": float(total_words),
            "sentence_count": float(total_sentences),
            "avg_word_length": sum(len(word) for word in words) / total_words,
            "avg_sentence_length": total_words / total_sentences,
            "lexical_diversity": unique_words / total_words
        }

    def _calculate_readability_scores(self, text: str) -> Dict[str, float]:
        return {
            "flesch_reading_ease": float(textstat.flesch_reading_ease(text)),
            "gunning_fog": float(textstat.gunning_fog(text)),
            "smog_index": float(textstat.smog_index(text)),
            "automated_readability_index": float(textstat.automated_readability_index(text)),
            "dale_chall_score": float(textstat.dale_chall_readability_score(text)),
            "difficult_words": float(textstat.difficult_words(text)),
            "linsear_write_formula": float(textstat.linsear_write_formula(text))
        }
    def create_weight_vectors(self, total_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
        component_sizes = [768, 768, 768, 768, 3072, 5, 7, 20] 
        weight1 = torch.ones(total_size)
        weight2 = torch.ones(total_size)
        component_weights_v1 = [1.2, 1.2, 1.2, 1.5, 1.0, 1.0, 1.0, 0.8]
        component_weights_v2 = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0]
        start_idx = 0
        for i, size in enumerate(component_sizes):
            weight1[start_idx:start_idx + size] *= component_weights_v1[i]
            weight2[start_idx:start_idx + size] *= component_weights_v2[i]
            start_idx += size
        return weight1, weight2

    def _extract_topics(self, text: str, n_topics: int = 1, num_keywords: int = 20) -> Tuple[torch.Tensor, List[Tuple[str, float]]]:
        vectorizer = CountVectorizer(stop_words="english", min_df=1)
        lda = LatentDirichletAllocation(n_components=n_topics, random_state=42)
        text_vectorized = vectorizer.fit_transform([text])
        lda.fit(text_vectorized)
        feature_names = vectorizer.get_feature_names_out()
        
        keywords = []
        for topic in lda.components_:
            top_indices = topic.argsort()[:-(num_keywords):][::-1]
            topic_keywords = [(feature_names[i], float(topic[i])) for i in top_indices]
            keywords.extend(topic_keywords)
        
        keywords = sorted(keywords, key=lambda x: x[1], reverse=True)[:num_keywords]
        
        keywords_embeddings = torch.zeros(num_keywords, device=self.device)
        for i, (keyword, weight) in enumerate(keywords):
            with torch.no_grad():
                keyword_tokens = self.tokenizer(
                    keyword, 
                    return_tensors="pt", 
                    padding=True, 
                    truncation=True
                )
                input_ids = keyword_tokens["input_ids"].clone().to(self.device)
                attention_mask = keyword_tokens["attention_mask"].clone().to(self.device)
                
                keyword_outputs = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask
                )
                keyword_embedding = keyword_outputs.last_hidden_state.clone().mean()
                keywords_embeddings[i] = keyword_embedding * weight
        
        keywords_embeddings = keywords_embeddings.detach().clone()
        return keywords_embeddings, keywords
    
    def process_document(self, text: str) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[Tuple[str, float]]]:
        chunks_ids, chunks_mask = self.get_chunks(text)
        embeddings, attention_mask, hidden_states = self.process_text(chunks_ids, chunks_mask)
        
        pooled_embeddings = self.aggregate_embeddings(embeddings, attention_mask)
        mean_pooled = pooled_embeddings["mean"]
        max_pooled = pooled_embeddings["max"]
        attention_pooled = pooled_embeddings["attention"]
        cls_embeddings = embeddings[:, 0, :]
        layer_wise_embeddings = hidden_states[..., 0, :].mean(dim=0).unsqueeze(0)
        
        statistical_features = torch.tensor(list(self._calculate_statistical_features(text).values())).to(self.device)
        readability_scores = torch.tensor(list(self._calculate_readability_scores(text).values())).to(self.device)
        keyword_embeddings, keywords = self._extract_topics(text)
        
        combined_features = torch.cat([
            max_pooled.flatten(),
            mean_pooled.flatten(),
            attention_pooled.flatten(),
            cls_embeddings.flatten(),
            layer_wise_embeddings.flatten(),
            statistical_features,
            readability_scores,
            keyword_embeddings.flatten()
        ])
        print(f"[INFO] Keyword Embeddings Length : {keyword_embeddings.size()}")
        weight1, weight2 = self.create_weight_vectors(combined_features.size(0))
        return combined_features, weight1, weight2, keywords

    def process_single_file(self, file_data: tuple) -> Tuple[bool, torch.Tensor, torch.Tensor]:
        json_file, input_path, vector_output_path, keywords_output_path, first_vector_size = file_data
        try:
            relative_path = os.path.relpath(json_file.parent, input_path)
            vector_dir = vector_output_path / relative_path
            keywords_dir = keywords_output_path / relative_path
            
            with self.lock:
                vector_dir.mkdir(parents=True, exist_ok=True)
                keywords_dir.mkdir(parents=True, exist_ok=True)
                print(f"\n[INFO] Processing file: {json_file}")

            with open(json_file, "r") as f:
                sections = json.load(f)

            text = ""
            for heading, content in sections.items():
                text += f"{heading}\n{content}\n\n"

            with self.lock:
                print(f"[INFO] Extracting features from text (length: {len(text)} chars)")
            
            combined_features, weight1, weight2, keywords = self.process_document(text)
            print(f"[INFO] Combined Features Length: {combined_features.size()}")

            current_size = combined_features.size(0)
            if first_vector_size is not None:
                assert current_size == first_vector_size, f"Vector size mismatch: {current_size} vs {first_vector_size}"

            vector_file = vector_dir / f"{json_file.stem}.pt"
            torch.save(combined_features, vector_file)
            keywords_file = keywords_dir / f"{json_file.stem}.txt"
            with open(keywords_file, 'w', encoding='utf-8') as f:
                for keyword, weight in keywords:
                    f.write(f"{keyword}\n")
            
            with self.lock:
                print(f"[SUCCESS] Saved vector features to: {vector_file}")
                print(f"[SUCCESS] Saved keywords to: {keywords_file}")
            
            return True, weight1, weight2

        except Exception as e:
            with self.lock:
                print(f"[ERROR] Failed to process {json_file}: {str(e)}")
            return False, None, None

    def process_json_files(self, input_dir: str, output_dir: str, max_workers: int = 4):
        print(f"[INFO] Starting processing from input directory: {input_dir}")
        print(f"[INFO] Output will be saved to: {output_dir}")
        print(f"[INFO] Using {max_workers} worker threads")

        processed_count = 0
        total_files = 0
        first_vector_size = None
        weights_saved = False
        all_files = []
        for category in ["publishable", "non-publishable"]:
            input_path = Path(input_dir) / category
            vector_output_path = Path(output_dir) / "vectors" / category
            keywords_output_path = Path(output_dir) / "keywords" / category
            
            print(f"\n[INFO] Collecting files from category: {category}")
            
            for root, dirs, files in os.walk(input_path):
                json_files = [Path(root) / f for f in files if f.lower().endswith('.json')]
                all_files.extend([(f, input_path, vector_output_path, keywords_output_path, first_vector_size) for f in json_files])
                total_files += len(json_files)
        
        # input_path = Path(input_dir) 
        # vector_output_path = Path(output_dir) / "vectors"
        # keywords_output_path = Path(output_dir) / "keywords" 
            
        # for root, dirs, files in os.walk(input_path):
        #     json_files = [Path(root) / f for f in files if f.lower().endswith('.json')]
        #     all_files.extend([(f, input_path, vector_output_path, keywords_output_path, first_vector_size) for f in json_files])
        #     total_files += len(json_files)

        print(f"\n[INFO] Found {total_files} files to process")

        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = []
            for file_data in all_files:
                future = executor.submit(self.process_single_file, file_data)
                futures.append(future)

            for future in concurrent.futures.as_completed(futures):
                success, weight1, weight2 = future.result()
                if success:
                    processed_count += 1
                    if not weights_saved and weight1 is not None and weight2 is not None:
                        weight_path1 = Path(output_dir) / "weight1.pt"
                        weight_path2 = Path(output_dir) / "weight2.pt"
                        torch.save(weight1, weight_path1)
                        torch.save(weight2, weight_path2)
                        print(f"[SUCCESS] Saved weight vectors")
                        weights_saved = True

                    print(f"[PROGRESS] Processed {processed_count}/{total_files} files ({(processed_count/total_files)*100:.1f}%)")

        print(f"\n[COMPLETE] Processing finished. Total files processed: {processed_count}/{total_files}")
        if first_vector_size is not None:
            print(f"Vector size for all processed files: {first_vector_size}")

if __name__ == "__main__":
    input_dir = "Dataset/texts"
    output_dir = "Dataset"
    processor = DoraemonProcessor()
    processor.process_json_files(input_dir, output_dir, max_workers=1)

In [None]:
class DoraemonDataset(Dataset):
    def __init__(self, root_dir, weights_path):
        self.features = []
        self.labels = []
        print(f"\nLoading weights from {weights_path}")
        self.weights = torch.load(weights_path)
        
        # Load non-publishable data
        non_pub_dir = os.path.join(root_dir, "non-publishable")
        print(f"\nLoading non-publishable data from {non_pub_dir}")
        non_pub_files = [f for f in os.listdir(non_pub_dir) if f.endswith(".pt")]
        for file in tqdm(non_pub_files, desc="Loading non-publishable data"):
            tensor = torch.load(os.path.join(non_pub_dir, file))
            tensor = tensor * self.weights
            self.features.append(tensor)
            self.labels.append(0)
        
        # Load publishable data
        pub_dir = os.path.join(root_dir, "publishable")
        print(f"\nLoading publishable data from {pub_dir}")
        pub_count = 0
        for subfolder in os.listdir(pub_dir):
            subfolder_path = os.path.join(pub_dir, subfolder)
            if os.path.isdir(subfolder_path):
                files = [f for f in os.listdir(subfolder_path) if f.endswith(".pt")]
                for file in tqdm(files, desc=f"Loading {subfolder}"):
                    tensor = torch.load(os.path.join(subfolder_path, file))
                    tensor = tensor * self.weights
                    self.features.append(tensor)
                    self.labels.append(1)
                    pub_count += 1
        
        self.features = torch.stack(self.features)
        self.labels = torch.tensor(self.labels, dtype=torch.float)
        
        print("\nDataset Summary:")
        print(f"Total samples: {len(self.labels)}")
        print(f"Non-publishable samples: {len(non_pub_files)}")
        print(f"Publishable samples: {pub_count}")
        print(f"Feature dimension: {self.features.shape[1]}")
        print(f"Class distribution: {torch.bincount(self.labels.long()).tolist()}")
        
        # Normalize features
        print("\nNormalizing features...")
        self.features = (self.features - self.features.mean(dim=0)) / self.features.std(dim=0)
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

class DoraemonBinaryClassifier(nn.Module):
    def __init__(self, input_dim):
        super(DoraemonBinaryClassifier, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(0.5),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(0.5),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Dropout(0.5),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.BatchNorm1d(32),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )
        
        # Print model architecture
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print(f"\nModel Architecture:")
        print(f"Total parameters: {total_params:,}")
        print(f"Trainable parameters: {trainable_params:,}")

    def forward(self, x):
        return self.model(x)

def calculate_metrics(y_true, y_pred):
    tp = torch.sum((y_true == 1) & (y_pred == 1)).float()
    tn = torch.sum((y_true == 0) & (y_pred == 0)).float()
    fp = torch.sum((y_true == 0) & (y_pred == 1)).float()
    fn = torch.sum((y_true == 1) & (y_pred == 0)).float()
    
    accuracy = (tp + tn) / (tp + tn + fp + fn)
    precision = tp / (tp + fp) if tp + fp > 0 else torch.tensor(0.0)
    recall = tp / (tp + fn) if tp + fn > 0 else torch.tensor(0.0)
    f1 = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else torch.tensor(0.0)
    
    return {
        'accuracy': accuracy.item(),
        'precision': precision.item(),
        'recall': recall.item(),
        'f1': f1.item(),
        'tp': tp.item(),
        'tn': tn.item(),
        'fp': fp.item(),
        'fn': fn.item()
    }

def create_model(input_dim):
    print(f"\nCreating model with input dimension: {input_dim}")
    model = DoraemonBinaryClassifier(input_dim)
    loss_fn = nn.BCELoss()
    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.05)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
    return model, loss_fn, optimizer, scheduler

def train_model(model, loss_fn, optimizer, scheduler, train_loader, val_loader, epochs=20, device='cpu'):
    model.to(device)
    best_val_loss = float('inf')
    
    print(f"\nStarting training for {epochs} epochs")
    print(f"Training batches per epoch: {len(train_loader)}")
    print(f"Validation batches per epoch: {len(val_loader)}")
    
    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_predictions = []
        train_labels = []
        
        train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs} [Train]')
        for X_batch, y_batch in train_pbar:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            
            optimizer.zero_grad()
            outputs = model(X_batch).squeeze()
            loss = loss_fn(outputs, y_batch)
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            train_loss += loss.item()
            train_predictions.extend((outputs >= 0.5).float().cpu())
            train_labels.extend(y_batch.cpu())
            
            train_pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_predictions = []
        val_labels = []
        
        val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{epochs} [Val]')
        with torch.no_grad():
            for X_val, y_val in val_pbar:
                X_val, y_val = X_val.to(device), y_val.to(device)
                val_outputs = model(X_val).squeeze()
                loss = loss_fn(val_outputs, y_val)
                val_loss += loss.item()
                
                val_predictions.extend((val_outputs >= 0.5).float().cpu())
                val_labels.extend(y_val.cpu())
                
                val_pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        train_predictions = torch.tensor(train_predictions)
        train_labels = torch.tensor(train_labels)
        val_predictions = torch.tensor(val_predictions)
        val_labels = torch.tensor(val_labels)
        
        train_metrics = calculate_metrics(train_labels, train_predictions)
        val_metrics = calculate_metrics(val_labels, val_predictions)
        
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        
        print(f"\nEpoch {epoch+1}/{epochs} Summary:")
        print(f"Training:")
        print(f"  Loss: {avg_train_loss:.4f}")
        print(f"  Accuracy: {train_metrics['accuracy']:.4f}")
        print(f"  Precision: {train_metrics['precision']:.4f}")
        print(f"  Recall: {train_metrics['recall']:.4f}")
        print(f"  F1 Score: {train_metrics['f1']:.4f}")
        print(f"  Confusion Matrix: [TP: {train_metrics['tp']}, TN: {train_metrics['tn']}, FP: {train_metrics['fp']}, FN: {train_metrics['fn']}]")
        
        print(f"Validation:")
        print(f"  Loss: {avg_val_loss:.4f}")
        print(f"  Accuracy: {val_metrics['accuracy']:.4f}")
        print(f"  Precision: {val_metrics['precision']:.4f}")
        print(f"  Recall: {val_metrics['recall']:.4f}")
        print(f"  F1 Score: {val_metrics['f1']:.4f}")
        print(f"  Confusion Matrix: [TP: {val_metrics['tp']}, TN: {val_metrics['tn']}, FP: {val_metrics['fp']}, FN: {val_metrics['fn']}]")
        
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            print(f"\nNew best model found! Saving checkpoint...")
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_val_loss,
                'metrics': val_metrics,
            }, 'doraemon_binary_classifier.pt')
        
        scheduler.step(avg_val_loss)
        current_lr = optimizer.param_groups[0]['lr']
        print(f"Current learning rate: {current_lr}")

def prepare_data(data_dir, weights_path, train_split=0.8, batch_size=32):
    print(f"\nPreparing data from {data_dir}")
    print(f"Train split: {train_split}")
    print(f"Batch size: {batch_size}")
    
    dataset = DoraemonDataset(data_dir, weights_path)
    labels = dataset.labels
    
    stratified_split = StratifiedShuffleSplit(n_splits=1, test_size=(1 - train_split), random_state=42)
    train_indices, val_indices = next(stratified_split.split(torch.arange(len(labels)), labels))
    
    train_dataset = Subset(dataset, train_indices)
    val_dataset = Subset(dataset, val_indices)
    
    print(f"Train set size: {len(train_dataset)}")
    print(f"Validation set size: {len(val_dataset)}")
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=4)
    
    return train_loader, val_loader, dataset.features.shape[1]

def main():
    print("\nStarting binary classification training")
    
    data_dir = "Dataset/vectors"
    weights_path = "Dataset/weight1.pt"
    batch_size = 32
    epochs = 10
    
    train_loader, val_loader, input_dim = prepare_data(data_dir, weights_path, batch_size=batch_size)
    
    model, loss_fn, optimizer, scheduler = create_model(input_dim)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"\nUsing device: {device}")
    
    train_model(model, loss_fn, optimizer, scheduler, train_loader, val_loader, 
                epochs=epochs, device=device)
    
    print("\nTraining completed. Best model saved as 'doraemon_binary_classifier.pt'")

if __name__ == "__main__":
    main()

In [None]:
class DoraemonConferenceDataset(Dataset):
    def __init__(self, root_dir, weights_path, label_map):
        self.features = []
        self.labels = []
        print(f"\nLoading weights from {weights_path}")
        self.weights = torch.load(weights_path)
        
        total_samples = 0
        class_counts = {label: 0 for label in label_map.keys()}
        
        print("\nLoading conference data...")
        for label, subfolder in label_map.items():
            subfolder_path = os.path.join(root_dir, "publishable", subfolder)
            if os.path.isdir(subfolder_path):
                files = [f for f in os.listdir(subfolder_path) if f.endswith(".pt")]
                for file in tqdm(files, desc=f"Loading {subfolder}"):
                    tensor = torch.load(os.path.join(subfolder_path, file))
                    tensor = tensor * self.weights
                    self.features.append(tensor)
                    self.labels.append(label)
                    class_counts[label] += 1
                    total_samples += 1
        
        self.features = torch.stack(self.features)
        self.labels = torch.tensor(self.labels, dtype=torch.long)
        
        print("\nDataset Summary:")
        print(f"Total samples: {total_samples}")
        for label, count in class_counts.items():
            print(f"{label_map[label]}: {count} samples")
        print(f"Feature dimension: {self.features.shape[1]}")
        
        # Normalize features
        print("\nNormalizing features...")
        self.features = (self.features - self.features.mean(dim=0)) / (self.features.std(dim=0) + 1e-6)
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

class DoraemonConferenceClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(DoraemonConferenceClassifier, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(0.5),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(0.5),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Dropout(0.5),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.BatchNorm1d(32),
            nn.Linear(32, num_classes)
        )
        
        # Print model architecture
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print(f"\nModel Architecture:")
        print(f"Total parameters: {total_params:,}")
        print(f"Trainable parameters: {trainable_params:,}")

    def forward(self, x):
        return self.model(x)

def calculate_multiclass_metrics(y_true, y_pred, num_classes):
    correct = (y_pred == y_true).sum().item()
    total = y_true.size(0)
    accuracy = correct / total
    
    # Per-class metrics
    class_correct = torch.zeros(num_classes)
    class_total = torch.zeros(num_classes)
    for i in range(num_classes):
        mask = (y_true == i)
        class_correct[i] = ((y_pred == y_true) & mask).sum().item()
        class_total[i] = mask.sum().item()
    
    class_accuracies = class_correct / (class_total + 1e-6)
    
    return {
        'accuracy': accuracy,
        'class_accuracies': class_accuracies.tolist(),
        'class_counts': class_total.tolist()
    }

def conference_model(input_dim, num_classes):
    print(f"\nCreating model with input dimension: {input_dim}")
    model = DoraemonConferenceClassifier(input_dim, num_classes)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
    return model, loss_fn, optimizer, scheduler

def train_multiclass_model(model, loss_fn, optimizer, scheduler, train_loader, val_loader, label_map, epochs=20, device='cpu'):
    model.to(device)
    best_val_loss = float('inf')
    num_classes = len(label_map)
    
    print(f"\nStarting training for {epochs} epochs")
    print(f"Training batches per epoch: {len(train_loader)}")
    print(f"Validation batches per epoch: {len(val_loader)}")
    
    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_predictions = []
        train_labels = []
        
        train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs} [Train]')
        for X_batch, y_batch in train_pbar:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            
            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = loss_fn(outputs, y_batch)
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            train_loss += loss.item()
            train_predictions.extend(torch.argmax(outputs, dim=1).cpu())
            train_labels.extend(y_batch.cpu())
            
            train_pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_predictions = []
        val_labels = []
        
        val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{epochs} [Val]')
        with torch.no_grad():
            for X_val, y_val in val_pbar:
                X_val, y_val = X_val.to(device), y_val.to(device)
                val_outputs = model(X_val)
                loss = loss_fn(val_outputs, y_val)
                val_loss += loss.item()
                
                val_predictions.extend(torch.argmax(val_outputs, dim=1).cpu())
                val_labels.extend(y_val.cpu())
                
                val_pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        train_predictions = torch.tensor(train_predictions)
        train_labels = torch.tensor(train_labels)
        val_predictions = torch.tensor(val_predictions)
        val_labels = torch.tensor(val_labels)
        
        train_metrics = calculate_multiclass_metrics(train_labels, train_predictions, num_classes)
        val_metrics = calculate_multiclass_metrics(val_labels, val_predictions, num_classes)
        
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        
        print(f"\nEpoch {epoch+1}/{epochs} Summary:")
        print(f"Training:")
        print(f"  Loss: {avg_train_loss:.4f}")
        print(f"  Overall Accuracy: {train_metrics['accuracy']:.4f}")
        print("  Per-class Accuracies:")
        for i, acc in enumerate(train_metrics['class_accuracies']):
            print(f"    {label_map[i]}: {acc:.4f} ({train_metrics['class_counts'][i]} samples)")
        
        print(f"\nValidation:")
        print(f"  Loss: {avg_val_loss:.4f}")
        print(f"  Overall Accuracy: {val_metrics['accuracy']:.4f}")
        print("  Per-class Accuracies:")
        for i, acc in enumerate(val_metrics['class_accuracies']):
            print(f"    {label_map[i]}: {acc:.4f} ({val_metrics['class_counts'][i]} samples)")
        
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            print(f"\nNew best model found! Saving checkpoint...")
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_val_loss,
                'metrics': val_metrics,
            }, 'doraemon_conference_classifier.pt')
        
        scheduler.step(avg_val_loss)
        current_lr = optimizer.param_groups[0]['lr']
        print(f"Current learning rate: {current_lr}")

def prepare_multiclass_data(data_dir, weights_path, label_map, train_split=0.8, batch_size=32):
    print(f"\nPreparing data from {data_dir}")
    print(f"Train split: {train_split}")
    print(f"Batch size: {batch_size}")
    
    dataset = DoraemonConferenceDataset(data_dir, weights_path, label_map)
    labels = [label for _, label in dataset]
    
    train_indices, val_indices = train_test_split(
        range(len(labels)),
        test_size=1 - train_split,
        stratify=labels,
        random_state=42
    )
    
    train_dataset = Subset(dataset, train_indices)
    val_dataset = Subset(dataset, val_indices)
    
    print(f"Train set size: {len(train_dataset)}")
    print(f"Validation set size: {len(val_dataset)}")
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=4)
    
    return train_loader, val_loader, dataset.features.shape[1]

def main():
    print("\nStarting conference classification training")
    
    data_dir = "Dataset/vectors"
    weights_path = "Dataset/weight2.pt"
    label_map = {0: "CVPR", 1: "TMLR", 2: "KDD", 3: "NEURIPS", 4: "EMNLP"}
    batch_size = 32
    epochs = 10
    
    train_loader, val_loader, input_dim = prepare_multiclass_data(data_dir, weights_path, label_map, batch_size=batch_size)
    
    model, loss_fn, optimizer, scheduler = conference_model(input_dim, len(label_map))
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"\nUsing device: {device}")
    
    train_multiclass_model(model, loss_fn, optimizer, scheduler, train_loader, val_loader, 
                          label_map, epochs=epochs, device=device)
    
    print("\nTraining completed. Best model saved as 'doraemon_conference_classifier.pt'")

if __name__ == "__main__":
    main()

In [1]:
import torch
import torch.nn as nn
import pandas as pd
from pathlib import Path
import json
import numpy as np
from tqdm import tqdm
from Mistral7b_Instruct import Doraemon_justification
from Binary_classification import DoraemonBinaryClassifier
from Conference_classification import DoraemonConferenceClassifier

def load_model(model, checkpoint_path, device):
    """Helper function to load model with correct state dict structure"""
    checkpoint = torch.load(checkpoint_path, map_location=device)
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
    else:
        model.load_state_dict(checkpoint)
    return model
def process_saved_data(input_dir: Path, output_dir: Path):
    print("[INFO] Initializing processing of saved data...")
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[INFO] Using device: {device}")

    text_dir = input_dir / "texts"
    vector_dir = input_dir / "vectors"
    keywords_dir = input_dir / "keywords"

    for dir_path in [text_dir, vector_dir, keywords_dir]:
        if not dir_path.exists():
            raise ValueError(f"Directory not found: {dir_path}")
        
    vector_files = list(vector_dir.glob("*.pt"))
    if not vector_files:
        raise ValueError(f"No vector files found in {vector_dir}")
    print(f"[INFO] Found {len(vector_files)} files to process")

    sample_vector = torch.load(vector_files[0], map_location=device)
    input_dim = sample_vector.shape[0]
    print(f"[INFO] Detected input dimension: {input_dim}")

    try:
        binary_classifier = DoraemonBinaryClassifier(input_dim=input_dim).to(device)
        conference_classifier = DoraemonConferenceClassifier(input_dim=input_dim, num_classes=5).to(device)
        
        binary_classifier = load_model(binary_classifier, "doraemon_binary_classifier.pt", device)
        conference_classifier = load_model(conference_classifier, "doraemon_conference_classifier.pt", device)
        
        binary_classifier.eval()
        conference_classifier.eval()
    except Exception as e:
        raise RuntimeError(f"Error loading models: {str(e)}")

    label_map = {0: "CVPR", 1: "TMLR", 2: "KDD", 3: "NEURIPS", 4: "EMNLP"}
    
    print("[INFO] Loading and processing saved data...")
    features_list = []
    file_ids = []
    
    for vector_file in tqdm(vector_files, desc="Loading vectors"):
        try:
            features = torch.load(vector_file, map_location=device)
            features_list.append(features)
            file_ids.append(vector_file.stem)
        except Exception as e:
            print(f"[WARNING] Error loading vector {vector_file}: {str(e)}")
            continue

    print("[INFO] Computing normalization statistics...")
    all_features = torch.stack(features_list)
    feature_mean = all_features.mean(dim=0)
    feature_std = all_features.std(dim=0) + 1e-6

    print("[INFO] Processing with normalized features...")
    results = []
    
    for idx, file_id in enumerate(tqdm(file_ids, desc="Processing files")):
        try:
            text_file = text_dir / f"{file_id}.json"
            keywords_file = keywords_dir / f"{file_id}.txt"
            
            with open(text_file, 'r') as f:
                parsed_content = json.load(f)
            
            with open(keywords_file, 'r') as f:
                keywords = f.read().splitlines()

            abstract = ""
            conclusion = ""
            for heading, content in parsed_content.items():
                if 'abstract' in heading.lower() or 'introduction' in heading.lower():
                    abstract = content
                elif 'conclusion' in heading.lower() or 'summary' in heading.lower():
                    conclusion = content

            if abstract == "" or conclusion == "":
                for heading, content in parsed_content.items():
                    if 'abstract' in content.lower() or 'introduction' in content.lower():
                        abstract = content
                        break
                for heading, content in parsed_content.items():
                    if 'conclusion' in content.lower() or 'summary' in content.lower():
                        conclusion = content
                        break

            normalized_features = (features_list[idx] - feature_mean) / feature_std
            
            with torch.no_grad():
                binary_pred = binary_classifier(normalized_features.unsqueeze(0).to(device))
                is_publishable = binary_pred.item() > 0.5
                
                conference = "na"
                justification = "na"
                
                if is_publishable:
                    conf_pred = conference_classifier(normalized_features.unsqueeze(0).to(device))
                    conference_id = torch.argmax(conf_pred).item()
                    conference = label_map[conference_id]
                    
                    justification = Doraemon_justification(
                        abstract=abstract,
                        conclusion=conclusion,
                        keywords=keywords,
                        conference_name=conference
                    )
            
            results.append([file_id, int(is_publishable), conference, justification])
            
        except Exception as e:
            print(f"[WARNING] Error processing results for {file_id}: {str(e)}")
            results.append([file_id, 0, 'error', f'Error: {str(e)}'])

    df = pd.DataFrame(results, columns=['Paper ID', 'Publishable', 'Conference', 'Rationale'])
    df.to_csv(output_dir / "results.csv", index=False)
    print(f"[INFO] Results saved to {output_dir / 'results.csv'}")

if __name__ == "__main__":
    input_dir = Path("Sample")
    output_dir = Path("Sample")
    process_saved_data(input_dir, output_dir)


KeyboardInterrupt

