# Imports and utils

In [None]:
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam

from typing import List, Tuple

In [None]:
import einops
import numpy as np
# from torchvision import transforms
# from PIL import Image

In [None]:
from common.model_target import ImpalaModelTarget

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data

## Load data locally

In [None]:
def get_rewarded_obs_idxs(rew_batch: torch.Tensor, 
                          done_batch: torch.Tensor
                         ) -> torch.Tensor:
  sliced = torch.clone(rew_batch)
  for r in range(sliced.shape[0]):
    start = 0
    for i in torch.nonzero(done_batch[r]):
        sliced[r][start:i] = rew_batch[r][i]
        start = i + 1
  return torch.nonzero(sliced)

In [None]:
# def collect_obs(dir: str, n_epochs: int) -> torch.Tensor:
#     obs = []
#     for e in range(1, n_epochs + 1):
#         print(f'Epoch: {e}')
#         rew_batch = torch.load(os.path.join(dir, str(e), 'rew_batch.pt'), map_location=device)
#         done_batch = torch.load(os.path.join(dir, str(e), 'done_batch.pt'), map_location=device)
#         obs_idxs = get_rewarded_obs_idxs(rew_batch, done_batch)
#         print('collect obs idxs')

#         tmp_obs = torch.load(os.path.join(dir, str(e), 'observations_batch.pt'), map_location=device)[obs_idxs]
#         print('load obs')
#         #tmp_obs = torch.quantize_per_tensor_dynamic(tmp_obs, dtype=torch.quint8, reduce_range=True)
#         print('quantize obs')
#         obs.append(tmp_obs)
    
#     return torch.stack(obs)

In [None]:
path = "samples"

## Create Dataset
- [ ] Load data  
- [ ] Generate wrong target twin 

In [None]:
def get_target_idx(targets: torch.Tensor) -> Tuple[List[int], torch.Tensor]:
    target_asset = targets.unique(dim=1)
    target_idxs = []
    for t in targets[0]:
        for i in range(target_asset.shape[0]):
            if torch.all(target_asset[0][i].eq(t)):
                target_idxs.append(i)
                break
    return target_idxs, target_asset

In [None]:
def load_tensor(path: str):
    t = torch.load(path)
    return einops.rearrange(t, 'step env -> env step')

def load_img(path: str):
    t = torch.load(path)
    return einops.rearrange(t, 'step env c w h -> env step c w h')

In [None]:
class ProbingDataset(Dataset):
    def __init__(self, dir:str, n_dir:int = 1, transform=None, target_transform=None) -> None:
        self.path = os.path.join(dir, str(n_dir))

        self.rewards = load_tensor(os.path.join(self.path, 'rew_batch.pt'))
        self.dones = load_tensor(os.path.join(self.path, 'done_batch.pt'))

        self.obs_idxs = get_rewarded_obs_idxs(self.rewards, self.dones)
        self.observations = load_img(os.path.join(self.path, 'observations_batch.pt'))
        self.observations = self.observations[obs_idx]

        # Flatten the samples
        self.rewards = einops.rearrange(self.rewards, "env step -> (env step)")
        self.observations = einops.rearrange(self.observations, "env step c w h -> (env step) c w h")

        targets = load_img(os.path.join(self.path, 'target_idxs.pt'))
        self.target_idxs, self.target_assets = get_target_idx(targets)
        
        del self.dones
        del targets

        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.rewards.shape[0])

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        obs = self.observations[idx]
        target_id = self.target_idxs[idx % 257] # WARNING: not sure :'(

        true_target = np.random.rand(1)[0] > .5
        if true_target:
            target = self.target_assets[idx]
            reward = self.rewards[idx] 
        else: 
            targets_probas = np.ones(26) * .04
            targets_probas[target_id] = 0
            target = self.target_assets[np.random.choice(26, size=1, p=targets_probas)]
            reward = self.reward[idx] * -1 

        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return obs, target, reward

    # def get_rewarded_obs_idxs(self) -> torch.Tensor:
    #     sliced = torch.clone(self.reward)
    #     for j in range(self.reward.shape[0]):
    #       for r in range(sliced.shape[1]):
    #         start = 0
    #         for i in torch.nonzero(self.done[j][r]):
    #             sliced[j][r][start:i] = self.reward[j][r][i]
    #             start = i + 1
    #     return torch.nonzero(sliced)

# Load Impala Model

In [None]:
model_path = "logs/procgen/coinrun/easy-random-100-res-128-coins-27-pierre-old/seed_3087_15-12-2022_11-01-44/model_31031296.pth" #FIXME
tmp_dict = torch.load(model_path, map_location=device)["state_dict"]

#### Update the state_dict to fit the .embedder instead of the agent

In [None]:
del tmp_dict['fc_policy.weight']
del tmp_dict['fc_policy.bias']
del tmp_dict['fc_value.weight']
del tmp_dict['fc_value.bias']

In [None]:
state_dict = {}
for key, value in tmp_dict.items():
  state_dict[key.replace('embedder.', '')] = tmp_dict[key]

del tmp_dict

In [None]:
impala_model = ImpalaModelTarget(in_channels=3)
impala_model.load_state_dict(state_dict)
impala_model.to(device)

# Create Linear probe

## Linear probe architecture


In [None]:
class LinearProbe(nn.Module):
  def __init__(self, input_dim, output_dim) -> None:
    super().__init__()
    self.clf = nn.Linear(in_features=input_dim, out_features=output_dim)

  def forward(self, x):
    x = self.clf(x)
    x = F.sigmoid(x)
    return self.clf()

In [None]:
class ProbedModel(nn.Module):
  def __init__(self, model: nn.Module) -> None:
    super().__init__()

    self.model = model
    self.model.requires_grad_(False)
    self.probe = LinearProbe(
        input_dim=self.model.fc1.in_features,
        output_dim=1
    )

  def forward(self, x, target) -> torch.Tensor:
    hidden = self.model.forward(x, target)
    return self.probe(hidden)

# Train Linear Probe

In [None]:
def train(model: nn.Module, dir: str, 
          epochs: int, lr: int, regularization=None
          ) -> List[int]:

  assert(model.model.requires_grad == False)

  loss_fn = nn.MSELoss()
  loss_hist = []
  optimizer = Adam(model.parameters(), lr=lr)
  
  for i in range(epochs):
    random_dir = np.random.randint(1, 41)
    dataset = ProbingDataset(dir=dir, n_dir=random_dir)
    dataloader = DataLoader(dataset)
    for sample in dataloader:
      optimizer.zero_grad()
      obs, target, y = sample
      y_pred = model(obs, target)
      loss = loss_fn(y_pred, y)

      if regularization == "L1":
        l1_loss = torch.abs(model.probe.parameters).sum()
        loss = loss + l1_loss

      loss.backward()
      loss_hist.append(loss)

  return loss_hist

In [None]:
model = ProbedModel(model=impala_model)

loss_hist = train(model, dir="samples", epochs=50, lr=...) #FIXME 

In [None]:
observations = torch.load("/home/qfeuilla/Desktop/AI Safety/Explicit_Goal_Pointer/EGPWorkBench/samples/1/observations_batch.pt")
rewards = torch.load("/home/qfeuilla/Desktop/AI Safety/Explicit_Goal_Pointer/EGPWorkBench/samples/1/rew_batch.pt")
targets = torch.load("/home/qfeuilla/Desktop/AI Safety/Explicit_Goal_Pointer/EGPWorkBench/samples/1/target_idxs.pt")

In [None]:
import matplotlib.pyplot as plt

In [None]:
for i in range(len(torch.nonzero(rewards))):
    plt.imshow(einops.rearrange(observations[torch.nonzero(rewards)[i][0]][torch.nonzero(rewards)[i][1]], 'c w h -> w h c'))
    plt.show()
    plt.imshow(einops.rearrange(targets[torch.nonzero(rewards)[i][0]][torch.nonzero(rewards)[i][1]], 'c w h -> w h c'))
    plt.show()