In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from transformers import get_cosine_schedule_with_warmup
from models import DendResNet
from models.modules import DendriticConv2d, DendriticLinear

### Parameters ###
batch_size = 64
resolution = 30
dt = 0.01
workers = 8
num_training_steps = 50000
evaluation_steps = 250

In [4]:
DEVICE = "cuda:0"

class CifarDendConv(nn.Module):
    def __init__(self, resolution=30, dt=0.001, in_channels=3):
        super().__init__()
        self.conv1 = DendriticConv2d(in_channels,6,kernel_size=5,stride=1, resolution=resolution, dt=dt)
        self.pool = nn.MaxPool2d(2,2)
        self.conv2 = DendriticConv2d(6,16,kernel_size=5,stride=1, resolution=resolution, dt=dt)
        self.fc1 = DendriticLinear(16*5*5, 120, resolution=resolution, dt=dt)
        self.fc2 = DendriticLinear(120, 84, resolution=resolution, dt=dt)
        self.fc3 = DendriticLinear(84, 10, resolution=resolution, dt=dt)

    def forward(self, x):
        # print(torch.abs(x).mean())
        x = self.pool(F.sigmoid(self.conv1(x)))
        # print(torch.abs(x).mean())
        x = self.pool(F.sigmoid(self.conv2(x)))
        # print(torch.abs(x).mean())
        x = torch.flatten(x,1)
        # print(torch.abs(x).mean())
        x = F.sigmoid(self.fc1(x))
        # print(torch.abs(x).mean())
        x = F.sigmoid(self.fc2(x))
        # print(torch.abs(x).mean())
        x = self.fc3(x)
        # print(torch.abs(x).mean())
        return x

class CifarConv(nn.Module):
    def __init__(self, resolution=30, dt=0.001, in_channels=3):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels,6,kernel_size=5,stride=1)
        self.pool = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(6,16,kernel_size=5,stride=1)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # print(torch.abs(x).mean())
        x = self.pool(F.sigmoid(self.conv1(x)))
        # print(torch.abs(x).mean())
        x = self.pool(F.sigmoid(self.conv2(x)))
        # print(torch.abs(x).mean())
        x = torch.flatten(x,1)
        # print(torch.abs(x).mean())
        x = F.sigmoid(self.fc1(x))
        # print(torch.abs(x).mean())
        x = F.sigmoid(self.fc2(x))
        # print(torch.abs(x).mean())
        x = self.fc3(x)
        # print(torch.abs(x).mean())
        return x
        
### Define Model ###
model = CifarDendConv(resolution=resolution, dt=dt)
model = model.to(DEVICE)

### Define Datasets ###
train = CIFAR10('datasets', train=True, download=True,
                          transform=transforms.Compose([
                              transforms.ToTensor(),
                              transforms.Normalize((0.4914, 0.4822, 0.4465),(0.247, 0.243, 0.261))
                           ]))
    
test = CIFAR10('datasets', train=False, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.4914, 0.4822, 0.4465),(0.247, 0.243, 0.261))
                       ]))


train_loader = DataLoader(train, 
                          batch_size=batch_size, 
                          shuffle=True, 
                          num_workers=workers, 
                          pin_memory=True)

test_loader = DataLoader(test, 
                         batch_size=batch_size, 
                         shuffle=False, 
                         num_workers=workers, 
                         pin_memory=True)

optimizer = optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.005)

loss_fn = torch.nn.CrossEntropyLoss()

lr_scheduler = get_cosine_schedule_with_warmup(optimizer=optimizer, 
                                               num_training_steps=num_training_steps, 
                                               num_warmup_steps=250)




Files already downloaded and verified
Files already downloaded and verified


In [None]:
training_logs = {"completed_steps": [], 
                 "training_loss": [], 
                 "testing_loss": [],
                 "training_acc": [], 
                 "testing_acc": []}

train = True
completed_steps = 0
train_loss, test_loss, train_acc, test_acc = [], [], [], []
progress_bar = tqdm(range(num_training_steps))

while train:

    model.train()

    for X, y in train_loader:
        X, y = X.to(DEVICE), y.to(DEVICE)
        pred = model(X)
        
    
        # print(pred)
        # print(y)
        loss = loss_fn(pred, y)
        train_loss.append(loss.item())

        predictions = torch.argmax(pred, axis=1)
        accuracy = (predictions == y).sum() / (len(predictions))
        train_acc.append(accuracy.item())

        loss.backward()

        # for name, param in model.named_parameters():
        #     print(param.grad)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        lr_scheduler.step()

        if completed_steps % evaluation_steps == 0:

            # model.eval()

            # for X, y in tqdm(test_loader):
            #     X, y = X.to(DEVICE), y.to(DEVICE)
            #     with torch.no_grad():
            #         pred = model(X)
            #     loss = loss_fn(pred, y)
            #     test_loss.append(loss.item())

            #     predictions = torch.argmax(pred, axis=1)
            #     accuracy = (predictions == y).sum() / len(predictions)
            #     test_acc.append(accuracy.item())

            ### Save Results ###
            avg_train_loss = np.mean(train_loss)
            # avg_test_loss = np.mean(test_loss)
            avg_train_acc = np.mean(train_acc)
            # avg_test_acc = np.mean(test_acc)

            print("Training Loss:", avg_train_loss)
            # print("Testing Loss:", avg_test_loss)
            print("Training Acc:", avg_train_acc)
            # print("Testing Acc:", avg_test_acc)
            
            training_logs["completed_steps"].append(completed_steps)
            training_logs["training_loss"].append(avg_train_loss)
            # training_logs["testing_loss"].append(avg_test_loss)
            training_logs["training_acc"].append(avg_train_acc)
            # training_logs["testing_acc"].append(avg_test_acc)

        completed_steps += 1 
        progress_bar.update(1)

        

  0%|          | 0/50000 [00:00<?, ?it/s]

Training Loss: 42.548736572265625
Training Acc: 0.0625
Training Loss: 35.269591000925494
Training Acc: 0.10346115537848606
Training Loss: 29.67610917690985
Training Acc: 0.0997692115768463
Training Loss: 26.4467038384449
Training Acc: 0.09928428761651131
Training Loss: 24.239008956855827
Training Acc: 0.10141421078921078
Training Loss: 22.399147705303776
Training Acc: 0.10231814548361311
Training Loss: 20.827378056988092
Training Acc: 0.10388907395069953
Training Loss: 19.39265432118144
Training Acc: 0.10522558537978298
Training Loss: 18.008123082318704
Training Acc: 0.10611100699650175
Training Loss: 16.633152545478172
Training Acc: 0.10578631719235895
Training Loss: 15.306959760613271
Training Acc: 0.10583891443422631
Training Loss: 14.160390968974484
Training Acc: 0.10579675572519084
Training Loss: 13.188381321308654
Training Acc: 0.10566269576807731
Training Loss: 12.357779202920332
Training Acc: 0.10510708243617349
Training Loss: 11.64405546914302
Training Acc: 0.10490752642102256