In [1]:
import torch
import src.data_processing as dp
from src.rtf_pruning import prune_attention_heads,model_heads_by_magnitude
from sdmetrics.reports.single_table import QualityReport
from sdmetrics.reports.single_table import DiagnosticReport
from realtabformer import REaLTabFormer

In [36]:
train_data, test_data, sample_data = dp.csv_data_split("../data/breast-cancer-wisconsin.csv")
my_metadata_dict = dp.metadata("../data/cancer_metadata.json")

In [37]:
rtf_model = REaLTabFormer.load_from_dir("../models/rtf_regular/id000017342890144858071040")

In [38]:
# Size of model before pruning attn heads
def model_size(model):
    return sum(p.numel() * p.element_size() for p in model.parameters())

print(f"Model size: {model_size(rtf_model.model)} bytes")

Model size: 173740032 bytes


In [39]:
# Prints how many attentions heads before pruning 
for i, block in enumerate(rtf_model.model.transformer.h):
    num_heads = block.attn.num_heads
    print(f"Layer {i} has {num_heads} attention heads.")

Layer 0 has 12 attention heads.
Layer 1 has 12 attention heads.
Layer 2 has 12 attention heads.
Layer 3 has 12 attention heads.
Layer 4 has 12 attention heads.
Layer 5 has 12 attention heads.


In [42]:
# Quality report before pruning
synthetic_data = rtf_model.sample(n_samples=len(test_data))

quality = QualityReport()
diagnostic = DiagnosticReport()

quality.generate(test_data,synthetic_data,metadata=my_metadata_dict,verbose=False)
diagnostic.generate(test_data,synthetic_data,metadata=my_metadata_dict,verbose=False)

print(quality.get_properties())
print(diagnostic.get_properties())



  0%|          | 0/137 [00:00<?, ?it/s]

Generated 0 invalid samples out of total 256 samples generated. Sampling efficiency is: 100.0000%
             Property     Score
0       Column Shapes  0.923358
1  Column Pair Trends  0.910994
         Property     Score
0   Data Validity  0.992701
1  Data Structure  1.000000


In [43]:
# Finding least to most important attention heads per layer
heads_to_prune = model_heads_by_magnitude(model=rtf_model.model,percentage=0.3,num_heads_per_layer=12,layers=6)
heads_to_prune

[(0, [8, 2, 1, 7]),
 (1, [2, 11, 0, 7]),
 (2, [5, 6, 8, 0]),
 (3, [5, 2, 10, 9]),
 (4, [8, 2, 3, 6]),
 (5, [6, 4, 5, 0])]

In [44]:
prune_attention_heads(rtf_model.model,heads_to_prune)

In [45]:
# After pruning 
for i, block in enumerate(rtf_model.model.transformer.h):
    num_heads = block.attn.num_heads
    print(f"Layer {i} has {num_heads} attention heads.")

Layer 0 has 8 attention heads.
Layer 1 has 8 attention heads.
Layer 2 has 8 attention heads.
Layer 3 has 8 attention heads.
Layer 4 has 8 attention heads.
Layer 5 has 8 attention heads.


In [47]:
# Size of model after pruning attn heads
def model_size(model):
    return sum(p.numel() * p.element_size() for p in model.parameters())

print(f"Model size: {model_size(rtf_model.model)} bytes")

Model size: 154847232 bytes


In [50]:
# Evaluation after pruning attn heads
synthetic_data_2 = rtf_model.sample(n_samples=(len(test_data)))

quality = QualityReport()
diagnostic = DiagnosticReport()

quality.generate(test_data,synthetic_data_2,metadata=my_metadata_dict,verbose=False)
diagnostic.generate(test_data,synthetic_data_2,metadata=my_metadata_dict,verbose=False)

print(quality.get_properties())
print(diagnostic.get_properties())



  0%|          | 0/137 [00:00<?, ?it/s]

Generated 0 invalid samples out of total 256 samples generated. Sampling efficiency is: 100.0000%
             Property     Score
0       Column Shapes  0.870073
1  Column Pair Trends  0.881952
         Property     Score
0   Data Validity  0.993364
1  Data Structure  1.000000


In [55]:
rtf_model.save("../models/saved/")

Copying artefacts from: best-disc-model
Copying artefacts from: mean-best-disc-model
Copying artefacts from: not-best-disc-model
Copying artefacts from: last-epoch-model


In [56]:
# Since size of saved model and expected model don't match up after pruning loading doesn't work, 
# the config needs to be updated after pruning or making any adjustments to the model
test = REaLTabFormer.load_from_dir("../models/saved/id000017342890144858071040")

RuntimeError: Error(s) in loading state_dict for GPT2LMHeadModel:
	size mismatch for transformer.h.0.attn.c_attn.weight: copying a param with shape torch.Size([768, 1536]) from checkpoint, the shape in current model is torch.Size([768, 2304]).
	size mismatch for transformer.h.0.attn.c_attn.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([2304]).
	size mismatch for transformer.h.0.attn.c_proj.weight: copying a param with shape torch.Size([512, 768]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for transformer.h.1.attn.c_attn.weight: copying a param with shape torch.Size([768, 1536]) from checkpoint, the shape in current model is torch.Size([768, 2304]).
	size mismatch for transformer.h.1.attn.c_attn.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([2304]).
	size mismatch for transformer.h.1.attn.c_proj.weight: copying a param with shape torch.Size([512, 768]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for transformer.h.2.attn.c_attn.weight: copying a param with shape torch.Size([768, 1536]) from checkpoint, the shape in current model is torch.Size([768, 2304]).
	size mismatch for transformer.h.2.attn.c_attn.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([2304]).
	size mismatch for transformer.h.2.attn.c_proj.weight: copying a param with shape torch.Size([512, 768]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for transformer.h.3.attn.c_attn.weight: copying a param with shape torch.Size([768, 1536]) from checkpoint, the shape in current model is torch.Size([768, 2304]).
	size mismatch for transformer.h.3.attn.c_attn.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([2304]).
	size mismatch for transformer.h.3.attn.c_proj.weight: copying a param with shape torch.Size([512, 768]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for transformer.h.4.attn.c_attn.weight: copying a param with shape torch.Size([768, 1536]) from checkpoint, the shape in current model is torch.Size([768, 2304]).
	size mismatch for transformer.h.4.attn.c_attn.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([2304]).
	size mismatch for transformer.h.4.attn.c_proj.weight: copying a param with shape torch.Size([512, 768]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for transformer.h.5.attn.c_attn.weight: copying a param with shape torch.Size([768, 1536]) from checkpoint, the shape in current model is torch.Size([768, 2304]).
	size mismatch for transformer.h.5.attn.c_attn.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([2304]).
	size mismatch for transformer.h.5.attn.c_proj.weight: copying a param with shape torch.Size([512, 768]) from checkpoint, the shape in current model is torch.Size([768, 768]).