In [None]:
!pip install spikingjelly
!pip install tensorboardX


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting spikingjelly
  Downloading spikingjelly-0.0.0.0.14-py3-none-any.whl (437 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m437.6/437.6 KB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: spikingjelly
Successfully installed spikingjelly-0.0.0.0.14
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tensorboardX
  Downloading tensorboardX-2.6-py2.py3-none-any.whl (114 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m114.5/114.5 KB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tensorboardX
Successfully installed tensorboardX-2.6


In [None]:
# from google.colab import files
# uploaded = files.upload()
from google.colab import drive
drive.mount('/content/drive',force_remount=True)
import sys
sys.path.insert(0, 'drive/MyDrive/Colab Notebooks/CS 679 Project')



In [None]:
import time
import logging
from data_builder import *
import argparse
from networks_for_CIFAR import *
from networks_for_ImageNet import *
from utils import *
from layers import *
from tensorboardX import SummaryWriter
from torch.cuda import amp
from schedulers import *
from Regularization import *
import random
import os
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
import torch
from tqdm.autonotebook import tqdm

In [None]:
####################################################
# args                                             #
#                                                  #
####################################################
class Args:
    def __init__(self, eval=False, eval_resume='./raw/models', train_resume='./raw/models', batch_size=16, epochs=3, 
                 learning_rate=1e-1, momentum=0.9, weight_decay=4e-5, seed=9, auto_continue=False, display_interval=10, 
                 save_interval=10, dataset_path='./dataset/', train_dir='./fmnist/train', val_dir='./fmnist/val', 
                 tunable_lif=True, amp=False, modeltag='SNN', gate=[0.6, 0.8, 0.6], static_gate=False, static_param=False, 
                 channel_wise=False, softsimple=False, soft_mode=False, t=3, randomgate=False, fmnist=False, celeba=True):
        self.eval = eval
        self.eval_resume = eval_resume
        self.train_resume = train_resume
        self.batch_size = batch_size
        self.epochs = epochs
        self.learning_rate = learning_rate
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.seed = seed
        self.auto_continue = auto_continue
        self.display_interval = display_interval
        self.save_interval = save_interval
        self.dataset_path = dataset_path
        self.train_dir = train_dir
        self.val_dir = val_dir
        self.tunable_lif = tunable_lif
        self.amp = amp
        self.modeltag = modeltag
        self.gate = gate
        self.static_gate = static_gate
        self.static_param = static_param
        self.channel_wise = channel_wise
        self.softsimple = softsimple
        self.soft_mode = soft_mode
        self.t = t
        self.randomgate = randomgate
        self.fmnist = fmnist
        self.celeba = celeba


# Data Builder

In [None]:
def seed_all(seed=1):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

args = Args()

seed_all(args.seed)

if torch.cuda.device_count() > 1:
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.distributed.init_process_group(backend="nccl")
    torch.cuda.set_device(local_rank)

# Log
log_format = '[%(asctime)s] %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
    format=log_format, datefmt='%d %I:%M:%S')
t = time.time()
local_time = time.localtime(t)
if not os.path.exists('./log'):
    os.mkdir('./log')
fh = logging.FileHandler(os.path.join('log/train-{}{:02}{}'.format(local_time.tm_year % 2000, local_time.tm_mon, t)))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)

epochs = 1
initial_dict = {'gate': [0.6, 0.8, 0.6], 'param': [tau, Vth, linear_decay, conduct],
                't': steps, 'static_gate': True, 'static_param': False, 'time_wise': True, 'soft_mode': False}
initial_dict['gate'] = args.gate
initial_dict['static_gate'] = args.static_gate
initial_dict['static_param'] = args.static_param
initial_dict['time_wise'] = False
initial_dict['soft_mode'] = args.soft_mode
if args.t != steps:
    initial_dict['t']=args.t

# In case time step is too large, we intuitively recommend to use the following code to alleviate the linear decay
# initial_dict['param'][2] = initial_dict['param'][1]/(initial_dict['t'] * 2)


use_gpu = False
if torch.cuda.is_available():
    use_gpu = True
if args.fmnist:
    train_loader, val_loader, _ = build_data(dpath=args.dataset_path,dataset='FashionMNIST',
                                              batch_size=args.batch_size, train_val_split=False, workers=2)
else:
    train_loader, val_loader, _ = build_data(dpath=args.dataset_path,dataset='CelebA',
                                              batch_size=args.batch_size, train_val_split=False, workers=2)

print('load data successfully')

print(initial_dict)



# Train & Test 

In [7]:
log_interval = 200

def train(args, model, device, train_loader, optimizer, epoch, writer, scaler=None):
    layer_cnt, gate_score_list = None, None
    t1 = time.time()
    train_loss = 0
    model.train()
    for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
        data, target = data.to(device, dtype=torch.float), target.to(device, dtype=torch.long)
        optimizer.zero_grad()
        if scaler is not None:
            with amp.autocast():
                output, input, mean, log_variances = model(data)
                train_loss = model.loss_function(mean, log_variances, output, input)
                total_loss = train_loss['loss']

        else:
            output, input, mean, log_variances = model(data)
            train_loss = model.loss_function(means=mean, log_variances=log_variances, output=output, target=input)
            total_loss = train_loss['loss']

        if scaler is not None:
            scaler.scale(total_loss).backward()
            scaler.step(optimizer)
            scaler.update()

        else:
            total_loss.backward()
            optimizer.step()

        # if batch_idx % log_interval == 0:
        #     print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        #         epoch, batch_idx * len(data), len(train_loader.dataset),
        #         100. * batch_idx / len(train_loader),
        #         total_loss.item() / len(data)))
    # average_train_loss = train_loss / len(train_loader.dataset)
    return train_loss


In [8]:
def test(args, model, device, test_loader, epoch, writer, modeltag, dict_params, best= None):
    layer_cnt, gate_score_list = None, None
    test_loss = 0
    model.eval()# inactivate BN
    t1 = time.time()
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device, dtype=torch.float), target.to(device, dtype=torch.long)
            output, input, mean, log_variances = model(data)
            total_loss, recon_loss, kld_loss = model.loss_function(mean, log_variances, output, input)
            test_loss+=total_loss

        record_param(args, model, dict=dict_params, epoch=epoch, modeltag=modeltag)
    average_test_loss = test_loss / len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(average_test_loss))
    return average_test_loss


# Model


In [9]:

class VanillaVAE_GLIF(nn.Module):
    def __init__(self, lif_param:dict, tunable_lif=False, in_channels=3, hidden_dims=[32, 64, 128, 256, 512], 
                latent_dim=128,
                beta: int = 4,
                gamma:float = 10.,
                max_capacity: int = 25,
                Capacity_max_iter: int = 1e5,
                loss_type = "beta",
                kld_weight_corrector = 1.0):
        super(VanillaVAE_GLIF, self).__init__()
        
        self.choice_param_name = ['alpha', 'beta', 'gamma']
        self.lifcal_param_name = ['tau', 'Vth', 'leak', 'conduct', 'reVth']
        self.T = lif_param['t']
        self.lif_param = lif_param
        self.tunable_lif = tunable_lif
        self.gamma = gamma
        self.beta = beta
        self.C_max = torch.Tensor([max_capacity])
        self.C_stop_iter = Capacity_max_iter
        self.loss_type = loss_type
        self.kld_weight = kld_weight_corrector
        self.latent_dim = latent_dim

        image_channels = in_channels
        modules = []
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    layer.SeqToANNContainer(nn.Conv2d(in_channels, out_channels=h_dim, kernel_size= 3, stride= 2, padding  = 1), nn.BatchNorm2d(h_dim)),
                    LIFSpike_CW(h_dim, **self.lif_param)
                    # nn.LeakyReLU())
            ))
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = layer.SeqToANNContainer(nn.Linear(hidden_dims[-1]*4, latent_dim))
        self.fc_var = layer.SeqToANNContainer(nn.Linear(hidden_dims[-1]*4, latent_dim))

        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1]*4)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    layer.SeqToANNContainer(nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1])),
                    LIFSpike_CW(hidden_dims[i + 1], **self.lif_param)
            ))
            
        self.decoder = nn.Sequential(*modules)
        self.final_layer = nn.Sequential(
                            layer.SeqToANNContainer(nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1])),
                            LIFSpike_CW(hidden_dims[-1], **self.lif_param),
                            layer.SeqToANNContainer(nn.Conv2d(hidden_dims[-1], out_channels=image_channels,
                                      kernel_size= 3, padding= 1)),
                            nn.Tanh())

        self._initialize_weights()
        print('steps:{}'.format(self.T),
              'init-tau:{}'.format(tau),
              'aa:{}'.format(aa),
              'Vth:{}'.format(Vth)
              )

    def encode(self, input):
        input = input.unsqueeze(0).repeat(self.T, 1, 1, 1, 1)
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=2)
        # Split the result into mu and var components
        # of the latent Gaussian distribution
        means = self.fc_mu(result)
        log_variances = self.fc_var(result)
        return [means, log_variances]

    def decode(self, z):
        result = self.decoder_input(z)
        result = result.view(3, -1, 512, 2, 2)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def forward(self, input):
        means, log_variances = self.encode(input)
        z = self.reparameterize(means, log_variances)
        return  [self.decode(z).mean(0), input, means.mean(0), log_variances.mean(0)]
    
    def reparameterize(self, means, log_variances):
        std = torch.exp(0.5 * log_variances)
        eps = torch.randn_like(std)
        return eps * std + means


    
    def loss_function(self, means, log_variances, output, target):
        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_variances - means ** 2 - log_variances.exp(), dim = 1), dim = 0)
        reconstruction_loss = F.mse_loss(output, target)
        if self.loss_type == "beta":
          loss = reconstruction_loss + self.beta * self.kld_weight * kld_loss
        else:
          self.C_max = self.C_max.to(device)
          C = torch.clamp(self.C_max/self.C_stop_iter * self.num_iter, 0, self.C_max.data[0])
          loss = reconstruction_loss + self.gamma * self.kld_weight* (kld_loss - C).abs()
        return {"loss":loss, "Reconstruction_loss": reconstruction_loss, "KLD_loss": kld_loss}

    def generate(self, x,):
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """
        return self.forward(x)[0]

    def sample(self, num_samples):
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        z = torch.randn(num_samples,
                        self.latent_dim)

        z = z.to(device)

        samples = self.decode(z)
        return samples
    
    def randomize_gate(self):
        for name, m in self.named_modules():
            if all([hasattr(m, i) for i in self.choice_param_name]):
                for i in range(len(self.choice_param_name)):
                    setattr(m, self.choice_param_name[i],
                            nn.Parameter(
                                torch.tensor(init_constrain * (np.random.rand(m.plane) - 0.5)
                                             , dtype=torch.float)
                                        )
                            )
                    

    def _initialize_weights(self):
        for name, m in self.named_modules():
            if isinstance(m, nn.Conv2d):
                if 'first' in name:
                    nn.init.normal_(m.weight, 0, 0.01)
                else:
                    nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, tdBatchNorm)):
                if m.weight is not None:
                    nn.init.constant_(m.weight, 1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.running_mean, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0001)
                nn.init.constant_(m.running_mean, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)


In [10]:
print(initial_dict)
#prepare the model
if args.fmnist:
    model = VanillaVAE_GLIF(lif_param=initial_dict, in_channels = 1)

elif args.celeba:
    model = VanillaVAE_GLIF(lif_param=initial_dict, in_channels = 3, kld_weight_corrector=.00025)

else: #cifar10
      model = VanillaVAE_GLIF(lif_param=initial_dict)


if args.randomgate:
    randomize_gate(model)
    # model.randomize_gate
    print('randomized gate')

modeltag = args.modeltag
writer = SummaryWriter('./summaries/' + modeltag)

dict_params = create_para_dict(args, model)
# recording the initial GLIF parameters
record_param(args, model, dict=dict_params, epoch=0, modeltag=modeltag)
# classify GLIF-related params
choice_param_name = ['alpha', 'beta', 'gamma']
lifcal_param_name = ['tau', 'Vth', 'leak', 'conduct', 'reVth']
all_params = model.parameters()
lif_params = []
lif_choice_params = []
lif_cal_params = []

for pname, p in model.named_parameters():
    if pname.split('.')[-1] in choice_param_name:
        lif_params.append(p)
        lif_choice_params.append(p)
    elif pname.split('.')[-1] in lifcal_param_name:
        lif_params.append(p)
        lif_cal_params.append(p)
# fetch id
params_id = list(map(id, lif_params))
other_params = list(filter(lambda p: id(p) not in params_id, all_params))
# optimizer & scheduler
if args.tunable_lif:
    init_lr_diff = 10
    optimizer = torch.optim.SGD([
            {'params': other_params},
            {'params': lif_cal_params, "weight_decay": 0.},
            {'params': lif_choice_params, "weight_decay": 0., "lr":args.learning_rate / init_lr_diff}
        ],
            lr=args.learning_rate,
            momentum=0.9,
            weight_decay=args.weight_decay
        )
    scheduler = CosineAnnealingLR_Multi_Params_soft(optimizer,
                                                        T_max=[args.epochs, args.epochs, int(args.epochs)])
else:
    optimizer = torch.optim.SGD([
        {'params': other_params},
        {'params': lif_params, "weight_decay": 0.}
    ],
        lr=args.learning_rate,
        momentum=0.9,
        weight_decay=args.weight_decay
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)

# criterion = VAELoss(args)
device = torch.device("cuda" if use_gpu else "cpu")
#Distributed computation
# if torch.cuda.is_available():
#     loss_function = criterion.cuda()
# else:
#     loss_function = criterion.cpu()

if args.auto_continue:
    lastest_model = get_model(modeltag)
    if lastest_model is not None:
        checkpoint = torch.load(lastest_model, map_location='cpu')
        epochs = checkpoint['epoch']
        if torch.cuda.device_count() > 1:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
            checkpoint = deletStrmodule(checkpoint)
        model.load_state_dict(checkpoint['state_dict'], strict=True)
        print('load from checkpoint, the epoch is {}'.format(epochs))
        dict_params = read_param(epoch=epochs, modeltag=modeltag)
        for i in range(epochs):
            scheduler.step()
        epochs += 1


best = {'acc': 0., 'epoch': 0}

if args.eval:
    lastest_model = get_model(modeltag, addr=args.eval_resume)
    if lastest_model is not None:
        epochs = -1
        checkpoint = torch.load(lastest_model, map_location='cpu')
        model.load_state_dict(checkpoint['state_dict'], strict=True)
        if torch.cuda.device_count() > 1:
            device = torch.device(local_rank)
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
            model = nn.parallel.DistributedDataParallel(model.cuda(), device_ids=[local_rank],
                                                        output_device=local_rank,
                                                        find_unused_parameters=False)
        else:
            model = model.to(device)
        test(args, model, device, val_loader, epochs, writer, criterion=criterion,
              modeltag=modeltag, best=best, dict_params=dict_params)
    else:
        print('no model detected')
    exit(0)


if torch.cuda.device_count() > 1:
    device = torch.device(local_rank)
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = nn.parallel.DistributedDataParallel(model.cuda(), device_ids=[local_rank], output_device=local_rank,
                                                find_unused_parameters=False)
else:
    model = model.to(device)


print('the random seed is {}'.format(args.seed))

# amp
if args.amp:
    scaler = amp.GradScaler()
else:
    scaler = None
for t in range(args.epochs):
    loss_result = train(args, model, device, train_loader, optimizer, epochs, writer,
          scaler=scaler)
    print('====> Epoch: {} Loss: {}'.format(
          t, loss_result))
    # if t % 1 == 0:
    #     test(args, model, device, val_loader, epochs, writer,
    #           modeltag=modeltag, best=best, dict_params=dict_params)
    # else:
    #     pass
    print('and lr now is {}'.format(scheduler.get_last_lr()))
    scheduler.step()
writer.close()


{'gate': [0.6, 0.8, 0.6], 'param': [0.25, 0.5, 0.0625, 0.5], 't': 3, 'static_gate': False, 'static_param': False, 'time_wise': False, 'soft_mode': False}
steps:3 init-tau:0.25 aa:0.5 Vth:0.5
the random seed is 9


  0%|          | 0/10174 [00:00<?, ?it/s]

KeyboardInterrupt: ignored

In [11]:
# torch.save(model.cpu(), 'vae_glif_tunable.pt')
model = torch.load('vae_glif_sgd_3_epochs.pt').to(device)

In [None]:
test_input, test_label = next(iter((val_loader)))
test_input = test_input.to(device)
test_label = test_label.to(device)

recons = model.generate(test_input)

In [None]:
import matplotlib.pyplot as plt
plt.imshow(test_input[0].cpu().detach().permute(1, 2, 0))

In [None]:
plt.imshow(recons[0].cpu().detach().permute(1, 2, 0))

In [None]:
mu, log_var = model.encode(test_input)
mu[0][0] = mu[0][3] + 0.06
z = model.reparameterize(mu, log_var)
plt.imshow(model.decode(z).mean(0)[0].cpu().detach().permute(1, 2, 0))

In [None]:
# new_perp = model.sample(1)
z = torch.randn(3, 1, 128)
z = z.to(device)
samples = model.decode(z).mean(0)
plt.imshow(  samples[0].cpu().detach().permute(1, 2, 0)  )

In [12]:
optimizer = torch.optim.SGD([
        {'params': other_params},
        {'params': lif_params, "weight_decay": 0.}
    ],
        lr=1e-3,
        momentum=0.9,
        weight_decay=0.
    )
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)


In [13]:
if args.amp:
    scaler = amp.GradScaler()
else:
    scaler = None
for t in range(args.epochs):
    loss_result = train(args, model, device, train_loader, optimizer, epochs, writer,
          scaler=scaler)
    print('====> Epoch: {} Loss: {}'.format(
          t, loss_result))
    # if t % 1 == 0:
    #     test(args, model, device, val_loader, epochs, writer,
    #           modeltag=modeltag, best=best, dict_params=dict_params)
    # else:
    #     pass
    print('and lr now is {}'.format(scheduler.get_last_lr()))
    scheduler.step()
writer.close()

  0%|          | 0/10174 [00:00<?, ?it/s]

====> Epoch: 0 Loss: {'loss': tensor(0.0289, device='cuda:0', grad_fn=<AddBackward0>), 'Reconstruction_loss': tensor(0.0247, device='cuda:0', grad_fn=<MseLossBackward0>), 'KLD_loss': tensor(4.2017, device='cuda:0', grad_fn=<MeanBackward1>)}
and lr now is [0.001, 0.001]


  0%|          | 0/10174 [00:00<?, ?it/s]

====> Epoch: 1 Loss: {'loss': tensor(0.0423, device='cuda:0', grad_fn=<AddBackward0>), 'Reconstruction_loss': tensor(0.0368, device='cuda:0', grad_fn=<MseLossBackward0>), 'KLD_loss': tensor(5.5110, device='cuda:0', grad_fn=<MeanBackward1>)}
and lr now is [0.00075, 0.00075]


  0%|          | 0/10174 [00:00<?, ?it/s]

====> Epoch: 2 Loss: {'loss': tensor(0.0235, device='cuda:0', grad_fn=<AddBackward0>), 'Reconstruction_loss': tensor(0.0185, device='cuda:0', grad_fn=<MseLossBackward0>), 'KLD_loss': tensor(4.9947, device='cuda:0', grad_fn=<MeanBackward1>)}
and lr now is [0.0002500000000000001, 0.0002500000000000001]


In [14]:
torch.save(model.cpu(), 'vae_glif_sgd_6_epochs.pt')
# model = torch.load('vae_glif_sgd_3_epochs.pt').to(device)