In [8]:
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
import torch 
from torch.nn.utils.rnn import pad_sequence
# Load pre-trained MBART model and tokenizer (MBART-50 for multilingual tasks)
model_name = "../mbart_model"
tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
model = MBartForConditionalGeneration.from_pretrained(model_name)

# Example batch of input sentences in various languages
batch_sentences = [
    "Hello, how are you?",   # English
    "What have you been up to recently?", # French
    "Do you want to go for a run?",    # Spanish
]



# Tokenize the input batch of sentences
inputs = tokenizer(batch_sentences, return_tensors="pt", padding=True, truncation=True)

# Generate translations (for example, to French) or any other target language
# Specify the target language for the model to generate in
forced_bos_token_id = tokenizer.lang_code_to_id["zh_CN"]

# Perform inference with the model to generate translations
outputs = model.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'], 
                         forced_bos_token_id=forced_bos_token_id)

predictions = [] 
for i in range(len(outputs)): 
    predictions.append(outputs[i, :])

pad_tensor = torch.ones(200-len(predictions[0]))
predictions[0] = torch.cat((predictions[0],pad_tensor.long()),dim = 0)
predictions = pad_sequence(predictions,batch_first=True,padding_value=1)

# Decode the generated outputs back to text
translated_sentences = tokenizer.batch_decode(predictions, skip_special_tokens=True)

# Print the generated translations
for i, translation in enumerate(translated_sentences):
    print(f"Original: {batch_sentences[i]}")
    print(f"Translated: {translation}")
    print()

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Original: Hello, how are you?
Translated: 你好,你好吗?

Original: What have you been up to recently?
Translated: 你最近做了些什么?

Original: Do you want to go for a run?
Translated: 你想跑吗?



## Testing LLM adaptor 2 shape


In [27]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence

# Define a dummy PAD_IDX for padding purposes
PAD_IDX = 0

class LLMAdapter2(nn.Module):
    '''
    LLM adapter aims to capture temporal relations and transform 32 tokens into 1024 tokens.
    This version introduces an additional projection layer between the two convolution layers.
    '''
    def __init__(self, num_tokens=32, hidden_dim=1024, kernel_size=5):
        super(LLMAdapter2, self).__init__()
        
        # Store parameters
        self.num_tokens = num_tokens
        self.hidden_dim = hidden_dim
        
        # First projection from input tokens to hidden_dim/2
        self.proj = nn.Linear(self.num_tokens, self.hidden_dim // 2)

        # First convolutional block
        self.conv_block_1 = nn.Sequential(
            nn.Conv1d(self.hidden_dim // 2, self.hidden_dim // 2, kernel_size=kernel_size, stride=1, padding=0),
            nn.BatchNorm1d(self.hidden_dim // 2),
            nn.ReLU(inplace=True),
            nn.AvgPool1d(kernel_size=2, ceil_mode=False)
        )

        # New projection layer between convolution layers
        self.intermediate_proj = nn.Linear(self.hidden_dim // 2, self.hidden_dim)

        # Second convolutional block
        self.conv_block_2 = nn.Sequential(
            nn.Conv1d(self.hidden_dim, self.hidden_dim, kernel_size=kernel_size, stride=1, padding=0),
            nn.BatchNorm1d(self.hidden_dim),
            nn.ReLU(inplace=True),
            nn.AvgPool1d(kernel_size=2, ceil_mode=False)
        )


    def forward(self, x, src_length):
        # Input shape: (batch_size, num_frames, num_tokens)
        
        # Split the input into individual batches according to src_length
        start = 0
        x_batch = []
        for length in src_length:
            end = start + length
            x_batch.append(x[start:end])
            start = end
        
        # Pad sequences to ensure uniform batch sizes
        x = pad_sequence(x_batch, padding_value=PAD_IDX, batch_first=True)
        print(f"After padding: {x.shape}")  # Check padding result (batch_size, num_frames, num_tokens)
        
        # Apply the initial projection layer
        x = self.proj(x)  # Shape: (batch_size, num_frames, hidden_dim / 2)
        print(f"After initial projection: {x.shape}")  # Should be (batch_size, num_frames, 512)
        
        # Permute to (batch_size, hidden_dim / 2, num_frames) for Conv1d
        x = x.permute(0, 2, 1)
        print(f"After permute (before first conv): {x.shape}")  # Should be (batch_size, 512, num_frames)
        
        # First convolutional block
        x = self.conv_block_1(x)  # Shape: (batch_size, hidden_dim / 2, reduced_num_frames)
        print(f"After first conv block: {x.shape}")  # Check after first conv
        
        # Apply the intermediate projection layer
        x = x.permute(0, 2, 1)  # Back to (batch_size, reduced_num_frames, hidden_dim / 2)
        x = self.intermediate_proj(x)  # Shape: (batch_size, reduced_num_frames, hidden_dim)
        print(f"After intermediate projection: {x.shape}")  # Should be (batch_size, reduced_num_frames, 1024)
        x = x.permute(0, 2, 1)  # Back to (batch_size, hidden_dim, reduced_num_frames)
        print(f"After permute (before second conv): {x.shape}")  # Check before second conv
        
        # Second convolutional block
        x = self.conv_block_2(x)  # Shape: (batch_size, hidden_dim, further_reduced_num_frames)
        print(f"After second conv block: {x.shape}")  # Check after second conv
        
        # Convert back to (batch_size, further_reduced_num_frames, hidden_dim)
        x = x.permute(0, 2, 1)
        print(f"Final output shape: {x.shape}")  # Should be (batch_size, further_reduced_num_frames, 1024)

        return x

# Create an instance of LLMAdapter2
model = LLMAdapter2()

# Test input
batch_size = 10
num_frames = 50  # Let's assume each sequence has 15 frames
num_tokens = 32  # As specified in the model

# Random test tensor simulating a batch of 10 sequences, each with 15 frames and 32 tokens
test_input = torch.rand((batch_size * num_frames, num_tokens))

# Source lengths for each batch (assuming all sequences have 15 frames)
src_length = torch.tensor([num_frames] * batch_size)

# Forward pass
output = model(test_input, src_length)

After padding: torch.Size([10, 50, 32])
After initial projection: torch.Size([10, 50, 512])
After permute (before first conv): torch.Size([10, 512, 50])
After first conv block: torch.Size([10, 512, 23])
After intermediate projection: torch.Size([10, 23, 1024])
After permute (before second conv): torch.Size([10, 1024, 23])
After second conv block: torch.Size([10, 1024, 9])
Final output shape: torch.Size([10, 9, 1024])


## Testing LLM adaptor 3 shape

In [32]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence

# Define a dummy PAD_IDX for this example
PAD_IDX = 0

class LLMAdapter3(nn.Module):
    def __init__(self, num_tokens=32, hidden_dim=1024, kernel_size=5):
        super(LLMAdapter3, self).__init__()
        self.num_tokens = num_tokens
        self.hidden_dim = hidden_dim
        
        # Temporal convolution over the time dimension
        self.temporal_conv = nn.Sequential(
            nn.Conv1d(self.num_tokens, self.num_tokens * 2, kernel_size=5, stride=1, padding=0),
            nn.BatchNorm1d(self.num_tokens * 2),  # Channels must match Conv1d output channels
            nn.ReLU(inplace=True),
            # Reduce kernel size for pooling to avoid sequence collapse
            nn.AvgPool1d(kernel_size=1, ceil_mode=False),  

            nn.Conv1d(self.num_tokens * 2, self.num_tokens * 4, kernel_size=5, stride=1, padding=0),
            nn.BatchNorm1d(self.num_tokens * 4),  # Channels must match Conv1d output channels
            nn.ReLU(inplace=True),
            nn.AvgPool1d(kernel_size=1, ceil_mode=False)  # Adjusted pooling to avoid reducing size to zero
        )
        
        # Final projection layer
        self.final_proj = nn.Sequential(
            nn.Linear(self.num_tokens * 4, self.hidden_dim)
        )
        self.out = nn.Sequential(nn.BatchNorm1d(self.hidden_dim),
            nn.ReLU(inplace=True))

    def forward(self, x, src_length):
        start = 0
        x_batch = []
        for length in src_length:
            end = start + length
            x_batch.append(x[start:end])
            start = end
        print(f"Before padding: {x.shape}") 
        x = pad_sequence(x_batch, padding_value=PAD_IDX, batch_first=True)
        print(f"After padding: {x.shape}")  # Print shape after padding
        
        # Permute to match Conv1d expected shape: (batch_size, channels, sequence_length)
        x = x.permute(0, 2, 1)
        print(f"After permute: {x.shape}")  # Shape should now be (batch_size, num_tokens, num_frames)
        
        # Apply temporal convolution
        x = self.temporal_conv(x)
        print(f"After temporal_conv: {x.shape}")  # Check shape after convolution
        
        # Permute back to (batch_size, sequence_length, hidden_dim)
        x = x.permute(0, 2, 1)
        print(f"After second permute: {x.shape}")  # Shape should be (batch_size, num_frames, num_tokens*4)
        
        # Apply final projection (we need to flatten or reshape input to match Linear input requirements)
        batch_size, seq_len, hidden_dim = x.shape
        x = self.final_proj(x)
        #x = self.final_proj(x.reshape(batch_size * seq_len, hidden_dim))
        print(f"After final_proj: {x.shape}")  # Check final shape

        print(f"before out shape : {x.shape}")
        x = self.out(x.permute(0, 2, 1)).permute(0, 2, 1)
        return x

# Create an instance of LLMAdapter3
model = LLMAdapter3()

# Test input
batch_size = 10
num_frames = 50 # Let's assume each sequence has 15 frames
num_tokens = 32  # As specified in the model

# Random test tensor simulating a batch of 10 sequences, each with 15 frames and 32 tokens
test_input = torch.rand((batch_size * num_frames, num_tokens))

# Source lengths for each batch (assuming all sequences have 15 frames)
src_length = torch.tensor([num_frames] * batch_size)

# Forward pass
output = model(test_input, src_length)

Before padding: torch.Size([500, 32])
After padding: torch.Size([10, 50, 32])
After permute: torch.Size([10, 32, 50])
After temporal_conv: torch.Size([10, 128, 42])
After second permute: torch.Size([10, 42, 128])
After final_proj: torch.Size([10, 42, 1024])
before out shape : torch.Size([10, 42, 1024])


### Test some generation

In [1]:
import torch 
import torch.nn as nn
from omegaconf import OmegaConf
import torch.distributed
from train_sign_utils import * 
import os 
from accelerate import Accelerator
from logger import setup_logger
from accelerate.utils import set_seed
import sys 
from transformers import MBart50Tokenizer
import torch.multiprocessing as mp

  from .autonotebook import tqdm as notebook_tqdm


attention mode is flash


In [2]:
## Take in a configuration 
import train_sign_utils
import signdata
import seq_model
from imp import reload
reload(seq_model)
reload(signdata)
reload(train_sign_utils)
from train_sign_utils import create_model, create_signloader




config = OmegaConf.load("./configs/Sign2Text_CSL_config_v3.yaml")


output_dir = config.experiment.output_dir

os.makedirs(output_dir, exist_ok=True)
config.experiment.logging_dir = os.path.join(output_dir, "logs")
# Load the model 
accelerator = Accelerator(
        gradient_accumulation_steps=config.training.gradient_accumulation_steps,
        mixed_precision=config.training.mixed_precision,
        project_dir=config.experiment.logging_dir,
        split_batches=False
    )


logger = setup_logger(name="Sign2Text", log_level="INFO",
    output_file=f"{output_dir}/log{accelerator.process_index}.txt")

# We need to initialize the trackers we use, and also store our configuration.

# If passed along, set the training seed now.
if config.training.seed is not None:
    set_seed(config.training.seed, device_specific=True)

# Create model 
model, ema_model = create_model(config, logger, accelerator)
# Create signloaders 
tokenizer = MBart50Tokenizer.from_pretrained(config.training.tokenizer,
                                            src_lang=config.dataset.lang,
                                              tgt_lang= config.dataset.lang)
train_dataloader, dev_dataloader, test_dataloader = create_signloader(config, logger, accelerator, tokenizer, "cpu")





[32m[10/24 14:13:21 Sign2Text]: [0mCreating model and loss module.


  from imp import reload
dataloader_config = DataLoaderConfiguration(split_batches=False)


Titok weights loaded successfully from: TiTok_weights/ema_model/pytorch_model.bin
TiTok weights are frozen chowwww!
[32m[10/24 14:13:35 Sign2Text]: [0mloading weight from ./frozen_sign2text/ema_model/pytorch_model.bin, msg: <All keys matched successfully>




[32m[10/24 14:13:37 Sign2Text]: [0mCreating Signloaders. Batch_size = 2
train dataloader done!
dev dataloader done!
train dataloader done!


In [5]:
def translate_images(model, images, tgt_labels, input_attn, src_length, config, accelerator,  logger, tokenizer): 

    logger.info("Translating images...")
    model = accelerator.unwrap_model(model).to("cuda")
    images = torch.clone(images)
    
    # Set appropriate dtype based on mixed precision
    dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        dtype = torch.bfloat16

    with torch.no_grad(): 

        # Directly generate translations using model.generate
        output = model.generate(
            src_input=images, 
            src_attn=input_attn, 
            src_length=src_length,
            max_new_tokens=150, 
            num_beams=4, 
            decoder_start_token_id=tokenizer.lang_code_to_id[config.dataset.lang]
        )

        
        # Use tokenizer to decode generated token IDs to translations
        pred_translations = tokenizer.batch_decode(output, skip_special_tokens=True)

        # Decode the target labels (ground truth)
        gt_translations = tokenizer.batch_decode(tgt_labels, skip_special_tokens=True)

    
    return pred_translations, gt_translations

In [6]:
# Load the dataset and dataloader 
for i, (src, tgt) in enumerate(tqdm(train_dataloader, desc=f"Generating!")):
    model.to("cuda")
    model.eval()
    batch = src['input_ids']
    src_length = src['src_length_batch']
    tgt_attn = tgt.attention_mask
    tgt_input = tgt['input_ids']
    input_attn = src['attention_mask']

    images = batch.to(
                accelerator.device, memory_format=torch.contiguous_format, non_blocking=True
            )
    #print(f"imagges type: {images.type()}")
    tgt_input = tgt['input_ids'].to(
            accelerator.device, memory_format=torch.contiguous_format, non_blocking=True
        )
    input_attn = input_attn.to(
            accelerator.device, memory_format=torch.contiguous_format, non_blocking=True
        )
    tgt_attn = tgt_attn.to(
            accelerator.device, memory_format=torch.contiguous_format, non_blocking=True
        )
    
    original_images= torch.clone(images)
    print("images shape ", original_images.shape)
    print("input attn shape", input_attn.shape)
    print("tgt attn shape", tgt_attn.shape)
    print("src length shape", src_length.shape)
    print("tgt input shape", tgt_input.shape)
    # Save a batch of translated images to check by reading

    fw_out = model( src_input = original_images,tgt_input = tgt_input, src_attn=input_attn, tgt_attn = tgt_attn, src_length = src_length)
    print(f"fw_out ")
    pred, gt = translate_images(
        model=model,
        images=images,
        tgt_labels=tgt_input,
        input_attn=input_attn, 
        src_length=src_length,
        config=config,
        accelerator=accelerator,
        logger=logger, 
        tokenizer=tokenizer
    )
    break 


print(f"predictions: {pred}")
print(f"ground truth: {gt}")

Generating!:   0%|          | 0/9201 [00:00<?, ?it/s]

images shape  torch.Size([280, 3, 256, 256])
input attn shape torch.Size([2, 37])
tgt attn shape torch.Size([2, 18])
src length shape torch.Size([2])
tgt input shape torch.Size([2, 18])
fw_out 
[32m[10/24 14:22:22 Sign2Text]: [0mTranslating images...


Generating!:   0%|          | 0/9201 [00:23<?, ?it/s]

predictions: ['zh_CN 。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。</s>', 'zh_CN 。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。</s>']
ground truth: ['zh_CN 他不会生气的,我很了解他。</s><pad><pad><pad><pad><pad><pad><pad>', 'zh_CN 动车的车票已经卖完了,只有坐普通车了。</s>']





In [50]:
model

SignModel(
  (titok): TiTok(
    (encoder): TiTokEncoder(
      (patch_embed): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
      (ln_pre): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (transformer): ModuleList(
        (0-23): 24 x ResidualAttentionBlock(
          (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)
          )
          (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
            (gelu): GELU(approximate='none')
            (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
          )
        )
      )
      (ln_post): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (conv_out): Conv2d(1024, 12, kernel_size=(1, 1), stride=(1, 1))
    )
    (decoder)

### Checking LLM adaptation and linking it to gloss

In [14]:
 ## Reading in the labelled annotations
import pickle
def read_CSL_annotations(CSL_annot_path):
    with open(CSL_annot_path, 'rb') as f:
        data = pickle.load(f)
    return data
train_labels = read_CSL_annotations("../../CSL-Daily/sentence_label/processed/labels_train.pkl") 
dev_labels = read_CSL_annotations("../../CSL-Daily/sentence_label/processed/labels_dev.pkl")
test_labels = read_CSL_annotations("../../CSL-Daily/sentence_label/processed/labels_test.pkl")
combined_labels = read_CSL_annotations("../../CSL-Daily/sentence_label/csl2020ct_v2.pkl")

def find_entry_by_name(data, name):
    # Iterate through the list of dictionaries in 'info'
    for entry in data['info']:
        if entry['name'] == name:
            return entry
    return None


def get_gloss(combined_annotations, name): 
  # find entry in combined_annotations 
  entry = find_entry_by_name(combined_labels, name)
  gloss_entry = entry['label_gloss']
  return gloss_entry


def match_frames_w_gloss(tensor_lst, gloss): 
  # divide the frames into glosses
  num_tensors = len(tensor_lst)
  num_gloss = len(gloss)
  tensors_per_gloss = num_tensors//num_gloss

  # create a dictionary to hold the frames referring to a gloss
  gloss_dict = {}
  for i, g in enumerate(gloss):
    if i == num_gloss-1: 
      gloss_dict[g] = [tensor_lst[i*tensors_per_gloss:]]
    else: 
      gloss_dict[g] = [tensor_lst[i*tensors_per_gloss:(i+1)*tensors_per_gloss]]
  
  return gloss_dict

def find_and_combine_glossdicts(phase, num_samples = None , name_lst= None , dir="../../CSL-Daily/sentence/frames_512x512"):
  gloss_dict_lst = {}
  ## make assertions to prevent error 
  assert phase in ['train', 'dev', 'test'], "Phase must be either train, dev or test"
  assert num_samples is not None or name_lst is not None, "Either number of samples or name list must be provided"

  if name_lst is None:
    video_lst = os.listdir(f"{dir}/{phase}")
    # Sample random number of videos
    name_lst =  random.sample(video_lst, num_samples)

  ## Gather various gloss dicts 
  for name in name_lst: 
    gloss = get_gloss(combined_labels, name)
    tensor_lst = gather_vid_emb(name, phase)
    gloss_dict_lst[name] = match_frames_w_gloss(tensor_lst, gloss)

  # Comebine gloss dicts 
  combined_gloss_dict = {}
  for name in name_lst: 
    for k, v in gloss_dict_lst[name].items(): 
      if k in combined_gloss_dict: 
        combined_gloss_dict[k].extend(v)
      else: 
        combined_gloss_dict[k] = v
  return gloss_dict_lst, combined_gloss_dict



In [21]:
'''Checking after LLM adaptation'''
 
# Separate TikTok model and LLM adapter for testing
titok = model.titok
adapter = model.adapter

# run the video frames through the TikTok model 
# Load the dataset and dataloader 
for i, (src, tgt) in enumerate(tqdm(train_dataloader, desc=f"adapting train!")):
    phase = "train"
    titok.to("cuda")
    adapter.to("cuda")
    batch = src['input_ids']
    src_length = src['src_length_batch']
    tgt_attn = tgt.attention_mask
    tgt_input = tgt['input_ids']
    input_attn = src['attention_mask']

    images = batch.to(
                accelerator.device, memory_format=torch.contiguous_format, non_blocking=True
            )
    #print(f"imagges type: {images.type()}")
    tgt_input = tgt['input_ids'].to(
            accelerator.device, memory_format=torch.contiguous_format, non_blocking=True
        )
    input_attn = input_attn.to(
            accelerator.device, memory_format=torch.contiguous_format, non_blocking=True
        )
    tgt_attn = tgt_attn.to(
            accelerator.device, memory_format=torch.contiguous_format, non_blocking=True
        )
    
    original_images= torch.clone(images)
    # print(f"names: {src['name_batch']}")
    # print("images shape ", original_images.shape)
    # print("input attn shape", input_attn.shape)
    # print("tgt attn shape", tgt_attn.shape)
    # print("src length shape", src_length.shape)
    # print("tgt input shape", tgt_input.shape)
    # Save a batch of translated images to check by reading
    encoded_tokens = titok.encode(x=original_images)[1]['min_encoding_indices'].squeeze()
    #print(f"encoded tokens shape: {encoded_tokens.shape}")
    hidden_values = adapter(encoded_tokens.float(), src_length).squeeze()
    #print(f"hidden values shape: {hidden_values.shape}")

    for i,  hid_val in enumerate(hidden_values): 
        hid_val = hid_val[input_attn[i]==1]
        #print(f"hid val: {hid_val.shape}")
        assert hid_val.shape[0] == src_length[i]

        for j, token in enumerate(hid_val): 
            # save the token as a .pth file 
            path = os.path.join(config.dataset.img_path,phase )
            final_path = os.path.join(path, src['name_batch'][i])
            #print(final_path)
            torch.save(token, f"{final_path}/aft_adapter_{j}.pth")

   

# Save newly compressed and temporal conv as .pth files and also consider the gloss used for the translation
for i, (src, tgt) in enumerate(tqdm(dev_dataloader, desc=f"adapting dev!")):
    phase = "dev"
    titok.to("cuda")
    adapter.to("cuda")
    batch = src['input_ids']
    src_length = src['src_length_batch']
    tgt_attn = tgt.attention_mask
    tgt_input = tgt['input_ids']
    input_attn = src['attention_mask']

    images = batch.to(
                accelerator.device, memory_format=torch.contiguous_format, non_blocking=True
            )
    #print(f"imagges type: {images.type()}")
    tgt_input = tgt['input_ids'].to(
            accelerator.device, memory_format=torch.contiguous_format, non_blocking=True
        )
    input_attn = input_attn.to(
            accelerator.device, memory_format=torch.contiguous_format, non_blocking=True
        )
    tgt_attn = tgt_attn.to(
            accelerator.device, memory_format=torch.contiguous_format, non_blocking=True
        )
    
    original_images= torch.clone(images)
    # print(f"names: {src['name_batch']}")
    # print("images shape ", original_images.shape)
    # print("input attn shape", input_attn.shape)
    # print("tgt attn shape", tgt_attn.shape)
    # print("src length shape", src_length.shape)
    # print("tgt input shape", tgt_input.shape)
    # Save a batch of translated images to check by reading
    encoded_tokens = titok.encode(x=original_images)[1]['min_encoding_indices'].squeeze()
    #print(f"encoded tokens shape: {encoded_tokens.shape}")
    hidden_values = adapter(encoded_tokens.float(), src_length).squeeze()
    #print(f"hidden values shape: {hidden_values.shape}")

    for i,  hid_val in enumerate(hidden_values): 
        hid_val = hid_val[input_attn[i]==1]
        print(f"hid val: {hid_val.shape}")

        for j, token in enumerate(hid_val): 
            # save the token as a .pth file 
            path = os.path.join(config.dataset.img_path,phase )
            final_path = os.path.join(path, src['name_batch'][i])
            #print(final_path)
            torch.save(token, f"{final_path}/aft_adapter_{j}.pth")

   

# Save newly compressed and temporal conv as .pth files and also consider the gloss used for the translation
for i, (src, tgt) in enumerate(tqdm(test_dataloader, desc=f"adapting test!")):
    phase = "test"
    titok.to("cuda")
    adapter.to("cuda")
    batch = src['input_ids']
    src_length = src['src_length_batch']
    tgt_attn = tgt.attention_mask
    tgt_input = tgt['input_ids']
    input_attn = src['attention_mask']

    images = batch.to(
                accelerator.device, memory_format=torch.contiguous_format, non_blocking=True
            )
    #print(f"imagges type: {images.type()}")
    tgt_input = tgt['input_ids'].to(
            accelerator.device, memory_format=torch.contiguous_format, non_blocking=True
        )
    input_attn = input_attn.to(
            accelerator.device, memory_format=torch.contiguous_format, non_blocking=True
        )
    tgt_attn = tgt_attn.to(
            accelerator.device, memory_format=torch.contiguous_format, non_blocking=True
        )
    
    original_images= torch.clone(images)
    # print(f"names: {src['name_batch']}")
    # print("images shape ", original_images.shape)
    # print("input attn shape", input_attn.shape)
    # print("tgt attn shape", tgt_attn.shape)
    # print("src length shape", src_length.shape)
    # print("tgt input shape", tgt_input.shape)
    # Save a batch of translated images to check by reading
    encoded_tokens = titok.encode(x=original_images)[1]['min_encoding_indices'].squeeze()
    #print(f"encoded tokens shape: {encoded_tokens.shape}")
    hidden_values = adapter(encoded_tokens.float(), src_length).squeeze()
    #print(f"hidden values shape: {hidden_values.shape}")

    for i,  hid_val in enumerate(hidden_values): 
        hid_val = hid_val[input_attn[i]==1]
        print(f"hid val: {hid_val.shape}")

        for j, token in enumerate(hid_val): 
            # save the token as a .pth file 
            path = os.path.join(config.dataset.img_path,phase )
            final_path = os.path.join(path, src['name_batch'][i])
            #print(final_path)
            torch.save(token, f"{final_path}/aft_adapter_{j}.pth")

   

# Save newly compressed and temporal conv as .pth files and also consider the gloss used for the translation


adapting train!:   0%|          | 11/9201 [00:55<11:17:56,  4.43s/it]