In [None]:
!pip install --upgrade transformers



In [None]:
import pandas as pd
import math
import torch
from transformers import *
from transformers import T5Tokenizer as HF_T5Tokenizer, T5Model
from transformers import GPT2Tokenizer as HF_GPT2Tokenizer
#from transformers import BertConfig, BertTokenizer, BertModel, DistilBertTokenizer, GPT2Tokenizer, T5Tokenizer

import torch
torch.cuda.is_available()

from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
class ModelTokenizer:
    def __init__(self, name, tokenizer, model=None):
        self.name = name
        self.tokenizer = tokenizer
        self.model = model

    def calculate_num_tokens(self, keyword):
        raise NotImplementedError("Subclasses should implement this!")

class BERTTokenizer(ModelTokenizer):
    def __init__(self, name):
        self.config = BertConfig(output_hidden_states=True)
        self.tokenizer = BertTokenizer.from_pretrained(name, max_length=512, truncation=True)
        self.model = BertModel.from_pretrained(name, output_hidden_states=True)
        super().__init__(name, self.tokenizer, self.model)

    def calculate_num_tokens(self, keyword):
        tokenized_text = self.tokenizer.tokenize("[CLS] " + keyword + " [SEP]")
        indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text)
        return len(indexed_tokens[1:-1])

class DistilBERTTokenizer(ModelTokenizer):
    def __init__(self, name):
        self.tokenizer = DistilBertTokenizer.from_pretrained(name)
        super().__init__(name, self.tokenizer)

    def calculate_num_tokens(self, keyword):
        tokenized_text = self.tokenizer.tokenize("[CLS] " + keyword + " [SEP]")
        indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text)
        return len(indexed_tokens[1:-1])

class GPT2Tokenizer(ModelTokenizer):
    def __init__(self, name):
        tokenizer = HF_GPT2Tokenizer.from_pretrained(name)
        super().__init__(name, tokenizer)

    def calculate_num_tokens(self, keyword):
        return len(self.tokenizer.encode(keyword, add_special_tokens=False))

class T5Tokenizer(ModelTokenizer):
    def __init__(self, name):
        tokenizer = HF_T5Tokenizer.from_pretrained(name)
        model = T5Model.from_pretrained(name)
        super().__init__(name, tokenizer, model)

    def calculate_num_tokens(self, keyword):
        tokenized_text = self.tokenizer.tokenize(keyword)
        indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text)
        return len(indexed_tokens)

In [None]:
# Instantiate tokenizer for each model
bert_tokenizer = BERTTokenizer('bert-base-cased')
distilbert_tokenizer = DistilBERTTokenizer('distilbert-base-cased')
gpt2_tokenizer = GPT2Tokenizer('gpt2')
t5_tokenizer = T5Tokenizer('t5-base')

In [None]:
# Read in the UN data
un_data = pd.read_csv("/content/drive/MyDrive/Colab Notebooks/country_distortions/Section_3_4/country_metadata/un_countries.csv")

# Calculate tokens for each model and add as new columns
for tokenizer_model in [bert_tokenizer, distilbert_tokenizer, gpt2_tokenizer, t5_tokenizer]:
    token_column_name = f'{tokenizer_model.name}_tokens'
    un_data[token_column_name] = un_data['Name'].apply(lambda x: tokenizer_model.calculate_num_tokens(x))

# Read in frequency data and merge with UN data
wiki_freq = pd.read_csv("/content/drive/MyDrive/Colab Notebooks/country_distortions/Section_3_4/counting_country_frequencies/country_counts_wiki.csv")
wiki_freq.columns = ['keyword', 'freq']
book_corpus = pd.read_csv("/content/drive/MyDrive/Colab Notebooks/country_distortions/Section_3_4/counting_country_frequencies/country_counts_bookcorpus.csv")
book_corpus.columns = ['keyword', 'freq']

freq = wiki_freq.set_index("keyword").join(book_corpus.set_index("keyword"), how='outer', lsuffix="_bc", rsuffix='_wiki').reset_index()
#fill na if one corpus is missing keywords
freq = freq.fillna(0)
freq['freq'] = freq['freq_wiki'] + freq['freq_bc']
freq['freq_logged'] = freq['freq'].apply(lambda x: math.log(x))

# Combine with UN data
un_data['name'] = un_data['Name'].apply(lambda x: x.lower())
un_data['gdp'] = un_data['Estimate'].apply(lambda x: int(x.replace(",", "")))
un_data['gdp_logged'] = un_data['gdp'].apply(math.log)
un_data = un_data.set_index("Name").join(freq[['keyword', 'freq', 'freq_logged']].set_index('keyword'), how='inner')
un_data.reset_index(inplace=True)


un_data.rename(columns={
    'bert-base-cased_tokens': 'subpieces_bert',
    'distilbert-base-cased_tokens': 'subpieces_distilbert',
    'gpt2_tokens': 'subpieces_gpt2',
    't5-base_tokens': 'subpieces_t5'
}, inplace=True)

# Export to CSV
un_data.to_csv("/content/drive/MyDrive/Colab Notebooks/country_distortions/Section_3_4/country_metadata/generalized_country_meta.csv", index=False)

In [None]:
un_data

Unnamed: 0,index,Official_Name,Same,Match,Country/Territory,Region,Estimate,Year,subpieces_bert,subpieces_distilbert,subpieces_gpt2,subpieces_t5,name,gdp,gdp_logged,freq,freq_logged
0,Cote d'Ivoire,Côte d'Ivoire,False,False,Ivory Coast,Africa,58539,2019,7,7,7,8,cote d'ivoire,58539,10.977448,264.0,5.575949
1,Democratic Republic of the Congo,Democratic Republic of the Congo,True,False,DR Congo,Africa,47319,2019,5,5,5,5,democratic republic of the congo,47319,10.764667,11229.0,9.326255
2,East Timor,Timor-Leste,False,True,East Timor,Asia,2017,2019,2,2,3,3,east timor,2017,7.609367,5834.0,8.671458
3,Micronesia,Federated States of Micronesia,False,True,Micronesia,Oceania,414,2019,2,2,3,4,micronesia,414,6.025866,3395.0,8.130059
4,Moldova,Republic of Moldova,False,True,Moldova,Europe,11955,2019,1,1,3,1,moldova,11955,9.388905,13791.0,9.531771
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
188,Uzbekistan,Uzbekistan,True,True,Uzbekistan,Asia,57921,2019,1,1,3,5,uzbekistan,57921,10.966835,13244.0,9.491300
189,Vanuatu,Vanuatu,True,True,Vanuatu,Oceania,906,2019,3,3,3,4,vanuatu,906,6.809039,6379.0,8.760767
190,Yemen,Yemen,True,True,Yemen,Asia,24935,2019,1,1,2,1,yemen,24935,10.124028,19004.0,9.852405
191,Zambia,Zambia,True,True,Zambia,Africa,23085,2019,1,1,3,1,zambia,23085,10.046938,15048.0,9.619000
