# Save Activation and Data Used for Getting the Activation

In [1]:
import numpy as np
import torch
import os
import sys

sys.path.append('../')
from sample_batch_data import get_data_info, get_batch

sys.path.append('../../')
from decision_transformer.models.decision_transformer import DecisionTransformer

In [4]:
def save_data_and_activation(
    seed=666,
    model_name='gpt2',
    epoch=40,
    env_name_list=['hopper', 'halfcheetah', 'walker2d'],
    ):

    for env_name in env_name_list:
        
        torch.manual_seed(seed)

        dataset_name = 'medium'

        if model_name == 'gpt2':
            pretrained_lm1 = 'gpt2'
        elif model_name == 'clip':
            pretrained_lm1 = 'openai/clip-vit-base-patch32'
        elif model_name == 'igpt':
            pretrained_lm1 = 'openai/imagegpt-small'
        elif model_name == 'dt':
            pretrained_lm1 = False

        variant = {
            'embed_dim': 768,
            'n_layer': 12,
            'n_head': 1,
            'activation_function': 'relu',
            'dropout': 0.2, # 0.1
            'load_checkpoint': False if epoch==0 else f'../checkpoints/{model_name}_medium_{env_name}_666/model_{epoch}.pt',
            'seed': seed,
            'outdir': f"checkpoints/{model_name}_{dataset_name}_{env_name}_{seed}",
            'env': env_name,
            'dataset': dataset_name,
            'model_type': 'dt',
            'K': 20, # 2
            'pct_traj': 1.0,
            'batch_size': 100,  # 64
            'num_eval_episodes': 100,
            'max_iters': 40,
            'num_steps_per_iter': 2500,
            'pretrained_lm': pretrained_lm1,
            'gpt_kmeans': None,
            'kmeans_cache': None,
            'frozen': False,
            'extend_positions': False,
            'share_input_output_proj': True
        }

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        state_dim, act_dim, max_ep_len, scale = get_data_info(variant)
        states, actions, rewards, dones, rtg, timesteps, attention_mask = get_batch(variant, state_dim, act_dim, max_ep_len, scale, device)

        data = {
            'states': states,
            'actions': actions,
            'rtg': rtg,
            'timesteps': timesteps,
            'attention_mask': attention_mask
        }

        model = DecisionTransformer(
            args=variant,
            state_dim=state_dim,
            act_dim=act_dim,
            max_length=variant["K"],
            max_ep_len=max_ep_len,
            hidden_size=variant["embed_dim"],
            n_layer=variant["n_layer"],
            n_head=variant["n_head"],
            n_inner=4 * variant["embed_dim"],
            activation_function=variant["activation_function"],
            n_positions=1024,
            resid_pdrop=variant["dropout"],
            attn_pdrop=0.1,
        )
        if variant["load_checkpoint"]:
            state_dict = torch.load(variant["load_checkpoint"])  # , map_location=torch.device('cpu')
            model.load_state_dict(state_dict)
            print(f"Loaded from {variant['load_checkpoint']}")

        model.to('cuda')
        model.eval()

        activation = {}
        def get_activation(name):
            def hook(model, input, output):
                activation[name] = output.detach()
            return hook

        for block_id in range(len(model.transformer.h)):
            model.transformer.h[block_id].ln_1.register_forward_hook(get_activation(f'{block_id}.ln_1'))
            model.transformer.h[block_id].attn.c_attn.register_forward_hook(get_activation(f'{block_id}.attn.c_attn'))
            model.transformer.h[block_id].attn.c_proj.register_forward_hook(get_activation(f'{block_id}.attn.c_proj'))
            model.transformer.h[block_id].attn.attn_dropout.register_forward_hook(get_activation(f'{block_id}.attn.attn_dropout'))
            model.transformer.h[block_id].attn.resid_dropout.register_forward_hook(get_activation(f'{block_id}.attn.resid_dropout'))
            model.transformer.h[block_id].ln_2.register_forward_hook(get_activation(f'{block_id}.ln_2'))
            model.transformer.h[block_id].mlp.c_fc.register_forward_hook(get_activation(f'{block_id}.mlp.c_fc'))
            model.transformer.h[block_id].mlp.c_proj.register_forward_hook(get_activation(f'{block_id}.mlp.c_proj'))
            # model.transformer.h[block_id].mlp.act.register_forward_hook(get_activation(f'{block_id}.mlp.act'))  # actはfunctionらしくregister_forward_hookがないと言われる
            model.transformer.h[block_id].mlp.dropout.register_forward_hook(get_activation(f'{block_id}.mlp.dropout'))

        state_preds, action_preds, reward_preds, all_embs = model.forward(
            states,
            actions,
            rewards,
            rtg[:, :-1],
            timesteps,
            attention_mask=attention_mask,
        )

        activation_ordered = {}
        block_name_list = [
            'ln_1',
            'attn.c_attn',
            'attn.c_proj',
            'attn.resid_dropout',
            'ln_2',
            'mlp.c_fc',
            'mlp.c_proj',
            'mlp.dropout'
        ]
        for block_id in range(len(model.transformer.h)):
            for block_name in block_name_list:
                activation_ordered[f'{block_id}.{block_name}'] = activation[f'{block_id}.{block_name}']
        batch_size = variant['batch_size']
        np.save(f'results/activation_{epoch}_{model_name}_{env_name}_{dataset_name}_{seed}_{batch_size}.npy', activation_ordered)
        np.save(f'data/data_{env_name}_{dataset_name}_{seed}_{batch_size}.npy', data)

In [1]:
save_data_and_activation(
    seed=666,
    model_name='gpt2',
    epoch=40,
    env_name_list=['hopper'],
    )