In [1]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim
import torch.nn.utils as utils
import torch.nn.utils.parametrizations as param
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import datasets, transforms

import torchdyn 
from torchdyn.core import NeuralODE

import os
import cv2 
import numpy as np 
from tqdm import tqdm 
import matplotlib.pyplot as plt 

import warnings 
warnings.filterwarnings("ignore")

## DATA

In [2]:
# Step 1: Device setup
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# Step 2: Load and Normalize the CIFAR-10 Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # MNIST has 1 channel
])


full_train_dataset = datasets.CIFAR10(root='data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='data', train=False, download=True, transform=transform)

# Step 3: Split the original training set into training (70%) and validation set (30%)
train_size = int(0.7 * len(full_train_dataset))
val_size = len(full_train_dataset) - train_size
train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

Using device: mps
Files already downloaded and verified
Files already downloaded and verified


## MODELS

### CNNODE Block

In [3]:
class CNNODEBlock(nn.Module):
    def __init__(self, filters:int, kernel:int=3, expand:float=1, drop=0.3):
        super(CNNODEBlock, self).__init__()
        xilters = int(filters * expand)
        self.conv1 = utils.spectral_norm(nn.Conv2d(filters, xilters, kernel - 2, padding= (kernel - 2) // 2), )
        self.conv2 = utils.spectral_norm(nn.Conv2d(filters, xilters, kernel    , padding= kernel // 2))
        self.conv3 = utils.spectral_norm(nn.Conv2d(filters, xilters, kernel + 2, padding= (kernel + 2) // 2))
        self.conv =  utils.spectral_norm(nn.Conv2d(xilters*3, filters, 1))
        self.drop = drop

        self.act = nn.SiLU()

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x)
        x3 = self.conv3(x)
        x = torch.concat([x1, x2, x3], dim=1)
        x = self.act(x)
        return self.conv(x)

### UPsampling Conv Block 

In [4]:
class UPBlock(nn.Module):
    def __init__(self, infilter:int, outfilter:int, kernel:int, moment:float, drop:float):
        super(UPBlock, self).__init__()
        self.conv1 = nn.Conv2d(infilter, outfilter, kernel, padding=kernel // 2, stride=2)
        self.norm1 = nn.BatchNorm2d(outfilter, momentum=moment)

        self.conv2 = nn.Conv2d(outfilter, outfilter, kernel, padding=kernel // 2)
        self.norm2 = nn.BatchNorm2d(outfilter, momentum=moment)

        self.act = nn.SiLU(0.1)
        self.drop = nn.Dropout2d(drop)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.drop(self.act(self.norm1(x)))
        x = self.conv2(x)
        x = self.drop(self.act(self.norm2(x)))
        return x

### Dense Block

In [5]:
class DenseBlock(nn.Module):
    def __init__(self, filters:int,  drop:float, classes:int):
        super(DenseBlock, self).__init__()
        self.pool = nn.AdaptiveAvgPool2d((1,1))
        self.flat = nn.Flatten()
        self.drop = nn.Dropout(drop)
        self.final = nn.Linear(filters, classes)
    
    def forward(self,x):
        x = self.drop(self.flat(self.pool(x)))
        return F.softmax(self.final(x), dim=-1)

### CNNODE

In [6]:
class CNNODE(nn.Module):
    def __init__(self, num=4, filters=64, classes=10, gf=2, kernel=3, moment=0.99, drop=0.5, dropnode=0.1, dropconv=0.2):
        super(CNNODE, self).__init__()
        self.cnnode = nn.ModuleList([])
        self.upsamp = nn.ModuleList([])
        self.conv = nn.Conv2d(3, filters, 7, padding=2)
        self.norm = nn.BatchNorm2d(filters, momentum=moment)
        self.relu = nn.ReLU()
        # self.pool = nn.MaxPool2d(2, 2)

        for _ in range(num):
            f = CNNODEBlock(int(filters*gf), kernel, moment, dropnode)
            model = NeuralODE(f, sensitivity='adjoint', solver='rk4',
                               solver_adjoint='dopri5', atol_adjoint=1e-6, rtol_adjoint=1e-6)
            self.cnnode.append(model)
            self.upsamp.append(UPBlock(filters, int(filters*gf), kernel, moment, dropconv))
            filters = int(filters*gf)
        self.final = DenseBlock(filters, drop, classes)
        

    def forward(self, x, t_span):
        x = self.relu(self.norm(self.conv(x)))
        for neuralode, neuralnetwork in zip(self.cnnode, self.upsamp):
            x = neuralnetwork(x)
            # t, x = neuralode(x, t_span)
            # x = x[-1]
            
        return self.final(x)  

## TRAIN

### Training Class

### Training Values

In [None]:
t_span = torch.linspace(0, 0.8, 8).to(device)

model = CNNODE().to(device)  # Move the model to MPS or CPU
# state_dict = torch.load("mk2_dict.pt")  # Load the state_dict from the file
# model.load_state_dict(state_dict)
criterion = nn.CrossEntropyLoss()  # Cross-Entropy loss for classification
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=3e-3)  # Adam optimizer
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.75)

def l1_loss_on_bias(model, l1_lambda):
    l1_loss = 0.0
    for param in model.parameters():
        if param.requires_grad and param.ndimension() == 1:  # Check if it's a bias (usually 1D)
            l1_loss += torch.sum(torch.abs(param))
    return l1_lambda * l1_loss


# Step 5: Combined Training and Validation Loop with tqdm progress bar and accuracy updates
def train_and_validate(model, train_loader, val_loader, optimizer, criterion, scheduler, epochs=5):
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0

        # Training Loop with tqdm
        loop = tqdm(train_loader, desc=f'Epoch [{epoch+1}/{epochs}]', leave=False)
        for images, labels in loop:
            images, labels = images.to(device), labels.to(device)  # Move data to MPS or CPU
            optimizer.zero_grad()

            # Forward pass
            outputs = model(images, t_span)
            loss = criterion(outputs, labels)

            l1_penalty = l1_loss_on_bias(model, 5e-5)
            loss = loss + l1_penalty


            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            # Accumulate training stats
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

            # Calculate batch accuracy
            batch_accuracy = 100 * (correct_train / total_train)

            # Update tqdm progress bar with current loss and accuracy
            loop.set_postfix(train_loss=loss.item(), train_accuracy=batch_accuracy)

        # Adjust the learning rate
        scheduler.step()

        # Calculate training loss and accuracy
        train_loss = running_loss / len(train_loader)
        train_accuracy = 100 * correct_train / total_train

        # Validation Loop with tqdm
        model.eval()
        val_loss = 0.0
        correct_val = 0
        total_val = 0
        with torch.no_grad():
            val_loop = tqdm(val_loader, desc="Validation", leave=False)
            for images, labels in val_loop:
                images, labels = images.to(device), labels.to(device)  # Move validation data to device
                outputs = model(images, t_span)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()

                # Calculate batch accuracy
                batch_val_accuracy = 100 * (correct_val / total_val)

                # Update tqdm progress bar with current validation loss and accuracy
                val_loop.set_postfix(val_loss=loss.item(), val_accuracy=batch_val_accuracy)

        # Calculate validation loss and accuracy
        val_loss /= len(val_loader)
        val_accuracy = 100 * correct_val / total_val

        # Print epoch results
        if (epoch+1)%10 == 0:
            print(f'Epoch [{epoch+1}/{epochs}]')
            print(f'Training Loss: {train_loss:.4f}, Training Accuracy: {train_accuracy:.2f}%')
            print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%')

# Step 6: Train and Validate the Model
train_and_validate(model, train_loader, val_loader, optimizer, criterion, scheduler, epochs=240)

Your vector field callable (nn.Module) should have both time `t` and state `x` as arguments, we've wrapped it for you.
Your vector field callable (nn.Module) should have both time `t` and state `x` as arguments, we've wrapped it for you.
Your vector field callable (nn.Module) should have both time `t` and state `x` as arguments, we've wrapped it for you.
Your vector field callable (nn.Module) should have both time `t` and state `x` as arguments, we've wrapped it for you.


                                                                                                       

Epoch [10/240]
Training Loss: 1.9568, Training Accuracy: 68.43%
Validation Loss: 1.7953, Validation Accuracy: 67.03%


                                                                                                       

Epoch [20/240]
Training Loss: 1.8431, Training Accuracy: 77.75%
Validation Loss: 1.7455, Validation Accuracy: 72.25%


                                                                                                       

Epoch [30/240]
Training Loss: 1.7747, Training Accuracy: 82.72%
Validation Loss: 1.7295, Validation Accuracy: 73.93%


                                                                                                       

Epoch [40/240]
Training Loss: 1.7172, Training Accuracy: 86.98%
Validation Loss: 1.7183, Validation Accuracy: 75.33%


                                                                                                       

Epoch [50/240]
Training Loss: 1.6919, Training Accuracy: 88.48%
Validation Loss: 1.7266, Validation Accuracy: 74.71%


                                                                                                       

Epoch [60/240]
Training Loss: 1.6765, Training Accuracy: 89.21%
Validation Loss: 1.7326, Validation Accuracy: 74.57%


                                                                                                       

Epoch [70/240]
Training Loss: 1.6541, Training Accuracy: 91.08%
Validation Loss: 1.7311, Validation Accuracy: 75.31%


                                                                                                       

Epoch [80/240]
Training Loss: 1.6505, Training Accuracy: 91.46%
Validation Loss: 1.7491, Validation Accuracy: 73.72%


                                                                                                       

Epoch [90/240]
Training Loss: 1.6514, Training Accuracy: 91.29%
Validation Loss: 1.7576, Validation Accuracy: 73.51%


                                                                                                        

Epoch [100/240]
Training Loss: 1.6321, Training Accuracy: 92.69%
Validation Loss: 1.7548, Validation Accuracy: 73.63%


                                                                                                        

Epoch [110/240]
Training Loss: 1.6299, Training Accuracy: 92.73%
Validation Loss: 1.7589, Validation Accuracy: 74.07%


                                                                                                        

Epoch [120/240]
Training Loss: 1.6243, Training Accuracy: 93.03%
Validation Loss: 1.7753, Validation Accuracy: 71.89%


                                                                                                        

Epoch [130/240]
Training Loss: 1.6079, Training Accuracy: 94.15%
Validation Loss: 1.7589, Validation Accuracy: 73.89%


                                                                                                        

Epoch [140/240]
Training Loss: 1.6058, Training Accuracy: 94.22%
Validation Loss: 1.7632, Validation Accuracy: 73.24%


                                                                                                        

Epoch [150/240]
Training Loss: 1.6034, Training Accuracy: 94.41%
Validation Loss: 1.7637, Validation Accuracy: 73.34%


                                                                                                        

Epoch [160/240]
Training Loss: 1.5921, Training Accuracy: 95.06%
Validation Loss: 1.7695, Validation Accuracy: 72.62%


Epoch [166/240]:  88%|████████▊ | 121/137 [00:18<00:02,  6.67it/s, train_accuracy=95.1, train_loss=1.59]

In [8]:
def test_model(model, test_loader, criterion):
    model.eval()
    test_loss = 0.0
    correct_test = 0
    total_test = 0
    with torch.no_grad():
        test_loop = tqdm(test_loader, desc="Testing", leave=False)
        for images, labels in test_loop:
            images, labels = images.to(device), labels.to(device)  # Move test data to MPS or CPU
            outputs = model(images, t_span)
            loss = criterion(outputs, labels)
            test_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            total_test += labels.size(0)
            correct_test += (predicted == labels).sum().item()

            # Calculate batch accuracy
            batch_test_accuracy = 100 * (correct_test / total_test)

            # Update tqdm progress bar with current test loss and accuracy
            test_loop.set_postfix(test_loss=loss.item(), test_accuracy=batch_test_accuracy)

    # Calculate test loss and accuracy
    test_loss /= len(test_loader)
    test_accuracy = 100 * correct_test / total_test

    # Print 0est results
    print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')

# After training and validation, run the test evaluation
test_model(model, test_loader, criterion)

                                                                                            

Test Loss: 1.7762, Test Accuracy: 71.55%




In [9]:
# model = torch.load("mk2.pt")
torch.save(model.state_dict(), "mk3_dict.pt")