In [1]:
from models._config import C
from dataset.RadarDataset import RadarSignalDataset  


c = C()
if __name__ == "__main__":
    datajson = c.dataload(csv=True, mode='train')
    dataset = RadarSignalDataset(datajson, c.signalTypes[0:c.typeSize], snr_max=17)

<<Loading Train Data [True]>>
Data loading for 'Barker'.....Done!
Data loading for 'Costas'.....Done!
Data loading for 'Frank'.....Done!
Data loading for 'LFM'.....Done!
Data loading for 'P1'.....Done!
Data loading for 'P2'.....Done!
Data loading for 'P3'.....Done!
Data loading for 'P4'.....Done!
Data loading for 'T1'.....Done!
Data loading for 'T2'.....Done!
Data loading for 'T3'.....Done!
Data loading for 'T4'.....Done!


In [2]:
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from models.LSTM import BiLSTM


class LRP:
    def __init__(self, model, epsilon=1e-6):
        self.model = model
        self.epsilon = epsilon
    
    def relevance(self, x, lengths, target=None):
        output, h, w = self.model(x, lengths, lstm_outputs=True)
        target = torch.argmax(output, dim=1) if target is None else target
        # Output shape: torch.Size([batch_size, cls_size])
        # Hidden states shape: torch.Size([batch_size, length, hidden_size*2])
        # Attention weights shape: torch.Size([batch_size, length])
        # Context shape: torch.Size([batch_size, hidden_size*2])
        
        r = torch.zeros_like(output)
        for i in range(output.size(0)):
            r[i, target[i]] = output[i, target[i]] # Target class에 대한 relevance만 생존
        
        r_c = self.bpp_fc(h, r)
        r_h = self.bpp_att(r_c, h, w)
        r_x = self.bpp_lstm(r_h, h, lengths)
        
        return r_x
    
    def lstm_gates(self, x_t, h_prev, W_ih, W_hh, b_ih, b_hh, cl_prev):
        gates = torch.matmul(W_ih, x_t) + torch.matmul(W_hh, h_prev) + b_ih + b_hh
        i_t, f_t, o_t, g_t = torch.chunk(gates, 4, dim=0)
        
        i_t = torch.sigmoid(i_t)
        f_t = torch.sigmoid(f_t)
        o_t = torch.sigmoid(o_t)
        cl_t = f_t * cl_prev + i_t * torch.tanh(g_t)
        
        return (f_t, i_t, o_t, cl_t)
    
    
    def bpp_bilstm(self, rel_h, h, x, lengths):
        batch_size, seq_len, hidden_size = h.size()
        hidden_size = hidden_size // 2
        
        rel_x_fw = torch.zeros(batch_size, seq_len, 2)
        rel_cl_fw_t1 = torch.zeros(batch_size, hidden_size)
        W_ih = self.model.lstm.weight_ih_l0
        W_hh = self.model.lstm.weight_hh_l0
        b_ih = self.model.lstm.bias_ih_l0
        b_hh = self.model.lstm.bias_hh_l0
        
        rel_x_bw = torch.zeros(batch_size, seq_len, 2)
        rel_cl_bw_t1 = torch.zeros(batch_size, hidden_size)
        W_ih_bw = self.model.lstm.weight_ih_l0_reverse
        W_hh_bw = self.model.lstm.weight_hh_l0_reverse
        b_ih_bw = self.model.lstm.bias_ih_l0_reverse
        b_hh_bw = self.model.lstm.bias_hh_l0_reverse
        
        for i in range(batch_size):
            h_prev_fw = torch.zeros(hidden_size)
            cl_prev_fw = torch.zeros(hidden_size)
            for t in reversed(range(seq_len)):
                if t < lengths[i]:
                    x_t = x[i, t]
                    h_fw_t = h[i, t, :hidden_size]
                    rel_h_fw_t = rel_h[i, t, :hidden_size]
                    gates_fw = self.lstm_gates(x_t, h_prev_fw, W_ih, W_hh, b_ih, b_hh, cl_prev_fw)
                    rel_x_fw[i, t], rel_cl_fw_t1 = self.bpp_lstm_cell(h_fw_t, rel_h_fw_t, rel_cl_fw_t1, gates_fw, W_ih)
                    
                    h_prev_fw = h_fw_t
                    cl_prev_fw = gates_fw[-1]
            
            h_prev_bw = torch.zeros(hidden_size)
            cl_prev_bw = torch.zeros(hidden_size)
            for t in range(seq_len):
                if t < lengths[i]:
                    x_t = x[i, t]
                    h_bw_t = h[i, t, hidden_size:]
                    rel_h_bw_t = rel_h[i, t, hidden_size:]
                    gates_bw = self.lstm_gates(x_t, h_prev_bw, W_ih_bw, W_hh_bw, b_ih_bw, b_hh_bw, cl_prev_bw)
                    rel_x_bw[i, t], rel_cl_bw_t1 = self.bpp_lstm_cell(h_bw_t, rel_h_bw_t, rel_cl_bw_t1, gates_bw, W_ih_bw)
             
                    h_prev_bw = h_bw_t
                    cl_prev_bw = gates_bw[-1]
                    
        rel_x = rel_x_fw + rel_x_bw
        return rel_x
        
    def bpp_lstm_cell(self, rel_h_t, rel_cl_t1, gates, W_ih):
        f_t, i_t, o_t, cl_t = gates
        
        rel_cl_t = rel_cl_t1 + rel_h_t * o_t * (1-torch.tanh(cl_t)**2)
        rel_cl_t1 = rel_cl_t * f_t
        
        rel_x_t = torch.matmul(W_ih.T, rel_cl_t * i_t)
        return rel_x_t, rel_cl_t1
        
    def bpp_att(self, r_c, h, w):
        rel_h = torch.zeros_like(h)
        for i in range(h.size(0)):
            rel_h_t = r_c[i].unsqueeze(0) * w[i].unsqueeze(-1)
            rel_h[i] = rel_h_t

        return rel_h
    
    def bpp_fc(self, c, r):
        fc_W = self.model.fc.weight
        rel_c = torch.zeros_like(c)
        for i in range(r.size(0)):
            for j in range(r.size(1)):
                rel_c[i] += (c[i]*fc_W[j]) * r[i, j] / (fc_W[j].abs().sum() + self.epsilon)
        return rel_c
        
    
    
def explain_set(train_dataset):
    model = BiLSTM(input_size=2, hidden_size=128, num_layers=2, num_classes=12)
    model.load_state_dict(torch.load('/home/kiwan/TSC_XAI/ckpts/ 0dB/LSTM_ 0dB_0.0073.pt'))
    model.eval()
    
    lrp = LRP(model)

    batch = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=model.collate)

    fig = plt.figure(figsize=(20, 15))

    with torch.no_grad():
        for batch_idx, (data_batch, labels_batch, snrs_batch, lengths_batch) in enumerate(batch):
            r_scores = lrp.relevance(data_batch, lengths_batch)

            real_part = data_batch[batch_idx][:, 0].cpu().numpy()  # 실수부
            imag_part = data_batch[batch_idx][:, 1].cpu().numpy()  # 허수부
            r_scores_real = r_scores[:, 0].cpu().numpy()   # Relevance 실수부

            # 그래프 그리기
            ax = fig.add_subplot(4, 2, batch_idx + 1)
            ax.plot(real_part, label='Real Part')
            ax.plot(r_scores_real, label='Relevance (Real Part)', linestyle='--')
            ax.set_title(f'Sample {batch_idx+1} (SNR: {snrs_batch[batch_idx]} dB)')
            ax.legend()

            break  # 첫 번째 배치만 시각화

    plt.tight_layout()
    plt.show()

explain_set(dataset)

Ground Truth: tensor([9, 1])


RuntimeError: Input and parameter tensors are not at the same device, found input tensor at cuda:0 and parameter tensor at cpu

<Figure size 2000x1500 with 0 Axes>

In [None]:
import os
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from models.LSTM import BiLSTM

def complex_sep(data):
    data = data.replace('i', 'j')  # 'i'를 'j'로 변경하여 복소수 형식에 맞춤
    complex_numbers = data.split(',')  # 쉼표로 구분된 복소수 분리
    complex_list = []
    for num in complex_numbers:
        num = complex(num.strip())
        complex_list.append(num)
    return complex_list

def exaplin_set(data):
    model = BiLSTM(input_size=2, hidden_size=128, num_layers=2, num_classes=12)
    model.load_state_dict(torch.load('/home/kiwan/TSC_XAI/ckpts/ 0dB/LSTM_ 0dB_0.0073.pt'))
    model.eval()
    
    lrp = LRP(model)
    with torch.no_grad():
        x = torch.tensor(data, dtype=torch.float).unsqueeze(0)
        lengths = torch.tensor([len(data)], dtype=torch.long)
        relevance = lrp.get_relevance(x, lengths)
	
	
        
dataset = '/data/kiwan/dataset-CWD-1000/'
signals = ['Barker', 'Costas', 'Frank', 'LFM', 'P1', 'P2', 'P3', 'P4', 'T1', 'T2', 'T3', 'T4']

for signal in signals:
    real_parts = []
    plt.figure(figsize=(20, 5))
    data_dir = os.path.join(dataset, signal)
    data_file = os.listdir(data_dir)
    for ie, file in enumerate(data_file):
        with open(os.path.join(data_dir, file), 'r') as f:
            data = f.readlines()
            data = [complex_sep(d) for d in data]
            data = [item for sublist in data for item in sublist]
            
            

        for c in data:
            real_parts.append(c.real)

        plt.axvline(x=len(real_parts), color='r', linestyle='--')
            
        if ie == 5:
            break


    plt.plot(real_parts, label='Real Part')
    # plt.plot(imag_parts, label='Imaginary Part')
    plt.title(f'{signal} Signal: Real and Imaginary Parts over Time')
    plt.xlabel('Sample Index')
    plt.ylabel('Value')
    plt.legend()
    plt.grid(True)
    plt.show()
        

In [None]:

def explain_set(train_dataset):
    model = BiLSTM(input_size=2, hidden_size=128, num_layers=2, num_classes=12)
    model.load_state_dict(torch.load('/home/kiwan/TSC_XAI/ckpts/ 0dB/LSTM_ 0dB_0.0073.pt'))
    model.eval()
    
    lrp = LRP(model)

    batch = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=model.collate)

    fig = plt.figure(figsize=(20, 15))

    with torch.no_grad():
        for data_batch, labels, snr, lengths in batch:
            for i in range(len(data_batch)):
                data = data_batch[i]
                seq_len = lengths[i].item()  
                print(seq_len)
                
                # LRP로 Relevance 스코어 계산
                r_scores = lrp.get_relevance(data, seq_len)

                real_part = data_batch[i][:, 0].cpu().numpy()  # 실수부
                r_scores_real = r_scores[:, 0].cpu().numpy()   # Relevance 실수부

                ax = fig.add_subplot(4, 2, i + 1)
                ax.plot(real_part, label='Real Part')
                ax.plot(r_scores_real, label='Relevance (Real Part)', linestyle='--')
                ax.set_title(f'Sample {i+1} (SNR: {snr[i]} dB)')
                ax.legend()

            break

    plt.tight_layout()
    plt.show()

explain_set(train_dataset)