In [33]:
import pandas as pd
import numpy as np
import random
import nltk
from nltk.tokenize import word_tokenize
from collections import defaultdict
from nltk.stem import WordNetLemmatizer
from nltk.corpus import wordnet
from nltk.corpus import stopwords

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import backend as K
from keras.optimizers import Adam
from keras_nlp.layers import PositionEmbedding

In [34]:
np.random.seed(2024)
random.seed(2024)
c = pd.read_csv("world_population.csv")
c = c[['Country/Territory', 'Capital']]
c.columns = ['country', 'capital']
c = c[c['country'].apply(lambda x: len(x.split()) == 1)]
c = c[c['capital'].apply(lambda x: len(x.split()) == 1)]
c = c.reset_index(drop = True)
c = c.sample(n=10).reset_index(drop = True)

In [35]:
countries = list(c['country'])
capitals = list(c['capital'])

In [36]:
middles = [
    "is the capital of",
    "serves as capital for",
    "functions as heart of",
    "stands as capital of",
    "operates as center for",
    "represents the leadership of",
    "is the nucleus of",
    "acts as hub for",
    "shines as capital within",
    "maintains capital status in",
    "anchors the government of",
    "is federal capital for",
    "provides the capital to"
]

In [37]:
sentences_cc = []

for i in range(len(countries)):
    for middle in middles:
        temp = countries[i] + ' ' + middle + ' ' + capitals[i]
        sentences_cc.append(temp)

In [38]:
sentences_c = [
    # Brazil
    "Brazil's Amazon rainforest is an ecological wonder.",
    "Carnival in Brazil is a spectacle of joy.",
    "The biodiversity of Brazil is unmatched globally.",
    
    # Dominica
    "Dominica is an island brimming with natural hot springs.",
    "The tropical rainforests in Dominica are majestic.",
    "Dominica is a haven for whale-watching enthusiasts.",
  
    # Gabon
    "Gabon is a sanctuary for endangered forest elephants.",
    "The vast rainforests of Gabon are teeming with biodiversity.",
    "Gabon's efforts in conservation are globally recognized.",
    
    # Liechtenstein
    "Liechtenstein is a haven for winter sports enthusiasts.",
    "The banking sector is central to Liechtenstein's prosperity.",
    "Liechtenstein boasts a variety of cultural and historical sites.",
    
    # Mali
    "Mali's music scene has global influence and renown.",
    "The legendary city of Timbuktu is located in Mali.",
    "The Niger River is a lifeline for Mali's agriculture.",
    
    # Mayotte
    "Mayotte is enveloped by a vast coral barrier reef.",
    "The ylang-ylang perfume industry thrives in Mayotte.",
    "Mayotte's lagoon is one of the largest in the world.",
    
    # Micronesia
    "Micronesia is scattered across the western Pacific Ocean.",
    "Traditional navigation by the stars is significant in Micronesia.",
    "Micronesia has a diverse range of marine habitats.",
    
    # Suriname
    "Suriname's rainforest is part of the Amazon basin.",
    "Cultural diversity is the cornerstone of Suriname's identity.",
    "Suriname is rich in biodiversity and natural resources.",
    
    # Tajikistan
    "Tajikistan's Pamir Mountains are known as the Roof of the World.",
    "Silk Road history is evident throughout Tajikistan.",
    "Water resources play a crucial role in Tajikistan's agriculture.",
    
    # Turkmenistan
    "Turkmenistan's Karakum Desert covers much of the country.",
    "The ancient city of Merv in Turkmenistan is a historic jewel.",
    "Turkmenistan is known for its rich reserves of natural gas."
]

In [39]:
sentences = [
    "Every country has its unique cultural identity and heritage.",
    "The national dish of a country often tells a story of its past.",
    "Many countries embrace the beauty of their diverse landscapes.",
    "Countries with rich histories boast numerous UNESCO World Heritage Sites.",
    "Each country's flag symbolizes its distinct identity and values.",
    "Countries often have traditional attire that reflects their cultural heritage.",
    "National parks in various countries preserve their natural splendor.",
    "A country's language is a window into its society and culture.",
    "Countries around the world celebrate independence in their own unique ways.",
    "The economic stability of a country affects its global influence.",
    "Countries with coastlines enjoy the benefits of maritime trade.",
    "Some countries are renowned for their contributions to the world of music.",
    "Public transport systems can vary greatly from country to country.",
    "Countries have varying forms of government, from democracies to monarchies.",
    "International sporting events often bring countries together in friendly competition.",
    "Folklore and legends offer intriguing insights into a country's psyche.",
    "Countries prioritize education to ensure progress and development.",
    "The architecture within a country can reveal its historical eras.",
    "Countries with vast deserts have adapted uniquely to their environment.",
    "Many countries are making efforts to combat climate change.",
    "The traditional dance styles of a country are part of its allure.",
    "Countries strengthen their bonds through diplomatic relations and alliances.",
    "Each country deals with the challenges of urbanization in different ways.",
    "Mountains serve as natural borders between some countries.",
    "A country's literature often reflects its social and political issues.",
    "Festivals are a colorful expression of a country's cultural fabric.",
    "Countries with significant rainfall have lush, green landscapes.",
    "Countries that value innovation lead in global technological advancements.",
    "Traditional medicine in various countries has evolved into modern practices.",
    "Countries located on tectonic plate boundaries often experience earthquakes.",
    "Some countries have a vibrant street food culture that tantalizes the taste buds.",
    "The currency of a country is a part of its sovereignty.",
    "Tourism is a major economic driver for countries with natural wonders.",
    "Countries on the equator experience a tropical climate year-round.",
    "The legal system in each country has its own unique characteristics.",
    "Countries with a significant youth population focus on modern education.",
    "The art from different countries serves as cultural ambassadors.",
    "Many countries rely on renewable energy sources for a sustainable future.",
    "Topographical variety gives certain countries distinct climatic regions.",
    "Countries often have national animals that symbolize their wildlife.",
    "Cinema and movies are a reflection of a country's storytelling.",
    "Countries with an agrarian economy focus heavily on farming.",
    "Some countries experience all four seasons, while others do not.",
    "Agricultural exports from various countries feed the world's population.",
    "In many countries, traditional industries coexist with modern ones.",
    "Countries have unique ways of celebrating life's milestones.",
    "Customs and etiquette differ widely from country to country.",
    "The flora and fauna of a country contribute to its biodiversity.",
    "Countries facing the ocean have a rich tradition of seafaring.",
    "Cuisine from different countries often includes a variety of spices.",
    "Countries promote their language and culture through educational exchanges.",
    "The spiritual life in countries can vary widely among the population.",
    "Each country has a history marked by significant events and epochs.",
    "The national anthems of countries evoke patriotism and unity.",
    "Countries with rivers often develop rich agricultural and cultural societies.",
    "Local markets in countries are melting pots of tradition and trade.",
    "Startups and entrepreneurship are thriving in various forward-thinking countries.",
    "Fashion trends in different countries can influence global styles.",
    "Countries implement measures to protect their wildlife and ecosystems.",
    "In many countries, handcrafting skills are passed down through generations."
]

In [40]:
sentences_cccc = []

for i in range(len(sentences_cc)):
    for j in range(len(sentences_cc)):
        if i != j and i // 13 != j // 13:
            sentences_cccc.append(sentences_cc[i] + ' ' + sentences_cc[j])
            
sentences_cccc = random.sample(sentences_cccc, 1000)

In [41]:
all_sentences = sentences_cc + sentences_c + sentences + sentences_cccc

In [85]:
len(sentences_cc)

130

In [86]:
sentences_cccc[5]

'Brazil functions as heart of Brasilia Turkmenistan operates as center for Ashgabat'

In [59]:
cleaned_sentences = []

def clean_text(x):
    temp = x.lower()
    if temp.endswith("'s"):
        temp = temp[:-2]
    return temp
    
for sentence in all_sentences:
    tokens = [clean_text(x) for x in sentence.split()]
    cleaned_sentences.append(tokens)

In [60]:
vocab = set([item for sublist in cleaned_sentences for item in sublist])

In [64]:
vocab_map = {}

cnt = 1

for x in c['country'].str.lower():
    vocab_map[x] = cnt
    cnt += 1
    
for x in c['capital'].str.lower():
    vocab_map[x] = cnt
    cnt += 1
    
for x in vocab:
    if x not in vocab_map:
        vocab_map[x] = cnt
        cnt += 1

In [66]:
cleaned_sentences_number = []

for sentence in cleaned_sentences:  
    temp = []
    for word in sentence:
        temp.append(vocab_map[word])
    
    cleaned_sentences_number.append(temp)

In [67]:
from tensorflow.keras.preprocessing.sequence import pad_sequences
max_sequence_len = max(len(sentence) for sentence in cleaned_sentences_number)
x_train_padded = pad_sequences(cleaned_sentences_number, maxlen=max_sequence_len, padding='post')

n_cat = len(vocab_map)

In [68]:
x_train_padded[0]

array([  1, 357, 205, 389, 161,  11,   0,   0,   0,   0,   0,   0,   0],
      dtype=int32)

In [69]:
x_train = np.array(x_train_padded)
n_cat = len(vocab_map)
np.random.shuffle(x_train)
x_masked_train = np.copy(x_train)
x_masked_train = x_masked_train[:,:-1]
y_masked_labels_train = np.copy(x_train)
y_masked_labels_train = y_masked_labels_train[:,1:]

In [73]:
# Building the model

embed_dim = 100
num_heads = 2
num_blocks = 5

batch_size = 1024

input_layer = layers.Input(shape=(x_masked_train.shape[1],), dtype=tf.int32)  # Input layer

embedding_layer = layers.Embedding(n_cat + 1, embed_dim, name="word_embedding")(input_layer)  # Embedding layer
position_embeddings = PositionEmbedding(sequence_length=len(x_masked_train[0]))(embedding_layer)
embedding_layer = embedding_layer + position_embeddings

# Transformer blocks with causal masking for next token prediction
x = embedding_layer
for i in range(num_blocks):
    # Apply the causal mask to ensure that each position can only attend to known tokens
    attention_output = layers.MultiHeadAttention(
        num_heads=num_heads,
        key_dim=embed_dim // num_heads
    )(x, x, x, use_causal_mask=True)
    
    x = layers.Add()([x, attention_output])  # Skip Connection
    x = layers.LayerNormalization(epsilon=1e-6)(x)
    
    ff_net = keras.models.Sequential([
        layers.Dense(2 * embed_dim, activation='relu'),
        layers.Dense(embed_dim),
    ])

    # Apply Feedforward network
    x = ff_net(x)

    # Add & Normalize
    x = layers.Add()([attention_output, x]) 
    x = layers.LayerNormalization(epsilon=1e-6)(x)

# Output layer for providing predictions over the vocabulary
predict_layer = layers.Dense(n_cat, activation='softmax')(x)

model = keras.models.Model(inputs=input_layer, outputs=predict_layer)  # Model definition
model.compile(optimizer=keras.optimizers.Adam(), loss='sparse_categorical_crossentropy', metrics=['accuracy'],
             weighted_metrics=[])  # Compile the model

# Reshape the target data to have an extra dimension
y_masked_labels_train_reshaped = y_masked_labels_train.reshape(y_masked_labels_train.shape[0], 
                                                               y_masked_labels_train.shape[1], 1)

target_mask = np.where(y_masked_labels_train_reshaped == 0, 0, 1)

y_masked_labels_train -= 1
y_masked_labels_train[y_masked_labels_train < 0] = 0

In [74]:
callback = keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

history = model.fit(x_masked_train, y_masked_labels_train_reshaped, 
                    sample_weight=target_mask, epochs=1000, batch_size=batch_size,
                    validation_split = 0.5, callbacks = [callback])

Epoch 1/1000
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 26/1000
Epoch 27/1000
Epoch 28/1000
Epoch 29/1000
Epoch 30/1000
Epoch 31/1000
Epoch 32/1000
Epoch 33/1000
Epoch 34/1000
Epoch 35/1000
Epoch 36/1000
Epoch 37/1000
Epoch 38/1000
Epoch 39/1000
Epoch 40/1000
Epoch 41/1000
Epoch 42/1000
Epoch 43/1000
Epoch 44/1000
Epoch 45/1000
Epoch 46/1000
Epoch 47/1000
Epoch 48/1000
Epoch 49/1000
Epoch 50/1000
Epoch 51/1000
Epoch 52/1000
Epoch 53/1000
Epoch 54/1000
Epoch 55/1000
Epoch 56/1000
Epoch 57/1000
Epoch 58/1000
Epoch 59/1000
Epoch 60/1000
Epoch 61/1000
Epoch 62/1000
Epoch 63/1000
Epoch 64/1000
Epoch 65/1000
Epoch 66/1000
Epoch 67/1000
Epoch 68/1000
Epoch 69/1000
Epoch 70/1000
Epoch 71/1000
Epoch 72/1000
E

In [80]:
np.random.seed(2024)
random.seed(2024)

def get_accuracy(num_ex): ### num_ex = 1 to 5
    acc_a = []
    samples = [random.sample([x for x in range(1,11)], num_ex + 1) for _ in range(500)]
    for sample in samples:
        temp = sample[:-1]
        targ = sample[-1]
        temp_a = []
        for x in temp:
            temp_a += [x, x + c.shape[0]]
        temp_a += [targ]
        targ_a = targ + c.shape[0]
        temp_a += [0] * (len(x_masked_train[0]) - len(temp_a))
        pred_a = keras.backend.function(inputs = model.layers[0].input, outputs = model.layers[-1].output) \
                (np.array(temp_a).reshape(1,len(temp_a)))
        if pred_a[:,(2 * num_ex),:][:(2 * c.shape[0])].argmax() + 1 == targ_a:
            acc_a.append(1)
    return np.sum(acc_a)/500

In [81]:
print(get_accuracy(1))
print(get_accuracy(2))
print(get_accuracy(3))
print(get_accuracy(4))
print(get_accuracy(5))

0.0
0.0
0.0
0.0
0.0
