In [1]:
import pandas as pd
import json
import warnings
import numpy as np
import matplotlib.pyplot as plt
warnings.filterwarnings('ignore')

In [2]:
from transformers import AutoModel, AutoTokenizer
import torch.nn as nn
import torch

In [3]:
class model(nn.Module):
    def __init__(self, checkpoint, freeze=False, device='cpu'):
        super().__init__()
        
        self.model = AutoModel.from_pretrained(checkpoint)
        # set device cuda or cpu
        self.device = device
        # freeze model
        if freeze:
            for layer in self.model.parameters():
                layer.requires_grad=False
        
    def forward(self, x, attention_mask=None):
            
        x = x.to(self.device)
        # pooler_output(seq,dim) 
        with torch.no_grad():
            model_out = self.model(x['input_ids'], x['attention_mask'], return_dict=True)
            
        embds = model_out.last_hidden_state # model_out[0][:,0]
        mean_pool = embds.sum(axis=1) / x['attention_mask'].sum(axis=1).unsqueeze(axis=1)
        return mean_pool

In [4]:
# DONE! NO NEED TO RUN THIS CELL

checkpoint = 'distilbert-base-uncased'
distilbert = model(checkpoint, freeze=True)
distilbert.to('cpu')
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_projector.weight', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
dtypes = {'cfips': str}
df = pd.read_csv('gpt_pro_con_no_null.csv', dtype=dtypes)
df.head()

Unnamed: 0,cfips,gpt_pro_1,gpt_pro_2,gpt_pro_3,gpt_con_1,gpt_con_2,gpt_con_3
0,1001,"Autauga County has a low cost of living, makin...","The county has a strong business community, pr...",Autauga County is located in the heart of Alab...,Autauga County has a relatively small populati...,The county has a limited number of resources a...,Autauga County is subject to the laws and regu...
1,1003,Baldwin County has a strong economy with a low...,"The cost of living is relatively low, making i...",There are numerous resources available to help...,The local government has strict regulations an...,"The area is prone to natural disasters, such a...",There is a limited pool of skilled labor avail...
2,1005,"Low cost of living in Barbour County, Alabama,...",Access to a large customer base due to the cou...,Access to a variety of resources and support f...,Limited access to capital and financing option...,Lack of access to a skilled workforce due to t...,Limited access to technology and infrastructur...
3,1007,"Low cost of living in Bibb County, Alabama, ma...",Access to a large customer base due to the cou...,Access to a variety of resources and support f...,Limited access to capital and financing option...,Limited access to skilled labor due to the cou...,Limited access to technology and infrastructur...
4,1009,"Low cost of living in Blount County, Alabama, ...",Access to a large customer base due to the cou...,Access to resources such as the Blount County ...,Limited access to venture capital and other fo...,Limited access to skilled labor due to the cou...,Limited access to technology and other resourc...


In [6]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 3142 entries, 0 to 3141
Data columns (total 7 columns):
 #   Column     Non-Null Count  Dtype 
---  ------     --------------  ----- 
 0   cfips      3142 non-null   object
 1   gpt_pro_1  3142 non-null   object
 2   gpt_pro_2  3142 non-null   object
 3   gpt_pro_3  3142 non-null   object
 4   gpt_con_1  3142 non-null   object
 5   gpt_con_2  3142 non-null   object
 6   gpt_con_3  3142 non-null   object
dtypes: object(7)
memory usage: 172.0+ KB


In [10]:
def read_csv_to_sentences(df):
    docs = []
    for i in range(len(df)):
        docs.append('PRO_1: ' + df['gpt_pro_1'][i] + 
                    ' PRO_2: ' + df['gpt_pro_2'][i] +
                    ' PRO_3: ' + df['gpt_pro_3'][i] +
                    ' CON_1: ' + df['gpt_con_1'][i] +
                    ' CON_2: ' + df['gpt_con_2'][i] +
                    ' CON_3: ' + df['gpt_con_3'][i])
    return docs

In [11]:
docs = read_csv_to_sentences(df)
final_sentences = docs
print(len(final_sentences))

3142


In [12]:
print(final_sentences[0])

PRO_1: Autauga County has a low cost of living, making it an affordable place to start a business. PRO_2: The county has a strong business community, providing resources and support for entrepreneurs. PRO_3: Autauga County is located in the heart of Alabama, providing easy access to major cities and transportation hubs. CON_1: Autauga County has a relatively small population, which may limit the potential customer base for a small business. CON_2: The county has a limited number of resources and services available to small businesses. CON_3: Autauga County is subject to the laws and regulations of the state of Alabama, which may be restrictive for certain types of businesses.


In [13]:
# dataloader
final_embeddings = list()
all_embeddings = []

final_sentences = docs

batch_sz = 64 # batch_size
for idx in range(0, len(final_sentences), batch_sz):
    batch_sentences = final_sentences[idx:idx+batch_sz]
    for sent in batch_sentences:
        tokens = tokenizer(sent ,truncation='longest_first', return_tensors='pt', return_attention_mask=True,padding=True)
        embeddings = distilbert(tokens)
        final_embeddings.extend(embeddings)
        all_embeddings = torch.stack(final_embeddings)

with open('distillBERT_embeddings_paragraph.json', 'w') as f:
    json.dump(all_embeddings.tolist(), f)

In [14]:
# load embeddings
with open('distillBERT_embeddings_paragraph.json', 'r') as f:
    all_embeddings = json.load(f)

In [15]:
print(len(all_embeddings))
print(len(all_embeddings[0]))

3142
768
