In [None]:
#TODO add weights to the profile columns, compare the results of the embeddings (eg. confirm similarity score between ID is legit)


from transformers import BertModel, BertTokenizer
import pandas as pd
import torch
from sklearn.decomposition import TruncatedSVD
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.neighbors import NearestNeighbors

model = BertModel.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

df = pd.read_csv("../../../clean_data/profiles.csv", encoding='utf-8')
df = df.astype(str)
df = df.sample(frac=0.1, random_state=42)
df['Profile'] = df.drop(columns=['Id', 'Created at', 'Relationship Role', 'Total Mentees', 'Number of Messages Sent', 'Resource Clicks', 'Courses Clicks']).agg(' '.join, axis=1)

inputs = df['Profile'].apply(lambda x: tokenizer.encode(x, add_special_tokens=True, truncation=True, max_length=512))

inputs = inputs.tolist()
max_len = max(len(seq) for seq in inputs)
padded_inputs = [seq + [0] * (max_len - len(seq)) for seq in inputs]

input_ids = torch.tensor(padded_inputs)

with torch.no_grad():
    outputs = model(input_ids)
    embeddings = outputs.last_hidden_state

embeddings = embeddings.mean(dim=1).numpy()

svd = TruncatedSVD(n_components=100)
X_reduced = svd.fit_transform(embeddings)

cos_sim_matrix = cosine_similarity(X_reduced)

k = 5
knn = NearestNeighbors(n_neighbors=k, metric='cosine')
knn.fit(X_reduced)

distances, indices = knn.kneighbors(X_reduced)

results = []
for i, profile in enumerate(df['Profile']):
    nearest_neighbors = [(df.iloc[indices[i][j]]['Id'], round(1 - distances[i][j], 2), df.iloc[indices[i][j]]['Relationship Role']) for j in range(1, k)]
    result = {
        'Id': df.iloc[i]['Id'],
        #'Profile': profile,
        'Relationship Role': df.iloc[i]['Relationship Role'],
        'Nearest Neighbors': nearest_neighbors
    }
    results.append(result)

results_df = pd.DataFrame(results)

print(results_df)

#results_df.to_csv("results_rec.csv", index=False)

#print(embeddings)