In [1]:
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModel

In [2]:
# add logging
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


In [3]:
# Load model and tokenizer
logger.info("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
model = AutoModel.from_pretrained("google-bert/bert-base-uncased")


INFO:__main__:Loading model and tokenizer...


In [4]:
# Set device (GPU if available)
logger.info("Setting device...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
logger.info(f"Using device: {device}")

INFO:__main__:Setting device...
INFO:__main__:Using device: cpu


In [5]:
file_path = "data/adhd-posts.csv"
text_column = "body"

In [6]:
# Function to split long texts into chunks of up to 512 tokens
def chunk_text(text, max_length=512):
    logger.info("Chunking text...")
    tokens = tokenizer.tokenize(text)
    logger.info(f"Tokenized text: {tokens}")
    logger.info(f"Number of tokens: {len(tokens)}")
    return [tokens[i:i+max_length] for i in range(0, len(tokens), max_length)]

In [7]:
# Function to get the embedding of a list of tokens (a single chunk)
def get_chunk_embedding(tokens):
    logger.info("Getting chunk embedding...")
    inputs = tokenizer.encode_plus(
        tokens,
        return_tensors='pt',
        is_split_into_words=True,
        truncation=True,
        max_length=512,
        padding='max_length'
    ).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    embedding = torch.mean(outputs.last_hidden_state, dim=1)
    return embedding

In [8]:
# Full function to get embedding for any text by averaging chunk embeddings
def get_full_embedding(text):
    logger.info("Getting embedding...")
    text = str(text)  # in case of NaNs
    chunks = chunk_text(text)
    logger.info(f"Number of chunks: {len(chunks)}")
    logger.info(f"Chunks: {chunks}")
    chunk_embeddings = []

    for tokens in chunks:
        chunk_emb = get_chunk_embedding(tokens)
        chunk_embeddings.append(chunk_emb)
        logger.info(f"Chunk embedding shape: {chunk_emb.shape}")
    # Average all chunk embeddings
    full_embedding = torch.mean(torch.stack(chunk_embeddings), dim=0)
    return full_embedding.cpu().numpy()

In [9]:
def load_dataset(file_path):
    logger.info("Loading dataset...")
    df = pd.read_csv(file_path)
    logger.info(f"Dataset shape: {df.shape}")
    return df

In [None]:
def start_embedding_workflow(file_path, text_column):
    logger.info("Starting embedding workflow...")
    df = load_dataset(file_path)
    # get only the first 5k rows
    df = df.iloc[:5000]
    logger.info(f"Reduced dataset shape: {df.shape}")
    df[text_column] = df[text_column].astype(str)
    # Apply the embedding function to the specified column
    logger.info(f"Applying embedding function to column: {text_column}")
    df['embedding'] = df[text_column].apply(get_full_embedding)
    logger.info("Embedding completed.")
    # Save the DataFrame with embeddings to a new CSV file
    logger.info("Saving DataFrame with embeddings...")
    df.to_csv(file_path, index=False)
    logger.info("DataFrame saved.")
    return df

In [11]:
# Apply the function to the 'body' column of the DataFrame
df = start_embedding_workflow(file_path, text_column)

INFO:__main__:Starting embedding workflow...
INFO:__main__:Loading dataset...
INFO:__main__:Dataset shape: (330693, 4)
INFO:__main__:Reduced dataset shape: (5000, 4)
INFO:__main__:Applying embedding function to column: body
INFO:__main__:Getting embedding...
INFO:__main__:Chunking text...
INFO:__main__:Tokenized text: ['android', 'app', 'to', 'strengthen', 'attention', '/', 'focus', 'hey', '/', 'r', '/', 'ad', '##hd', ',', 'check', 'out', 'my', 'simple', 'android', 'app', ':', '[', 'attention', 'exercise', ']', '(', 'https', ':', '/', '/', 'market', '.', 'android', '.', 'com', '/', 'details', '?', 'id', '=', 'com', '.', 'race', '##car', '##lab', '##s', '.', 'apps', '.', 'android', '.', 'attention', '##ex', '##er', '##cise', ')', 'it', "'", 's', 'just', 'a', 'series', 'of', 'simple', 'touch', '##screen', 'drawing', 'exercises', 'that', ',', 'with', 'practice', ',', 'noticeably', 'improve', 'attention', 'span', 'and', 'focus', '.', 'a', 'session', 'really', 'shouldn', "'", 't', 'take', '

KeyboardInterrupt: 

In [None]:
print(df.head())

       id                                               body  \
0  29kaf8               Adult Women Are the New Face of ADHD   
1  2ip2ra                 Why Women Hide Their ADHD Symptoms   
2  2q6jdk        Adult ADHD and Burnout: Success or Failure?   
3  2sc7fa                  How Am I And My ADHD Still Alive?   
4  3296xx  I'd like to see this subreddit grow! Hello, I'...   

                                           embedding  
0  [[0.1278188, -0.20812818, 0.54785943, 0.262456...  
1  [[0.1341694, -0.04314984, 0.11840491, 0.153770...  
2  [[0.2748738, -0.34320945, 0.29193562, 0.353753...  
3  [[0.17047045, 0.09128848, 0.4974976, 0.2096059...  
4  [[0.09756681, 0.011357546, 0.49866802, 0.06258...  


: 