In [3]:
import torch
import src.data_processing as dp
from src.rtf_pruning import prune_attention_heads,prune_heads_and_shrink,model_heads_by_magnitude
from sdmetrics.reports.single_table import QualityReport
from sdmetrics.reports.single_table import DiagnosticReport
from realtabformer import REaLTabFormer

In [4]:
train_data, test_data, sample_data = dp.csv_data_split("../data/breast-cancer-wisconsin.csv")
my_metadata_dict = dp.metadata("../data/cancer_metadata.json")

In [83]:
rtf_model = REaLTabFormer.load_from_dir("../models/rtf_regular/id000017342890144858071040")

In [64]:
heads_to_prune = model_heads_by_magnitude(model=rtf_model.model,percentage=0.3,num_heads_per_layer=12,layers=6)

In [65]:
heads_to_prune

[(0, [8, 2, 1, 7]),
 (1, [2, 11, 0, 7]),
 (2, [5, 6, 8, 0]),
 (3, [5, 2, 10, 9]),
 (4, [8, 2, 3, 6]),
 (5, [6, 4, 5, 0])]

In [76]:
prune_attention_heads(rtf_model.model,heads_to_prune)

Processing transformer.h.0.attn.c_attn
Processing transformer.h.0.attn.c_proj
Processing transformer.h.0.mlp.c_fc
Processing transformer.h.0.mlp.c_proj
Processing transformer.h.1.attn.c_attn
Processing transformer.h.1.attn.c_proj
Processing transformer.h.1.mlp.c_fc
Processing transformer.h.1.mlp.c_proj
Processing transformer.h.2.attn.c_attn
Processing transformer.h.2.attn.c_proj
Processing transformer.h.2.mlp.c_fc
Processing transformer.h.2.mlp.c_proj
Processing transformer.h.3.attn.c_attn
Processing transformer.h.3.attn.c_proj
Processing transformer.h.3.mlp.c_fc
Processing transformer.h.3.mlp.c_proj
Processing transformer.h.4.attn.c_attn
Processing transformer.h.4.attn.c_proj
Processing transformer.h.4.mlp.c_fc
Processing transformer.h.4.mlp.c_proj
Processing transformer.h.5.attn.c_attn
Processing transformer.h.5.attn.c_proj
Processing transformer.h.5.mlp.c_fc
Processing transformer.h.5.mlp.c_proj


In [69]:
for i, block in enumerate(rtf_model.model.transformer.h):
    num_heads = block.attn.num_heads
    print(f"Layer {i} has {num_heads} attention heads.")

Layer 0 has 8 attention heads.
Layer 1 has 8 attention heads.
Layer 2 has 8 attention heads.
Layer 3 has 8 attention heads.
Layer 4 has 8 attention heads.
Layer 5 has 8 attention heads.


In [85]:
synthetic_data2 = rtf_model.sample(n_samples=len(test_data))

quality = QualityReport()
diagnostic = DiagnosticReport()

quality.generate(test_data,synthetic_data2,metadata=my_metadata_dict,verbose=False)
diagnostic.generate(test_data,synthetic_data2,metadata=my_metadata_dict,verbose=False)

print(quality.get_properties())
print(diagnostic.get_properties())



  0%|          | 0/137 [00:00<?, ?it/s]

Generated 0 invalid samples out of total 256 samples generated. Sampling efficiency is: 100.0000%
             Property     Score
0       Column Shapes  0.921168
1  Column Pair Trends  0.913513
         Property     Score
0   Data Validity  0.994691
1  Data Structure  1.000000


In [22]:
for i, block in enumerate(rtf_model.model.transformer.h):
    num_heads = block.attn.num_heads
    print(f"Layer {i} has {num_heads} attention heads.")

Layer 0 has 6 attention heads.
Layer 1 has 6 attention heads.
Layer 2 has 6 attention heads.
Layer 3 has 6 attention heads.


In [201]:
# indices = torch.LongTensor([i for i in range(0, 512, 2 )])  # Keep 100 out of 512 channels
apply_structured_pruning(rtf_model.model)

Pruning layer: transformer.h.0.attn.c_attn
Pruning layer: transformer.h.0.attn.c_proj
Pruning layer: transformer.h.0.mlp.c_fc
Pruning layer: transformer.h.0.mlp.c_proj
Pruning layer: transformer.h.1.attn.c_attn
Pruning layer: transformer.h.1.attn.c_proj
Pruning layer: transformer.h.1.mlp.c_fc
Pruning layer: transformer.h.1.mlp.c_proj
Pruning layer: transformer.h.2.attn.c_attn
Pruning layer: transformer.h.2.attn.c_proj
Pruning layer: transformer.h.2.mlp.c_fc
Pruning layer: transformer.h.2.mlp.c_proj
Pruning layer: transformer.h.3.attn.c_attn
Pruning layer: transformer.h.3.attn.c_proj
Pruning layer: transformer.h.3.mlp.c_fc
Pruning layer: transformer.h.3.mlp.c_proj


In [518]:
def compute_sparsity(model):
    total_params = 0
    zero_params = 0
    for param in model.parameters():
        total_params += param.numel()
        zero_params += (param == 0).sum().item()
    
    sparsity = zero_params / total_params
    return sparsity,total_params, zero_params

# Example usage
sparsity, total_params, zero_params = compute_sparsity(rtf_model.model)
print(f"Sparsity: {sparsity * 100:.2f}%")
print(f"Total: {total_params}")
print(f"Zero: {zero_params}")

Sparsity: 0.00%
Total: 12164608
Zero: 0


In [70]:
def model_size(model):
    return sum(p.numel() * p.element_size() for p in model.parameters())

print(f"Model size: {model_size(rtf_model.model)} bytes")

Model size: 154847232 bytes


In [73]:
def model_size(model):
    return sum(p.numel() * p.element_size() for p in model.parameters())

print(f"Model size: {model_size(rtf_model.model)} bytes")

Model size: 154847232 bytes


In [32]:
48658432/52858880

0.9205346764819837

In [82]:
synthetic_data = rtf_model.sample(n_samples=(len(test_data)))

quality = QualityReport()
diagnostic = DiagnosticReport()

quality.generate(test_data,synthetic_data,metadata=my_metadata_dict,verbose=False)
diagnostic.generate(test_data,synthetic_data,metadata=my_metadata_dict,verbose=False)

print(quality.get_properties())
print(diagnostic.get_properties())



  0%|          | 0/137 [00:00<?, ?it/s]

Generated 0 invalid samples out of total 256 samples generated. Sampling efficiency is: 100.0000%
             Property     Score
0       Column Shapes  0.869343
1  Column Pair Trends  0.874957
         Property     Score
0   Data Validity  0.998009
1  Data Structure  1.000000


In [80]:
rtf_model.save("save/")

ValueError: This directory is not empty, and contains either a config or a model. Consider setting `allow_overwrite=True` if you want to overwrite these.

In [134]:
test = REaLTabFormer.load_from_dir("/Users/sebastian/PycharmProjects/model-compression/notebooks/testing/id000017342868701547638784")

RuntimeError: Error(s) in loading state_dict for GPT2LMHeadModel:
	size mismatch for transformer.wte.weight: copying a param with shape torch.Size([156, 512]) from checkpoint, the shape in current model is torch.Size([156, 256]).
	size mismatch for transformer.wpe.weight: copying a param with shape torch.Size([1024, 512]) from checkpoint, the shape in current model is torch.Size([1024, 256]).
	size mismatch for transformer.h.0.ln_1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for transformer.h.0.ln_1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for transformer.h.0.ln_2.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for transformer.h.0.ln_2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for transformer.h.1.ln_1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for transformer.h.1.ln_1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for transformer.h.1.ln_2.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for transformer.h.1.ln_2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for transformer.h.2.ln_1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for transformer.h.2.ln_1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for transformer.h.2.ln_2.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for transformer.h.2.ln_2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for transformer.h.3.ln_1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for transformer.h.3.ln_1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for transformer.h.3.ln_2.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for transformer.h.3.ln_2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for transformer.ln_f.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for transformer.ln_f.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for lm_head.weight: copying a param with shape torch.Size([156, 512]) from checkpoint, the shape in current model is torch.Size([156, 256]).

In [None]:
synthetic_data_pruned = 0

In [33]:
from sdmetrics.single_table import BinaryDecisionTreeClassifier

def evaluate_model(test_data, synthetic_data, target, metadata):
    return BinaryDecisionTreeClassifier.compute(
        test_data=test_data,
        train_data=synthetic_data,
        target=target,
        metadata=metadata
    )


In [44]:
evaluate_model(test_data,synthetic_data_pruned,'Class',my_metadata_dict)

0.8461538461538461

In [47]:
evaluate_model(test_data,synthetic_data,'Class',my_metadata_dict)

0.8712871287128713

In [38]:
checkpoint = torch.load("/Users/sebastian/PycharmProjects/model-compression/notebooks/testing/id000017342868701547638784/rtf_model.pt")
for key, value in checkpoint.items():
    if torch.is_tensor(value):
        print(f"Key: {key}, Shape: {value.shape}")

FileNotFoundError: [Errno 2] No such file or directory: '/Users/sebastian/PycharmProjects/model-compression/notebooks/testing/id000017342868701547638784/rtf_model.pt'