## a-AlphaBio homework 
### Mark Thompson
### Started April 29, 2024 

In [None]:
%load_ext autoreload

In [None]:
%autoreload
# import libraries
import numpy as np
import pickle as pk
import pandas as pd
import math
import os
import matplotlib.pyplot as plt
%matplotlib inline


In [None]:
import torch
import torch.nn as nn

# Some plotting functions
#
def plot_preds_hist(preds_file_path):
    preds = pk.load(open(preds_file_path, 'rb'))
    print('len(preds):', len(preds))
    preds = [p[0] for p in preds]
    print('preds[0:10]:', preds[0:10])

    # Histogram of predicted values
    plt.hist(preds, bins=100)
    plt.xlabel('pred Kd (nm)')
    plt.ylabel('count')
    plt.title('Distribution of pred values set')
    plt.show()


def plot_pred_vs_true(preds_file_path, true_file_path, xlim=(0,5), ylim=(0,5)):
    preds = pk.load(open(preds_file_path, 'rb'))
    y = pk.load(open(true_file_path, 'rb'))
    print('len(preds):', len(preds), ', len(y):', len(y))
    preds = [p[0] for p in preds]
    y = [a[0] for a in y]

    # scatter plot of true vs pred
    plt.scatter(y, preds, c ="blue")
    plt.xlabel('experimental Kd (nm)')
    plt.ylabel('predicted Kd (nm)')
    plt.title('true vs predicted Kd on validation set')
    plt.xlim(xlim)
    plt.ylim(ylim)
    plt.show()



---------
## Holdout dataset and predictions

In [None]:
# The holdout data
data_file = './data/alphaseq_data_hold_out.csv'
df = pd.read_csv(data_file)
rows1 = df.shape[0]
print('holdout dataframe has', rows1, 'rows')
print(df.columns.tolist())
print(df['sequence_a'].describe())

In [None]:
# The predictions on the holdout set

# tform_mlp version 1 predictions
data_file = './inference_results/tform_mlp_model/cleaned-4-data/preds_tform_mlp_1715104590.5575511.csv'
df = pd.read_csv(data_file)
rows1 = df.shape[0]
print('holdout predictions has', rows1, 'rows')
print(df.columns.tolist())
print(df.describe())
preds = df['pred_Kd'].values

# tform_mlp_v2 version 2 predictions
data_file = './inference_results/tform_mlp_model_v2/addendum/cleaned-4b-data/preds_tform_mlp_1715280172.2843447.csv'
df = pd.read_csv(data_file)
rows1 = df.shape[0]
print('holdout predictions has', rows1, 'rows')
print(df.columns.tolist())
print(df.describe())
preds_v2 = df['pred_Kd'].values


# Histogram of predicted values
# plt.figure(figsize=(6,6))
# plt.hist(preds, bins=100)
plt.hist(preds, bins=100, alpha=1.0, label='Transformer v1', color='b')
plt.hist(preds_v2, bins=100, alpha=1.0, label='Transformer v2', color='r')
plt.legend(loc='upper left')

plt.xlabel('pred Kd (nm)')
plt.ylabel('count')
plt.title('Distribution of pred Kd values on the holdout set')
plt.xlim((-1,5))
# plt.ylim((0,5))
plt.show()


----
## alphaseq_data_train dataset  (not cleaned)

In [None]:
# The predictions on the holdout set
data_file = './data/alphaseq_data_train.csv'
df = pd.read_csv(data_file)
rows1 = df.shape[0]
print('dataset has', rows1, 'rows')
print(df.columns.tolist())
print(df.describe())

raw_Kds = df['Kd'].values

# Histogram of Kd values
plt.figure(figsize=(3,3))
plt.hist(raw_Kds, bins=100)
plt.xlabel('Experimental Kd (nm)')
plt.ylabel('count')
plt.title('Distribution of Kd values in the raw training data set')
plt.xlim((-1,5))
plt.show()


----
## Kd distribution for clean-4 dataset train only

In [None]:
# The predictions on the holdout set
# data_file = './data/q_cleaned_4_train_set.csv'
data_file = './data/q_cleaned_4b_train_set.csv'
df = pd.read_csv(data_file)
rows1 = df.shape[0]
print('dataset has', rows1, 'rows')
print(df.columns.tolist())
print(df.describe())

clean4_Kds = df['Kd'].values

# Histogram of Kd values
# plt.figure(figsize=(3,3))
# plt.hist(clean4_Kds, bins=100)
# plt.hist(raw_Kds, bins=100)
plt.hist(raw_Kds, bins=100, alpha=1.0, label='raw data', color='b')
plt.hist(clean4_Kds, bins=100, alpha=1.0, label='clean-4 data', color='r')
plt.legend(loc='upper left')
plt.xlabel('Experimental Kd (nm)')
plt.ylabel('count')
plt.title('Distribution of experimental Kd values')
plt.xlim((-1,5))
plt.show()


----
### MLP model  Clean-3b dataset

In [None]:
pred_file_path = ''
plot_preds_hist(pred_file_path)

In [None]:
pred_file_path = './inference_results/mlp_model/cleaned-3/test_no_cls_token/preds_mlp_1714982629.5918856.pkl'
true_file_path = './inference_results/mlp_model/cleaned-3/test_no_cls_token/y_mlp_1714982629.5921109.pkl'
plot_pred_vs_true(pred_file_path, true_file_path, xlim=(0,3.5), ylim=(0,3.5))

In [None]:
# Pearson Correlation Coefficient
#
pred_file_path = '/Users/markthompson/Documents/dev/a-alphaBio-homework/test_results/mlp_model/cleaned-3b-data/preds_mlp_1715105115.4793816.pkl'
true_file_path = '/Users/markthompson/Documents/dev/a-alphaBio-homework/test_results/mlp_model/cleaned-3b-data/y_mlp_1715105115.4793816.pkl'

pred = torch.tensor(pk.load(open(pred_file_path, 'rb'))).squeeze()
true = torch.tensor(pk.load(open(true_file_path, 'rb'))).squeeze()

print(pred.shape)
print(true.shape)

c = torch.stack((pred, true), dim=0)
print(c.shape)

p = torch.corrcoef(c)
print(p)


-------
### Vision Transform Model (VIT)  1-channel Clean-3b Dataset

In [None]:
pred_file_path = ''
plot_preds_hist(pred_file_path)

In [None]:
# Inference on the validation set to compare actual with predicted values
pred_file_path = './inference_results/vit_model/cleaned-3b/BW/test_no_cls_token/preds_vit_1715016846.793841.pkl'
true_file_path = './inference_results/vit_model/cleaned-3b/BW/test_no_cls_token/y_vit_1715016846.7940617.pkl'
plot_pred_vs_true(pred_file_path, true_file_path, xlim=(0,3.5), ylim=(0,3.5))

In [None]:
# Pearson Correlation Coefficient
#
pred_file_path = '/Users/markthompson/Documents/dev/a-alphaBio-homework/test_results/vit_model/cleaned-3b-data/1-channel/preds_vit_1715105283.6443539.pkl'
true_file_path = '/Users/markthompson/Documents/dev/a-alphaBio-homework/test_results/vit_model/cleaned-3b-data/1-channel/y_vit_1715105283.6443539.pkl'

pred = torch.tensor(pk.load(open(pred_file_path, 'rb'))).squeeze()
true = torch.tensor(pk.load(open(true_file_path, 'rb'))).squeeze()

print(pred.shape)
print(true.shape)

c = torch.stack((pred, true), dim=0)
print(c.shape)

p = torch.corrcoef(c)
print(p)


----
### Vision Transform Model (VIT)  3-channel, Clean-3b Dataset

In [None]:
# Inference on the validation set to compare actual with predicted values
pred_file_path = './inference_results/vit_model/cleaned-3b/BGR/test_no_cls_token/preds_vit_1715020986.6452327.pkl'
true_file_path = './inference_results/vit_model/cleaned-3b/BGR/test_no_cls_token/y_vit_1715020986.6455376.pkl'
plot_pred_vs_true(pred_file_path, true_file_path, xlim=(0,3.5), ylim=(0,3.5))

In [None]:
# Pearson Correlation Coefficient
#
pred_file_path = '/Users/markthompson/Documents/dev/a-alphaBio-homework/test_results/vit_model/cleaned-3b-data/3-channel/preds_vit_1715105421.8529866.pkl'
true_file_path = '/Users/markthompson/Documents/dev/a-alphaBio-homework/test_results/vit_model/cleaned-3b-data/3-channel/y_vit_1715105421.8529866.pkl'

pred = torch.tensor(pk.load(open(pred_file_path, 'rb'))).squeeze()
true = torch.tensor(pk.load(open(true_file_path, 'rb'))).squeeze()

print(pred.shape)
print(true.shape)

c = torch.stack((pred, true), dim=0)
print(c.shape)

p = torch.corrcoef(c)
print(p)


----
### TFormMLP Clean-4 Dataset 

In [None]:
pred_file_path = './test_results/tform_mlp_model/cleaned-4-data/preds_tform_mlp_1715105469.0702462.pkl'
true_file_path = './test_results/tform_mlp_model/cleaned-4-data/y_tform_mlp_1715105469.0702462.pkl'
plot_pred_vs_true(pred_file_path, true_file_path, xlim=(0,3.5), ylim=(0,3.5))

In [None]:
# Pearson Correlation Coefficient
#
pred_file_path = '/Users/markthompson/Documents/dev/a-alphaBio-homework/test_results/tform_mlp_model/cleaned-4-data/preds_tform_mlp_1715105469.0702462.pkl'
true_file_path = '/Users/markthompson/Documents/dev/a-alphaBio-homework/test_results/tform_mlp_model/cleaned-4-data/y_tform_mlp_1715105469.0702462.pkl'

pred = torch.tensor(pk.load(open(pred_file_path, 'rb'))).squeeze()
true = torch.tensor(pk.load(open(true_file_path, 'rb'))).squeeze()

print(pred.shape)
print(true.shape)

c = torch.stack((pred, true), dim=0)
print(c.shape)

p = torch.corrcoef(c)
print(p)


----
### TFormMLP Clean-4 Dataset:  pretrained, then fine-tuned

In [None]:
pred_file_path = './test_results/tform_mlp_model/finetune/cleaned-4b-data/preds_tform_mlp_1715572603.424339.pkl'
true_file_path = './test_results/tform_mlp_model/finetune/cleaned-4b-data/y_tform_mlp_1715572603.424339.pkl'
plot_pred_vs_true(pred_file_path, true_file_path, xlim=(-1.0, 0.5), ylim=(-0.2,0))

a = pk.load(open(pred_file_path, 'rb'))
print(a)

----
### t-SNE analysis

The transformer-based models should have learned relationships between the elements of the sequence.  See how this appears in t-SNE plots

In [None]:
%autoreload
from torch.utils.data import DataLoader
from models.tform import TFormMLP_Lightning
from datasets.scFv_diy_dataset_pretrain import scFv_diy_pretrain_Dataset as dataset
import os
import yaml
import pytorch_lightning as pl
import torch

# Read the config
config_path = './config/tform_params.yaml'  
with open(config_path, 'r') as file:
    try:
        config = yaml.safe_load(file)
    except yaml.YAMLError as exc:
        print(exc)

model_config = config['model_params']
train_config = config['train_params']    
test_config = config['test_params']
print(model_config)
print(train_config)
print(test_config)

pl.seed_everything(config['seed'])



In [None]:
#----------------------------------------------------------
# Load the dataset and dataloaders
#----------------------------------------------------------
test_data_path = './data/oas_2/paired/paired_2_val_set.csv'
test_dataset = dataset(train_config, model_config['block_size'], test_data_path, regularize=False)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=256) #train_config['batch_size'])

#----------------------------------------------------------
# Load pre-trained model
#----------------------------------------------------------
checkpoint_name = './lightning_logs/tform_mlp_model/pretrain/big2/chain_id/paired_2_set/version_1/checkpoints/epoch=60-step=8000-val_loss=0.44-loss=0.43.ckpt' 
model = TFormMLP_Lightning.load_from_checkpoint(checkpoint_path=checkpoint_name, model_config=model_config, config=train_config)

# register a forward hooks to pick out the output tensors for the attention blocks
feats = {} #an empty dictionary
def hook_func(m , inp ,op):
   feats['feat'] = op.detach()

# The hook will grab the output of the final Encoder layer's attention block
model.model.transformer.encoders[-1].attn_block.attend.register_forward_hook(hook_func)

In [None]:
it = iter(test_loader)
batch = next(it)
# x, x2, x3, y, name = batch
x, chain, mask, y, name = batch
print('x shape:', x.shape, ', chain shape:', chain.shape, ', mask shape:', mask.shape, ', y shape:', y.shape)

y_hat, loss, trans_out = model(x.to(model.device), chain.to(model.device), mask.to(model.device))
print('y_hat shape:', y_hat.shape, ', trans_out shape', trans_out.shape)

print('hook function results. feats[feat].shape:', feats['feat'].shape)

In [None]:

import matplotlib.colors as colors

seq_idx = 1
# print('chain:', chain[seq_idx])
heads = feats['feat']
print('heads shape:', heads.shape)
nhead = heads[seq_idx,2,:,:].cpu()
print('nhead shape:', nhead.shape)

sum = torch.sum(nhead[10,:])
print(sum)


# calc the mean-head over all heads
# h1 = torch.sum(heads[seq_idx, :, :, :], dim=0)
# nhead = (torch.nn.functional.normalize(h1, p=2.0, dim=1)).cpu()
print('nhead shape:', nhead.shape)

# h1_colsum = torch.sum(h1, dim=0)
# h1_colsum = torch.nn.functional.normalize(h1_colsum, p=2.0, dim=0)
# print(h1_colsum)
# print(h1_colsum.shape)
# print(h1_colsum[28:112])
# x = np.arange(0, len(h1_colsum)) #28, 112, 1)
# y = h1_colsum.cpu().numpy()
# plt.plot(x,y)

plt.figure(figsize=(10,10))
cmap = plt.get_cmap('brg_r') #viridis')
plt.imshow(nhead, cmap=cmap)  
xmin=48 #0 
xmax=50 #248 
xticks = torch.arange(xmin, xmax, 2).tolist()
# plt.colorbar() 
# plt.title('Attention map for head #2 in the final encoder layer\nfor the first sequence in the batch')  
# plt.title('Attention map for head #2 in the final encoder layer\nThe linker region\nfor the first sequence in the batch')  
# plt.title('Attention map for head #2 in the final encoder layer\nHeavy Chain region\nfor the first sequence in the batch')  
plt.xlim(0,  60) #248)  #xmin, xmax)
plt.ylim(xmax, xmin)
plt.xlabel('aa residue')
plt.ylabel('aa residue')
plt.xticks(torch.arange(1, 60, 2).tolist())
plt.yticks([48, 49, 50])

for i in range(0, 60):
    print('{:.3f}'.format(nhead[49,i]), end=' ')

# print('{:.2f}'.format([nhead[49, 0:60]]))

In [None]:

import matplotlib.colors as colors

seq_idx = 1

# 'rgb' is not a valid value for name; supported values are 'Accent',
# 'Accent_r', 'Blues', 'Blues_r', 'BrBG', 'BrBG_r', 'BuGn', 'BuGn_r', 
# 'BuPu', 'BuPu_r', 'CMRmap', 'CMRmap_r', 'Dark2', 'Dark2_r', 'GnBu', 
# 'GnBu_r', 'Greens', 'Greens_r', 'Greys', 'Greys_r', 'OrRd', 'OrRd_r', 
# 'Oranges', 'Oranges_r', 'PRGn', 'PRGn_r', 'Paired', 'Paired_r', 'Pastel1', 
# 'Pastel1_r', 'Pastel2', 'Pastel2_r', 'PiYG', 'PiYG_r', 'PuBu', 'PuBuGn', 
# 'PuBuGn_r', 'PuBu_r', 'PuOr', 'PuOr_r', 'PuRd', 'PuRd_r', 'Purples', 'Purples_r', 
# 'RdBu', 'RdBu_r', 'RdGy', 'RdGy_r', 'RdPu', 'RdPu_r', 'RdYlBu', 'RdYlBu_r', 'RdYlGn', 
# 'RdYlGn_r', 'Reds', 'Reds_r', 'Set1', 'Set1_r', 'Set2', 'Set2_r', 'Set3', 'Set3_r', 
# 'Spectral', 'Spectral_r', 'Wistia', 'Wistia_r', 'YlGn', 'YlGnBu', 'YlGnBu_r', 'YlGn_r', 
# 'YlOrBr', 'YlOrBr_r', 'YlOrRd', 'YlOrRd_r', 'afmhot', 'afmhot_r', 'autumn', 'autumn_r', 
# 'binary', 'binary_r', 'bone', 'bone_r', 'brg', 'brg_r', 'bwr', 'bwr_r', 
# 'cividis', 'cividis_r', 'cool', 'cool_r', 'coolwarm', 'coolwarm_r', 'copper', 'copper_r', 
# 'crest', 'crest_r', 'cubehelix', 'cubehelix_r', 'flag', 'flag_r', 'flare', 'flare_r', 
# 'gist_earth', 'gist_earth_r', 'gist_gray', 'gist_gray_r', 'gist_heat', 'gist_heat_r', 
# 'gist_ncar', 'gist_ncar_r', 'gist_rainbow', 'gist_rainbow_r', 'gist_stern', 'gist_stern_r', 
# 'gist_yarg', 'gist_yarg_r', 'gnuplot', 'gnuplot2', 'gnuplot2_r', 'gnuplot_r', 'gray', 'gray_r', 
# 'hot', 'hot_r', 'hsv', 'hsv_r', 'icefire', 'icefire_r', 'inferno', 'inferno_r', 'jet', 'jet_r', 
# 'magma', 'magma_r', 'mako', 'mako_r', 'nipy_spectral', 'nipy_spectral_r', 'ocean', 'ocean_r', 
# 'pink', 'pink_r', 'plasma', 'plasma_r', 'prism', 'prism_r', 'rainbow', 'rainbow_r', 'rocket',
#  'rocket_r', 'seismic', 'seismic_r', 'spring', 'spring_r', 'summer', 'summer_r', 'tab10', 
# 'tab10_r', 'tab20', 'tab20_r', 'tab20b', 'tab20b_r', 'tab20c', 'tab20c_r', 'terrain', 
# 'terrain_r', 'turbo', 'turbo_r', 'twilight', 'twilight_r', 'twilight_shifted', 'twilight_shifted_r', 
# 'viridis', 'viridis_r', 'vlag', 'vlag_r', 'winter', 'winter_r

cols = 4
rows = 4
num_heads = 16
print('num rows:', rows, ', num cols:', cols)
plt.figure(figsize=(15,15))
cmap = plt.get_cmap('brg') #viridis') brg_r

idx = 1
for hidx in range(num_heads):
    nhead = heads[seq_idx,hidx,:,:].cpu()  
    ax = plt.subplot(rows, cols, idx)
    ax.axis('off')
    plt.imshow(nhead, cmap=cmap)  
    idx += 1


In [None]:
from einops import repeat

print('trans_out shape:', trans_out.shape)
b, n, d = trans_out.shape
print('b:', b, ', n:', n, ', d:', d)

ctx_labels = torch.arange(0, model_config['block_size'])
print('ctx_labels shape:', ctx_labels.shape)


# Five types of tokens to distinguish in the tsne plot:
# chain id to distinguish heavy, light, and linker
# 1: heavy chain, 2: light chain, 3: linker.  CLS: 0, PAD: 4
# ctx_cats = torch.ones_like(ctx_labels)  # label everything like it's a non-CDR aa group
# ctx_cats[0] = 0  # the classifier token used in regression
# ctx_cats[29:109] = 2  # the CDR region is aa residues 29-108
labels = chain.squeeze()
print('labels shape:', labels.shape)
# print(labels)

# labels = ctx_cats.repeat(train_config['batch_size'])
# print('labels shape:', labels.shape)

ctx_vectors = torch.reshape(trans_out, (trans_out.shape[0]*trans_out.shape[1], trans_out.shape[2]))
labels = torch.flatten(labels)
print('ctx_vectors shape', ctx_vectors.shape)
print('labels shape', labels.shape)

In [None]:
# Look at just one sequence
#
num_seqs = 1
labels = ctx_cats.repeat(num_seqs)
print('labels shape:', labels.shape)

idx = 4
one_seq = trans_out[idx:idx+num_seqs]
print('shape one_seq:', one_seq.shape)
new_shape = one_seq.shape[0]*one_seq.shape[1]

ctx_vectors = one_seq
ctx_vectors = torch.reshape(one_seq, (one_seq.shape[0]*one_seq.shape[1], one_seq.shape[2]))
print('ctx_vectors shape', ctx_vectors.shape)

In [None]:
from sklearn.manifold import TSNE
n_iter = 3000

# 2D
tsne = TSNE(n_components=2, random_state=config['seed'], metric="cosine", n_iter=n_iter, verbose=True)
x_tsne = tsne.fit_transform(ctx_vectors.detach().cpu().numpy())
print('x_tsne shape:', x_tsne.shape)
print(tsne.kl_divergence_)


In [None]:
from pathlib import Path
path = Path('./misc_analysis/tSNE/tform_mlp/pretrain/big2/chain_id/paired_2_set/')
path.mkdir(parents=True, exist_ok=True)

pk.dump(x_tsne, open(os.path.join(path, 'tsne_10000iter_tform_pretrain_paired2.pkl'), 'wb'))

# x_tsne = pk.load(open('./misc_analysis/tSNE/tform_mlp_v2/tsne_x_10000iter_tform_mlp_v2.pkl', 'rb'))

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Create an array with three colors
colors = ["#55FF00", "#FF0000", "#0000FF", "#aa00FF",  "#000000"] # rgb, purple, and black
# colors = ["#0000FF", "#A0A0A0", "#FF0000"] # rgb
# Set custom color palette
my_palette = sns.color_palette(colors)
sns.set_palette(my_palette)

fig, ax = plt.subplots(figsize=(10, 8))
x_all = x_tsne[:, 0] 
y_all = x_tsne[:, 1] 
cat_all = labels     
plt.title('t-SNE of transformer model output')  
ax.set_xlabel("t-SNE y")
ax.set_ylabel("t-SNE x")
p = sns.scatterplot(x=x_all, y=y_all, alpha=0.6, ax=ax, palette=sns.color_palette(my_palette), hue=labels, legend=True)

# title
p.legend_.set_title('Token type')
new_labels = ['regression token', 'aa: non-CDR', 'aa: CDR region (29-108)']
new_labels = ['regression token', 'heavy chain', 'light chain', 'linker',  'PAD token']
for t, l in zip(p.legend_.texts, new_labels):
    t.set_text(l)


In [None]:
# Set the color palette
sns.set_palette(sns.color_palette("Paired"))
# Plot the data, specifying a different color for data points in
# each of the day categories (weekday and weekend)
ax = sns.scatterplot(x='day', y='miles_walked', data=dataset, hue='day_category')
# Customize the axes and title
ax.set_title("Miles walked")
ax.set_xlabel("day")
ax.set_ylabel("total miles")
# Remove top and right borders
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.show()

----
## Analysis of output of attention blocks to see the interactions

### (also, see the BertViz tool:  https://github.com/jessevig/bertviz)

In [None]:
%autoreload
from torch.utils.data import DataLoader
from models.tform_mlp_v2 import TFormMLP_Lightning_v2
from datasets.scFv_dataset_v2 import scFv_Dataset_v2 as dataset
import os
import yaml
import pytorch_lightning as pl
import torch

# Read the config
config_path = './config/tform_mlp_params.yaml'  
with open(config_path, 'r') as file:
    try:
        config = yaml.safe_load(file)
    except yaml.YAMLError as exc:
        print(exc)

model_config = config['model_params']
train_config = config['train_params']    
test_config = config['test_params']
print(model_config)
print(train_config)
print(test_config)

pl.seed_everything(config['seed'])



In [None]:
#----------------------------------------------------------
# Load the dataset and dataloaders
#----------------------------------------------------------
test_data_path = './data/q_cleaned_4b_test_set.csv'
test_dataset = dataset(train_config, model_config['block_size'], test_data_path, regularize=False)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=train_config['batch_size'])

#----------------------------------------------------------
# Load pre-trained model
#----------------------------------------------------------
checkpoint_name = './lightning_logs/tform_mlp_model/cleaned-4-data/trained/checkpoints/epoch=316-step=9800-val_loss=0.11-loss=0.05.ckpt' 
model = TFormMLP_Lightning_v2.load_from_checkpoint(checkpoint_path=checkpoint_name, model_config=model_config, config=train_config)


it = iter(test_loader)
batch = next(it)
x, y, name = batch
print('x shape:', x.shape, ', y shape:', y.shape)

y_hat, trans_out = model(x.to(model.device))
print('y_hat shape:', y_hat.shape, ', trans_out shape', trans_out.shape)

In [None]:
#
# Use forward hooks to pick out the output tensors for the attention blocks
#

feats = {} #an empty dictionary
def hook_func(m , inp ,op):
   feats['feat'] = op.detach()

# it = iter(test_loader)
# batch = next(it)
# x, x2, x3, y, name = batch
y_hat, trans_out = model(x.to(model.device), x2.to(model.device), x3.to(model.device))

print(feats['feat'].shape)

In [None]:
# Register hooks to pick outputs from specific layers in model