In [None]:
#!/usr/bin/env python3
"""
Improved LoRA Fine-tuning Script for Causal Language Models
Supports both instruction-following and conversation formats
"""

import argparse
import json
import logging
import os
import sys
from typing import Dict, List, Optional, Union
import warnings

import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    TrainingArguments, 
    Trainer,
    DataCollatorForLanguageModeling,
    EarlyStoppingCallback
)
from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
from datasets import load_dataset, Dataset
from transformers.trainer_utils import get_last_checkpoint
import wandb

# Setup logging
logging.basicConfig(
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    level=logging.INFO
)
logger = logging.getLogger(__name__)

def parse_arguments():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser(description="Fine-tune language models with LoRA")
    
    # Data arguments
    parser.add_argument("--train_data", type=str, required=True,
                       help="Path to training data (JSON/JSONL)")
    parser.add_argument("--valid_data", type=str, required=True,
                       help="Path to validation data (JSON/JSONL)")
    
    # Model arguments
    parser.add_argument("--model", type=str, required=True,
                       help="Model name or path (e.g., microsoft/DialoGPT-medium)")
    parser.add_argument("--output_dir", type=str, required=True,
                       help="Output directory for model and checkpoints")
    
    # Training arguments
    parser.add_argument("--epochs", type=int, default=3,
                       help="Number of training epochs")
    parser.add_argument("--batch_size", type=int, default=4,
                       help="Training batch size per device")
    parser.add_argument("--eval_batch_size", type=int, default=4,
                       help="Evaluation batch size per device")
    parser.add_argument("--learning_rate", type=float, default=2e-4,
                       help="Learning rate")
    parser.add_argument("--warmup_steps", type=int, default=100,
                       help="Number of warmup steps")
    parser.add_argument("--max_length", type=int, default=512,
                       help="Maximum sequence length")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1,
                       help="Gradient accumulation steps")
    
    # LoRA arguments
    parser.add_argument("--lora_r", type=int, default=16,
                       help="LoRA rank")
    parser.add_argument("--lora_alpha", type=int, default=32,
                       help="LoRA alpha parameter")
    parser.add_argument("--lora_dropout", type=float, default=0.1,
                       help="LoRA dropout")
    parser.add_argument("--target_modules", type=str, nargs="+", 
                       default=["q_proj", "v_proj", "k_proj", "o_proj"],
                       help="Target modules for LoRA")
    
    # Advanced arguments
    parser.add_argument("--resume_from_checkpoint", type=str, default=None,
                       help="Resume training from checkpoint")
    parser.add_argument("--use_fp16", action="store_true", default=True,
                       help="Use FP16 training")
    parser.add_argument("--use_8bit", action="store_true",
                       help="Use 8-bit quantization")
    parser.add_argument("--gradient_checkpointing", action="store_true",
                       help="Enable gradient checkpointing")
    parser.add_argument("--early_stopping_patience", type=int, default=3,
                       help="Early stopping patience")
    parser.add_argument("--wandb_project", type=str, default=None,
                       help="Weights & Biases project name")
    parser.add_argument("--data_format", type=str, choices=["instruction", "conversation"], 
                       default="instruction",
                       help="Data format: instruction (input/output) or conversation")
    parser.add_argument("--instruction_template", type=str, 
                       default="### Instruction:\n{instruction}\n\n### Response:\n{response}",
                       help="Template for instruction format")
    
    return parser.parse_args()

def setup_model_and_tokenizer(args):
    """Load and setup model and tokenizer"""
    logger.info(f"Loading model and tokenizer: {args.model}")
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        args.model, 
        use_fast=False,
        trust_remote_code=True
    )
    
    # Add padding token if missing
    if tokenizer.pad_token is None:
        if tokenizer.eos_token is not None:
            tokenizer.pad_token = tokenizer.eos_token
        else:
            tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    
    # Load model with quantization if specified
    model_kwargs = {
        "trust_remote_code": True,
        "torch_dtype": torch.float16 if args.use_fp16 else torch.float32,
    }
    
    if args.use_8bit:
        model_kwargs.update({
            "load_in_8bit": True,
            "device_map": "auto",
        })
    
    model = AutoModelForCausalLM.from_pretrained(args.model, **model_kwargs)
    
    # Resize token embeddings if we added new tokens
    if len(tokenizer) > model.config.vocab_size:
        model.resize_token_embeddings(len(tokenizer))
    
    # Prepare model for training
    if args.use_8bit:
        model = prepare_model_for_kbit_training(model)
    
    if args.gradient_checkpointing:
        model.gradient_checkpointing_enable()
    
    # Setup LoRA
    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=args.lora_r,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        target_modules=args.target_modules,
        bias="none",
    )
    
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()
    
    return model, tokenizer

def load_and_preprocess_data(args, tokenizer):
    """Load and preprocess training data"""
    logger.info("Loading and preprocessing data...")
    
    # Determine file format
    file_format = "json" if args.train_data.endswith('.json') else "text"
    
    # Load datasets
    train_ds = load_dataset(file_format, data_files=args.train_data, split="train")
    valid_ds = load_dataset(file_format, data_files=args.valid_data, split="train")
    
    logger.info(f"Training samples: {len(train_ds)}")
    logger.info(f"Validation samples: {len(valid_ds)}")
    
    def preprocess_instruction_format(examples):
        """Preprocess data in instruction format (input/output)"""
        inputs = []
        
        for i in range(len(examples["input"])):
            instruction = examples["input"][i]
            response = examples["output"][i]
            
            # Format using template
            text = args.instruction_template.format(
                instruction=instruction,
                response=response
            )
            inputs.append(text)
        
        # Tokenize
        model_inputs = tokenizer(
            inputs,
            truncation=True,
            padding="max_length",
            max_length=args.max_length,
            return_tensors=None
        )
        
        # For causal LM, labels are the same as input_ids
        model_inputs["labels"] = model_inputs["input_ids"].copy()
        
        return model_inputs
    
    def preprocess_conversation_format(examples):
        """Preprocess data in conversation format"""
        inputs = []
        
        for conversation in examples["conversation"]:
            # Convert conversation to text
            text = ""
            for turn in conversation:
                role = turn.get("role", "user")
                content = turn.get("content", "")
                text += f"{role}: {content}\n"
            
            inputs.append(text.strip())
        
        # Tokenize
        model_inputs = tokenizer(
            inputs,
            truncation=True,
            padding="max_length",
            max_length=args.max_length,
            return_tensors=None
        )
        
        model_inputs["labels"] = model_inputs["input_ids"].copy()
        
        return model_inputs
    
    # Choose preprocessing function based on data format
    if args.data_format == "instruction":
        preprocess_fn = preprocess_instruction_format
    else:
        preprocess_fn = preprocess_conversation_format
    
    # Apply preprocessing
    train_ds = train_ds.map(
        preprocess_fn,
        batched=True,
        remove_columns=train_ds.column_names,
        desc="Preprocessing training data"
    )
    
    valid_ds = valid_ds.map(
        preprocess_fn,
        batched=True,
        remove_columns=valid_ds.column_names,
        desc="Preprocessing validation data"
    )
    
    return train_ds, valid_ds

def setup_training_arguments(args):
    """Setup training arguments"""
    
    # Initialize wandb if specified
    if args.wandb_project:
        wandb.init(project=args.wandb_project)
        report_to = "wandb"
    else:
        report_to = None
    
    training_args = TrainingArguments(
        output_dir=args.output_dir,
        
        # Training parameters
        num_train_epochs=args.epochs,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.eval_batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        learning_rate=args.learning_rate,
        warmup_steps=args.warmup_steps,
        
        # Optimization
        fp16=args.use_fp16,
        optim="adamw_torch",
        lr_scheduler_type="cosine",
        
        # Evaluation and saving
        evaluation_strategy="steps",
        eval_steps=100,
        save_strategy="steps",
        save_steps=100,
        save_total_limit=3,
        
        # Logging
        logging_strategy="steps",
        logging_steps=10,
        report_to=report_to,
        
        # Early stopping
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        
        # Memory optimization
        dataloader_pin_memory=False,
        remove_unused_columns=False,
    )
    
    return training_args

class CustomTrainer(Trainer):
    """Custom trainer with additional features"""
    
    def compute_loss(self, model, inputs, return_outputs=False):
        """Custom loss computation"""
        labels = inputs.get("labels")
        outputs = model(**inputs)
        
        if labels is not None:
            # Shift labels for causal LM
            shift_logits = outputs.logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            
            # Calculate loss only on non-padded tokens
            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        else:
            loss = outputs.loss
        
        return (loss, outputs) if return_outputs else loss

def main():
    """Main training function"""
    args = parse_arguments()
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Save arguments
    with open(os.path.join(args.output_dir, "training_args.json"), "w") as f:
        json.dump(vars(args), f, indent=2)
    
    # Setup model and tokenizer
    model, tokenizer = setup_model_and_tokenizer(args)
    
    # Load and preprocess data
    train_ds, valid_ds = load_and_preprocess_data(args, tokenizer)
    
    # Setup training arguments
    training_args = setup_training_arguments(args)
    
    # Data collator
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False,  # Causal LM, not masked LM
        pad_to_multiple_of=8 if args.use_fp16 else None,
    )
    
    # Setup trainer
    trainer = CustomTrainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=valid_ds,
        data_collator=data_collator,
        tokenizer=tokenizer,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=args.early_stopping_patience)]
    )
    
    # Resume from checkpoint if specified
    checkpoint = None
    if args.resume_from_checkpoint:
        checkpoint = args.resume_from_checkpoint
    elif os.path.isdir(args.output_dir):
        checkpoint = get_last_checkpoint(args.output_dir)
    
    if checkpoint:
        logger.info(f"Resuming training from {checkpoint}")
    
    # Train the model
    logger.info("Starting training...")
    trainer.train(resume_from_checkpoint=checkpoint)
    
    # Save the final model
    logger.info("Saving final model...")
    trainer.save_model()
    trainer.save_state()
    
    # Save tokenizer
    tokenizer.save_pretrained(args.output_dir)
    
    # Save training metrics
    if trainer.state.log_history:
        with open(os.path.join(args.output_dir, "training_log.json"), "w") as f:
            json.dump(trainer.state.log_history, f, indent=2)
    
    logger.info(f"Training completed! Model saved to {args.output_dir}")
    
    # Test the model with a sample
    test_generation(model, tokenizer, args)

def test_generation(model, tokenizer, args):
    """Test the trained model with sample generation"""
    logger.info("Testing model generation...")
    
    # Sample prompt
    if args.data_format == "instruction":
        prompt = "### Instruction:\nWhat is machine learning?\n\n### Response:\n"
    else:
        prompt = "user: Hello, how are you?\nassistant:"
    
    # Tokenize and generate
    inputs = tokenizer(prompt, return_tensors="pt")
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=100,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=tokenizer.pad_token_id
        )
    
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    logger.info(f"Generated text:\n{generated_text}")

if __name__ == "__main__":
    main()

In [None]:
!python train_lora.py \
  --train_data train.json \
  --valid_data valid.json \
  --model meta-llama/Llama-2-7b-hf \
  --output_dir ./llama-finetuned \
  --epochs 3 \
  --batch_size 2 \
  --use_8bit \
  --gradient_checkpointing \
  --lora_r 32 \
  --lora_alpha 64 \
  --wandb_project my-finetuning

In [None]:
### RAG

# Install dependencies (run in Colab)
!pip install -U google-cloud-aiplatform chromadb vertexai

import vertexai
from vertexai.language_models import TextEmbeddingModel
from vertexai.generative_models import GenerativeModel
import chromadb

# Initialize Vertex AI (replace with your project details)
PROJECT_ID = "your-project-id"  # Replace with your actual project ID
LOCATION = "us-central1"  # or your preferred location
vertexai.init(project=PROJECT_ID, location=LOCATION)

# 1. Prepare corpus and get embeddings
texts = ["これは最初の文書です。", "これは二番目の文書です。"]
EMBED_MODEL = "text-multilingual-embedding-002"

try:
    embedding_model = TextEmbeddingModel.from_pretrained(EMBED_MODEL)
    doc_embeddings = []
    
    for text in texts:
        embedding_result = embedding_model.get_embeddings([text])
        doc_embeddings.append(embedding_result[0].values)
    
    print(f"Generated {len(doc_embeddings)} document embeddings")
    
except Exception as e:
    print(f"Error generating embeddings: {e}")
    exit()

# 2. Build vector database (Chroma)
try:
    chroma_client = chromadb.Client()
    
    # Delete collection if it exists (for testing)
    try:
        chroma_client.delete_collection("docs")
    except:
        pass
    
    collection = chroma_client.create_collection("docs")
    
    # Add documents with embeddings
    for i, (text, emb) in enumerate(zip(texts, doc_embeddings)):
        collection.add(
            documents=[text], 
            embeddings=[emb], 
            ids=[str(i)]
        )
    
    print("Vector database created successfully")
    
except Exception as e:
    print(f"Error creating vector database: {e}")
    exit()

# 3. Process user query
query = "二番目の文書について教えて"

try:
    # Get query embedding
    query_embedding_result = embedding_model.get_embeddings([query])
    query_emb = query_embedding_result[0].values
    
    # Search for similar documents
    results = collection.query(
        query_embeddings=[query_emb], 
        n_results=2
    )
    
    if results['documents'] and len(results['documents'][0]) > 0:
        top_context = results['documents'][0][0]  # Top document text
        print(f"Retrieved context: {top_context}")
    else:
        print("No relevant documents found")
        exit()
        
except Exception as e:
    print(f"Error processing query: {e}")
    exit()

# 4. Generate response with Gemini
try:
    # Use the newer GenerativeModel class
    model = GenerativeModel("gemini-1.5-pro")
    
    prompt = f"""参考文書: {top_context}

質問: {query}

上記の参考文書に基づいて質問に答えてください。

答え:"""
    
    response = model.generate_content(prompt)
    print(f"\nFinal Response: {response.text}")
    
except Exception as e:
    print(f"Error generating response: {e}")
    
    # Fallback: try with different model name
    try:
        model = GenerativeModel("gemini-1.5-flash")
        response = model.generate_content(prompt)
        print(f"\nFinal Response (fallback): {response.text}")
    except Exception as e2:
        print(f"Fallback also failed: {e2}")
        print("Please check your model access permissions and available models")

In [None]:
### GraphRAG

import networkx as nx
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from typing import List, Dict, Any, Tuple

class GraphRAG:
    def __init__(self, embedding_model=None, llm=None):
        self.G = nx.Graph()
        self.embedding_model = embedding_model
        self.llm = llm
    
    def add_document(self, doc_id: int, text: str, embedding: np.ndarray = None):
        """Add a document node to the graph"""
        if embedding is None and self.embedding_model:
            # Generate embedding if not provided
            embedding = self.embedding_model.get_embeddings([text])[0].values
        
        self.G.add_node(doc_id, text=text, embedding=embedding)
    
    def add_relation(self, doc_id1: int, doc_id2: int, relation_type: str = "reference"):
        """Add an edge between two documents"""
        if doc_id1 in self.G.nodes and doc_id2 in self.G.nodes:
            self.G.add_edge(doc_id1, doc_id2, type=relation_type)
        else:
            raise ValueError(f"One or both document IDs ({doc_id1}, {doc_id2}) not found in graph")
    
    def find_similar_nodes(self, query_embedding: np.ndarray, top_k: int = 3) -> List[Tuple[int, float]]:
        """Find top-k most similar nodes to query embedding"""
        similarities = []
        
        for node_id in self.G.nodes():
            node_embedding = self.G.nodes[node_id]['embedding']
            if node_embedding is not None:
                # Calculate cosine similarity
                similarity = cosine_similarity(
                    query_embedding.reshape(1, -1), 
                    node_embedding.reshape(1, -1)
                )[0][0]
                similarities.append((node_id, similarity))
        
        # Sort by similarity (descending) and return top-k
        similarities.sort(key=lambda x: x[1], reverse=True)
        return similarities[:top_k]
    
    def get_subgraph_context(self, node_ids: List[int], max_hops: int = 1) -> str:
        """Get context from nodes and their neighbors within max_hops"""
        context_nodes = set(node_ids)
        
        # Add neighbors within max_hops
        for node_id in node_ids:
            if node_id in self.G.nodes:
                # Get neighbors within max_hops using BFS
                neighbors = nx.single_source_shortest_path_length(
                    self.G, node_id, cutoff=max_hops
                )
                context_nodes.update(neighbors.keys())
        
        # Collect text from all context nodes
        context_texts = []
        for node_id in context_nodes:
            if node_id in self.G.nodes and 'text' in self.G.nodes[node_id]:
                text = self.G.nodes[node_id]['text']
                context_texts.append(f"文書{node_id}: {text}")
        
        return "\n".join(context_texts)
    
    def query(self, query: str, top_k: int = 3, max_hops: int = 1) -> str:
        """Main query method for GraphRAG"""
        if not self.embedding_model or not self.llm:
            raise ValueError("Both embedding_model and llm must be provided")
        
        # 1. Get query embedding
        try:
            query_embedding = self.embedding_model.get_embeddings([query])[0].values
        except AttributeError:
            # Handle different embedding model interfaces
            query_embedding = self.embedding_model.encode([query])[0]
        
        # 2. Find similar nodes
        similar_nodes = self.find_similar_nodes(query_embedding, top_k)
        
        if not similar_nodes:
            return "関連する文書が見つかりませんでした。"
        
        # 3. Get subgraph context
        top_node_ids = [node_id for node_id, _ in similar_nodes]
        context = self.get_subgraph_context(top_node_ids, max_hops)
        
        # 4. Generate response using LLM
        prompt = f"""参考文書:
{context}

質問: {query}

上記の参考文書の内容に基づいて、質問に答えてください。

答え:"""
        
        try:
            response = self.llm.predict(prompt)
            return response.text if hasattr(response, 'text') else str(response)
        except Exception as e:
            return f"LLMによる回答生成中にエラーが発生しました: {str(e)}"

# Usage example
def example_usage():
    """Example of how to use the GraphRAG class"""
    
    # Initialize GraphRAG
    graph_rag = GraphRAG()  # You would pass your actual embedding_model and llm here
    
    # Sample embeddings (in practice, these would come from your embedding model)
    doc_embeddings = [
        np.random.rand(384),  # Example embedding dimension
        np.random.rand(384),
        np.random.rand(384)
    ]
    
    # Add documents
    graph_rag.add_document(0, "これは最初の文書です。人工知能について説明しています。", doc_embeddings[0])
    graph_rag.add_document(1, "これは二番目の文書です。機械学習について詳しく述べています。", doc_embeddings[1])
    graph_rag.add_document(2, "これは三番目の文書です。深層学習の応用について書かれています。", doc_embeddings[2])
    
    # Add relations
    graph_rag.add_relation(0, 1, "reference")
    graph_rag.add_relation(1, 2, "related")
    
    # Example query (you would need actual embedding_model and llm)
    query = "機械学習について教えてください"
    
    # For demonstration without actual models:
    print("GraphRAG setup completed!")
    print(f"Graph has {graph_rag.G.number_of_nodes()} nodes and {graph_rag.G.number_of_edges()} edges")
    
    # Show graph structure
    print("\nGraph structure:")
    for node in graph_rag.G.nodes(data=True):
        print(f"Node {node[0]}: {node[1]['text'][:50]}...")
    
    for edge in graph_rag.G.edges(data=True):
        print(f"Edge {edge[0]}-{edge[1]}: {edge[2]['type']}")

if __name__ == "__main__":
    example_usage()

In [None]:
# Install dependencies (run in Colab)
!pip install -U google-cloud-aiplatform chromadb vertexai

import vertexai
from vertexai.language_models import TextEmbeddingModel
from vertexai.generative_models import GenerativeModel
import chromadb

# Initialize Vertex AI (replace with your project details)
PROJECT_ID = "your-project-id"  # Replace with your actual project ID
LOCATION = "us-central1"  # or your preferred location
vertexai.init(project=PROJECT_ID, location=LOCATION)

# 1. Prepare corpus and get embeddings
texts = ["これは最初の文書です。", "これは二番目の文書です。"]
EMBED_MODEL = "text-multilingual-embedding-002"

try:
    embedding_model = TextEmbeddingModel.from_pretrained(EMBED_MODEL)
    doc_embeddings = []
    
    for text in texts:
        embedding_result = embedding_model.get_embeddings([text])
        doc_embeddings.append(embedding_result[0].values)
    
    print(f"Generated {len(doc_embeddings)} document embeddings")
    
except Exception as e:
    print(f"Error generating embeddings: {e}")
    exit()

# 2. Build vector database (Chroma)
try:
    chroma_client = chromadb.Client()
    
    # Delete collection if it exists (for testing)
    try:
        chroma_client.delete_collection("docs")
    except:
        pass
    
    collection = chroma_client.create_collection("docs")
    
    # Add documents with embeddings
    for i, (text, emb) in enumerate(zip(texts, doc_embeddings)):
        collection.add(
            documents=[text], 
            embeddings=[emb], 
            ids=[str(i)]
        )
    
    print("Vector database created successfully")
    
except Exception as e:
    print(f"Error creating vector database: {e}")
    exit()

# 3. Process user query
query = "二番目の文書について教えて"

try:
    # Get query embedding
    query_embedding_result = embedding_model.get_embeddings([query])
    query_emb = query_embedding_result[0].values
    
    # Search for similar documents
    results = collection.query(
        query_embeddings=[query_emb], 
        n_results=2
    )
    
    if results['documents'] and len(results['documents'][0]) > 0:
        top_context = results['documents'][0][0]  # Top document text
        print(f"Retrieved context: {top_context}")
    else:
        print("No relevant documents found")
        exit()
        
except Exception as e:
    print(f"Error processing query: {e}")
    exit()

# 4. Generate response with Gemini
try:
    # Use the newer GenerativeModel class
    model = GenerativeModel("gemini-1.5-pro")
    
    prompt = f"""参考文書: {top_context}

質問: {query}

上記の参考文書に基づいて質問に答えてください。

答え:"""
    
    response = model.generate_content(prompt)
    print(f"\nFinal Response: {response.text}")
    
except Exception as e:
    print(f"Error generating response: {e}")
    
    # Fallback: try with different model name
    try:
        model = GenerativeModel("gemini-1.5-flash")
        response = model.generate_content(prompt)
        print(f"\nFinal Response (fallback): {response.text}")
    except Exception as e2:
        print(f"Fallback also failed: {e2}")
        print("Please check your model access permissions and available models")