In [1]:
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import torch.nn as nn
from piqa import SSIM
from tqdm import tqdm
import wandb
import matplotlib.pyplot as plt
import os
import csv

from datasets.image_vae_dataset import ImageVAEDataset, ImageVAEDatasetFull
from datasets.bc_dataset import BCDataset
from models.visual_autoencoder import VisualAutoencoder
from models.bc_policy import LanguageConditionedPolicy

import config as CFG

  from .autonotebook import tqdm as notebook_tqdm


## VisualAutoencoder

dataloder

In [None]:
caption_path_training = '{}/lang_annotations/auto_lang_ann.npy'.format(CFG.datapath_training)
train_dataset = ImageVAEDataset(dataset_path=CFG.datapath_training, caption_path=caption_path_training)
train_loader = DataLoader(train_dataset,
                          #num_workers=16,
                          #prefetch_factor=6, 
                          batch_size=CFG.batch_size, 
                          shuffle=True)

caption_path_val = '{}/lang_annotations/auto_lang_ann.npy'.format(CFG.datapath_val)
val_dataset = ImageVAEDataset(dataset_path=CFG.datapath_val, caption_path=caption_path_val)
val_loader = DataLoader(val_dataset,
                          #num_workers=16,
                          #prefetch_factor=6, 
                          batch_size=CFG.batch_size, 
                          shuffle=True)


train_dataset_full = ImageVAEDatasetFull(dataset_path=CFG.datapath_training)
train_loader_full = DataLoader(train_dataset_full,
                          #num_workers=16,
                          #prefetch_factor=6, 
                          batch_size=CFG.batch_size, 
                          shuffle=True)

val_dataset_full = ImageVAEDatasetFull(dataset_path=CFG.datapath_val)
val_loader_full = DataLoader(val_dataset_full,
                          #num_workers=16,
                          #prefetch_factor=6, 
                          batch_size=CFG.batch_size, 
                          shuffle=True)

test dataloader

In [None]:
for idx, batch in enumerate(train_loader):
    print("img_static: ", batch.img_static.shape)
    print("img_gripper: ", batch.img_gripper.shape)
    break


for idx, batch in enumerate(train_dataset):
    print("img_static: ", batch.img_static.shape)
    print("img_gripper: ", batch.img_gripper.shape)
    plt.imshow(np.transpose(batch.img_static, (1, 2, 0)))
    plt.show()
    break

"""
for idx, batch in enumerate(train_loader_full):
    print("img_static: ", batch.img_static.shape)
    print("img_gripper: ", batch.img_gripper.shape)
    break
"""

### Train VAE model

In [None]:
encoding_size = 512
learning_rate = 0.0001

class SSIMLoss(SSIM):
    def forward(self, x, y):
        return 1. - super().forward(x, y)

def visualize_reconstruction(original, reconstructed):
    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    axs[0].imshow(original.permute(1, 2, 0))
    axs[0].set_title('Original')
    axs[1].imshow(reconstructed.detach().permute(1, 2, 0))
    axs[1].set_title('Reconstructed')
    plt.close()
    # return figure for wandb
    return fig


def train():
    # Train the model
    step = 0
    checkpoint_step = 0;
    #use tqdm for progress bar
    for epoch in range(CFG.epochs):

        running_loss = 0.0
        print(f"Epoch {epoch+1}/{CFG.epochs}")
        
        #for i, data in enumerate(tqdm(train_loader)):
        for i, data in enumerate(tqdm(train_loader_full)):
            # Get the inputs and move them to the GPU if available
            # inputs = data['img_static']
            inputs = data['img_gripper']
            inputs = inputs.to(CFG.device)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(inputs)

            # Compute the loss
            reconstruction_loss = criterion(outputs, inputs)
            #ssim_loss = SSIM_criterion(outputs, inputs) * 0.001

            loss = reconstruction_loss #+ ssim_loss
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            # Log the loss
            wandb.log({"Total Loss": loss.item()}, step=step)
            wandb.log({"Reconstruction Loss": reconstruction_loss.item()}, step=step)
            #wandb.log({"SSIM Loss": ssim_loss.item()}, step=step)

            #if step % 321 == 0:
            if step % 2000 == 0:
                #visualize_reconstruction(inputs[0].detach().cpu(), outputs[0].detach().cpu())
                wandb.log({"Training Images": wandb.Image(visualize_reconstruction(inputs[0].detach().cpu(), outputs[0].detach().cpu()))}, step=step)
                #save model weights
                torch.save(model.state_dict(), f"checkpoints/image_vae/image_vae_gripper/image_vae_gripper_high_res-{checkpoint_step:03d}.pt")
                checkpoint_step += 1
            
            step += 1
        """
        wandb.log({"Training Images": wandb.Image(visualize_reconstruction(inputs[0].detach().cpu(), outputs[0].detach().cpu()))}, step=step)
        torch.save(model.state_dict(), f"checkpoints/image_vae/image_vae_static_long/image_vae_static_high_res-{epoch:03d}.pt")
        """

wandb.init(project="ImageVAE")
model = VisualAutoencoder(encoding_size)
model.to(CFG.device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()
SSIM_criterion = SSIMLoss().cuda()

train()

validate image vae

In [None]:
def visualize_reconstruction(original, reconstructed):
    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    axs[0].imshow(original.permute(1, 2, 0))
    axs[0].set_title('Original')
    axs[1].imshow(reconstructed.detach().permute(1, 2, 0))
    axs[1].set_title('Reconstructed')
    plt.show()
    plt.close()
    # return figure for wandb
    return fig

def reconstruct(model):
    for i, data in enumerate(val_loader_full):
        inputs = data['img_static']
        inputs = inputs.to(CFG.device)
        outputs = model(inputs)
        for i in range(0, 6):
            visualize_reconstruction(inputs[i].detach().cpu(), outputs[i].detach().cpu())
        break
    
encoding_size = 512
model = VisualAutoencoder(encoding_size).to(CFG.device)
model.load_state_dict(torch.load('checkpoints/image_vae/image_vae_static_full_long/image_vae_static_full_high_res-045.pt', map_location=CFG.device))
model = model.eval()
reconstruct(model)


## Behavior Cloning Policy

dataloader

In [2]:
batch_size = 32

caption_path_training = '{}/lang_annotations/auto_lang_ann.npy'.format(CFG.datapath_training)
train_dataset_bc = BCDataset(data_path=CFG.datapath_training, caption_path=caption_path_training)
train_loader_bc = DataLoader(train_dataset_bc, 
                          batch_size=batch_size, 
                          shuffle=True)

caption_path_val = '{}/lang_annotations/auto_lang_ann.npy'.format(CFG.datapath_val)
val_dataset_bc = BCDataset(data_path=CFG.datapath_val, caption_path=caption_path_val)
val_loader_bc = DataLoader(val_dataset_bc,
                          batch_size=batch_size, 
                          shuffle=True)

test dataloader bc policy

In [3]:
for idx, batch in enumerate(train_loader_bc):
    print("img_static: ", batch.img_static.shape)
    print("img_gripper: ", batch.img_gripper.shape)
    print("robot_obs: ", batch.robot_obs.shape)
    print("robot_obs: ", batch.robot_obs[0])
    print("action: ", batch.action.shape)
    print("rel_action: ", batch.rel_action.shape)
    print("text_encoding: ", batch.text_encoding.shape)

    print("img_static: ", batch.img_static.dtype)
    print("img_gripper: ", batch.img_gripper.dtype)
    print("action: ", batch.action.dtype)
    print("rel_action: ", batch.rel_action.dtype)
    print("robot_obs: ", batch.robot_obs.dtype)
    print("text_encoding: ", batch.text_encoding.dtype)
    break


img_static:  torch.Size([32, 3, 180, 180])
img_gripper:  torch.Size([32, 3, 180, 180])
robot_obs:  torch.Size([32, 15])
robot_obs:  tensor([ 0.1724, -0.3310,  0.5218, -3.1360, -0.1580,  1.3245,  0.0800, -1.2581,
         1.1439,  1.8302, -2.1316, -1.0079,  1.7368,  0.3788,  1.0000])
action:  torch.Size([32, 7])
rel_action:  torch.Size([32, 7])
text_encoding:  torch.Size([32, 512])
img_static:  torch.float32
img_gripper:  torch.float32
action:  torch.float32
rel_action:  torch.float32
robot_obs:  torch.float32
text_encoding:  torch.float16


### Train BC Policy

calc mean and std of robot actions (used in bc_dataset.py)

In [None]:
# mean and std for entire training data

datafiles = os.listdir(CFG.datapath_training)

all_actions = []
# Loop through and print the file names
for file in tqdm(datafiles):
    try:
        file_path = os.path.join(CFG.datapath_training, file)
        data = np.load(file_path)
        all_actions.append(data["actions"])
    except:
        print(f"Skipping '{file}'")

mean = np.mean(np.array(all_actions), axis=0)
std = np.std(np.array(all_actions), axis=0)

print("mean: ", mean)
print("std: ", std)


In [3]:
wandb.init(project="skillGPT_BC")
device = CFG.device

#hyperparameters
# lr = 0.0001
lr = 0.0001
num_epochs = 1000
language_dim = 512
image_dim = 512
action_dim = 7

def train():
    # loss function
    criterion = nn.MSELoss()

    # Initialize the model
    policy = LanguageConditionedPolicy(language_dim, action_dim)
    policy = policy.to(device)

    # initialize the visual autoencoder
    visual_autoencoder_static = VisualAutoencoder(512)
    visual_autoencoder_static = visual_autoencoder_static.to(device)
    visual_autoencoder_static.load_state_dict(torch.load("./checkpoints/image_vae/image_vae_static_full_long/image_vae_static_full_high_res-045.pt"))
    image_encoder_static = visual_autoencoder_static.encoder
    for param in image_encoder_static.parameters():
        param.requires_grad = True

    visual_autoencoder_gripper = VisualAutoencoder(512)
    visual_autoencoder_gripper = visual_autoencoder_gripper.to(device)
    visual_autoencoder_gripper.load_state_dict(torch.load("./checkpoints/image_vae/image_vae_gripper/image_vae_gripper_high_res-017.pt"))
    image_encoder_gripper = visual_autoencoder_gripper.encoder
    for param in image_encoder_gripper.parameters():
        param.requires_grad = True


    # Define the optimizer
    params = list(policy.parameters()) + list(image_encoder_static.parameters()) + list(image_encoder_gripper.parameters()) 
    optimizer = optim.Adam(params, lr=lr)

    step = 0
    # Training loop
    for epoch in range(num_epochs):
        # Train the model
        policy.train()
        epoch_train_loss = 0.0
        for i, data in enumerate(tqdm(train_loader_bc)):

            language_emb = data["text_encoding"].to(device).to(torch.float)
          #   image = data["img_static"].to(device)
            action = data["action"].to(device)
            #action = data['rel_action'].to(device)

            robot_obs = data["robot_obs"].to(device)

            image_static = data["img_static"].to(device)
            optimizer.zero_grad()    
            image_features_static = image_encoder_static(image_static)

            image_gripper = data["img_gripper"].to(device)
            optimizer.zero_grad()    
            image_features_gripper = image_encoder_gripper(image_gripper)


            pred_action = policy(language_emb, image_features_static, image_features_gripper, robot_obs)
            loss = criterion(pred_action, action)
            loss.backward()
            optimizer.step()
            epoch_train_loss += loss.item()
            
            wandb.log({"train_loss": loss.item()}, step)
            step += 1
            
        # Print the average training loss for the epoch
        avg_epoch_train_loss = epoch_train_loss / len(train_loader_bc)
        print(f"Epoch {epoch+1} Training Loss: {avg_epoch_train_loss:.4f}")

        # Evaluate the model on the validation set
        policy.eval()
        epoch_val_loss = 0.0
        with torch.no_grad():
            for i, data in enumerate(tqdm(val_loader_bc)):
                language_emb = data["text_encoding"].to(device).to(torch.float)
                image_static = data["img_static"].to(device)
                image_gripper = data["img_gripper"].to(device)
                action = data["action"].to(device)
                #action = data["rel_action"].to(device)
                robot_obs = data["robot_obs"].to(device) 

                image_features_static = image_encoder_static(image_static)
                image_features_gripper = image_encoder_gripper(image_gripper)


                pred_action = policy(language_emb, image_features_static, image_features_gripper, robot_obs)
                loss = criterion(pred_action, action)
                epoch_val_loss += loss.item()
                
        # Print the average validation loss for the epoch
        avg_epoch_val_loss = epoch_val_loss / len(val_loader_bc)
        print(f"Epoch {epoch+1} Validation Loss: {avg_epoch_val_loss:.8f}")
        wandb.log({"val_loss": avg_epoch_val_loss}, step)
        
        # Save the model checkpoint
        torch.save(policy.state_dict(), f"./checkpoints/bc_policy/bc_static_gripper_obs/bc_policy-{epoch:03d}.pt")
        
train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mtimlauffs[0m. Use [1m`wandb login --relogin`[0m to force relogin


100%|██████████| 161/161 [00:46<00:00,  3.43it/s]


Epoch 1 Training Loss: 0.0785


100%|██████████| 32/32 [00:07<00:00,  4.04it/s]


Epoch 1 Validation Loss: 0.06364228


100%|██████████| 161/161 [00:43<00:00,  3.68it/s]


Epoch 2 Training Loss: 0.0613


100%|██████████| 32/32 [00:07<00:00,  4.05it/s]


Epoch 2 Validation Loss: 0.06202749


100%|██████████| 161/161 [00:44<00:00,  3.65it/s]


Epoch 3 Training Loss: 0.0557


100%|██████████| 32/32 [00:08<00:00,  3.98it/s]


Epoch 3 Validation Loss: 0.05844997


 22%|██▏       | 35/161 [00:09<00:35,  3.51it/s]


KeyboardInterrupt: 