In [1]:
import numpy as np
import os
import math
import sys

from pathlib import Path
import cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
from sklearn.manifold import TSNE
from tqdm import tqdm
from enum import Enum
from typing import Optional
from r3m import load_r3m
import clip

import torch
import torch.nn as nn
import torch.multiprocessing as mp
from torch.utils.data import DataLoader, Dataset
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AdamW, get_linear_schedule_with_warmup
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, ToPILImage, InterpolationMode

from torch.nn import functional as nnf
import wandb

from dataset import CustomDataset, AttrDict 
from utils.visualize import visualize
from models.caption_model import ClipCaptionModel, MappingType

import config as CFG


mp.set_start_method('spawn')


  from .autonotebook import tqdm as notebook_tqdm
2023-09-05 23:19:16.329930: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Visualize Calvin Data

In [None]:
visualize()

In [None]:
datapath_test= '.././test_data/D_D/task_D_D_episode.npz'
data = np.load(datapath_test)
print(list(data.keys()))
print(data['actions'].shape)
print(data['rel_actions'].shape)
print(data['rgb_static'].shape)
print(data['rgb_gripper'].shape)
print(data['scene_obs'])

## Captioning

find sutiable caption for optimization

In [None]:
path = CFG.datapath_training_parsed + "/lang_annotations/auto_lang_ann.npy"

annotations = np.load(path, allow_pickle=True).item()
annotations = annotations["language"]["ann"]

unique_annotaions = set(annotations)
for unique_annotaion in unique_annotaions:
    print(unique_annotaion, ': ', annotations.count(unique_annotaion))

print(len(annotations))

### Data Loader

In [2]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

datapath_training_parsed = CFG.datapath_training_parsed
datapath_val_parsed = CFG.datapath_val_parsed
caption_path_training = '{}/lang_annotations/auto_lang_ann.npy'.format(datapath_training_parsed)
caption_path_val = '{}/lang_annotations/auto_lang_ann.npy'.format(datapath_val_parsed)

# train_dataset = CustomDataset(datapath_dd_training, caption_path_training, tokenizer, max_seq_length)
train_dataset = CustomDataset(datapath_training_parsed, caption_path_training, tokenizer, CFG.max_seq_length)
val_dataset  = CustomDataset(datapath_val_parsed, caption_path_val, tokenizer, CFG.max_seq_length)

train_dataloader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True, num_workers=CFG.num_workers)
val_dataloader = DataLoader(val_dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=CFG.num_workers)



Test Data Loader

In [None]:
for idx, batch in enumerate(val_dataloader):
    if os.name == 'nt': 
        os.system('cls')
    else:
        os.system('clear')

    print("gpt_tokens: ", batch.gpt_tokens.shape)
    print("gpt_mask: ", batch.gpt_mask.shape)
    print(batch.instruction[0])
    print("actions: ", batch.actions.shape)
    print("observations: ", batch.observations.shape)
    print("batch at index done: ", idx)
    break


In [None]:

clip_model, _ = clip.load("ViT-B/32", device=CFG.device, jit=True)
clip_text_encoder = clip_model.encode_text
clip_text_features = clip_text_encoder(clip.tokenize("grasp the blue block, then rotate it left").to(CFG.device))   #.detach().cpu().numpy()[()].squeeze(0)

print(clip_text_features.shape)

## Training

In [None]:
def validate(model: ClipCaptionModel):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for data in tqdm(val_dataloader):

            data.observations = data.observations.to(CFG.device)
            data.actions = data.actions.to(CFG.device)
            # data.instruction = data.instruction.to(CFG.device)
            # data.instruction = (clip_text_encoder(clip.tokenize(data.instruction).to(CFG.device)).detach().cpu().numpy()).to(CFG.device)
            data.instruction = clip_text_encoder(clip.tokenize(data.instruction).to(CFG.device)).to(CFG.device)
            data.gpt_tokens = data.gpt_tokens.to(CFG.device)
            data.gpt_mask = data.gpt_mask.to(CFG.device)
            outputs = model(data)

            logits = outputs.logits[:, data.observations.shape[1] - 1: -1]
            loss = nnf.cross_entropy(logits.reshape(-1, logits.shape[-1]), data.gpt_tokens.flatten(), ignore_index=0)
            
            total_loss += loss.item()

    return total_loss / len(val_dataloader)

def train(model: ClipCaptionModel,
          lr: float = 2e-5, warmup_steps: int = 5000, output_dir: str = ".", output_prefix: str = ""):

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    epochs = CFG.epochs
    
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    model = model.to(device)
    model.train()
    optimizer = AdamW(model.parameters(), lr=lr)

    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=epochs * len(train_dataloader)
    )

    for epoch in range(epochs):
        print(f">>> Training epoch {epoch}")
        sys.stdout.flush()
        progress = tqdm(total=len(train_dataloader), desc=output_prefix)
        for idx, data in enumerate(train_dataloader):

            model.zero_grad()

            data.observations = data.observations.to(CFG.device)
            data.actions = data.actions.to(CFG.device)
            data.instruction = clip_text_encoder(clip.tokenize(data.instruction).to(CFG.device)).to(CFG.device)
            data.gpt_tokens = data.gpt_tokens.to(CFG.device)
            data.gpt_mask = data.gpt_mask.to(CFG.device)
            
            outputs = model(data)


            logits = outputs.logits[:, data.observations.shape[1] - 1: -1]
            loss = nnf.cross_entropy(logits.reshape(-1, logits.shape[-1]), data.gpt_tokens.flatten(), ignore_index=0)
            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            progress.set_postfix({"loss": loss.item()})
            wandb.log({"loss": loss.item()})
            progress.update()
            if (idx+1) % 20 == 0:
                torch.save(
                    model.state_dict(),
                    os.path.join(output_dir, f"{output_prefix}_latest.pt"),
                )
                # val_loss = validate(model)
                # wandb.log({"val_loss": val_loss})
        progress.close()
        if epoch % 1 == 0 or epoch == epochs - 1:
            torch.save(
                model.state_dict(),
                os.path.join(output_dir, f"{output_prefix}-{epoch:03d}.pt"),
            )
    return model

### Load and Train Model

In [None]:
wandb.init(project="clipcalvin")
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 

prefix_length = 10
prefix_length_clip = 10
num_layers = 8
prefix_dim = 512
mapping_type = {'mlp': MappingType.MLP, 'transformer': MappingType.Transformer}["transformer"]

model = ClipCaptionModel(prefix_length, clip_length=prefix_length_clip, prefix_size=prefix_dim,
                          num_layers=num_layers, mapping_type=mapping_type)


clip_model, _ = clip.load("ViT-B/32", device=CFG.device, jit=True)
clip_text_encoder = clip_model.encode_text

train(model, output_dir="./checkpoints/hulccap/run2_seq2seq", output_prefix="hulccap_prefix")


## Evaluation

### find best model

In [None]:
clip_model, _ = clip.load("ViT-B/32", device=CFG.device, jit=True)
clip_text_encoder = clip_model.encode_text

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 


def evaluate_loss(path, filename, val_dataloader):

    mapper_model = ClipCaptionModel(prefix_length=10, clip_length=10).to(CFG.device)
    mapper_model.load_state_dict(torch.load(path, map_location=CFG.device))
    mapper_model = mapper_model.eval()

    total_loss = 0
    for data in val_dataloader:

        data.observations = data.observations.to(CFG.device)
        data.actions = data.actions.to(CFG.device)
        data.instruction = clip_text_encoder(clip.tokenize(data.instruction).to(CFG.device)).to(CFG.device)
        data.gpt_tokens = data.gpt_tokens.to(CFG.device)
        data.gpt_mask = data.gpt_mask.to(CFG.device)

        outputs = mapper_model(data)

        with torch.no_grad():

            logits = outputs.logits[:, data.observations.shape[1] - 1: -1]
            loss = nnf.cross_entropy(logits.reshape(-1, logits.shape[-1]), data.gpt_tokens.flatten(), ignore_index=0)    
            total_loss += loss.item()
    print(filename, ' loss: ', total_loss / len(val_dataloader))
    return total_loss  / len(val_dataloader)

        

model_dir = "./checkpoints/hulccap/run_1/"

best_model = None
best_loss = float('inf')

for filename in os.listdir(model_dir):
    model_path = os.path.join(model_dir, filename)
    current_loss = evaluate_loss(model_path, filename, val_dataloader)
    if current_loss < best_loss:
        best_loss = current_loss
        best_model = filename

print("best model: ", best_model, ' loss: ', best_loss)


### Evaluate best model

Greedy decoding

In [3]:
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 

clip_model, _ = clip.load("ViT-B/32", device=CFG.device, jit=True)
clip_text_encoder = clip_model.encode_text

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

path_best = "./checkpoints/hulccap/run_1/hulccap_prefix-032.pt"
best_model = ClipCaptionModel(prefix_length=10, clip_length=10).to(CFG.device)
best_model.load_state_dict(torch.load(path_best, map_location=CFG.device))
best_model = best_model.eval()

for data in val_dataloader:

    instruction_ground = data.instruction

    data.observations = data.observations.to(CFG.device)
    data.actions = data.actions.to(CFG.device)
    data.gpt_tokens = data.gpt_tokens.to(CFG.device)
    data.gpt_mask = data.gpt_mask.to(CFG.device)

    outputs = best_model(data)
    
    for i in range(len(data.instruction)):
        print("INSTRUCTION:", instruction_ground[i])
        

        with torch.no_grad():
            generated_token = outputs.logits[i].argmax(dim=-1)
            print(outputs.logits[i].shape)
            print(generated_token.shape)

            # remove duplicates
            result = []
            for i in range(len(generated_token)):
                word = generated_token[i]
                if word not in result or result[-1] != word:
                    result.append(word)
            result_tensor = torch.tensor(result)

            generated_text = tokenizer.decode(result_tensor)
            print("Generated Caption:", generated_text)

    break

INSTRUCTION: turn on the green light
torch.Size([80, 50257])
torch.Size([80])
Generated Caption: grputturn theturnpush on the led lamp bulb This
INSTRUCTION: lift the pink block lying in the drawer
torch.Size([80, 50257])
torch.Size([80])
Generated Caption:  thego thego the pink block lying in the drawer and grasp
INSTRUCTION: slide the door to the left
torch.Size([80, 50257])
torch.Size([80])
Generated Caption: grslpushmove thepushide the door to the left, then
INSTRUCTION: pick up the red block from the table
torch.Size([80, 50257])
torch.Size([80])
Generated Caption: lift thelift up the red block from the table and lift
INSTRUCTION: take the pink block and turn it right
torch.Size([80, 50257])
torch.Size([80])
Generated Caption: rotgrrotgr the pink block and rotate it right the The
INSTRUCTION: toggle the light switch to turn off the yellow light
torch.Size([80, 50257])
torch.Size([80])
Generated Caption: inmoveturnmove the,turn the light switch to turn off the yellow light bulb the

Beam decoding

In [4]:
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 

val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=CFG.num_workers)

def beamsearch(model, tokenizer, embed, beam_size: int = 5, stop_token: str = '\n'):
    model.eval()
    scores = None
    tokens = None
    stop_token_index = tokenizer.encode(stop_token)[0]
    seq_lengths = torch.ones(beam_size, device=CFG.device)
    is_stopped = torch.zeros(beam_size, device=CFG.device, dtype=torch.bool)
    generated = embed
    with torch.no_grad():
        for i in range(10):
            outputs = best_model.gpt(inputs_embeds=generated)
            logits = outputs.logits[:, -1, :]
            logits = logits.softmax(-1).log()
            #print(logits.shape)
            if scores is None:
                scores, next_tokens = logits.topk(beam_size, -1)
                #print(scores)
                #print(next_tokens)
                generated = generated.expand(beam_size, *generated.shape[1:])
                #print(generated.shape)
                next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
                if tokens is None:
                    tokens = next_tokens
                else:
                    tokens = tokens.expand(beam_size, *tokens.shape[1:])
                    tokens = torch.cat((tokens, next_tokens), dim=1)
            else:
                logits[is_stopped] = -float(np.inf)
                logits[is_stopped, 0] = 0
                scores_sum = scores[:, None] + logits
                seq_lengths[~is_stopped] += 1
                scores_sum_average = scores_sum / seq_lengths[:, None]
                scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
                next_tokens_source = next_tokens // scores_sum.shape[1]
                seq_lengths = seq_lengths[next_tokens_source]
                next_tokens = next_tokens % scores_sum.shape[1]
                next_tokens = next_tokens.unsqueeze(1)
                tokens = tokens[next_tokens_source]
                tokens = torch.cat((tokens, next_tokens), dim=1)
                generated = generated[next_tokens_source]
                scores = scores_sum_average * seq_lengths
                is_stopped = is_stopped[next_tokens_source]
            next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
            generated = torch.cat((generated, next_token_embed), dim=1)
            is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
            if is_stopped.all():
                break
    scores = scores / seq_lengths
    output_list = tokens.cpu().numpy()
    output_texts = [tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
    order = scores.argsort(descending=True)
    output_texts = [output_texts[i] for i in order]
    return output_texts


clip_model, _ = clip.load("ViT-B/32", device=CFG.device, jit=True)
clip_text_encoder = clip_model.encode_text

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

path_best = "./checkpoints/hulccap/run_1/hulccap_prefix-032.pt"
best_model = ClipCaptionModel(prefix_length=10, clip_length=10).to(CFG.device)
best_model.load_state_dict(torch.load(path_best, map_location=CFG.device))
# best_model = best_model.eval()

for data in val_dataloader:

    instruction_ground = data.instruction

    data.observations = data.observations.to(CFG.device)
    data.actions = data.actions.to(CFG.device)
    data.gpt_tokens = data.gpt_tokens.to(CFG.device)
    data.gpt_mask = data.gpt_mask.to(CFG.device)

    # outputs = best_model(data)
    
    for i in range(len(data.instruction)):
        print("INSTRUCTION:", instruction_ground[i])
        
      #  generated_token = outputs.logits[i].argmax(dim=-1)
        """
        tokens = torch.tensor(tokenizer.encode("take"))
        tokens = tokens.unsqueeze(0).to(CFG.device)
        generated = best_model.gpt.transformer.wte(tokens)
        inputs = best_model.gpt.transformer.wte(tokens)
        """
        src = AttrDict(observations=data.observations, actions=data.actions)
        behaviour_encoding = best_model.behaviour_encoder(src)
        prefix_embed = best_model.project_to_gpt(behaviour_encoding)

        generated_caption =  beamsearch(best_model, tokenizer, prefix_embed)
        print(generated_caption)

INSTRUCTION: turn on the green light
['push down the button to turn on the led light', 'push the button to turn on the led light bulb', 'turn on the led light bulb to turn on the', 'turn on the led lamp on the led light bulb', 'toggle the button to turn on the led light bulb']
INSTRUCTION: lift the pink block lying in the drawer
['sweep the pink block to the left, then', 'go slide the pink block to the left side,', 'slide the pink block towards the left, then', 'go towards the pink block in the drawer and grasp', 'lift the pink block lying in the drawer and lift']
INSTRUCTION: slide the door to the left
['slide the door to the left, then let', 'push the sliding door to the left side, then', 'grasp the door handle, then slide the door', 'grasp the door handle and slide the door to', 'move the sliding door to the left side, then']
INSTRUCTION: pick up the red block from the table
['lift the red block from the table, then lift', 'lift the red block from the table and lift it', 'lift the p

KeyboardInterrupt: 