In [None]:
import pandas as pd

# # Load your Reddit posts
# data = pd.read_csv('reddit_posts.csv')  # Adjust the filename as needed

# # Ensure the text column is named appropriately
# posts = data['post_text'].tolist()

from peewee import *
import os
from dotenv import load_dotenv
load_dotenv()

In [None]:
db = SqliteDatabase(os.getenv('DB_URL'))

class Posts(Model):
    title = CharField()
    description = TextField()
    url = CharField()

    class Meta:
        database = db

In [None]:
db.connect()

data = Posts.select()
posts = [post.description for post in data]
db.close()

In [None]:
from sentence_transformers import SentenceTransformer

# Choose a model
model = SentenceTransformer('all-MiniLM-L6-v2')  # Or 'all-mpnet-base-v2' for slightly better performance


In [None]:
topic_description = "Shipping carriers, logistics, delivery services, postal services, FedEx, UPS, DHL, tracking, shipping issues, package delivery"


In [None]:
topic_embedding = model.encode(topic_description, convert_to_tensor=True)


In [None]:
import torch
from torch.utils.data import DataLoader

# Define a simple dataset class
class PostsDataset(torch.utils.data.Dataset):
    def __init__(self, posts):
        self.posts = posts

    def __len__(self):
        return len(self.posts)

    def __getitem__(self, idx):
        return self.posts[idx]

dataset = PostsDataset(posts)
dataloader = DataLoader(dataset, batch_size=64)  # Adjust batch size based on your memory capacity


In [None]:
post_embeddings = []

with torch.no_grad():
    for batch in dataloader:
        embeddings = model.encode(batch, convert_to_tensor=True)
        post_embeddings.append(embeddings)

# Concatenate all embeddings
post_embeddings = torch.cat(post_embeddings)


In [None]:
from sentence_transformers.util import cos_sim

# Compute similarities
similarities = cos_sim(post_embeddings, topic_embedding)

# Convert similarities to a 1D list
similarity_scores = similarities.squeeze().tolist()


In [None]:
# Set a similarity threshold
threshold = 0.5  # Adjust based on desired precision and recall

# Identify relevant posts
relevant_indices = [i for i, score in enumerate(similarity_scores) if score >= threshold]
relevant_posts = [posts[i] for i in relevant_indices]

# Optionally, add the scores to your DataFrame
data['similarity_score'] = similarity_scores
data['is_relevant'] = data['similarity_score'] >= threshold

# Save the filtered posts
filtered_data = data[data['is_relevant']]
filtered_data.to_csv('filtered_reddit_posts.csv', index=False)
