<a href="https://colab.research.google.com/github/vvikasreddy/lexically_constrained_beam_search_/blob/main/beam_search.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

References:

Marian MT model : https://huggingface.co/docs/transformers/model_doc/marian

Code to get the logits : https://huggingface.co/docs/transformers/main_classes/output

to get the BOS and EOS tokens: https://huggingface.co/docs/transformers/main_classes/configuration#transformers.PretrainedConfig.decoder_start_token_id

get topk values : https://pytorch.org/docs/stable/generated/torch.topk.html

ideas and core implementation drawn from this paper: https://arxiv.org/pdf/1704.07138

reference to link google colab with .py file from git : https://colab.research.google.com/github/jckantor/cbe61622/blob/master/docs/A.02-Downloading_Python_source_files_from_github.ipynb


## downloading essential modules

In [1]:
!pip install datasets



## Importing necessary libraries

In [2]:
import torch, random
from datasets import load_dataset
from transformers import MarianMTModel, MarianTokenizer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

## loading the dataset, considering the wmt turkish - english translation

In [3]:
ds = load_dataset("wmt/wmt16", "tr-en")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


## Glancing the organization of the dataset

In [4]:
ds

DatasetDict({
    train: Dataset({
        features: ['translation'],
        num_rows: 205756
    })
    validation: Dataset({
        features: ['translation'],
        num_rows: 1001
    })
    test: Dataset({
        features: ['translation'],
        num_rows: 3000
    })
})

In [5]:
ds['train'][0]

{'translation': {'en': "Kosovo's privatisation process is under scrutiny",
  'tr': "Kosova'nın özelleştirme süreci büyüteç altında"}}

## Loading the tokenizer and model, based of Marian-NMT

In [6]:
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-tr-en")
model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-tr-en")



In [7]:
model

MarianMTModel(
  (model): MarianModel(
    (shared): Embedding(62389, 512, padding_idx=62388)
    (encoder): MarianEncoder(
      (embed_tokens): Embedding(62389, 512, padding_idx=62388)
      (embed_positions): MarianSinusoidalPositionalEmbedding(512, 512)
      (layers): ModuleList(
        (0-5): 6 x MarianEncoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation_fn): SiLU()
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (final_layer_norm): LayerNorm((512,), eps=1e-05

In [9]:
# def generate_translation(src_text, max_length = 50):

#   # Tokenize input
#   encoder_inputs = tokenizer(src_text, return_tensors="pt")

#   # intializes the decoder input with decoder start token
#   decoder_input = torch.tensor([[model.config.decoder_start_token_id]])

#   # chang the model to eval mode.
#   model.eval()
#   with torch.no_grad():

#     generated_tokens = []

#     while len(generated_tokens) < max_length:

#       outputs = model(
#           input_ids=encoder_inputs.input_ids,
#           attention_mask=encoder_inputs.attention_mask,
#           decoder_input_ids=decoder_input
#       )


#       next_token_logits = outputs.logits[:, -1, :]


#       # get the token with maximum logits value
#       next_token = torch.argmax(next_token_logits, dim=-1)

#       generated_tokens.append(next_token.item())

#       if next_token.item() == tokenizer.eos_token_id:
#           break
#       print(decoder_input)
#       decoder_input = torch.cat([decoder_input, next_token.unsqueeze(0)], dim=1)
#       print(decoder_input)
#     translated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
#   return translated_text


In [10]:

### step by Step generation
# for i in range(len(generated_tokens)):
#     partial_text = tokenizer.decode(generated_tokens[:i+1], skip_special_tokens=True)
#     print(f"Step {i+1}: {partial_text}")

In [11]:
model.config.num_beams

6

In [12]:
# print(generate_translation(ds['validation'][1]['translation']['tr']))

In [13]:
ds["validation"][1]["translation"]["tr"]

"Norveç'in beş milyon insanı en yüksek yaşam standartlarının tadını çıkarıyor, sadece Avrupa'da değil, dünyada."

In [14]:
ds["validation"][1]["translation"]["en"]

"Norway's five million people enjoy one of the highest standards of living, not just in Europe, but in the world."

## Extracting the constraints

In [15]:
# code to import constraints and store in a local directory, from my git

user = "vvikasreddy"
repo = "lexically_constrained_beam_search_"
pyfile = "constraints.py"

# i.e url is "https://github.com/vvikasreddy/lexically_constrained_beam_search_/blob/main/constraints.py"

url = f"https://raw.githubusercontent.com/{user}/{repo}/main/{pyfile}"
!wget --no-cache --backups=1 {url}

import constraints

--2024-12-04 20:17:24--  https://raw.githubusercontent.com/vvikasreddy/lexically_constrained_beam_search_/main/constraints.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4914 (4.8K) [text/plain]
Saving to: ‘constraints.py’


2024-12-04 20:17:24 (52.2 MB/s) - ‘constraints.py’ saved [4914/4914]



In [16]:
# takes almost 4 minutes to get the constraints, you will see 3 progress bars
c = constraints.get_constraints()

100%|██████████| 205756/205756 [00:57<00:00, 3585.56it/s]
100%|██████████| 205756/205756 [01:39<00:00, 2062.50it/s]
100%|██████████| 26221852/26221852 [00:58<00:00, 448485.10it/s]


In [22]:
print("some of the constraints are :")

# Extract 5 random keys
random_keys = random.sample(list(c.keys()), 5)

for key in random_keys:
  print(key, c[key])

print("The length of the constraints is", len(c))

some of the constraints are :
('ABD', 'Dışişleri') (('of', 'State'), 1.1697696008746603)
('eski', 'Bosnalı') (('former', 'Bosnian'), 1.0410855671492016)
('Orta', 've') (('Central', 'and'), 0.993313184724903)
('Başbakanı', 'Vojislav') (('Serbian', 'Prime'), 1.0148018894754143)
("Jovanoviç'in", 'haberi') (('Jovanovic', 'for'), 1.1927609036257918)
The length of the constraints is 570


In [53]:
def generate_translation(src_text, decoder_input = [], probabilites = [], get_constrained_token_probability = -1, k = 5):

  """
    returns decoder_input_tokens, probs, vis_data

    decoder_input_tokens : next top k tokens, or constraints probability if get_constrained_token_probabliity != -1
    probs : corresponding probablities of decoder_input_tokens
    vis_data : top k beams

    generate the decoder input ids and corresponding probabilities
    src_text : It is the source text
    decoder_input : Represents the decoder tokens
    probabilities : Represents corresponding decoder token probablities
    get_constrained_token_probabliity : holds the constraint, -1 indicates no constraint,
    k : number of beams to be generated, default is 5
  """


  decoder_input_tokens = []
  probs = []
  vis_data = []

  # Tokenize input
  encoder_inputs = tokenizer(src_text, return_tensors="pt")

  # if decoder_input is empty, then include the decoder start token
  if decoder_input == []:
    # intial decoder start token has probability 1
    probabilites = torch.tensor([[1]])
    decoder_input = torch.tensor([[model.config.decoder_start_token_id]])

  # change the model to eval mode and stop the computation of gradients.
  model.eval()
  with torch.no_grad():

    generated_tokens = []

    outputs = model(
        input_ids=encoder_inputs.input_ids,
        attention_mask=encoder_inputs.attention_mask,
        decoder_input_ids=decoder_input
    )

    # gets the most frequenlty generated token.
    next_token_logits = outputs.logits[:, -1, :]

    # constraint, if provided, returns the probability.
    if get_constrained_token_probability != -1:
      softmax_  = torch.softmax(next_token_logits, dim=-1)
      return softmax_[0][get_constrained_token_probability]

    # get the top k tokens with maximum logits value
    top_probs, top_indices = torch.topk(torch.softmax(next_token_logits, dim=-1), k = k)

  for indx, id in enumerate(top_indices[0]):
    decoder_input_tokens.append(torch.cat([decoder_input, id.unsqueeze(0).unsqueeze(0)], dim=1))
    probs.append(torch.cat([probabilites, top_probs[0][indx].unsqueeze(0).unsqueeze(0)], dim=1))
    vis_data.append((vis_data, tokenizer.decode(decoder_input_tokens[indx].squeeze(), skip_special_tokens = True)))

  return decoder_input_tokens, probs, vis_data

x,y,z = generate_translation(ds['validation'][1]['translation']['tr'], decoder_input = torch.tensor([[62388,  1969]]), probabilites = torch.tensor([[0.0000, 0.0000]]))
print( x)
print(y)
print(z)
# c

[tensor([[62388,  1969,   261]]), tensor([[62388,  1969,    15]]), tensor([[62388,  1969,   510]]), tensor([[62388,  1969,     8]]), tensor([[62388,  1969,    47]])]
[tensor([[0.0000, 0.0000, 0.3950]]), tensor([[0.0000, 0.0000, 0.1113]]), tensor([[0.0000, 0.0000, 0.0247]]), tensor([[0.0000, 0.0000, 0.0110]]), tensor([[0.0000, 0.0000, 0.0105]])]
[([...], 'Koso'), ([...], 'Kost'), ([...], 'Kosum'), ([...], 'Koss'), ([...], 'Kosa')]


In [19]:
# def get_ngrams(src, n = 2, ):

#   src = src.split(" ")
#   src = [tuple(src[i:i+n]) for i in range(len(src) - n + 1)]

#   return src

# def constraints_tokens(src, c):
#   ngrams = get_ngrams(src)
#   constraints_src = []
#   for ngram in ngrams:
#     # print(ngram)
#     if ngram in c:
#       for gram in ngram:

#         constraints_src.append(tokenizer(gram, return_tensors="pt"))
#   return constraints_src

# def get_input_ids(data):

#   input_ids = []
#   for example in data:
#     input_ids.append((example['input_ids'][0].tolist())[0])
#   return input_ids

In [20]:
# x = constraints_tokens(ds['validation'][1]['translation']['tr'] + ' Times' +  ' için', c)
# # c
# print(x)
# print(get_input_ids(x))

In [26]:
def visualize_data(decoder_input):
  return tokenizer.decode(decoder_input.squeeze(), skip_special_tokens = True)

In [73]:
import torch

def get_top_k_prob(A, B, k=2):

  d = {}
  # cummulative sum
  for indx, val in enumerate(B):
    print(val, type(val))
    cum_sum = torch.prod(val)
    d[cum_sum] = indx

  sorted_keys = sorted(d.keys(), reverse = True)

  top_k_indices = []
  top_k_sequences = []

  for key in sorted_keys[:k]:
    top_k_indices.append(A[d[key]])
    top_k_sequences.append(B[d[key]])

  return top_k_sequences, top_k_indices

# sanity
k = 2

# A = [torch.tensor([[62388,   626,    13]]), torch.tensor([[62388,   626,     9]]), torch.tensor([[62388,   626,  1341]]), torch.tensor([[62388,   626,    27]])]
# B = [torch.tensor([[1.0000, 0.0038, 0.2500]]), torch.tensor([[1.0000, 0.0038, 0.0619]]), torch.tensor([[1.0000, 0.0038, 0.0474]]), torch.tensor([[1.0000, 0.0038, 0.0425]])]


A = [torch.tensor([[62388,  1969]]), torch.tensor([[62388,   323]]), torch.tensor([[62388,    67]]), torch.tensor([[62388,  1132]]), torch.tensor([[62388,   626]])]
B = [torch.tensor([[1.0000, 0.8746]]), torch.tensor([[1.0000, 0.0156]]), torch.tensor([[1.0000, 0.0114]]), torch.tensor([[1.0000, 0.0039]]), torch.tensor([[1.0000, 0.0038]])]

top_sequences, indices = get_top_k_prob(A, B, k)
print(f"Top {k} sequences:", top_sequences)
print("Their indices:", indices)

tensor([[1.0000, 0.8746]]) <class 'torch.Tensor'>
tensor([[1.0000, 0.0156]]) <class 'torch.Tensor'>
tensor([[1.0000, 0.0114]]) <class 'torch.Tensor'>
tensor([[1.0000, 0.0039]]) <class 'torch.Tensor'>
tensor([[1.0000, 0.0038]]) <class 'torch.Tensor'>
Top 2 sequences: [tensor([[1.0000, 0.8746]]), tensor([[1.0000, 0.0156]])]
Their indices: [tensor([[62388,  1969]]), tensor([[62388,   323]])]


In [72]:
def beam_search(maxlen, numC, k, src, constraints):

    decoder_start_token = model.config.decoder_start_token_id

    # initialize the grids
    grids = [[[] for _ in range(numC + 1)] for _ in range(maxlen + 1)]
    probs_grid  = [[[] for _ in range(numC + 1)] for _ in range(maxlen + 1)]

    # intialize the first grid to start hyp
    grids[0][0] = [1]

    # remove during testsrc
    # constrained_tokens = get_input_ids(constraints_tokens(, constraints))
    # temporary
    constrained_tokens = [3762, 37]

    generated_constraint_index = 0

    for t in range(1, maxlen):

      index_c = max(0, (numC - t) - maxlen)

      g = torch.tensor([])
      s = torch.tensor([])
      probs = torch.tensor([])

      for c in range(index_c, min(t, numC) + 1):

          print("cur iteration ", t, c )
          s = []
          g = []

          # storing decoder inputs
          decoder_inputs = []
          probs = []
          vis_data = []

          # print(grids)
          print(probs_grid[t-1][c], "hey yo ", t - 1, c, "-------------------")
          for indx, element in enumerate(grids[t-1][c]):

            # guess there is no need for conditioning, just generate.
            print(element, "this is the element", )
            if type(element) == int:
              decoder_input = []
              probs =[]
            else:
              decoder_input = element
              probs = probs_grid[t-1][c][indx]

            # print(element)
            # print(decoder_input)
            # print(probs)
            # print("----------------------------")
            t_g, t_probs,vis_data = generate_translation(src_text= src, decoder_input = decoder_input, probabilites = probs)
            print(type(g))
            g.append(t_g)
            print(probs)
            print(type(probs))
            probs.append(t_probs)
            # g += t_g
            # probs += t_probs
            print(probs)
            print(vis_data)
            print("here")

          # retrieve the  probability of the constraint and add that to the decoder_input.
          if c > 0 and constrained_tokens:

            for indx, element in enumerate(grids[t-1][c-1]):

              if c == 1 and t == 1:
                decoder_inputs = torch.tensor([[model.config.decoder_start_token_id]])
                prob = torch.tensor([[1]])
              else:
                decoder_inputs = element
                prob = probs_grid[t-1][c-1][indx]

              # Gets the constraints probability and stores them
              cons = generate_translation(src, decoder_input = decoder_input, get_constrained_token_probability = constrained_tokens[c - 1])
              # print(cons, decoder_inputs, element)
              decoder_inputs = torch.cat([decoder_inputs, torch.tensor(constrained_tokens[c-1]).unsqueeze(0).unsqueeze(0)], dim=1)
              prob = torch.cat([prob, torch.tensor(cons).unsqueeze(0).unsqueeze(0)], dim=1)


              g.append(decoder_inputs)

              probs.append(prob)
          print("before")
          print(g, "g")
          print(probs, "probs")
          probs_grid[t][c], grids[t][c] = get_top_k_prob(g, probs, k)


          # print(grids[t][c], t, c )
          # print(probs_grid[t][c])
          for i in  grids[t][c]:
            print(visualize_data(i), i )


          print('asd-------------------------------------------')

  # sanity : print grids
    print(grids)


beam_search(maxlen= 5, numC=0, k =6, src = ds["train"][1]["translation"]["tr"], constraints = c)

cur iteration  1 0
[] hey yo  0 0 -------------------
1 this is the element
<class 'list'>
[]
<class 'list'>
[[tensor([[1.0000, 0.8746]]), tensor([[1.0000, 0.0156]]), tensor([[1.0000, 0.0114]]), tensor([[1.0000, 0.0039]]), tensor([[1.0000, 0.0038]])]]
[([...], 'Kos'), ([...], 'In'), ([...], 'He'), ([...], 'With'), ([...], 'As')]
here
before
[[tensor([[62388,  1969]]), tensor([[62388,   323]]), tensor([[62388,    67]]), tensor([[62388,  1132]]), tensor([[62388,   626]])]] g
[[tensor([[1.0000, 0.8746]]), tensor([[1.0000, 0.0156]]), tensor([[1.0000, 0.0114]]), tensor([[1.0000, 0.0039]]), tensor([[1.0000, 0.0038]])]] probs
tensor([[1.0000, 0.8746]]) <class 'torch.Tensor'>
tensor([[1.0000, 0.0156]]) <class 'torch.Tensor'>
tensor([[1.0000, 0.0114]]) <class 'torch.Tensor'>
tensor([[1.0000, 0.0039]]) <class 'torch.Tensor'>
tensor([[1.0000, 0.0038]]) <class 'torch.Tensor'>


IndexError: list index out of range

In [68]:
def beam_search(maxlen, numC, k, src, constraints):

    decoder_start_token = model.config.decoder_start_token_id

    # initialize the grids
    grids = [[[] for _ in range(numC + 1)] for _ in range(maxlen + 1)]
    probs_grid  = [[[] for _ in range(numC + 1)] for _ in range(maxlen + 1)]

    # intialize the first grid to start hyp
    grids[0][0] = [1]

    # remove during testsrc
    # constrained_tokens = get_input_ids(constraints_tokens(, constraints))
    # temporary
    constrained_tokens = [3762, 37]

    generated_constraint_index = 0

    for t in range(1, maxlen):

      index_c = max(0, (numC - t) - maxlen)

      g = torch.tensor([])
      s = torch.tensor([])
      probs = torch.tensor([])

      for c in range(index_c, min(t, numC) + 1):

          print("cur iteration ", t, c )
          s = []
          g = []

          # storing decoder inputs
          decoder_inputs = []
          probs = []
          vis_data = []

          # print(grids)
          print(probs_grid[t-1][c], "hey yo ", t - 1, c, "-------------------")
          for indx, element in enumerate(grids[t-1][c]):

            # guess there is no need for conditioning, just generate.
            print(element, "this is the element", )
            if type(element) == int:
              decoder_input = []
              probs =[]
            else:
              decoder_input = element
              probs = probs_grid[t-1][c][indx]

            # print(element)
            # print(decoder_input)
            # print(probs)
            # print("----------------------------")
            g, probs,vis_data = generate_translation(src_text= src, decoder_input = decoder_input, probabilites = probs)
            # print(type(g))
            # g.append(t_g)
            # print(probs)
            # print(type(probs))
            # probs.append(t_probs)
            # g += t_g
            # probs += t_probs
            print(probs)
            print(vis_data)
            print("here")

          # retrieve the  probability of the constraint and add that to the decoder_input.
          if c > 0 and constrained_tokens:

            for indx, element in enumerate(grids[t-1][c-1]):

              if c == 1 and t == 1:
                decoder_inputs = torch.tensor([[model.config.decoder_start_token_id]])
                prob = torch.tensor([[1]])
              else:
                decoder_inputs = element
                prob = probs_grid[t-1][c-1][indx]

              # Gets the constraints probability and stores them
              cons = generate_translation(src, decoder_input = decoder_input, get_constrained_token_probability = constrained_tokens[c - 1])
              # print(cons, decoder_inputs, element)
              decoder_inputs = torch.cat([decoder_inputs, torch.tensor(constrained_tokens[c-1]).unsqueeze(0).unsqueeze(0)], dim=1)
              prob = torch.cat([prob, torch.tensor(cons).unsqueeze(0).unsqueeze(0)], dim=1)


              g.append(decoder_inputs)

              probs.append(prob)
          print("before")
          print(g, "g")
          print(probs, "probs")
          probs_grid[t][c], grids[t][c] = get_top_k_prob(g, probs, k)


          # print(grids[t][c], t, c )
          # print(probs_grid[t][c])
          for i in  grids[t][c]:
            print(visualize_data(i), i )


          print('asd-------------------------------------------')

  # sanity : print grids
    print(grids)


beam_search(maxlen= 5, numC=0, k =6, src = ds["train"][1]["translation"]["tr"], constraints = c)

cur iteration  1 0
[] hey yo  0 0 -------------------
1 this is the element
[tensor([[1.0000, 0.8746]]), tensor([[1.0000, 0.0156]]), tensor([[1.0000, 0.0114]]), tensor([[1.0000, 0.0039]]), tensor([[1.0000, 0.0038]])]
[([...], 'Kos'), ([...], 'In'), ([...], 'He'), ([...], 'With'), ([...], 'As')]
here
before
[tensor([[62388,  1969]]), tensor([[62388,   323]]), tensor([[62388,    67]]), tensor([[62388,  1132]]), tensor([[62388,   626]])] g
[tensor([[1.0000, 0.8746]]), tensor([[1.0000, 0.0156]]), tensor([[1.0000, 0.0114]]), tensor([[1.0000, 0.0039]]), tensor([[1.0000, 0.0038]])] probs
Kos tensor([[62388,  1969]])
In tensor([[62388,   323]])
He tensor([[62388,    67]])
With tensor([[62388,  1132]])
As tensor([[62388,   626]])
asd-------------------------------------------
cur iteration  2 0
[tensor([[1.0000, 0.8746]]), tensor([[1.0000, 0.0156]]), tensor([[1.0000, 0.0114]]), tensor([[1.0000, 0.0039]]), tensor([[1.0000, 0.0038]])] hey yo  1 0 -------------------
tensor([[62388,  1969]]) this 

In [33]:
print(ds["train"][1]["translation"]["tr"])
ds["train"][1]["translation"]["en"]

Kosova, tekrar eden şikayetler ışığında özelleştirme sürecini incelemeye alıyor.


'Kosovo is taking a hard look at its privatisation process in light of recurring complaints.'

In [34]:

text = "Kosova, tekrar eden şikayetler ışığında özelleştirme sürecini incelemeye alıyor."

# Tokenize input text
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)

# Generate translation
translated_tokens = model.generate(**inputs)

# Decode and print the translation
translated_text = tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
print("Translated text:", translated_text)

Translated text: Kosovo is reviewing the privatisation process in light of repeated complaints.
