In [None]:
# Imports
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
import flwr as fl
from typing import Dict, List, Tuple
import numpy as np
from collections import OrderedDict
import os
from pathlib import Path
import tempfile
import logging
import json
from datetime import datetime
import time

In [None]:
# Set client ID (change for each client: 1,2,3,4)
CLIENT_ID = 2  # Change this for each client

In [None]:
# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(f"Client_{CLIENT_ID}")

def log_status(status: str, details: str = ""):
    status_line = f"\n{'='*20} {status} {'='*20}"
    logger.info(status_line)
    if details:
        logger.info(details)
    logger.info("="*len(status_line))
    print(status_line)
    if details:
        print(details)
    print("="*len(status_line))

# Set random seed
torch.manual_seed(42 + CLIENT_ID)

# Check CUDA availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log_status("DEVICE INFO", f"Client {CLIENT_ID} using device: {device}")

In [None]:
# Create temporary directory for outputs
temp_dir = tempfile.mkdtemp()
log_status("DIRECTORY INFO", f"Using temporary directory: {temp_dir}")

# Define message size
GRPC_MAX_MESSAGE_LENGTH = 1024 * 1024 * 1024

In [None]:
# Load dataset
log_status("LOADING DATASET", f"Client {CLIENT_ID} loading dataset...")
try:
    dataset = load_dataset("medalpaca/medical_meadow_medical_flashcards")
    
    # Select different ranges for each client
    start_idx = (CLIENT_ID - 1) * 50
    end_idx = CLIENT_ID * 50
    small_dataset = dataset['train'].select(range(start_idx, end_idx))
    
    log_status("DATASET LOADED", 
              f"Dataset size: {len(small_dataset)} examples\n"
              f"Range: {start_idx} to {end_idx}")
    
    # Extract test questions from the dataset
    test_indices = np.linspace(0, len(small_dataset)-1, 5, dtype=int)

    test_questions = []
    test_answers = []
    for idx in test_indices:
        example = small_dataset[int(idx)]  # Convert idx to Python int
        test_questions.append(example['input'])   # Ensure correct column names
        test_answers.append(example['output'])    # Ensure correct column names
    
    log_status("TEST QUESTIONS SELECTED", 
              f"Number of test questions: {len(test_questions)}")
    
except Exception as e:
    log_status("DATASET ERROR", str(e))
    raise

In [None]:
test_questions

In [None]:
# Format flashcards
def format_flashcard(example):
    return {
        'text': f"Question: {example['input']}\nAnswer: {example['output']}\n\n"
    }

formatted_dataset = small_dataset.map(format_flashcard)

In [None]:
# Initialize model and tokenizer
log_status("MODEL INITIALIZATION", "Loading model and tokenizer...")
try:
    model_name = "gpt2"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
    
    # Configure tokenizer
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        model.config.pad_token_id = model.config.eos_token_id
    
    log_status("MODEL LOADED", 
              f"Model: {model_name}\n"
              f"Parameters: {sum(p.numel() for p in model.parameters())}")
except Exception as e:
    log_status("MODEL ERROR", str(e))
    raise

In [None]:
# Function to generate answers
def generate_answer(question: str, max_length: int = 100) -> str:
    try:
        prompt = f"Question: {question}\nAnswer:"
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        
        outputs = model.generate(
            inputs["input_ids"],
            max_length=max_length,
            num_return_sequences=1,
            no_repeat_ngram_size=2,
            temperature=0.7
        )
        
        return tokenizer.decode(outputs[0], skip_special_tokens=True)
    except Exception as e:
        logger.error(f"Error generating answer: {e}")
        return f"Error generating answer: {str(e)}"

In [None]:
# Function to evaluate model responses
def evaluate_model_responses(phase="Before"):
    log_status(f"{phase.upper()} TRAINING EVALUATION", "Starting model evaluation...")
    responses = {}
    for q, a in zip(test_questions, test_answers):
        response = generate_answer(q)
        responses[q] = {
            'model_response': response,
            'ground_truth': a
        }
        print(f"\nQuestion: {q}")
        print(f"Model Response: {response}")
        print(f"Ground Truth: {a}")
        print("-" * 50)
    return responses

In [None]:
# Tokenize dataset
def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
        max_length=256,
        return_tensors="pt"
    )

log_status("TOKENIZATION", "Tokenizing dataset...")
tokenized_dataset = formatted_dataset.map(
    tokenize_function,
    remove_columns=formatted_dataset.column_names,
    batched=True
)

In [None]:
# Training arguments
training_args = TrainingArguments(
    output_dir=os.path.join(temp_dir, f"client_{CLIENT_ID}"),
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    warmup_steps=100,
    learning_rate=5e-5,
    fp16=torch.cuda.is_available(),
    logging_steps=10,
    save_strategy="epoch",
    evaluation_strategy="no",
    save_total_limit=2,
    overwrite_output_dir=True,
)

In [None]:
# Data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

In [None]:
# Define Flower client
class MedicalFlashcardsClient(fl.client.NumPyClient):
    def __init__(self):
        log_status("CLIENT INITIALIZATION", f"Client {CLIENT_ID} initializing...")
        self.trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=tokenized_dataset,
            data_collator=data_collator,
        )
        log_status("CLIENT READY", f"Client {CLIENT_ID} initialized and ready")
        
    def get_parameters(self, config: Dict[str, str]) -> List[np.ndarray]:
        log_status("PARAMETER RETRIEVAL", f"Client {CLIENT_ID}: Getting parameters")
        return [val.cpu().numpy() for _, val in model.state_dict().items()]
    
    def set_parameters(self, parameters: List[np.ndarray]) -> None:
        log_status("PARAMETER UPDATE", f"Client {CLIENT_ID}: Setting parameters")
        params_dict = zip(model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        model.load_state_dict(state_dict, strict=True)
    
    def fit(
        self, parameters: List[np.ndarray], config: Dict[str, str]
    ) -> Tuple[List[np.ndarray], int, Dict[str, float]]:
        log_status("TRAINING START", f"Client {CLIENT_ID}: Starting training round")
        self.set_parameters(parameters)
        self.trainer.train()
        log_status("TRAINING COMPLETE", f"Client {CLIENT_ID}: Completed training round")
        return self.get_parameters(config), len(tokenized_dataset), {}
    
    def evaluate(
        self, parameters: List[np.ndarray], config: Dict[str, str]
    ) -> Tuple[float, int, Dict[str, float]]:
        log_status("EVALUATION", f"Client {CLIENT_ID}: Evaluating model")
        self.set_parameters(parameters)
        metrics = self.trainer.evaluate()
        return float(metrics["eval_loss"]), len(tokenized_dataset), {"loss": float(metrics["eval_loss"])}


In [None]:
# Save model function
def save_model(path: str = None):
    try:
        if path is None:
            path = os.path.join(temp_dir, f"medical-model-client-{CLIENT_ID}")
        
        Path(path).mkdir(parents=True, exist_ok=True)
        model.save_pretrained(path)
        tokenizer.save_pretrained(path)
        log_status("MODEL SAVED", f"Model saved to {path}")
    except Exception as e:
        log_status("SAVE ERROR", str(e))
        # Try fallback location
        home_dir = os.path.expanduser("~")
        fallback_path = os.path.join(home_dir, f"medical_model_backup_client_{CLIENT_ID}")
        Path(fallback_path).mkdir(parents=True, exist_ok=True)
        model.save_pretrained(fallback_path)
        tokenizer.save_pretrained(fallback_path)
        log_status("FALLBACK SAVE", f"Model saved to fallback location: {fallback_path}")

In [None]:
# Calculate similarity between responses
def calculate_similarity(str1, str2):
    words1 = set(str1.lower().split())
    words2 = set(str2.lower().split())
    overlap = len(words1.intersection(words2))
    union = len(words1.union(words2))
    return overlap / union if union > 0 else 0

# Test before training
log_status("PRE-TRAINING EVALUATION", "Testing model before training...")
before_responses = evaluate_model_responses("Before")

In [None]:
# Start Flower client
log_status("CONNECTION SETUP", 
         f"Starting Flower client {CLIENT_ID}\n"
         f"Server address: 127.0.0.1:8081")

connection_attempts = 0
max_attempts = 3
retry_delay = 5

while connection_attempts < max_attempts:
    connection_attempts += 1
    try:
        log_status("CONNECTION ATTEMPT", f"Attempt {connection_attempts} of {max_attempts}")
        
        fl.client.start_numpy_client(
            server_address="127.0.0.1:8081",
            client=MedicalFlashcardsClient(),
            transport="grpc-bidi",
            grpc_max_message_length=GRPC_MAX_MESSAGE_LENGTH
        )
        
        log_status("TRAINING SUCCESS", "Client completed all training rounds")
        break
        
    except ConnectionRefusedError:
        log_status("CONNECTION REFUSED", 
                  f"Server not available (attempt {connection_attempts})")
        if connection_attempts < max_attempts:
            logger.info(f"Retrying in {retry_delay} seconds...")
            time.sleep(retry_delay)
    
    except Exception as e:
        log_status("CONNECTION ERROR", 
                  f"Error on attempt {connection_attempts}: {str(e)}")
        if connection_attempts < max_attempts:
            logger.info(f"Retrying in {retry_delay} seconds...")
            time.sleep(retry_delay)

In [None]:
# Test after training
log_status("POST-TRAINING EVALUATION", "Testing model after training...")
after_responses = evaluate_model_responses("After")

# Compare responses
comparison_data = {
    "client_id": CLIENT_ID,
    "timestamp": datetime.now().isoformat(),
    "connection_info": {
        "attempts": connection_attempts,
        "max_attempts": max_attempts,
        "status": "success" if connection_attempts < max_attempts else "failed"
    },
    "comparisons": []
}

log_status("RESULTS COMPARISON", "Analyzing before/after performance")
before_similarities = []
after_similarities = []

for q, a in zip(test_questions, test_answers):
    before_sim = calculate_similarity(a, before_responses[q]['model_response'])
    after_sim = calculate_similarity(a, after_responses[q]['model_response'])
    before_similarities.append(before_sim)
    after_similarities.append(after_sim)
    
    comparison = {
        "question": q,
        "ground_truth": a,
        "before": before_responses[q]['model_response'],
        "after": after_responses[q]['model_response'],
        "similarity_before": before_sim,
        "similarity_after": after_sim,
        "improvement": after_sim - before_sim
    }
    comparison_data["comparisons"].append(comparison)
    
    print(f"\nQuestion: {q}")
    print(f"Ground Truth: {a}")
    print(f"Before: {before_responses[q]['model_response']}")
    print(f"After: {after_responses[q]['model_response']}")
    print(f"Improvement: {(after_sim - before_sim) * 100:.1f}%")

In [None]:
# Calculate overall metrics
avg_improvement = np.mean(np.array(after_similarities) - np.array(before_similarities))
comparison_data["metrics"] = {
    "average_similarity_before": float(np.mean(before_similarities)),
    "average_similarity_after": float(np.mean(after_similarities)),
    "average_improvement": float(avg_improvement),
    "max_improvement": float(np.max(np.array(after_similarities) - np.array(before_similarities))),
    "min_improvement": float(np.min(np.array(after_similarities) - np.array(before_similarities)))
}

In [None]:
# Save results
results_path = os.path.join(temp_dir, f"client_{CLIENT_ID}_results.json")
with open(results_path, "w") as f:
    json.dump(comparison_data, f, indent=2)

log_status("FINAL STATISTICS", 
         f"Client {CLIENT_ID} Training Summary\n"
         f"Total examples: {len(small_dataset)}\n"
         f"Device: {device}\n"
         f"Model parameters: {sum(p.numel() for p in model.parameters())}\n"
         f"Average improvement: {avg_improvement * 100:.1f}%\n"
         f"Results saved: {results_path}")

# Save final model
save_model()