In [4]:
# # ArXiv Dataset Generation with Llama-2-7B for Research Paper Critiques
# # This script downloads real papers from arXiv and uses Llama-2 to create a critique dataset

# Install required packages - minimizing dependencies
!pip install -q arxiv==1.4.7 pymupdf==1.23.5 datasets==2.14.5 tqdm==4.66.1
!pip install -q torch==2.1.0 transformers==4.36.0 accelerate==0.25.0 bitsandbytes==0.41.0

# # Import necessary libraries
# import os
# import re
# import json
# import random
# import time
# import arxiv
# import fitz  # PyMuPDF
# import torch
# from tqdm.notebook import tqdm
# from datasets import Dataset
# from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

# # Create directories
# !mkdir -p ./data
# !mkdir -p ./papers

# # Function to download papers from arXiv
# def download_arxiv_papers(search_queries, max_results=10, output_dir="./papers"):
#     """Download papers from arXiv using the API."""
#     os.makedirs(output_dir, exist_ok=True)
#     downloaded_papers = []
    
#     for query in search_queries:
#         print(f"Searching arXiv for: {query}")
        
#         # Create a search client
#         search = arxiv.Search(
#             query=query,
#             max_results=max_results,
#             sort_by=arxiv.SortCriterion.SubmittedDate,
#             sort_order=arxiv.SortOrder.Descending
#         )
        
#         # Download papers
#         for result in tqdm(search.results(), desc=f"Downloading papers for '{query}'"):
#             # Create a safe filename
#             paper_id = result.get_short_id()
#             filename = f"{paper_id}.pdf"
#             filepath = os.path.join(output_dir, filename)
            
#             # Skip if already downloaded
#             if os.path.exists(filepath):
#                 downloaded_papers.append(filepath)
#                 continue
            
#             try:
#                 # Download the paper
#                 result.download_pdf(dirpath=output_dir, filename=filename)
#                 downloaded_papers.append(filepath)
                
#                 # Be nice to arXiv API
#                 time.sleep(3)
#             except Exception as e:
#                 print(f"Error downloading {paper_id}: {str(e)}")
    
#     return downloaded_papers

# # Function to extract paragraphs from PDFs
# def extract_paragraphs_from_pdfs(pdf_paths, min_words=50, max_words=300):
#     """Extract paragraphs from PDF files."""
#     all_paragraphs = []
    
#     for pdf_path in tqdm(pdf_paths, desc="Processing PDFs"):
#         try:
#             # Open the PDF
#             doc = fitz.open(pdf_path)
            
#             # Detect the starting page (skip front matter)
#             start_page = 0
#             for i in range(min(5, len(doc))):
#                 text = doc[i].get_text().lower()
#                 if any(marker in text for marker in ["introduction", "background", "1. introduction"]):
#                     start_page = i
#                     break
            
#             # Extract text from each page
#             for page_num in range(start_page, len(doc)):
#                 page = doc[page_num]
#                 blocks = page.get_text("blocks")  # Get text as blocks which preserves layout better
                
#                 for block in blocks:
#                     # Block structure: (x0, y0, x1, y1, text, block_type, block_no)
#                     text = block[4].strip()
                    
#                     # Clean up the text
#                     text = re.sub(r'\s+', ' ', text)  # Replace multiple spaces with single space
                    
#                     # Skip if empty or too short
#                     if not text or len(text.split()) < min_words:
#                         continue
                        
#                     # Skip likely headers, footers, figure captions, etc.
#                     if re.match(r'^[0-9]+\.?\s*$', text):  # Just a number
#                         continue
#                     if re.match(r'^figure\s+[0-9]+', text.lower()):  # Figure caption
#                         continue
#                     if re.match(r'^table\s+[0-9]+', text.lower()):  # Table caption
#                         continue
#                     if text.isupper() and len(text.split()) < 15:  # ALL CAPS header
#                         continue
                    
#                     # Handle paragraphs that are too long
#                     if len(text.split()) > max_words:
#                         words = text.split()
#                         chunks = []
#                         for i in range(0, len(words), max_words):
#                             chunk = ' '.join(words[i:i+max_words])
#                             chunks.append(chunk)
#                         all_paragraphs.extend(chunks)
#                     else:
#                         all_paragraphs.append(text)
        
#         except Exception as e:
#             print(f"Error processing {pdf_path}: {str(e)}")
    
#     print(f"Extracted {len(all_paragraphs)} paragraphs from {len(pdf_paths)} PDFs")
#     return all_paragraphs

# # Function to set up Llama-2 model
# def setup_llama_model(device="cuda", load_in_8bit=True):
#     """Set up the Llama-2 model for text generation."""
#     print("Loading Llama-2 7B Chat model...")
    
#     # Configure quantization if using CUDA
#     if device == "cuda" and torch.cuda.is_available() and load_in_8bit:
#         bnb_config = BitsAndBytesConfig(
#             load_in_8bit=True,
#             bnb_8bit_use_double_quant=True,
#             bnb_8bit_quant_type="nf4",
#             bnb_8bit_compute_dtype=torch.float16
#         )
        
#         # Load tokenizer and model - Using Meta AI's open source version
#         tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
#         model = AutoModelForCausalLM.from_pretrained(
#             "meta-llama/Llama-2-7b-chat-hf",
#             device_map="auto",
#             quantization_config=bnb_config,
#             trust_remote_code=True
#         )
#     else:
#         # CPU or non-quantized loading
#         tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
#         model = AutoModelForCausalLM.from_pretrained(
#             "meta-llama/Llama-2-7b-chat-hf",
#             device_map="auto" if device == "cuda" else None,
#             trust_remote_code=True
#         )
    
#     # Set padding token if not set
#     if tokenizer.pad_token is None:
#         tokenizer.pad_token = tokenizer.eos_token
    
#     return tokenizer, model

# # Function to format Llama-2 prompt
# def format_llama_prompt(instruction, content=None):
#     """Format a prompt according to Llama-2's expected format."""
#     if content:
#         return f"<s>[INST] {instruction}\n\n{content} [/INST]"
#     else:
#         return f"<s>[INST] {instruction} [/INST]"

# # Function to extract Llama-2 response
# def extract_llama_response(generated_text):
#     """Extract the model's response from the generated text."""
#     response = generated_text.split("[/INST]")[-1].strip()
#     if "</s>" in response:
#         response = response.split("</s>")[0].strip()
    
#     return response

# # Function to generate a critique for a paragraph
# def generate_critique(tokenizer, model, paragraph, issue_type=None):
#     """Generate a critique for a paragraph using Llama-2."""
#     # Define issue types
#     issue_types = [
#         "missing_evidence",
#         "logical_contradiction",
#         "unclear_argument",
#         "poor_citation",
#         "grammar_spelling",
#         "undefined_terminology",
#         "statistical_error",
#         "methodology_issue",
#         "unsubstantiated_claim",
#         "structural_issue",
#         "well_written"  # Include examples with no issues
#     ]
    
#     # Select a random issue type if not specified
#     if issue_type is None:
#         issue_type = random.choice(issue_types)
    
#     # Create instruction
#     instruction = f"""You are an expert academic reviewer with years of experience reviewing research papers.
# Analyze the following paragraph from a real research paper and provide a detailed critique.
# Focus on identifying issues related to: {issue_type.replace('_', ' ')}.
# If you genuinely find no issues, explain why the paragraph is well-written instead.
# Your critique should be specific, actionable, and professional.

# Provide only the critique - do not include any introductory text or explanations about your role."""
    
#     # Format prompt
#     prompt = format_llama_prompt(instruction, paragraph)
    
#     try:
#         # Tokenize input
#         inputs = tokenizer(prompt, return_tensors="pt")
#         inputs = {k: v.to(model.device) for k, v in inputs.items()}
        
#         # Generate text
#         outputs = model.generate(
#             **inputs,
#             max_new_tokens=512,
#             temperature=0.7,
#             top_p=0.9,
#             do_sample=True,
#             pad_token_id=tokenizer.pad_token_id
#         )
        
#         # Decode output
#         generated_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
        
#         # Extract response
#         critique = extract_llama_response(generated_text)
        
#         # Create result
#         result = {
#             "paragraph": paragraph,
#             "critique": critique,
#             "issue_type": issue_type
#         }
        
#         return result
    
#     except Exception as e:
#         print(f"Error generating critique: {str(e)}")
#         return None

# # Function to create dataset
# def create_dataset(paragraphs, tokenizer, model, num_examples, output_path="./data"):
#     """Create a dataset of paragraph-critique pairs."""
#     examples = []
    
#     # Ensure we don't try to generate more examples than paragraphs
#     num_examples = min(num_examples, len(paragraphs))
    
#     # Shuffle paragraphs
#     random.shuffle(paragraphs)
    
#     # Create progress bar
#     with tqdm(total=num_examples, desc="Generating critiques") as pbar:
#         for i in range(num_examples):
#             # Get paragraph
#             paragraph = paragraphs[i]
            
#             # Generate critique
#             example = generate_critique(tokenizer, model, paragraph)
            
#             if example:
#                 examples.append(example)
#                 pbar.update(1)
            
#             # Clear CUDA cache occasionally
#             if i % 10 == 9:
#                 torch.cuda.empty_cache()
    
#     # Create output directory
#     os.makedirs(output_path, exist_ok=True)
    
#     # Convert to HuggingFace Dataset
#     dataset = Dataset.from_list(examples)
    
#     # Split into train/validation sets (90/10)
#     splits = dataset.train_test_split(test_size=0.1, seed=42)
    
#     # Save the dataset
#     splits["train"].to_json(os.path.join(output_path, "train_data.json"))
#     splits["test"].to_json(os.path.join(output_path, "val_data.json"))
    
#     # Format for different models
#     format_for_models(splits, output_path)
    
#     print(f"Dataset saved to {output_path}")
#     print(f"Train set: {len(splits['train'])} examples")
#     print(f"Validation set: {len(splits['test'])} examples")
    
#     return splits

# # Function to format for different models
# def format_for_models(dataset, output_path):
#     """Format dataset for different model types."""
#     # Format for Mistral
#     formatted_mistral_train = []
#     for example in dataset["train"]:
#         formatted_mistral_train.append({
#             "messages": [
#                 {"role": "user", "content": f"Review and critique the following research paper paragraph. Identify any logical issues, missing evidence, contradictions, or other problems:\n\n{example['paragraph']}"},
#                 {"role": "assistant", "content": example['critique']}
#             ]
#         })
    
#     formatted_mistral_val = []
#     for example in dataset["test"]:
#         formatted_mistral_val.append({
#             "messages": [
#                 {"role": "user", "content": f"Review and critique the following research paper paragraph. Identify any logical issues, missing evidence, contradictions, or other problems:\n\n{example['paragraph']}"},
#                 {"role": "assistant", "content": example['critique']}
#             ]
#         })
    
#     # Save Mistral format
#     with open(os.path.join(output_path, "train_mistral_format.json"), "w") as f:
#         json.dump(formatted_mistral_train, f, indent=2)
    
#     with open(os.path.join(output_path, "val_mistral_format.json"), "w") as f:
#         json.dump(formatted_mistral_val, f, indent=2)
    
#     print(f"Model-specific formatted data saved to {output_path}")

# # MAIN EXECUTION

# # Check if we need to use a token for Llama-2 access
# import getpass
# print("Note: You'll need a Hugging Face token with access to Llama-2 to use this script.")
# print("If you don't have one, consider using another model like TinyLlama or Phi-2.")

# try_llama = input("Do you want to try using Llama-2? (y/n): ")

# if try_llama.lower() == 'y':
#     # 1. Define search queries for arXiv
#     search_queries = [
#         "cat:cs.AI",  # Artificial Intelligence
#         "cat:cs.CL",  # Computational Linguistics
#         "cat:cs.LG",  # Machine Learning
#         "cat:cs.CV",  # Computer Vision
#         "cat:cs.NE",  # Neural and Evolutionary Computing
#         "cat:cs.IR",  # Information Retrieval
#         "cat:stat.ML",  # Machine Learning (Statistics)
#         "cat:cs.RO"   # Robotics
#     ]

#     # Ask user for number of papers to download per category
#     num_papers = int(input("Enter number of papers to download per category (5-10 recommended): "))
#     print(f"Will download approximately {num_papers} papers per category...")

#     # 2. Download papers from arXiv
#     pdf_paths = download_arxiv_papers(search_queries, max_results=num_papers)
#     print(f"Downloaded {len(pdf_paths)} papers")

#     # 3. Extract paragraphs from PDFs
#     paragraphs = extract_paragraphs_from_pdfs(pdf_paths)

#     # 4. Set up Llama-2 model
#     try:
#         tokenizer, model = setup_llama_model()

#         # 5. Ask user for dataset size
#         num_examples = int(input(f"Found {len(paragraphs)} paragraphs. How many examples to generate? "))
#         num_examples = min(num_examples, len(paragraphs))
#         print(f"Will generate {num_examples} examples...")

#         # 6. Create dataset
#         dataset_splits = create_dataset(paragraphs, tokenizer, model, num_examples)

#         # 7. Print information about saving
#         print("\nDataset generation complete!")
#         print("You can now save this dataset to your Kaggle account:")
#         print("1. Click on the 'Data' tab in the right panel")
#         print("2. Under 'Output', click '+ Save All'")
#         print("3. Enter a name for your dataset (e.g., 'research-paper-critique-data')")
#         print("4. Click 'Save'")
#         print("\nYou can then use this dataset in a new notebook for fine-tuning Mistral")

#         # Display a sample from the dataset
#         print("\nHere's a sample from the generated dataset:")
#         with open("./data/train_data.json", "r") as f:
#             data = json.load(f)
#             sample = data[0]
#             print("\nPARAGRAPH:")
#             print(sample["paragraph"])
#             print("\nCRITIQUE:")
#             print(sample["critique"])
#             print("\nISSUE TYPE:", sample["issue_type"])

#     except Exception as e:
#         print(f"Error loading Llama-2: {str(e)}")
#         print("Falling back to another model...")
#         use_alternative = True
# else:
#     use_alternative = True

# # Fallback to an alternative model if Llama-2 fails or user chose not to use it
# if 'use_alternative' in locals() and use_alternative:
#     print("\n\nUsing TinyLlama or Phi-2 instead...")
    
#     # Define function to use TinyLlama
#     def setup_tinyllama_model(device="cuda", load_in_8bit=True):
#         """Set up the TinyLlama model for text generation."""
#         print("Loading TinyLlama model...")
        
#         # Configure quantization if using CUDA
#         if device == "cuda" and torch.cuda.is_available() and load_in_8bit:
#             bnb_config = BitsAndBytesConfig(
#                 load_in_8bit=True,
#                 bnb_8bit_use_double_quant=True,
#                 bnb_8bit_quant_type="nf4",
#                 bnb_8bit_compute_dtype=torch.float16
#             )
            
#             # Load tokenizer and model - TinyLlama is open access
#             tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
#             model = AutoModelForCausalLM.from_pretrained(
#                 "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
#                 device_map="auto",
#                 quantization_config=bnb_config,
#                 trust_remote_code=True
#             )
#         else:
#             # CPU or non-quantized loading
#             tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
#             model = AutoModelForCausalLM.from_pretrained(
#                 "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
#                 device_map="auto" if device == "cuda" else None,
#                 trust_remote_code=True
#             )
        
#         # Set padding token if not set
#         if tokenizer.pad_token is None:
#             tokenizer.pad_token = tokenizer.eos_token
        
#         return tokenizer, model
    
#     # Define function to format TinyLlama prompt (same as Llama-2)
#     def format_tinyllama_prompt(instruction, content=None):
#         """Format a prompt for TinyLlama."""
#         if content:
#             return f"<s>[INST] {instruction}\n\n{content} [/INST]"
#         else:
#             return f"<s>[INST] {instruction} [/INST]"
    
#     # Run the same workflow with TinyLlama
#     search_queries = [
#         "cat:cs.AI",  # Artificial Intelligence
#         "cat:cs.CL",  # Computational Linguistics
#         "cat:cs.LG",  # Machine Learning
#         "cat:cs.CV",  # Computer Vision
#     ]
    
#     num_papers = int(input("Enter number of papers to download per category (5-10 recommended): "))
#     pdf_paths = download_arxiv_papers(search_queries, max_results=num_papers)
#     paragraphs = extract_paragraphs_from_pdfs(pdf_paths)
    
#     tokenizer, model = setup_tinyllama_model()
    
#     num_examples = int(input(f"Found {len(paragraphs)} paragraphs. How many examples to generate? "))
#     num_examples = min(num_examples, len(paragraphs))
    
#     dataset_splits = create_dataset(paragraphs, tokenizer, model, num_examples)
    
#     print("\nDataset generation complete!")
#     print("You can now save this dataset to your Kaggle account.")

In [None]:
# Llama 3 8B Dataset Generation for Research Paper Critiques
# This script downloads real papers from arXiv and uses Meta's Llama 3 8B to create a critique dataset

# Install required packages
!pip install -q arxiv==1.4.7 pymupdf==1.23.5 datasets==2.14.5 torch==2.1.0 
!pip install -q transformers==4.36.0 accelerate==0.25.0 bitsandbytes==0.41.0 tqdm==4.66.1

# Import necessary libraries
import os
import re
import json
import random
import time
import arxiv
import fitz  # PyMuPDF
import torch
from tqdm.notebook import tqdm
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

# Create directories
!mkdir -p ./data
!mkdir -p ./papers

# Set your Hugging Face token directly
os.environ["HF_TOKEN"] = "hf_token"  # Replace with your actual token

# Function to download papers from arXiv
def download_arxiv_papers(search_queries, max_results=10, output_dir="./papers"):
    """Download papers from arXiv using the API."""
    os.makedirs(output_dir, exist_ok=True)
    downloaded_papers = []
    
    for query in search_queries:
        print(f"Searching arXiv for: {query}")
        
        # Create a search client
        search = arxiv.Search(
            query=query,
            max_results=max_results,
            sort_by=arxiv.SortCriterion.SubmittedDate,
            sort_order=arxiv.SortOrder.Descending
        )
        
        # Download papers
        for result in tqdm(search.results(), desc=f"Downloading papers for '{query}'"):
            # Create a safe filename
            paper_id = result.get_short_id()
            filename = f"{paper_id}.pdf"
            filepath = os.path.join(output_dir, filename)
            
            # Skip if already downloaded
            if os.path.exists(filepath):
                downloaded_papers.append(filepath)
                continue
            
            try:
                # Download the paper
                result.download_pdf(dirpath=output_dir, filename=filename)
                downloaded_papers.append(filepath)
                
                # Be nice to arXiv API
                time.sleep(3)
            except Exception as e:
                print(f"Error downloading {paper_id}: {str(e)}")
    
    return downloaded_papers

# Function to extract paragraphs from PDFs
def extract_paragraphs_from_pdfs(pdf_paths, min_words=50, max_words=300):
    """Extract paragraphs from PDF files."""
    all_paragraphs = []
    
    for pdf_path in tqdm(pdf_paths, desc="Processing PDFs"):
        try:
            # Open the PDF
            doc = fitz.open(pdf_path)
            
            # Detect the starting page (skip front matter)
            start_page = 0
            for i in range(min(5, len(doc))):
                text = doc[i].get_text().lower()
                if any(marker in text for marker in ["introduction", "background", "1. introduction"]):
                    start_page = i
                    break
            
            # Extract text from each page
            for page_num in range(start_page, len(doc)):
                page = doc[page_num]
                blocks = page.get_text("blocks")  # Get text as blocks which preserves layout better
                
                for block in blocks:
                    # Block structure: (x0, y0, x1, y1, text, block_type, block_no)
                    text = block[4].strip()
                    
                    # Clean up the text
                    text = re.sub(r'\s+', ' ', text)  # Replace multiple spaces with single space
                    
                    # Skip if empty or too short
                    if not text or len(text.split()) < min_words:
                        continue
                        
                    # Skip likely headers, footers, figure captions, etc.
                    if re.match(r'^[0-9]+\.?\s*$', text):  # Just a number
                        continue
                    if re.match(r'^figure\s+[0-9]+', text.lower()):  # Figure caption
                        continue
                    if re.match(r'^table\s+[0-9]+', text.lower()):  # Table caption
                        continue
                    if re.match(r'^references', text.lower()):  # References section
                        continue
                    if text.isupper() and len(text.split()) < 15:  # ALL CAPS header
                        continue
                    
                    # Handle paragraphs that are too long
                    if len(text.split()) > max_words:
                        words = text.split()
                        chunks = []
                        for i in range(0, len(words), max_words):
                            chunk = ' '.join(words[i:i+max_words])
                            chunks.append(chunk)
                        all_paragraphs.extend(chunks)
                    else:
                        all_paragraphs.append(text)
        
        except Exception as e:
            print(f"Error processing {pdf_path}: {str(e)}")
    
    print(f"Extracted {len(all_paragraphs)} paragraphs from {len(pdf_paths)} PDFs")
    return all_paragraphs

# Function to set up Llama 3 model
def setup_llama3_model(load_in_4bit=True):
    """Set up the Llama 3 8B model for text generation."""
    print("Loading Llama 3 8B Instruct model...")
    
    # Configure quantization for memory efficiency
    if load_in_4bit:
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4"
        )
    else:
        bnb_config = BitsAndBytesConfig(
            load_in_8bit=True,
            bnb_8bit_use_double_quant=True,
            bnb_8bit_quant_type="nf4",
            bnb_8bit_compute_dtype=torch.float16
        )
    
    # Load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", 
                                             use_auth_token=os.environ.get("HF_TOKEN"))
    model = AutoModelForCausalLM.from_pretrained(
        "meta-llama/Meta-Llama-3-8B-Instruct",
        device_map="auto",
        quantization_config=bnb_config,
        trust_remote_code=True,
        use_auth_token=os.environ.get("HF_TOKEN")
    )
    
    # Set padding token if not set
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    return tokenizer, model

# Function to format Llama 3 prompt
def format_llama3_prompt(instruction, content=None):
    """Format a prompt according to Llama 3's expected format."""
    if content:
        return f"<|begin_of_text|><|user|>\n{instruction}\n\n{content}<|end_of_turn|>\n<|assistant|>"
    else:
        return f"<|begin_of_text|><|user|>\n{instruction}<|end_of_turn|>\n<|assistant|>"

# Function to extract Llama 3 response
def extract_llama3_response(generated_text):
    """Extract the model's response from the generated text."""
    response = generated_text.split("<|assistant|>")[-1].strip()
    if "<|end_of_turn|>" in response:
        response = response.split("<|end_of_turn|>")[0].strip()
    
    return response

# Fallback functions for TinyLlama (in case Llama 3 fails)
def format_tinyllama_prompt(instruction, content=None):
    if content:
        return f"<|user|>\n{instruction}\n\n{content}\n<|assistant|>"
    else:
        return f"<|user|>\n{instruction}\n<|assistant|>"

def extract_tinyllama_response(generated_text):
    response = generated_text.split("<|assistant|>")[-1].strip()
    if "<|user|>" in response:
        response = response.split("<|user|>")[0].strip()
    return response

# Fallback functions for Phi-2 (in case Llama 3 fails)
def format_phi2_prompt(instruction, content=None):
    if content:
        return f"Instruct: {instruction}\n\nInput: {content}\n\nOutput:"
    else:
        return f"Instruct: {instruction}\n\nOutput:"

def extract_phi2_response(generated_text):
    if "Output:" in generated_text:
        response = generated_text.split("Output:")[-1].strip()
        return response
    else:
        # As a fallback, return everything after the prompt
        return generated_text.split("Input:")[-1].strip()

# Function to generate a critique for a paragraph
def generate_critique(tokenizer, model, paragraph, issue_type=None, format_prompt=format_llama3_prompt, extract_response=extract_llama3_response):
    """
    Generate a critique for a paragraph using the loaded model.
    
    Args:
        tokenizer: Model tokenizer
        model: Loaded model
        paragraph: Research paper paragraph to critique
        issue_type: Specific issue type to focus on (optional)
        format_prompt: Function to format the prompt for the specific model
        extract_response: Function to extract the response from the specific model
    
    Returns:
        Dictionary with paragraph, critique, and issue_type
    """
    # Define issue types
    issue_types = [
        "missing_evidence",
        "logical_contradiction",
        "unclear_argument",
        "poor_citation",
        "grammar_spelling",
        "undefined_terminology",
        "statistical_error",
        "methodology_issue",
        "unsubstantiated_claim",
        "structural_issue",
        "well_written"  # Include examples with no issues
    ]
    
    # Select a random issue type if not specified
    if issue_type is None:
        issue_type = random.choice(issue_types)
    
    # Create instruction
    instruction = f"""You are an expert academic reviewer with years of experience reviewing research papers.
Analyze the following paragraph from a real research paper and provide a detailed critique.
Focus on identifying issues related to: {issue_type.replace('_', ' ')}.
If you genuinely find no issues, explain why the paragraph is well-written instead.
Your critique should be specific, actionable, and professional.

Provide only the critique - do not include any introductory text or explanations about your role."""
    
    # Format prompt
    prompt = format_prompt(instruction, paragraph)
    
    try:
        # Tokenize input
        inputs = tokenizer(prompt, return_tensors="pt")
        inputs = {k: v.to(model.device) for k, v in inputs.items()}
        
        # Generate text
        outputs = model.generate(
            **inputs,
            max_new_tokens=512,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id
        )
        
        # Decode output
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
        
        # Extract response
        critique = extract_response(generated_text)
        
        # Create result
        result = {
            "paragraph": paragraph,
            "critique": critique,
            "issue_type": issue_type
        }
        
        return result
    
    except Exception as e:
        print(f"Error generating critique: {str(e)}")
        return None

# Function to create dataset
def create_dataset(paragraphs, tokenizer, model, num_examples, output_path="./data", 
                  format_prompt=format_llama3_prompt, extract_response=extract_llama3_response):
    """
    Create a dataset of paragraph-critique pairs.
    
    Args:
        paragraphs: List of research paper paragraphs
        tokenizer: Model tokenizer
        model: Loaded model
        num_examples: Number of examples to generate
        output_path: Path to save the dataset
        format_prompt: Function to format prompts for the specific model
        extract_response: Function to extract responses from the specific model
    
    Returns:
        Dataset splits (train/val)
    """
    examples = []
    
    # Ensure we don't try to generate more examples than paragraphs
    num_examples = min(num_examples, len(paragraphs))
    
    # Shuffle paragraphs
    random.shuffle(paragraphs)
    
    # Create progress bar
    with tqdm(total=num_examples, desc="Generating critiques") as pbar:
        for i in range(num_examples):
            # Get paragraph
            paragraph = paragraphs[i]
            
            # Generate critique
            example = generate_critique(tokenizer, model, paragraph, 
                                       format_prompt=format_prompt, 
                                       extract_response=extract_response)
            
            if example:
                examples.append(example)
                pbar.update(1)
            
            # Clear CUDA cache occasionally
            if i % 5 == 4:  # Clear cache frequently for memory efficiency
                torch.cuda.empty_cache()
    
    # Create output directory
    os.makedirs(output_path, exist_ok=True)
    
    # Convert to HuggingFace Dataset
    dataset = Dataset.from_list(examples)
    
    # Split into train/validation sets (90/10)
    splits = dataset.train_test_split(test_size=0.1, seed=42)
    
    # Save the dataset
    splits["train"].to_json(os.path.join(output_path, "train_data.json"))
    splits["test"].to_json(os.path.join(output_path, "val_data.json"))
    
    # Format for Mistral
    format_for_mistral(splits, output_path)
    
    print(f"Dataset saved to {output_path}")
    print(f"Train set: {len(splits['train'])} examples")
    print(f"Validation set: {len(splits['test'])} examples")
    
    return splits

# Function to format for Mistral
def format_for_mistral(dataset, output_path):
    """Format dataset specifically for Mistral fine-tuning."""
    formatted_train = []
    
    for example in dataset["train"]:
        formatted_train.append({
            "messages": [
                {"role": "user", "content": f"Review and critique the following research paper paragraph. Identify any logical issues, missing evidence, contradictions, or other problems:\n\n{example['paragraph']}"},
                {"role": "assistant", "content": example['critique']}
            ]
        })
    
    formatted_val = []
    
    for example in dataset["test"]:
        formatted_val.append({
            "messages": [
                {"role": "user", "content": f"Review and critique the following research paper paragraph. Identify any logical issues, missing evidence, contradictions, or other problems:\n\n{example['paragraph']}"},
                {"role": "assistant", "content": example['critique']}
            ]
        })
    
    # Save formatted data
    with open(os.path.join(output_path, "train_mistral_format.json"), "w") as f:
        json.dump(formatted_train, f, indent=2)
    
    with open(os.path.join(output_path, "val_mistral_format.json"), "w") as f:
        json.dump(formatted_val, f, indent=2)
    
    print(f"Mistral-formatted data saved to {output_path}")

# Function to analyze dataset diversity
def analyze_dataset(dataset_path):
    """Analyze the diversity and quality of the generated dataset."""
    # Load dataset
    with open(dataset_path, 'r') as f:
        data = json.load(f)
    
    # Count issue types
    issue_counts = {}
    for example in data:
        issue_type = example['issue_type']
        issue_counts[issue_type] = issue_counts.get(issue_type, 0) + 1
    
    # Calculate average lengths
    paragraph_lengths = [len(example['paragraph'].split()) for example in data]
    critique_lengths = [len(example['critique'].split()) for example in data]
    
    avg_paragraph_length = sum(paragraph_lengths) / len(paragraph_lengths)
    avg_critique_length = sum(critique_lengths) / len(critique_lengths)
    
    # Print statistics
    print("\nDataset Analysis:")
    print(f"Total examples: {len(data)}")
    print(f"Average paragraph length: {avg_paragraph_length:.1f} words")
    print(f"Average critique length: {avg_critique_length:.1f} words")
    print("\nIssue type distribution:")
    for issue_type, count in sorted(issue_counts.items(), key=lambda x: x[1], reverse=True):
        print(f"- {issue_type}: {count} ({count/len(data)*100:.1f}%)")

# MAIN EXECUTION

# 1. Define search queries for arXiv - diverse academic fields
search_queries = [
    "cat:cs.AI",  # Artificial Intelligence
    "cat:cs.CL",  # Computational Linguistics
    "cat:cs.CV",  # Computer Vision
    "cat:cs.LG",  # Machine Learning
    "cat:stat.ML",  # Machine Learning (Statistics)
    "cat:cs.SE",  # Software Engineering
    "cat:physics.comp-ph",  # Computational Physics
    "cat:q-bio.QM",  # Quantitative Methods in Biology
    "cat:q-fin.ST",  # Statistical Finance
    "cat:cs.HC"  # Human-Computer Interaction
]

# Ask user for number of papers to download per category
num_papers = int(input("Enter number of papers to download per category (5-10 recommended): "))
print(f"Will download approximately {num_papers} papers per category...")

# 2. Download papers from arXiv
pdf_paths = download_arxiv_papers(search_queries, max_results=num_papers)
print(f"Downloaded {len(pdf_paths)} papers")

# 3. Extract paragraphs from PDFs
paragraphs = extract_paragraphs_from_pdfs(pdf_paths)

# 4. Set up model with 4-bit quantization for memory efficiency
format_prompt = format_llama3_prompt
extract_response = extract_llama3_response

try:
    print("Attempting to load Llama 3 8B with 4-bit quantization...")
    tokenizer, model = setup_llama3_model(load_in_4bit=True)
except Exception as e:
    print(f"Error loading Llama 3: {str(e)}")
    print("This might be due to Hugging Face token access issues or memory constraints.")
    
    fallback = input("Do you want to try a fallback model instead? (y/n): ")
    if fallback.lower() == 'y':
        # Offer alternative models
        print("\nAlternative models:")
        print("1. TinyLlama (1.1B parameters)")
        print("2. Phi-2 (2.7B parameters)")
        alt_choice = input("Choose a fallback model (1-2): ")
        
        if alt_choice == "1":
            # TinyLlama fallback
            from transformers import AutoModelForCausalLM, AutoTokenizer
            print("Loading TinyLlama model...")
            tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
            model = AutoModelForCausalLM.from_pretrained(
                "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
                device_map="auto",
                torch_dtype=torch.float16
            )
            # Use TinyLlama prompt format
            format_prompt = format_tinyllama_prompt
            extract_response = extract_tinyllama_response
        elif alt_choice == "2":
            # Phi-2 fallback
            from transformers import AutoModelForCausalLM, AutoTokenizer
            print("Loading Phi-2 model...")
            tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
            model = AutoModelForCausalLM.from_pretrained(
                "microsoft/phi-2",
                device_map="auto",
                trust_remote_code=True
            )
            # Use Phi-2 prompt format
            format_prompt = format_phi2_prompt
            extract_response = extract_phi2_response
        else:
            print("Invalid choice. Exiting.")
            exit(1)
    else:
        print("Exiting program.")
        exit(1)

# 5. Ask user for dataset size
max_size = len(paragraphs)
print(f"Found {max_size} paragraphs. How many examples would you like to generate?")
print(f"Recommended: At least 500 for good fine-tuning results, but more is better.")
print(f"Note: Starting with a smaller number (50-100) is good for testing.")
num_examples = int(input(f"Enter number of examples to generate (max {max_size}): "))
num_examples = min(num_examples, max_size)

# 6. Generate dataset
print(f"Generating {num_examples} examples. This will take some time...")
dataset_splits = create_dataset(paragraphs, tokenizer, model, num_examples, 
                               format_prompt=format_prompt, 
                               extract_response=extract_response)

# 7. Analyze the dataset
analyze_dataset("./data/train_data.json")

# 8. Print instructions for saving and using the dataset
print("\nDataset generation complete!")
print("To save this dataset to your Kaggle account:")
print("1. Click on the 'Data' tab in the right panel")
print("2. Under 'Output', click '+ Save All'")
print("3. Enter a name for your dataset (e.g., 'llama3-paper-critique-data')")
print("4. Click 'Save'")
print("\nYou can then use this dataset in a new notebook for fine-tuning Mistral")

# Display a sample from the dataset
print("\nHere's a sample from the generated dataset:")
with open("./data/train_data.json", "r") as f:
    data = json.load(f)
    sample = data[0]
    print("\nPARAGRAPH:")
    print(sample["paragraph"])
    print("\nCRITIQUE:")
    print(sample["critique"])
    print("\nISSUE TYPE:", sample["issue_type"])

Enter number of papers to download per category (5-10 recommended):  8


Will download approximately 8 papers per category...
Searching arXiv for: cat:cs.AI


Downloading papers for 'cat:cs.AI': 0it [00:00, ?it/s]

Searching arXiv for: cat:cs.CL


Downloading papers for 'cat:cs.CL': 0it [00:00, ?it/s]

Searching arXiv for: cat:cs.CV


Downloading papers for 'cat:cs.CV': 0it [00:00, ?it/s]

Searching arXiv for: cat:cs.LG


Downloading papers for 'cat:cs.LG': 0it [00:00, ?it/s]

Searching arXiv for: cat:stat.ML


Downloading papers for 'cat:stat.ML': 0it [00:00, ?it/s]

Searching arXiv for: cat:cs.SE


Downloading papers for 'cat:cs.SE': 0it [00:00, ?it/s]

Searching arXiv for: cat:physics.comp-ph


Downloading papers for 'cat:physics.comp-ph': 0it [00:00, ?it/s]

Searching arXiv for: cat:q-bio.QM


Downloading papers for 'cat:q-bio.QM': 0it [00:00, ?it/s]

Searching arXiv for: cat:q-fin.ST


Downloading papers for 'cat:q-fin.ST': 0it [00:00, ?it/s]

Searching arXiv for: cat:cs.HC


Downloading papers for 'cat:cs.HC': 0it [00:00, ?it/s]

Downloaded 80 papers


Processing PDFs:   0%|          | 0/80 [00:00<?, ?it/s]

Extracted 3741 paragraphs from 80 PDFs
Attempting to load Llama 3 8B with 4-bit quantization...
Loading Llama 3 8B Instruct model...




tokenizer_config.json:   0%|          | 0.00/51.0k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/73.0 [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


config.json:   0%|          | 0.00/654 [00:00<?, ?B/s]



model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]

Found 3741 paragraphs. How many examples would you like to generate?
Recommended: At least 500 for good fine-tuning results, but more is better.
Note: Starting with a smaller number (50-100) is good for testing.


Enter number of examples to generate (max 3741):  500


Generating 500 examples. This will take some time...


Generating critiques:   0%|          | 0/500 [00:00<?, ?it/s]

2025-04-15 21:13:48.853159: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744751629.104230      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744751629.177884      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Mistral-formatted data saved to ./data
Dataset saved to ./data
Train set: 450 examples
Validation set: 50 examples


JSONDecodeError: Extra data: line 2 column 1 (char 3312)

In [9]:
!pip install flash_attn

Collecting flash_attn
  Downloading flash_attn-2.7.4.post1.tar.gz (6.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.0/6.0 MB[0m [31m101.8 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: flash_attn
  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mpython setup.py bdist_wheel[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m See above for output.
  
  [1;35mnote[0m: This error originates from a subprocess, and is likely not a problem with pip.
  Building wheel for flash_attn (setup.py) ... [?25lerror
[31m  ERROR: Failed building wheel for flash_attn[0m[31m
[0m[?25h  Running setup.py clean for flash_attn
Failed to build flash_attn
[31mERROR: ERROR: Failed to build installable wheels for some pyproject.toml based projects (flash_attn)[0m[31m
[0m

In [None]:
import shutil

# Replace 'output' with your actual output directory name
shutil.make_archive('/kaggle/working/data', 'zip', '/kaggle/working/data')


In [None]:
from IPython.display import FileLink

# Display clickable link to download
FileLink('/kaggle/working/output_dir.zip')


In [13]:
# Mistral 7B Fine-Tuning Script for Research Paper Critique
# This script fine-tunes Mistral 7B on your research paper critique dataset

# Install required packages
!pip install -q transformers==4.36.0 peft==0.8.0 trl==0.7.4 accelerate==0.25.0
!pip install -q bitsandbytes==0.41.0 datasets==2.14.5 scipy==1.11.4 tqdm==4.66.1

# Import necessary libraries
import os
import json
import datetime
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer, 
    BitsAndBytesConfig,
    TrainingArguments,
    logging
)
from peft import (
    LoraConfig, 
    get_peft_model, 
    prepare_model_for_kbit_training,
    TaskType
)
from trl import SFTTrainer
from tqdm.notebook import tqdm

# Set up logging
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Create output directory
os.makedirs("./models", exist_ok=True)

# Configuration
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.1"  # Or v0.1 if preferred
TRAIN_FILE = "/kaggle/working/data/train_mistral_format.json"
VAL_FILE = "/kaggle/working/data/val_mistral_format.json"

# Hyperparameters
LEARNING_RATE = 2e-5
BATCH_SIZE = 2  # Adjust based on GPU memory
GRADIENT_ACCUMULATION = 4
EPOCHS = 3
MAX_SEQ_LENGTH = 2048
LOAD_IN_4BIT = True  # Use 4-bit quantization for memory efficiency

# Function to load datasets
def load_datasets(train_file, val_file):
    """Load train and validation datasets formatted for Mistral."""
    logger.info(f"Loading datasets from {train_file} and {val_file}")
    
    # Load datasets
    try:
        train_dataset = load_dataset('json', data_files=train_file)['train']
        val_dataset = load_dataset('json', data_files=val_file)['train']
        
        logger.info(f"Loaded {len(train_dataset)} training examples and {len(val_dataset)} validation examples")
        
        # Clean up any special tokens in the data
        def clean_text(example):
            for i, message in enumerate(example['messages']):
                if 'content' in message:
                    # Remove model-specific tokens like <|eot_id|>
                    message['content'] = message['content'].replace("<|eot_id|>", "").strip()
                    example['messages'][i] = message
            return example
        
        train_dataset = train_dataset.map(clean_text)
        val_dataset = val_dataset.map(clean_text)
        
        return train_dataset, val_dataset
    except Exception as e:
        logger.error(f"Error loading datasets: {e}")
        raise

# Set up Mistral model
def setup_mistral_model():
    """Set up the Mistral model for fine-tuning."""
    logger.info(f"Loading {MODEL_NAME}")
    
    # Configure quantization for memory efficiency
    if LOAD_IN_4BIT:
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4"
        )
    else:
        bnb_config = BitsAndBytesConfig(
            load_in_8bit=True,
            llm_int8_threshold=6.0,
            llm_int8_skip_modules=None,
            llm_int8_enable_fp32_cpu_offload=True
        )
    
    # Load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    
    # Set padding token if not set
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"
    
    # Load model with quantization
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True
    )
    
    model.config.use_cache = False
    
    return model, tokenizer

# Set up LoRA configuration
def get_lora_config():
    """Get LoRA configuration for Mistral model."""
    # Define target modules for Mistral
    target_modules = [
        "q_proj", 
        "k_proj", 
        "v_proj", 
        "o_proj",
        "gate_proj", 
        "up_proj", 
        "down_proj"
    ]
    
    # Create LoRA config
    lora_config = LoraConfig(
        r=16,
        lora_alpha=32,
        target_modules=target_modules,
        lora_dropout=0.05,
        bias="none",
        task_type=TaskType.CAUSAL_LM
    )
    
    return lora_config

# Set up training arguments
def get_training_args():
    """Create training arguments."""
    # Create a timestamp for the output directory
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    output_dir = os.path.join("./models", f"mistral-7b-critique_{timestamp}")
    
    # Create training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=EPOCHS,
        per_device_train_batch_size=BATCH_SIZE,
        gradient_accumulation_steps=GRADIENT_ACCUMULATION,
        optim="paged_adamw_32bit",
        save_steps=100,
        logging_steps=10,
        learning_rate=LEARNING_RATE,
        weight_decay=0.001,
        fp16=True,
        bf16=False,
        max_grad_norm=0.3,
        max_steps=-1,
        warmup_ratio=0.03,
        group_by_length=True,
        lr_scheduler_type="cosine",
        report_to="tensorboard",
        evaluation_strategy="steps",
        eval_steps=100,
        load_best_model_at_end=True,
    )
    
    return training_args

# Main execution
def main():
    # Load datasets
    train_dataset, val_dataset = load_datasets(TRAIN_FILE, VAL_FILE)

    # Setup model
    model, tokenizer = setup_mistral_model()
    
    # Prepare model for training
    model = prepare_model_for_kbit_training(model)
    
    # Get LoRA configuration
    lora_config = get_lora_config()
    
    # Apply LoRA to the model
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()
    
    # Get training arguments
    training_args = get_training_args()
    
    # Set up the SFT trainer
    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        peft_config=lora_config,
        tokenizer=tokenizer,
        max_seq_length=MAX_SEQ_LENGTH,
        packing=False,  # Disable packing for more stable training
        dataset_text_field="messages"  # Use messages field for ChatML format
    )
    
    # Train the model
    logger.info("Starting training...")
    trainer.train()
    
    # Save the fine-tuned model
    output_dir = os.path.join(training_args.output_dir, "final_model")
    trainer.save_model(output_dir)
    tokenizer.save_pretrained(output_dir)
    
    logger.info(f"Training complete. Model saved to {output_dir}")
    
    # Save a smaller file with just the LoRA adapters
    adapter_output_dir = os.path.join(training_args.output_dir, "lora_adapters")
    model.save_pretrained(adapter_output_dir)
    logger.info(f"LoRA adapters saved to {adapter_output_dir}")
    
    return output_dir

if __name__ == "__main__":
    main()

# To run inference with the fine-tuned model
def run_inference(model_path, prompt):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        device_map="auto",
        torch_dtype=torch.float16
    )
    
    messages = [
        {"role": "user", "content": prompt}
    ]
    
    inputs = tokenizer.apply_chat_template(
        messages, 
        return_tensors="pt"
    ).to(model.device)
    
    outputs = model.generate(
        inputs,
        max_new_tokens=1024,
        temperature=0.7,
        top_p=0.9,
        do_sample=True
    )
    
    response = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
    return response

# Example usage:
# paragraph = "Your research paragraph here..."
# prompt = f"Review and critique the following research paper paragraph. Identify any logical issues, missing evidence, contradictions, or other problems:\n\n{paragraph}"
# response = run_inference("./models/mistral-7b-critique_TIMESTAMP/final_model", prompt)
# print(response)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.2/183.2 kB[0m [31m11.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m133.9/133.9 kB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m123.6/123.6 kB[0m [31m9.0 MB/s[0m eta [36m0:00:00[0m
[?25h

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.4/60.4 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m36.4/36.4 MB[0m [31m48.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tsfresh 0.21.0 requires scipy>=1.14.0; python_version >= "3.10", but you have scipy 1.11.4 which is incompatible.
featuretools 1.31.0 requires tqdm>=4.66.3, but you have tqdm 4.66.1 which is incompatible.
nilearn 0.11.1 requires scikit-learn>=1.4.0, but you have scikit-learn 1.2.2 which is incompatible.
sentence-transformers 3.4.1 requires transformers<5.0.0,>=4.41.0, but you have transformers 4.36.0 which is incompatible.
bigframes 1.36.0 requires rich<14,>=12.4.4, but you have rich 14.0.0 which is incompatible.
imbalanced-learn 0.13.0 requires scik

RuntimeError: Failed to import diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion because of the following error (look up to see its traceback):
Failed to import diffusers.loaders.single_file because of the following error (look up to see its traceback):
No module named 'torch.sparse._triton_ops_meta'

In [1]:
# Kaggle notebook for testing your fine-tuned Mistral model
!pip install -q transformers==4.36.0 accelerate==0.25.0 bitsandbytes==0.41.0 tqdm==4.66.1

import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Set your model path (replace with your actual dataset name)
model_path = "/kaggle/input/mistral-llama/transformers/default/1/content/mistral-critique-model"  # Adjust this path

# Create offload directory if needed
os.makedirs("./offload_folder", exist_ok=True)

# Load tokenizer
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_path)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load model with proper quantization
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.float16,
    device_map="auto",
    load_in_8bit=True  # Use 8-bit quantization for efficient inference
)

# Function to generate critiques
def critique_paragraph(paragraph, max_new_tokens=1024):
    """Generate a research paper critique for a given paragraph"""
    # Format the prompt
    prompt = f"Review and critique the following research paper paragraph. Identify any logical issues, missing evidence, contradictions, or other problems:\n\n{paragraph}"
    
    # Format as chat for Mistral
    messages = [{"role": "user", "content": prompt}]
    
    # Apply chat template
    inputs = tokenizer.apply_chat_template(
        messages,
        return_tensors="pt"
    ).to(model.device)
    
    # Generate response
    with torch.no_grad():
        outputs = model.generate(
            inputs,
            max_new_tokens=max_new_tokens,
            temperature=0.7,
            top_p=0.9,
            do_sample=True
        )
    
    # Decode response
    response = tokenizer.decode(
        outputs[0][inputs.shape[1]:], 
        skip_special_tokens=True
    )
    
    return response

# Test examples
test_paragraphs = [
    """The integration of machine learning algorithms into healthcare settings has shown promising results, reducing diagnostic errors by 30% in preliminary studies. This improvement can be attributed to the ability of these systems to process vast amounts of patient data and identify patterns that might be missed by human practitioners. Furthermore, the implementation of these technologies has been well-received by medical professionals, with 85% reporting increased confidence in their diagnostic decisions when supported by AI tools.""",
    
    """Recent advances in natural language processing have enabled more accurate sentiment analysis in social media posts, with reported accuracy rates exceeding 90%. This represents a significant improvement over previous methods and opens new avenues for understanding public opinion at scale.""",
    
    """Our experiment showed a statistically significant effect (p < 0.05) of the new treatment on patient recovery times. The treatment group showed a 45% reduction in recovery time compared to the control group, suggesting this approach could revolutionize care standards."""
]

# Run tests
for i, paragraph in enumerate(test_paragraphs):
    print(f"\n\n=== Test Example {i+1} ===")
    print(f"\nParagraph:\n{paragraph}\n")
    print(f"Generating critique...")
    critique = critique_paragraph(paragraph)
    print(f"\nCritique:\n{critique}")
    print("\n" + "="*50)

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m126.8/126.8 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.6/57.6 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.2/8.2 MB[0m [31m90.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m265.7/265.7 kB[0m [31m18.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m92.6/92.6 MB[0m [31m18.9 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.3/78.3 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.6/3.6 MB[0m [31m91.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0

  _torch_pytree._register_pytree_node(


Loading tokenizer...


Exception: data did not match any variant of untagged enum PyPreTokenizerTypeWrapper at line 54 column 3