In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import random
import matplotlib.pyplot as plt
from torch.utils.data import random_split
import torch.optim as optim
import lovely_tensors as lt
import os
import tarfile
import urllib.request

from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.datasets import ImageFolder

lt.monkey_patch()

import os
os.environ['TORCH_LOGS'] = "output_code"
##https://github.com/fastai/imagenette

## https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz

In [2]:
# Set the URL and local path for the dataset
url = "https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz"
local_path = "imagenette2-160.tgz"

# Download the dataset
urllib.request.urlretrieve(url, local_path)

# Extract the dataset
with tarfile.open(local_path, "r:gz") as tar:
    tar.extractall()

# Remove the downloaded archive
os.remove(local_path)

# Set the path to the extracted dataset
dataset_path = "imagenette2-160"

# Define the transformations for the dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((160, 160)),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load the dataset using ImageFolder
dataset = ImageFolder(root=dataset_path, transform=transform)

# Split the dataset into train and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Print the sizes of the train and validation sets
print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")

Train dataset size: 10715
Validation dataset size: 2679


In [3]:
class ImageMLP(nn.Module):
    def __init__(self):
        super(ImageMLP, self).__init__()
        self.flatten = nn.Flatten()
        self.linear = BitLinearTrain(160 * 160 * 3, 10)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.flatten(x)
        x = self.linear(x)
        x = self.softmax(x)
        return x

class BitLinearTrain(nn.Linear):
    def forward(self, x):
        w = self.weight
        x_norm = x
        x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
        w_quant = w + (weight_quant(w) - w).detach()
        y = F.linear(x_quant, w_quant)
        return y

def activation_quant(x):
    scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
    y = (x * scale).round().clamp_(-128, 127) / scale
    return y

def weight_quant(w):
    scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
    u = (w * scale).round().clamp_(-1, 1) / scale
    return u

In [4]:
model = ImageMLP()
image = torch.randn(1, 3, 160, 160)  # Batch size 1, single channel, 160x160 image
output = model(image)
print(output)

tensor[1, 10] x∈[0.033, 0.165] μ=0.100 σ=0.050 grad SoftmaxBackward0 [[0.066, 0.108, 0.157, 0.165, 0.143, 0.033, 0.041, 0.141, 0.090, 0.057]]


In [5]:
from torch.utils.data import DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move model to device
model = ImageMLP().to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    # Training
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Calculate loss and accuracy
        train_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs.data, 1)
        train_total += labels.size(0)
        train_correct += (predicted == labels).sum().item()
    
    train_loss = train_loss / len(train_dataset)
    train_acc = train_correct / train_total
    
    # Validation
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Calculate loss and accuracy
            val_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()
    
    val_loss = val_loss / len(val_dataset)
    val_acc = val_correct / val_total
    
    print(f"Epoch [{epoch+1}/{num_epochs}], "
          f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
          f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

Epoch [1/1], Train Loss: 1.9501, Train Acc: 0.5159, Val Loss: 1.9088, Val Acc: 0.5558


In [6]:
for name, param in model.named_parameters():
    if 'weight' in name:
        print(f"Layer: {name}")
        print(f"Weight Shape: {param.shape}")
        print(f"Weight Data: {param}")
        print("---")


Layer: linear.weight
Weight Shape: torch.Size([10, 76800])
Weight Data: Parameter containing:
Parameter[10, 76800] n=768000 (2.9Mb) x∈[-0.012, 0.011] μ=4.038e-05 σ=0.003 grad cuda:0
---


In [7]:
from torch import Tensor 

def roundclip(x, a, b):
    return torch.max(a, torch.min(b, torch.round(x)))

def quantize_weights(weights):
    # Compute the average absolute value of the weight matrix
    gamma = torch.mean(torch.abs(weights))
    
    # Scale the weight matrix by the average absolute value
    scaled_weights = weights / (gamma + 1e-8)
    
    # Round each scaled weight to the nearest integer in {-1, 0, +1}
    quantized_weights = roundclip(scaled_weights, Tensor([-1]), Tensor([1]))
    
    return quantized_weights

In [8]:
weights = model.linear.weight.detach().cpu()

# Quantize the weights
quantized_weights = quantize_weights(weights)

# Assign the quantized weights back to the model
model.linear.weight.data = quantized_weights

In [9]:
quantized_weights = quantized_weights.to(device)

In [10]:
def down_size(size):
    assert size[-1] % 4 == 0, f"{size} last dim not divisible by four"
    return (*size[:-1], size[-1] // 4)

def up_size(size):
    return (*size[:-1], size[-1] * 4)

#unpack int8
@torch.compile
def unpack_uint8_to_trinary2(uint8_data) -> torch.Tensor:
    """Get the original weight from the normalized float weight format"""
    # since we are using uint8 we will decode 4 entries per byte
    shape = uint8_data.shape
    first_elements = ((uint8_data >> 6) & 0b11).to(torch.int8) - 1
    second_elements = ((uint8_data >> 4) & 0b11).to(torch.int8) - 1
    third_elements = ((uint8_data >> 2) & 0b11).to(torch.int8) - 1
    fourth_elements = (uint8_data & 0b11).to(torch.int8) - 1
    return torch.stack([first_elements, second_elements, third_elements, fourth_elements], dim=-1).view(up_size(shape))

#packing uint8
@torch.compile
def pack_int2(uint8_data) -> torch.Tensor:
    # converting to uint8 for operations
    shape = uint8_data.shape
    assert shape[-1] % 4 == 0
    uint8_data = uint8_data.contiguous().view(-1)
    packed_data = (uint8_data[::4] << 6 | uint8_data[1::4] << 4 | uint8_data[2::4] << 2 | uint8_data[3::4]).view(down_size(shape))
    return packed_data

In [11]:
shifted_layer = (quantized_weights + 1.0).to(torch.uint8).to(device)
shifted_layer

tensor[10, 76800] u8 n=768000 (0.7Mb) x∈[0, 2] μ=1.012 σ=0.845 cuda:0

In [12]:
packed = pack_int2(shifted_layer).to(device)
unpacked = unpack_uint8_to_trinary2(packed).to(device)

In [13]:
print(unpacked)
print(unpacked.dtype)
print(unpacked.allclose(quantized_weights.to(torch.int8)))
assert(unpacked.allclose(quantized_weights.to(torch.int8)))

tensor[10, 76800] i8 n=768000 (0.7Mb) x∈[-1, 1] μ=0.012 σ=0.845 cuda:0
torch.int8
True
