# Exp 019: Decoding-based grammar control
This experiments sets up decoding-strategies.

In [1]:
from dotenv import load_dotenv
load_dotenv()
import os
os.environ['CACHE_DIR'] = f"/scratch/tmp.{os.getenv('SLURM_JOB_ID')}.dglandorf" # speed up model loading

import sys
sys.path.append(f'../source')
import helpers
import models
import evaluation
import importlib
importlib.reload(evaluation)

import torch
import pandas as pd
from tqdm.notebook import tqdm

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

[nltk_data] Downloading package punkt to
[nltk_data]     /cluster/home/dglandorf/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt to
[nltk_data]     /cluster/home/dglandorf/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


## Ground-truth partial scoring
For an efficient decoding, partial sequences should already score high. Then, we do not need to train future discriminators anymore.

In [2]:
# Load a model
model, tokenizer = models.load_generator()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
input_file = "../data/task1_test.json"
output_file = "../data/task1_test_decoding.json"
max_rows=5
if os.path.exists(output_file):
    original_testset = pd.read_json(input_file)
    testset = pd.read_json(output_file)
    cols_to_assert = ['context', 'constraints']
    assert_frame_equal(testset[cols_to_assert], original_testset[cols_to_assert])
else:
    testset = pd.read_json(input_file)
    testset['responses'] = [[]] * len(testset)

condition = testset['responses'].apply(len)==0
max_rows = min(max_rows, len(testset))
remaining_testset = testset[condition]
for idx, case in tqdm(remaining_testset.iterrows(), total=max_rows-(~condition).sum()):
    if idx >= max_rows: break
    prompt = helpers.get_generation_prompt(case)['prompt']
    constraints = case['constraints']

  0%|          | 0/5 [00:00<?, ?it/s]

In [4]:
print(prompt)

[INST] Write the response of A and include these grammatical items in the response:
- superlatives - FORM/USE: 'THE BEST' WITH NOUN AND PRESENT PERFECT: Can use 'the best' before a noun + present perfect to talk about a unique experience.B1
- superlatives - FORM/USE: WITH NOUN AND POSTMODFIER: Can use a postmodifier to make the superlative stronger in the structure superlative + postmodifier + noun. C1
- would - FORM/USE: AFTER 'IF' CLAUSES: Can use 'would' in the main clause of a conditional sentence to talk about an imagined situation, often in the context of advice or opinion-giving.B1
- would - USE: POLITE REQUESTS: Can use 'would' to make polite requests, often in the fixed expression 'would you mind'.B1
Dialog:
A: It was a melodrama filmed entirely in a Burbank Ikea store, without the store knowing!  Pretty wild,. You could be an employee at Ikea and not even know you are a TV star!
B: wow, that is very different. do you like dogs?
A: I love dogs. I have a GSD.  Did you know a do

In [5]:
constraints

[70, 77, 625, 633]

In [7]:
input_ids = tokenizer.encode(prompt, return_tensors="pt")
input_len = input_ids.shape[1]
p=0.9

with torch.no_grad():
    model.eval()
    for _ in range(10):
        output = model(input_ids, use_cache=True)
        prediction = output.logits[:,-1,:]

        probs = torch.softmax(prediction, dim=-1)
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
        sorted_indices_to_remove = cumulative_probs > p
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        probs[indices_to_remove] = 0
        probs /= probs.sum(dim=-1, keepdim=True)
        candidate_tokens = torch.where(probs>0)[1]
        print(len(candidate_tokens))
        candidate_sequences = torch.cat([input_ids[:,input_len:].expand(len(candidate_tokens), -1), torch.where(probs>0)[1].unsqueeze(1)], dim=-1)
        candidates = tokenizer.batch_decode(candidate_sequences)
        
        k=100
        top_k = torch.topk(prediction, k, dim=-1).indices
        #candidates = tokenizer.batch_decode(torch.cat([input_ids[:,input_len:].expand(k, -1), top_k.transpose(0,1)], dim=-1))
        #print(candidates)
        scores = evaluation.detector.score_texts(candidates, constraints)
        for nr, score in scores.items():
            if (score[0]>0.05).any().item():
                print(nr)
                print([candidates[idx] for idx in torch.where(score[0]>0.05)[0].tolist()])
        
        selected_token = top_k[:,0]
        if selected_token == 2: break
        input_ids = torch.cat((input_ids, torch.tensor([[selected_token]])), dim=1)
        text = tokenizer.batch_decode(input_ids[:,input_len:])[0]
        print(text)
        if text[-2:] == "B:":
            break 

1091


OutOfMemoryError: CUDA out of memory. Tried to allocate 820.00 MiB (GPU 0; 10.75 GiB total capacity; 8.23 GiB already allocated; 261.62 MiB free; 9.38 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

[18]

In [266]:
score[0]>0.1

tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False,  True, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False])

In [217]:
top_k.transpose(0,1).shape

torch.Size([5, 1])

In [220]:
torch.cat([x, top_k.transpose(0,1)], dim=-1)

tensor([[28741, 28747, 28705,  3840],
        [28741, 28747, 28705,  1725],
        [28741, 28747, 28705,  6087],
        [28741, 28747, 28705,   369],
        [28741, 28747, 28705,   315]])

In [187]:
torch.cat((input_ids.unsqueeze(1).expand(-1, k, -1), top_k.unsqueeze(-1)), dim=-1)

tensor([[[    1,   733, 16289,  ..., 28760, 28747, 28747],
         [    1,   733, 16289,  ..., 28760, 28747,  4049],
         [    1,   733, 16289,  ..., 28760, 28747, 19746],
         [    1,   733, 16289,  ..., 28760, 28747, 12813],
         [    1,   733, 16289,  ..., 28760, 28747,   714]]])

In [178]:
input_ids.unsqueeze(1).expand(-1, k, -1).shape

torch.Size([1, 5, 270])