In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch
import pandas as pd
import random
import numpy as np
from tqdm import tqdm
import re

class LongTextParaphraser:
    def __init__(self, model_name="Vamsi/T5_Paraphrase_Paws", device=None):
        print(f"🚀 Loading model: {model_name}")
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.model = T5ForConditionalGeneration.from_pretrained(model_name).to(self.device)
        self.model.eval()
        print(f"✅ Model loaded on {self.device.upper()}")

    def smart_chunk_text(self, text, max_chunk_tokens=450, overlap_sentences=1):
        """
        Smart chunking that preserves semantic coherence and flow
        - Uses paragraph boundaries as primary split points
        - Maintains topic coherence within chunks
        - Adds overlap between chunks to preserve context
        - Falls back to sentence boundaries only when necessary
        """
        # First, try to split by paragraphs (double newlines or clear paragraph markers)
        paragraphs = re.split(r'\n\s*\n|\r\n\s*\r\n', text.strip())
        paragraphs = [p.strip() for p in paragraphs if p.strip()]

        if len(paragraphs) <= 1:
            # No clear paragraphs, use sentence-based chunking with overlap
            return self._sentence_chunk_with_overlap(text, max_chunk_tokens, overlap_sentences)

        chunks = []
        current_chunk = []
        current_tokens = 0

        for paragraph in paragraphs:
            para_tokens = len(self.tokenizer.tokenize(paragraph))

            # If single paragraph is too long, split it intelligently
            if para_tokens > max_chunk_tokens:
                if current_chunk:
                    chunks.append(self._join_with_context(current_chunk))
                    current_chunk = []
                    current_tokens = 0

                # Split long paragraph while preserving semantic units
                para_chunks = self._split_long_paragraph(paragraph, max_chunk_tokens, overlap_sentences)
                chunks.extend(para_chunks)
                continue

            # Check if adding this paragraph exceeds limit
            if current_tokens + para_tokens > max_chunk_tokens and current_chunk:
                chunks.append(self._join_with_context(current_chunk))
                current_chunk = [paragraph]
                current_tokens = para_tokens
            else:
                current_chunk.append(paragraph)
                current_tokens += para_tokens

        if current_chunk:
            chunks.append(self._join_with_context(current_chunk))

        return chunks

    def _split_long_paragraph(self, paragraph, max_tokens, overlap_sentences):
        """Split a long paragraph while maintaining semantic coherence"""
        sentences = re.split(r'(?<=[.!?])\s+', paragraph)
        chunks = []
        current_chunk = []
        current_tokens = 0

        for i, sentence in enumerate(sentences):
            sentence_tokens = len(self.tokenizer.tokenize(sentence))

            if current_tokens + sentence_tokens > max_tokens and current_chunk:
                # Add overlap from previous chunk
                overlap_start = max(0, len(current_chunk) - overlap_sentences)
                chunk_text = ' '.join(current_chunk)
                chunks.append(chunk_text)

                # Start new chunk with overlap
                overlap_sentences_list = current_chunk[overlap_start:] if overlap_sentences > 0 else []
                current_chunk = overlap_sentences_list + [sentence]
                current_tokens = sum(len(self.tokenizer.tokenize(s)) for s in current_chunk)
            else:
                current_chunk.append(sentence)
                current_tokens += sentence_tokens

        if current_chunk:
            chunks.append(' '.join(current_chunk))

        return chunks

    def _sentence_chunk_with_overlap(self, text, max_tokens, overlap_sentences):
        """Fallback sentence-based chunking with overlap for coherence"""
        sentences = re.split(r'(?<=[.!?])\s+', text.strip())
        chunks = []
        current_chunk = []
        current_tokens = 0

        for sentence in sentences:
            sentence_tokens = len(self.tokenizer.tokenize(sentence))

            if current_tokens + sentence_tokens > max_tokens and current_chunk:
                chunk_text = ' '.join(current_chunk)
                chunks.append(chunk_text)

                # Add overlap for context continuity
                overlap_start = max(0, len(current_chunk) - overlap_sentences)
                overlap_sentences_list = current_chunk[overlap_start:] if overlap_sentences > 0 else []
                current_chunk = overlap_sentences_list + [sentence]
                current_tokens = sum(len(self.tokenizer.tokenize(s)) for s in current_chunk)
            else:
                current_chunk.append(sentence)
                current_tokens += sentence_tokens

        if current_chunk:
            chunks.append(' '.join(current_chunk))

        return chunks

    def _join_with_context(self, paragraphs):
        """Join paragraphs while preserving formatting context"""
        return '\n\n'.join(paragraphs)

    def paraphrase(self, text, num_return_sequences=1, max_length=512, preserve_structure=True):
        """
        Enhanced paraphrase method with coherence preservation
        - preserve_structure: maintains paragraph breaks and overall flow
        """
        if not text.strip():
            return [text] * num_return_sequences

        token_count = len(self.tokenizer.tokenize(text))

        # For short texts, paraphrase directly
        if token_count <= 450:
            return self._paraphrase_chunk(text, num_return_sequences, max_length)

        # For long texts, use smart chunking
        chunks = self.smart_chunk_text(text, max_chunk_tokens=450, overlap_sentences=2)

        if len(chunks) == 1:
            # Text fits in one chunk after smart processing
            return self._paraphrase_chunk(chunks[0], num_return_sequences, max_length)

        # Process multiple chunks with coherence preservation
        return self._paraphrase_long_text(chunks, num_return_sequences, max_length, preserve_structure)

    def _paraphrase_long_text(self, chunks, num_return_sequences, max_length, preserve_structure):
        """Paraphrase long text while maintaining coherence across chunks"""
        all_results = []

        for seq_idx in range(num_return_sequences):
            paraphrased_chunks = []
            previous_context = ""

            for chunk_idx, chunk in enumerate(chunks):
                # Add context from previous chunk for better coherence
                if chunk_idx > 0 and previous_context:
                    # Use last sentence of previous chunk as context
                    context_prompt = f"paraphrase: {previous_context} {chunk} </s>"
                else:
                    context_prompt = f"paraphrase: {chunk} </s>"

                chunk_result = self._paraphrase_chunk_with_prompt(
                    context_prompt, 1, max_length
                )[0]

                # Remove potential repetition from context
                if chunk_idx > 0 and previous_context:
                    chunk_result = self._remove_context_overlap(chunk_result, previous_context)

                paraphrased_chunks.append(chunk_result)

                # Store context for next chunk (last sentence)
                sentences = re.split(r'(?<=[.!?])\s+', chunk_result.strip())
                previous_context = sentences[-1] if sentences else ""

            # Join chunks appropriately
            if preserve_structure:
                final_text = self._smart_join_chunks(paraphrased_chunks, chunks)
            else:
                final_text = ' '.join(paraphrased_chunks)

            all_results.append(final_text)

        return all_results

    def _paraphrase_chunk_with_prompt(self, prompt, num_return_sequences, max_length):
        """Paraphrase with custom prompt (for context-aware processing)"""
        encoding = self.tokenizer.encode_plus(
            prompt,
            padding="longest",
            return_tensors="pt",
            max_length=512,
            truncation=True
        ).to(self.device)

        with torch.no_grad():
            outputs = self.model.generate(
                input_ids=encoding["input_ids"],
                attention_mask=encoding["attention_mask"],
                max_length=max_length,
                num_beams=5,
                num_return_sequences=num_return_sequences,
                do_sample=True,
                top_k=120,
                top_p=0.95,
                early_stopping=True,
                temperature=0.7,
                repetition_penalty=1.1  # Reduce repetition
            )

        return [
            self.tokenizer.decode(out, skip_special_tokens=True, clean_up_tokenization_spaces=True)
            for out in outputs
        ]

    def _paraphrase_chunk(self, text, num_return_sequences, max_length):
        """Standard chunk paraphrasing"""
        prompt = f"paraphrase: {text.strip().replace(chr(10), ' ')} </s>"
        return self._paraphrase_chunk_with_prompt(prompt, num_return_sequences, max_length)

    def _remove_context_overlap(self, text, context):
        """Remove potential overlap from context at the beginning of text"""
        context_words = context.lower().split()[-5:]  # Last 5 words of context
        text_words = text.lower().split()

        # Find overlap
        for i in range(min(len(context_words), len(text_words))):
            if context_words[-i-1:] == text_words[:i+1]:
                # Remove overlap
                return ' '.join(text.split()[i+1:])

        return text

    def _smart_join_chunks(self, paraphrased_chunks, original_chunks):
        """Join chunks while preserving original structure cues"""
        result = []

        for i, (para_chunk, orig_chunk) in enumerate(zip(paraphrased_chunks, original_chunks)):
            # Preserve paragraph breaks from original
            if '\n\n' in orig_chunk:
                # Split and rejoin with paragraph breaks
                para_parts = para_chunk.split('. ')
                if len(para_parts) > 1:
                    mid_point = len(para_parts) // 2
                    rejoined = '. '.join(para_parts[:mid_point]) + '.\n\n' + '. '.join(para_parts[mid_point:])
                    result.append(rejoined)
                else:
                    result.append(para_chunk)
            else:
                result.append(para_chunk)

        return ' '.join(result)

# Enhanced augmentation function
def augment_minority_class(df, label_column='mental_state', text_column='text',
                          target_per_class=10000, min_samples=500):
    paraphraser = LongTextParaphraser()
    augmented_data = []

    class_counts = df[label_column].value_counts()
    minority_classes = class_counts[class_counts < target_per_class].index.tolist()

    print(f"🔍 Minority classes: {', '.join(minority_classes)}")

    for label in minority_classes:
        class_df = df[df[label_column] == label]
        current_count = len(class_df)
        needed = max(target_per_class - current_count, min_samples)

        print(f"\n⚙️ Augmenting '{label}' ({current_count} → {current_count + needed} samples)")
        print(f"📝 Generating {needed} augmented samples with coherence preservation...")

        # Create augmentation plan
        augmentation_plan = []
        base_samples = class_df[text_column].tolist()

        # Calculate how many augmentations per original sample
        per_sample = max(1, needed // current_count)
        remainder = needed % current_count

        for i, text in enumerate(base_samples):
            count = per_sample + (1 if i < remainder else 0)
            augmentation_plan.append((text, count))

        # Process with progress bar
        for text, count in tqdm(augmentation_plan, desc=f"Augmenting {label}"):
            try:
                # Use enhanced paraphrasing with structure preservation
                paraphrases = paraphraser.paraphrase(
                    text,
                    num_return_sequences=count,
                    max_length=512,
                    preserve_structure=True
                )

                for paraphrase in paraphrases:
                    augmented_data.append({
                        text_column: paraphrase,
                        label_column: label,
                        'source': 'augmented_coherent'
                    })
            except Exception as e:
                print(f"❌ Error augmenting text: {str(e)[:100]}...")
                # Fallback to original text if needed
                for _ in range(count):
                    augmented_data.append({
                        text_column: text,
                        label_column: label,
                        'source': 'original_fallback'
                    })

    # Create augmented DataFrame
    augmented_df = pd.DataFrame(augmented_data)

    # Combine with original data
    original_df = df.copy()
    original_df['source'] = 'original'
    combined_df = pd.concat([original_df, augmented_df], ignore_index=True)

    print(f"\n✅ Coherent augmentation complete! Final counts:")
    print(combined_df[label_column].value_counts())

    return combined_df

# Main execution
if __name__ == "__main__":
    print("📄 Loading data...")
    data = pd.read_csv("/content/drive/My Drive/combined_data.csv")
    df = data.dropna().reset_index(drop=True)
    df = df[["text", "mental_state"]]

    print("\n🧠 Starting coherent augmentation for minority classes...")
    df_augmented = augment_minority_class(
        df,
        target_per_class=10000,
        min_samples=1000
    )

    output_path = "/content/drive/My Drive/augmented_data_coherent_t5_yy.csv"
    df_augmented.to_csv(output_path, index=False)
    print(f"\n💾 Saved coherently augmented dataset to: {output_path}")

📄 Loading data...

🧠 Starting coherent augmentation for minority classes...
🚀 Loading model: Vamsi/T5_Paraphrase_Paws


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

config.json: 0.00B [00:00, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]

✅ Model loaded on CUDA
🔍 Minority classes: bipolar, lonely, stress, ptsd, personality disorder

⚙️ Augmenting 'bipolar' (7765 → 10000 samples)
📝 Generating 2235 augmented samples with coherence preservation...


Augmenting bipolar:   0%|          | 32/7765 [01:33<4:36:22,  2.14s/it]