In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score


warnings.filterwarnings('ignore')

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [11]:
import pandas as pd
import numpy as np
import random
from collections import Counter
from typing import List, Dict, Tuple
import re
from sklearn.utils import resample
import nltk
from nltk.corpus import wordnet
from transformers import pipeline, AutoTokenizer, AutoModel
import torch
from sentence_transformers import SentenceTransformer
import warnings
warnings.filterwarnings('ignore')

# Download required NLTK data
try:
    nltk.download('wordnet', quiet=True)
    nltk.download('omw-1.4', quiet=True)
    nltk.download('punkt', quiet=True)
    nltk.download('averaged_perceptron_tagger', quiet=True)
except:
    print("NLTK data download failed - some features may not work")

class MentalHealthTextAugmenter:
    def __init__(self, target_samples_per_class=15000, min_words=500, max_words=15000):
        """
        Advanced text augmentation for mental health classification

        Args:
            target_samples_per_class: Target number of samples per class after augmentation
            min_words: Minimum word count for augmented texts
            max_words: Maximum word count for augmented texts
        """
        self.target_samples_per_class = target_samples_per_class
        self.min_words = min_words
        self.max_words = max_words

        # Initialize models (lazy loading)
        self.paraphrase_model = None
        self.sentence_model = None

        # Mental health domain-specific synonyms and patterns
        self.domain_synonyms = {
            'anxiety': ['worry', 'nervousness', 'unease', 'apprehension', 'concern', 'fear'],
            'depression': ['sadness', 'melancholy', 'despair', 'gloom', 'dejection'],
            'stress': ['pressure', 'strain', 'tension', 'burden', 'overwhelm'],
            'lonely': ['isolated', 'alone', 'solitary', 'disconnected', 'friendless'],
            'suicidal': ['hopeless', 'desperate', 'helpless', 'lost', 'trapped'],
            'bipolar': ['mood swings', 'emotional instability', 'ups and downs'],
            'ptsd': ['trauma', 'flashbacks', 'nightmares', 'triggers'],
            'feel': ['experience', 'sense', 'perceive', 'undergo', 'encounter'],
            'think': ['believe', 'consider', 'contemplate', 'reflect', 'ponder'],
            'help': ['support', 'assistance', 'aid', 'guidance', 'relief']
        }

    def _init_models(self):
        """Initialize heavy models only when needed"""
        if self.paraphrase_model is None:
            try:
                # Use a lightweight paraphrasing approach
                self.paraphrase_model = pipeline(
                    "text2text-generation",
                    model="t5-small",
                    device=0 if torch.cuda.is_available() else -1
                )
            except:
                print("Warning: Paraphrase model failed to load, using synonym replacement only")
                self.paraphrase_model = "failed"

        if self.sentence_model is None:
            try:
                self.sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
            except:
                print("Warning: Sentence model failed to load")
                self.sentence_model = "failed"

    def get_synonyms(self, word: str, pos: str = None) -> List[str]:
        """Get synonyms for a word using WordNet and domain knowledge"""
        synonyms = set()

        # Check domain-specific synonyms first
        word_lower = word.lower()
        for key, values in self.domain_synonyms.items():
            if word_lower == key or word_lower in values:
                synonyms.update(values)

        # Use WordNet for additional synonyms
        try:
            for syn in wordnet.synsets(word):
                for lemma in syn.lemmas():
                    synonym = lemma.name().replace('_', ' ')
                    if synonym.lower() != word.lower():
                        synonyms.add(synonym)
        except:
            pass

        return list(synonyms)[:3]  # Limit to top 3 synonyms

    def synonym_replacement(self, text: str, n: int = None) -> str:
        """Replace n random words with their synonyms"""
        words = text.split()
        if n is None:
            n = max(1, len(words) // 20)  # Replace ~5% of words

        new_words = words.copy()
        random_word_list = list(set([word for word in words if word.isalpha()]))
        random.shuffle(random_word_list)

        num_replaced = 0
        for random_word in random_word_list:
            synonyms = self.get_synonyms(random_word)
            if synonyms and num_replaced < n:
                synonym = random.choice(synonyms)
                new_words = [synonym if word == random_word else word for word in new_words]
                num_replaced += 1

        return ' '.join(new_words)

    def random_insertion(self, text: str, n: int = None) -> str:
        """Randomly insert n synonyms into the sentence"""
        words = text.split()
        if n is None:
            n = max(1, len(words) // 25)  # Insert ~4% new words

        for _ in range(n):
            random_word = random.choice([w for w in words if w.isalpha()])
            synonyms = self.get_synonyms(random_word)
            if synonyms:
                random_synonym = random.choice(synonyms)
                random_idx = random.randint(0, len(words))
                words.insert(random_idx, random_synonym)

        return ' '.join(words)

    def random_swap(self, text: str, n: int = None) -> str:
        """Randomly swap two words in the sentence n times"""
        words = text.split()
        if len(words) < 4:
            return text

        if n is None:
            n = max(1, len(words) // 30)  # Swap ~3% of words

        for _ in range(n):
            idx1, idx2 = random.sample(range(len(words)), 2)
            words[idx1], words[idx2] = words[idx2], words[idx1]

        return ' '.join(words)

    def random_deletion(self, text: str, p: float = 0.05) -> str:
        """Randomly delete words from the sentence with probability p"""
        words = text.split()
        if len(words) == 1:
            return text

        new_words = []
        for word in words:
            if random.uniform(0, 1) > p:
                new_words.append(word)

        # If all words were deleted, return original
        if len(new_words) == 0:
            return text

        return ' '.join(new_words)

    def sentence_level_augmentation(self, text: str) -> str:
        """Apply sentence-level transformations"""
        sentences = re.split(r'[.!?]+', text)
        augmented_sentences = []

        for sentence in sentences:
            sentence = sentence.strip()
            if not sentence:
                continue

            # Randomly apply transformations
            if random.random() < 0.3:  # 30% chance
                # Sentence reordering within paragraphs
                words = sentence.split()
                if len(words) > 5:
                    # Split into chunks and reorder
                    mid_point = len(words) // 2
                    if random.random() < 0.5:
                        sentence = ' '.join(words[mid_point:] + words[:mid_point])

            augmented_sentences.append(sentence)

        return '. '.join(augmented_sentences)

    def length_normalization(self, text: str, target_length: int) -> str:
        """Normalize text length by expansion or contraction"""
        words = text.split()
        current_length = len(words)

        if current_length < target_length * 0.8:
            # Expand text
            expansion_factor = target_length / current_length
            if expansion_factor > 1.5:
                # Use repetition and elaboration
                elaboration_phrases = [
                    "I really feel that", "It's important to note that", "What I mean is",
                    "To elaborate further", "Additionally", "Furthermore", "In other words",
                    "This makes me think", "I can't help but feel", "It seems to me that"
                ]

                sentences = re.split(r'[.!?]+', text)
                expanded_sentences = []

                for sentence in sentences:
                    if sentence.strip():
                        expanded_sentences.append(sentence.strip())
                        if random.random() < 0.4 and len(expanded_sentences) < target_length // 20:
                            phrase = random.choice(elaboration_phrases)
                            elaboration = f"{phrase} {sentence.strip().lower()}"
                            expanded_sentences.append(elaboration)

                text = '. '.join(expanded_sentences)

        elif current_length > target_length * 1.2:
            # Contract text by removing less important sentences
            sentences = re.split(r'[.!?]+', text)
            # Keep first and last sentences, randomly sample middle ones
            if len(sentences) > 3:
                keep_ratio = target_length / current_length
                middle_sentences = sentences[1:-1]
                keep_count = max(1, int(len(middle_sentences) * keep_ratio))
                kept_middle = random.sample(middle_sentences, min(keep_count, len(middle_sentences)))
                sentences = [sentences[0]] + kept_middle + [sentences[-1]]
                text = '. '.join([s.strip() for s in sentences if s.strip()])

        return text

    def augment_text(self, text: str, augmentation_type: str = 'mixed') -> str:

         original_length = len(text.split())

    # Calculate bounds
         min_len = max(self.min_words, int(original_length * 0.7))
         max_len = min(self.max_words, int(original_length * 1.5))

    # ✅ Safe-guard: avoid ValueError when min_len > max_len
         if min_len > max_len:

           target_length = original_length  # fallback
         else:
            target_length = random.randint(min_len, max_len)

         augmented = text

         if augmentation_type in ['mixed', 'eda']:
        # Apply EDA techniques
            techniques = [
            self.synonym_replacement,
            self.random_insertion,
            self.random_swap,
            self.random_deletion
        ]
            num_techniques = random.randint(1, 2)
            selected_techniques = random.sample(techniques, num_techniques)

            for technique in selected_techniques:
                  if technique == self.random_deletion:
                     augmented = technique(augmented, p=0.03)  # Lower deletion rate
                  else:
                      augmented = technique(augmented)

         if augmentation_type in ['mixed', 'sentence']:
        # Apply sentence-level transformations
              augmented = self.sentence_level_augmentation(augmented)

    # Normalize to target length
         augmented = self.length_normalization(augmented, target_length)

         return augmented


    def calculate_augmentation_needs(self, df: pd.DataFrame) -> Dict[str, int]:
        """Calculate how many samples each class needs"""
        class_counts = df['mental_state'].value_counts()
        augmentation_needs = {}

        for class_name in class_counts.index:
            current_count = class_counts[class_name]
            needed = max(0, self.target_samples_per_class - current_count)
            augmentation_needs[class_name] = needed

        return augmentation_needs

    def augment_class(self, class_df: pd.DataFrame, needed_samples: int,
                     class_name: str) -> pd.DataFrame:
        """Augment a specific class"""
        if needed_samples <= 0:
            return class_df

        print(f"Augmenting {class_name}: {len(class_df)} -> {len(class_df) + needed_samples}")

        augmented_data = []
        original_texts = class_df['text'].tolist()

        # Create multiple augmented versions of each text
        samples_per_original = max(1, needed_samples // len(original_texts))

        for text in original_texts:
            for i in range(samples_per_original):
                if len(augmented_data) >= needed_samples:
                    break

                # Use different augmentation strategies
                strategies = ['mixed', 'eda', 'sentence']
                strategy = random.choice(strategies)

                augmented_text = self.augment_text(text, strategy)

                # Quality check - ensure minimum difference from original
                if len(set(augmented_text.split()) - set(text.split())) > 5:
                    augmented_data.append({
                        'text': augmented_text,
                        'mental_state': class_name,
                        'augmented': True
                    })

            if len(augmented_data) >= needed_samples:
                break

        # If we still need more samples, repeat with higher variation
        while len(augmented_data) < needed_samples:
            text = random.choice(original_texts)
            augmented_text = self.augment_text(text, 'mixed')
            augmented_data.append({
                'text': augmented_text,
                'mental_state': class_name,
                'augmented': True
            })

        return pd.DataFrame(augmented_data[:needed_samples])

    def augment_dataset(self, df: pd.DataFrame) -> pd.DataFrame:
        """Augment the entire dataset"""
        print("=== Mental Health Text Augmentation ===")
        print(f"Original dataset shape: {df.shape}")
        print(f"Target samples per class: {self.target_samples_per_class}")

        # Add augmented flag to original data
        df = df.copy()
        df['augmented'] = False

        # Calculate augmentation needs
        augmentation_needs = self.calculate_augmentation_needs(df)

        print("\nAugmentation needs:")
        for class_name, needed in augmentation_needs.items():
            current = len(df[df['mental_state'] == class_name])
            print(f"  {class_name}: {current} -> {current + needed} (+{needed})")

        # Augment each class
        augmented_dfs = [df]

        for class_name, needed_samples in augmentation_needs.items():
            if needed_samples > 0:
                class_df = df[df['mental_state'] == class_name]
                augmented_class_df = self.augment_class(class_df, needed_samples, class_name)
                augmented_dfs.append(augmented_class_df)

        # Combine all data
        final_df = pd.concat(augmented_dfs, ignore_index=True)

        print(f"\nFinal dataset shape: {final_df.shape}")
        print(f"Augmented samples: {final_df['augmented'].sum()}")

        # Show final distribution
        print("\nFinal class distribution:")
        final_counts = final_df['mental_state'].value_counts()
        for class_name, count in final_counts.items():
            percentage = (count / len(final_df)) * 100
            print(f"  {class_name}: {count} ({percentage:.1f}%)")

        return final_df

# Usage example and utility functions
def load_and_preprocess_data(file_path: str) -> pd.DataFrame:
    """Load and preprocess the mental health dataset"""
    df = pd.read_csv(file_path)

    # Basic preprocessing
    df = df.dropna(subset=['text', 'mental_state'])
    df['text'] = df['text'].astype(str)
    df['mental_state'] = df['mental_state'].astype(str)

    # Remove very short texts (less than 50 words)
    df['word_count'] = df['text'].apply(lambda x: len(x.split()))
    df = df[df['word_count'] >= 50].copy()

    return df[['text', 'mental_state']]

def analyze_text_lengths(df: pd.DataFrame):
    """Analyze text length distribution by class"""
    print("=== Text Length Analysis ===")
    df['word_count'] = df['text'].apply(lambda x: len(x.split()))

    for class_name in df['mental_state'].unique():
        class_texts = df[df['mental_state'] == class_name]['word_count']
        print(f"\n{class_name}:")
        print(f"  Count: {len(class_texts)}")
        print(f"  Mean length: {class_texts.mean():.0f} words")
        print(f"  Median length: {class_texts.median():.0f} words")
        print(f"  Length range: {class_texts.min()}-{class_texts.max()} words")

# Main execution example
if __name__ == "__main__":
    # Load your data
    print("Loading data...")
    df = load_and_preprocess_data('/content/drive/My Drive/combined_data.csv')

    # Analyze original distribution
    analyze_text_lengths(df)

    # Initialize augmenter with balanced targets
    augmenter = MentalHealthTextAugmenter(
        target_samples_per_class=12000,  # Adjust based on your needs
        min_words=1000,
        max_words=12000
    )

    # Augment the dataset
    augmented_df = augmenter.augment_dataset(df)

    # Save the augmented dataset
    augmented_df.to_csv("/content/drive/My Drive/augmented_mental_health_data.csv", index=False)
    print("\nAugmented dataset saved to 'augmented_mental_health_data.csv'")

    # Analyze final distribution
    print("\n" + "="*50)
    analyze_text_lengths(augmented_df)

Loading data...
=== Text Length Analysis ===

anxiety:
  Count: 14481
  Mean length: 197 words
  Median length: 151 words
  Length range: 50-3009 words

normal:
  Count: 1314
  Mean length: 85 words
  Median length: 79 words
  Length range: 50-255 words

depression:
  Count: 40532
  Mean length: 214 words
  Median length: 156 words
  Length range: 50-4830 words

suicidal:
  Count: 19335
  Mean length: 207 words
  Median length: 148 words
  Length range: 50-5248 words

stress:
  Count: 2219
  Mean length: 128 words
  Median length: 95 words
  Length range: 50-1606 words

bipolar:
  Count: 6800
  Mean length: 207 words
  Median length: 154 words
  Length range: 50-4804 words

personality disorder:
  Count: 876
  Mean length: 214 words
  Median length: 168 words
  Length range: 50-5419 words

lonely:
  Count: 2158
  Mean length: 187 words
  Median length: 140 words
  Length range: 50-2078 words

ptsd:
  Count: 1249
  Mean length: 233 words
  Median length: 174 words
  Length range: 50-450