In [1]:
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import torch.nn as nn
from piqa import SSIM
from tqdm import tqdm
import wandb
import matplotlib.pyplot as plt
import os
import csv

from datasets.image_vae_dataset import ImageVAEDataset, ImageVAEDatasetFull
from datasets.bc_dataset import BCDataset
from models.visual_autoencoder import VisualAutoencoder
from models.bc_policy import LanguageConditionedPolicy
import config as CFG


  from .autonotebook import tqdm as notebook_tqdm


## VisualAutoencoder

dataloder

In [None]:
caption_path_training = '{}/lang_annotations/auto_lang_ann.npy'.format(CFG.datapath_training)
train_dataset = ImageVAEDataset(dataset_path=CFG.datapath_training, caption_path=caption_path_training)
train_loader = DataLoader(train_dataset,
                          #num_workers=16,
                          #prefetch_factor=6, 
                          batch_size=CFG.batch_size, 
                          shuffle=True)

caption_path_val = '{}/lang_annotations/auto_lang_ann.npy'.format(CFG.datapath_val)
val_dataset = ImageVAEDataset(dataset_path=CFG.datapath_val, caption_path=caption_path_val)
val_loader = DataLoader(val_dataset,
                          #num_workers=16,
                          #prefetch_factor=6, 
                          batch_size=CFG.batch_size, 
                          shuffle=True)


train_dataset_full = ImageVAEDatasetFull(dataset_path=CFG.datapath_training)
train_loader_full = DataLoader(train_dataset_full,
                          #num_workers=16,
                          #prefetch_factor=6, 
                          batch_size=CFG.batch_size, 
                          shuffle=True)

val_dataset_full = ImageVAEDatasetFull(dataset_path=CFG.datapath_val)
val_loader_full = DataLoader(val_dataset_full,
                          #num_workers=16,
                          #prefetch_factor=6, 
                          batch_size=CFG.batch_size, 
                          shuffle=True)

test dataloader

In [None]:
for idx, batch in enumerate(train_loader):
    print("img_static: ", batch.img_static.shape)
    print("img_gripper: ", batch.img_gripper.shape)
    break

for idx, batch in enumerate(train_loader_full):
    print("img_static: ", batch.img_static.shape)
    print("img_gripper: ", batch.img_gripper.shape)
    break


### Train VAE model

In [None]:
encoding_size = 512
learning_rate = 0.0001

class SSIMLoss(SSIM):
    def forward(self, x, y):
        return 1. - super().forward(x, y)

def visualize_reconstruction(original, reconstructed):
    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    axs[0].imshow(original.permute(1, 2, 0))
    axs[0].set_title('Original')
    axs[1].imshow(reconstructed.detach().permute(1, 2, 0))
    axs[1].set_title('Reconstructed')
    plt.close()
    # return figure for wandb
    return fig


def train():
    # Train the model
    step = 0
    checkpoint_step = 0;
    #use tqdm for progress bar
    for epoch in range(CFG.epochs):

        running_loss = 0.0
        print(f"Epoch {epoch+1}/{CFG.epochs}")
        
        #for i, data in enumerate(tqdm(train_loader)):
        for i, data in enumerate(tqdm(train_loader_full)):
            # Get the inputs and move them to the GPU if available
            inputs = data['img_static']
            inputs = inputs.to(CFG.device)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(inputs)

            # Compute the loss
            reconstruction_loss = criterion(outputs, inputs)
            #ssim_loss = SSIM_criterion(outputs, inputs) * 0.001

            loss = reconstruction_loss #+ ssim_loss
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            # Log the loss
            wandb.log({"Total Loss": loss.item()}, step=step)
            wandb.log({"Reconstruction Loss": reconstruction_loss.item()}, step=step)
            #wandb.log({"SSIM Loss": ssim_loss.item()}, step=step)

            #if step % 321 == 0:
            if step % 2000 == 0:
                #visualize_reconstruction(inputs[0].detach().cpu(), outputs[0].detach().cpu())
                wandb.log({"Training Images": wandb.Image(visualize_reconstruction(inputs[0].detach().cpu(), outputs[0].detach().cpu()))}, step=step)
                #save model weights
                torch.save(model.state_dict(), f"checkpoints/image_vae/image_vae_static_full_long/image_vae_static_full_high_res-{checkpoint_step:03d}.pt")
                checkpoint_step += 1
            
            step += 1
        """
        wandb.log({"Training Images": wandb.Image(visualize_reconstruction(inputs[0].detach().cpu(), outputs[0].detach().cpu()))}, step=step)
        torch.save(model.state_dict(), f"checkpoints/image_vae/image_vae_static_long/image_vae_static_high_res-{epoch:03d}.pt")
        """

wandb.init(project="ImageVAE")
model = VisualAutoencoder(encoding_size)
model.to(CFG.device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()
SSIM_criterion = SSIMLoss().cuda()

train()

validate image vae

In [None]:
def visualize_reconstruction(original, reconstructed):
    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    axs[0].imshow(original.permute(1, 2, 0))
    axs[0].set_title('Original')
    axs[1].imshow(reconstructed.detach().permute(1, 2, 0))
    axs[1].set_title('Reconstructed')
    plt.show()
    plt.close()
    # return figure for wandb
    return fig

def reconstruct(model):
    for i, data in enumerate(val_loader_full):
        inputs = data['img_static']
        inputs = inputs.to(CFG.device)
        outputs = model(inputs)
        for i in range(0, 6):
            visualize_reconstruction(inputs[i].detach().cpu(), outputs[i].detach().cpu())
        break
    
encoding_size = 512
model = VisualAutoencoder(encoding_size).to(CFG.device)
model.load_state_dict(torch.load('checkpoints/image_vae/image_vae_static_full_long/image_vae_static_full_high_res-045.pt', map_location=CFG.device))
model = model.eval()
reconstruct(model)


## Behavior Cloning Policy

dataloader

In [2]:
caption_path_training = '{}/lang_annotations/auto_lang_ann.npy'.format(CFG.datapath_training)
train_dataset_bc = BCDataset(data_path=CFG.datapath_training, caption_path=caption_path_training)
train_loader_bc = DataLoader(train_dataset_bc, 
                          batch_size=CFG.batch_size, 
                          shuffle=True)

caption_path_val = '{}/lang_annotations/auto_lang_ann.npy'.format(CFG.datapath_val)
val_dataset_bc = BCDataset(data_path=CFG.datapath_val, caption_path=caption_path_val)
val_loader_bc = DataLoader(val_dataset_bc,
                          batch_size=CFG.batch_size, 
                          shuffle=True)

test dataloader bc policy

In [None]:
for idx, batch in enumerate(train_loader_bc):
    print("img_static: ", batch.img_static.shape)
    print("img_gripper: ", batch.img_gripper.shape)
    print("action: ", batch.action.shape)
    print("text_encoding: ", batch.text_encoding.shape)

    print("img_static: ", batch.img_static.dtype)
    print("img_gripper: ", batch.img_gripper.dtype)
    print("action: ", batch.action.dtype)
    print("text_encoding: ", batch.text_encoding.dtype)
    break


### Train BC Policy

calc mean and std of robot actions (used in bc_dataset.py)

In [6]:
# mean and std for entire training data

datafiles = os.listdir(CFG.datapath_training)

all_actions = []
# Loop through and print the file names
for file in tqdm(datafiles):
    try:
        file_path = os.path.join(CFG.datapath_training, file)
        data = np.load(file_path)
        all_actions.append(data["actions"])
    except:
        print(f"Skipping '{file}'")

mean = np.mean(np.array(all_actions), axis=0)
std = np.std(np.array(all_actions), axis=0)

print("mean: ", mean)
print("std: ", std)


  4%|▎         | 18916/512093 [00:03<01:27, 5614.74it/s]

Skipping 'lang_huggingface_distilroberta'


  6%|▌         | 29331/512093 [00:05<01:25, 5621.76it/s]

Skipping 'lang_all-mpnet-base-v2'


 12%|█▏        | 62716/512093 [00:11<01:20, 5590.20it/s]

Skipping 'lang_clip_ViTB32'


 14%|█▍        | 73362/512093 [00:13<01:18, 5562.20it/s]

Skipping 'scene_info.npy'


 23%|██▎       | 118466/512093 [00:22<01:15, 5214.49it/s]

Skipping 'lang_clip_resnet50'


 44%|████▍     | 225185/512093 [00:41<00:53, 5411.02it/s]

Skipping 'lang_BERT'


 49%|████▉     | 252485/512093 [00:46<00:47, 5495.48it/s]

Skipping 'lang_annotations'


 50%|████▉     | 254713/512093 [00:47<00:46, 5563.40it/s]

Skipping 'lang_all-distilroberta-v1'


 50%|█████     | 256408/512093 [00:47<00:45, 5617.01it/s]

Skipping 'lang_huggingface_mpnet'


 60%|█████▉    | 305779/512093 [00:56<00:37, 5572.79it/s]

Skipping 'lang_msmarco-bert-base-dot-v5'


 63%|██████▎   | 320296/512093 [00:58<00:34, 5597.29it/s]

Skipping 'lang_all-MiniLM-L6-v2'


 74%|███████▎  | 377467/512093 [01:09<00:25, 5321.87it/s]

Skipping 'lang_paraphrase-MiniLM-L3-v2'


 83%|████████▎ | 424024/512093 [01:18<00:17, 5120.71it/s]

Skipping 'ep_start_end_ids.npy'


 90%|████████▉ | 458941/512093 [01:24<00:09, 5538.18it/s]

Skipping 'ep_lens.npy'


 94%|█████████▎| 480045/512093 [01:28<00:05, 5475.59it/s]

Skipping '.hydra'


 99%|█████████▊| 505671/512093 [01:33<00:01, 5470.24it/s]

Skipping 'statistics.yaml'


100%|██████████| 512093/512093 [01:34<00:00, 5423.52it/s]


mean:  [ 0.05244775 -0.12080607  0.50815218  1.01496132 -0.03902264  1.56418701
  0.13438409]
std:  [0.15992226 0.10983621 0.0623301  2.90982278 0.10183952 0.34633791
 0.99092932]


In [14]:
# mean and std only for captioned training data 

caption_path_training = '{}/lang_annotations/auto_lang_ann.npy'.format(CFG.datapath_training_parsed)
annotations = np.load(f"{caption_path_training}", allow_pickle=True).item()
annotations = annotations["info"]["indx"]

for range in annotations:
    files = [f"episode_{i:07d}.npz" for i in range]
    for file in files:
        try:
            file_path = os.path.join(CFG.datapath_training, file)
            data = np.load(file_path)
            all_actions.append(data["actions"])
        except:
            print(f"Skipping '{file}'")

mean = np.mean(np.array(all_actions), axis=0)
std = np.std(np.array(all_actions), axis=0)

print("mean: ", mean)
print("std: ", std)

mean:  [ 0.05256253 -0.12068113  0.50823375  1.01427333 -0.03897106  1.56396382
  0.13202699]
std:  [0.15978371 0.10957365 0.06213988 2.91004043 0.10183245 0.35185281
 0.99124612]


In [3]:
wandb.init(project="skillGPT_BC")
device = CFG.device

#hyperparameters
lr = 0.001
batch_size = 32
num_epochs = 1000
language_dim = 512
image_dim = 512
action_dim = 7

def train():
    # loss function
    criterion = nn.MSELoss()

    # Initialize the model
    policy = LanguageConditionedPolicy(language_dim, action_dim)
    policy = policy.to(device)

    # initialize the visual autoencoder
    visual_autoencoder = VisualAutoencoder(512)
    visual_autoencoder = visual_autoencoder.to(device)

    # load the weights
    visual_autoencoder.load_state_dict(torch.load("./checkpoints/image_vae/image_vae_static_full_long/image_vae_static_full_high_res-045.pt"))
    image_encoder = visual_autoencoder.encoder
    for param in image_encoder.parameters():
        param.requires_grad = True

    # Define the optimizer
    params = list(policy.parameters()) + list(image_encoder.parameters())
    optimizer = optim.Adam(params, lr=lr)

    step = 0
    # Training loop
    for epoch in range(num_epochs):
        # Train the model
        policy.train()
        epoch_train_loss = 0.0
        for i, data in enumerate(tqdm(train_loader_bc)):

            language_emb = data["text_encoding"].to(device).to(torch.float)
            image = data["img_gripper"].to(device)
            action = data["action"].to(device)
            #action = data["action"].to(torch.float32).to(device)
            optimizer.zero_grad()
            image_features = image_encoder(image)
            pred_action = policy(language_emb, image_features)
            loss = criterion(pred_action, action)
            loss.backward()
            optimizer.step()
            epoch_train_loss += loss.item()
            
            wandb.log({"train_loss": loss.item()}, step)
            step += 1
            
        # Print the average training loss for the epoch
        avg_epoch_train_loss = epoch_train_loss / len(train_loader_bc)
        print(f"Epoch {epoch+1} Training Loss: {avg_epoch_train_loss:.4f}")

        # Evaluate the model on the validation set
        policy.eval()
        epoch_val_loss = 0.0
        with torch.no_grad():
            for i, data in enumerate(tqdm(val_loader_bc)):
                language_emb = data["text_encoding"].to(device).to(torch.float)
                image = data["img_gripper"].to(device)
                action = data["action"].to(device)
                image_features = image_encoder(image)
                pred_action = policy(language_emb, image_features)
                loss = criterion(pred_action, action)
                epoch_val_loss += loss.item()
                
        # Print the average validation loss for the epoch
        avg_epoch_val_loss = epoch_val_loss / len(val_loader_bc)
        print(f"Epoch {epoch+1} Validation Loss: {avg_epoch_val_loss:.8f}")
        wandb.log({"val_loss": avg_epoch_val_loss}, step)
        
        # Save the model checkpoint
        torch.save(policy.state_dict(), f"./checkpoints/bc_policy/bc_policy_static_2/bc_policy_static-{epoch:03d}")
        
train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mtimlauffs[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/321 [00:00<?, ?it/s]


AttributeError: 'BCDataset' object has no attribute 'action_stats'