In [1]:
import argparse
import os
import time
import shutil

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms

from models import *   # bring everything in the folder models
from models.project_vgg import *

global best_prec
use_gpu = torch.cuda.is_available()
print('=> Building model...')
    
    
batch_size = 128

model_name = "project_part2_VGG16_quant"
#model = cifar10()
# model = VGG19()

model = VGG16_quant_project() # VGG16()

# model = model.cuda() # use din the cell below
        

normalize = transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262])


train_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ]))
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)


test_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ]))

testloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)


print_freq = 100 # every 100 batches, accuracy printed. Here, each batch includes "batch_size" data points
# CIFAR10 has 50,000 training data, and 10,000 validation data.

def train(trainloader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter()   ## at the begining of each epoch, this should be reset
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    model.train()

    end = time.time()  # measure current time
    
    for i, (input, target) in enumerate(trainloader):
        # measure data loading time
        data_time.update(time.time() - end)  # data loading time

        input, target = input.cuda(), target.cuda()

        # compute output
        output = model(input)
        loss = criterion(output, target)

        # measure accuracy and record loss
        prec = accuracy(output, target)[0]
        losses.update(loss.item(), input.size(0))
        top1.update(prec.item(), input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end) # time spent to process one batch
        end = time.time()


        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec {top1.val:.3f}% ({top1.avg:.3f}%)'.format(
                   epoch, i, len(trainloader), batch_time=batch_time,
                   data_time=data_time, loss=losses, top1=top1))

            

def validate(val_loader, model, criterion ):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
         
            input, target = input.cuda(), target.cuda()

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            prec = accuracy(output, target)[0]
            losses.update(loss.item(), input.size(0))
            top1.update(prec.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % print_freq == 0:  # This line shows how frequently print out the status. e.g., i%5 => every 5 batch, prints out
                print('Test: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec {top1.val:.3f}% ({top1.avg:.3f}%)'.format(
                   i, len(val_loader), batch_time=batch_time, loss=losses,
                   top1=top1))

    print(' * Prec {top1.avg:.3f}% '.format(top1=top1))
    return top1.avg


def accuracy(output, target, topk=(1,5)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk) # 5
    batch_size = target.size(0) # 128

    _, pred = output.topk(maxk, 1, True, True) # topk(k, dim=None, largest=True, sorted=True)
                                    # will output (max value, its index)
    pred = pred.t()               # transpose
    correct = pred.eq(target.view(1, -1).expand_as(pred))   # "-1": calculate automatically

    res = []
    for k in topk: # 1, 5
        # correct_k = correct[:k].view(-1).float().sum(0)  # view(-1): make a flattened 1D tensor
        correct_k = correct[:k].reshape(-1).float().sum(0)  # view(-1): make a flattened 1D tensor
        
        res.append(correct_k.mul_(100.0 / batch_size))   # correct: size of [maxk, batch_size]
    return res


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n    ## n is impact factor
        self.count += n
        self.avg = self.sum / self.count

        
def save_checkpoint(state, is_best, fdir):
    filepath = os.path.join(fdir, 'checkpoint.pth')
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(fdir, 'vgg16_best.pth.tar'))


def adjust_learning_rate(optimizer, epoch):
    """For resnet, the lr starts from 0.1, and is divided by 10 at 80 and 120 epochs"""
    adjust_list = [150, 225]
    if epoch in adjust_list:
        for param_group in optimizer.param_groups:
            param_group['lr'] = param_group['lr'] * 0.1        

#model = nn.DataParallel(model).cuda()
#all_params = checkpoint['state_dict']
#model.load_state_dict(all_params, strict=False)
#criterion = nn.CrossEntropyLoss().cuda()
#validate(testloader, model, criterion)

=> Building model...
Files already downloaded and verified
Files already downloaded and verified


In [2]:
train = False
if train:
    from IPython.display import clear_output
    
    lr = 1e-3
    # lr = 0.1
    # lr = 0.05
    
    # weight_decay = 1e-4
    weight_decay = 0 # 5e-4
    epochs = 200 # orig: 30
    best_prec = 0
    
    model = model.cuda()
    criterion = nn.CrossEntropyLoss().cuda()
    # optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
    # weight decay: for regularization to prevent overfitting
    
    if not os.path.exists('result'):
        os.makedirs('result')
        
    fdir = 'result/'+str(model_name)
    
    if not os.path.exists(fdir):
        os.makedirs(fdir)
    
    decay = []
    no_decay = []
    
    # keep weight decay nonzero for the full precision layers:
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
    
        # First and last layers of VGG16 (full precision)
        if "features.0" in name or "classifier" in name:
            decay.append(param) #use weight decay
        else:
            no_decay.append(param) # quantized layers get no WD
    
    optimizer = torch.optim.SGD(
        [
            {"params": decay, "weight_decay": 5e-4},
            {"params": no_decay, "weight_decay": 0.0},
        ],
        lr=lr,
        momentum=0.9,
    )
    
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-5)
    
    for epoch in range(0, epochs):
        # adjust_learning_rate(optimizer, epoch)
        
        # clear_output(wait=True)
        # print('best acc: {:1f}'.format(best_prec))
    
        train(trainloader, model, criterion, optimizer, epoch)
        scheduler.step()
        # evaluate on test set
        
        
        print("Validation starts")
        prec = validate(testloader, model, criterion)
    
        # remember best precision and save checkpoint
        is_best = prec > best_prec
        best_prec = max(prec,best_prec)
        print('best acc: {:1f}'.format(best_prec))
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_prec': best_prec,
            'optimizer': optimizer.state_dict(),
        }, is_best, fdir)
        

In [2]:
# model_name = 'VGG16'
""""
LOADING MODEL AND TESTING PART
"""
fdir = 'result/'+str(model_name)+'/2bit_training_best_above85.tar'

checkpoint = torch.load(fdir)
model.load_state_dict(checkpoint['state_dict'])

criterion = nn.CrossEntropyLoss().cuda()

model.eval()
model.cuda()


prec = validate(testloader, model, criterion)


Test: [0/79]	Time 0.507 (0.507)	Loss 0.3292 (0.3292)	Prec 88.281% (88.281%)
 * Prec 85.050% 


In [3]:
use_gpu = torch.cuda.is_available()
device = torch.device("cuda" if use_gpu else "cpu") 

class SaveOutput:
    def __init__(self):
        self.outputs = []
    def __call__(self, module, module_in):
        self.outputs.append(module_in)
    def clear(self):
        self.outputs = []  
        
######### Save inputs from selected layer ##########
save_output = SaveOutput()

model.features[27].register_forward_pre_hook(save_output)             
####################################################

save_input28 = SaveOutput()
model.features[28].register_forward_pre_hook(save_input28)

dataiter = iter(testloader)
images, labels = next(dataiter)
images = images.to(device)
out = model(images)


# Input to layer 28:
layer28_input = save_input28.outputs[0][0]   # (tuple → [0])
print("Layer 28 input:", layer28_input.shape)

Layer 28 input: torch.Size([128, 16, 4, 4])


In [4]:
weight_q = model.features[27].weight_q
w_alpha = model.features[27].weight_quant.wgt_alpha
w_bit = 4

weight_int = weight_q / (w_alpha / (2**(w_bit-1)-1))
#print(weight_int)
print(weight_int.abs().sum()) #should be integer

tensor(5040., device='cuda:0', grad_fn=<SumBackward0>)


In [5]:
act = save_output.outputs[0][0]
print(act.size())
act_alpha  = model.features[27].act_alpha
act_bit = 2
act_quant_fn = act_quantization(act_bit)

act_q = act_quant_fn(act, act_alpha)

act_int = torch.round(act_q / (act_alpha / (2**act_bit-1)))
#print(act_int)
print(act_int.abs().sum()) #should be an integer

torch.Size([128, 16, 4, 4])
tensor(9707., device='cuda:0', grad_fn=<SumBackward0>)


In [33]:
a_int = act_int[0].unsqueeze(0) # [1, 16, 4, 4]
device = a_int.device

# must use the regulsar conv layer - if I use the custom one that I built for quantization part, it will quantize 2 times and cause an error
conv_int = nn.Conv2d(
    in_channels=16,
    out_channels=16,
    kernel_size=3,
    padding=0,
    bias=False
).to(device)

conv_int.weight = torch.nn.Parameter(weight_int.clone())
conv_int.bias = None

output_int = conv_int(a_int)        # [1, 16, 4, 4]

scale_act = act_alpha / (2**act_bit - 1)
scale_w   = w_alpha   / (2**(w_bit-1) - 1)

output_recovered = output_int * scale_act * scale_w

# print("Recovered output (float):", output_recovered[0])
print("act alpha: ", act_alpha.item())
print("weight alpha:", w_alpha.item())
print(output_int.size())
print(output_recovered.size())

act alpha:  5.192733287811279
weight alpha: 2.63910174369812
torch.Size([1, 16, 2, 2])
torch.Size([1, 16, 2, 2])


In [7]:
diff = abs( output_int - output_recovered)
print(diff.mean().item())
diff.abs().sum().item()


7.762695789337158


496.8125305175781

In [8]:
def rmse(x, y):
    return torch.sqrt(((x - y) ** 2).mean())

act_rmse = rmse(act, act_q)
act_error = act - act_q
print("act quantization error sum:", act_error.abs().sum().item())
print("act quantization MSE:", (act_error**2).mean().item())
print("act quantization RMSE:",act_rmse.item())

act_reconstructed = act_int * (act_alpha / (2**act_bit - 1))
act_int_error = act - act_reconstructed

print("act int to float reconstruction error:", act_int_error.abs().sum().item())

weight_float = model.features[27].weight
weight_q     = model.features[27].weight_q
weight_rmse = rmse(weight_float, weight_q)


w_error = weight_float - weight_q
print("weight quantization L1:", w_error.abs().sum().item())
print("weight quantization MSE:", (w_error**2).mean().item())
print("weight quantization RMSE:", weight_rmse.item())

w_reconstructed = weight_int * (w_alpha / (2**(w_bit-1)-1))
w_int_error = weight_float - w_reconstructed
print("weight int reconstruction error:", w_int_error.abs().sum().item())

with torch.no_grad():
    out_float = F.conv2d(act, weight_float, None, 1, 1)

with torch.no_grad():
    out_quant = F.conv2d(act_q, weight_q, None, 1, 1)

out_error = out_float - out_quant
# out_rmse = rmse(out_float, out_quant)
print("output L1 quant error:", out_error.abs().sum().item())
print("output RMSE quant error:", torch.sqrt((out_error**2).mean()).item())
print("output max error:", out_error.abs().max().item())

act quantization error sum: 7053.25927734375
act quantization MSE: 0.12529322504997253
act quantization RMSE: 0.3539678156375885
act int to float reconstruction error: 7053.25927734375
weight quantization L1: 1806.7098388671875
weight quantization MSE: 0.9202349185943604
weight quantization RMSE: 0.9592887759208679
weight int reconstruction error: 1806.7098388671875
output L1 quant error: 404227.96875
output RMSE quant error: 15.826031684875488
output max error: 80.21974182128906


In [9]:
########### TILED CALCULATION HERE #################
# taken from the hw solution:
import math

# act_int.size = torch.Size([128, 64, 32, 32])  <- batch_size, input_ch, ni, nj
a_int = act_int[0,:,:,:]  # pick only one input out of batch
# a_int.size() = [64, 32, 32]

# conv_int.weight.size() = torch.Size([64, 64, 3, 3])  <- output_ch, input_ch, ki, kj
w_int = torch.reshape(weight_int, (weight_int.size(0), weight_int.size(1), -1))  # merge ki, kj index to kij
# w_int.weight.size() = torch.Size([64, 64, 9])
                      
padding = 0
stride = 1
array_size = 8 # row and column number

######## Inputs ########
nig = range(a_int.size(1))  ## ni group [0,1,...31]
njg = range(a_int.size(2))  ## nj group
nijg = range(a_int.size(1)*a_int.size(2))

######## Weights and related stuff ########
kijg = range(w_int.size(2)) # [0, .. 8]
ki_dim = int(math.sqrt(w_int.size(2)))  ## Kernel's 1 dim size = 3
kig = range(int(math.sqrt(len(kijg)))) ## = 3
kjg = range(int(math.sqrt(len(kijg)))) ## = 3

######## Channels ########
icg = range(int(w_int.size(1)))  ## input channel [0,...63]
ocg = range(int(w_int.size(0)))  ## output channel [0,...63]
ic_tileg = range(int(len(icg)/array_size)) ##[0,1,2,3]
oc_tileg = range(int(len(ocg)/array_size)) ##[0,1,2,3]

######## Padding before Convolution #######
a_pad = torch.zeros(len(icg), len(nig)+padding*2, len(njg)+padding*2).cuda()
# a_pad.size() = [64, 32+2pad, 32+2pad]
a_pad[ :, padding:padding+len(nig), padding:padding+len(njg)] = a_int.cuda()
a_pad = torch.reshape(a_pad, (a_pad.size(0), -1))  ## mergin ni and nj index into nij
# a_pad.size() = [64, (32+2pad)*(32+2pad)]
a_tile = torch.zeros(len(ic_tileg), len(oc_tileg), array_size, a_pad.size(1)).cuda()
for ic_tile in ic_tileg: #spatial
    for oc_tile in oc_tileg: #spatial
        a_tile[ic_tile,oc_tile,:,:] = a_pad[ic_tile*array_size:(ic_tile+1)*array_size,:]
        
p_nijg = range(a_tile.size(3)) ## paded activation's nij group [0, ...34*34-1]

######## Outputs ########
o_nig = range(int((math.sqrt(len(nijg))+2*padding -(math.sqrt(len(kijg))- 1) - 1)/stride + 1)) #range(0, 32)
o_njg = range(int((math.sqrt(len(nijg))+2*padding -(math.sqrt(len(kijg)) - 1) - 1)/stride + 1)) #range(0, 32)
psum=torch.zeros(len(ic_tileg), len(oc_tileg), array_size, len(p_nijg), len(kijg)).cuda() 
out = torch.zeros(len(ocg), len(o_nig), len(o_njg)).cuda()
  

######## Tiled 2D version ########

for ic_tile in ic_tileg: #spatial
    for oc_tile in oc_tileg: #spatial
        for kij in kijg: #temporal
            m = nn.Linear(array_size, array_size, bias=False)
            m.weight = torch.nn.Parameter(w_int[oc_tile*array_size:(oc_tile+1)*array_size,ic_tile*array_size:(ic_tile+1)*array_size,kij])
            for nij in p_nijg: #temporal
                psum[ic_tile, oc_tile, :, nij, kij] = m(a_tile[ic_tile,oc_tile,:, nij]).cuda()

### SFP accumulation ###
for ni in o_nig:
    for nj in o_njg:
        for ki in kig:
            for kj in kjg:
                for ic_tile in ic_tileg:    
                    for oc_tile in oc_tileg:                           
                        out[oc_tile*array_size:(oc_tile+1)*array_size, ni, nj] = out[oc_tile*array_size:(oc_tile+1)*array_size, ni, nj] + \
                        psum[ic_tile, oc_tile, :, int(math.sqrt(len(p_nijg)))*(ni+ki) + (nj+kj), len(kig)*ki+kj]
print(out.size())
print(out.abs().sum())
print(psum[0, 0, :, :, 0].size())
print(len(ic_tileg))
print(len(oc_tileg))
print(a_tile.size())

torch.Size([16, 2, 2])
tensor(1430., device='cuda:0', grad_fn=<SumBackward0>)
torch.Size([8, 16])
2
2
torch.Size([2, 2, 8, 16])


In [32]:
ref = output_int[0]
diff = (out - ref).abs().sum()
print(ref.size())
print("8×8 tiled systolic difference =", diff.item())

torch.Size([1, 16, 2, 2])
torch.Size([16, 2, 2])
8×8 tiled systolic difference = 0.0


In [52]:
import torch
import math
import os
import torch.nn as nn



def make_dir(d):
    if not os.path.exists(d):
        os.makedirs(d)

def to_bin_signed(val, bits):
    if val < 0:
        val = (1 << bits) + val
    return f'{val:0{bits}b}'



def extract_layer27_int(model, images, act_bit=2, w_bit=4):

    class Save:
        def __init__(self):
            self.out = None
        def __call__(self, module, inp):
            self.out = inp[0]   # only the input activation

    hook = Save()
    model.features[27].register_forward_pre_hook(hook)

    device = next(model.parameters()).device

    images = images.to(device)
    _ = model(images)  # run forward once
    a_float = hook.out[0]   # [16,4,4]

    layer = model.features[27]
    act_alpha  = layer.act_alpha
    act_scale  = act_alpha / (2**act_bit - 1)

    w_alpha    = layer.weight_quant.wgt_alpha
    w_scale    = w_alpha / (2**(w_bit-1) - 1)

    act_q = layer.act_alq(a_float, act_alpha)
    act_int = act_q / act_scale

    weight_q = layer.weight_q    # already stored in fwd
    weight_int = weight_q / w_scale

    return act_int.detach(), weight_int.detach(), act_scale, w_scale


# one file per each 4 TILES
# def export_weight_tiles(weight_int, outdir="weights", bits=4):

#     make_dir(outdir)

#     # reshape to [C_out, C_in, 9]
#     w_flat = weight_int.reshape(16, 16, 9)

#     array_size = 8
#     ic_tiles = 2  # 16/8
#     oc_tiles = 2

#     tile_id = 0

#     for oc_tile in range(oc_tiles):
#         for ic_tile in range(ic_tiles):
#             for kij in range(9):

#                 W = w_flat[
#                     oc_tile*8:(oc_tile+1)*8,
#                     ic_tile*8:(ic_tile+1)*8,
#                     kij
#                 ]  # [8,8]

#                 fname = f"{outdir}/weight_tile{tile_id}_kij{kij}.txt"
#                 with open(fname, "w") as f:
#                     f.write('#W[col0,row7], W[col1,row7], ..., W[col7,row7]#\n')
#                     f.write('#W[col0,row6], W[col1,row6], ..., W[col7,row6]#\n')
#                     f.write('#................#\n')
#                     for row in range(8):
#                         for col in range(8):
#                             val = int(round(W[col, 7 - row].item()))  # reversed row ordering
#                             bits_bin = to_bin_signed(val, bits)
#                             f.write(bits_bin)
#                         f.write("\n")

#                 print("Generated:", fname)
#             tile_id += 1

import tensorflow as tf
            
# one file per output channel tile - 2 files in total
def export_weight_tiles(weight_int, outdir="weights", bits=4):

    make_dir(outdir)

    # reshape to [C_out, C_in, 9]
    w_flat = weight_int.reshape(16, 16, 9)
    
    # print("THESE ARE NOT THE ACTUAL WEIGHTS!!!")
    # w_flat = torch.zeros((16, 16, 9), dtype=torch.int32)

    # col_idx = torch.randint(0, 16, (1,)).item()  # pick random col
    # w_flat[10, :, 8] = 1
    
    # # for i in range(16):
    # for j in range(16):
    #     w_flat[10,j,8] = j

    array_size = 8
    ic_tiles = 2  # 16/8
    oc_tiles = 2

    tile_id = 0


# 0000 inch14(maps to 8th row) - 0010010101011110 - 0110(r:2,c:0) inch4 - 0011(r:1,c:0) inch2 - 1111(PE0,0)(w1) (PE00 gets inupt channels 0 & 1 not 0 and 8)
# 0001 inch15 - 0000000110111111 - 0100(r:2,c:0) inch5 - 1111(r:1,c:0) inch3 - 1111(PE0,0)(w2)
# 11001101000111010010 - 1100(r:2,c:1) - 0000(r:1,c:1) - 1100(r:0,c:1)
# 01000011111011111110 - 1110(r:2,c:1) - 1100(r:1,c:1) - 0010(r:0,c:1)
# 1111110101010011010100111110 - 0010(r:0,c:2)(w1)
# 0010111011100011110111010011 - 1101(r:0,c:2)(w2)
    
    for oc_tile in range(oc_tiles):
        for kij in range(9):

            fname = f"{outdir}/weight_tile{tile_id}_kij{kij}.txt"
            with open(fname, "w") as f:
                # PE(row, col)-
                # PE(2,1)-W1,PE(1,1)-W1
                # PE(2,1)-W2,PE(1,1)-W2
                # PE(
                # f.write(f'octile={oc_tile},ic_tile=0 #W[col0,row7], W[col1,row7], ..., W[col7,row7]#\n')
                # f.write(f'octile={oc_tile},ic_tile=1 #W[col0,row7], W[col1,row7], ..., W[col7,row7]#\n')
                # f.write(f'octile={oc_tile},ic_tile=0 #W[col0,row6], W[col1,row6], ..., W[col7,row6]#\n')
                
                # 2 rows together form 1 input channel
                # inch:14 12 10 inch:8 ... inch:0
                # inch:15 13 11 inch:9 ... inch:1
                f.write("#......#\n")
                f.write("#......#\n")
                f.write("#......#\n")
                for col in range(8):
                # for ic_tile in range(ic_tiles): # 2
                    W = w_flat[
                            oc_tile*8:(oc_tile+1)*8,
                            :,
                            kij
                        ]
                    # print(W.size()) # [8,16]
                    for row in range(1,16,2):
                        val2 = int(round(W[col, 15 - row].item()))  # reversed row ordering
                        bits_bin2 = to_bin_signed(val2, bits)
                        f.write(bits_bin2)
                    f.write("\n")
                    for row in range(0,16,2):
                        val = int(round(W[col, 15 - row].item()))  # reversed row ordering
                        bits_bin = to_bin_signed(val, bits)
                        f.write(bits_bin)
                    f.write("\n")
                    
            print("Generated:", fname)
        tile_id += 1

# def export_activation_tiles(act_int, outdir="activations", bits=4):

#     make_dir(outdir)

#     array_size = 8
#     ic_tiles = 2   # 16→2 tiles
#     a_flat = act_int.reshape(16, 16)  # [C_in, H*W] = [16,16]

#     for ic_tile in range(ic_tiles):
#         tile_block = a_flat[ic_tile*8:(ic_tile+1)*8, :]   # [8,16]
#         fname = f"{outdir}/activation_ic{ic_tile}.txt"
#         with open(fname, "w") as f:
#             f.write('#time0row7[msb-lsb],time0row6[msb-lst],....,time0row0[msb-lst]#\n')
#             f.write('#time1row7[msb-lsb],time1row6[msb-lst],....,time1row0[msb-lst]#\n')
#             f.write('#................#\n')
#             for t in range(16):           # time 0..15
#                 for row in range(8):      # row 7..0
#                     val = int(round(tile_block[7-row, t].item()))
#                     bits_bin = to_bin_signed(val, bits)
#                     f.write(bits_bin)
#                 f.write("\n")
#         print("Generated:", fname)

# def export_activation_tiles(act_int, outdir="activations", bits=2):

#     make_dir(outdir)

#     array_size = 8
#     ic_tiles = 2   # 16→2 tiles
#     a_flat = act_int.reshape(16, 16)  # [C_in, H*W] = [16,16]
#     nij_total = 16
        
#     fname = f"{outdir}/activation.txt"
#     with open(fname, "w") as f:
#         f.write('#w1time0row7[msb-lsb],w0time0row7[msb-lsb],w1time0row6[msb-lst],....,w0time0row0[msb-lst]#\n')
#         f.write('#w1time1row7[msb-lsb],w0time1row6[msb-lst],w0time1row6[msb-lst],....,w0time1row0[msb-lst]#\n')
#         f.write('#................#\n')
#         for t in range(nij_total):           # time 0..15
#             for row in range(8):      # row 7..0
#                 for ic_tile in range(ic_tiles):
#                     # tile_block = a_flat[(1-ic_tile)*8:((1-ic_tile)+1)*8, :]   # [8,16]
#                     tile_block = a_flat[(ic_tile)*8:((ic_tile)+1)*8, :]   # [8,16]
#                     val = int(round(tile_block[7-row, t].item()))
#                     bits_bin = to_bin_signed(val, bits)
#                     # print(ic_tile, bits_bin)
#                     f.write(bits_bin)
#             f.write("\n")
#     print("Generated:", fname)

from bitarray import bitarray

# a first, then b (left to right)
def interleave(A, B):
    out = 0
    out_pos = 62  # start at MSB of 64-bit output

    for bitpos in range(30, -1, -2):  # 30, 28, ..., 0
        b_chunk = (B >> bitpos) & 0b11
        a_chunk = (A >> bitpos) & 0b11

        # place B chunk (upper 2 bits)
        out |= (b_chunk << out_pos)
        out_pos -= 2

        # place A chunk (lower 2 bits)
        out |= (a_chunk << out_pos)
        out_pos -= 2

    return out

def export_activation_tiles(act_int, outdir="activations", bits=2):

    make_dir(outdir)

    array_size = 8
    ic_tiles = 2   # 16→2 tiles
    # a_flat = act_int.reshape(16, 16)  # [C_in, H*W] = [16,16]
    nij_total = 16
    # a_tile[ic_tile,oc_tile,:,:] = a_pad[ic_tile*array_size:(ic_tile+1)*array_size,:]
    # ic tile, oc tile, array size, nij
    fname = f"{outdir}/activation.txt"
    # 16 rows = 16 nijs
    # each row: 32 bit elts, 2 bit per element/input ch (16 in chs in total).
    # 
    with open(fname, "w") as f:
        # f.write('#time0row7[msb-lsb],time0row6[msb-lsb],time0row5[msb-lst],....,time0row0[msb-lst]#\n')
        # f.write('#time1row7[msb-lsb],time1row6[msb-lst],time1row5[msb-lst],....,time1row0[msb-lst]#\n')
        f.write('#................#\n')
        f.write('#................#\n')
        f.write('#................#\n')
        for t in range(0,nij_total,1): # time 0..15
            for ic_tile in range(ic_tiles):
                for row in range(8):      # row 7..0
                    # tile_block = a_flat[(1-ic_tile)*8:((1-ic_tile)+1)*8, :]   # [8,16]
                    # tile_block = a_flat[(ic_tile)*8:((ic_tile)+1)*8, :]   # [8,16]
                    
                    val0 = int(round(a_tile[ic_tile, 0, 7-row, t].item()))
                    bits_bin0 = to_bin_signed(val0, bits)
                    # val1 = int(round(a_tile[ic_tile, 0, 7-row, t+1].item()))
                    # bits_bin1 = to_bin_signed(val1, bits)
                    # bits_interleaved = interleave(val1, val0)
                    # bits_interleaved_bin = to_bin_signed(bits_interleaved, bits*2)
                    # print(bits_bin1)
                    # print(bits_bin0)
                    # print(bits_interleaved_bin)
                    # print("\n")
                    # f.write(bits_interleaved_bin)
                    f.write(bits_bin0)
            f.write("\n")
    print("Generated:", fname)



## there should be 16 psum outputs - one per output channel
# THIS FUNCITON IS WRONG: it writes the partial sums (during the convolution), not the final result
def export_psum(psum, outdir="psum", bits=16):

    make_dir(outdir)

    num_kij = psum.shape[4]
    array_size = psum.shape[2]
    nij_total = psum.shape[3]

    oc_tile_id = 0
    oc_tiles =2 
    array_size = 8
    C_out = array_size * oc_tiles
    
    # ic_tiles, oc_tiles, array_size, nij_total, num_kij = psum.shape


    for kij in range(num_kij):
        # itercnt = 0
        fname = f"{outdir}/psum_kij{kij}.txt"
        with open(fname, "w") as f:
            f.write("#every 16 bits in each column = psum at oc\n")
            f.write("#every row is nij\n")
            f.write("#time0 col7..col0 for kij slice\n")
            
            for t in range(nij_total): # time = line in file
                
                for oc_tile in range(oc_tiles): # out channel - column in file
                    psum_x_0 = psum[0, oc_tile, :, :, kij]
                    psum_x_1 = psum[1, oc_tile, :, :, kij]
                    for oc in range(array_size):
                        # itercnt+=1
                        val = int(round(psum_x_0[7-oc, t].item())) # + int(round(psum_x_1[7-oc, t].item()))
                        bits_bin = to_bin_signed(val, bits)
                        # if(oc == 1):
                            # print(bits_bin)
                        f.write(bits_bin)
                f.write("\n")
        # print(itercnt)
    
        print("Generated:", fname)

# fixed version
    # 256 BITS (16 output channels, 16 bits each) per row - each row is time 0, then time 1 ... each row is a specific time
    # one file per kij -- per each kij we do nij many input macs and accumulations
# we need the psum outputs of the sfu, the two weights' mac results (input channel) will be combined anyway 
# psum is the accumulatd result in the spec func unit -- 1 file per kij
# d

# def export_psum(psum, outdir="psum", bits=16):
    # make_dir(outdir)

    # ic_tiles, oc_tiles, array_size, nij_total, num_kij = psum.shape

    # C_out = oc_tiles * array_size
    # row_in_tile = oc % array_size  # 0..7
    # row_idx = array_size - 1 - row_in_tile


    #     ## USES REVERSED ROW: - ie begins at row 7

    #     for kij in range(num_kij):
    #         fname = f"{outdir}/psum_ch{oc}_kij{kij}.txt"
    #         with open(fname, "w") as f:
    #             f.write(f"# PSUM for output channel {oc}, kij={kij}\n")
    #             f.write("# One value per time step (nij index)\n")
    #             f.write("# One value per time step (nij index)\n")

    #             for t in range(nij_total):
    #                 # accumulate over all input-channel tiles
    #                 acc_val = 0
                    
    #             for oc in range(C_out):
    #                 oc_tile_id  = oc // array_size  # 0 or 1
    #                 for ic_tile_id in range(ic_tiles):
    #                     v = int(round(
    #                         psum[ic_tile_id,   #  input tile (0..1)
    #                              oc_tile_id,   #  output tile (0..1)
    #                              row_idx,      # row in the 8x8 array
    #                              t,            # time - nij
    #                              kij].item()   # kernel index
    #                     ))
    #                     acc_val += v

    #                 bits_bin = to_bin_signed(acc_val, bits)
    #                 f.write(bits_bin + "\n")

    #         print("Generated:", fname)

# 128 bits per row - per outch: 16 bits - 8 columns in array
# 4 rows (2x2)
# 2 files (for each output tile)

# bij: flattened index of the output matrix

# file1:
# bij0 : ch 7, ch 6, ... , ch 0
# bij1 : ch 7, ch 6, ... , ch 0
# bij2
# bij3

# file2:
# bij0 : ch 15, ch 14, ... , ch 8
# bij1 : ch 15, ch 14, ... , ch 8
# bij2
# bij3
def export_final_output(out, outdir="final_psums", bits=16, separate_files=True):

    make_dir(outdir)

    C_out, H_out, W_out = out.shape
    # print("out.shape", out.shape) # [16, 2, 2]
    assert C_out == 16, f"Expected 16 output channels, got {C_out}"
    
    def write_value(f, val):
        ival = int(round(val))
        binval = to_bin_signed(ival, bits)
        f.write(binval)
        
    bijg = 4 # flattened run var
    tileg = 2 # oc tile count
    array_size = 8
    
    out_flat = out.reshape(16, 4)

    # testing:
    # print("THESE ARE NOT THE ACTUAL OUTPUTS!!!")
    # out_flat = torch.zeros((16, 4), dtype=torch.int32)

    # col_idx = torch.randint(0, 16, (1,)).item()  # pick random col
    
    # # for i in range(16):
    # for j in range(4):
    #     out_flat[8,j] = j
    
    for tile in range(tileg):
        fname = f"{outdir}/final_out_tile{tile}.txt"
        with open(fname, "w") as f:
            f.write("#.....#\n")
            f.write("#.....#\n")
            f.write("#.....#\n")
            for i in range(bijg): # rows - 4
                for j in range((tile+1)*array_size - 1, (tile)*array_size-1, -1): # elts 0-7 or 8-15
                    # print(f"{j}")
                    write_value(f, out_flat[j, i].item())
                f.write("\n")

    # if separate_files:
    #     for oc in range(C_out):
    #         fname = f"{outdir}/final_out_ch{oc}.txt"
    #         with open(fname, "w") as f:
    #             f.write(f"#output: 16x2x2 - 16bits per out[ch][h][w] - [msb..lsb]\n")
    #             f.write(f"#line0: out[ch={oc}][0][0], out[ch={oc}][0][1]\n")
    #             f.write(f"#line1: out[ch={oc}][1][0], out[ch={oc}][1][1]\n")
    #             for i in range(H_out):
    #                 for j in range(W_out):
    #                     write_value(f, out[oc, i, j].item())
    #                 f.write("\n")
    #         print("Generated:", fname)

def export_systolic_layer27(model, images, psum, outdir="systolic_export"):

    make_dir(outdir)

    act_int, weight_int, act_scale, w_scale = extract_layer27_int(model, images)

    export_weight_tiles(weight_int, outdir=f"{outdir}/weights")

    export_activation_tiles(act_int, outdir=f"{outdir}/activations")

    export_psum(psum, outdir=f"{outdir}/psum")

    export_final_output(out, outdir=f"{outdir}/outputs")
    
    torch.save(act_int,  f"{outdir}/act_int.pt")
    torch.save(weight_int, f"{outdir}/weight_int.pt")
    torch.save(psum, f"{outdir}/psum_raw.pt")

export_systolic_layer27(model, images, psum)

Generated: systolic_export/weights/weight_tile0_kij0.txt
Generated: systolic_export/weights/weight_tile0_kij1.txt
Generated: systolic_export/weights/weight_tile0_kij2.txt
Generated: systolic_export/weights/weight_tile0_kij3.txt
Generated: systolic_export/weights/weight_tile0_kij4.txt
Generated: systolic_export/weights/weight_tile0_kij5.txt
Generated: systolic_export/weights/weight_tile0_kij6.txt
Generated: systolic_export/weights/weight_tile0_kij7.txt
Generated: systolic_export/weights/weight_tile0_kij8.txt
Generated: systolic_export/weights/weight_tile1_kij0.txt
Generated: systolic_export/weights/weight_tile1_kij1.txt
Generated: systolic_export/weights/weight_tile1_kij2.txt
Generated: systolic_export/weights/weight_tile1_kij3.txt
Generated: systolic_export/weights/weight_tile1_kij4.txt
Generated: systolic_export/weights/weight_tile1_kij5.txt
Generated: systolic_export/weights/weight_tile1_kij6.txt
Generated: systolic_export/weights/weight_tile1_kij7.txt
Generated: systolic_export/weig