In [2]:
import pandas as pd
from nnsight import LanguageModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

In [3]:
genaudit_read_path = '/home/ramprasad.sa/probing_summarization_factuality/datasets/Genaudit_annotations.csv'
df_genaudit = pd.read_csv(genaudit_read_path)
df_genaudit.head()[:1]

Unnamed: 0.1,Unnamed: 0,id,source,summary,annotated_spans,model,origin
0,0,Rachel Usher#REDDIT-83:flanul2-ul2,both me and my girlfriend participate in winte...,I waited outside the locker rooms for my girlf...,nobody,flanul2,REDDIT


In [3]:
model_name = 'mistral7b'
df_genaudit_mistral = df_genaudit[df_genaudit['model'] == model_name]
df_genaudit_mistral.head()[:1]

Unnamed: 0.1,Unnamed: 0,id,source,summary,annotated_spans,model,origin
4,4,Rachel Usher#REDDIT-83:mistral7b-ul2,both me and my girlfriend participate in winte...,The document describes the experience of a hig...,participates<sep>with their girlfriend<sep>ends,mistral7b,REDDIT


In [6]:
print(set(df_genaudit['model']))

{'falcon7b', 'flanul2', 'mistral7b', 'llama70b', 'chatgpt', 'geminipro', 'llama7b', 'gpt4'}


In [5]:

model_path = {'mistral7b': 'mistralai/Mistral-7B-Instruct-v0.1',
             'falcon7b': 'tiiuae/falcon-7b-instruct',
             'llama7b': '/work/frink/models/Llama-2-7b-chat-hf',
             'flanul2': 'google/flan-ul2'}

def load_model(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_path[model_name],
                                         cache_dir = '/scratch/ramprasad.sa/huggingface_models')

    model = AutoModelForCausalLM.from_pretrained(model_path[model_name],
                                            cache_dir = '/scratch/ramprasad.sa/huggingface_models')
    # model = model.to('cuda')
    return tokenizer, model


tokenizer, mistral_model = load_model(model_name)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model = LanguageModel(mistral_model, tokenizer=tokenizer, device_map="auto", dispatch=True)
model



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )
    (norm): MistralRMSNorm(

In [118]:
import json 
prompt_template = f'{{instruction}}{{prompt_prefix}}{{source}}{{prompt_suffix}}{{summary}}'

prompt_template_path = '/home/ramprasad.sa/probing_summarization_factuality/datasets/prompt_templates/'
prompt_type = 'document_context_causal'

with open(f'{prompt_template_path}/{prompt_type}.json', 'r') as fp:
    prompt_dict = json.load(fp)
prompt_dict, prompt_template

({'instruction': 'Generate a summary for the following document in brief. When creating the summary, only use information that is present in the document',
  'prompt_prefix_template': ' CONTENT: ',
  'prompt_suffix_template': ' SUMMARY: '},
 '{instruction}{prompt_prefix}{source}{prompt_suffix}{summary}')

In [6]:
import json


def strip_bos_eos_ids(ids):
    ids = ids[:, 1:] if ids[0][0] in [0,1,2] else ids
    return ids
    
class PromptProcessor():

    def __init__(self,
                 prompt_template,
                 prompt_dict_path,
                prompt_type):

        self.prompt_template = prompt_template

        with open(f'{prompt_dict_path}/{prompt_type}.json', 'r') as fp:
            self.prompt_dict = json.load(fp)
            
        self.instruction = self.prompt_dict['instruction'] if 'instruction' in self.prompt_dict else ''
        self.prefix = self.prompt_dict['prompt_prefix_template'] if 'prompt_prefix_template' in self.prompt_dict else ''
        self.suffix = self.prompt_dict['prompt_suffix_template'] if 'prompt_suffix_template' in self.prompt_dict else ''

    def make_prompt(self,
                    source,
                    summary):
        
        prompt = self.prompt_template.format(instruction = self.instruction,
                                    prompt_prefix = self.prefix,
                                    source = source,
                                    prompt_suffix = self.suffix,
                                    summary = summary
                                   )
        return prompt


    def get_prompt_attributes_idx(self,
                                  prompt_ids,
                                  source,
                                  summary,
                                  tokenizer):
        instr_idx = -1
        instr_prefix_idx = -1
        instr_prefix_src_idx = -1
        instr_prefix_src_suffix_idx = -1

        for span_idx in range( len(prompt_ids)):
        
            span_tokens = prompt_ids[1:span_idx]
            if tokenizer.decode(span_tokens) == self.prompt_template.format(instruction = self.instruction,
                                            prompt_prefix = '',
                                            source = '',
                                            prompt_suffix = '',
                                            summary = ''
                                           ).strip():
                instr_idx = span_idx
            
            if tokenizer.decode(span_tokens) == self.prompt_template.format(instruction = self.instruction,
                                        prompt_prefix = self.prefix,
                                        source = '',
                                        prompt_suffix = '',
                                        summary = ''
                                       ).strip():
                instr_prefix_idx = span_idx
            
            if tokenizer.decode(span_tokens) == self.prompt_template.format(instruction = self.instruction,
                                        prompt_prefix = self.prefix,
                                        source = source,
                                        prompt_suffix = '',
                                        summary = ''
                                       ).strip():
                instr_prefix_src_idx = span_idx
            
            if tokenizer.decode(span_tokens) == self.prompt_template.format(instruction = self.instruction,
                                        prompt_prefix = self.prefix,
                                        source = source,
                                        prompt_suffix = self.suffix,
                                        summary = ''
                                       ).strip():
                instr_prefix_src_suffix_idx = span_idx 

        return instr_idx, instr_prefix_idx, instr_prefix_src_idx, instr_prefix_src_suffix_idx

    

In [7]:
import json 
prompt_template = f'{{instruction}}{{prompt_prefix}}{{source}}{{prompt_suffix}}{{summary}}'

prompt_dict_path = '/home/ramprasad.sa/probing_summarization_factuality/datasets/prompt_templates/'
prompt_type = 'document_context_causal'

prompt_processor = PromptProcessor(prompt_template = prompt_template,
                 prompt_dict_path = prompt_dict_path,
                prompt_type = prompt_type)

In [8]:

counter = 0

for idx, row in df_genaudit_mistral[~df_genaudit_mistral['annotated_spans'].isnull()].iterrows():
    source = row['source']
    summary = row['summary']
    nonfactual_spans = row['annotated_spans']
    prompt = prompt_processor.make_prompt(source= source,
                                 summary = summary)
    if counter == 4:
        break
    counter += 1

In [9]:
# print(prompt)
nonfactual_spans

'peroxide<sep>earrings overnight<sep>mild'

In [34]:
'''
Get non corrupted hidden states and save --> clean hstates
Corrupt instruction embeddings --> save corrupted hidden states 
For all summary tokens as tgt
    For all layers
        For all clean tokens 
            Replace corrupted clean_token_layer with clean hstates 
            Get p(tgt_token) with corrupted 
            Get p(tgt_token) with corrupted and replaced 
            Accumulate difference --> {t1: {l1: , l2: ,,,}}
'''
import torch.nn.functional as F
from copy import deepcopy
from tqdm import tqdm
from nnsight import LanguageModel, util
from nnsight.tracing.Proxy import Proxy

# util.apply(comparator, lambda x: x.value.item(), Proxy)


source = source
summary  = summary
N_LAYERS = len(model.model.layers)

prompt = prompt

prompt_tokens = tokenizer(prompt).input_ids
instr_idx, instr_prefix_idx, instr_prefix_src_idx, instr_prefix_src_suffix_idx = prompt_processor.get_prompt_attributes_idx(prompt_ids = tokenizer(prompt).input_ids,
                              source = source,
                              summary = summary,
                              tokenizer = tokenizer)


with model.trace() as tracer:
    ''' clean run ''' 
    with tracer.invoke(prompt) as invoker:
        clean_input_embeddings = model.model.embed_tokens.output.save()
        
        clean_hs = [
            model.model.layers[layer_idx].output[0].save()
            for layer_idx in range(N_LAYERS)
        ]
        clean_logits = model.lm_head.output[0]
        clean_logits = F.softmax(clean_logits, dim = 1)

    # ''' corrupted run '''
    with tracer.invoke(prompt) as invoker:
        init_noise = torch.zeros(clean_input_embeddings.shape)
    #     ''' only noise the instruction tokens'''
        init_noise[:, :instr_idx, :] = (0.1**0.5)*torch.randn(init_noise[:, :instr_idx, :].shape)
        
        model.model.embed_tokens.output = clean_input_embeddings + init_noise
        corrupted_hs = [
            model.model.layers[layer_idx].output[0].save()
            for layer_idx in range(N_LAYERS)
        ]
        corrupted_logits = model.lm_head.output[0]
        corrupted_logits = F.softmax(corrupted_logits, dim = 1).save()
        noised_embeddings = model.model.embed_tokens.output.save() 

    comparator = []
    # for token_idx in range(len(prompt_tokens)):
    #     if token_idx >= instr_prefix_src_suffix_idx:
    #         p_token_idx = clean_logits[token_idx - 1]
    #         p_token_idx_corr 
    #         comparator.append(p_token[prompt_tokens[token_idx]].save())

    for tidx in [470, 471, 475, 485, 486, 487, 492, 493, 494][:1]:

        if tidx >= instr_prefix_src_suffix_idx:

            prob_clean_tidx = clean_logits[tidx - 1]
            prob_corr_tidx = corrupted_logits[tidx - 1]
            pred_token = torch.argmax(prob_clean_tidx).item()
            tgt_token = prompt_tokens[tidx]

            layerwise_patching_results = []
            for layer_idx in range(len(model.model.layers))[:1]:
                # layerwise_patching_results = []
                
                with tracer.invoke(prompt) as invoker:
                        model.model.embed_tokens.output = noised_embeddings
                        model.model.layers[layer_idx].output[0].t[tidx - 1] = clean_hs[layer_idx].t[tidx - 1]
                        corrupted_repl_logits = model.lm_head.output[0]
                        corrupted_repl_logits = F.softmax(corrupted_repl_logits, dim = 1).save()
                        
                        # assert(corrupted_repl_logits
                        prob_corr_repl_tidx = corrupted_repl_logits[tidx - 1]

                        assert(corrupted_repl_logits[tidx - 1] == corrupted_logits[tidx - 1])
                        layerwise_patching_results.append((layer_idx,
                                                           tidx - 1,
                                                           tgt_token, 
                                                           prob_clean_tidx[tgt_token].save(), 
                                                           prob_corr_tidx[tgt_token].save(),
                                                           prob_corr_repl_tidx[tgt_token].save()))
                        
            
                
            comparator.append(layerwise_patching_results)

        
    
                        
                        

In [44]:
corrupted_repl_logits[468] == corrupted_logits[468]

tensor([True, True, True,  ..., True, True, True])

In [45]:
corrupted_repl_logits[469] == corrupted_logits[469]

tensor([False, False, False,  ..., False, False, False])

In [40]:
comparator

[[(0,
   469,
   8120,
   tensor(0.8350, grad_fn=<SelectBackward0>),
   tensor(0.9613, grad_fn=<SelectBackward0>),
   tensor(0.9611, grad_fn=<SelectBackward0>))]]

In [444]:
torch.argmax(clean_logits[instr_prefix_src_suffix_idx: ])

pred_tokens = [torch.argmax(logit).item() for logit in clean_logits[instr_prefix_src_suffix_idx -1: -1]]
summary_tokens = prompt_ids[instr_prefix_src_suffix_idx:]

assert(len(pred_tokens) == len(summary_tokens))

In [448]:
instr_prefix_src_suffix_idx

100

In [452]:
torch.max(clean_logits[instr_prefix_src_suffix_idx -1 + 0]), summary_tokens[0]

(tensor(14.5398, grad_fn=<MaxBackward1>), 415)

In [446]:

counter = 0
for pred, summ in list(zip(pred_tokens, summary_tokens)):
    
    if pred == summ:
        print(counter)
    else:
        print(tokenizer.decode([pred, summ]))
    counter += 1

0
person document
2
a the
4
5
6
experience confusion
8
the a
10
11
12
13
14
frustration experience
16
trying spending
18
trying at
20
21
22
23
24
find open
26
. and
realizing ultimately
29
30
correct test
was had
33
passed taken
35
on.


In [447]:
list(zip(pred_tokens, summary_tokens))

[(415, 415),
 (1338, 3248),
 (13966, 13966),
 (264, 272),
 (3227, 3227),
 (28742, 28742),
 (28713, 28713),
 (2659, 16630),
 (684, 684),
 (272, 264),
 (1369, 1369),
 (3608, 3608),
 (304, 304),
 (652, 652),
 (14235, 14235),
 (22802, 2659),
 (302, 302),
 (2942, 9981),
 (727, 727),
 (2942, 438),
 (264, 264),
 (1486, 1486),
 (2052, 2052),
 (2942, 2942),
 (298, 298),
 (1300, 1565),
 (9289, 9289),
 (28723, 304),
 (27494, 12665),
 (27494, 27494),
 (272, 272),
 (4714, 1369),
 (403, 553),
 (2141, 2141),
 (4568, 3214),
 (1633, 1633),
 (356, 28723)]

In [426]:
prompt_ids = tokenizer(prompt).input_ids
prompt_ids[instr_prefix_src_suffix_idx:]

[415,
 3248,
 13966,
 272,
 3227,
 28742,
 28713,
 16630,
 684,
 264,
 1369,
 3608,
 304,
 652,
 14235,
 2659,
 302,
 9981,
 727,
 438,
 264,
 1486,
 2052,
 2942,
 298,
 1565,
 9289,
 304,
 12665,
 27494,
 272,
 1369,
 553,
 2141,
 3214,
 1633,
 28723]

In [368]:
corrupted_repl_hs[13][0][:, repl_token_idx, :]

tensor([[ 0.0818, -0.1291, -0.0457,  ...,  0.0578, -0.0109,  0.0048]],
       grad_fn=<SliceBackward0>)

In [396]:
# corrupted_hs[0] == corrupted_repl_hs[0]
for layer_idx in range(0, 32):
    if layer_idx == 9:
        differing = corrupted_hs[layer_idx] == corrupted_repl_hs[layer_idx]
        # print(corrupted_hs[layer_idx] == corrupted_repl_hs[layer_idx])
    print(layer_idx, torch.allclose(corrupted_hs[layer_idx], corrupted_repl_hs[layer_idx]))

0 True
1 True
2 True
3 True
4 True
5 True
6 True
7 True
8 True
9 False
10 False
11 False
12 False
13 False
14 False
15 False
16 False
17 False
18 False
19 False
20 False
21 False
22 False
23 False
24 False
25 False
26 False
27 False
28 False
29 False
30 False
31 False


In [410]:
# import torch

# # Example boolean tensor
# boolean_tensor = torch.tensor([[True, False, True], [False, True, False]])

# # Find indices of False values
# false_indices = (boolean_tensor == False).nonzero()

# print("Indices of False values:")
# print(false_indices)

false_indices = (differing == False).nonzero()
false_indices

tensor([[   0,    4,    0],
        [   0,    4,    1],
        [   0,    4,    2],
        ...,
        [   0,    4, 4093],
        [   0,    4, 4094],
        [   0,    4, 4095]])

In [411]:
differing.shape

torch.Size([1, 137, 4096])

In [409]:
'''
Save original output
Corrupt input embeddings 
Save corrupted output 

For each layer and token replace corrupted with original
'''
import torch

N_LAYERS = 32
with model.trace() as tracer:
    #### clean run ####
    with tracer.invoke(prompt) as invoker:
        clean_tokens = model.input[1]["input_ids"].squeeze().save()
        embeddings = model.model.embed_tokens.output.save()
        clean_hs = [
            model.model.layers[layer_idx].output[0].save()
            for layer_idx in range(N_LAYERS)
        ]
        clean_logits = model.lm_head.output.save()
        
    #### corrupted run #### 
    with tracer.invoke(prompt): 
        # Noise the input embeddings for instructions only 
        noise = torch.zeros(embeddings.shape)
        noise[:, :prompt_instruction_ids.shape[-1], :] = (0.1**0.5)*torch.randn(noise[:, :prompt_instruction_ids.shape[-1], :].shape)
        
        corrupted_logits = model.lm_head.output.save()

        
        
        ''' Test output differences '''
        # model.model.embed_tokens.output = embeddings + noise
        
        # corrupted_output = [
        #     model.model.layers[layer_idx].output[0].save()
        #     for layer_idx in range(N_LAYERS)
        # ]
        # corrupted_logits = model.lm_head.output.save()

    # #### corrupted with replacement 
    # for tgt_token_idx in range(1, len(clean_tokens)):
    #     tgt_token = clean_tokens[tgt_token_idx]
    #     p_tgt_token_
    #     p_corrupted_tgt_token = 

        
    #     for layer_idx in range(len(model.model.layers)):
    #         for token_idx in range(len(clean_tokens)):
                
            
    # for token_idx in range(len(clean_tokens)):
    #     for layer_idx in range(len(model.model.layers)):
        
    #         #### Making sure the predicted token is not beyond our summary/prompt
    #         tgt_idx = token_idx + 1 
    #         if tgt_idx < len(clean_tokens):
    #             tgt_token = clean_tokens[tgt_idx]

    #             model.model.layers[layer_idx].output[0].t[token_idx] = 
    


KeyboardInterrupt



In [186]:
clean_tokens

tensor([    1, 26075,   264, 14060,   354,   272,  2296,  3248,   297,  6817,
        28723,  1684,  6818,   272, 14060, 28725,   865,   938,  1871,   369,
          349,  2169,   297,   272,  3248,  4192, 28738,  2431, 28747,   613,
         1654,   586,   960,  1369,   403,  3154,   568,   378,   403, 12091,
          568,   613,  5223,   264,  3102,   302, 28705, 28770, 28734,  3486,
          442,   579,  7312,  1401,   264,  1486, 19185,  2942,   298,  1565,
         1012,  2692,  2251,   568,   613,  4251,  9681,   586,  1369, 13490,
          304,  7185,   272,  3608,   773,  2495,   324,  1466,   461,  1802,
        28705, 28740, 28770,   568,   708,  3383,   708,   624,  1112,   403,
          297,   272, 12128,  2055, 28723,   318,  4171, 28755, 10713, 28747,
          415,  3248, 13966,   272,  3227, 28742, 28713, 16630,   684,   264,
         1369,  3608,   304,   652, 14235,  2659,   302,  9981,   727,   438,
          264,  1486,  2052,  2942,   298,  1565,  9289,   304, 

In [191]:
instr_idx, instr_prefix_idx, instr_prefix_src_idx, instr_prefix_src_suffix_idx = prompt_processor.get_prompt_attributes_idx(prompt_ids = clean_tokens,
                          source = source,
                          summary = summary,
                          tokenizer = tokenizer)

summary_tokens = clean_tokens[instr_prefix_src_suffix_idx:]
assert(tokenizer.decode(summary_tokens) == summary)

In [190]:
summary

"The document describes the author's confusion about a test date and their subsequent experience of spending time at a high school trying to open doors and ultimately realizing the test had already taken place."

In [64]:
layer_repl = 13
token_repl = 26

with model.trace() as tracer:
    with tracer.invoke(prompt): 
        model.model.embed_tokens.output = embeddings + noise
        model.model.layers[layer_repl].output[0].t[token_repl] = clean_hs[layer_repl].t[token_repl]
        corrupted_replaced_logits = model.lm_head.output.save()

In [91]:
tokenizer.decode(clean_tokens), source_len

NameError: name 'source_len' is not defined

In [73]:
corrupted_summ_logits = corrupted_replaced_logits[:,-summ_len - 1:,:]

# clean_summary_logits.squeeze(0).shape
tokens = [torch.argmax(each).item() for each in corrupted_summ_logits.squeeze(0)]
tokenizer.decode(tokens)

"i test is a protagon's experience and the test date and their subsequent realization of walking time trying a high school trying to find doors. finding realizing the test date already passed place. The"

In [74]:
corrupted_summ_logits = corrupted_logits[:,-summ_len - 1:,:]

# clean_summary_logits.squeeze(0).shape
tokens = [torch.argmax(each).item() for each in corrupted_summ_logits.squeeze(0)]
tokenizer.decode(tokens)

"i test is the protagon's experience and the test date and their subsequent realization of walking time trying a high school trying to find doors. finding realizing the test date already passed place. The"

In [41]:
clean_summary_logits = clean_logits[:,-summ_len - 1:,:]

# clean_summary_logits.squeeze(0).shape
tokens = [torch.argmax(each).item() for each in clean_summary_logits.squeeze(0)]
tokenizer.decode(tokens)
# tokens = [torch.argmax(clean_logits[:, token_idx, :]).item() for token_idx in range(clean_logits.shape[1])]

"The author describes the author's confusion about the test date and their subsequent attempt of trying time trying a high school trying to find doors until realizing realizing the error was already passed place on</s>"

In [40]:
summary

"The document describes the author's confusion about a test date and their subsequent experience of spending time at a high school trying to open doors and ultimately realizing the test had already taken place."

In [33]:
input_ids = tokenizer(prompt,
          return_tensors="pt").input_ids
mistral_model_output = mistral_model.generate(input_ids,
                      max_length=input_ids.shape[-1] + 1, 
                        output_attentions = True,
                        output_hidden_states=True, 
                        return_dict_in_generate=True)




The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


In [34]:
summ_len

37

In [35]:
import torch

hstates = torch.cat(mistral_model_output['hidden_states'][0])
summary_hstates = hstates[:, -summ_len-1:, :]
logits = [mistral_model.lm_head(hstate) for hstate in summary_hstates[-1]]
# hstates[-1].shape

In [36]:
tokens = [torch.argmax(logit).item() for logit in logits]
tokenizer.decode(tokens)

"The author describes the author's confusion about the test date and their subsequent attempt of trying time trying a high school trying to find doors until realizing realizing the error was already passed place on</s>"

In [37]:
summary

"The document describes the author's confusion about a test date and their subsequent experience of spending time at a high school trying to open doors and ultimately realizing the test had already taken place."

In [178]:
tokenizer.decode(tokens)

"Q a list of the following story: .\n the a summary, consider include the that is relevant in the document and\n: https the and my friend are in a sports\n the local school. we am and she playsences\n\n we the great start, both andi:)) and i was cold cold that have follow in in the formal, i girlfriend was a black meet today and so she practices out at 3:30pm she practice is at 3:05..\n is to dress home and she can for for locker room for her to she can help her my ride beforebye before she do did't know is that she girlfriend is at she amm waiting outside the locker rooms waiting i late senior walkresser, takes not no for  bus minute.. i i gets out she give her a kiss goodbye and i to.. i is a wrong no there nothing me bus soundfall.. it i wased and start walking my to i girlfriend who i' a off at and i call my mom and my one.. i mom are i brother are out late i one.. them either i i'm freak  bus for my 20 walk home in thezing weather.. the swim swim and and jeans jacket.. hat.. i i l

In [179]:
summary

'The document describes the experience of a high school student who participates in winter sports with their girlfriend.. Today, the student has a late meet and must dress up for it.. Meanwhile, their girlfriend has a normal practice that ends at 3:15 PM.. The student waits outside the locker rooms to give their girlfriend a goodbye kiss before they leave for their respective activities.. However, the student realizes that their bus leaves while they are waiting and they miss it.. As a result, the student has to walk six miles in freezing temperatures to get home, wearing only a dress shirt and no gloves or hat.. The document ends with the student wishing for luck.'

In [173]:
clean_logits[:, 0, :].shape

torch.Size([1, 32000])