In [None]:
# Google Drive + working directory path
from google.colab import drive
drive.mount('/content/drive')

PATH = '/content/drive/MyDrive/Vision24/'
VERSION = 'dynamicPruning'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import os

print(os.path.isdir(PATH))
assert os.path.isdir(PATH)

True


In [None]:
CHECKPOINT_PATH = os.path.join(PATH, VERSION)
os.makedirs(CHECKPOINT_PATH, exist_ok=True)


print(os.path.isdir(CHECKPOINT_PATH))
assert os.path.isdir(CHECKPOINT_PATH)

True


In [None]:
# verify >16GB of RAM
!nvidia-smi

Mon Nov 18 01:43:33 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off | 00000000:00:04.0 Off |                    0 |
| N/A   37C    P0              66W / 400W |      2MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [None]:
!pip install torch
!pip install zetascale
!pip install swarms
!pip install torchinfo

import torch
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from torch import nn, Tensor
from zeta.nn import SSM
from einops.layers.torch import Reduce

In [None]:
# For pruning
import math
import torch.nn.utils.prune as prune

def calculate_sparsity(epoch, start_epoch, end_epoch, initial_sparsity, final_sparsity):
    if epoch < start_epoch:
        return initial_sparsity
    if epoch > end_epoch:
        return final_sparsity
    progress = (epoch - start_epoch) / (end_epoch - start_epoch)
    return initial_sparsity + (final_sparsity - initial_sparsity) * (1 - math.cos(progress * math.pi)) / 2

def dynamic_pruning(model, epoch, start_epoch, end_epoch, initial_sparsity, final_sparsity):
    sparsity = calculate_sparsity(epoch, start_epoch, end_epoch, initial_sparsity, final_sparsity)
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv1d) or isinstance(module, nn.Linear):
          # Remove any existing masks to avoid cumulative pruning
            if hasattr(module, "weight_mask"):
                prune.remove(module, "weight")

            # Reduce pruning for critical layers
            if "first_layer" in name or "output_head" in name:
                layer_sparsity = sparsity * 0.5  # Prune less in critical layers
            else:
                layer_sparsity = sparsity
            prune.l1_unstructured(module, name="weight", amount=layer_sparsity)
            print(f"Epoch {epoch}: Pruned {name} to {layer_sparsity * 100:.2f}% sparsity")


def check_sparsity(model, epoch):
    print(f"After epoch {epoch+1}:")
    for name, module in model.named_modules():
        if hasattr(module, 'weight') and module.weight is not None:
            sparsity = float(torch.sum(module.weight == 0)) / module.weight.nelement()
            print(f"\tSparsity in {name}: {sparsity:.2%}")



In [None]:
def pair(t):
    return t if isinstance(t, tuple) else (t, t)


def output_head(dim: int, num_classes: int):
    """
    Creates a head for the output layer of a model.

    Args:
        dim (int): The input dimension of the head.
        num_classes (int): The number of output classes.

    Returns:
        nn.Sequential: The output head module.
    """
    return nn.Sequential(
        Reduce("b s d -> b d", "mean"),
        nn.LayerNorm(dim),
        nn.Linear(dim, num_classes),
    )


class VisionEncoderMambaBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        dt_rank: int,
        dim_inner: int,
        d_state: int,
    ):
        super().__init__()
        self.dim = dim
        self.dt_rank = dt_rank
        self.dim_inner = dim_inner
        self.d_state = d_state
        self.forward_conv1d = nn.Conv1d(in_channels=dim, out_channels=dim, kernel_size=1)
        self.backward_conv1d = nn.Conv1d(in_channels=dim, out_channels=dim, kernel_size=1)
        self.norm = nn.LayerNorm(dim)
        self.silu = nn.SiLU()
        self.ssm = SSM(dim, dt_rank, dim_inner, d_state)

        self.proj = nn.Linear(dim, dim) # projection layer
        self.softplus = nn.Softplus()


    def forward(self, x: torch.Tensor):
        b, s, d = x.shape

        skip = x            # skip connection
        x = self.norm(x)
        z = self.silu(self.proj(x))   # project --> activation for gating
        x = self.proj(x)

        x1 = self.process_direction(x, self.forward_conv1d, self.ssm,)
        x2 = self.process_direction(x, self.backward_conv1d, self.ssm,)

        x1 *= z
        x2 *= z

        # residual
        return x1 + x2 + skip

    def process_direction(
        self,
        x: Tensor,
        conv1d: nn.Conv1d,
        ssm: SSM,
    ):
        x = rearrange(x, "b s d -> b d s")
        x = self.softplus(conv1d(x))
        x = rearrange(x, "b d s -> b s d")
        x = ssm(x)
        return x


class Vim(nn.Module):
    def __init__(
        self,
        dim: int,
        dt_rank: int = 32,
        dim_inner: int = None,
        d_state: int = None,
        num_classes: int = None,
        image_size: int = 224,
        patch_size: int = 16,
        channels: int = 3,
        dropout: float = 0.1,
        depth: int = 12,
        *args,
        **kwargs,
    ):
        super().__init__()
        self.dim = dim
        self.dt_rank = dt_rank
        self.dim_inner = dim_inner
        self.d_state = d_state
        self.num_classes = num_classes
        self.image_size = image_size
        self.patch_size = patch_size
        self.channels = channels
        self.dropout = dropout
        self.depth = depth

        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)
        patch_dim = channels * patch_height * patch_width

        self.to_patch_embedding = nn.Sequential(
            Rearrange(
                "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
                p1=patch_height,
                p2=patch_height,
            ),
            nn.Linear(patch_dim, dim),
        )

        self.dropout = nn.Dropout(dropout)
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.to_latent = nn.Identity()
        self.layers = nn.ModuleList()        # encoder layers
        for _ in range(depth):
            self.layers.append(
                VisionEncoderMambaBlock(
                    dim=dim,
                    dt_rank=dt_rank,
                    dim_inner=dim_inner,
                    d_state=d_state,
                    *args,
                    **kwargs,
                )
            )
        self.output_head = output_head(dim, num_classes)

    def forward(self, x: Tensor):
        b, c, h, w = x.shape
        x = self.to_patch_embedding(x)
        b, n, _ = x.shape
        cls_tokens = repeat(self.cls_token, "() n d -> b n d", b=b)
        # x = torch.cat((cls_tokens, x), dim=1)
        x = self.dropout(x)
        for layer in self.layers:
            x = layer(x)

        x = self.to_latent(x)
        # x = reduce(x, "b s d -> b d", "mean")
        return self.output_head(x)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from einops import repeat

model = Vim(dim=96,
            dt_rank=16,
            dim_inner=96,
            d_state=96,
            num_classes=10,
            image_size=32,
            patch_size=4,
            channels=3,
            dropout=0.1,
            depth=10,)

batch_size = 256
epochs = 100

# Pruning config
start_epoch = 10 # Epoch at which pruning begins
end_epoch = 70 # Epoch at which pruning stops
initial_sparsity = 0.0 # No pruning initially)
final_sparsity = 0.6  # Target 60% sparsity

transform = transforms.Compose([
    # transforms.Resize([224,224]),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

training_data = torchvision.datasets.CIFAR10(root = './dataa',train=True,download=True,transform=transform)
testing_data = torchvision.datasets.CIFAR10(root = './data',train=False,download=True,transform=transform)

training_loader = DataLoader(training_data, batch_size=batch_size, shuffle=True, num_workers=4)
testing_loader = DataLoader(testing_data, batch_size=batch_size, shuffle=True, num_workers=4)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
print('epochs: ', epochs)
print('Training data: ', len(training_loader))
print('Total steps: ', epochs*len(training_loader))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

epochs:  100
Training data:  196
Total steps:  19600
Device: cuda


In [None]:
import time
import pandas as pd
import numpy as np

_start = time.time()
print(_start)

__losses = []
__n = 100

1731894223.0920682


In [None]:
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr = 0.003)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

for epoch in range(epochs):
    if epoch >= start_epoch and epoch <= end_epoch:
        # Apply dynamic pruning only during pruning phase
        dynamic_pruning(model, epoch, start_epoch, end_epoch, initial_sparsity, final_sparsity)

    model.train()
    running_loss = 0.0

    for i, (inputs,labels) in enumerate(training_loader):
        inputs,labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs,labels)
        loss.backward()
        optimizer.step()
        # print(f'batch {i} completed')

        __losses.append(loss.item())
        running_loss += loss.item()

        if(i % __n == 0):
            print(f"Epoch [{epoch + 1}/{epochs}], Step [{i + 1}/{len(training_loader)}], Loss: {running_loss / __n:.4f}")
            running_loss = 0.0

    print(f"Epoch {epoch+1} finished. Saving checkpoint...")
    if epoch >= start_epoch and epoch <= end_epoch:
        check_sparsity(model, epoch)

    scheduler.step()


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
	Sparsity in layers.4.ssm.dt_proj_layer: 13.67%
	Sparsity in layers.4.proj: 13.66%
	Sparsity in layers.5.forward_conv1d: 13.66%
	Sparsity in layers.5.backward_conv1d: 13.66%
	Sparsity in layers.5.norm: 0.00%
	Sparsity in layers.5.ssm.deltaBC_layer: 13.66%
	Sparsity in layers.5.ssm.dt_proj_layer: 13.67%
	Sparsity in layers.5.proj: 13.66%
	Sparsity in layers.6.forward_conv1d: 13.66%
	Sparsity in layers.6.backward_conv1d: 13.66%
	Sparsity in layers.6.norm: 0.00%
	Sparsity in layers.6.ssm.deltaBC_layer: 13.66%
	Sparsity in layers.6.ssm.dt_proj_layer: 13.67%
	Sparsity in layers.6.proj: 13.66%
	Sparsity in layers.7.forward_conv1d: 13.66%
	Sparsity in layers.7.backward_conv1d: 13.66%
	Sparsity in layers.7.norm: 0.00%
	Sparsity in layers.7.ssm.deltaBC_layer: 13.66%
	Sparsity in layers.7.ssm.dt_proj_layer: 13.67%
	Sparsity in layers.7.proj: 13.66%
	Sparsity in layers.8.forward_conv1d: 13.66%
	Sparsity in layers.8.backward_conv1d: 

In [None]:
torch.save(model.state_dict(), os.path.join(CHECKPOINT_PATH, f'final_checkpoint_with_masks.pt'))

for name, module in model.named_modules():
    if hasattr(module, "weight_mask") and (isinstance(module, nn.Conv1d) or isinstance(module, nn.Linear)):
        prune.remove(module, "weight")
        print(f"Removed pruning mask from {name}")

torch.save(model.state_dict(), os.path.join(CHECKPOINT_PATH, f'final_checkpoint.pt'))

Removed pruning mask from to_patch_embedding.1
Removed pruning mask from layers.0.forward_conv1d
Removed pruning mask from layers.0.backward_conv1d
Removed pruning mask from layers.0.ssm.deltaBC_layer
Removed pruning mask from layers.0.ssm.dt_proj_layer
Removed pruning mask from layers.0.proj
Removed pruning mask from layers.1.forward_conv1d
Removed pruning mask from layers.1.backward_conv1d
Removed pruning mask from layers.1.ssm.deltaBC_layer
Removed pruning mask from layers.1.ssm.dt_proj_layer
Removed pruning mask from layers.1.proj
Removed pruning mask from layers.2.forward_conv1d
Removed pruning mask from layers.2.backward_conv1d
Removed pruning mask from layers.2.ssm.deltaBC_layer
Removed pruning mask from layers.2.ssm.dt_proj_layer
Removed pruning mask from layers.2.proj
Removed pruning mask from layers.3.forward_conv1d
Removed pruning mask from layers.3.backward_conv1d
Removed pruning mask from layers.3.ssm.deltaBC_layer
Removed pruning mask from layers.3.ssm.dt_proj_layer
Remov

In [None]:
_end = time.time()
print(_end)

with open(os.path.join(CHECKPOINT_PATH, 'trainning_time.txt'), 'w') as f:
  f.write(f"Start Time: {_start}\nEnd Time:{_end}\nTotal Time:{_end - _start}")


DF = pd.DataFrame(np.array(__losses))
DF.to_csv(os.path.join(CHECKPOINT_PATH, 'losses.csv'))


1731910072.9769154


In [None]:

from torchinfo import summary
summary(model, input_size=(batch_size, 3, 32, 32))

Layer (type:depth-idx)                   Output Shape              Param #
Vim                                      [256, 10]                 96
├─Sequential: 1-1                        [256, 64, 96]             --
│    └─Rearrange: 2-1                    [256, 64, 48]             --
│    └─Linear: 2-2                       [256, 64, 96]             4,704
├─Dropout: 1-2                           [256, 64, 96]             --
├─ModuleList: 1-3                        --                        --
│    └─VisionEncoderMambaBlock: 2-3      [256, 64, 96]             --
│    │    └─LayerNorm: 3-1               [256, 64, 96]             192
│    │    └─Linear: 3-2                  [256, 64, 96]             9,312
│    │    └─SiLU: 3-3                    [256, 64, 96]             --
│    │    └─Linear: 3-4                  [256, 64, 96]             (recursive)
│    │    └─Conv1d: 3-5                  [256, 96, 64]             9,312
│    │    └─Softplus: 3-6                [256, 96, 64]            

In [None]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for inputs,labels in testing_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(f"accuracy of the model test images: {(100 * correct / total):.2f}%")

accuracy of the model test images: 70.02%


In [None]:
print(torch.cuda.memory_summary(device=None, abbreviated=False))

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  29278 KiB |  38768 MiB |   5439 TiB |   5439 TiB |
|       from large pool |  19712 KiB |  38753 MiB |   5438 TiB |   5438 TiB |
|       from small pool |   9566 KiB |     17 MiB |      1 TiB |      1 TiB |
|---------------------------------------------------------------------------|
| Active memory         |  29278 KiB |  38768 MiB |   5439 TiB |   5439 TiB |
|       from large pool |  19712 KiB |  38753 MiB |   5438 TiB |   5438 TiB |
|       from small pool |   9566 KiB |     17 MiB |      1 TiB |      1 TiB |
|---------------------------------------------------------------