In [2]:
import torch
import transformers

In [3]:
BEAM_SIZE=3

In [4]:

def resolve_logits_for_best_beam(outputs, num_beams):
    """ Resolve the logits from the best beam, using model output from a generate call.
        For a shape [tokens?, batch_size*num_beams, vocab], returns [tokens?, batch_size, vocab]

        Assumes num_return_sequences=1."""

    best_logits  = []
    beam_indices = [ outputs.beam_indices[:,i].tolist() for i in range(len(outputs.logits)) ]

    for beam_index, logits in zip(beam_indices, outputs.logits):
        beam_index = [ idx if idx != -1 else ((num_beams*(i+1))-1) for i, idx in enumerate(beam_index) ]
        best_logits.append(logits[beam_index,:])

    return best_logits

In [15]:

tokenizer = transformers.AutoTokenizer.from_pretrained("google/flan-t5-small")
model = transformers.AutoModelForSeq2SeqLM.from_pretrained(
    "google/flan-t5-small", torch_dtype='bfloat16', device_map='auto'
)




In [7]:


inputs = [
    "Translate English to French: The capital of India is New Delhi.",
    "Translate English to German: The city of Delhi is very polluted these days."
]

In [13]:
"La capital de l'Inde est New Delhi.".index("Delhi")

29

In [23]:



def get_logits_for_span(logits, sequences, tokenizer, search_spans):
    """ Given search spans, returns the logits before the span was generated.

    Args:
        logits (tuple[Tensor]): Tuple of tensors, of shape [tokens?, batch_size, vocab]
        sequences (tuple[list[int]]): Tokenized output sequences.
        tokenizer (PreTrainedTokenizerBase): Tokenizer for the model.
        search_spans (list[str]): batch_size spans to search for. Must be present in the generated sequences.

    Returns:
        Tensor: Tensor of shape [batch_size, vocab] indicating the logits before the span for each batch element.
    """

    # print("seq  ", sequences)

    if isinstance(search_spans, str):
        search_spans = [ search_spans ] * len(sequences)

    # print("spans", search_spans)

    detok_outputs = tokenizer.batch_decode(sequences, skip_special_tokens=True)
    print("detok", detok_outputs)
    

    positions = [ output.index(span) for output, span in zip(detok_outputs, search_spans) ]
    print("pos  ", positions)
    logit_pos = [  ]

    for seq, detok_seq, span, pos in zip(sequences, detok_outputs, search_spans, positions):
        if pos == 0:
            subtokens = tokenizer(span, add_special_tokens=False).input_ids
            print("subtok if 1", subtokens)
        else:
            subtokens = tokenizer(span, add_special_tokens=False).input_ids
            print("subtok else 1", subtokens)
            subtokens_2 = tokenizer(detok_seq[pos-1] + span, add_special_tokens=False).input_ids
            print("subtok else 2", subtokens_2)
            if subtokens[0] not in subtokens_2: subtokens = subtokens_2

        print("subtok effective", subtokens)
        print("seq::", seq)
        

        idx = 0
        while idx < len(seq):
            if all(seq[idx+i] == tok for i, tok in enumerate(subtokens)): 
                print("found", idx)
                break
            idx += 1
        logit_pos.append(idx)
        print("subtok", subtokens, "idx", idx, "logit_pos", logit_pos)

    return torch.stack([ logits[token][batch,:] for batch, token in enumerate(logit_pos) ])
    # return torch.gather(output.scores, index=, dim=-2)


In [24]:
with torch.inference_mode(): #INFERENCE MODE TODO??
    inputs_t = tokenizer(inputs, padding='longest', return_tensors='pt')
    outputs = model.generate(
        **inputs_t.to(model.device),
        max_new_tokens=20,
        do_sample=False,
        temperature=0,
        output_scores=True, # must be true since beam indices are needed
        output_logits=True,
        return_dict_in_generate=True,
        num_beams=BEAM_SIZE,
        num_return_sequences=1
    )
    del outputs.scores
    outputs.beam_indices = outputs.beam_indices.cpu()
    # print("Beam Indices:",len(outputs.beam_indices[0]), outputs.beam_indices)    
    outputs.logits = tuple(logits.cpu() for logits in outputs.logits)
    # print("Length:", len(outputs.logits))
    # print("Output Logits:",  outputs.logits)


print("model output y_hat:", tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True, clean_up_tokenization_spaces=True))

# To get logits for the generation before Delhi:

# 1. Resolve logits for the best beam from these inputs
#    (would be a tuple of tensoors of shape of [tokens?, batch_size, vocab])
best_logits = resolve_logits_for_best_beam(outputs, num_beams=BEAM_SIZE)

# 2. Get the logits before New Delhi
#    (would be a tensor of shape of [batch_size, vocab])
start_logits = get_logits_for_span(best_logits, outputs.sequences, tokenizer, [ "Delhi", "Delhi" ])

# 3. Use these for experiments ...




model output y_hat: ["La capital de l'Inde est New Delhi.", 'Das Stadt von Delhi ist sehr polluted diesen Tagen.']
detok ["La capital de l'Inde est New Delhi.", 'Das Stadt von Delhi ist sehr polluted diesen Tagen.']
pos   [29, 14]
subtok else 1 [10619]
subtok else 2 [10619]
subtok effective [10619]
seq:: tensor([    0,   325,  1784,    20,     3,    40,    31, 26267,   259,   368,
        10619,     5,     1,     0,     0], device='cuda:7')
found 10
subtok [10619] idx 10 logit_pos [10]
subtok else 1 [10619]
subtok else 2 [10619]
subtok effective [10619]
seq:: tensor([    0,   644,  3287,   193, 10619,   229,  1319,  5492,  2810,    26,
            3,  5162, 13657,     5,     1], device='cuda:7')
found 4
subtok [10619] idx 4 logit_pos [10, 4]


In [10]:
print("Best logits:", best_logits)
print("Start logits:", start_logits)

Best logits: [tensor([[-43.5000,   0.1719,  -4.0625,  ..., -43.5000, -43.5000, -43.5000],
        [-38.0000,   1.6719,  -2.7188,  ..., -37.7500, -38.0000, -37.7500]],
       dtype=torch.bfloat16), tensor([[-35.2500,   0.1602,  -2.2500,  ..., -35.0000, -35.2500, -35.0000],
        [-38.0000,   0.2656,  -2.1406,  ..., -38.0000, -38.0000, -37.7500]],
       dtype=torch.bfloat16), tensor([[-46.7500,  -2.7969,  -4.4688,  ..., -46.5000, -46.5000, -46.2500],
        [-40.2500,  -2.2969,  -2.4688,  ..., -40.2500, -40.2500, -40.0000]],
       dtype=torch.bfloat16), tensor([[-59.7500,  -6.6250,  -8.1250,  ..., -59.7500, -59.7500, -59.5000],
        [-43.0000,  -2.4062,  -4.4062,  ..., -43.0000, -43.0000, -42.7500]],
       dtype=torch.bfloat16), tensor([[-60.7500,  -9.5625,  -1.4297,  ..., -60.7500, -60.7500, -60.5000],
        [-37.7500,   0.2500,  -1.2656,  ..., -37.7500, -37.7500, -37.5000]],
       dtype=torch.bfloat16), tensor([[-61.0000,  -9.4375,  -5.2812,  ..., -61.0000, -61.0000, -60.75