In [1]:
!pip install torchsummary
!export 'PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512'
import torch
import torchvision
from torch import Tensor
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from torchsummary import summary

import os
import zipfile
from torch.utils.data import Dataset
import json 
from PIL import Image
import gc

cpu_device = torch.device('cpu')
cuda_device = torch.device('cuda')
DTYPENET = torch.float32
DTYPEDATA = torch.float16 #maybe even float16. Sadly it doesn't do anything
DEVICE = cpu_device
BATCH_SIZE = 256
LR = 1e-4

if torch.cuda.is_available():
    DEVICE = cuda_device
    torch.cuda.empty_cache()
print(DEVICE)
#device = torch.device("cuda")

torch.autograd.set_detect_anomaly(True)

cpu


<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7e6836e907f0>

In [2]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                         shuffle=False)
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:03<00:00, 44.2MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [3]:
class BitConv2d(nn.Conv2d):
    def __init__(self, isDepth, *args, num_bits: int = 2, **kwargs):
        super().__init__(*args, **kwargs)
        self.num_bits = num_bits

        self.eps:float = 1e-5
        self.quantization_range: int = 2 ** (num_bits - 1) # Q_b in the paper
        self.isDepth = isDepth

    def ste_weights(self, weights_gamma: float) -> Tensor:
        eps: float = 1e-7
        scaled_weights:Tensor = self.weight / (weights_gamma + eps)
        bin_weights_no_grad: Tensor = torch.clamp(torch.round(scaled_weights), min=-1, max=1)
        bin_weights_with_grad: Tensor = (bin_weights_no_grad - self.weight).detach() + self.weight
        return bin_weights_with_grad


    def binarize_weights(self, weights_gamma: float) -> Tensor:
        binarized_weights = self.ste_weights(weights_gamma)
        return binarized_weights


    def quantize_activations(self, _input:Tensor, input_gamma: float) -> Tensor:
        # Equation 4 BitNet paper
        quantized_input = torch.clamp(
                _input * self.quantization_range / input_gamma,
                -self.quantization_range + self.eps,
                self.quantization_range - self.eps,
            )
        #quantized_input = torch.floor(quantized_input)
        return quantized_input


    def dequantize_activations(self, _input: Tensor, input_gamma: float, beta: float) -> Tensor:
        return _input * input_gamma * beta / self.quantization_range


    def forward(self, _input: Tensor) -> Tensor:
        # print("input mean and sd = ", torch.mean(_input).item(), torch.std(_input).item())
        # print("max and min = ", torch.max(_input).item(), torch.min(_input).item())
        normalized_input: Tensor = nn.functional.layer_norm(_input, (_input.shape[1:]))
        input_gamma: float = normalized_input.abs().max().item()
        # print("absmax = ", input_gamma)
        weight_abs_mean: float = self.weight.abs().mean().item()

        binarized_weights = self.binarize_weights(weight_abs_mean)
        input_quant = self.quantize_activations(normalized_input, input_gamma)
        output = torch.nn.functional.conv2d(
            input=input_quant,
            weight=binarized_weights,
            bias=self.bias,
            stride=self.stride,
            padding=self.padding,
            dilation=self.dilation,
            groups=self.groups
        )#input=input_quant
        # if torch.any(torch.abs(output) >= 128):
        #     print("alarm! int8 overflow")
        if not self.isDepth:
            output = self.dequantize_activations(output, input_gamma, weight_abs_mean)
        
        return output
    
class BitLinear(nn.Linear):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = True,
        num_bits: int = 8,
    ):
        super().__init__(in_features, out_features, bias)
        self.eps:float = 1e-5
        self.quantization_range: int = 2 ** (num_bits - 1) # Q_b in the paper


    def ste_weights(self, weights_gamma: float) -> Tensor:
        eps: float = 1e-7
        scaled_weights:Tensor = self.weight / (weights_gamma + eps)
        bin_weights_no_grad: Tensor = torch.clamp(torch.round(scaled_weights), min=-1, max=1)
        bin_weights_with_grad: Tensor = (bin_weights_no_grad - self.weight).detach() + self.weight
        return bin_weights_with_grad


    def binarize_weights(self, weights_gamma: float) -> Tensor:
        binarized_weights = self.ste_weights(weights_gamma)
        return binarized_weights


    def quantize_activations(self, _input:Tensor, input_gamma: float) -> Tensor:
        # Equation 4 BitNet paper
        quantized_input = torch.clamp(
                _input * self.quantization_range / input_gamma,
                -self.quantization_range + self.eps,
                self.quantization_range - self.eps,
            )
        quantized_input = torch.floor(quantized_input)
        return quantized_input


    def dequantize_activations(self, _input: Tensor, input_gamma: float, beta: float) -> Tensor:
        return _input * input_gamma * beta / self.quantization_range


    def forward(self, _input: Tensor) -> Tensor:
        normalized_input: Tensor =nn.functional.layer_norm(_input,(_input.shape[1:]))
        input_gamma: float = normalized_input.abs().max().item()
        weight_abs_mean: float = self.weight.abs().mean().item()

        binarized_weights = self.binarize_weights(weight_abs_mean)
        input_quant = self.quantize_activations(normalized_input, input_gamma)
        output = torch.nn.functional.linear(input_quant, binarized_weights, self.bias)#input_quant
        output = self.dequantize_activations(output, input_gamma, weight_abs_mean)

        return output

class CNNBLOCK_DS(nn.Module):#i should use more *ags and **kwargs
    def __init__(self, in_channels_, out_channels_,
                 kernel_size_=3, stride_=1, 
                 padding_=1, bias_=False):
        super().__init__()
        self.conv_depth = BitConv2d(isDepth = True,in_channels = in_channels_, 
                                    out_channels =in_channels_, 
                                    kernel_size =kernel_size_, 
                                    stride = stride_, 
                                    padding = padding_, 
                                    dilation = 1,
                                    groups =in_channels_,
                                    bias = bias_)
        self.conv_sep = BitConv2d(isDepth = False, in_channels = in_channels_, 
                                  out_channels = out_channels_, 
                                  kernel_size=1, 
                                  stride = 1,
                                  padding= 0, 
                                  dilation=1,
                                  groups = 1,
                                  bias=False )
        self.lrlu = nn.LeakyReLU(0.1)
    def forward(self, x):
        return self.lrlu(self.conv_sep(self.conv_depth(x)))#i think we need bn here
        

class RESBLOCK(nn.Module):
    def __init__(self, list_of_params):
        super().__init__()
        modules = []
        for block in list_of_params:
            in_channels_ = block[0]
            out_channels_ = block[1]
            modules.append(CNNBLOCK_DS(in_channels_, out_channels_))
        self.suka = nn.Sequential(*modules)
    def forward(self, x):
        return x + self.suka(x)

class YOLONET(nn.Module):
    def __init__(self):
        super().__init__()
        self.pool = nn.MaxPool2d(2,2)
        
        self.convBlock1 = CNNBLOCK_DS(3, 64)
        
        #we need 5 32 layers, 5 64 layers, 5 128 layers and 5 256 layers. Don't forget to maxpool
        self.resblock64_1 = RESBLOCK([[64, 64], [64,64] ])#how many layers should i have in my resblock???
        self.resblock64_2 = RESBLOCK([[64, 64], [64,64] ])
        self.resblock64_3 = RESBLOCK([[64, 64], [64,64] ])
        self.resblock64_4 = RESBLOCK([[64, 64], [64,64] ])
        self.resblock64_5 = RESBLOCK([[64, 64], [64,64] ])

        #how do we upscale??
        self.convInter_64_128 = CNNBLOCK_DS(64, 128)
        self.resblock128_1 = RESBLOCK([[128, 128], [128, 128]])
        self.resblock128_2 = RESBLOCK([[128, 128], [128, 128]])
        self.resblock128_3 = RESBLOCK([[128, 128], [128, 128]])
        self.resblock128_4 = RESBLOCK([[128, 128], [128, 128]])
        self.resblock128_5 = RESBLOCK([[128, 128], [128, 128]])
        self.convInter_128_256 = CNNBLOCK_DS(128, 256)
        self.resblock256_1 = RESBLOCK([[256, 256], [256, 256]])
        self.resblock256_2 = RESBLOCK([[256, 256], [256, 256]])
        self.resblock256_3 = RESBLOCK([[256, 256], [256, 256]])
        self.resblock256_4 = RESBLOCK([[256, 256], [256, 256]])
        self.resblock256_5 = RESBLOCK([[256, 256], [256, 256]])
        self.resblock256_6 = RESBLOCK([[256, 256], [256, 256]])
        
        #self.convLast = nn.Conv2d(128, 16, 1 padding=0, bias=False)#how do we ensure we have [:, 16, 13, 13] output tensor????
        self.avgPool = nn.AvgPool2d(2,3)
        self.fc = BitLinear(256, 10)

    def forward(self,x):
        x = self.convBlock1(x)
        x = self.resblock64_1(x)
        x = self.resblock64_2(x)
        x = self.resblock64_3(x)
        x = self.pool(x)
        x = self.resblock64_4(x)
        x = self.resblock64_5(x)
        x = self.pool(x)
        x = self.convInter_64_128(x)
        x = self.resblock128_1(x)
        x = self.resblock128_2(x)
        x = self.pool(x)
        x = self.resblock128_3(x)
        x = self.pool(x)
        x = self.resblock128_4(x)
        x = self.resblock128_5(x)
        x = self.pool(x)
        x = self.convInter_128_256(x)
        x = self.resblock256_1(x)
        x = self.resblock256_2(x)
        x = self.resblock256_3(x)
        x = self.resblock256_4(x)
        x = self.resblock256_5(x)
        x = self.resblock256_6(x)
        x = torch.flatten(x,1)
        #print('1', x.shape)
        x = self.fc(x)#F.relu(self.fc(x))
        return x

In [4]:
resResResNet = YOLONET().to(DTYPENET)
for layer in resResResNet.modules():
    if isinstance(layer, nn.BatchNorm2d):
        layer.float()

resResResNet = resResResNet.to(DEVICE)
criterionRes = nn.CrossEntropyLoss()
optimizerRes = optim.Adam(resResResNet.parameters(), lr=LR)

In [5]:
if False:
    resResResNet.load_state_dict(torch.load("/kaggle/input/weaintsuffering/theCIFARbitnetWithNoBatchNorm (1).pt"))

In [6]:
if False:
    for epoch in range(10):  # loop over the dataset multiple times
        print("EPOCH",epoch+1)
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data

        # zero the parameter gradients
            optimizerRes.zero_grad()

        # forward + backward + optimize
            outputs = resResResNet(inputs.to(DEVICE))
            loss = criterionRes(outputs, labels.to(DEVICE))
            if torch.any(torch.isnan(loss)).item() or torch.any(torch.isinf(loss)).item():
                print('NAN OR INF', loss.item())
            loss.backward()
            optimizerRes.step()

        # print statistics
            running_loss += loss.item()
            if i % 20 == 19:    
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 20:.3f}')
                running_loss = 0.0

    torch.save(resResResNet.state_dict(), '/kaggle/working/theCIFARbitnetWithNoBatchNorm.pt')
    print('Finished Training')

In [7]:
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
    for data in testloader:
        images, labels = data
        # calculate outputs by running images through the network
        outputs = resResResNet(images.to(DEVICE))
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels.to(DEVICE)).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')

Accuracy of the network on the 10000 test images: 9 %
