## Attention plots - Gradient weighted class activation mapping (Grad - CAM) 
see figure 4 Melchior22

In [None]:
import torch
from torch import nn
from torch import optim
from accelerate import Accelerator
from torch.utils.data import DataLoader
from spender import SpectrumEncoder,MLP,encoder_percentiles,load_model
import gc


# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True
print('CPU prepared')

In [None]:
class Dataset(torch.utils.data.Dataset):

    def __init__(self,x,y):

        """ generate and organize artificial data from parametrizations of SFHs"""

        self.x=torch.from_numpy(x) #seds
        self.y=torch.from_numpy(y) #percentiles


    def __len__(self):
        """total number of samples"""
        return len(self.x[:,0])

    def __getitem__(self,index):
        """Generates one sample of data"""
        x=self.x[index,:]
        y=self.y[index,:]
        return x,y


In [None]:
# get model and dataloader batch again
from torch.autograd import Variable

n_latent=16

test_set = Dataset(x_test, y_test[0])
print('Shape of the test set: ',np.shape(x_test),np.shape(y_test[0]))
#params={'batch_size': len(x_test[:,0]) } #no minitbatches
params={'batch_size': 64 } #batch 64
test_generator = torch.utils.data.DataLoader(test_set,**params) 

print('Calling accelerator...')
accelerator = Accelerator(mixed_precision='fp16')
print(accelerator.distributed_type)
testloader = accelerator.prepare(test_generator)
print(testloader,len(testloader))


print('Loading model...')
model_file = "./saved_model/generate_latent_2/latent_"+str(n_latent)+"/checkpoint.pt"
model, loss = load_model(model_file, device=accelerator.device,n_hidden=(16,32))
model = accelerator.prepare(model)


#print(model)



def grad_fam(model, spec, l_callback):
    
    # compute attention value and weights
    
    with torch.no_grad():
        h, a = model.encoder._downsample(spec)
        a = model.encoder.softmax(a)   
    
    # compute percentiles, with gradients!
    s, y_pred = model._forward(spec)
    
    # compute specific l
    l = l_callback(spec)
    l.backward()
    
    att = a.detach()
    
    att_grad = model.encoder.attention_grad.detach()
    
    return att, att_grad
    
def l_halpha(spec):
    sel = (wave > 6560) & (wave < 6566)
    diff = spec[:,sel] - 1
    return Variable(torch.sum(diff), requires_grad=True)


def grad_fam_halpha(model, spec):
    return grad_fam(model, spec, l_halpha)


#model.eval()
    
batch = next(iter(testloader))

print(batch[0].shape) 

att, att_grad = grad_fam_halpha(model, batch[0].float())
    
with torch.no_grad():
       s, y_pred = model._forward(batch[0].float())
    
#print(att_grad)



In [None]:
gc.collect()
torch.cuda.empty_cache()