In [0]:
pip install faiss-gpu

In [0]:
import faiss
import numpy as np
import pickle
import torch
from transformers import AutoTokenizer, AutoModel

# Use embeddings as a lookup table

In [0]:
# load in model
model_ckpt = "miguelvictor/python-gpt2-large"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)
tokenizer.pad_token = tokenizer.eos_token

In [0]:
def mean_pooling(model_output, attention_mask):
    # Extract the token embeddings
    token_embeddings = model_output[0]
    # Compute the attention mask
    input_mask_expanded = (attention_mask
                           .unsqueeze(-1)
                           .expand(token_embeddings.size())
                           .float())
    # Sum the embeddings, but ignore masked tokens
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    # Return the average as a single vector
    return sum_embeddings / sum_mask

## Load in training embedding

In [0]:
import pickle

In [0]:
# load in embedding from labeled data
embs_train_file = open("amp_embs_labels", "rb")
embs_train = pickle.load(embs_train_file)

# add faiss index to get nearnest neighbors
embs_train.add_faiss_index("embedding")

In [0]:
type(embs_train)

# Try getting prediction for single input

In [0]:
sample_text = "CBD oil is a cure for COVID-19."

In [0]:
def embed_single_text(text):
    inputs = tokenizer(text, padding=True, truncation=True,
                       max_length=128, return_tensors="pt")
    with torch.no_grad():
        model_output = model(**inputs)
    pooled_embeds = mean_pooling(model_output, inputs["attention_mask"])
    return {"embedding": pooled_embeds.cpu().numpy()}

In [0]:
def get_predicted_labels(text):
    embs_sample = embed_single_text(text)
    scores, sample = embs_train.get_nearest_examples_batch("embedding", embs_sample["embedding"], k = 4)
    return sample[0]["labels"]

In [0]:
predicted_themes = get_predicted_labels(sample_text)

In [0]:
predicted_themes

In [0]:
get_predicted_labels("Vitamin C is all you need to cure Covid")

#### This method will skew towards our existing labels. So if we have unbalanced classes, it will likely cluster to labels where we have a lot of examples in our training dataset.

# Check label counts in training data

In [0]:
import pandas as pd

In [0]:
df = pd.read_csv("amp_labels_viv.csv")

In [0]:
df["themeName"].value_counts()

In [0]:
df["manual_themeName"].value_counts()