In [1]:
import torch
import json
import time
import statistics
import numpy as np
import pickle as pkl

from tqdm import tqdm


from transformers import AutoTokenizer
from transformers.models.bloom.configuration_bloom import BloomConfig
from pruning.pruned_bloom import PrunedBloomForCausalLM
from node_attribution.utils import count_params
from rouge_score import rouge_scorer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = AutoTokenizer.from_pretrained(f"bigscience/bloomz-1b1")

In [3]:
data = pkl.load(open("44_human_filtered_conv_pairs.pkl", "rb"))
cali_data = data[:22]
val_data = data[22:]

In [4]:
def score(model, tokenizer, sentence):
    tensor_input = tokenizer.encode(sentence, return_tensors='pt')
    repeat_input = tensor_input.repeat(tensor_input.size(-1)-2, 1)
    mask = torch.ones(tensor_input.size(-1) - 1).diag(1)[:-2]
    masked_input = repeat_input.masked_fill(mask == 1, tokenizer.pad_token_id)
    labels = repeat_input.masked_fill( masked_input != tokenizer.pad_token_id, -100)
    with torch.inference_mode():
        loss = model(masked_input, labels=labels).loss
    return np.exp(loss.item())

In [5]:
def calc_p(data):
    perplexity_sum = 0
    for pair in tqdm(data):
        perplexity = score(sentence=pair, model=pruned_model, tokenizer=tokenizer)
        perplexity_sum += perplexity
    
    p = perplexity_sum / len(val_data)
    
    return p

In [6]:
weights_path = "pruned_30percent_560m_bloom.pt"
state_dict_shapes_path = "pruned_30percent_560m_bloom_state_dict_shapes.pkl"

bloom_config = BloomConfig(
    vocab_size=250880,
    hidden_size=1024,
    n_layer=24,
    n_head=16,
    layer_norm_epsilon=1e-5,
    initializer_range=0.02,
    use_cache=True,
    bos_token_id=1,
    eos_token_id=2,
    apply_residual_connection_post_layernorm=False,
    hidden_dropout=0.0,
    attention_dropout=0.0,
    pretraining_tp=1,  # TP rank used when training with megatron
    slow_but_exact=False,
    attention_softmax_in_fp32=True,
    bias_dropout_fusion=True,
    masked_softmax_fusion=True,
    offset_alibi=100,
    pad_token_id=3,
    seq_length=2048,
    skip_bias_add=True,
    skip_bias_add_qkv=False,
    unk_token_id=0,
    
)

In [7]:
pruned_model = PrunedBloomForCausalLM(bloom_config, state_dict_shapes_path)

In [8]:
pruned_model.load_state_dict(torch.load(weights_path))

<All keys matched successfully>

In [9]:
pruned_percent = 1.0 - (count_params(pruned_model)[-1] / 559214592)
print(pruned_percent)

0.30908961152430015


In [10]:
print(calc_p(cali_data))
print(calc_p(val_data))

  0%|                                                                                                                                                               | 0/22 [00:00<?, ?it/s]

0
torch.Size([57, 59, 256])
1
torch.Size([57, 59, 256])
2
torch.Size([57, 59, 256])
3
torch.Size([57, 59, 256])
4
torch.Size([57, 59, 256])
5
torch.Size([57, 59, 256])
6
torch.Size([57, 59, 256])
7
torch.Size([57, 59, 256])
8
torch.Size([57, 59, 256])
9
torch.Size([57, 59, 256])
10
torch.Size([57, 59, 256])
11
torch.Size([57, 59, 256])
12
torch.Size([57, 59, 256])
13
torch.Size([57, 59, 256])
14
torch.Size([57, 59, 320])
15
torch.Size([57, 59, 256])
16
torch.Size([57, 59, 256])
17
torch.Size([57, 59, 256])
18
torch.Size([57, 59, 320])
19
torch.Size([57, 59, 256])
20
torch.Size([57, 59, 704])
21
torch.Size([57, 59, 1024])
22
torch.Size([57, 59, 1024])
23
torch.Size([57, 59, 256])


  5%|██████▊                                                                                                                                                | 1/22 [00:03<01:09,  3.29s/it]

0
torch.Size([53, 55, 256])
1
torch.Size([53, 55, 256])
2
torch.Size([53, 55, 256])
3
torch.Size([53, 55, 256])
4
torch.Size([53, 55, 256])
5
torch.Size([53, 55, 256])
6
torch.Size([53, 55, 256])
7
torch.Size([53, 55, 256])
8
torch.Size([53, 55, 256])
9
torch.Size([53, 55, 256])
10
torch.Size([53, 55, 256])
11
torch.Size([53, 55, 256])
12
torch.Size([53, 55, 256])
13
torch.Size([53, 55, 256])
14
torch.Size([53, 55, 320])
15
torch.Size([53, 55, 256])
16
torch.Size([53, 55, 256])
17
torch.Size([53, 55, 256])
18
torch.Size([53, 55, 320])
19
torch.Size([53, 55, 256])
20
torch.Size([53, 55, 704])
21
torch.Size([53, 55, 1024])
22
torch.Size([53, 55, 1024])
23
torch.Size([53, 55, 256])


  9%|█████████████▋                                                                                                                                         | 2/22 [00:05<00:58,  2.92s/it]

0
torch.Size([64, 66, 256])
1
torch.Size([64, 66, 256])
2
torch.Size([64, 66, 256])
3
torch.Size([64, 66, 256])
4
torch.Size([64, 66, 256])
5
torch.Size([64, 66, 256])
6
torch.Size([64, 66, 256])
7
torch.Size([64, 66, 256])
8
torch.Size([64, 66, 256])
9
torch.Size([64, 66, 256])
10
torch.Size([64, 66, 256])
11
torch.Size([64, 66, 256])
12
torch.Size([64, 66, 256])
13
torch.Size([64, 66, 256])
14
torch.Size([64, 66, 320])
15
torch.Size([64, 66, 256])
16
torch.Size([64, 66, 256])
17
torch.Size([64, 66, 256])
18
torch.Size([64, 66, 320])
19
torch.Size([64, 66, 256])
20
torch.Size([64, 66, 704])
21
torch.Size([64, 66, 1024])
22
torch.Size([64, 66, 1024])
23
torch.Size([64, 66, 256])


 14%|████████████████████▌                                                                                                                                  | 3/22 [00:09<01:03,  3.36s/it]

0
torch.Size([79, 81, 256])
1
torch.Size([79, 81, 256])
2
torch.Size([79, 81, 256])
3
torch.Size([79, 81, 256])
4
torch.Size([79, 81, 256])
5
torch.Size([79, 81, 256])
6
torch.Size([79, 81, 256])
7
torch.Size([79, 81, 256])
8
torch.Size([79, 81, 256])
9
torch.Size([79, 81, 256])
10
torch.Size([79, 81, 256])
11
torch.Size([79, 81, 256])
12
torch.Size([79, 81, 256])
13
torch.Size([79, 81, 256])
14
torch.Size([79, 81, 320])
15
torch.Size([79, 81, 256])
16
torch.Size([79, 81, 256])
17
torch.Size([79, 81, 256])
18
torch.Size([79, 81, 320])
19
torch.Size([79, 81, 256])
20
torch.Size([79, 81, 704])
21
torch.Size([79, 81, 1024])
22
torch.Size([79, 81, 1024])
23
torch.Size([79, 81, 256])


 18%|███████████████████████████▍                                                                                                                           | 4/22 [00:15<01:20,  4.46s/it]

0
torch.Size([53, 55, 256])
1
torch.Size([53, 55, 256])
2
torch.Size([53, 55, 256])
3
torch.Size([53, 55, 256])
4
torch.Size([53, 55, 256])
5
torch.Size([53, 55, 256])
6
torch.Size([53, 55, 256])
7
torch.Size([53, 55, 256])
8
torch.Size([53, 55, 256])
9
torch.Size([53, 55, 256])
10
torch.Size([53, 55, 256])
11
torch.Size([53, 55, 256])
12
torch.Size([53, 55, 256])
13
torch.Size([53, 55, 256])
14
torch.Size([53, 55, 320])
15
torch.Size([53, 55, 256])
16
torch.Size([53, 55, 256])
17
torch.Size([53, 55, 256])
18
torch.Size([53, 55, 320])
19
torch.Size([53, 55, 256])
20
torch.Size([53, 55, 704])
21
torch.Size([53, 55, 1024])
22
torch.Size([53, 55, 1024])
23
torch.Size([53, 55, 256])


 23%|██████████████████████████████████▎                                                                                                                    | 5/22 [00:18<01:05,  3.86s/it]

0
torch.Size([104, 106, 256])
1
torch.Size([104, 106, 256])
2
torch.Size([104, 106, 256])
3
torch.Size([104, 106, 256])
4
torch.Size([104, 106, 256])
5
torch.Size([104, 106, 256])
6
torch.Size([104, 106, 256])
7
torch.Size([104, 106, 256])
8
torch.Size([104, 106, 256])
9
torch.Size([104, 106, 256])
10
torch.Size([104, 106, 256])
11
torch.Size([104, 106, 256])
12
torch.Size([104, 106, 256])
13
torch.Size([104, 106, 256])
14
torch.Size([104, 106, 320])
15
torch.Size([104, 106, 256])
16
torch.Size([104, 106, 256])
17
torch.Size([104, 106, 256])
18
torch.Size([104, 106, 320])
19
torch.Size([104, 106, 256])
20
torch.Size([104, 106, 704])
21
torch.Size([104, 106, 1024])
22
torch.Size([104, 106, 1024])
23
torch.Size([104, 106, 256])


 27%|█████████████████████████████████████████▏                                                                                                             | 6/22 [00:29<01:39,  6.23s/it]

0
torch.Size([60, 62, 256])
1
torch.Size([60, 62, 256])
2
torch.Size([60, 62, 256])
3
torch.Size([60, 62, 256])
4
torch.Size([60, 62, 256])
5
torch.Size([60, 62, 256])
6
torch.Size([60, 62, 256])
7
torch.Size([60, 62, 256])
8
torch.Size([60, 62, 256])
9
torch.Size([60, 62, 256])
10
torch.Size([60, 62, 256])
11
torch.Size([60, 62, 256])
12
torch.Size([60, 62, 256])
13
torch.Size([60, 62, 256])
14
torch.Size([60, 62, 320])
15
torch.Size([60, 62, 256])
16
torch.Size([60, 62, 256])
17
torch.Size([60, 62, 256])
18
torch.Size([60, 62, 320])
19
torch.Size([60, 62, 256])
20
torch.Size([60, 62, 704])
21
torch.Size([60, 62, 1024])
22
torch.Size([60, 62, 1024])
23
torch.Size([60, 62, 256])


 32%|████████████████████████████████████████████████                                                                                                       | 7/22 [00:33<01:21,  5.41s/it]

0
torch.Size([98, 100, 256])
1
torch.Size([98, 100, 256])
2
torch.Size([98, 100, 256])
3
torch.Size([98, 100, 256])
4
torch.Size([98, 100, 256])
5
torch.Size([98, 100, 256])
6
torch.Size([98, 100, 256])
7
torch.Size([98, 100, 256])
8
torch.Size([98, 100, 256])
9
torch.Size([98, 100, 256])
10
torch.Size([98, 100, 256])
11
torch.Size([98, 100, 256])
12
torch.Size([98, 100, 256])
13
torch.Size([98, 100, 256])
14
torch.Size([98, 100, 320])
15
torch.Size([98, 100, 256])
16
torch.Size([98, 100, 256])
17
torch.Size([98, 100, 256])
18
torch.Size([98, 100, 320])
19
torch.Size([98, 100, 256])
20
torch.Size([98, 100, 704])
21
torch.Size([98, 100, 1024])
22
torch.Size([98, 100, 1024])
23
torch.Size([98, 100, 256])


 36%|██████████████████████████████████████████████████████▉                                                                                                | 8/22 [00:42<01:33,  6.67s/it]

0
torch.Size([48, 50, 256])
1
torch.Size([48, 50, 256])
2
torch.Size([48, 50, 256])
3
torch.Size([48, 50, 256])
4
torch.Size([48, 50, 256])
5
torch.Size([48, 50, 256])
6
torch.Size([48, 50, 256])
7
torch.Size([48, 50, 256])
8
torch.Size([48, 50, 256])
9
torch.Size([48, 50, 256])
10
torch.Size([48, 50, 256])
11
torch.Size([48, 50, 256])
12
torch.Size([48, 50, 256])
13
torch.Size([48, 50, 256])
14
torch.Size([48, 50, 320])
15
torch.Size([48, 50, 256])
16
torch.Size([48, 50, 256])
17
torch.Size([48, 50, 256])
18
torch.Size([48, 50, 320])
19
torch.Size([48, 50, 256])
20
torch.Size([48, 50, 704])
21
torch.Size([48, 50, 1024])
22
torch.Size([48, 50, 1024])
23
torch.Size([48, 50, 256])


 41%|█████████████████████████████████████████████████████████████▊                                                                                         | 9/22 [00:44<01:08,  5.30s/it]

0
torch.Size([81, 83, 256])
1
torch.Size([81, 83, 256])
2
torch.Size([81, 83, 256])
3
torch.Size([81, 83, 256])
4
torch.Size([81, 83, 256])
5
torch.Size([81, 83, 256])
6
torch.Size([81, 83, 256])
7
torch.Size([81, 83, 256])
8
torch.Size([81, 83, 256])
9
torch.Size([81, 83, 256])
10
torch.Size([81, 83, 256])
11
torch.Size([81, 83, 256])
12
torch.Size([81, 83, 256])
13
torch.Size([81, 83, 256])
14
torch.Size([81, 83, 320])
15
torch.Size([81, 83, 256])
16
torch.Size([81, 83, 256])
17
torch.Size([81, 83, 256])
18
torch.Size([81, 83, 320])
19
torch.Size([81, 83, 256])
20
torch.Size([81, 83, 704])
21
torch.Size([81, 83, 1024])
22
torch.Size([81, 83, 1024])
23
torch.Size([81, 83, 256])


 45%|████████████████████████████████████████████████████████████████████▏                                                                                 | 10/22 [00:51<01:07,  5.66s/it]

0
torch.Size([86, 88, 256])
1
torch.Size([86, 88, 256])
2
torch.Size([86, 88, 256])
3
torch.Size([86, 88, 256])
4
torch.Size([86, 88, 256])
5
torch.Size([86, 88, 256])
6
torch.Size([86, 88, 256])
7
torch.Size([86, 88, 256])
8
torch.Size([86, 88, 256])
9
torch.Size([86, 88, 256])
10
torch.Size([86, 88, 256])
11
torch.Size([86, 88, 256])
12
torch.Size([86, 88, 256])
13
torch.Size([86, 88, 256])
14
torch.Size([86, 88, 320])
15
torch.Size([86, 88, 256])
16
torch.Size([86, 88, 256])
17
torch.Size([86, 88, 256])
18
torch.Size([86, 88, 320])
19
torch.Size([86, 88, 256])
20
torch.Size([86, 88, 704])
21
torch.Size([86, 88, 1024])
22
torch.Size([86, 88, 1024])
23
torch.Size([86, 88, 256])


 50%|███████████████████████████████████████████████████████████████████████████                                                                           | 11/22 [00:58<01:07,  6.13s/it]

0
torch.Size([66, 68, 256])
1
torch.Size([66, 68, 256])
2
torch.Size([66, 68, 256])
3
torch.Size([66, 68, 256])
4
torch.Size([66, 68, 256])
5
torch.Size([66, 68, 256])
6
torch.Size([66, 68, 256])
7
torch.Size([66, 68, 256])
8
torch.Size([66, 68, 256])
9
torch.Size([66, 68, 256])
10
torch.Size([66, 68, 256])
11
torch.Size([66, 68, 256])
12
torch.Size([66, 68, 256])
13
torch.Size([66, 68, 256])
14
torch.Size([66, 68, 320])
15
torch.Size([66, 68, 256])
16
torch.Size([66, 68, 256])
17
torch.Size([66, 68, 256])
18
torch.Size([66, 68, 320])
19
torch.Size([66, 68, 256])
20
torch.Size([66, 68, 704])
21
torch.Size([66, 68, 1024])
22
torch.Size([66, 68, 1024])
23
torch.Size([66, 68, 256])


 55%|█████████████████████████████████████████████████████████████████████████████████▊                                                                    | 12/22 [01:02<00:55,  5.53s/it]

0
torch.Size([127, 129, 256])
1
torch.Size([127, 129, 256])
2
torch.Size([127, 129, 256])
3
torch.Size([127, 129, 256])
4
torch.Size([127, 129, 256])
5
torch.Size([127, 129, 256])
6
torch.Size([127, 129, 256])
7
torch.Size([127, 129, 256])
8
torch.Size([127, 129, 256])
9
torch.Size([127, 129, 256])
10
torch.Size([127, 129, 256])
11
torch.Size([127, 129, 256])
12
torch.Size([127, 129, 256])
13
torch.Size([127, 129, 256])
14
torch.Size([127, 129, 320])
15
torch.Size([127, 129, 256])
16
torch.Size([127, 129, 256])
17
torch.Size([127, 129, 256])
18
torch.Size([127, 129, 320])
19
torch.Size([127, 129, 256])
20
torch.Size([127, 129, 704])
21
torch.Size([127, 129, 1024])
22
torch.Size([127, 129, 1024])
23
torch.Size([127, 129, 256])


 59%|████████████████████████████████████████████████████████████████████████████████████████▋                                                             | 13/22 [01:34<02:00, 13.43s/it]

0
torch.Size([89, 91, 256])
1
torch.Size([89, 91, 256])
2
torch.Size([89, 91, 256])
3
torch.Size([89, 91, 256])
4
torch.Size([89, 91, 256])
5
torch.Size([89, 91, 256])
6
torch.Size([89, 91, 256])
7
torch.Size([89, 91, 256])
8
torch.Size([89, 91, 256])
9
torch.Size([89, 91, 256])
10
torch.Size([89, 91, 256])
11
torch.Size([89, 91, 256])
12
torch.Size([89, 91, 256])
13
torch.Size([89, 91, 256])
14
torch.Size([89, 91, 320])
15
torch.Size([89, 91, 256])
16
torch.Size([89, 91, 256])
17
torch.Size([89, 91, 256])
18
torch.Size([89, 91, 320])
19
torch.Size([89, 91, 256])
20
torch.Size([89, 91, 704])
21
torch.Size([89, 91, 1024])
22
torch.Size([89, 91, 1024])
23
torch.Size([89, 91, 256])


 64%|███████████████████████████████████████████████████████████████████████████████████████████████▍                                                      | 14/22 [01:42<01:33, 11.73s/it]

0
torch.Size([126, 128, 256])
1
torch.Size([126, 128, 256])
2
torch.Size([126, 128, 256])
3
torch.Size([126, 128, 256])
4
torch.Size([126, 128, 256])
5
torch.Size([126, 128, 256])
6
torch.Size([126, 128, 256])
7
torch.Size([126, 128, 256])
8
torch.Size([126, 128, 256])
9
torch.Size([126, 128, 256])
10
torch.Size([126, 128, 256])
11
torch.Size([126, 128, 256])
12
torch.Size([126, 128, 256])
13
torch.Size([126, 128, 256])
14
torch.Size([126, 128, 320])
15
torch.Size([126, 128, 256])
16
torch.Size([126, 128, 256])
17
torch.Size([126, 128, 256])
18
torch.Size([126, 128, 320])
19
torch.Size([126, 128, 256])
20
torch.Size([126, 128, 704])
21
torch.Size([126, 128, 1024])
22
torch.Size([126, 128, 1024])
23
torch.Size([126, 128, 256])


 68%|██████████████████████████████████████████████████████████████████████████████████████████████████████▎                                               | 15/22 [02:11<01:59, 17.11s/it]

0
torch.Size([117, 119, 256])
1
torch.Size([117, 119, 256])
2
torch.Size([117, 119, 256])
3
torch.Size([117, 119, 256])
4
torch.Size([117, 119, 256])
5
torch.Size([117, 119, 256])
6
torch.Size([117, 119, 256])
7
torch.Size([117, 119, 256])
8
torch.Size([117, 119, 256])
9
torch.Size([117, 119, 256])
10
torch.Size([117, 119, 256])
11
torch.Size([117, 119, 256])
12
torch.Size([117, 119, 256])
13
torch.Size([117, 119, 256])
14
torch.Size([117, 119, 320])
15
torch.Size([117, 119, 256])
16
torch.Size([117, 119, 256])
17
torch.Size([117, 119, 256])
18
torch.Size([117, 119, 320])
19
torch.Size([117, 119, 256])
20
torch.Size([117, 119, 704])
21
torch.Size([117, 119, 1024])
22
torch.Size([117, 119, 1024])
23
torch.Size([117, 119, 256])


 73%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████                                         | 16/22 [02:24<01:34, 15.78s/it]

0
torch.Size([120, 122, 256])
1
torch.Size([120, 122, 256])
2
torch.Size([120, 122, 256])
3
torch.Size([120, 122, 256])
4
torch.Size([120, 122, 256])
5
torch.Size([120, 122, 256])
6
torch.Size([120, 122, 256])
7
torch.Size([120, 122, 256])
8
torch.Size([120, 122, 256])
9
torch.Size([120, 122, 256])
10
torch.Size([120, 122, 256])
11
torch.Size([120, 122, 256])
12
torch.Size([120, 122, 256])
13
torch.Size([120, 122, 256])
14
torch.Size([120, 122, 320])
15
torch.Size([120, 122, 256])
16
torch.Size([120, 122, 256])
17
torch.Size([120, 122, 256])
18
torch.Size([120, 122, 320])
19
torch.Size([120, 122, 256])
20
torch.Size([120, 122, 704])
21
torch.Size([120, 122, 1024])
22
torch.Size([120, 122, 1024])
23
torch.Size([120, 122, 256])


 77%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                  | 17/22 [02:38<01:15, 15.19s/it]

0
torch.Size([59, 61, 256])
1
torch.Size([59, 61, 256])
2
torch.Size([59, 61, 256])
3
torch.Size([59, 61, 256])
4
torch.Size([59, 61, 256])
5
torch.Size([59, 61, 256])
6
torch.Size([59, 61, 256])
7
torch.Size([59, 61, 256])
8
torch.Size([59, 61, 256])
9
torch.Size([59, 61, 256])
10
torch.Size([59, 61, 256])
11
torch.Size([59, 61, 256])
12
torch.Size([59, 61, 256])
13
torch.Size([59, 61, 256])
14
torch.Size([59, 61, 320])
15
torch.Size([59, 61, 256])
16
torch.Size([59, 61, 256])
17
torch.Size([59, 61, 256])
18
torch.Size([59, 61, 320])
19
torch.Size([59, 61, 256])
20
torch.Size([59, 61, 704])
21
torch.Size([59, 61, 1024])
22
torch.Size([59, 61, 1024])
23
torch.Size([59, 61, 256])


 82%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                           | 18/22 [02:41<00:46, 11.59s/it]

0
torch.Size([70, 72, 256])
1
torch.Size([70, 72, 256])
2
torch.Size([70, 72, 256])
3
torch.Size([70, 72, 256])
4
torch.Size([70, 72, 256])
5
torch.Size([70, 72, 256])
6
torch.Size([70, 72, 256])
7
torch.Size([70, 72, 256])
8
torch.Size([70, 72, 256])
9
torch.Size([70, 72, 256])
10
torch.Size([70, 72, 256])
11
torch.Size([70, 72, 256])
12
torch.Size([70, 72, 256])
13
torch.Size([70, 72, 256])
14
torch.Size([70, 72, 320])
15
torch.Size([70, 72, 256])
16
torch.Size([70, 72, 256])
17
torch.Size([70, 72, 256])
18
torch.Size([70, 72, 320])
19
torch.Size([70, 72, 256])
20
torch.Size([70, 72, 704])
21
torch.Size([70, 72, 1024])
22
torch.Size([70, 72, 1024])
23
torch.Size([70, 72, 256])


 86%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                    | 19/22 [02:46<00:28,  9.46s/it]

0
torch.Size([69, 71, 256])
1
torch.Size([69, 71, 256])
2
torch.Size([69, 71, 256])
3
torch.Size([69, 71, 256])
4
torch.Size([69, 71, 256])
5
torch.Size([69, 71, 256])
6
torch.Size([69, 71, 256])
7
torch.Size([69, 71, 256])
8
torch.Size([69, 71, 256])
9
torch.Size([69, 71, 256])
10
torch.Size([69, 71, 256])
11
torch.Size([69, 71, 256])
12
torch.Size([69, 71, 256])
13
torch.Size([69, 71, 256])
14
torch.Size([69, 71, 320])
15
torch.Size([69, 71, 256])
16
torch.Size([69, 71, 256])
17
torch.Size([69, 71, 256])
18
torch.Size([69, 71, 320])
19
torch.Size([69, 71, 256])
20
torch.Size([69, 71, 704])
21
torch.Size([69, 71, 1024])
22
torch.Size([69, 71, 1024])
23
torch.Size([69, 71, 256])


 91%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎             | 20/22 [02:50<00:15,  7.95s/it]

0
torch.Size([104, 106, 256])
1
torch.Size([104, 106, 256])
2
torch.Size([104, 106, 256])
3
torch.Size([104, 106, 256])
4
torch.Size([104, 106, 256])
5
torch.Size([104, 106, 256])
6
torch.Size([104, 106, 256])
7
torch.Size([104, 106, 256])
8
torch.Size([104, 106, 256])
9
torch.Size([104, 106, 256])
10
torch.Size([104, 106, 256])
11
torch.Size([104, 106, 256])
12
torch.Size([104, 106, 256])
13
torch.Size([104, 106, 256])
14
torch.Size([104, 106, 320])
15
torch.Size([104, 106, 256])
16
torch.Size([104, 106, 256])
17
torch.Size([104, 106, 256])
18
torch.Size([104, 106, 320])
19
torch.Size([104, 106, 256])
20
torch.Size([104, 106, 704])
21
torch.Size([104, 106, 1024])
22
torch.Size([104, 106, 1024])
23
torch.Size([104, 106, 256])


 95%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏      | 21/22 [03:01<00:08,  8.74s/it]

0
torch.Size([91, 93, 256])
1
torch.Size([91, 93, 256])
2
torch.Size([91, 93, 256])
3
torch.Size([91, 93, 256])
4
torch.Size([91, 93, 256])
5
torch.Size([91, 93, 256])
6
torch.Size([91, 93, 256])
7
torch.Size([91, 93, 256])
8
torch.Size([91, 93, 256])
9
torch.Size([91, 93, 256])
10
torch.Size([91, 93, 256])
11
torch.Size([91, 93, 256])
12
torch.Size([91, 93, 256])
13
torch.Size([91, 93, 256])
14
torch.Size([91, 93, 320])
15
torch.Size([91, 93, 256])
16
torch.Size([91, 93, 256])
17
torch.Size([91, 93, 256])
18
torch.Size([91, 93, 320])
19
torch.Size([91, 93, 256])
20
torch.Size([91, 93, 704])
21
torch.Size([91, 93, 1024])
22
torch.Size([91, 93, 1024])
23
torch.Size([91, 93, 256])


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 22/22 [03:09<00:00,  8.60s/it]


21432584.78894007


  0%|                                                                                                                                                               | 0/22 [00:00<?, ?it/s]

0
torch.Size([86, 88, 256])
1
torch.Size([86, 88, 256])
2
torch.Size([86, 88, 256])
3
torch.Size([86, 88, 256])
4
torch.Size([86, 88, 256])
5
torch.Size([86, 88, 256])
6
torch.Size([86, 88, 256])
7
torch.Size([86, 88, 256])
8
torch.Size([86, 88, 256])
9
torch.Size([86, 88, 256])
10
torch.Size([86, 88, 256])
11
torch.Size([86, 88, 256])
12
torch.Size([86, 88, 256])
13
torch.Size([86, 88, 256])
14
torch.Size([86, 88, 320])
15
torch.Size([86, 88, 256])
16
torch.Size([86, 88, 256])
17
torch.Size([86, 88, 256])
18
torch.Size([86, 88, 320])
19
torch.Size([86, 88, 256])
20
torch.Size([86, 88, 704])
21
torch.Size([86, 88, 1024])
22
torch.Size([86, 88, 1024])
23
torch.Size([86, 88, 256])


  5%|██████▊                                                                                                                                                | 1/22 [00:07<02:30,  7.17s/it]

0
torch.Size([98, 100, 256])
1
torch.Size([98, 100, 256])
2
torch.Size([98, 100, 256])
3
torch.Size([98, 100, 256])
4
torch.Size([98, 100, 256])
5
torch.Size([98, 100, 256])
6
torch.Size([98, 100, 256])
7
torch.Size([98, 100, 256])
8
torch.Size([98, 100, 256])
9
torch.Size([98, 100, 256])
10
torch.Size([98, 100, 256])
11
torch.Size([98, 100, 256])
12
torch.Size([98, 100, 256])
13
torch.Size([98, 100, 256])
14
torch.Size([98, 100, 320])
15
torch.Size([98, 100, 256])
16
torch.Size([98, 100, 256])
17
torch.Size([98, 100, 256])
18
torch.Size([98, 100, 320])
19
torch.Size([98, 100, 256])
20
torch.Size([98, 100, 704])
21
torch.Size([98, 100, 1024])
22
torch.Size([98, 100, 1024])
23
torch.Size([98, 100, 256])


  9%|█████████████▋                                                                                                                                         | 2/22 [00:16<02:46,  8.31s/it]

0
torch.Size([103, 105, 256])
1
torch.Size([103, 105, 256])
2
torch.Size([103, 105, 256])
3
torch.Size([103, 105, 256])
4
torch.Size([103, 105, 256])
5
torch.Size([103, 105, 256])
6
torch.Size([103, 105, 256])
7
torch.Size([103, 105, 256])
8
torch.Size([103, 105, 256])
9
torch.Size([103, 105, 256])
10
torch.Size([103, 105, 256])
11
torch.Size([103, 105, 256])
12
torch.Size([103, 105, 256])
13
torch.Size([103, 105, 256])
14
torch.Size([103, 105, 320])
15
torch.Size([103, 105, 256])
16
torch.Size([103, 105, 256])
17
torch.Size([103, 105, 256])
18
torch.Size([103, 105, 320])
19
torch.Size([103, 105, 256])
20
torch.Size([103, 105, 704])
21
torch.Size([103, 105, 1024])
22
torch.Size([103, 105, 1024])
23
torch.Size([103, 105, 256])


 14%|████████████████████▌                                                                                                                                  | 3/22 [00:26<02:54,  9.20s/it]

0
torch.Size([67, 69, 256])
1
torch.Size([67, 69, 256])
2
torch.Size([67, 69, 256])
3
torch.Size([67, 69, 256])
4
torch.Size([67, 69, 256])
5
torch.Size([67, 69, 256])
6
torch.Size([67, 69, 256])
7
torch.Size([67, 69, 256])
8
torch.Size([67, 69, 256])
9
torch.Size([67, 69, 256])
10
torch.Size([67, 69, 256])
11
torch.Size([67, 69, 256])
12
torch.Size([67, 69, 256])
13
torch.Size([67, 69, 256])
14
torch.Size([67, 69, 320])
15
torch.Size([67, 69, 256])
16
torch.Size([67, 69, 256])
17
torch.Size([67, 69, 256])
18
torch.Size([67, 69, 320])
19
torch.Size([67, 69, 256])
20
torch.Size([67, 69, 704])
21
torch.Size([67, 69, 1024])
22
torch.Size([67, 69, 1024])
23
torch.Size([67, 69, 256])


 18%|███████████████████████████▍                                                                                                                           | 4/22 [00:30<02:11,  7.29s/it]

0
torch.Size([62, 64, 256])
1
torch.Size([62, 64, 256])
2
torch.Size([62, 64, 256])
3
torch.Size([62, 64, 256])
4
torch.Size([62, 64, 256])
5
torch.Size([62, 64, 256])
6
torch.Size([62, 64, 256])
7
torch.Size([62, 64, 256])
8
torch.Size([62, 64, 256])
9
torch.Size([62, 64, 256])
10
torch.Size([62, 64, 256])
11
torch.Size([62, 64, 256])
12
torch.Size([62, 64, 256])
13
torch.Size([62, 64, 256])
14
torch.Size([62, 64, 320])
15
torch.Size([62, 64, 256])
16
torch.Size([62, 64, 256])
17
torch.Size([62, 64, 256])
18
torch.Size([62, 64, 320])
19
torch.Size([62, 64, 256])
20
torch.Size([62, 64, 704])
21
torch.Size([62, 64, 1024])
22
torch.Size([62, 64, 1024])
23
torch.Size([62, 64, 256])


 23%|██████████████████████████████████▎                                                                                                                    | 5/22 [00:34<01:42,  6.05s/it]

0
torch.Size([39, 41, 256])
1
torch.Size([39, 41, 256])
2
torch.Size([39, 41, 256])
3
torch.Size([39, 41, 256])
4
torch.Size([39, 41, 256])
5
torch.Size([39, 41, 256])
6
torch.Size([39, 41, 256])
7
torch.Size([39, 41, 256])
8
torch.Size([39, 41, 256])
9
torch.Size([39, 41, 256])
10
torch.Size([39, 41, 256])
11
torch.Size([39, 41, 256])
12
torch.Size([39, 41, 256])
13
torch.Size([39, 41, 256])
14
torch.Size([39, 41, 320])
15
torch.Size([39, 41, 256])
16
torch.Size([39, 41, 256])
17
torch.Size([39, 41, 256])
18
torch.Size([39, 41, 320])
19
torch.Size([39, 41, 256])
20
torch.Size([39, 41, 704])
21
torch.Size([39, 41, 1024])
22
torch.Size([39, 41, 1024])
23
torch.Size([39, 41, 256])


 27%|█████████████████████████████████████████▏                                                                                                             | 6/22 [00:36<01:13,  4.57s/it]

0
torch.Size([39, 41, 256])
1
torch.Size([39, 41, 256])
2
torch.Size([39, 41, 256])
3
torch.Size([39, 41, 256])
4
torch.Size([39, 41, 256])
5
torch.Size([39, 41, 256])
6
torch.Size([39, 41, 256])
7
torch.Size([39, 41, 256])
8
torch.Size([39, 41, 256])
9
torch.Size([39, 41, 256])
10
torch.Size([39, 41, 256])
11
torch.Size([39, 41, 256])
12
torch.Size([39, 41, 256])
13
torch.Size([39, 41, 256])
14
torch.Size([39, 41, 320])
15
torch.Size([39, 41, 256])
16
torch.Size([39, 41, 256])
17
torch.Size([39, 41, 256])
18
torch.Size([39, 41, 320])
19
torch.Size([39, 41, 256])
20
torch.Size([39, 41, 704])
21
torch.Size([39, 41, 1024])
22
torch.Size([39, 41, 1024])
23
torch.Size([39, 41, 256])


 32%|████████████████████████████████████████████████                                                                                                       | 7/22 [00:38<00:54,  3.60s/it]

0
torch.Size([68, 70, 256])
1
torch.Size([68, 70, 256])
2
torch.Size([68, 70, 256])
3
torch.Size([68, 70, 256])
4
torch.Size([68, 70, 256])
5
torch.Size([68, 70, 256])
6
torch.Size([68, 70, 256])
7
torch.Size([68, 70, 256])
8
torch.Size([68, 70, 256])
9
torch.Size([68, 70, 256])
10
torch.Size([68, 70, 256])
11
torch.Size([68, 70, 256])
12
torch.Size([68, 70, 256])
13
torch.Size([68, 70, 256])
14
torch.Size([68, 70, 320])
15
torch.Size([68, 70, 256])
16
torch.Size([68, 70, 256])
17
torch.Size([68, 70, 256])
18
torch.Size([68, 70, 320])
19
torch.Size([68, 70, 256])
20
torch.Size([68, 70, 704])
21
torch.Size([68, 70, 1024])
22
torch.Size([68, 70, 1024])
23
torch.Size([68, 70, 256])


 36%|██████████████████████████████████████████████████████▉                                                                                                | 8/22 [00:42<00:54,  3.88s/it]

0
torch.Size([75, 77, 256])
1
torch.Size([75, 77, 256])
2
torch.Size([75, 77, 256])
3
torch.Size([75, 77, 256])
4
torch.Size([75, 77, 256])
5
torch.Size([75, 77, 256])
6
torch.Size([75, 77, 256])
7
torch.Size([75, 77, 256])
8
torch.Size([75, 77, 256])
9
torch.Size([75, 77, 256])
10
torch.Size([75, 77, 256])
11
torch.Size([75, 77, 256])
12
torch.Size([75, 77, 256])
13
torch.Size([75, 77, 256])
14
torch.Size([75, 77, 320])
15
torch.Size([75, 77, 256])
16
torch.Size([75, 77, 256])
17
torch.Size([75, 77, 256])
18
torch.Size([75, 77, 320])
19
torch.Size([75, 77, 256])
20
torch.Size([75, 77, 704])
21
torch.Size([75, 77, 1024])
22
torch.Size([75, 77, 1024])
23
torch.Size([75, 77, 256])


 41%|█████████████████████████████████████████████████████████████▊                                                                                         | 9/22 [00:47<00:56,  4.37s/it]

0
torch.Size([79, 81, 256])
1
torch.Size([79, 81, 256])
2
torch.Size([79, 81, 256])
3
torch.Size([79, 81, 256])
4
torch.Size([79, 81, 256])
5
torch.Size([79, 81, 256])
6
torch.Size([79, 81, 256])
7
torch.Size([79, 81, 256])
8
torch.Size([79, 81, 256])
9
torch.Size([79, 81, 256])
10
torch.Size([79, 81, 256])
11
torch.Size([79, 81, 256])
12
torch.Size([79, 81, 256])
13
torch.Size([79, 81, 256])
14
torch.Size([79, 81, 320])
15
torch.Size([79, 81, 256])
16
torch.Size([79, 81, 256])
17
torch.Size([79, 81, 256])
18
torch.Size([79, 81, 320])
19
torch.Size([79, 81, 256])
20
torch.Size([79, 81, 704])
21
torch.Size([79, 81, 1024])
22
torch.Size([79, 81, 1024])
23
torch.Size([79, 81, 256])


 45%|████████████████████████████████████████████████████████████████████▏                                                                                 | 10/22 [00:54<00:58,  4.92s/it]

0
torch.Size([64, 66, 256])
1
torch.Size([64, 66, 256])
2
torch.Size([64, 66, 256])
3
torch.Size([64, 66, 256])
4
torch.Size([64, 66, 256])
5
torch.Size([64, 66, 256])
6
torch.Size([64, 66, 256])
7
torch.Size([64, 66, 256])
8
torch.Size([64, 66, 256])
9
torch.Size([64, 66, 256])
10
torch.Size([64, 66, 256])
11
torch.Size([64, 66, 256])
12
torch.Size([64, 66, 256])
13
torch.Size([64, 66, 256])
14
torch.Size([64, 66, 320])
15
torch.Size([64, 66, 256])
16
torch.Size([64, 66, 256])
17
torch.Size([64, 66, 256])
18
torch.Size([64, 66, 320])
19
torch.Size([64, 66, 256])
20
torch.Size([64, 66, 704])
21
torch.Size([64, 66, 1024])
22
torch.Size([64, 66, 1024])
23
torch.Size([64, 66, 256])


 50%|███████████████████████████████████████████████████████████████████████████                                                                           | 11/22 [00:57<00:50,  4.59s/it]

0
torch.Size([53, 55, 256])
1
torch.Size([53, 55, 256])
2
torch.Size([53, 55, 256])
3
torch.Size([53, 55, 256])
4
torch.Size([53, 55, 256])
5
torch.Size([53, 55, 256])
6
torch.Size([53, 55, 256])
7
torch.Size([53, 55, 256])
8
torch.Size([53, 55, 256])
9
torch.Size([53, 55, 256])
10
torch.Size([53, 55, 256])
11
torch.Size([53, 55, 256])
12
torch.Size([53, 55, 256])
13
torch.Size([53, 55, 256])
14
torch.Size([53, 55, 320])
15
torch.Size([53, 55, 256])
16
torch.Size([53, 55, 256])
17
torch.Size([53, 55, 256])
18
torch.Size([53, 55, 320])
19
torch.Size([53, 55, 256])
20
torch.Size([53, 55, 704])
21
torch.Size([53, 55, 1024])
22
torch.Size([53, 55, 1024])
23
torch.Size([53, 55, 256])


 55%|█████████████████████████████████████████████████████████████████████████████████▊                                                                    | 12/22 [01:00<00:40,  4.01s/it]

0
torch.Size([52, 54, 256])
1
torch.Size([52, 54, 256])
2
torch.Size([52, 54, 256])
3
torch.Size([52, 54, 256])
4
torch.Size([52, 54, 256])
5
torch.Size([52, 54, 256])
6
torch.Size([52, 54, 256])
7
torch.Size([52, 54, 256])
8
torch.Size([52, 54, 256])
9
torch.Size([52, 54, 256])
10
torch.Size([52, 54, 256])
11
torch.Size([52, 54, 256])
12
torch.Size([52, 54, 256])
13
torch.Size([52, 54, 256])
14
torch.Size([52, 54, 320])
15
torch.Size([52, 54, 256])
16
torch.Size([52, 54, 256])
17
torch.Size([52, 54, 256])
18
torch.Size([52, 54, 320])
19
torch.Size([52, 54, 256])
20
torch.Size([52, 54, 704])
21
torch.Size([52, 54, 1024])
22
torch.Size([52, 54, 1024])
23
torch.Size([52, 54, 256])


 59%|████████████████████████████████████████████████████████████████████████████████████████▋                                                             | 13/22 [01:03<00:32,  3.58s/it]

0
torch.Size([91, 93, 256])
1
torch.Size([91, 93, 256])
2
torch.Size([91, 93, 256])
3
torch.Size([91, 93, 256])
4
torch.Size([91, 93, 256])
5
torch.Size([91, 93, 256])
6
torch.Size([91, 93, 256])
7
torch.Size([91, 93, 256])
8
torch.Size([91, 93, 256])
9
torch.Size([91, 93, 256])
10
torch.Size([91, 93, 256])
11
torch.Size([91, 93, 256])
12
torch.Size([91, 93, 256])
13
torch.Size([91, 93, 256])
14
torch.Size([91, 93, 320])
15
torch.Size([91, 93, 256])
16
torch.Size([91, 93, 256])
17
torch.Size([91, 93, 256])
18
torch.Size([91, 93, 320])
19
torch.Size([91, 93, 256])
20
torch.Size([91, 93, 704])
21
torch.Size([91, 93, 1024])
22
torch.Size([91, 93, 1024])
23
torch.Size([91, 93, 256])


 64%|███████████████████████████████████████████████████████████████████████████████████████████████▍                                                      | 14/22 [01:11<00:39,  4.88s/it]

0
torch.Size([40, 42, 256])
1
torch.Size([40, 42, 256])
2
torch.Size([40, 42, 256])
3
torch.Size([40, 42, 256])
4
torch.Size([40, 42, 256])
5
torch.Size([40, 42, 256])
6
torch.Size([40, 42, 256])
7
torch.Size([40, 42, 256])
8
torch.Size([40, 42, 256])
9
torch.Size([40, 42, 256])
10
torch.Size([40, 42, 256])
11
torch.Size([40, 42, 256])
12
torch.Size([40, 42, 256])
13
torch.Size([40, 42, 256])
14
torch.Size([40, 42, 320])
15
torch.Size([40, 42, 256])
16
torch.Size([40, 42, 256])
17
torch.Size([40, 42, 256])
18
torch.Size([40, 42, 320])
19
torch.Size([40, 42, 256])
20
torch.Size([40, 42, 704])
21
torch.Size([40, 42, 1024])
22
torch.Size([40, 42, 1024])
23
torch.Size([40, 42, 256])


 68%|██████████████████████████████████████████████████████████████████████████████████████████████████████▎                                               | 15/22 [01:12<00:27,  3.92s/it]

0
torch.Size([46, 48, 256])
1
torch.Size([46, 48, 256])
2
torch.Size([46, 48, 256])
3
torch.Size([46, 48, 256])
4
torch.Size([46, 48, 256])
5
torch.Size([46, 48, 256])
6
torch.Size([46, 48, 256])
7
torch.Size([46, 48, 256])
8
torch.Size([46, 48, 256])
9
torch.Size([46, 48, 256])
10
torch.Size([46, 48, 256])
11
torch.Size([46, 48, 256])
12
torch.Size([46, 48, 256])
13
torch.Size([46, 48, 256])
14
torch.Size([46, 48, 320])
15
torch.Size([46, 48, 256])
16
torch.Size([46, 48, 256])
17
torch.Size([46, 48, 256])
18
torch.Size([46, 48, 320])
19
torch.Size([46, 48, 256])
20
torch.Size([46, 48, 704])
21
torch.Size([46, 48, 1024])
22
torch.Size([46, 48, 1024])
23
torch.Size([46, 48, 256])


 73%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████                                         | 16/22 [01:14<00:20,  3.36s/it]

0
torch.Size([84, 86, 256])
1
torch.Size([84, 86, 256])
2
torch.Size([84, 86, 256])
3
torch.Size([84, 86, 256])
4
torch.Size([84, 86, 256])
5
torch.Size([84, 86, 256])
6
torch.Size([84, 86, 256])
7
torch.Size([84, 86, 256])
8
torch.Size([84, 86, 256])
9
torch.Size([84, 86, 256])
10
torch.Size([84, 86, 256])
11
torch.Size([84, 86, 256])
12
torch.Size([84, 86, 256])
13
torch.Size([84, 86, 256])
14
torch.Size([84, 86, 320])
15
torch.Size([84, 86, 256])
16
torch.Size([84, 86, 256])
17
torch.Size([84, 86, 256])
18
torch.Size([84, 86, 320])
19
torch.Size([84, 86, 256])
20
torch.Size([84, 86, 704])
21
torch.Size([84, 86, 1024])
22
torch.Size([84, 86, 1024])
23
torch.Size([84, 86, 256])


 77%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                  | 17/22 [01:21<00:21,  4.34s/it]

0
torch.Size([77, 79, 256])
1
torch.Size([77, 79, 256])
2
torch.Size([77, 79, 256])
3
torch.Size([77, 79, 256])
4
torch.Size([77, 79, 256])
5
torch.Size([77, 79, 256])
6
torch.Size([77, 79, 256])
7
torch.Size([77, 79, 256])
8
torch.Size([77, 79, 256])
9
torch.Size([77, 79, 256])
10
torch.Size([77, 79, 256])
11
torch.Size([77, 79, 256])
12
torch.Size([77, 79, 256])
13
torch.Size([77, 79, 256])
14
torch.Size([77, 79, 320])
15
torch.Size([77, 79, 256])
16
torch.Size([77, 79, 256])
17
torch.Size([77, 79, 256])
18
torch.Size([77, 79, 320])
19
torch.Size([77, 79, 256])
20
torch.Size([77, 79, 704])
21
torch.Size([77, 79, 1024])
22
torch.Size([77, 79, 1024])
23
torch.Size([77, 79, 256])


 82%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                           | 18/22 [01:26<00:18,  4.68s/it]

0
torch.Size([97, 99, 256])
1
torch.Size([97, 99, 256])
2
torch.Size([97, 99, 256])
3
torch.Size([97, 99, 256])
4
torch.Size([97, 99, 256])
5
torch.Size([97, 99, 256])
6
torch.Size([97, 99, 256])
7
torch.Size([97, 99, 256])
8
torch.Size([97, 99, 256])
9
torch.Size([97, 99, 256])
10
torch.Size([97, 99, 256])
11
torch.Size([97, 99, 256])
12
torch.Size([97, 99, 256])
13
torch.Size([97, 99, 256])
14
torch.Size([97, 99, 320])
15
torch.Size([97, 99, 256])
16
torch.Size([97, 99, 256])
17
torch.Size([97, 99, 256])
18
torch.Size([97, 99, 320])
19
torch.Size([97, 99, 256])
20
torch.Size([97, 99, 704])
21
torch.Size([97, 99, 1024])
22
torch.Size([97, 99, 1024])
23
torch.Size([97, 99, 256])


 86%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                    | 19/22 [01:35<00:17,  5.94s/it]

0
torch.Size([96, 98, 256])
1
torch.Size([96, 98, 256])
2
torch.Size([96, 98, 256])
3
torch.Size([96, 98, 256])
4
torch.Size([96, 98, 256])
5
torch.Size([96, 98, 256])
6
torch.Size([96, 98, 256])
7
torch.Size([96, 98, 256])
8
torch.Size([96, 98, 256])
9
torch.Size([96, 98, 256])
10
torch.Size([96, 98, 256])
11
torch.Size([96, 98, 256])
12
torch.Size([96, 98, 256])
13
torch.Size([96, 98, 256])
14
torch.Size([96, 98, 320])
15
torch.Size([96, 98, 256])
16
torch.Size([96, 98, 256])
17
torch.Size([96, 98, 256])
18
torch.Size([96, 98, 320])
19
torch.Size([96, 98, 256])
20
torch.Size([96, 98, 704])
21
torch.Size([96, 98, 1024])
22
torch.Size([96, 98, 1024])
23
torch.Size([96, 98, 256])


 91%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎             | 20/22 [01:44<00:13,  6.78s/it]

0
torch.Size([123, 125, 256])
1
torch.Size([123, 125, 256])
2
torch.Size([123, 125, 256])
3
torch.Size([123, 125, 256])
4
torch.Size([123, 125, 256])
5
torch.Size([123, 125, 256])
6
torch.Size([123, 125, 256])
7
torch.Size([123, 125, 256])
8
torch.Size([123, 125, 256])
9
torch.Size([123, 125, 256])
10
torch.Size([123, 125, 256])
11
torch.Size([123, 125, 256])
12
torch.Size([123, 125, 256])
13
torch.Size([123, 125, 256])
14
torch.Size([123, 125, 320])
15
torch.Size([123, 125, 256])
16
torch.Size([123, 125, 256])
17
torch.Size([123, 125, 256])
18
torch.Size([123, 125, 320])
19
torch.Size([123, 125, 256])
20
torch.Size([123, 125, 704])
21
torch.Size([123, 125, 1024])
22
torch.Size([123, 125, 1024])
23
torch.Size([123, 125, 256])


 95%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏      | 21/22 [02:11<00:12, 12.72s/it]

0
torch.Size([116, 118, 256])
1
torch.Size([116, 118, 256])
2
torch.Size([116, 118, 256])
3
torch.Size([116, 118, 256])
4
torch.Size([116, 118, 256])
5
torch.Size([116, 118, 256])
6
torch.Size([116, 118, 256])
7
torch.Size([116, 118, 256])
8
torch.Size([116, 118, 256])
9
torch.Size([116, 118, 256])
10
torch.Size([116, 118, 256])
11
torch.Size([116, 118, 256])
12
torch.Size([116, 118, 256])
13
torch.Size([116, 118, 256])
14
torch.Size([116, 118, 320])
15
torch.Size([116, 118, 256])
16
torch.Size([116, 118, 256])
17
torch.Size([116, 118, 256])
18
torch.Size([116, 118, 320])
19
torch.Size([116, 118, 256])
20
torch.Size([116, 118, 704])
21
torch.Size([116, 118, 1024])
22
torch.Size([116, 118, 1024])
23
torch.Size([116, 118, 256])


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 22/22 [02:23<00:00,  6.54s/it]

27434652.018668972





In [None]:
num_trials = 20

In [None]:
line = "Hello, I am a social bot! "
inputs = tokenizer(line, return_tensors="pt")
pruned_times = []

for i in range(num_trials):
    start = time.time()
    outputs = pruned_model.generate(
        input_ids=inputs["input_ids"], 
        max_new_tokens=100, 
        do_sample=True, 
        top_k=50, 
        top_p=0.95,
    )
    end = time.time()
    pruned_times.append(end - start)
    print(f"inference time: {end - start}")
    
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))

In [None]:
mean_pruned = np.mean(pruned_times)
std_pruned = np.std(pruned_times)
mode = statistics.median(pruned_times)
print(mean_pruned)
print(std_pruned)
print(mode)

In [None]:
# while True:
#     line = input("You:")
#     inputs = tokenizer(line, return_tensors="pt")
#     outputs = pruned_model.generate(
#         input_ids=inputs["input_ids"], 
#         max_new_tokens=20, 
#         do_sample=True, 
#         top_k=50, 
#         top_p=0.95,
#     )
#     print(tokenizer.batch_decode(outputs, skip_special_tokens=True))

In [None]:
# full_line = "Person: My favorite movie is the The Day After Tomorrow\nSocialBot: Oh, interesting, I am not familiar with that movie! Can you tell me more about it?"
# prompt_line = "Person: My favorite movie is The Day After Tomorrow\nSocialBot: "
# completion = full_line.split(prompt_line)[-1]
# inputs = tokenizer(prompt_line, return_tensors="pt")

# #for i in range(num_trials):
# outputs = pruned_model.generate(
#     input_ids=inputs["input_ids"], 
#     max_new_tokens=25, 
#     do_sample=True, 
#     top_k=50, 
#     top_p=0.95,
# )
# out_seq = tokenizer.batch_decode(outputs, skip_special_tokens=True)
# out_seq = out_seq[0].split("Person: My favorite movie is The Day After Tomorrow\nSocialBot: ")[-1]
# r_scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
# rouge_scores = r_scorer.score(completion, out_seq)

In [None]:
# from transformers import AutoTokenizer, BloomForCausalLM

In [None]:
# model = BloomForCausalLM.from_pretrained("bigscience/bloom-560m")

In [None]:
# inputs = tokenizer("Person: My favorite movie is The Day After Tomorrow\nSocialBot: ", return_tensors="pt")
# start = time.time()
# outputs = model.generate(input_ids=inputs["input_ids"], max_new_tokens=25, do_sample=True, top_k=50, top_p=0.95)
# end = time.time()
# print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
# print(f"inference time: {end - start}")