In [1]:
import numpy as np
import glob
import os
import math
import sys

from pathlib import Path
import cv2
import numpy as np
from matplotlib.animation import ArtistAnimation
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
from sklearn.manifold import TSNE
from tqdm import tqdm
from enum import Enum
from typing import Optional
from r3m import load_r3m
import clip

import torch
import torch.nn as nn
import torch.multiprocessing as mp
from torch.utils.data import DataLoader, Dataset
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AdamW, get_linear_schedule_with_warmup
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, ToPILImage
from torch.nn import functional as nnf
import wandb

from dataset import CustomDataset, AttrDict 

import config as CFG


mp.set_start_method('spawn')


  from .autonotebook import tqdm as notebook_tqdm
2023-09-03 02:49:24.757335: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Visualize Calvin Data

In [4]:
def visualize():

    #parser = ArgumentParser(description="Interactive visualization of CALVIN dataset")
    #parser.add_argument("path", type=str, help="Path to dir containing scene_info.npy")
    #parser.add_argument("-d", "--data", nargs="*", default=["rgb_static", "rgb_gripper"], help="Data to visualize")
    #args = parser.parse_args()

    path = CFG.datapath_training
    data = ["rgb_static", "rgb_gripper"]

    if not Path(path).is_dir():
        print(f"Path {path} is either not a directory, or does not exist.")
        exit()

    indices = next(iter(np.load(f"{path}/scene_info.npy", allow_pickle=True).item().values()))
    indices = list(range(indices[0], indices[1] + 1))

    scene_info = np.load(f"{path}/scene_info.npy", allow_pickle=True)
    print(scene_info)

    annotations = np.load(f"{path}/lang_annotations/auto_lang_ann.npy", allow_pickle=True).item()
    annotations = list(zip(annotations["info"]["indx"], annotations["language"]["ann"]))
    print(annotations)
    print(len(annotations))

    # idx = 0
    idx = 60000
    ann_idx = -1

    while True:
        t = np.load(f"{path}/episode_{indices[idx]:07d}.npz", allow_pickle=True)

        for d in data:
            if d not in t:
                print(f"Data {d} cannot be found in transition")
                continue

            img = cv2.resize(t[d], (400, 400))
            cv2.imshow(d, img[:, :, ::-1])

        for n, ((low, high), ann) in enumerate(annotations):
            if indices[idx] >= low and indices[idx] <= high:
                if n != ann_idx:
                    print(f"{ann}")
                    ann_idx = n

        # user_input = input("Enter something: ")


        key = cv2.waitKey(0)
        if key == ord("q"):
            break
        elif key == 83:  # Right arrow
            idx = (idx + 1) % len(indices)
        elif key == 81:  # Left arrow
            idx = (len(indices) + idx - 1) % len(indices)
        else:
            print(f'Unrecognized keycode "{key}"')

        

In [5]:
visualize()

{'calvin_scene_D': [0, 611098]}
[((315660, 315724), 'move the door to the left side'), ((191730, 191794), 'slide the door to the left side'), ((305439, 305503), 'slide down the switch'), ((340730, 340794), 'toggle the button to turn on the green light'), ((542337, 542401), 'toggle the light switch to turn on the yellow light'), ((536830, 536894), 'push the switch upwards'), ((575627, 575691), 'push down the button to turn on the led'), ((80243, 80307), 'open the cabinet drawer'), ((68433, 68497), 'grasp the drawer handle and open it'), ((370674, 370738), 'move up the switch'), ((526635, 526699), 'pull the handle of the drawer'), ((485616, 485680), 'move the sliding door to the left'), ((473791, 473839), 'put the block in the drawer'), ((201910, 201957), 'toggle the light switch to turn off the light bulb'), ((292365, 292429), 'move the door to the right side'), ((425910, 425974), 'turn on the yellow light'), ((91077, 91141), 'grasp the blue block and lift it up'), ((610343, 610407), 't

In [10]:
datapath_test= '.././test_data/D_D/task_D_D_episode.npz'
data = np.load(datapath_test)
print(list(data.keys()))
print(data['actions'].shape)
print(data['rel_actions'].shape)
print(data['rgb_static'].shape)
print(data['rgb_gripper'].shape)
print(data['scene_obs'])

['actions', 'rel_actions', 'robot_obs', 'scene_obs', 'rgb_static', 'rgb_gripper', 'rgb_tactile', 'depth_static', 'depth_gripper', 'depth_tactile']
(7,)
(7,)
(200, 200, 3)
(84, 84, 3)
[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  2.07909500e-15
  0.00000000e+00  0.00000000e+00  1.63225568e-01 -4.65878297e-02
  4.59990009e-01  3.30461182e-07 -6.46899737e-08 -6.43902616e-01
 -1.63540040e-01  5.52660776e-02  4.60989636e-01  4.61108288e-05
 -7.81370023e-06 -2.05803878e+00 -2.77312162e-01  8.50804019e-02
  4.60989884e-01 -2.63198658e-06  6.74378838e-06 -7.66156532e-01]


## Captioning

### Models
Captioning Model

In [2]:

class MappingType(Enum):
    MLP = 'mlp'
    Transformer = 'transformer'

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
       
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        #div_term = torch.exp(torch.arange(0, d_model) * (-math.log(10000.0) / d_model))


        # pe = torch.zeros(max_len, 1, d_model)

        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        # pe[0, :, 1::2] = torch.cos(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term[:-1])

        
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)


class BehaviourEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        self.pos_encoder = PositionalEncoding(CFG.d_model, dropout=CFG.dropout)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=CFG.d_model, nhead=CFG.n_heads, batch_first=True, dim_feedforward=CFG.d_ff, dropout=CFG.dropout)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=CFG.n_layers)

    def forward(self, src):

        image_features = src.observations # images: [batch_size, sequence_length, 2048]
        actions = src.actions

        src_key_padding_mask = (image_features.mean(dim=2)==0.0)
        features = torch.cat((image_features, actions), dim=-1)

        # Transformer encoder
        # add positional encoding
        features = features * math.sqrt(CFG.d_model)
        features = self.pos_encoder(features)
        behaviour_encoding = self.transformer_encoder(features, src_key_padding_mask=src_key_padding_mask)

        #print("behaviour_encoding: ", behaviour_encoding.shape)

        return behaviour_encoding

class TransformerMapper(nn.Module):

    def __init__(self, dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8):
        super(TransformerMapper, self).__init__()
        self.clip_length = clip_length
        #self.transformer = Transformer(dim_embedding, 8, num_layers)

        self.encoder_layer = nn.TransformerEncoderLayer(d_model=dim_embedding, nhead=8, batch_first=True, dim_feedforward=int(dim_embedding * 2), dropout=0.0)
        self.transformer = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)

        self.linear = nn.Linear(dim_clip, clip_length * dim_embedding)
        self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)

    def forward(self, x):
        x = self.linear(x).view(x.shape[0], self.clip_length, -1)
        prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)
        prefix = torch.cat((x, prefix), dim=1)
        out = self.transformer(prefix)[:, self.clip_length:]
        return out

class ClipCaptionModel(nn.Module):

    def __init__(self, prefix_length: int, clip_length: Optional[int] = None, prefix_size: int = 512,
                 num_layers: int = 8, mapping_type: MappingType = MappingType.MLP):
        


        super(ClipCaptionModel, self).__init__()
        self.prefix_length = prefix_length
        self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
        self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
        self.clip_project = TransformerMapper(prefix_size, self.gpt_embedding_size, prefix_length,
                                                                     clip_length, num_layers)
        self.behaviour_encoder = BehaviourEncoder()
        # self.project_to_gpt = nn.Linear(514, self.gpt_embedding_size).to("cuda")
        self.project_to_gpt = nn.Linear(2055, self.gpt_embedding_size).to("cuda")

    def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
        return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)

    def forward(self, data: AttrDict):

        tokens = data.gpt_tokens
        gpt_mask = data.gpt_mask
        labels = None

        embedding_text = self.gpt.transformer.wte(tokens)
        behaviour_encoding  = self.behaviour_encoder(data)

        
        behaviour_encoder_padding_mask = ~(data.observations.mean(dim=2)==0.0) * 1.0


        ########################################################################        
        prefix_projections = self.project_to_gpt(behaviour_encoding)


        # total_mask = torch.cat((behaviour_encoder_padding_mask, embedding_text), dim=1)
        total_mask = torch.cat((behaviour_encoder_padding_mask, gpt_mask), dim=1)

        # prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
        embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
        if labels is not None:
            dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
            labels = torch.cat((dummy_token, tokens), dim=1)
        out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=total_mask)
        return out

class ClipCaptionPrefix(ClipCaptionModel):

    def parameters(self, recurse: bool = True):
        return self.clip_project.parameters()

    def train(self, mode: bool = True):
        super(ClipCaptionPrefix, self).train(mode)
        self.gpt.eval()
        return self

find sutiable caption for optimization

In [None]:
path = CFG.datapath_training_parsed + "/lang_annotations/auto_lang_ann.npy"

annotations = np.load(path, allow_pickle=True).item()
annotations = annotations["language"]["ann"]

unique_annotaions = set(annotations)
for unique_annotaion in unique_annotaions:
    print(unique_annotaion, ': ', annotations.count(unique_annotaion))

print(len(annotations))

### Data Loader

In [3]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

datapath_training_parsed = CFG.datapath_training_parsed
datapath_val_parsed = CFG.datapath_val_parsed
caption_path_training = '{}/lang_annotations/auto_lang_ann.npy'.format(datapath_training_parsed)
caption_path_val = '{}/lang_annotations/auto_lang_ann.npy'.format(datapath_val_parsed)

# train_dataset = CustomDataset(datapath_dd_training, caption_path_training, tokenizer, max_seq_length)
train_dataset = CustomDataset(datapath_training_parsed, caption_path_training, tokenizer, CFG.max_seq_length)
val_dataset  = CustomDataset(datapath_val_parsed, caption_path_val, tokenizer, CFG.max_seq_length)

train_dataloader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True, num_workers=CFG.num_workers)
val_dataloader = DataLoader(val_dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=CFG.num_workers)



Test Data Loader

In [6]:
for idx, batch in enumerate(val_dataloader):
    """
    print("gpt_tokens: ", batch.gpt_tokens.shape)
    print("gpt_mask: ", batch.gpt_mask.shape)
    print("gpt_tokens: ", batch.gpt_tokens[0])
    print("gpt_mask: ", batch.gpt_mask[0])
    print(batch.instruction[0])
    print("actions: ", batch.actions.shape)
    print("observations: ", batch.observations.shape)
    print("observations: ", batch.observations)
    """
    if os.name == 'nt': 
        os.system('cls')
    else:
        os.system('clear')

    print("gpt_tokens: ", batch.gpt_tokens.shape)
    print("gpt_mask: ", batch.gpt_mask.shape)
    print(batch.instruction[0])
    print("actions: ", batch.actions.shape)
    print("observations: ", batch.observations.shape)
    print("batch at index done: ", idx)
    break


2023-09-02 18:11:31.404291: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-09-02 18:11:34.038852: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-09-02 18:11:36.624513: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-09-02

[H[2Jgpt_tokens:  torch.Size([16, 16])
gpt_mask:  torch.Size([16, 16])
turn on the green light
actions:  torch.Size([16, 64, 7])
observations:  torch.Size([16, 64, 2048])
batch at index done:  0


In [None]:

clip_model, _ = clip.load("ViT-B/32", device=CFG.device, jit=True)
clip_text_encoder = clip_model.encode_text
clip_text_features = clip_text_encoder(clip.tokenize("grasp the blue block, then rotate it left").to(CFG.device))   #.detach().cpu().numpy()[()].squeeze(0)

print(clip_text_features.shape)

## Training

In [4]:
def validate(model: ClipCaptionModel):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for data in tqdm(val_dataloader):

            data.observations = data.observations.to(CFG.device)
            data.actions = data.actions.to(CFG.device)
            # data.instruction = data.instruction.to(CFG.device)
            # data.instruction = (clip_text_encoder(clip.tokenize(data.instruction).to(CFG.device)).detach().cpu().numpy()).to(CFG.device)
            data.instruction = clip_text_encoder(clip.tokenize(data.instruction).to(CFG.device)).to(CFG.device)
            data.gpt_tokens = data.gpt_tokens.to(CFG.device)
            data.gpt_mask = data.gpt_mask.to(CFG.device)
            outputs = model(data)

            logits = outputs.logits[:, data.observations.shape[1] - 1: -1]
            loss = nnf.cross_entropy(logits.reshape(-1, logits.shape[-1]), data.gpt_tokens.flatten(), ignore_index=0)
            
            total_loss += loss.item()

    return total_loss / len(val_dataloader)

def train(model: ClipCaptionModel,
          lr: float = 2e-5, warmup_steps: int = 5000, output_dir: str = ".", output_prefix: str = ""):

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    epochs = CFG.epochs
    
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    model = model.to(device)
    model.train()
    optimizer = AdamW(model.parameters(), lr=lr)

    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=epochs * len(train_dataloader)
    )

    for epoch in range(epochs):
        print(f">>> Training epoch {epoch}")
        sys.stdout.flush()
        progress = tqdm(total=len(train_dataloader), desc=output_prefix)
        for idx, data in enumerate(train_dataloader):

            model.zero_grad()

            data.observations = data.observations.to(CFG.device)
            data.actions = data.actions.to(CFG.device)
            data.instruction = clip_text_encoder(clip.tokenize(data.instruction).to(CFG.device)).to(CFG.device)
            data.gpt_tokens = data.gpt_tokens.to(CFG.device)
            data.gpt_mask = data.gpt_mask.to(CFG.device)
            
            outputs = model(data)


            logits = outputs.logits[:, data.observations.shape[1] - 1: -1]
            loss = nnf.cross_entropy(logits.reshape(-1, logits.shape[-1]), data.gpt_tokens.flatten(), ignore_index=0)
            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            progress.set_postfix({"loss": loss.item()})
            wandb.log({"loss": loss.item()})
            progress.update()
            if (idx+1) % 20 == 0:
                torch.save(
                    model.state_dict(),
                    os.path.join(output_dir, f"{output_prefix}_latest.pt"),
                )
                # val_loss = validate(model)
                # wandb.log({"val_loss": val_loss})
        progress.close()
        if epoch % 1 == 0 or epoch == epochs - 1:
            torch.save(
                model.state_dict(),
                os.path.join(output_dir, f"{output_prefix}-{epoch:03d}.pt"),
            )
    return model

In [5]:
wandb.init(project="clipcalvin")
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 

prefix_length = 10
prefix_length_clip = 10
num_layers = 8
prefix_dim = 512
mapping_type = {'mlp': MappingType.MLP, 'transformer': MappingType.Transformer}["transformer"]

model = ClipCaptionModel(prefix_length, clip_length=prefix_length_clip, prefix_size=prefix_dim,
                          num_layers=num_layers, mapping_type=mapping_type)

# print(model)

#model = ClipCaptionPrefix(prefix_length, clip_length=prefix_length_clip, prefix_size=prefix_dim,
#                           num_layers=num_layers, mapping_type=mapping_type).to(CFG.device)

clip_model, _ = clip.load("ViT-B/32", device=CFG.device, jit=True)
clip_text_encoder = clip_model.encode_text

train(model, output_dir="./checkpoints/hulccap/run2", output_prefix="hulccap_prefix")
#for batch in train_dataloader:
   # print("token: ", batch.token.shape)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mtimlauffs[0m. Use [1m`wandb login --relogin`[0m to force relogin


768
torch.Size([1, 5000, 2055])
>>> Training epoch 0


hulccap_prefix:   0%|          | 0/321 [00:00<?, ?it/s]2023-09-02 18:08:55.951991: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-09-02 18:08:58.647269: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-09-02 18:09:01.296759: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild Ten

RuntimeError: Only support when num_heads is even in transformer

## Evaluation

### find best model

In [4]:
clip_model, _ = clip.load("ViT-B/32", device=CFG.device, jit=True)
clip_text_encoder = clip_model.encode_text

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 


def evaluate_loss(path, filename, val_dataloader):

    mapper_model = ClipCaptionModel(prefix_length=10, clip_length=10).to(CFG.device)
    mapper_model.load_state_dict(torch.load(path, map_location=CFG.device))
    mapper_model = mapper_model.eval()

    total_loss = 0
    for data in val_dataloader:

        data.observations = data.observations.to(CFG.device)
        data.actions = data.actions.to(CFG.device)
        data.instruction = clip_text_encoder(clip.tokenize(data.instruction).to(CFG.device)).to(CFG.device)
        data.gpt_tokens = data.gpt_tokens.to(CFG.device)
        data.gpt_mask = data.gpt_mask.to(CFG.device)

        outputs = mapper_model(data)

        with torch.no_grad():

            logits = outputs.logits[:, data.observations.shape[1] - 1: -1]
            loss = nnf.cross_entropy(logits.reshape(-1, logits.shape[-1]), data.gpt_tokens.flatten(), ignore_index=0)    
            total_loss += loss.item()
    print(filename, ' loss: ', total_loss / len(val_dataloader))
    return total_loss  / len(val_dataloader)

        

model_dir = "./checkpoints/hulccap/run_1/"

best_model = None
best_loss = float('inf')

for filename in os.listdir(model_dir):
    model_path = os.path.join(model_dir, filename)
    current_loss = evaluate_loss(model_path, filename, val_dataloader)
    if current_loss < best_loss:
        best_loss = current_loss
        best_model = filename

print("best model: ", best_model, ' loss: ', best_loss)


hulccap_prefix-028.pt  loss:  0.42106717685237527


'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: 828114de-54fb-499e-8d3b-e4d625cf80c6)')' thrown while requesting HEAD https://huggingface.co/gpt2/resolve/main/generation_config.json


hulccap_prefix-041.pt  loss:  0.43790126824751496
hulccap_prefix-037.pt  loss:  0.43389218486845493
hulccap_prefix-010.pt  loss:  0.6058757989667356
hulccap_prefix-051.pt  loss:  0.5002122279256582
hulccap_prefix-008.pt  loss:  0.7012736713513732
hulccap_prefix-027.pt  loss:  0.40414822241291404
hulccap_prefix-042.pt  loss:  0.47727493289858103
hulccap_prefix-019.pt  loss:  0.4471715269610286
hulccap_prefix-049.pt  loss:  0.4747509784065187
hulccap_prefix-048.pt  loss:  0.4856449053622782
hulccap_prefix-036.pt  loss:  0.41252506291493773
hulccap_prefix-015.pt  loss:  0.6118547092191875
hulccap_prefix-001.pt  loss:  1.5453845523297787
hulccap_prefix-043.pt  loss:  0.45606917794793844
hulccap_prefix-011.pt  loss:  0.5996267069131136
hulccap_prefix-038.pt  loss:  0.4191669840365648
hulccap_prefix-018.pt  loss:  0.4846190740354359
hulccap_prefix-045.pt  loss:  0.46406757459044456
hulccap_prefix-005.pt  loss:  0.8662657756358385
hulccap_prefix-016.pt  loss:  0.5008556833490729
hulccap_prefi

### Evaluate best model

#### generate text using beam

In [4]:
def generate_beam(model, tokenizer, beam_size: int = 5, prompt=None, embed=None,
                  entry_length=67, temperature=1., stop_token: str = '\n'):

    model.eval()
    stop_token_index = tokenizer.encode(stop_token)[0]
    tokens = None
    scores = None
    device = next(model.parameters()).device
    seq_lengths = torch.ones(beam_size, device=device)
    is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
    with torch.no_grad():
        if embed is not None:
            generated = embed
        else:
            if tokens is None:
                tokens = torch.tensor(tokenizer.encode(prompt))
                tokens = tokens.unsqueeze(0).to(device)
                generated = model.gpt.transformer.wte(tokens)
        for i in range(entry_length):
            outputs = model.gpt(inputs_embeds=generated)
            logits = outputs.logits
            logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
            logits = logits.softmax(-1).log()
            if scores is None:
                scores, next_tokens = logits.topk(beam_size, -1)
                generated = generated.expand(beam_size, *generated.shape[1:])
                next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
                if tokens is None:
                    tokens = next_tokens
                else:
                    tokens = tokens.expand(beam_size, *tokens.shape[1:])
                    tokens = torch.cat((tokens, next_tokens), dim=1)
            else:
                logits[is_stopped] = -float(np.inf)
                logits[is_stopped, 0] = 0
                scores_sum = scores[:, None] + logits
                seq_lengths[~is_stopped] += 1
                scores_sum_average = scores_sum / seq_lengths[:, None]
                scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
                next_tokens_source = next_tokens // scores_sum.shape[1]
                seq_lengths = seq_lengths[next_tokens_source]
                next_tokens = next_tokens % scores_sum.shape[1]
                next_tokens = next_tokens.unsqueeze(1)
                tokens = tokens[next_tokens_source]
                tokens = torch.cat((tokens, next_tokens), dim=1)
                generated = generated[next_tokens_source]
                scores = scores_sum_average * seq_lengths
                is_stopped = is_stopped[next_tokens_source]
            next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
            generated = torch.cat((generated, next_token_embed), dim=1)
            is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
            if is_stopped.all():
                break
    scores = scores / seq_lengths
    output_list = tokens.cpu().numpy()
    output_texts = [tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
    order = scores.argsort(descending=True)
    output_texts = [output_texts[i] for i in order]
    return output_texts


In [5]:
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 

clip_model, _ = clip.load("ViT-B/32", device=CFG.device, jit=True)
clip_text_encoder = clip_model.encode_text

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

path_best = "./checkpoints/hulccap/run_1/hulccap_prefix-034.pt"
best_model = ClipCaptionModel(prefix_length=10, clip_length=10).to(CFG.device)
best_model.load_state_dict(torch.load(path_best, map_location=CFG.device))
best_model = best_model.eval()

for data in val_dataloader:

    instruction_ground = data.instruction

    data.observations = data.observations.to(CFG.device)
    data.actions = data.actions.to(CFG.device)
    data.gpt_tokens = data.gpt_tokens.to(CFG.device)
    data.gpt_mask = data.gpt_mask.to(CFG.device)

    outputs = best_model(data)
    
    for i in range(len(data.instruction)):
        print("INSTRUCTION:", instruction_ground[i])

        #generated_captions_beam = generate_beam(best_model, tokenizer)
        #print(generated_captions_beam)
        
        with torch.no_grad():
            generated_token = outputs.logits[i].argmax(dim=-1)
            generated_text = tokenizer.decode(generated_token)
            print("Generated Caption:", generated_text)
    break
        

INSTRUCTION: turn on the green light


terminate called without an active exception
terminate called without an active exception
terminate called without an active exception
terminate called without an active exception
terminate called without an active exception
terminate called without an active exception
terminate called without an active exception
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda952e8c20>
Traceback (most recent call last):
  File "/home/tim/anaconda3/envs/cap-env/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/home/tim/anaconda3/envs/cap-env/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1442, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/home/tim/anaconda3/envs/cap-env/lib/python3.11/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tim/anaconda3/envs/cap-env/lib

ValueError: Input None is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers.