In [None]:
import os
if "mla-prj-23-mla-prj-12-gu1" not in os.listdir("./"):
  !git clone https://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@github.com/MLinApp-polito/mla-prj-23-mla-prj-12-gu1.git

!pip install snntorch

In [None]:
%cd mla-prj-23-mla-prj-12-gu1/

In [None]:
import torch
import torch.nn as nn
import csv
import numpy as np
import scipy.stats
from sklearn.model_selection import train_test_split
import math
from torch.utils.data import TensorDataset, DataLoader
import snntorch as snn
from snntorch import utils, surrogate
import snntorch.functional as SF
import json

###Utilities for data loading and processing.

In [None]:
def add_pad_data(data):
  miR_data = data
  c_int = math.ceil(np.sqrt(len(miR_data[0])))
  pad = c_int ** 2 - len(miR_data[0])
  pad_width = (0, pad)

  padded_miR_data = np.zeros((miR_data.shape[0], miR_data.shape[1] + pad_width[1]))

  for i in range(len(miR_data)):
    padded_miR_data[i] = np.pad(miR_data[i], pad_width, mode='constant')

  # reshape shape[1] into (c_int, c_int)

  dim = int(np.sqrt(len(padded_miR_data[0])))
  padded_miR_data = padded_miR_data.reshape((padded_miR_data.shape[0],1, dim, dim))

  return padded_miR_data

def build_dataloader(miR_data, num_miR_label, padded_data, batch_size=404):

    if padded_data:
        miR_data = add_pad_data(miR_data)

    train_data, val_data, train_label, val_label = train_test_split(miR_data, num_miR_label, test_size=0.20, random_state=42)

    miR_train = torch.Tensor(train_data)
    miR_train = miR_train.unsqueeze(1)
    miR_train_label = torch.LongTensor(train_label)
    miR_dataset_train = TensorDataset(miR_train, miR_train_label)

    miR_val = torch.Tensor(val_data)
    miR_val = miR_val.unsqueeze(1)
    miR_val_label = torch.LongTensor(val_label)
    miR_dataset_val = TensorDataset(miR_val, miR_val_label)

    train_loader = DataLoader(miR_dataset_train, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(miR_dataset_val, batch_size=batch_size)

    if padded_data:
        num_inputs = train_data.shape[2] ** 2
    else:
        num_inputs = train_data.shape[1]

    return num_inputs, train_loader, test_loader

def normalize(data, method='zscore'):
    if method == "zscore":
        return scipy.stats.zscore(data, axis=1)

    # log2 normalization
    elif method=="log2":
        data = data + abs(np.min(data)) + 0.001
        return np.log2(data)

    # normalization between [0, 255]
    else:
       return (data - np.min(data)) / (np.max(data) - np.min(data)) * 255

def extract_label(file_name, verbose=False):
    data = {}
    label = []
    with open(file_name, "r") as fin:
        reader = csv.reader(fin, delimiter=',')
        first = True
        for row in reader:
            lbl = row[2]
            if first or "TARGET" in lbl:
                first = False
                continue
            lbl = lbl.replace("TCGA-","")

            label.append(lbl)
            if lbl in data.keys():
                data[lbl] += 1
            else:
                data[lbl] = 1
    if verbose:
        print(f"Number of classes in the dataset = {len(data)}")
        pprint.pprint(data, indent=4)

    return label

def create_dictionary(labels):
    dictionary = {}
    class_names = np.unique(labels)
    for i, name in enumerate(class_names):
        dictionary[name] = i
    return dictionary

def label_processing(labels):
    new_miRna_label = []
    dictionary = create_dictionary(labels)
    for i in labels:
        new_miRna_label.append(dictionary[i])
    return new_miRna_label

def top_10_dataset(miR_data, miR_label):
  occ = dict({k: 0 for k in set(miR_label)})

  for i in range(len(miR_label)):
    occ[miR_label[i]] += 1

  top_10_class = sorted(occ, key=occ.get,reverse=True)[:10]

  list_top_10_train = []
  list_top_10_labels = []

  for i in range(len(miR_label)):
    if miR_label[i] in top_10_class:
      list_top_10_labels.append(miR_label[i])

  for i in range(miR_data.shape[0]):
    if miR_label[i] in top_10_class:
      list_top_10_train.append(miR_data[i])

  miR_data_reduced = np.stack(list_top_10_train, axis=0)
  miR_label_reduced = list_top_10_labels

  num_miR_label_reduced = label_processing(miR_label_reduced)

  return miR_data_reduced, miR_label_reduced, num_miR_label_reduced

def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

###Utilities function for network training.

In [None]:
def forward_pass(net, num_steps, data):
  mem_rec = []
  spk_rec = []
  utils.reset(net)  # resets hidden states for all LIF neurons in net

  for step in range(num_steps):
      spk_out, mem_out = net(data)
      spk_rec.append(spk_out)
      mem_rec.append(mem_out)

  return torch.stack(spk_rec), torch.stack(mem_rec)

def test_accuracy(train_loader, net, num_steps, device):
  with torch.no_grad():
    total = 0
    acc = 0
    net.eval()

    train_loader = iter(train_loader)
    for data, targets in train_loader:
        data = data.to(device)
        targets = targets.to(device)
        spk_rec, _ = forward_pass(net, num_steps, data)


        acc += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)
        total += spk_rec.size(1)

  return acc/total

def build_layer(layer_type, beta, grad, num_outputs, output=False):
    if layer_type=="leaky":
        if output==True:
            return snn.Leaky(beta=beta, spike_grad=grad, init_hidden=True, output=output, threshold=0.4)
        else:
            return snn.Leaky(beta=beta, spike_grad=grad, init_hidden=True, threshold=0.4)

    if layer_type=="lapicque":
        if output==True:
            return snn.Lapicque(beta=beta, spike_grad=grad, init_hidden=True, output=output, threshold=0.4)
        else:
            return snn.Lapicque(beta=beta, spike_grad=grad, init_hidden=True, threshold=0.4)

    if layer_type=="rleaky":
        if output==True:
            return snn.RLeaky(beta=beta, spike_grad=grad, init_hidden=True, output=output, linear_features=num_outputs, threshold=0.4)
        else:
            return snn.RLeaky(beta=beta, spike_grad=grad, init_hidden=True, linear_features=num_outputs, threshold=0.4)

def get_cnn_dimension(input_size, params_cnn):
    conv1_out = ((input_size - 1 * (params_cnn['wd1'] - 1) -1) + 1)
    conv1_out = int(conv1_out)

    s1 = (((conv1_out - 1 * (params_cnn['h1'] - 1) -1)/params_cnn['h1']) + 1)
    s1 = int(s1)

    conv2_out = ((s1 - 1 * (params_cnn['wd2'] - 1)-1) + 1)
    conv2_out = int(conv2_out)

    s2 = (((conv2_out - 1 * (params_cnn['h2'] -1 ) -1) / params_cnn['h2'] ) + 1)
    s2 = int(s2)

    conv3_out = ((s2 - 1 * (params_cnn['wd3'] - 1)-1) + 1)
    conv3_out = int(conv3_out)

    s3_dec = (((conv3_out - 1 * (params_cnn['h3'] - 1 ) -1) / params_cnn['h3']) + 1)
    if s3_dec < 1:
        s3 = 1
    else:
        s3 = math.floor(s3_dec)

    if s3 == 0:
        s3 = 1
    return s1, s2, s3

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

##Data loading and network training.

In [None]:
set_seed(42)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

miR_label = extract_label("./dataset/tcga_mir_label.csv")
miR_data = np.genfromtxt('./dataset/tcga_mir_rpm.csv', delimiter=',')[1:,0:-1]
number_to_delete = abs(len(miR_label) - miR_data.shape[0])
miR_data = miR_data[number_to_delete:,:]

# Convert labels in number
num_miR_label = label_processing(miR_label)

# Z-score
miR_data = normalize(miR_data)

assert np.isnan(miR_data).sum() == 0

#---Number of classes---#
top_10_classes = True
padded_data = False

if top_10_classes:
  n_classes = 10
  miR_data, miR_label, num_miR_label = top_10_dataset(miR_data, miR_label)
else:
  n_classes = np.unique(miR_label).size

num_inputs, train_loader, test_loader = build_dataloader(miR_data, num_miR_label, padded_data, batch_size=128)

grad = surrogate.fast_sigmoid()

Run one of these cell to define and train the spiking network that comes from the non-spiking architecture CNN_PT_1 (first cell) or the one that comes from CNN_PT_2 (second cell).

In [None]:
with open("./best_hyperparams/cnn_params_best.json", "r") as f:
  params_cnn = json.load(f)

s1, s2, s3 = get_cnn_dimension(num_inputs, params_cnn)

#with open("./best_hyperparams/snn_params_SNN_PT_1_ce_tradeoff.json", "r") as f:
with open("./best_hyperparams/snn_params_SNN_PT_1_ce_best.json", "r") as f:
  params_snn = json.load(f)

state_dict = torch.load("./trained_models/cnn_best_acc_trained_model.pt")

snn_model = nn.Sequential(
  nn.Conv1d(in_channels=1, out_channels=params_cnn['w1'], kernel_size=params_cnn['wd1']),
  nn.MaxPool1d(kernel_size=params_cnn['h1']),
  build_layer(params_snn['neuron_type'], params_snn['beta'], grad, s1),
  nn.Conv1d(in_channels=params_cnn['w1'], out_channels=params_cnn['w2'], kernel_size=params_cnn['wd2']),
  nn.MaxPool1d(kernel_size=params_cnn['h2']),
  build_layer(params_snn['neuron_type'], params_snn['beta'], grad, s2),
  nn.Conv1d(in_channels=params_cnn['w2'], out_channels=params_cnn['w3'], kernel_size=params_cnn['wd3']),
  nn.MaxPool1d(kernel_size=params_cnn['h3']),
  build_layer(params_snn['neuron_type'], params_snn['beta'], grad, s3),
  nn.Flatten(),
  nn.Linear(s3*params_cnn['w3'], n_classes),
  build_layer(params_snn['neuron_type'], params_snn['beta'], grad, n_classes, output=True)
)

In [None]:
with open("./best_hyperparams/cnn_params_light.json", "r") as f:
  params_cnn = json.load(f)

s1, s2, s3 = get_cnn_dimension(num_inputs, params_cnn)

#with open("./best_hyperparams/snn_params_SNN_PT_2_ce_tradeoff.json", "r") as f:
with open("./best_hyperparams/snn_params_SNN_PT_2_ce_best.json", "r") as f:
  params_snn = json.load(f)

state_dict = torch.load("./trained_models/cnn_light_trained_model.pt")

snn_model = nn.Sequential(
  nn.Conv1d(in_channels=1, out_channels=params_cnn['w1'], kernel_size=params_cnn['wd1']),
  nn.MaxPool1d(kernel_size=params_cnn['h1']),
  build_layer(params_snn['neuron_type'], params_snn['beta'], grad, s1),
  nn.Conv1d(in_channels=params_cnn['w1'], out_channels=params_cnn['w2'], kernel_size=params_cnn['wd2']),
  nn.MaxPool1d(kernel_size=params_cnn['h2']),
  build_layer(params_snn['neuron_type'], params_snn['beta'], grad, s2),
  nn.Conv1d(in_channels=params_cnn['w2'], out_channels=params_cnn['w3'], kernel_size=params_cnn['wd3']),
  nn.MaxPool1d(kernel_size=params_cnn['h3']),
  build_layer(params_snn['neuron_type'], params_snn['beta'], grad, s3),
  nn.Flatten(),
  nn.Linear(s3*params_cnn['w3'], n_classes),
  build_layer(params_snn['neuron_type'], params_snn['beta'], grad, n_classes, output=True)
)

###Training loop

In [None]:
with torch.no_grad():
  snn_model[0].weight = torch.nn.Parameter(state_dict['conv1.weight'])
  snn_model[0].bias = torch.nn.Parameter(state_dict['conv1.bias'])
  snn_model[3].weight = torch.nn.Parameter(state_dict['conv2.weight'])
  snn_model[3].bias = torch.nn.Parameter(state_dict['conv2.bias'])
  snn_model[6].weight = torch.nn.Parameter(state_dict['conv3.weight'])
  snn_model[6].bias = torch.nn.Parameter(state_dict['conv3.bias'])
  snn_model[10].weight = torch.nn.Parameter(state_dict['fc1.weight'])
  snn_model[10].bias = torch.nn.Parameter(state_dict['fc1.bias'])

snn_model.to(device)
print(f"SNN number of parameters: {count_parameters(snn_model)}")

optimizer = torch.optim.Adam(snn_model.parameters(), lr=params_snn['learning_rate'], betas=(0.9, 0.999))
loss_fn = SF.ce_count_loss()

num_epochs = 30
curr_acc = -np.inf

# Outer training loop
for epoch in range(num_epochs):
  # Training loop
  for data, targets in iter(train_loader):
      data = data.to(device)
      targets = targets.to(device)

      # forward pass
      snn_model.train()
      spk_rec, _ = forward_pass(snn_model, params_snn['num_steps'], data)

      # initialize the loss & sum over time
      loss_val = loss_fn(spk_rec, targets)

      # Gradient calculation + weight update
      optimizer.zero_grad()
      loss_val.backward()
      optimizer.step()

  epoch_acc = test_accuracy(test_loader, snn_model, params_snn['num_steps'], device)
  if epoch_acc >= curr_acc:
      curr_acc = epoch_acc

  print(f"Epoch [{epoch + 1}/{num_epochs}] Test Accuracy: {epoch_acc*100:.2f}%")