<a href="https://colab.research.google.com/github/vvikasreddy/lexically_constrained_beam_search_/blob/main/constraints.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### This python script is used to generate constraints from the WMT dataset (Turkish to English)

#### Installing and importing essential libraries

In [1]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.1.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (1

In [2]:
# importing necessary libraries

from datasets import load_dataset
from collections import defaultdict
from tqdm import tqdm

#### Loading the WMT dataset

In [3]:
# loading the wmt dataset, turkish to english

ds = load_dataset("wmt/wmt16", "tr-en")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/11.1k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/36.8M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/153k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/466k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/205756 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1001 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3000 [00:00<?, ? examples/s]

#### Helper functions to get the constraints

In [4]:
def get_ngrams(src, n = 2, ):
  """
  The function returns the n_grams of length n, from the given sentence.

  args:
    src: takes in a sentence
    n: represents the size of ngram

  returns:
    returns the specified ngrams size"""

  src = src.split(" ")
  src = [tuple(src[i:i+n]) for i in range(len(src) - n + 1)]
  return src

In [5]:
def constraints():

  """
  returns
    dict_pairs_count: total co-occurence of ngrams
    src_ngrams_count: dictionary containing occurences of unique ngrams in source language (turkish)
    tgt_ngrams_count: dictionary containing occurences of unique ngrams in target language (english)
    tgt_ngrams_count: total number of words in target language (english)
    tgt_ngrams_count: total number of words in source language (turkish)
  """

  src_ngrams_count = defaultdict(int)
  tgt_ngrams_count = defaultdict(int)
  dict_pairs = defaultdict(int)

  dict_pairs_count = 0
  total_src_words = 0
  total_tgt_words = 0


  count = 0
  n_gram = 2
  for doc in tqdm(ds["train"]["translation"]):

    # get the src, and tgt
    src = doc["tr"]
    tgt = doc["en"]

    # get the ngrams, default n = 2
    src_ngrams = get_ngrams(src, n = n_gram)
    tgt_ngrams = get_ngrams(tgt, n = n_gram)


    # count the occurences of each of the ngram
    for tgt_ngram in tgt_ngrams:
      total_tgt_words += n_gram
      tgt_ngrams_count[tgt_ngram] += 1

    for src_ngram in src_ngrams:
      total_src_words += n_gram
      src_ngrams_count[src_ngram] += 1

    # # count the combined co-occurence of the ngram
    for src_ngram in src_ngrams:
      for tgt_ngram in tgt_ngrams:
        dict_pairs_count+= min(src_ngrams_count[src_ngram], tgt_ngrams_count[tgt_ngram])

  return dict_pairs_count,  src_ngrams_count, tgt_ngrams_count, total_src_words, total_tgt_words



In [6]:
def get_count_dict_pairs(filtered_src, filtered_tgt, src_ngrams_count, tgt_ngrams_count):
  """
  args:
    filtered_src: ngrams in source language, who have repeated more than 3 times
    filtered_tgt: ngrams in target language, who have repeated more than 3 times
    src_ngrams_count: dictionary containing occurences of unique ngrams in source language (turkish)
    tgt_ngrams_count: dictionary containing occurences of unique ngrams in target language (english)
  returns
    dict_pairs: returns the dictionary containing the ngrams and their count
  """
  dict_pairs = defaultdict(int)

  count = 0
  for doc in tqdm(ds["train"]["translation"]):

    # get the src, and tgt
    src = doc["tr"]
    tgt = doc["en"]

    # get the ngrams, default n = 2
    src_ngrams = get_ngrams(src)
    tgt_ngrams = get_ngrams(tgt)

    # # count the combined co-occurence of the ngram

    for src_ngram in src_ngrams:
      if src_ngram not in filtered_src: continue
      for tgt_ngram in tgt_ngrams:

        if tgt_ngram not in filtered_tgt: continue
        dict_pairs[(src_ngram, tgt_ngram)] += min(src_ngrams_count[src_ngram], tgt_ngrams_count[tgt_ngram])

  return dict_pairs
# dict_pairs =  get_count_dict_pairs()

In [7]:
import math

def calculate_pmi(dict_pairs_count, dict_pairs, src_ngrams_count, tgt_ngrams_count, total_tgt_words, total_src_words):

  """
  This function returns the turkish sentences and corresponding english sentences of ngram size 2,
  it retains the ngrams whose score is greater than 0.9

  """
  ignored_count = 0
  pmi_dict = {}
  count = 0

  total_dict_pairs = dict_pairs_count

  for (src_ngram, tgt_ngram), count in tqdm(dict_pairs.items()):
    p_src_tgt = dict_pairs[(src_ngram, tgt_ngram)] / total_dict_pairs

    p_src = src_ngrams_count[src_ngram] / total_src_words
    p_tgt = tgt_ngrams_count[tgt_ngram] / total_tgt_words

    if p_src == 0 or p_tgt == 0:
      pmi = 0
    else:
      pmi = math.log2(p_src_tgt / (p_src * p_tgt))/ -math.log2(p_src_tgt)

    if pmi > 0.9:
      pmi_dict[(src_ngram, tgt_ngram)] = pmi
    else:
      ignored_count += 1

  return pmi_dict


In [8]:
def get_constraints():

  """
  returns the constraints, who pmi value is greater than 0.9.
  """

  dict_pairs_count, src_ngrams_count, tgt_ngrams_count, total_src_words, total_tgt_words =  constraints()
  min_count = 3
  # dict_pairs_count, dict_pairs, src_ngrams_count, tgt_ngrams_count, total_src_words, total_tgt_words


  filtered_src = {k: v for k, v in src_ngrams_count.items() if v >= min_count}
  filtered_tgt = {k: v for k, v in tgt_ngrams_count.items() if v >= min_count}

  dict_pairs =  get_count_dict_pairs( src_ngrams_count=src_ngrams_count, tgt_ngrams_count= tgt_ngrams_count, filtered_src=filtered_src, filtered_tgt=filtered_tgt)



  pmi_scores = calculate_pmi(dict_pairs_count, dict_pairs, filtered_src, filtered_tgt, total_tgt_words, total_src_words)


  # dictionary to only return the pair with max PMI value; since, there might be a different tgt value for the same src.
  max_pmi_dict = {}


  for (src, tgt), pmi_value in pmi_scores.items():
      if src not in max_pmi_dict:
          max_pmi_dict[src] = (tgt, pmi_value)
      elif pmi_value > max_pmi_dict[src][1]:
        max_pmi_dict[src] = (tgt, pmi_value)

  return max_pmi_dict

#### call the get_constraints() function, to get the constraints

In [9]:
pmi_scores = get_constraints()

100%|██████████| 205756/205756 [00:33<00:00, 6203.14it/s]
100%|██████████| 205756/205756 [01:03<00:00, 3226.84it/s]
100%|██████████| 26221852/26221852 [00:34<00:00, 751259.05it/s]


In [10]:
pmi_scores

{('Southeast', 'European'): (('Southeast', 'European'), 2.8832729719202517),
 ('European', 'Times'): (('European', 'Times'), 2.7325161234203272),
 ('Times', 'için'): (('for', 'Southeast'), 2.691656405976094),
 ('için', "Priştine'den"): (('in', 'Pristina'), 1.2344634523497555),
 ("Priştine'den", 'Muhamet'): (('By', 'Muhamet'), 0.9445522692748787),
 ('haberi', '--'): (('for', 'Southeast'), 1.85203947974123),
 ('yıl', 'önce'): (('years', 'ago,'), 0.9117258264951043),
 ("SETimes'a", 'konuşan'): (('told', 'SETimes.'), 1.3187489471640277),
 ('son', 'yıllarda'): (('in', 'recent'), 0.9050222137232432),
 ('Meclis', 'Başkanı'): (('Parliament', 'Speaker'), 1.0982522598673903),
 ('için', "Üsküp'ten"): (('in', 'Skopje'), 1.248587588170292),
 ("SETimes'a", 'verdiği'): (('told', 'SETimes.'), 1.328041226652118),
 ('verdiği', 'demeçte,'): (('told', 'SETimes.'), 1.387168566737143),
 ('olmak', 'üzere'): (('in', 'the'), 1.131720651701281),
 ('suç', 've'): (('crime', 'and'), 1.0859892472321813),
 ('milyon'

In [11]:
len(pmi_scores)

570

In [12]:
max_entry = max(pmi_scores.items(), key=lambda x: x[1][1])
print("Entry with max value:", max_entry)

Entry with max value: (('Southeast', 'European'), (('Southeast', 'European'), 2.8832729719202517))
