In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import openml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder


In [2]:
np.random.seed(42)  # Set the seed for NumPy

# Set the seed for PyTorch (CPU and GPU)
torch.manual_seed(42)

# If you're using CUDA (GPU), set the seed for CUDA as well
torch.cuda.manual_seed(42)
torch.cuda.manual_seed_all(42)  # For all GPUs (if you have more than one)

In [3]:
# 1. Download the dataset (OpenML: id 180)
dataset = openml.datasets.get_dataset(180)
X, y, _, _ = dataset.get_data(target=dataset.default_target_attribute, dataset_format='dataframe')

In [4]:
# 2. Preprocessing: Numeric features only, standardize, encode labels (classes are 1-7)
X_numeric = X.select_dtypes(include=[np.number])

scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_numeric.values.astype(np.float32))

le = LabelEncoder()
y_encoded = le.fit_transform(y)  # Converts to 0...6
y_encoded = y_encoded.astype(np.int64)

In [5]:
# 3. Train/test split
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y_encoded, test_size=0.2, random_state=42)

In [6]:
# 4. PyTorch Dataset
class CovertypeDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.from_numpy(X)
        self.y = torch.from_numpy(y)
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

train_ds = CovertypeDataset(X_train, y_train)
test_ds = CovertypeDataset(X_test, y_test)

train_loader = DataLoader(train_ds, batch_size=256, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=256)

In [7]:
from dpn_4.dpn import DPN
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DPN(X_train.shape[1], 192 + len(le.classes_), len(le.classes_), False).to(device)
model.compile()

In [8]:
# 6. Training setup
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [9]:
from utils import train

model = model
train_loader = train_loader
val_loader = test_loader
optimizer = optimizer
criterion = criterion
device = device

percent = 0.99
rounds = 100

In [10]:
original_weights = [param.clone().detach() for param in model.weights]
final_masks = [torch.ones_like(param) for param in original_weights]
current_masks = [torch.ones_like(param) for param in original_weights]
p_per_round = 1 - (1 - percent) ** (1 / rounds)

In [11]:
def prune_by_percent_once(percent, mask, final_weight):
    # Get the absolute values of weights where mask == 1
    masked_weights = final_weight[mask == 1].abs()

    # Sort the unmasked weights
    sorted_weights, _ = torch.sort(masked_weights)
    if sorted_weights.shape[0] != 0:
        # Determine the cutoff index for pruning
        cutoff_index = int(round(percent * sorted_weights.shape[0]))
        cutoff = sorted_weights[cutoff_index]

        # Prune all weights below or equal to the cutoff
        new_mask = torch.where(final_weight.abs() <= cutoff, torch.zeros_like(mask), mask)

    else:
        new_mask = mask

    return new_mask

In [12]:
def prune_by_percent(model, masks, percent):

    blocks = model.weights
    for i in range(len(blocks)):
        masks[i] = prune_by_percent_once(percent, masks[i], blocks[i])

    return masks

In [13]:
_, val_metrics = train(model, train_loader, val_loader, 5, optimizer, criterion, device=device)
val_accuracy = val_metrics[-1][1]


Epoch: 1 Total_Time: 20.8289 Average_Time_per_batch: 0.0604 Train_Accuracy: 0.6576 Train_Loss: 0.9413 Validation_Accuracy: 0.6929 Validation_Loss: 0.8268
Epoch: 2 Total_Time: 20.4707 Average_Time_per_batch: 0.0593 Train_Accuracy: 0.6980 Train_Loss: 0.8217 Validation_Accuracy: 0.7055 Validation_Loss: 0.7944
Epoch: 3 Total_Time: 20.8366 Average_Time_per_batch: 0.0604 Train_Accuracy: 0.7119 Train_Loss: 0.7927 Validation_Accuracy: 0.7196 Validation_Loss: 0.7718
Epoch: 4 Total_Time: 20.5308 Average_Time_per_batch: 0.0595 Train_Accuracy: 0.7235 Train_Loss: 0.7711 Validation_Accuracy: 0.7289 Validation_Loss: 0.7562
Epoch: 5 Total_Time: 20.6018 Average_Time_per_batch: 0.0597 Train_Accuracy: 0.7329 Train_Loss: 0.7529 Validation_Accuracy: 0.7368 Validation_Loss: 0.7401Peak GPU memory: 18.96 MB


In [14]:
for round_idx in range(rounds):
    current_masks = prune_by_percent(model, current_masks, p_per_round)
    pruned_weights = [w * m for w, m in zip(original_weights, current_masks)]
    model.weights = nn.ParameterList([nn.Parameter(w) for w in pruned_weights])
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    _, val_metrics = train(model, train_loader, val_loader, 5, optimizer, criterion, device=device)
    new_val_accuracy = val_metrics[-1][1]
    
    if new_val_accuracy >= val_accuracy:
        #val_accuracy = new_val_accuracy
        final_masks = current_masks.copy()  # This creates a shallow copy of the list
    else:
        break


Epoch: 1 Total_Time: 20.8624 Average_Time_per_batch: 0.0605 Train_Accuracy: 0.6639 Train_Loss: 0.9227 Validation_Accuracy: 0.6958 Validation_Loss: 0.8196
Epoch: 2 Total_Time: 20.5130 Average_Time_per_batch: 0.0595 Train_Accuracy: 0.7036 Train_Loss: 0.8134 Validation_Accuracy: 0.7131 Validation_Loss: 0.7857
Epoch: 3 Total_Time: 20.5382 Average_Time_per_batch: 0.0595 Train_Accuracy: 0.7177 Train_Loss: 0.7844 Validation_Accuracy: 0.7257 Validation_Loss: 0.7634
Epoch: 4 Total_Time: 20.5290 Average_Time_per_batch: 0.0595 Train_Accuracy: 0.7265 Train_Loss: 0.7630 Validation_Accuracy: 0.7366 Validation_Loss: 0.7464
Epoch: 5 Total_Time: 20.5686 Average_Time_per_batch: 0.0596 Train_Accuracy: 0.7355 Train_Loss: 0.7460 Validation_Accuracy: 0.7413 Validation_Loss: 0.7332Peak GPU memory: 18.96 MB

Epoch: 1 Total_Time: 20.5155 Average_Time_per_batch: 0.0595 Train_Accuracy: 0.6702 Train_Loss: 0.9030 Validation_Accuracy: 0.7002 Validation_Loss: 0.8115
Epoch: 2 Total_Time: 20.4996 Average_Time_per_bat

In [15]:
pruned_weights = [w * m for w, m in zip(original_weights, final_masks)]

In [16]:
for weight in pruned_weights:
    print(torch.nonzero(weight))

tensor([], device='cuda:0', size=(0, 2), dtype=torch.int64)
tensor([], device='cuda:0', size=(0, 2), dtype=torch.int64)
tensor([], device='cuda:0', size=(0, 2), dtype=torch.int64)
tensor([], device='cuda:0', size=(0, 2), dtype=torch.int64)
tensor([], device='cuda:0', size=(0, 2), dtype=torch.int64)
tensor([], device='cuda:0', size=(0, 2), dtype=torch.int64)
tensor([], device='cuda:0', size=(0, 2), dtype=torch.int64)
tensor([], device='cuda:0', size=(0, 2), dtype=torch.int64)
tensor([], device='cuda:0', size=(0, 2), dtype=torch.int64)
tensor([], device='cuda:0', size=(0, 2), dtype=torch.int64)
tensor([], device='cuda:0', size=(0, 2), dtype=torch.int64)
tensor([], device='cuda:0', size=(0, 2), dtype=torch.int64)
tensor([], device='cuda:0', size=(0, 2), dtype=torch.int64)
tensor([], device='cuda:0', size=(0, 2), dtype=torch.int64)
tensor([], device='cuda:0', size=(0, 2), dtype=torch.int64)
tensor([], device='cuda:0', size=(0, 2), dtype=torch.int64)
tensor([], device='cuda:0', size=(0, 2),

In [17]:
merged_weights = []
size_limit = pruned_weights[0].shape[1]
start_idx = 0

In [18]:
zero_indices = [i + size_limit for i, t in enumerate(pruned_weights) if torch.all(t == 0)]
print(zero_indices)


while len(zero_indices) > 0:
    pruned_weights = [t for t in pruned_weights if not torch.all(t == 0)]

    for i in range(len(pruned_weights)):
        tensor = pruned_weights[i]
        columns_to_keep = [i for i in range(tensor.shape[1]) if i not in zero_indices]
        pruned_weights[i] = tensor[:, columns_to_keep]

    zero_indices = [i + size_limit for i, t in enumerate(pruned_weights) if torch.all(t == 0)]
    print(zero_indices)


[14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]
[14, 16, 18, 20]
[22]
[]


In [19]:
for weight in pruned_weights:
    print(torch.nonzero(weight))

tensor([[0, 0]], device='cuda:0')
tensor([[0, 2]], device='cuda:0')
tensor([[ 0, 11]], device='cuda:0')
tensor([[0, 8]], device='cuda:0')
tensor([[0, 9]], device='cuda:0')
tensor([[ 0,  5],
        [ 0, 15]], device='cuda:0')
tensor([[0, 0]], device='cuda:0')
tensor([[ 0,  0],
        [ 0, 20]], device='cuda:0')
tensor([[ 0,  6],
        [ 0, 10],
        [ 0, 13]], device='cuda:0')
tensor([[0, 3]], device='cuda:0')
tensor([[ 0,  9],
        [ 0, 23]], device='cuda:0')
tensor([[ 0,  9],
        [ 0, 20]], device='cuda:0')
tensor([[ 0,  0],
        [ 0,  9],
        [ 0, 18]], device='cuda:0')
tensor([[ 0,  0],
        [ 0, 25]], device='cuda:0')
tensor([[ 0,  5],
        [ 0, 11],
        [ 0, 23]], device='cuda:0')
tensor([[ 0, 11],
        [ 0, 13]], device='cuda:0')
tensor([[0, 0],
        [0, 7]], device='cuda:0')
tensor([[ 0,  3],
        [ 0, 26],
        [ 0, 27]], device='cuda:0')
tensor([[ 0,  5],
        [ 0, 25],
        [ 0, 28]], device='cuda:0')
tensor([[ 0,  0],
        

tensor([[  0,   6],
        [  0,  10],
        [  0,  17],
        [  0,  20],
        [  0,  21],
        [  0,  25],
        [  0,  41],
        [  0,  47],
        [  0,  53],
        [  0,  56],
        [  0,  73],
        [  0,  79],
        [  0,  91],
        [  0,  99],
        [  0, 101],
        [  0, 105],
        [  0, 108],
        [  0, 112],
        [  0, 113],
        [  0, 115]], device='cuda:0')
tensor([[  0,   3],
        [  0,  13],
        [  0,  17],
        [  0,  19],
        [  0,  26],
        [  0,  29],
        [  0,  32],
        [  0,  45],
        [  0,  53],
        [  0,  55],
        [  0,  59],
        [  0,  62],
        [  0,  74],
        [  0,  76],
        [  0,  81],
        [  0,  87],
        [  0,  93],
        [  0, 106],
        [  0, 120]], device='cuda:0')
tensor([[  0,  12],
        [  0,  17],
        [  0,  19],
        [  0,  27],
        [  0,  30],
        [  0,  32],
        [  0,  33],
        [  0,  36],
        [  0,  45],
    

In [20]:
for i in range(1, len(pruned_weights)):
    nonzero_idx = torch.nonzero(pruned_weights[i])[-1][1].item()

    if nonzero_idx >= size_limit:

        merged_weights.append(torch.cat([t[:, :size_limit] for t in pruned_weights[start_idx:i]], dim=0))
        
        size_limit = pruned_weights[i].shape[1]
        start_idx = i

merged_weights.append(torch.cat([t[:, :size_limit] for t in pruned_weights[start_idx:]], dim=0))

In [21]:
for weight in merged_weights:
    print(weight.shape)

torch.Size([5, 14])
torch.Size([2, 19])
torch.Size([3, 21])
torch.Size([3, 24])
torch.Size([4, 27])
torch.Size([2, 31])
torch.Size([4, 33])
torch.Size([5, 37])
torch.Size([7, 42])
torch.Size([3, 49])
torch.Size([2, 52])
torch.Size([7, 54])
torch.Size([6, 61])
torch.Size([4, 67])
torch.Size([3, 71])
torch.Size([1, 74])
torch.Size([2, 75])
torch.Size([3, 77])
torch.Size([3, 80])
torch.Size([2, 83])
torch.Size([2, 85])
torch.Size([3, 87])
torch.Size([6, 90])
torch.Size([2, 96])
torch.Size([1, 98])
torch.Size([4, 99])
torch.Size([6, 103])
torch.Size([5, 109])
torch.Size([2, 114])
torch.Size([1, 116])
torch.Size([5, 117])
torch.Size([4, 122])
torch.Size([2, 126])
torch.Size([3, 128])
torch.Size([3, 131])
torch.Size([4, 134])
torch.Size([6, 138])
torch.Size([4, 144])
torch.Size([2, 148])
torch.Size([3, 150])
torch.Size([5, 153])
torch.Size([2, 158])


In [26]:
sizes = [w.shape[0] for w in merged_weights]
print(sum(sizes))

146


In [23]:
model = DPN(X_train.shape[1], sum(sizes), len(le.classes_), False).to(device)
model.compile()
model.weights = nn.ParameterList([nn.Parameter(w) for w in merged_weights])
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [24]:
train_metrics, eval_metrics = train(model, train_loader, test_loader, 50, optimizer, criterion, device=device)


Epoch: 1 Total_Time: 9.8767 Average_Time_per_batch: 0.0286 Train_Accuracy: 0.6609 Train_Loss: 0.9078 Validation_Accuracy: 0.6912 Validation_Loss: 0.8193
Epoch: 2 Total_Time: 9.8237 Average_Time_per_batch: 0.0285 Train_Accuracy: 0.6952 Train_Loss: 0.8232 Validation_Accuracy: 0.7055 Validation_Loss: 0.7904
Epoch: 3 Total_Time: 9.9258 Average_Time_per_batch: 0.0288 Train_Accuracy: 0.7074 Train_Loss: 0.7995 Validation_Accuracy: 0.7171 Validation_Loss: 0.7710
Epoch: 4 Total_Time: 9.9071 Average_Time_per_batch: 0.0287 Train_Accuracy: 0.7173 Train_Loss: 0.7810 Validation_Accuracy: 0.7238 Validation_Loss: 0.7551
Epoch: 5 Total_Time: 9.5805 Average_Time_per_batch: 0.0278 Train_Accuracy: 0.7258 Train_Loss: 0.7651 Validation_Accuracy: 0.7303 Validation_Loss: 0.7425
Epoch: 6 Total_Time: 9.6045 Average_Time_per_batch: 0.0278 Train_Accuracy: 0.7329 Train_Loss: 0.7508 Validation_Accuracy: 0.7384 Validation_Loss: 0.7300
Epoch: 7 Total_Time: 9.5154 Average_Time_per_batch: 0.0276 Train_Accuracy: 0.7402

In [25]:
print(sizes)

[5, 2, 3, 3, 4, 2, 4, 5, 7, 3, 2, 7, 6, 4, 3, 1, 2, 3, 3, 2, 2, 3, 6, 2, 1, 4, 6, 5, 2, 1, 5, 4, 2, 3, 3, 4, 6, 4, 2, 3, 5, 2]
