In [17]:
from datasets import DatasetDict
from modular.data_setup import triplet_collate_fn,TripletDataset,mean_pooling
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from accelerate import Accelerator
from tqdm import tqdm
import logging

In [3]:
ds= DatasetDict.from_csv("data/r_depression_posts.csv")
ds =ds.remove_columns("Unnamed: 0")
ds

Dataset({
    features: ['text', 'url', 'positive', 'negative'],
    num_rows: 32165
})

In [4]:
train_val, test = ds.train_test_split(test_size=0.03, seed=42).values()

val_size = 0.17 / (1 - 0.03)  
train, val = train_val.train_test_split(test_size=val_size, seed=42).values()

ds_splits = DatasetDict({
    'train': train,
    'validation': val,
    'test': test
})
ds_splits

DatasetDict({
    train: Dataset({
        features: ['text', 'url', 'positive', 'negative'],
        num_rows: 25731
    })
    validation: Dataset({
        features: ['text', 'url', 'positive', 'negative'],
        num_rows: 5469
    })
    test: Dataset({
        features: ['text', 'url', 'positive', 'negative'],
        num_rows: 965
    })
})

In [5]:
checkpoint = 'sentence-transformers/all-MiniLM-L6-v2'

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModel.from_pretrained(checkpoint)


# THIS IS WHAT YOU HAVE TO DO !


In [6]:
# # Tokenize sentences
# encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')

# # Compute token embeddings
# with torch.no_grad():
#     model_output = model(**encoded_input)

# # Perform pooling
# sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])

# # Normalize embeddings
# sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)

# print("Sentence embeddings:")
# print(sentence_embeddings)


In [7]:
train_dataset = TripletDataset(ds_splits["train"], tokenizer)
val_dataset = TripletDataset(ds_splits["validation"], tokenizer)


In [None]:

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=triplet_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=triplet_collate_fn)


In [9]:
def calculate_triplet_accuracy(anchor_emb, positive_emb, negative_emb):
    pos_sim = F.cosine_similarity(anchor_emb, positive_emb, dim=1)
    neg_sim = F.cosine_similarity(anchor_emb, negative_emb, dim=1)
    accuracy = (pos_sim > neg_sim).float().mean()
    return accuracy.item()

In [10]:
triplet_loss = torch.nn.TripletMarginLoss(margin=1)
optimizer= torch.optim.AdamW(model.parameters(),lr=2e-5)

accelerator = Accelerator()
model, optimizer, train_loader, val_loader = accelerator.prepare(
    model, optimizer, train_loader, val_loader
)

In [11]:
# from huggingface_hub import get_full_repo_name

# model_name = "all-MiniLM-L6-v2-finetuned-on-rDepression"
# repo_name = get_full_repo_name(model_name)
# repo_name

In [12]:
def get_embeddings(model, input_ids, attention_mask):
    """Get normalized sentence embeddings"""
    with torch.no_grad() if not model.training else torch.enable_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        embeddings = mean_pooling(outputs, attention_mask)
        embeddings = F.normalize(embeddings, p=2, dim=1)
    return embeddings

In [13]:
def validate_model(model, val_loader, triplet_loss, accelerator):
    """Validation function"""
    model.eval()
    total_loss = 0
    total_accuracy = 0
    num_batches = 0
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validating", disable=not accelerator.is_local_main_process):
            anchor_ids = batch['anchor_input_ids']
            anchor_mask = batch['anchor_attention_mask']
            positive_ids = batch['positive_input_ids']
            positive_mask = batch['positive_attention_mask']
            negative_ids = batch['negative_input_ids']
            negative_mask = batch['negative_attention_mask']
            
            # Get embeddings
            anchor_emb = get_embeddings(model, anchor_ids, anchor_mask)
            positive_emb = get_embeddings(model, positive_ids, positive_mask)
            negative_emb = get_embeddings(model, negative_ids, negative_mask)
            
            # Calculate loss
            loss = triplet_loss(anchor_emb, positive_emb, negative_emb)
            
            # Calculate accuracy
            accuracy = calculate_triplet_accuracy(anchor_emb, positive_emb, negative_emb)
            
            total_loss += loss.item()
            total_accuracy += accuracy
            num_batches += 1
    
    avg_loss = total_loss / num_batches
    avg_accuracy = total_accuracy / num_batches
    
    return avg_loss, avg_accuracy

In [None]:
from transformers import get_scheduler

num_epochs = 3
num_update_steps_per_epoch = len(train_loader)
num_training_steps = num_epochs * num_update_steps_per_epoch

gradient_accumulation_steps = 1
logging_steps = 50
eval_steps = 500


lr_scheduler = get_scheduler(
    "cosine",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

# Training Loop

In [15]:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


In [None]:
# Training loop
model.train()
global_step = 0
best_val_accuracy = 0

for epoch in range(num_epochs):
    logger.info(f"Starting epoch {epoch + 1}/{num_epochs}")
    
    epoch_loss = 0
    epoch_accuracy = 0
    num_batches = 0
    
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}", disable=not accelerator.is_local_main_process)
    
    for batch in progress_bar:
        anchor_ids = batch['anchor_input_ids']
        anchor_mask = batch['anchor_attention_mask']
        positive_ids = batch['positive_input_ids']
        positive_mask = batch['positive_attention_mask']
        negative_ids = batch['negative_input_ids']
        negative_mask = batch['negative_attention_mask']
        
        # Forward pass
        model.train()
        
        # Get embeddings (with gradients)
        anchor_outputs = model(input_ids=anchor_ids, attention_mask=anchor_mask)
        anchor_emb = mean_pooling(anchor_outputs, anchor_mask)
        anchor_emb = F.normalize(anchor_emb, p=2, dim=1)
        
        positive_outputs = model(input_ids=positive_ids, attention_mask=positive_mask)
        positive_emb = mean_pooling(positive_outputs, positive_mask)
        positive_emb = F.normalize(positive_emb, p=2, dim=1)
        
        negative_outputs = model(input_ids=negative_ids, attention_mask=negative_mask)
        negative_emb = mean_pooling(negative_outputs, negative_mask)
        negative_emb = F.normalize(negative_emb, p=2, dim=1)
        
        # Calculate loss
        loss = triplet_loss(anchor_emb, positive_emb, negative_emb)
        
        # Calculate accuracy for monitoring
        with torch.no_grad():
            accuracy = calculate_triplet_accuracy(anchor_emb, positive_emb, negative_emb)
        
        # Backward pass
        accelerator.backward(loss)
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        if (global_step + 1) % gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        
        # Update metrics
        epoch_loss += loss.item()
        epoch_accuracy += accuracy
        num_batches += 1
        global_step += 1
        
        # Update progress bar
        progress_bar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{accuracy:.4f}'
        })
        
        # Logging
        if global_step % logging_steps == 0:
            avg_loss = epoch_loss / num_batches
            avg_acc = epoch_accuracy / num_batches
            logger.info(f"Step {global_step} - Loss: {avg_loss:.4f}, Accuracy: {avg_acc:.4f}")
        
        # Validation
        if global_step % eval_steps == 0:
            val_loss, val_accuracy = validate_model(model, val_loader, triplet_loss, accelerator)
            logger.info(f"Validation - Loss: {val_loss:.4f}, Accuracy: {val_accuracy:.4f}")
            
            # Save best model
            if val_accuracy > best_val_accuracy:
                best_val_accuracy = val_accuracy
                if accelerator.is_main_process:
                    accelerator.save_state(f"best_model_checkpoint")
                    logger.info(f"New best model saved with accuracy: {val_accuracy:.4f}")
            
            model.train()
    
    # End of epoch validation
    val_loss, val_accuracy = validate_model(model, val_loader, triplet_loss, accelerator)
    avg_train_loss = epoch_loss / num_batches
    avg_train_accuracy = epoch_accuracy / num_batches
    
    logger.info(f"Epoch {epoch + 1} completed:")
    logger.info(f"  Train Loss: {avg_train_loss:.4f}, Train Accuracy: {avg_train_accuracy:.4f}")
    logger.info(f"  Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")
    
    # Save checkpoint at end of epoch
    if accelerator.is_main_process:
        accelerator.save_state(f"epoch_{epoch + 1}_checkpoint")

# Final model save
if accelerator.is_main_process:
    accelerator.save_state("final_model_checkpoint")
    logger.info("Training completed and final model saved!")

# # Optional: Save the model in HuggingFace format
# if accelerator.is_main_process:
#     unwrapped_model = accelerator.unwrap_model(model)
#     unwrapped_model.save_pretrained("fine_tuned_sentence_transformer")
#     tokenizer.save_pretrained("fine_tuned_sentence_transformer")
#     logger.info("Model saved in HuggingFace format!")


INFO:__main__:Starting epoch 1/3
Epoch 1:   0%|          | 0/403 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 48.00 MiB. GPU 0 has a total capacity of 9.62 GiB of which 64.50 MiB is free. Including non-PyTorch memory, this process has 7.86 GiB memory in use. Of the allocated memory 7.60 GiB is allocated by PyTorch, and 7.35 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)