<a href="https://colab.research.google.com/github/psriraj17/Bantu_language-model/blob/bantu-ssr/swahili_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

The entire code in this file is extracted from 'https://github.com/gregfromstl/bantu-language-*modeling*'

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
PARAMS = {
    'experiment_name': "Swahili",
    'tags': ["swahili", "from scratch"],
    'n': 1000,
    'threshold': 750,
    'train_iterations': 2,
    'carry_hidden_state': False,
    'val_split': 0.3,
    'swahili_train': "/content/drive/MyDrive/Colab Notebooks/train-04/sw-train.txt",
    'test_data': "/content/drive/MyDrive/test04/sw-test.txt"
}

In [3]:
import math

In [4]:
class Dataset():
    def __init__(self, raw_data: str):
        self.chars = set(list(set(raw_data)))
        self.chars.add('~')
        self.data_size, self.vocab_size = len(raw_data), len(self.chars)
        print("{} characters, {} unique".format(self.data_size, self.vocab_size))
        
        self.char_to_idx = { char: idx for idx, char in enumerate(self.chars) }
        self.idx_to_char = { idx: char for idx, char in enumerate(self.chars) }
        
        self.data = [self.char_to_idx[char] for char in list(raw_data)]
    
    def __len__(self):
        return self.data_size
    
    def __getitem__(self, index):
        return self.data[index]


In [5]:
def clean_data(raw_data: str, known_chars: str) -> str:
    cleaned = ""
    for char in raw_data:
        if char not in known_chars:
            cleaned += "~"
        else:
            cleaned += char
    return cleaned

In [6]:
print("Loading Swahili training data:", end="\n\t")
raw_swahili = open(PARAMS['swahili_train'], 'r').read()[:10000000]
swahili_train_size, swahili_val_size = int(len(raw_swahili)*(1-PARAMS['val_split'])), int(len(raw_swahili)*PARAMS['val_split'])

train_data = Dataset(raw_swahili[:swahili_train_size])

print("Loading Swahili validation data:", end="\n\t")
cleaned_swahili_val_data = clean_data(raw_swahili[swahili_train_size:], train_data.chars)
val_data = Dataset(cleaned_swahili_val_data)


if len(PARAMS['test_data']) > 0:
    print("Loading Testing data:", end="\n\t")
    raw_test = open(PARAMS['test_data'], 'r').read()

    cleaned_test_data = clean_data(raw_test, train_data.chars)
    test_data = Dataset(cleaned_test_data)

Loading Swahili training data:
	7000000 characters, 49 unique
Loading Swahili validation data:
	3000000 characters, 49 unique
Loading Testing data:
	3451383 characters, 49 unique


In [7]:
class CountMatrix:
    def __init__(self, vocab: list, init_matrix=None):
        self.counts = init_matrix if init_matrix is not None else {i:0 for i in vocab}
        self.next = {i:None for i in vocab}

In [8]:
def increment_count(char: str, sequence: list, count_matrix: CountMatrix) -> list:
    next_char = sequence[-1]
    
    count_matrix.counts[char] += 1
    if count_matrix.next[next_char] is not None:
        count_matrix.next[next_char] = increment_count(char, sequence[:-1], count_matrix.next[next_char])
    elif sum(count_matrix.counts.values()) > PARAMS['threshold']:
        vocab = count_matrix.next.keys()
        initial_matrix = {i:0 for i in vocab}
        initial_matrix[char] += 1
        count_matrix.next = {i:CountMatrix(vocab, initial_matrix) for i in vocab}
    
    return count_matrix

In [9]:
def iterate_counts(data: Dataset, n: int, count_matrix: CountMatrix):
    for idx, char in enumerate(data[n:]):
        idx = n + idx
        sequence = data[idx-n:idx]
        
        count_matrix = increment_count(data[idx], sequence, count_matrix)
    return count_matrix

In [10]:
print("Building Matrix...")
count_matrix = CountMatrix(vocab=train_data.idx_to_char.keys())

print("Fitting...")
for i in range(PARAMS['train_iterations']):
    print("Iteration {}".format(i+1))
    count_matrix = iterate_counts(train_data, PARAMS['n'], count_matrix)

Building Matrix...
Fitting...
Iteration 1
Iteration 2


In [11]:
def probabilities_from_counts(counts: dict):
    # add one smoothing
    counts = {key:counts[key]+1 for key in counts.keys()}
    
    probabilities = {key: counts[key] / sum(counts.values()) for key in counts.keys()}
    prob_sum = sum(probabilities.values())
    assert(abs(prob_sum - 1) < 0.0001), "Probabilities should sum to 1.0 but got {}".format(prob_sum)
    
    return probabilities

In [12]:
def get_probabilities_for_sequence(sequence: list, count_matrix: CountMatrix):
    # return counts if sequence has been exhausted
    if len(sequence) == 0:
        return count_matrix.counts
    
    next_char = sequence[-1]
    
    if count_matrix.next[next_char] is not None:
        return get_probabilities_for_sequence(sequence[:-1], count_matrix.next[next_char])
    else:
        return probabilities_from_counts(count_matrix.counts)

In [13]:
def calc_loss(target_prob):
    return -math.log(target_prob, 2)

In [14]:
def eval(data: Dataset, n: int, count_matrix: CountMatrix):
    print("Evaluating...")
    
    counter = 0
    running_loss = 0
    running_acc = 0
    
    for idx, char in enumerate(data[n:]):
        idx = n + idx
        sequence = data[idx-n:idx]

        probabilities: dict = get_probabilities_for_sequence(sequence, count_matrix)
        pred: str = max(probabilities, key=probabilities.get)
        target: str = data[idx]
        target_prob: float = probabilities[target]
        
        running_loss += calc_loss(target_prob)
        running_acc += 1 if target == pred else 0
        counter += 1
        
    return running_loss / counter, running_acc / counter

In [15]:
train_loss, train_acc = eval(train_data, PARAMS['n'], count_matrix)
print("Train Loss: {:.3f}\t\t|\tTrain Accuracy: {:.2f}%".format(train_loss, train_acc*100))

val_loss, val_acc = eval(val_data, PARAMS['n'], count_matrix)
print("Validation Loss: {:.3f}\t\t|\tValidation Accuracy: {:.2f}%".format(val_loss, val_acc*100))

Evaluating...
Train Loss: 2.373		|	Train Accuracy: 58.43%
Evaluating...
Validation Loss: 3.389		|	Validation Accuracy: 43.35%


In [16]:
if len(PARAMS['test_data']) > 0:
    test_loss, test_acc = eval(test_data, PARAMS['n'], count_matrix)
    print("Test Loss: {:.3f}\t\t|\tTest Accuracy: {:.2f}%".format(test_loss, test_acc*100))

Evaluating...
Test Loss: 3.383		|	Test Accuracy: 43.52%
