In [1]:
#python3.7 
# install dependencies with requirements.txt (included in zip)

# To do experiments for Pong, or any other ALE game
# just change env_id
# e.g. for Pong, env_id = "PongNoFrameskip-v4"

In [None]:
from models import *
from train_utils import *
from pytorchtools import EarlyStopping

import torch.optim as optim
from tqdm import tqdm
import pandas as pd
from dagger import *
#from training import *
import cv2
from visualization import *
import pandas as pd

# import EarlyStopping
from pytorchtools import EarlyStopping

import sys, os
from copy import deepcopy

if torch.cuda.is_available:
        device = "cuda:0"
        print('Using GPU')
else:
        device = "cpu"
        print('using CPU')

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#@title get_data
# get dataset of state-action pairs from expert
from skimage.transform import resize
from PIL import Image
def gen_color_data(num_interactions=int(6e4), env_id="PongNoFrameskip-v4", preprocess=False):
    env, ppo_expert = get_env_and_model(env_id)

    if env_id == 'CartPole-v1':
        img = env.render(mode='rgb_array') 
    
    state_shape = env.observation_space.shape
    action_shape = env.action_space.shape

    print('state shape: ', state_shape)
    print('action shape: ', action_shape)
    
    atari_games = ['PongNoFrameskip-v4',
                   'EnduroNoFrameskip-v4',
                   'breakout'
                   ]

    
    #gather data
    if isinstance(env.action_space, gym.spaces.Box):
      expert_observations = np.empty((num_interactions,) + env.observation_space.shape)
      #expert_observations = np.empty((num_interactions, 4,84,84))
      expert_actions = np.empty((num_interactions,) + (env.action_space.shape[0],))

    else:
      #expert_observations = np.empty((num_interactions,) + env.observation_space.shape)
      expert_observations = np.empty((num_interactions, 4,84,84))
      expert_actions = np.empty((num_interactions,) + env.action_space.shape)

    episode_schedule = np.empty((num_interactions, 2))
    color_observations = np.empty((num_interactions,84,84,3))
      
    obs = env.reset()

    ep_number = 0
    
    for i in tqdm(range(num_interactions)):
        action, _ = ppo_expert.predict(obs, deterministic=True)
        if preprocess:
            obs = crop_pong(obs)[0]
            obs = np.expand_dims(resize(obs, (84,84,4)),0)

        expert_observations[i]= obs.transpose(0,3,1,2)
        frame = env.render(mode='rgb_array')
        im = Image.fromarray(frame)
        im = im.resize(size=(84,84), resample=Image.BICUBIC, reducing_gap=3.0)
        color_observations[i] = np.array(im)
        
        expert_actions[i] = action

        episode_schedule[i] = np.array([ep_number, i])
        
        obs, reward, done, info = env.step(action)
        if done:
            ep_number = ep_number + 1
            obs = env.reset()

    env.close()

    return expert_observations, color_observations, expert_actions, episode_schedule

In [None]:
def dataset_with_indices(cls):
    """
    Modifies the given Dataset class to return a tuple data, target, index
    instead of just data, target.
    """

    def __getitem__(self, index):
        data, target = cls.__getitem__(self, index)
        return data, target, index

    return type(cls.__name__, (cls,), {
        '__getitem__': __getitem__,
    })

In [None]:
new_data = False

In [None]:
if new_data:
    expert_observations, color_observations, expert_actions, episode_schedule = gen_color_data(num_interactions=int(1e4), env_id="SeaquestNoFrameskip-v4", preprocess=False)
    np.savez_compressed(
                        'seaquest_holdout.npz',
                        expert_actions=expert_actions,#np.array(acts),
                        color_observations=color_observations,
                        expert_observations=expert_observations,#np.array(states),
                        episode_schedule = episode_schedule#np.array(episode_schedule)
                )
else:
    arrs = np.load('seaqest.npz')
    expert_observations = arrs['expert_observations']
    color_observations = arrs['color_observations']
    expert_actions = arrs['expert_actions']
    episode_schedule = arrs['episode_schedule']

In [None]:
expert_dataset = ExpertDataset(expert_observations, expert_actions)

In [None]:
episode_labels = torch.FloatTensor(episode_schedule[:,0]).to(device)
step_labels = torch.FloatTensor(episode_schedule[:,1]).to(device)
DatasetWithInDices = dataset_with_indices(ExpertDataset)
dset = DatasetWithInDices(expert_observations, expert_actions)

In [None]:
holdout_arrs = np.load('seaquest_holdout.npz')
holdout_expert_observations = holdout_arrs['expert_observations']
holdout_color_observations = holdout_arrs['color_observations']
holdout_expert_actions = holdout_arrs['expert_actions']
holdout_episode_schedule = holdout_arrs['episode_schedule']

holdout_expert_dataset = ExpertDataset(holdout_expert_observations, holdout_expert_actions)
eval_loader = th.utils.data.DataLoader(
    dataset=holdout_expert_dataset, batch_size=128, shuffle=False
)

# Encoder Pretraining

In [None]:
# dataset for encoder pretraining
# RUN THIS FOR PRE-TRAINING

train_size = int(0.8 * len(dset))
test_size = int(0.2 * len(dset))


train_expert_dataset, test_expert_dataset = random_split(
        dset, [train_size, test_size], generator=torch.Generator().manual_seed(42))


kwargs = {"num_workers": 8, "pin_memory": False}
train_loader = th.utils.data.DataLoader(
        dataset=train_expert_dataset, batch_size=64, shuffle=True, **kwargs
)

test_loader = th.utils.data.DataLoader(
        dataset=test_expert_dataset, batch_size=64, shuffle=True, **kwargs,
)
    

In [None]:
push_loader = th.utils.data.DataLoader(
    dataset=expert_dataset, batch_size=128, shuffle=False
)

In [None]:
def plot_ae_outputs(encoder,decoder,n=10):
    plt.figure(figsize=(16,4.5))
    try:
        imgs, targets = next(iter(test_loader))
    except Exception as e:
        imgs, targets, _ = next(iter(test_loader))
    
    targets = targets.numpy()
    for i in range(n):
      ax = plt.subplot(2,n,i+1)
      img = imgs[i].unsqueeze(0).to(device)
      encoder.eval()
      decoder.eval()
      with torch.no_grad():
         rec_img  = decoder(encoder(img))
      for k in range(4):
         plt.imshow(img.cpu().squeeze().numpy()[k], cmap='gist_gray')
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
        ax.set_title('Original images')
      ax = plt.subplot(2, n, i + 1 + n)
      for k in range(4):
          plt.imshow(rec_img.cpu().squeeze().numpy()[k], cmap='gist_gray')  
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
         ax.set_title('Reconstructed images')
    plt.show()  

In [None]:
#https://raw.githubusercontent.com/lyakaap/NetVLAD-pytorch/master/hard_triplet_loss.py
import torch
import torch.nn as nn
import torch.nn.functional as F


class HardQuadrupletLoss(nn.Module):
    """Hard/Hardest Triplet Loss
    (pytorch implementation of https://omoindrot.github.io/triplet-loss)

    For each anchor, we get the hardest positive and hardest negative to form a triplet.
    """
    def __init__(self, margin1=0.1, margin2=.05, hardest=False, squared=False, epsilon=10, tau=11):
        """
        Args:
            margin: margin for triplet loss
            hardest: If true, loss is considered only hardest triplets.
            squared: If true, output is the pairwise squared euclidean distance matrix.
                If false, output is the pairwise euclidean distance matrix.
        """
        super(HardQuadrupletLoss, self).__init__()
        self.margin1 = margin1
        self.margin2 = margin2
        self.hardest = hardest
        self.squared = squared
        self.epsilon = epsilon
        self.tau = tau

    def forward(self, embeddings, labels, idx):
        """
        Args:
            labels: labels of the batch, of size (batch_size,)
            embeddings: tensor of shape (batch_size, embed_dim)

        Returns:
            triplet_loss: scalar tensor containing the triplet loss
        """
        pairwise_dist = _pairwise_distance(embeddings, squared=self.squared)

        if self.hardest:
            # Get the hardest positive pairs
            mask_anchor_positive = _get_anchor_positive_triplet_mask(labels).float()
            valid_positive_dist = pairwise_dist * mask_anchor_positive
            hardest_positive_dist, _ = torch.max(valid_positive_dist, dim=1, keepdim=True)

            # Get the hardest negative1 pairs
            mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).float()
            max_anchor_negative_dist, _ = torch.max(pairwise_dist, dim=1, keepdim=True)
            anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (
                    1.0 - mask_anchor_negative)
            hardest_negative_dist, _ = torch.min(anchor_negative_dist, dim=1, keepdim=True)
            
            # Get hardest negative 2 pairs

            # Combine biggest d(a, p) and smallest d(a, n) into final triplet loss
            quad_loss = F.relu(hardest_positive_dist - hardest_negative_dist + self.margin1)
            quad_loss += F.relu(hardest_positive_dist - hardest_negative2_dist + self.margin2)
            quad_loss = torch.mean(quad_loss)
        else:
            anc_pos_dist = pairwise_dist.unsqueeze(dim=2)
            anc_neg_dist = pairwise_dist.unsqueeze(dim=1)
            anc_neg2_dist = pairwise_dist.unsqueeze(dim=0)

            # Compute a 3D tensor of size (batch_size, batch_size, batch_size)
            # triplet_loss[i, j, k] will contain the triplet loss of anc=i, pos=j, neg=k
            # Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1)
            # and the 2nd (batch_size, 1, batch_size)
            loss = F.relu(anc_pos_dist - anc_neg_dist + self.margin1)
            loss += F.relu(anc_pos_dist - anc_neg2_dist + self.margin2)
            
            #print('\ninit loss stats', torch.min(loss).item(), torch.max(loss).item())

            mask = _get_quadruplet_mask(labels, idx).float()
            quadruplet_loss = loss * mask
            
            #print('masked loss stats', torch.min(triplet_loss).item(), torch.max(triplet_loss).item())

            # Remove negative losses (i.e. the easy triplets)
            #quadruplet_loss = F.relu(quadruplet_loss)

            # Count number of hard triplets (where triplet_loss > 0)
            hard_quadruplets = torch.gt(quadruplet_loss, 1e-16).float()
            num_hard_quadruplets = torch.sum(hard_quadruplets)

            quadruplet_loss = torch.sum(quadruplet_loss) / (num_hard_quadruplets + 1e-16)

        return quadruplet_loss



def _pairwise_distance(x, squared=False, eps=1e-16):
    # Compute the 2D matrix of distances between all the embeddings.

    cor_mat = torch.matmul(x, x.t())
    norm_mat = cor_mat.diag()
    distances = norm_mat.unsqueeze(1) - 2 * cor_mat + norm_mat.unsqueeze(0)
    distances = F.relu(distances)

    if not squared:
        mask = torch.eq(distances, 0.0).float()
        distances = distances + mask * eps
        distances = torch.sqrt(distances)
        distances = distances * (1.0 - mask)

    return distances


def _get_anchor_positive_triplet_mask(labels):
    # Return a 2D mask where mask[a, p] is True iff a and p are distinct and have same label.

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    indices_not_equal = torch.eye(labels.shape[0]).to(device).byte() ^ 1

    # Check if labels[i] == labels[j]
    labels_equal = torch.unsqueeze(labels, 0) == torch.unsqueeze(labels, 1)

    mask = indices_not_equal * labels_equal

    return mask


def _get_anchor_negative_triplet_mask(labels):
    # Return a 2D mask where mask[a, n] is True iff a and n have distinct labels.

    # Check if labels[i] != labels[k]
    labels_equal = torch.unsqueeze(labels, 0) == torch.unsqueeze(labels, 1)
    mask = labels_equal ^ 1

    return mask


def _get_quadruplet_mask(
        labels,
        idx
    ):
    B = labels.size(0)

    # Make sure that i != j != k != l
    indices_equal = torch.eye(B, dtype=torch.bool).to(device)  # [B, B] 
    indices_not_equal = ~indices_equal  # [B, B] 
    i_not_equal_j = indices_not_equal.view(B, B, 1, 1)  # [B, B, 1, 1]
    j_not_equal_k = indices_not_equal.view(1, B, B, 1)  # [B, 1, 1, B] 
    k_not_equal_l = indices_not_equal.view(1, 1, B, B)  # [1, 1, B, B] 
    distinct_indices = i_not_equal_j & j_not_equal_k & k_not_equal_l  # [B, B, B, B] 

    # Make sure that labels[i] == labels[j] 
    #            and labels[j] != labels[k] 
    #            and labels[k] != labels[l]
    labels_equal = labels.view(1, B) == labels.view(B, 1)  # [B, B]
    i_equal_j = labels_equal.view(B, B, 1, 1)  # [B, B, 1, 1]
    j_equal_k = labels_equal.view(1, B, B, 1)  # [1, B, B, 1]
    l_equal_i = labels_equal.view(B, 1, 1, B)  # [1, 1, B, B]
    label_match = i_equal_j & ~j_equal_k #& ~l_equal_i
    
    eps = 15#self.epsilon
    tau = 15#self.tau
    
    ep_labels = episode_labels[idx]
    lst = step_labels[idx] 
    
    ep_labels_equal = ep_labels.view(1, B) == ep_labels.view(B, 1)  # [B, B]
    i_equal_j_ep = labels_equal.view(B, B, 1, 1)  # [B, B, 1, 1]
    j_equal_k_ep = labels_equal.view(1, B, B, 1)  # [1, B, B, 1]
    k_equal_l_ep = labels_equal.view(1, 1, B, B)  # [1, 1, B, B]
    episode_match = i_equal_j_ep & j_equal_k_ep & k_equal_l_ep
    
    # uncomment this to keep quadruplets in the same episode
    '''
    within_eps_over_tau = torch.logical_and(torch.abs(lst[:, None, None, None] - lst[None, :, None, None]) <= eps,
                                            torch.abs(lst[:,None,None,None]-lst[None,None,:,None]) <= eps,
                                            torch.abs(lst[:, None, None,None] - lst[None, None,None,:]) >= tau)
    '''
    within_eps_over_tau = (torch.abs(lst[:, None, None, None] - lst[None, :, None, None]) <= eps) & (torch.abs(lst[:,None,None,None]-lst[None,None,:,None] <= eps)) & (torch.abs(lst[:, None, None,None] - lst[None, None,None,:]) >= tau)
    #return within_eps_over_tau & episode_match & label_match & distinct_indices  # [B, B, B, B] 
    return within_eps_over_tau & label_match & distinct_indices


In [None]:
### Training function
def train_siamese_epoch(vae, device, dataloader, optimizer, criterion,beta = 1., lbda = .1):
    # Set train mode for both the encoder and the decoder
    vae.train()
    train_loss = 0.0
    vae_loss_tot = 0.0
    siam_loss_tot = 0.0
    # Iterate the dataloader (we do not need the label values, this is unsupervised learning)
    for batch, (x, actions, idx) in enumerate(dataloader,1): 
        # Move tensor to the proper device
        x = x.to(device)
        actions = actions.to(device)
        idx = idx.to(device)
        
        #forward pass
        x_hat, z = vae(x)
        
        
        # Evaluate VAE loss
        vae_loss = ((x - x_hat)**2).sum() + beta * vae.encoder.kl

        # Evaluate triplet/quadruplet loss
        siam_loss = criterion(embeddings = z, labels = actions, idx = idx)
        
        loss = vae_loss + lbda * siam_loss 
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # Print batch loss
        #print('\t partial train loss (single batch): %f' % (loss.item()))
        
        train_loss+=loss.item()
        vae_loss_tot += vae_loss.item()
        siam_loss_tot += siam_loss.item()
    
    print('Avg VAE/SIAM loss: ', vae_loss_tot / len(dataloader.dataset), siam_loss_tot / len(dataloader.dataset))
        
    return train_loss / len(dataloader.dataset)

In [None]:
### Set the random seed for reproducible results
torch.manual_seed(0)

d = 7

vae = VAE(z_dim = d,nc=4)
lr = 1e-4

optimizer = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=1e-7)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Selected device: {device}')

vae.to(device)

criterion = HardQuadrupletLoss(margin1=2., margin2=2.5,squared=True,hardest=False,epsilon=10)

In [None]:
num_epochs = 1

for epoch in tqdm(range(num_epochs)):
   train_loss = train_siamese_epoch(vae,device,train_loader,optimizer, criterion,beta =1.5,lbda=1000)
   #val_loss = test_epoch(vae,device,test_loader, beta = 1.5)
   #print('\n EPOCH {}/{} \t train loss {:.3f} \t val loss {:.3f}'.format(epoch + 1, num_epochs,train_loss,val_loss))
   plot_ae_outputs(vae.encoder,vae.decoder,n=10)
   print('epoch, train loss: ', epoch, train_loss)

In [None]:
        acts = {0:'NOOP',
                1:'FIRE',
                2:'UP',
                3:'LEFT',
                4:'RIGHT',
                5:'DOWN',
                6:'LEFT-FIRE',
                7:'RIGHT-FIRE',
                8:'UP-LEFT',
                9:'UP-RIGHT',
                10:'UP-FIRE',
                11:'DOWN-LEFT',
                12:'DOWN-RIGHT',
                13:'DOWN-FIRE',
                14:'UP-LEFT-FIRE',
                15:'UP-RIGHT-FIRE',
                16:'DOWN-LEFT-FIRE',
                17:'DOWN-RIGHT-FIRE'
                }

In [None]:
torch.save(vae, 'seaquest_svae.pt')

In [None]:
vae = torch.load('seaquest_svae.pt')
encoded_samples = []
for sample in tqdm(test_expert_dataset):
    img = torch.FloatTensor(sample[0]).unsqueeze(0).to(device)
    label = sample[1]
    # Encode image
    vae.eval()
    with torch.no_grad():
        encoded_img  = vae.encoder(img)
    # Append to list
    encoded_img = encoded_img.flatten().cpu().numpy()
    encoded_sample = {f"Enc. Variable {i}": enc for i, enc in enumerate(encoded_img)}
    encoded_sample['label'] = acts[label]
    encoded_samples.append(encoded_sample)
    
encoded_samples = pd.DataFrame(encoded_samples)
encoded_samples


from sklearn.manifold import TSNE
import plotly.express as px

px.scatter(encoded_samples, x='Enc. Variable 0', y='Enc. Variable 1', color=encoded_samples.label.astype(str), opacity=0.7)

In [None]:
#beta = 1.5
tsne = TSNE(n_components=2)
tsne_results = tsne.fit_transform(encoded_samples.drop(['label'],axis=1))

fig = px.scatter(tsne_results, x=0, y=1, color=encoded_samples.label.astype(str),labels={'0': 'tsne-2d-one', '1': 'tsne-2d-two'})
fig.show()

## DOWNSTREAM TRAINING

In [None]:
# dataset for downstream 
# RUN THIS FOR DOWNSTREAM TRAINING

train_size = int(0.8 * len(expert_dataset))
#push_size = int(0.1 * len(expert_dataset))
test_size = int(0.2 * len(expert_dataset))

train_expert_dataset, test_expert_dataset = random_split(
    expert_dataset, [train_size, test_size]
)



kwargs = {"num_workers": 8, "pin_memory": False}
train_loader = th.utils.data.DataLoader(
        dataset=train_expert_dataset, batch_size=32, shuffle=False, **kwargs
)
test_loader = th.utils.data.DataLoader(
    dataset=test_expert_dataset, batch_size=32, shuffle=False, **kwargs
)
  

In [None]:
def train_ds_epoch(net, 
          train_loader, 
          optimizer,
          iso_coeff=10,
          clst_coeff =0.,
          sep_coeff = 0.,
          rep_coeff = 0.,
          validation = False,
          resnetbc=False
         ):
    running_CE = 0.
    running_iso = 0.
    running_clst= 0.
    running_sep = 0.
    running_rep = 0.
    
    if not resnetbc:
        I = torch.eye(net.isometry.weight.data.shape[0]).to(device)
        max_dist = net.prototype_shape[1]
    
    criterion = nn.CrossEntropyLoss()   
    
    correct = 0
    total = 0
    for (states, actions) in train_loader:
        states = states.to(device)
        actions = actions.to(device)
        
        if not validation:
            optimizer.zero_grad()
        
        if not resnetbc:
            logits, min_distances = net(states.float())
        else:
            logits = net(states.float())
        
        _, predicted = torch.max(logits.data, 1)
        total += actions.size(0)
        correct += (predicted == actions).sum().item()
            

        if not resnetbc:
            #cluster cost
            # it's the mean of the min distances between encoding and prototypes of same class
            # where min is taken over batch dimension
            # torch.t(model.prototype_action_identity[:,target]) is distance matrix with entries
            # for prototypes of wrong class zeroed out
            # torch.max((max_dist - min_distances) * prototypes_of_correct_class, dim=1
            # gives you the min distances between encodings and correct protos over 
            # batch dim
            prototypes_of_correct_class = torch.t(net.prototype_action_identity[:,actions]).to(device)
            inverted_distances, _ = torch.max((max_dist - min_distances) * prototypes_of_correct_class, dim=1)
            clst_cost = torch.mean(max_dist - inverted_distances)

            #separation cost
            prototypes_of_wrong_class = 1 - prototypes_of_correct_class
            inverted_distances_to_nontarget_prototypes, _ = torch.max((max_dist - min_distances) * prototypes_of_wrong_class, dim=1)
            sep_cost = torch.mean(max_dist - inverted_distances_to_nontarget_prototypes)

            #avg clustering cost
            #avg_sep = torch.sum(min_distances * prototypes_of_wrong_class,dim=1) / torch.sum(prototypes_of_wrong_class,dim=1)
            #avg_sep = torch.mean(avg_sep)

            # Rep term
            rep = torch.sum(torch.min(min_distances, dim=0)[0])
            
            CE = criterion(logits, actions.long())

            A = net.isometry.weight.data
            iso_penalty = torch.linalg.matrix_norm(torch.mm(A.T, A) - I, ord = 'fro')**2 # 2 gives operator norm

            loss = CE + iso_coeff * iso_penalty + clst_coeff * clst_cost - sep_coeff * sep_cost + rep_coeff * rep 

            running_CE += CE.item()
            running_iso += iso_penalty.item()
            running_clst += clst_cost.item()
            running_sep += sep_cost.item()
            running_rep += rep.item()
        else:
            loss = criterion(logits, actions.long())
            running_CE += loss.item()

        if not validation:
            loss.backward()
            optimizer.step()
     
    if not validation:
        print('Train Acc: %d %%' % (100 * correct / total))
    else:
        print('Val Acc: %d %%' % (100 * correct / total))
    return running_CE, running_iso, running_clst, running_sep, running_rep

In [None]:
def train_iso(net, 
          train_loader, 
          val_loader,
          iso_coeff=10,
          clst_coeff =0.,
          sep_coeff = 0.,
          rep_coeff = 0.,
          push_interval=None,
          push_epochs=100,
          n_epochs=100,
          main_lr = 1e-6,
          pruned=False,
          lr_sched=False,
          resnetbc=False
         ):
    criterion = nn.CrossEntropyLoss()
    net_optimizer = optim.Adam(net.parameters(), lr=main_lr)
    if push_interval is not None:
        ll_optimizer = optim.Adam(net.last_layer.parameters(), lr=1e-5)
    else:
        ll_optimizer = None
    if lr_sched:
        if pruned:
            sched = StepLR(ll_optimizer, 10, .8)
        else:
            sched = StepLR(net_optimizer, 10, .8)
    
    
    early_stopping = EarlyStopping(patience=50, verbose=True)
    
    #wandb.init(project='mario')
    #wandb.watch(net)
    
    for i in range(n_epochs):
        net.train()
        if pruned:
            (ce,iso,clst,sep,rep) = train_ds_epoch(net, 
                                                   train_loader, 
                                                   ll_optimizer,
                                                   iso_coeff,
                                                   clst_coeff,
                                                   sep_coeff,
                                                   rep_coeff,
                                                   validation=False,
                                                   resnetbc=resnetbc)
        else:
            (ce,iso,clst,sep,rep) = train_ds_epoch(net, 
                                                   train_loader, 
                                                   net_optimizer,
                                                   iso_coeff,
                                                   clst_coeff,
                                                   sep_coeff,
                                                   rep_coeff,
                                                   validation=False,
                                                   resnetbc=resnetbc)         
        #wandb.log({'CE':ce,'Iso':iso,'Clst':clst,'Sep':sep,'Rep':rep})
        print('\nEpoch', i, {'CE':ce,'Iso':iso,'Clst':clst,'Sep':sep,'Rep':rep})
        
        if push_interval is not None:
            if i > 0 and i % push_interval == 0:
                # push
                prots = save_proto(net, train_loader, project=True, device=device)
                
                #freeze isometry
                for param in net.isometry.parameters():
                    param.requires_grad = False
                #freeze protos
                net.prototype_vectors.requires_grad = False
                
                for j in range(push_epochs):
                    (ce,iso,clst,sep,rep) = train_ds_epoch(net, 
                                                           train_loader, 
                                                           ll_optimizer,
                                                           iso_coeff,
                                                           clst_coeff,
                                                           sep_coeff,
                                                           rep_coeff,
                                                           validation=False,
                                                           resnetbc=resnetbc)
                    #wandb.log({'ll CE':ce,'ll Iso':iso,'ll Clst':clst,'ll Sep':sep,'ll Rep':rep})
                    
                # train isometry
                for param in net.isometry.parameters():
                    param.requires_grad = True
                
                # train protos
                net.prototype_vectors.requires_grad = True
                    
        # validation step
        net.eval()
        (ce,iso,clst,sep,rep) = train_ds_epoch(net, 
                                               val_loader, 
                                               net_optimizer,
                                               iso_coeff,
                                               clst_coeff,
                                               sep_coeff,
                                               rep_coeff,
                                               validation=True,
                                               resnetbc=resnetbc)
        early_stopping(ce, net)
        #wandb.log({'val CE':ce,'val Iso':iso,'val Clst':clst,'val Sep':sep,'val Rep':rep})
        
        if early_stopping.early_stop:
            print('early Stopping')
            break
    net.load_state_dict(torch.load('checkpoint.pt'))
    
    
    # final push
    print('Doing final push!')
    if not resnetbc:
        prots = save_proto(net, train_loader, project=True, device=device)
                
        #freeze isometry
        for param in net.isometry.parameters():
            param.requires_grad = False
        #freeze protos
        net.prototype_vectors.requires_grad = False
                
        for j in range(push_epochs):
            (ce,iso,clst,sep,rep) = train_ds_epoch(net, 
                                                      train_loader, 
                                                      ll_optimizer,
                                                      iso_coeff,
                                                      clst_coeff,
                                                      sep_coeff,
                                                      rep_coeff,
                                                      validation=False,
                                                      resnetbc=resnetbc)
            #wandb.log({'ll CE':ce,'ll Iso':iso,'ll Clst':clst,'ll Sep':sep,'ll Rep':rep})

    if lr_sched:
        sched.step()
    
    return net

In [None]:
def test(net, device, test_loader):
    net.eval()

    total = 0.
    correct = 0.
    for (states, actions) in test_loader:
        states = states.to(device)
        actions = actions.to(device)

        logits, _ = net(states)

        _, predicted = torch.max(logits.data, 1)
        total += actions.size(0)
        correct += (predicted == actions).sum().item()

    acc = 100 * (correct / total)
    return acc


In [None]:
#@title get_data
# get dataset of state-action pairs from expert
from skimage.transform import resize
from PIL import Image
def flip_fidelity(net,num_interactions=int(6e4), n_flip=10000,env_id="SeaquestNoFrameskip-v4", preprocess=False,save=False):
    env, ppo_expert = get_env_and_model(env_id)
    
    state_shape = env.observation_space.shape
    action_shape = env.action_space.shape

    print('state shape: ', state_shape)
    print('action shape: ', action_shape)
    

    expert_actions = np.empty((num_interactions,) + env.action_space.shape)
    agent_actions = np.empty((num_interactions,) + env.action_space.shape)

    episode_schedule = np.empty((num_interactions, 2))
    
    # save flip points
    flip_observations = np.empty((n_flip, 4,84,84))
    flip_actions = np.empty((n_flip,) + env.action_space.shape)
        
    obs = env.reset()

    ep_number = 0
    correct = 0
    total = 0
    flip_total = 0
    flip_correct = 0
    for i in tqdm(range(num_interactions)):
        step_number = i
        action, _ = ppo_expert.predict(obs, deterministic=True)
        #PREPROCESS AFTER EXPERT IS DONE!!!!!!
        if preprocess:
            obs = crop_pong(obs)[0]
            obs = np.expand_dims(resize(obs, (84,84,4)),0)

        expert_observations[i]= obs.transpose(0,3,1,2)
        frame = env.render(mode='rgb_array')#.astype(int)
        im = Image.fromarray(frame)
        im = im.resize(size=(84,84), resample=Image.BICUBIC, reducing_gap=3.0)
        color_observations[i] = np.array(im)
        
        expert_actions[i] = action

        episode_schedule[i] = np.array([ep_number, i])
        

        try:
            l, _ = net(torch.FloatTensor(obs).permute(0,3,1,2).to(device))
        except Exception as e:
            l = net(torch.FloatTensor(obs).permute(0,3,1,2).to(device))
        agent_policy = F.softmax(l, dim=1)
        agent_action = torch.argmax(agent_policy).item()
        
        if action == agent_action:
            correct += 1
        total += 1
            
        agent_actions[i] = agent_action
            
        if step_number >= 1  and action != expert_actions[step_number-1]:
            flip_total += 1
            if agent_action == action and agent_actions[step_number-1] == expert_actions[step_number-1]:
                flip_correct += 1
            flip_observations[flip_total - 1] = obs
            flip_actions = action
        
        obs, reward, done, info = env.step(action)
        if done:
            ep_number = ep_number + 1
            obs = env.reset()
            
        if flip_total >= n_flip:
            break

    env.close()

    return correct/total, flip_correct/flip_total

# ResNet BC

In [None]:
# train bbox
bbox = models.resnet18(pretrained=False)
bbox.conv1 = nn.Conv2d(4, 64, kernel_size=7,stride=2,padding=3,bias=False)
bbox.fc = nn.Linear(512,18)
bbox=bbox.to(device)

In [None]:
bbox = train_iso(bbox, 
      train_loader, 
      test_loader, 
      iso_coeff=1e-8,#.01,
      clst_coeff=4e-4,
      sep_coeff=1e-4,
      ev_coeff=1e-5,
      div_coeff=0.,#1e-8,
      n_epochs=2,
      main_lr=5e-6,
      push_interval=None,
      push_epochs=1,
      resnetbc=True)

In [None]:
tf, ff = flip_fidelity(bbox,n_flip=10000,num_interactions=30000, env_id="SeaquestNoFrameskip-v4", preprocess=False)
print(tf, ff)

In [None]:
torch.save(bbox,'resnetbc_seaquest.pt')

# ProtoX

In [None]:
# Make Prototype network
net = ProtoIsoResNet(prototype_shape=(360,512*3*3),num_actions=18,tl=False, sim_method=0)# MY sim method
net.isometry.weight.data.copy_(torch.eye(512*3*3))

# Load encoder from pre-training (change enc40.pt to whatever the pre-trained encoder is) 
encoder_state = torch.load('seaquest_svae.pt').encoder.state_dict()
#encoder_state = enc.state_dict()
net.load_state_dict(encoder_state, strict=False)
net = net.to(device)

# Freeze encoder
for parm in net.convunit.parameters():
    parm.requries_grad = False

net = net.to(device)

In [None]:
I = torch.eye(net.isometry.weight.data.shape[0]).to(device)

In [None]:
net = train_iso(net, 
      train_loader, 
      test_loader, 
      iso_coeff=1e-8,#.01,
      clst_coeff=4e-4,
      sep_coeff=1e-4,
      rep_coeff=1e-5,
      n_epochs=1,
      main_lr=5e-6,
      push_interval=25,
      push_epochs=1)

In [None]:
torch.save(net, 'seaquest_svae_ds.pt')

# TSNE

In [None]:
net = torch.load('seaquest_svae_ds.pt')
encoded_samples = []
for sample in tqdm(test_expert_dataset):
    img = torch.FloatTensor(sample[0]).unsqueeze(0).to(device)
    label = sample[1]
    # Encode image
    net.eval()
    with torch.no_grad():
        encoded_img, _  = net.push_forward(img)
    # Append to list
    encoded_img = encoded_img.flatten().cpu().numpy()
    encoded_sample = {f"Enc. Variable {i}": enc for i, enc in enumerate(encoded_img)}
    encoded_sample['label'] = acts[label]
    encoded_samples.append(encoded_sample)
    
encoded_samples = pd.DataFrame(encoded_samples)
encoded_samples


from sklearn.manifold import TSNE
import plotly.express as px

px.scatter(encoded_samples, x='Enc. Variable 0', y='Enc. Variable 1', color=encoded_samples.label.astype(str), opacity=0.7)

In [None]:
#beta = 1.5
tsne = TSNE(n_components=2)
tsne_results = tsne.fit_transform(encoded_samples.drop(['label'],axis=1))

fig = px.scatter(tsne_results, x=0, y=1, color=encoded_samples.label.astype(str),labels={'0': 'tsne-2d-one', '1': 'tsne-2d-two'})
fig.show()

In [None]:
torch.unique(net.prototype_vectors, dim=0).shape[0]

# Fidelity

In [None]:
test(net,device,eval_loader)

In [None]:
tf, ff = flip_fidelity(net,n_flip=10000,num_interactions=30000, env_id="SeaquestNoFrameskip-v4", preprocess=False)
print(tf, ff)

# VISUALIZATION FNs

In [None]:
def top_k_prots(net, oidx, action_filter=False, k=3, score_sort=False, sim_method=0, act=None):
    enc, dist = net(torch.FloatTensor(expert_observations[oidx]).unsqueeze(0).to(device))
    
    if sim_method == 1:
        sim = torch.log( (1 + dist) / (dist + 1e-10) )[0]#[pidx]
    else:
        sim = torch.exp(dist * -.05)[0]
    try:
        fcl = net.last_layer.weight.data.T
    except Exception as e:
        fcl = torch.exp(net.last_layer.log_weight.data.T)
        
    if act is None:
        action = expert_actions[oidx].item()
    else:
        action = act
    if not score_sort:
        topk_idx = torch.topk(sim, k=k)
    else:
        scores = sim * net.last_layer.weight.data[int(action)]
        topk_idx = torch.topk(scores,k=k)
        
    # TODO FIX NONES
    idx = [topk_idx[1][j].item() for j in range(k)]
    sims = [sim[topk_idx[1][j].item()] for j in range(k)]
    fcs = [fcl[int(idx[j])][int(action)].item() for j in range(k)]
        
    return idx, sims, fcs


In [None]:
def prot_rep(net, ix, push_loader,project=False,sim_method=0,max_samples=5000):
    for pidx in [ix]:#tqdm(range(net.prototype_vectors.shape[0])):
        best_prot = None
        ct = 0
        best_sim = -1*float('inf')
        best_idx = -1
        for batch, (x, actions) in enumerate(push_loader):
            batch_size = x.shape[0]
            
            x = x.to(device)
            enc, dist = net.push_forward(x)
        
            if sim_method == 1:
                sim = torch.log( (1 + dist) / (dist + 1e-10) )#[0][pidx]
            else:
                sim = torch.exp(-.05 * dist)

            sim = sim[:,pidx]
                
            bat_ix = torch.argmax(sim).item()
            
            data_ix = bat_ix + ct
            
            top_sim = sim[bat_ix].item()
            
            if top_sim > best_sim:
                best_prot = enc[bat_ix]
                best_idx = data_ix
                best_sim = top_sim

                
            ct += batch_size
            
            if ct > max_samples:
                break
    print('prototype action is ',acts[expert_actions[data_ix]])
    return color_observations[data_ix], best_prot

In [None]:
def explain_color(net,
                 oidx,
                 push_loader,
                 action_filter=False,
                 k=3,
                 fname=None,
                 score_sort=False,
                 sim_method=0):
    net.eval()
    acts = {0:'NOOP',1:'RIGHT',2:'RIGHT+A',3:'RIGHT+B',4:'RIGHT+A+B',5:'A',6:'LEFT'}

    num_prototypes_per_action = net.num_prototypes // net.num_actions
    
    fig, axes = plt.subplots(1,1+k,figsize=(30,30))
    
    action = expert_actions[oidx].item()
    
    out, _ = net(torch.FloatTensor(expert_observations[oidx]).unsqueeze(0).to(device))

    logit = out[0][int(action)].item()
    # show input
    axes[0].imshow(color_observations[oidx].astype(int))
    if k > 1:
        axes[0].set_title('Input w/ action: ' + acts[action],size=25)# + '\nTotal points: '+str(logit),size=25)
    else:
        axes[0].set_title('Input w/ action: ' + acts[action] + '\nat t='+str(oidx),size=40)# + '\nTotal points: '+str(logit),size=25)
    axes[0].set_xticklabels([])
    axes[0].set_yticklabels([])
    axes[0].set_xticks([])
    axes[0].set_yticks([])
    axes[0].axis('off')
    
    # show prots
    top_k_ix, sims, fcs = top_k_prots(net, oidx, 
                                      action_filter=action_filter,
                                      k=k,
                                      sim_method=sim_method,
                                      score_sort=score_sort)
    for j in range(1,1+k):
        im, _ = prot_rep(net, top_k_ix[j-1], push_loader)
        axes[j].imshow(im.astype(np.uint8))
        axes[j].set_xticklabels([])
        axes[j].set_yticklabels([])
        axes[j].set_xticks([])
        axes[j].set_yticks([])
        axes[j].axis('off')
        if k > 1:
            axes[j].set_title('Sim score: {:.2f} \n'.format(sims[j-1])+acts[action]+' score: {:.2f}\nPoints: {:.2f}'.format(fcs[j-1],sims[j-1]*fcs[j-1]),size=25)
        else:
             axes[j].set_title('Most Similar Prototype'.format(sims[j-1], fcs[j-1],sims[j-1]*fcs[j-1]),size=40)

In [None]:
def color_ball(net, oidx,min_sim, n_samples=1000,alpha=.2,Af=True):        
    fig, axes = plt.subplots(1,2)
    axes[0].imshow(color_observations[oidx].astype(int))
    
    if Af:
        enc1, _ = net.push_forward(torch.FloatTensor(expert_observations[oidx]).unsqueeze(0).to(device)) 
    else:
        enc1 = net.encoder(torch.FloatTensor(expert_observations[oidx]).unsqueeze(0).to(device))
        
    best_list = []
    for i in range(n_samples):
        if Af:
            enc2, _ = net.push_forward(torch.FloatTensor(expert_observations[i]).unsqueeze(0).to(device)) 
        else:
            enc2 = net.encoder(torch.FloatTensor(expert_observations[i]).unsqueeze(0).to(device))
            
        d = torch.norm(enc1 - enc2)
        sim = torch.exp(-.05 * d)
        
        if sim > min_sim:
            best_list.append(i)
            
    print('Found ', len(best_list))
    for i, ix in enumerate(best_list):
        axes[1].imshow(color_observations[ix].astype(int),alpha=alpha)
    for i in [0,1]:        
        axes[i].set_xticklabels([])
        axes[i].set_yticklabels([])
        axes[i].set_xticks([])
        axes[i].set_yticks([])
        axes[i].axis('off')
    axes[0].set_title('Input')
    axes[1].set_title('Top 30\n Most Similar States')

# Compare Similar states between Af and f

In [None]:
f = torch.load('seaquest_svae.pt')
Af = torch.load('seaquest_svae_ds.pt')

In [None]:
color_ball(f,200,.943,n_samples=5000,Af=False,alpha=.05)

In [None]:
color_ball(Af,200,.507,n_samples=5000,Af=True,alpha=.05)

# make pruned net

In [None]:
eq = prot_equivs(net)
unique_ix = list(eq)
        
pruned_net = merge_weights(net,device,eq,unique_ix,action_filter=False)

In [None]:
pruned_net.num_prototypes = pruned_net.prototype_vectors.shape[0]

# Explain step 32 using pruned network

In [None]:
# Get prototype representations without projecting into net
pruned_prots = save_proto(pruned_net, train_loader, project=False, device=device)

In [None]:
net.last_layer.weight.data.shape

In [None]:
pruned_net.last_layer.weight.data.shape

In [None]:
explain_color(pruned_net.to(device),32, train_loader, action_filter=False, k=3,fname=None,sim_method=0,score_sort=True)