In [3]:
import os
import argparse
import numpy as np

import shap
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

import models._config as c
from models.LSTM import BiLSTM
from models.Attention import Transformer
from explainer import LRP
    
criterion = nn.CrossEntropyLoss()

parser = argparse.ArgumentParser(description='Training parameters')
parser.add_argument('-m', '--mode', type=str, default='train', help='Mode of operation (train/eval)')

parser.add_argument('-smin','--snr_min', type=int, default=0, help='Minimum SNR value')
parser.add_argument('-smax','--snr_max', type=int, default=16, help='Maximum SNR value')
parser.add_argument('--split_size', type=float, default=0.8, help='Train/Test split size')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training')
parser.add_argument('--num_epochs', type=int, default=500, help='Number of epochs for training')
parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate for optimizer')
parser.add_argument('--weight_decay', type=float, default=1e-5, help='Weight decay for optimizer')
parser.add_argument('--input_size', type=int, default=2, help='Input size for the model')
parser.add_argument('--hidden_size', type=int, default=128, help='Hidden size for the model')
parser.add_argument('--num_layers', type=int, default=2, help='Number of layers in the model')
parser.add_argument('--num_classes', type=int, default=c.typeSize, help='Number of output classes')

args = parser.parse_args(args=[])

params = {
    'snr_min': args.snr_min,
    'snr_max': args.snr_max,
    'split_size': args.split_size,
    'batch_size': args.batch_size,
    'num_epochs': args.num_epochs,
    'learning_rate': args.learning_rate,
    'weight_decay': args.weight_decay,
    'input_size': args.input_size,
    'hidden_size': args.hidden_size,
    'num_layers': args.num_layers,
    'num_classes': args.num_classes
}

def train_eachset(dataset, mtype='LSTM'):
    print(f"Size of train dataset: {len(dataset)}")
    for snr in range(params['snr_min'], params['snr_max']+1, 2):
        print(f"\nTraining model for SNR: {snr}...")
        ckpt = os.path.join("./ckpts/", snr_str:=f"SNR-{snr}dB" if snr != 0 else " 0dB")
        os.makedirs(ckpt, exist_ok=True)
        
        model = BiLSTM(params['input_size'], params['hidden_size'], params['num_layers'], params['num_classes'])
        model.to(c.device)
        
        criterion = nn.CrossEntropyLoss()
        # criterion = LabelSmoothingLoss(classes=params['num_classes'], smoothing=0.1).to(c.device)
        optimizer = optim.Adam(model.parameters(), lr=params['learning_rate'], 
                                                    weight_decay=params['weight_decay'])
        
        snr_dataset = [(data, label, data_snr, length) for data, label, data_snr, length in dataset if data_snr == snr]
        snr_loader = DataLoader(snr_dataset, batch_size=params['batch_size'], shuffle=True, collate_fn=model.collate)
            
        best_state, best_loss = model.train_model(snr_loader, criterion, optimizer, params['num_epochs'], c.device, snr_str, ckpt)

        save_point = f'{ckpt}/{mtype}_{snr_str}_{best_loss:.4f}.pt'
        
        torch.save(best_state, save_point)
        print(f"Model checkpoint saved at {save_point}")

def train_set(dataset):
    print(f"Size of train dataset: {len(dataset)}")
    
    criterion = nn.CrossEntropyLoss()

    model = BiLSTM(params['input_size'], params['hidden_size'], params['num_layers'], params['num_classes'])
    
    if torch.cuda.device_count()>1:
        print("use device parallelly : ", torch.cuda.device_count(),"devices")
        model = nn.DataParallel(model)
    optimizer = optim.Adam(model.parameters(), lr=params['learning_rate'], weight_decay=params['weight_decay'])
    model.to(c.device)
    
    
    snr_loader = DataLoader(dataset, batch_size=params['batch_size'], shuffle=True, collate_fn=model.collate)
    best_state, best_loss = model.module.train_model(snr_loader, criterion, optimizer, params['num_epochs'], c.device)


    torch.save(best_state, f'./ckpts/result_loss_{best_loss:.4f}.pt')
    print("Train is done.")
        
def eval_set(dataset, mtype='LSTM'):
    print(f"Size of test dataset: {len(dataset)}")
    for snr in range(params['snr_min'], params['snr_max']+1, 2):
        ckpt = os.path.join("./ckpts/", snr_str:=f"SNR-{snr}dB" if snr != 0 else " 0dB")
        ckpts = [f for f in os.listdir(ckpt) if f.endswith(".pt") and f.startswith("LSTM")]
        
        model = BiLSTM(params['input_size'], params['hidden_size'], params['num_layers'], params['num_classes']).to(c.device)
        model.load_state_dict(torch.load(f"ckpts/result_loss.pt"))
        model.eval()
        
        criterion = nn.CrossEntropyLoss()
        snr_dataset = [(data, label, data_snr, length) for data, label, data_snr, length in dataset if data_snr == snr]
        snr_loader = DataLoader(snr_dataset, batch_size=128, shuffle=True, collate_fn=model.collate)
        
        total_loss, correct, total = 0, 0, 0
        with torch.no_grad():
            for batch in snr_loader:
                data, labels, length = batch[0].to(c.device), batch[1].to(c.device), batch[3]
                outputs = model(data, length)
                loss = criterion(outputs, labels)
                total_loss += loss.item()
                
                _, predicted = torch.max(outputs.data, 1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)
                
        acc = correct / total * 100
        avg_loss = total_loss / len(snr_loader)
        
        print(f"SNR -{snr}dB | Accuracy: {acc:.2f}% | Average Loss: {avg_loss:.4f}")


In [1]:
from dataset.RadarDataset import RadarSignalDataset  

from models._config import C
c = C()

if __name__ == "__main__":
    datajson = c.dataload(csv=True, mode='train')
    dataset = RadarSignalDataset(datajson, c.signalTypes[0:c.typeSize], snr_max=17)

<<Loading Train Data [True]>>
Data loading for 'Barker'.....Done!
Data loading for 'Costas'.....Done!
Data loading for 'Frank'.....Done!
Data loading for 'LFM'.....Done!
Data loading for 'P1'.....Done!
Data loading for 'P2'.....Done!
Data loading for 'P3'.....Done!
Data loading for 'P4'.....Done!
Data loading for 'T1'.....Done!
Data loading for 'T2'.....Done!
Data loading for 'T3'.....Done!
Data loading for 'T4'.....Done!


In [4]:
from collections import OrderedDict

def eval_set(dataset):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Size of test dataset: {len(dataset)}")
    for snr in range(params['snr_min'], params['snr_max']+1, 2):
        model = BiLSTM(params['input_size'], params['hidden_size'], params['num_layers'], params['num_classes']).to('cuda')
          
        state_dict = torch.load(f"ckpts/result_loss.pt")

        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            new_key = k.replace('module.', '')  # 'module.' 접두어 제거
            new_state_dict[new_key] = v
        
        model.load_state_dict(new_state_dict)
        
        criterion = nn.CrossEntropyLoss()
        
        snr_dataset = [(data, label, data_snr, length) for data, label, data_snr, length in dataset if data_snr == snr]
        snr_loader = DataLoader(snr_dataset, batch_size=128, shuffle=True, collate_fn=model.collate)
        
        total_loss, correct, total = 0, 0, 0
        with torch.no_grad():
            for batch in snr_loader:
                data, labels, length = batch[0].to(device), batch[1].to(device), batch[3]
                outputs = model(data, length)
                loss = criterion(outputs, labels)
                total_loss += loss.item()
                
                _, predicted = torch.max(outputs.data, 1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)
                
        acc = correct / total * 100
        avg_loss = total_loss / len(snr_loader)
        
        print(f"SNR -{snr}dB | Accuracy: {acc:.2f}% | Average Loss: {avg_loss:.4f}")
        
eval_set(dataset)

Size of test dataset: 108000
SNR -0dB | Accuracy: 97.96% | Average Loss: 0.0958
SNR -2dB | Accuracy: 96.91% | Average Loss: 0.1434
SNR -4dB | Accuracy: 95.30% | Average Loss: 0.2261
SNR -6dB | Accuracy: 91.93% | Average Loss: 0.4116
SNR -8dB | Accuracy: 87.91% | Average Loss: 0.6351
SNR -10dB | Accuracy: 83.05% | Average Loss: 0.9147
SNR -12dB | Accuracy: 74.74% | Average Loss: 1.3739
SNR -14dB | Accuracy: 62.91% | Average Loss: 2.1408
SNR -16dB | Accuracy: 48.99% | Average Loss: 3.2212


In [None]:
import visdom
import torch
import torch.nn as nn
import models._config as c
import torch.nn.functional as F
import torch.nn.utils.rnn as rnn_utils
from matplotlib import pyplot as plt
from tqdm import tqdm


    
# Temporal Attention Layer (Zero-padding에 대한 가중치 조정 포함)
class TemporalAttention(nn.Module):
    def __init__(self, hidden_size):
        super(TemporalAttention, self).__init__()
        self.hidden_size = hidden_size
        self.attn = nn.Linear(hidden_size * 2, hidden_size)  # BiLSTM이므로 hidden_size * 2
        self.v = nn.Parameter(torch.rand(hidden_size)) 
        
    def forward(self, hidden_states, mask=None):
        """
        hidden_states: [batch_size, seq_len, hidden_size * 2]
        mask: [batch_size, seq_len] - zero-padding mask
        """
        attn_weights = torch.tanh(self.attn(hidden_states))  # [batch_size, seq_len, hidden_size]
        attn_weights = attn_weights.matmul(self.v)           # [batch_size, seq_len]

        if mask is not None:
            mask = mask.to(attn_weights.device)
            attn_weights = attn_weights.masked_fill(mask == 0, -1e9)  # Zero-padding에 대한 large negative

        attn_weights = F.softmax(attn_weights, dim=1)  # [batch_size, seq_len]에서 softmax로 중요도 결정
        
        # 가중치를 반영하여 각 타임 스텝의 hidden state를 곱해줌
        context = torch.sum(hidden_states * attn_weights.unsqueeze(-1), dim=1)  # [batch_size, hidden_size * 2]
        return context, attn_weights

class BiLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(BiLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # Bidirectional LSTM Layer
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True, dropout=0.3)
        self.attention = TemporalAttention(hidden_size)
        self.fc = nn.Linear(hidden_size * 2, num_classes)  # Bidirectional이므로 hidden_size * 2
        self.dropout = nn.Dropout(0.5)

    def forward(self, x, lengths, lstm_outputs=False):
        batch_size, seq_len, _ = x.size()

        # Initial hidden state and cell state
        h0 = torch.zeros(self.num_layers * 2, batch_size, self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers * 2, batch_size, self.hidden_size).to(x.device)

        # PackedSequence로 변환하여 RNN/LSTM에서 패딩 무시
        packed_x = rnn_utils.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
        packed_out, (hn, cn) = self.lstm(packed_x, (h0, c0))  # LSTM 통과
        
        # 다시 패딩된 시퀀스로 변환
        out, _ = rnn_utils.pad_packed_sequence(packed_out, batch_first=True)

        # Zero-padding Mask 생성
        mask = torch.arange(seq_len).expand(batch_size, seq_len) < lengths.unsqueeze(1)
        
        # Attention with Zero-padding Mask 적용
        context, attn_weights = self.attention(out, mask)  # Self-Attention 통과
        
        out_last = self.dropout(context)  # Dropout
        out_fc = self.fc(out_last)  # Fully connected layer
                
        if lstm_outputs:
            return out_fc, out, attn_weights
        else:
            return out_fc


    def train_model(self, train_loader, criterion, optimizer, num_epochs, device):
        vis = visdom.Visdom()
        assert vis.check_connection(), "Visdom 서버를 실행 필수 : python -m visdom.server"

        losses = []  
        vis_window = vis.line(
            X=torch.zeros((1,)).cpu(),
            Y=torch.zeros((1,)).cpu(),
            opts=dict(xlabel='Epoch', ylabel='Loss', title=f'Training Loss', legend=['Loss'])
        )
        
        best_loss = float('inf')  # Best loss 초기화
        best_state = None

        for epoch in range(num_epochs):
            self.train()  
            running_loss = 0.0  
            
            progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch + 1}/{num_epochs}", leave=False)
            
            for batch_idx, (data_batch, labels_batch, _, lengths_batch) in progress_bar:
                data_batch = data_batch.to(device)
                labels_batch = labels_batch.to(device)
                lengths_batch = lengths_batch.cpu()  # 시퀀스 길이를 CPU로 이동
                optimizer.zero_grad()
                
                outputs = self(data_batch, lengths_batch)  # self()는 forward()를 호출함
                
                # Loss 계산
                loss = criterion(outputs, labels_batch)
                loss.backward()
                optimizer.step()

                running_loss += loss.item()

                # 현재 배치 번호와 평균 손실을 tqdm에 표시
                progress_bar.set_postfix({
                    'Batch': f"{batch_idx + 1}/{len(train_loader)}",
                    'Loss': f"{loss.item():.4f}"
                })

                loss.detach()

            avg_loss = running_loss / len(train_loader)
            losses.append(avg_loss)

            vis.line(
                X=torch.tensor([epoch + 1]).cpu(),
                Y=torch.tensor([avg_loss]).cpu(),
                win=vis_window,
                update='append'
            )

            print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}')
            
            # Best loss 갱신
            if avg_loss < best_loss:
                best_loss = avg_loss
                best_state = self.state_dict()

            if best_loss < 0.01:
                break
            
            torch.cuda.empty_cache()

        return best_state, best_loss

    @staticmethod        
    def collate(batch):
        data, labels, snrs, lengths = zip(*batch)
        data_pad = rnn_utils.pad_sequence([torch.tensor(seq, dtype=torch.float32) for seq in data], batch_first=True)
        
        labels = torch.tensor([c.label_mapping[label] for label in labels], dtype=torch.long)
        snrs = torch.tensor(snrs, dtype=torch.float32)
        lengths = torch.tensor(lengths, dtype=torch.long)  # 시퀀스 길이를 함께 전달

        return data_pad, labels, snrs, lengths


# def train_set(dataset):

print(f"Size of train dataset: {len(train_dataset)}")

criterion = nn.CrossEntropyLoss()

model = BiLSTM(params['input_size'], params['hidden_size'], params['num_layers'], params['num_classes'])

if torch.cuda.device_count() > 1:
    print("Using", torch.cuda.device_count(), "GPUs with DataParallel.")
    model = nn.DataParallel(model)
    model.to('cuda')
else:
    model.to(c.device)
    
optimizer = optim.Adam(model.module.parameters(), lr=params['learning_rate'], weight_decay=params['weight_decay'])

snr_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, collate_fn=model.module.collate, num_workers=8)
# best_state, best_loss = model.module.train_model(snr_loader, criterion, optimizer, params['num_epochs'], c.device)


# torch.save(best_state, f'./ckpts/result_loss_{best_loss:.4f}.pt')
# print("Train is done.")
    
# train_set(train_dataset) # 8374MB

In [None]:
best_state, best_loss = model.module.train_model(snr_loader, criterion, optimizer, params['num_epochs'], c.device)


torch.save(best_state, f'./ckpts/result_loss_{best_loss:.4f}.pt')
print("Train is done.")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


#for j in range(data.size(0)):
#   relevance = lrp.get_relevance(data[j]) # [seq_len, input_size]

class LRP:
    def __init__(self, model, epsilon=1e-5):
        self.model = model
        self.epsilon = epsilon
        self.model.eval()

    def forward(self, x, lstm_outputs):
        return self.model(x, lstm_outputs)
    
    def get_relevance(self, x, target=None):        
        x.requires_grad = True
        output, hiddens = self.forward(x, lstm_outputs=True) # [32, 4] | [32, 1495, 256]
        # output은 모델의 예측 결과, hid_outputs는 LSTM의 각 시간 스텝의 hidden state를 포함
        # output : [batch_size, num_classes], hiddens : [batch_size, seq_len, hidden_size]
        target = torch.argmax(output, dim=1) if target is None else torch.tensor(target).to(x.device)

        # 정답 클래스에 대한 기여도 초기화
        relevance = torch.zeros_like(output)
        for i in range(output.size(0)):
            relevance[i, target[i]] = output[i, target[i]]
        
        relevance = relevance.unsqueeze(1).expand(-1, hiddens.size(1), -1) # [batch_size, seq_len, num_classes]
        
        
        for module in reversed(list(self.model.modules())):
            if isinstance(module, nn.Linear):
                relevance = self.linear_lrp(module, hiddens, relevance)
            elif isinstance(module, nn.LSTM):
                relevance = self.bilstm_lrp(module, hiddens, relevance)
            elif isinstance(module, nn.ReLU) or isinstance(module, nn.Tanh):
                relevance = self.activation_lrp(module, hiddens, relevance)
                
        return relevance


    # hiddens : torch.Size([32, 1495, 256]) : [batch_size, seq_len, hidden_size] | relevance : torch.Size([32, 4])
    def linear_lrp(self, layer, hiddens, relevance): 
        print(hiddens.size())
        batch_size, seq_len, _ = hiddens.size()
        
        weight = layer.weight  # [num_classes, hidden_size]
        bias = layer.bias      # [num_classes]
        total_relevance = torch.zeros_like(hiddens)  # [batch_size, seq_len, hidden_size]

        for i in range(batch_size):
            hiddens_i = hiddens[i]                           # [seq_len, hidden_size]
            relevance_i = relevance[i]                       # [num_classes]
            total_relevance_i = torch.zeros_like(hiddens_i)  # [seq_len, hidden_size]

            for t in range(seq_len):
                h_t = hiddens_i[t, :]                              
                z_t = F.linear(h_t, weight, bias) + self.epsilon   
                s_t = relevance_i / z_t                            
                c_t = torch.matmul(s_t, weight)                    

                total_relevance_i[t, :] = h_t * c_t  # [hidden_size]

            total_relevance[i] = total_relevance_i

        print(total_relevance.size())  # total_relevance: [batch_size, seq_len, hidden_size]
        return total_relevance
    
    def bilstm_lrp(self, layer, hiddens, relevance):
        batch_size, seq_len, hidden_size2 = hiddens.size()
        hidden_size = hidden_size2 // 2                     # only one-direction hidden size
        
        h_fw, h_bw = hiddens[:, :, :hidden_size], hiddens[:, :, hidden_size:]
        rel_fw, rel_bw = torch.zeros_like(h_fw), torch.zeros_like(h_bw)
        
        for t in reversed(range(seq_len)):
            h_t_fw = h_fw[:, t, :]  # [batch_size, hidden_size]
 
            z_t_fw = F.linear(h_t_fw, layer.weight_ih_l0, layer.bias_ih_l0) + self.epsilon
                    
            s_t_fw = relevance / (z_t_fw + self.epsilon)
            c_t_fw = s_t_fw @ layer.weight_ih_l0[:hidden_size, :].t() 
            rel_fw[:, t, :] = h_t_fw * c_t_fw
            
            h_t_bw = h_bw[:, t, :]  # [batch_size, hidden_size]
            z_t_bw = F.linear(h_t_bw, layer.weight_ih_l0_reverse, layer.bias_ih_l0_reverse) + self.epsilon
            s_t_bw = relevance / (z_t_bw + self.epsilon)
            c_t_bw = s_t_bw @ layer.weight_ih_l0[hidden_size:, :].t()
            rel_bw[:, t, :] = h_t_bw * c_t_bw
            
        total_relevance = rel_fw + rel_bw
        print(f"total relevance: {total_relevance.size()}")
        return total_relevance
 
    def compute_lstm_relevance(self, layer, x, relevance, direction='forward'):
        """
        LSTM의 각 시간 스텝에 대해 relevance 계산.
        - layer: nn.LSTM 레이어
        - x: 입력 텐서
        - relevance: 기여도 텐서
        - direction: 'forward' or 'backward' (방향 선택)
        """
        # 정방향 또는 역방향 LSTM의 각 시간 스텝에 대해 relevance를 계산
        seq_len = x.size(1)
        
        # 방향에 따라 시퀀스를 정방향 또는 역방향으로 순회하며 기여도 계산
        if direction == 'forward':
            time_steps = range(seq_len)
        elif direction == 'backward':
            time_steps = reversed(range(seq_len))
        
        # 시퀀스를 순회하면서 각 시간 스텝에 대한 기여도 계산
        for t in time_steps:
            h_t = layer(x[:, t, :])[0]  # 현재 시간 스텝의 LSTM 출력
            z = h_t + self.epsilon  # 작은 epsilon 추가
            s = relevance[:, t, :] / z  # relevance 계산
            relevance[:, t, :] = x[:, t, :] * s
        
        return relevance
    
    def activation_lrp(self, layer, x, relevance):
        """
        활성화 함수의 LRP 계산 (ReLU, Tanh 등).
        활성화 함수에서는 LRP를 사용하여 기여도 역전파를 수행.
        """
        # 활성화 함수의 입력 x에 대해 기여도를 직접 역전파
        return relevance * (x > 0).float()  # 활성화된 뉴런만 기여도를 전달


explain_set(explain_dataset)

In [None]:
import numpy as np
from explainer import LRP

def explain_mode(dataset):
    print(f"Size of test dataset: {len(dataset)}")
    for snr in range(params['snr_min'], params['snr_max']+1, 2):
        ckpt = os.path.join("./ckpts/", snr_str:=f"-{snr}dB" if snr != 0 else "0dB")
        ckpts = [f for f in os.listdir(ckpt) if f.endswith(".pt")]

        model = BiLSTM(params['input_size'], params['hidden_size'], 
                    params['num_layers'], params['num_classes']).to(c.device)
        model.load_state_dict(torch.load(f"{ckpt}/{ckpts[0]}"))
        model.eval()
        
        explainer = LRP(model)
        dataset_batch = DataLoader(dataset, batch_size=params['batch_size'], shuffle=False, collate_fn=model.collate)
        for data, labels, _ in dataset_batch:
            data = data.to(c.device)
            relevances = explainer.get_relevance(data)
            num_classes = relevances.shape[3]
            
            for i, d, l in zip(range(len(data)), data, labels):
                fig, axs = plt.subplots(1, num_classes, figsize=(20, 5))
                for class_idx in range(num_classes):
                    relevance_class = relevances[i, :, :, class_idx].sum(axis=1) # sum over input channels
                    real, imag = data[i, :, 0].cpu().numpy(), data[i, :, 1].cpu().numpy()
                    
                    axs[class_idx].plot(real, color='gray', label='Real')
                    
                    relv_pos = np.where(relevance_class > 0, relevance_class, np.nan)
                    relv_neg = np.where(relevance_class < 0, relevance_class, np.nan)
                    
                    ax2 = axs[class_idx].twinx()
                    ax2.plot(relv_pos, color='red', label='Positive Relevance')
                    ax2.plot(relv_neg, color='blue', label='Negative Relevance')
                    # ax2.set_ylim(-0.5, 0.5)
                    ax2.set_yticks([])
        
                plt.subtitle(f'LRP values for {l} sample {i} at SNR -{snr}dB')
                plt.tight_layout()
                plt.show()
                break
            break
        break            
        
explain_mode(explain_dataset)

# 데이터셋을 불러오는 과정에서 패딩은 해당 배치에서 가장 길이가 긴 값을 기준으로 이루어지기 때문에 data의 일부가 0으로 되어있을 수 있음
# 
# shap 값은 해당 데이터의 실제 값과 비교하여 어떤 부분이 중요한지를 나타내는 값이기 때문에 0으로 패딩된 부분은 중요하지 않음
# 따라서 shap 값을 보기 위해서는 해당 데이터의 실제 값만을 보는 것이 좋음
# 이를 반영해 코드는 수정

# SHAP

In [None]:

def explain_mode(dataset):
    print(f"Size of test dataset: {len(dataset)}")
    for snr in range(params['snr_min'], params['snr_max']+1, 2):
        ckpt = os.path.join("./ckpts/", snr_str:=f"-{snr}dB" if snr != 0 else "0dB")
        ckpts = [f for f in os.listdir(ckpt) if f.endswith(".pt")]

        LSTMmodel = BiLSTM(params['input_size'], params['hidden_size'], 
                    params['num_layers'], params['num_classes']).to(c.device)
        LSTMmodel.load_state_dict(torch.load(f"{ckpt}/{ckpts[0]}"))
        
        
        dataset_batch = DataLoader(dataset, batch_size=params['batch_size'], shuffle=False, collate_fn=LSTMmodel.collate)
        for data, labels, _ in dataset_batch:
            data = data.to(c.device)
            
            for i, d, l in zip(range(len(data)), data, labels):                
                fig, axs = plt.subplots(1, 2, figsize=(20, 5))
                for class_idx in range(2): 
                    # shap_class = shap_values[i, :, :, class_idx].sum(axis=1)
                    
                    real_data = data[i, :, 0].cpu().numpy()
                    imag_data = data[i, :, 1].cpu().numpy()
                    
                    axs[class_idx].plot(real_data, color='blue') # plot(x, y)
                    axs[class_idx+1].plot(imag_data, color='orange')
                    # shap_pos = np.where(shap_class >= 0, shap_class, np.nan)
                    # shap_neg = np.where(shap_class < 0, shap_class, np.nan)
                    
                    # ax2 = axs[class_idx].twinx()
                    # ax2.plot(shap_pos, label='Positive SHAP', color='orange', linestyle='--', alpha=0.7)
                    # ax2.plot(shap_neg, label='Negative SHAP', color='blue', linestyle='--', alpha=0.7)
 
                    # ax2.set_ylim(-0.06, 0.053)
                    # ax2.set_yticks([])
                    # ax2.legend(loc='upper right')
                    
                    # axs[class_idx].set_title(f'Class {class_idx}')
                    # axs[class_idx].set_xlabel('Time')
                    # axs[class_idx].set_ylabel('SHAP value')
                    
                    # axs[class_idx].grid(True)
                    
                plt.suptitle(f'SHAP values for {l} Sample {i} at SNR {snr_str}')
                plt.tight_layout(rect=[0, 0, 1, 0.96])
                plt.show()
                break
            break
        break
        
explain_mode(explain_dataset)

# 데이터셋을 불러오는 과정에서 패딩은 해당 배치에서 가장 길이가 긴 값을 기준으로 이루어지기 때문에 data의 일부가 0으로 되어있을 수 있음
# 
# shap 값은 해당 데이터의 실제 값과 비교하여 어떤 부분이 중요한지를 나타내는 값이기 때문에 0으로 패딩된 부분은 중요하지 않음
# 따라서 shap 값을 보기 위해서는 해당 데이터의 실제 값만을 보는 것이 좋음
# 이를 반영해 코드는 수정

In [None]:
import os
import re
import torch
import glob
import json

import numpy as np
import pandas as pd
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from TripletConvolution import TCN, trainTCN

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(device)

signalTypes = ['Barker', 'Costas', 'Frank', 'LFM', 'P1', 'P2', 'P3', 'P4', 'T1', 'T2', 'T3', 'T4']

RawType = "/data/kiwan/dataset-CWD-50/"
TransformedTypes = {'DWT' : "/data/kiwan/Unknown_radar_detection/Adaptive_wavelet_transform/dataset-SPWVD-denoised-Adaptive_DWT",
                    'CWD' : "/data/kiwan/Unknown_radar_detection/Adaptive_wavelet_transform/240523_CWD-v1/",
                    'SAFI' : "/data/kiwan/Unknown_radar_detection/Adaptive_wavelet_transform/240523_SAFI-v1/",}


In [None]:
json_file = 'dataset/CWD_signals.json'

with open(json_file, 'r') as f:
    SignalData = json.load(f)

In [None]:
# [signal_type]>[snr_value]>[step]>[timepoint][real, imag]

import matplotlib.pyplot as plt
from torch.utils.data import Dataset


import torch.nn.utils.rnn as rnn_utils

class RadarSignalDataset(Dataset):
    def __init__(self, signals_data, signal_types, snr_max=17):
        self.data = []
        self.labels = []
        self.label_mapping = {signal: idx for idx, signal in enumerate(signalTypes)}

        for signal_type in signal_types:
            print(f"Data loading for '{signal_type}'", end='')
            for snr_idx, snr in enumerate(range(0, snr_max, 2)): 
                print(".", end='') if snr_idx % 2 == 0 else None
                ssnr = str(snr)
                if ssnr in signals_data[signal_type]: 
                    signal_snr_data = signals_data[signal_type][ssnr]
                    for signal in signal_snr_data:
                        complex_signal = [self.convIQ(x) for x in signal]
                        self.data.append(complex_signal)
                        self.labels.append(signal_type)
            print("Done!")
    
    @staticmethod
    def convIQ(datastring):
        comp = complex(datastring.replace('i', 'j'))
        return comp.real, comp.imag
    
    staticmethod
    def collate(self, batch):
        data, labels = zip(*batch)
        padding = rnn_utils.pad_sequence([torch.tensor(seq, dtype=torch.float32) for seq in data], batch_first=True)
        labels = [self.label_mapping[label] for label in labels]
        labels = torch.tensor(labels, dtype=torch.long)
        return padding, labels
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]


In [None]:
typeSize = 4
dataset = RadarSignalDataset(SignalData, signalTypes[0:typeSize], snr_max=17)

In [None]:
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=dataset.collate)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=dataset.collate)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=0.5)

        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        # Initial hidden state와 cell state
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        
        out, _ = self.lstm(x, (h0, c0))     # LSTM 출력: (배치 크기, 시퀀스 길이, hidden state 크기)
        out = self.fc(out[:, -1, :])        # 마지막 타임스텝만 사용하여 출력 계산
        return out

input_size = 2          # (Real, Imag)
hidden_size = 128       # LSTM hidden state size
num_layers = 2          
num_classes = len(signalTypes[:typeSize])  # Expected Output size

model = LSTM(input_size, hidden_size, num_layers, num_classes).to(device)

criterion = nn.CrossEntropyLoss() 
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)


In [None]:
num_epochs = 40 
batch_size = 128

for epoch in range(num_epochs):
    model.train()  
    rl = 0.0
    
    for data_batch, labels_batch in train_loader:
        data_batch = data_batch.to(device)
        labels_batch = labels_batch.to(device)

        optimizer.zero_grad()
        outputs = model(data_batch)

        loss = criterion(outputs, labels_batch)
        loss.backward()
        optimizer.step()

        rl += loss.item()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {rl/len(train_loader):.4f}')
torch.save(model.state_dict(), f'./lstm_l{rl/len(train_loader):.4f}.pt')

model.eval() 
correct = 0
total = 0

with torch.no_grad():
    for data_batch, labels_batch in test_loader:
        data_batch = data_batch.to(device)
        labels_batch = labels_batch.to(device)
        outputs = model(data_batch)
        _, predicted = torch.max(outputs.data, 1)
        total += labels_batch.size(0)
        correct += (predicted == labels_batch).sum().item()

print(f'Test Accuracy: {100 * correct / total:.2f}%')


In [None]:
version = 0
targetSNR = 0

vDataset = [(d,l,s) for d, l, s, _ in Dataset if s == targetSNR and signalTypes[l] in signal_groups[f'v{version}']]

plot = {}
for data, label, _ in vDataset:
    signalType = signalTypes[label]
    if signalType not in plot:
        plot[signalType] = data

fig, axs = plt.subplots(1, len(plot))
fig.set_size_inches(15, 5)

for i, (signal_type, image) in enumerate(plot.items()):
    axs[i].imshow(image.squeeze(), cmap='gray')
    axs[i].set_title(signal_type)
    axs[i].axis('off')

plt.show()

In [None]:
tcn = TCN(input_channel=1).cuda() 
optimizer = optim.Adam(tcn.parameters(), lr=1e-4)


maxSNR = 2
unique_labels = np.unique([l for _, l, _ in vDataset])

for snr in range(0, maxSNR, 2):
    snrDataset = [(d,l,s) for d, l, s in vDataset if s == snr]
    trainTCN(tcn, optim=optimizer, dataset=snrDataset, data_type='DWT', snr=snr, epochs=20)