In [67]:
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
from sentence_transformers import InputExample
from torch.utils.data import DataLoader
from sentence_transformers import losses
from torch import nn
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
from sentence_transformers import SentenceTransformer
import json
import os, random
from scipy.stats import spearmanr
import numpy as np

class CustomTripletLoss(nn.Module):
    def __init__(self, margin=1.0, distance_metric='euclidean', positive_margin=0.5):
        super(CustomTripletLoss, self).__init__()
        self.margin = margin
        self.positive_margin = positive_margin
        self.distance_metric = distance_metric

    def forward(self, anchor, positive, negative):
        if self.distance_metric == 'euclidean':
            distance_pos = torch.sqrt(torch.sum((anchor - positive) ** 2, dim=1))
            distance_neg = torch.sqrt(torch.sum((anchor - negative) ** 2, dim=1))
        elif self.distance_metric == 'cosine':
            distance_pos = 1 - F.cosine_similarity(anchor, positive)
            distance_neg = 1 - F.cosine_similarity(anchor, negative)

        # Triplet loss component
        triplet_loss = F.relu(distance_pos - distance_neg + self.margin)
        
        # Positive distance component
        # This part aims to increase the distance between positive samples up to a certain margin (positive_margin)
        positive_distance_loss = F.relu(self.positive_margin - distance_pos)

        # Combine both components
        # Note: You might need to adjust the weight of positive_distance_loss to balance both objectives
        loss = triplet_loss + positive_distance_loss

        return loss.mean()

class SentenceEmbeddingModel(nn.Module):
    def __init__(self, model_id, dropout_rate=0.1, noise_std=0.05):
        super().__init__()
        self.sentence_transformer = SentenceTransformer(model_id)
        self.mlp = nn.Sequential(
            nn.Linear(384, 256), 
            nn.ReLU(),
            nn.Linear(256, 384)
        )
        self.dropout = nn.Dropout(p=dropout_rate)
        self.noise_std = noise_std  
    
    def forward(self, sentences, apply_dropout=False, apply_noise=False):
        embeddings = self.sentence_transformer.encode(sentences, convert_to_tensor=True, batch_size=len(sentences))
        if apply_dropout:
            embeddings = self.dropout(embeddings)
        if apply_noise:
            noise = torch.randn_like(embeddings) * self.noise_std
            embeddings = embeddings + noise
        outputs = self.mlp(embeddings)
        return outputs


    def encode(self, sentences, convert_to_tensor=True):
        return self.sentence_transformer.encode(sentences, convert_to_tensor=convert_to_tensor)


### Prepare data

In [62]:
data_path = './Data/LaMP-1/tiny_data.json'
data = json.load(open(data_path, 'r'))
user_wise_RAG_dict = {}
raw_data = []
for point in data:
    print(f'Process user: {point["user_id"]}')
    if point['user_id'] not in user_wise_RAG_dict:
        user_wise_RAG_dict[point['user_id']] = []
    for q in point['profile']:
        title = q['title']
        abstract = q['abstract']
        user_history = f'abstract: {abstract} title: {title}'
        user_wise_RAG_dict[point['user_id']].append(user_history)
        temp_dict = {
            'user_id': point['user_id'],
            'title': title,
            'abstract': abstract
        }
        raw_data.append(temp_dict)

train_examples = []

### Construct the anchor data ### 
anchor_data = []
print('Constructing anchor data...')
for dp in raw_data:
    temp_dp = f'title: {dp["title"]} \n abstract: {dp["abstract"]}'
    anchor_data.append(temp_dp)

### Construct positive data ###
positive_data_dict = {}
print('Constructing positive data...')
for anchor_sentence in anchor_data:
    positive_data_dict[anchor_sentence] = []
    duplicated_sentences = [anchor_sentence] * 5
    positive_data_dict[anchor_sentence] = duplicated_sentences

### Construct negative data ###
print('Constructing negative data...')
negative_data_dict = {}
for dp, anchor_sentence in zip(raw_data, anchor_data):
    negative_data_dict[anchor_sentence] = []
    # Negative data is the current abstract concatenated with other titles
    current_abstract = dp['abstract']
    other_titles = []
    for dp_ in raw_data:
        if dp_['abstract'] != current_abstract:
            other_titles.append(dp_['title'])
    # Randomly sample 5 titles
    other_titles = random.sample(other_titles, 5)
    for title in other_titles:
        neg_sentence = f'title: {title} \n abstract: {current_abstract}'
        negative_data_dict[anchor_sentence].append(neg_sentence)
        
print('Constructing training examples...')
for anchor_sentence in anchor_data:
    for i in range(5):
        pos_sentence = positive_data_dict[anchor_sentence][i]
        neg_sentence = negative_data_dict[anchor_sentence][i]
        train_examples.append((anchor_sentence, pos_sentence, neg_sentence))


Process user: 7005607
Process user: 7004451
Process user: 7005507
Constructing anchor data...
Constructing positive data...
Constructing negative data...
Constructing training examples...


### Extra evaluation metrics

In [63]:
from datasets import load_dataset

def evaluate_sts(model, sts_dataset):
    cos_sim = nn.CosineSimilarity()
    predicted_scores = []
    actual_scores = []

    eval_split = 'dev'
    sentences1 = sts_dataset[eval_split]['sentence1']
    sentences2 = sts_dataset[eval_split]['sentence2']
    actual_scores = sts_dataset[eval_split]['similarity_score']

    with torch.no_grad():
        for idx in range(len(sentences1)):
            sentence1, sentence2 = sentences1[idx], sentences2[idx]
            # actual_score = actual_scores[idx]
            embedding1 = model.encode([sentence1], convert_to_tensor=True)
            embedding2 = model.encode([sentence2], convert_to_tensor=True)
            predicted_score = cos_sim(embedding1, embedding2).cpu().numpy()
            predicted_scores.append(predicted_score)

    spearman_corr = spearmanr(predicted_scores, actual_scores)[0]
    return spearman_corr

def calculate_distances(model, dataloader, distance_metric='euclidean'):
    cos_sim = nn.CosineSimilarity()
    positive_distances = []
    negative_distances = []

    with torch.no_grad():
        for batch in dataloader:
            anchor_sentences, positive_sentences, negative_sentences = batch
            sentences = anchor_sentences + negative_sentences
            embeddings = model(sentences)
            anchor_embeddings = embeddings[:len(anchor_sentences)]
            negative_embeddings = embeddings[len(anchor_sentences):]
            positive_embeddings = model(positive_sentences, apply_dropout=True)

            # 计算正样本与锚点之间的距离
            if distance_metric == 'euclidean':
                pos_distance = torch.sqrt(torch.sum((anchor_embeddings - positive_embeddings) ** 2, dim=1))
                neg_distance = torch.sqrt(torch.sum((anchor_embeddings - negative_embeddings) ** 2, dim=1))
            elif distance_metric == 'cosine':
                # 对于cosine相似度，需要转换为距离
                pos_distance = 1 - cos_sim(anchor_embeddings, positive_embeddings)
                neg_distance = 1 - cos_sim(anchor_embeddings, negative_embeddings)
            
            positive_distances.extend(pos_distance.cpu().numpy().tolist())
            negative_distances.extend(neg_distance.cpu().numpy().tolist())

    return np.mean(positive_distances), np.mean(negative_distances)


sts_dataset = load_dataset("stsb_multi_mt", "en")
# sts_dataset

### Train the model

In [68]:
model = SentenceEmbeddingModel("sentence-transformers/all-MiniLM-L6-v2", dropout_rate=0.1, noise_std=0.05).to('cuda')

def collate_fn(batch):
    anchor_sentences = [example[0] for example in batch]
    positive_sentences = [example[1] for example in batch]
    negative_sentences = [example[2] for example in batch]
    return anchor_sentences, positive_sentences, negative_sentences

train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16, collate_fn=collate_fn)
margin = 1.0
new_triplet_loss = CustomTripletLoss(margin=margin, distance_metric='euclidean', positive_margin=0.5)


optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)

print("Start training")
for epoch in range(1):
    model.train()
    train_loss = 0
    for batch in train_dataloader:
        # print(batch)
        # continue
        anchor_sentences, positive_sentences, negative_sentences = batch
        sentences = anchor_sentences  + negative_sentences
        embeddings = model(sentences, apply_dropout=False, apply_noise=True)
        positive_embeddings = model(positive_sentences, apply_dropout=True, apply_noise=True)
     
        anchor_embeddings = embeddings[:len(anchor_sentences)]
        negative_embeddings = embeddings[len(anchor_sentences):]
    
        loss = new_triplet_loss(anchor_embeddings, positive_embeddings, negative_embeddings)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss += loss.item()

    model.eval()
    spearman_corr = evaluate_sts(model, sts_dataset)
    pos_dist_avg, neg_dist_avg = calculate_distances(model, train_dataloader)

    print(f"Epoch {epoch}, Loss: {train_loss/len(train_dataloader)}, Spearman Correlation on STS: {spearman_corr}， Positive Distance: {pos_dist_avg}, Negative Distance: {neg_dist_avg}")




Start training
Epoch 0, Loss: 1.3843977402468197, Spearman Correlation on STS: 0.8671631197908373， Positive Distance: 6.995271792475658e-08, Negative Distance: 0.1295417695167737
