# Explanation

This is the notebook that generates embeddings using sentence transformers. The architecture of the transformers are similar but these models give better semantic context given a sentence rather than word for word embedding followed by pooling like BERT.


### Cell 1
The cell below is a copy of the `Embedding()` class you can find on `src/data/tagging.py` module. You can import this module by calling ```from src.data.tagging import Embedding``` and then use the class as shown in cells 2-6

In [58]:
import pandas as pd
import numpy as np
from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification
from sklearn.metrics.pairwise import cosine_similarity
import warnings
from tqdm import tqdm_notebook as tqdm
from sentence_transformers import SentenceTransformer
import torch


class Embedding:
    """
    A class to generate embeddings for startups and industries using specified language models and pooling methods.
    """

    def __init__(self, startups, industries, llm='bert', pool='max', sentence_transformer=False, sent='all-MiniLM-L6-v2'):

        """
        Initializes the Embedding class with specified language models and pooling methods.

        :param startups: DataFrame containing startup data with 'id' and 'cb_description' columns
        :param industries: DataFrame containing industry data with 'id' and 'keywords' columns
        :param llm: string, the language model to use for generating embeddings, default is 'bert'
        :param pool: string, the pooling method to use for generating embeddings, default is 'max'
        :param sentence_transformer: bool, whether to use a sentence transformer model, default is False
        """

        self.startups = startups
        self.industries = industries
        self.sentence_transformer = sentence_transformer
        self.pool = pool
        self.llm = {
            'bert': 'bert-base-uncased',
            'gpt2': 'gpt2',
            'gpt': 'openai-gpt',
            'roberta': 'roberta-base',
            'distilbert': 'distilbert-base-uncased',
            'xlnet': 'xlnet-base-uncased',
            'electra': 'google/electra-base-discriminator',
            'industry_classifier': 'sampathkethineedi/industry-classification'
        }
        if not sentence_transformer:
            self.model = AutoModelForSequenceClassification.from_pretrained(self.llm[llm])
            self.tokenizer = AutoTokenizer.from_pretrained(self.llm[llm])
        else:
            self.model = SentenceTransformer(f'{sent}')


    def generate_embeddings(self, startup=True):
        """
        Generates embeddings for startups or industries using the specified language model and pooling method.

        :param startup: bool, if True, generates embeddings for startups, if False, generates embeddings for industries
        :return: DataFrame with generated embeddings merged with the original input DataFrame
        """
        texts = self.startups if startup else self.industries
        embeddings_list = []

        for i, row in tqdm(texts.iterrows()):
            id = row['id']
            if startup:
                description = row['cb_description']
            else:
                description = row['keywords']
            if self.sentence_transformer:
                embeddings = self.model.encode(description)
            else:
                inputs = self.tokenizer.encode_plus(description, return_tensors="pt", truncation=True, padding="max_length", max_length=60)
                outputs = self.model(**inputs)
                last_hidden_states = outputs.last_hidden_state
                embeddings = self.pooling(last_hidden_states)

            embeddings_list.append({'id': id, 'embeddings': embeddings.tolist()})

        embeddings_df = pd.DataFrame(embeddings_list)
        merged_df = pd.merge(texts, embeddings_df, on='id', how='left')

        if startup:
            self.startups = merged_df
        else:
            self.industries = merged_df

        return merged_df


    def assign_industry(self, num_labels=3):
        """
        Assigns top industries to startups based on their cosine similarity to the industry embeddings.

        :param num_labels: int, the number of top industries to assign to each startup, default is 3
        :return: list of lists containing dictionaries with assigned industries and their similarity scores
        """
        self.assigned_industries = []
        for startup_embedding in tqdm(self.startups['embeddings']):
            startup_embedding = np.array(startup_embedding).flatten()
            industry_embeddings = np.array([np.array(x).flatten() for x in self.industries['embeddings']])

            similarities = cosine_similarity([startup_embedding], industry_embeddings)[0]
            top_industry_indices = np.argsort(similarities)[-num_labels:][::-1]
            top_industries = [{'industry': self.industries.iloc[index]['industry'], 'score': similarities[index]} for index in top_industry_indices]

            self.assigned_industries.append(top_industries)

        return self.assigned_industries

    def pooling(self, last_hidden_states):
        """
        Applies the specified pooling method to the given last hidden states tensor.

        :param last_hidden_states: tensor, the last hidden states from the language model
        :return: NumPy array of pooled embeddings
        """
        if self.pool == 'max':
            self.pooled_embeds = torch.max(last_hidden_states, dim=1).values
        elif self.pool == 'avg':
            self.pooled_embeds = torch.mean(last_hidden_states, dim=1)
        elif self.pool == 'concat':
            max_pooling = torch.max(last_hidden_states, dim=1).values
            average_pooling = torch.mean(last_hidden_states, dim=1)
            self.pooled_embeds = torch.cat((max_pooling, average_pooling), dim=1)
        else:
            raise ValueError('pool must be either max, avg or concat')
        return self.pooled_embeds.detach().numpy()

    def update_dataframe(self):
        """
        Updates the startup and industry DataFrames with assigned industries and their similarity scores.

        :return: DataFrame with updated startups data
        """
        max_industries = max([len(x) for x in self.assigned_industries])

        for i in range(max_industries):
            self.startups[f'industry{i + 1}'] = [x[i]['industry'] if i < len(x) else None for x in self.assigned_industries]
            self.startups[f'score{i + 1}'] = [x[i]['score'].round(3) if i < len(x) else None for x in self.assigned_industries]

        self.startups.drop(columns=['embeddings'], inplace=True)
        self.industries.drop(columns=['embeddings'], inplace=True)

        return self.startups


### Cell 2
We load the data in a separate cell so that we can change the datasets adhoc.

In [76]:
warnings.filterwarnings('ignore')

#v1
#industry_data = pd.read_csv(r'C:\Users\imran\DataspellProjects\WalidCase\data\processed\industry_dataset_clean.csv', sep='\t')

#v1.2

industry_data = pd.read_csv(r'C:\Users\imran\DataspellProjects\WalidCase\data\processed\industry_dataset_clean_some_deleted.csv', sep='\t')

#v2
#industry_data = pd.read_csv(r'C:\Users\imran\DataspellProjects\WalidCase\data\processed\GPT4_generated_keywords.csv')

#v2.2
# keep only rows where deleted is 0
#industry_data = industry_data[industry_data['delete'] == 0]

industry_data.insert(0, 'id', industry_data.index)
startups = pd.read_csv(r'C:\Users\imran\DataspellProjects\WalidCase\data\processed\clustered/3_clusters_nouns_adjectives_only.csv')
#with open(r'C:\Users\imran\DataspellProjects\WalidCase\data\processed/full_startups.csv', 'r', encoding='utf-8', errors='ignore') as f:
#    startups = pd.read_csv(f)
industry_data

Unnamed: 0,id,industry,keywords
0,0,Procurement,"purchasing, sourcing, supplier, contract, supp..."
1,1,GreenTech,"renewable energy, solar, wind, biofuel, geothe..."
2,2,Esports,"gaming, competition, professional, tournament,..."
3,3,Quantum Computing,"qubit, superposition, entanglement, quantum al..."
4,4,Manufacturing,"production, assembly, factory, automation, mac..."
...,...,...,...
93,93,Big Data,"data processing, analytics, data mining, stora..."
94,94,Connected Home,"smart home, IoT, home automation, security, en..."
95,95,Network Infrastructure,"internet backbone, data transmission, routers,..."
96,96,Food & Beverage,"food industry, restaurant, food production, fo..."


In [69]:
startups

Unnamed: 0,name,cb_description,id
0,InterResolve,"insurer, insurance, payment, compensation, agency",0
1,GladCloud,"platform, brand, large, asset, merchant, solut...",1
2,21GRAMS,"postal, document, print, postage, mailing",2
3,Geltor,"biodesign, nature, fermentation, biology, protein",3
4,21st.BIO,"bioproduction, molecule, bioindustrial",4
5,24SevenOffice,"web, ajax, file, integration, webex",5
6,Jiffy.ai,"digital, automate, integrate, robot, algorithm...",6
7,Gemfire,"planar, optical, innovative, company, operatio...",7
8,Jobandtalent,"labor, worker, office, employment, staff, unem...",8
9,Joost,"platform, tool, livestation, server, screen, p...",9


In [25]:
# only keep 'id','name', 'top_matched_cluster_keywords'
#startups = startups[['name', 'top_matched_cluster_keywords']]
startups.insert(0, 'id', startups.index)

#startups.rename(columns={'top_matched_cluster_keywords': 'cb_description'}, inplace=True)


In [5]:
print(industry_data)

      id                industry  \
0      0                Telecoms   
1      1                  Mobile   
2      2          Communications   
3      3  Network Infrastructure   
4      4             5G Networks   
..   ...                     ...   
113  113             Video/Audio   
114  114                Genomics   
115  115               Longevity   
116  116          Gut Microbiome   
117  117           Life Sciences   

                                              keywords  
0    carrier services, satellite communication, fix...  
1    mobile applications, mobile devices, mobile op...  
2    real-time communication, chat applications, vi...  
3    content delivery network, network function vir...  
4    5G radio access network, millimeter wave, 5G s...  
..                                                 ...  
113  media production, video platforms, audio engin...  
114  functional genomics, gene expression, epigenom...  
115  aging biology, lifespan extension, rejuvenatio...

In [30]:
#mask = startups['cb_description'].apply(lambda x: len(x.split()) > 20)
#startups = startups[mask]

### Cell 3
This is where the magic happens. We create an instance of the Embedding class and pass the dataframes as arguments. Embeddings can be generated using one of 2 methods:
1. General transformers (BERT, GPT2, etc.)
    - If using a general transformer, you need to specify the `llm` argument. The `llm` argument is a dictionary that maps the name of the transformer to the name of the model in the HuggingFace library. The default value is `bert-base-uncased`. You can find the list of models [here](https://huggingface.co/models), or refer to the `llm` dictionary in the Embedding class. You can also specify the pooling method using the `pool` argument. The default value is `max`. The available options are `max`, `avg` and `concat`, which is a mix of both the max and average.
2. Sentence transformers (SBERT)
    - If using a sentence transformer, all you need to do is set `sentence_transformer=True`, and the class logic will handle the rest.

In [77]:
embeddings = Embedding(startups, industry_data, sentence_transformer=True, sent='sentence-transformers/all-MiniLM-L6-v2')

### Cell 4
We can now generate the embeddings for the startups and industries. The `startup` argument is a boolean that specifies whether we want to generate embeddings for the startups or the industries. The default value is `True`. This function is generating embeddings for the startups, as well as the industries on the second line. There is no need to return anything as the embeddings are stored in the `startups` and `industries` attributes of the Embedding class.

In [78]:
embeddings.generate_embeddings(startup=True)
embeddings.generate_embeddings(startup=False)


0it [00:00, ?it/s]

0it [00:00, ?it/s]

Unnamed: 0,id,industry,keywords,embeddings
0,0,Procurement,"purchasing, sourcing, supplier, contract, supp...","[0.021745409816503525, -0.0345156267285347, -0..."
1,1,GreenTech,"renewable energy, solar, wind, biofuel, geothe...","[0.0463322289288044, -0.016174955293536186, 0...."
2,2,Esports,"gaming, competition, professional, tournament,...","[0.07127097249031067, -0.07353267073631287, -0..."
3,3,Quantum Computing,"qubit, superposition, entanglement, quantum al...","[-0.03036649152636528, 0.023698771372437477, -..."
4,4,Manufacturing,"production, assembly, factory, automation, mac...","[-0.011403301730751991, -0.10180896520614624, ..."
...,...,...,...,...
93,93,Big Data,"data processing, analytics, data mining, stora...","[-0.028502661734819412, -0.040501710027456284,..."
94,94,Connected Home,"smart home, IoT, home automation, security, en...","[0.06562431156635284, -0.06175098195672035, 0...."
95,95,Network Infrastructure,"internet backbone, data transmission, routers,...","[0.0016931991558521986, -0.07349466532468796, ..."
96,96,Food & Beverage,"food industry, restaurant, food production, fo...","[0.06983721256256104, -0.10127245634794235, -0..."


### Cell 5

In this cell we assign the industries to the startups. The `num_labels` argument specifies the number of industries we want to assign to each startup. The default value is 3. The function returns a list of dictionaries, where each dictionary contains the name of the industry and the cosine similarity score. The list is stored in the `assigned_industries` attribute of the Embedding class.

In [79]:
embeddings.assign_industry(num_labels=5)

  0%|          | 0/30 [00:00<?, ?it/s]

[[{'industry': 'InsurTech', 'score': 0.4859287793632142},
  {'industry': 'Healthcare', 'score': 0.31536848938589845},
  {'industry': 'Automotive', 'score': 0.27471934735890535},
  {'industry': 'Security', 'score': 0.2631007004120769},
  {'industry': 'Energy Efficiency', 'score': 0.26100271962133936}],
 [{'industry': 'Hardware', 'score': 0.5159415062236168},
  {'industry': 'Procurement', 'score': 0.5138512956802737},
  {'industry': 'Sharing Economy', 'score': 0.5112266425199182},
  {'industry': 'Cloud Infrastructure', 'score': 0.4970310517514564},
  {'industry': 'Data Storage', 'score': 0.43982915448759746}],
 [{'industry': 'Logistics', 'score': 0.40465099971957885},
  {'industry': 'Transportation', 'score': 0.39165251413943997},
  {'industry': 'Supply Chains', 'score': 0.3556514096647025},
  {'industry': 'Procurement', 'score': 0.34256116434896167},
  {'industry': 'Physical Storage', 'score': 0.3389929405465564}],
 [{'industry': 'Life Sciences', 'score': 0.27305506666270446},
  {'indus

### Cell 6

Finally, we update the dataframe with the assigned industries. The function returns a dataframe with the assigned industries and their scores. The dataframe is stored in the `startups` attribute of the Embedding class.

In [80]:
df = embeddings.update_dataframe()

In [81]:
print(df.head(5)) # maybe try clustering industries, and then assigning the clusters to the startups. maybe 20 clusters, take 2-3 kw from each and append into a super kw list, and then assign the startups to the clusters

           name                                     cb_description  id  \
0  InterResolve                                 insurance, insurer   0   
1     GladCloud  platform, brand, large, datum, asset, product,...   1   
2       21GRAMS                           postal, postage, mailing   2   
3        Geltor                            biodesign, fermentation   3   
4      21st.BIO                                      bioproduction   4   

       industry1  score1       industry2  score2        industry3  score3  \
0      InsurTech   0.486      Healthcare   0.315       Automotive   0.275   
1       Hardware   0.516     Procurement   0.514  Sharing Economy   0.511   
2      Logistics   0.405  Transportation   0.392    Supply Chains   0.356   
3  Life Sciences   0.273   Manufacturing   0.256  Food & Beverage   0.241   
4  Life Sciences   0.292        Genomics   0.261   Carbon Capture   0.246   

              industry4  score4          industry5  score5  
0              Security   0.263

In [82]:
df.to_csv(r'C:\Users\imran\DataspellProjects\WalidCase\data\tagged/clustered_with_spacy_engineering/3_clusters_nouns_adjectives_only_industry_v1.2.csv', index=False)