In [None]:
import torch
import torch.nn as nn
import math

class VariableSizeDecoderLayer(nn.Module):
    def __init__(self, d_model_in, d_model_out, nhead, dim_feedforward, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model_in, nhead, dropout=dropout)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model_in, dim_feedforward),
            nn.ReLU(),
            nn.Linear(dim_feedforward, d_model_out)
        )
        self.norm1 = nn.LayerNorm(d_model_in)
        self.norm2 = nn.LayerNorm(d_model_out)
        self.dropout = nn.Dropout(dropout)
        self.resize = nn.Linear(d_model_in, d_model_out) if d_model_in != d_model_out else nn.Identity()

    def forward(self, x, mask=None):
        attn_output, _ = self.self_attn(x, x, x, attn_mask=mask)
        x = x + self.dropout(attn_output)
        x = self.norm1(x)
        ff_output = self.feed_forward(x)
        x = self.resize(x) + self.dropout(ff_output)
        x = self.norm2(x)
        return x

class VariableSizeTransformerDecoder(nn.Module):
    def __init__(self, num_layers, d_model_list, nhead_list, dim_feedforward_list, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            VariableSizeDecoderLayer(d_model_list[i], d_model_list[i+1], nhead_list[i], dim_feedforward_list[i], dropout)
            for i in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model_list[-1])

    def forward(self, x, mask=None):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0), :]

class HierarchicalTransformerDecoder(nn.Module):
    def __init__(self, num_layers, d_model_list, nhead_list, dim_feedforward_list, vocab_size, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model_list[0])
        self.pos_encoder = PositionalEncoding(d_model_list[0])
        self.decoder = VariableSizeTransformerDecoder(num_layers, d_model_list, nhead_list, dim_feedforward_list, dropout)
        self.output_layer = nn.Linear(d_model_list[-1], vocab_size)

    def forward(self, x, mask=None):
        x = self.embedding(x) * math.sqrt(d_model_list[0])
        x = self.pos_encoder(x)
        x = self.decoder(x, mask)
        output = self.output_layer(x)
        return output

# Example usage
num_layers = 5  # Reduced by 1 to match the number of transitions
d_model_list = [256, 256, 512, 512, 1024, 1024]
nhead_list = [4, 4, 8, 8, 16]
dim_feedforward_list = [1024, 1024, 2048, 2048, 4096]
vocab_size = 10000

model = HierarchicalTransformerDecoder(num_layers, d_model_list, nhead_list, dim_feedforward_list, vocab_size)

# Example input
src = torch.randint(0, vocab_size, (10, 32))  # (sequence_length, batch_size)
output = model(src)
print(output.shape)  # Should be (10, 32, vocab_size)


In [None]:
import matplotlib.pyplot as plt
import re

# Parse the data
data_lines = """00:19:59.466 - B 0 [93869, 111155, 33845, 94844, 35294, 81963, 68236, 101974, 10249, 58798, 5601, 39995, 67901, 53486, 70831, 45838, 95554, 110077, 9479, 68059, 101752, 19193, 91254, 37498, 30459, 114206, 17742, 22699, 47439, 19184, 102967, 65685, 37040, 115353, 51420, 3628, 2112, 8964, 57148, 108897, 70224, 46612, 70249, 55734] 2048
00:20:00.638 - B 1 [80680, 115539, 106120, 14552, 16070, 48386, 34087, 34825, 18586, 109237, 93044, 82142, 14525, 17536, 32220, 14813, 95425, 2923, 62688, 7760, 23321, 70265, 64270, 102321, 98539, 16786, 10769, 37950, 78306, 73059, 89658, 12549, 88344, 62321, 104313, 76086, 67089, 38206, 64276, 71457, 43008, 62313, 17330, 53201] 2048
00:20:01.798 - B 2 [45841, 102715, 61874, 55308, 51271, 13320, 35674, 104679, 114054, 96819, 85397, 14223, 80353, 1156, 83366, 112357, 42405, 61880, 86966, 44319, 89807, 73091, 29859, 37418, 30299, 32881, 112793, 73756, 18039, 28953, 105930, 107739, 88810, 102794, 19138, 23834, 18576, 58419, 51415, 84426, 2728, 69124, 85866, 63341] 2048
00:20:02.926 - B 3 [116806, 41479, 24884, 10875, 42240, 34252, 77267, 35387, 30088, 59563, 113214, 106575, 70312, 108889, 76429, 92022, 9672, 103831, 24396, 85532, 55743, 62677, 50014, 80380, 108150, 103094, 68255, 50664, 23558, 94749, 11365, 70679, 34914, 35154, 42047, 112822, 67406, 116470, 77115, 12185, 4069, 101722, 68947, 87139] 2048
00:20:04.217 - B 4 [95898, 38138, 68885, 31825, 19451, 5743, 98018, 26784, 48854, 10686, 42009, 37318, 31716, 82174, 55389, 6573, 2386, 19488, 33669, 80640, 36277, 65704, 80618, 101799, 80899, 70882, 96551, 81362, 80772, 49419, 51908, 104980, 94456, 36891, 98525, 80888, 90171, 77611, 12335, 47514, 29513, 43005, 116325, 13486] 2048
00:20:05.375 - B 5 [70522, 38053, 109390, 83408, 60972, 52328, 60651, 86987, 23431, 72376, 19535, 103394, 53837, 82957, 91789, 98816, 62691, 28755, 44793, 75075, 46100, 93933, 34815, 47953, 111272, 72211, 43479, 19986, 27363, 87985, 70381, 77753, 16069, 108030, 115040, 70841, 24939, 12631, 104415, 69174, 7949, 43089, 75570, 62870] 2048
00:20:06.527 - B 6 [53513, 103343, 53611, 76179, 19317, 109217, 54260, 31593, 67428, 35711, 53782, 35975, 52931, 17886, 99528, 52824, 69694, 62933, 67384, 22357, 82248, 55828, 94517, 49107, 72302, 61418, 101220, 101702, 57984, 2747, 77267, 38422, 72698, 5814, 24848, 116662, 107212, 64734, 111823, 49017, 1799, 21874, 45981, 11022] 2048
00:20:07.675 - B 7 [30188, 3182, 49213, 77815, 88045, 105777, 71056, 112764, 51668, 79586, 1180, 99987, 17242, 28563, 74711, 8277, 91262, 46903, 111960, 4346, 68693, 102474, 14740, 92309, 43602, 34438, 108982, 29195, 89998, 6052, 11746, 75814, 53743, 41733, 45966, 87913, 47480, 97625, 43173, 62154, 7557, 116162, 71352, 71128] 2048
00:20:08.819 - B 8 [110252, 93976, 61781, 47446, 98306, 26828, 77097, 26273, 54752, 49012, 96863, 81195, 2358, 48499, 50789, 54026, 14424, 29885, 62030, 78219, 85991, 89731, 28363, 34707, 33318, 39916, 97918, 10619, 109857, 77434, 31685, 103229, 6096, 107443, 69709, 10235, 94857, 23636, 43568, 16591, 60522, 96953, 111557, 34954] 2048
00:20:09.979 - B 9 [43178, 10627, 98439, 31285, 47313, 72023, 107751, 98536, 45088, 24285, 71122, 104729, 94761, 34846, 53388, 43672, 110688, 82157, 13951, 89610, 81096, 70518, 61363, 10134, 57271, 52344, 108004, 76192, 39378, 62942, 11019, 91067, 65390, 22616, 103338, 59486, 34629, 65027, 56176, 94226, 42348, 90350, 26799, 100522] 2048
00:20:11.138 - B 10 [77525, 25757, 59503, 74845, 65768, 54779, 17358, 25291, 103660, 15653, 105146, 27385, 74440, 65227, 20435, 90286, 33889, 21606, 109118, 17737, 58892, 71121, 53273, 105287, 89044, 80660, 97999, 89033, 98851, 104988, 81325, 74889, 77211, 57521, 45901, 88966, 101178, 7838, 16619, 103792, 71452, 28826, 98565, 42421] 2048
00:20:12.290 - B 11 [2938, 45095, 75411, 87092, 11182, 26062, 23582, 2834, 10454, 36846, 13830, 42675, 82137, 84688, 115338, 45575, 107111, 49955, 21640, 64684, 70872, 78772, 85233, 63407, 76503, 37288, 13447, 82541, 73280, 11740, 29491, 30233, 22375, 36992, 83826, 111702, 26862, 60395, 96595, 110403, 43251, 116344, 100387, 58035] 2048
00:20:13.418 - B 12 [66180, 10678, 27291, 24271, 44998, 2103, 72937, 74942, 34999, 22605, 6488, 88456, 99989, 41981, 9382, 1149, 49825, 57165, 14152, 46419, 28697, 8157, 37786, 78493, 102560, 104736, 71987, 70171, 110976, 86705, 67706, 110377, 96267, 41211, 75390, 63420, 51276, 35773, 80254, 83443, 4025, 1294, 113658, 33028] 2048
00:20:14.565 - B 13 [7926, 28174, 54671, 56146, 78409, 5156, 24492, 116722, 93593, 37902, 117012, 70969, 2354, 84573, 31662, 13440, 94305, 96937, 36926, 21211, 66369, 85028, 44360, 50419, 108100, 54840, 35186, 22890, 65858, 114447, 114762, 82866, 51799, 2092, 75517, 31194, 102949, 116217, 63285, 85721, 14188, 83872, 97903, 3006] 2048
00:20:15.730 - B 14 [57123, 99104, 10772, 17700, 42469, 35471, 111041, 23047, 39818, 80373, 17084, 11565, 19675, 73023, 82750, 47661, 85581, 94033, 62597, 79650, 56088, 10098, 66943, 105885, 90922, 110630, 17906, 52329, 54471, 77839, 12185, 15815, 78251, 82713, 83525, 33060, 110846, 97263, 16802, 25418, 112943, 70113, 66464, 99760] 2048
00:20:16.879 - B 15 [101502, 68188, 92172, 61937, 87401, 28692, 81037, 25869, 33760, 94238, 95280, 107044, 65510, 47996, 46341, 71313, 96782, 113470, 38226, 103024, 98599, 69293, 10845, 47437, 10248, 94840, 43712, 64501, 64797, 108301, 86401, 77060, 88713, 6935, 23033, 60669, 1761, 28761, 44899, 43350, 70880, 52362, 33714, 26431] 2048
00:20:18.032 - B 16 [42288, 70649, 25295, 59240, 76842, 114710, 40600, 28760, 98616, 69911, 24277, 65260, 109464, 112942, 84061, 60228, 74018, 41595, 112150, 44762, 18756, 111135, 31122, 93880, 32980, 74832, 26027, 15119, 73278, 97551, 90226, 24759, 21513, 27086, 72959, 62192, 46609, 63729, 104800, 104944, 74021, 61739, 78377, 115935] 2048
00:20:19.187 - B 17 [99566, 30609, 27703, 70687, 70652, 100243, 47981, 96977, 34389, 23932, 74821, 52378, 61014, 36142, 16321, 68582, 99293, 95938, 107033, 27871, 7019, 81699, 26754, 77752, 84241, 33918, 89876, 102008, 59777, 104818, 71329, 12800, 21343, 86060, 2827, 63335, 48714, 900, 79375, 25485, 24715, 110401, 30180, 114530] 2048
00:20:20.475 - B 18 [568, 6245, 75947, 108597, 76364, 82170, 56464, 4343, 43586, 62811, 114197, 82279, 96812, 12612, 105507, 20209, 11062, 101621, 12425, 116369, 51794, 31837, 84904, 6123, 74618, 56130, 105778, 48854, 17438, 57822, 68182, 66099, 53611, 77559, 72802, 37614, 115956, 14266, 8555, 31414, 56322, 83061, 27827, 78705] 2048
""".split('\n')


# Initialize list for all numbers
all_numbers = []

# Extract numbers from brackets
for line in data_lines:
    if line.strip():  # Skip empty lines
        # Extract the content within brackets
        match = re.search(r'\[(.*?)\]', line)
        if match:
            # Split the numbers and convert to integers
            numbers = [int(x.strip()) for x in match.group(1).split(',')]
            all_numbers.extend(numbers)

# Create the plot
plt.figure(figsize=(12, 6))

# Plot histogram with better parameters
plt.hist(all_numbers, bins=50, alpha=0.7, color='blue', edgecolor='black')

# Customize the plot
plt.title('Distribution of Numbers')
plt.xlabel('Value')
plt.ylabel('Frequency')

# Set the x-axis range
plt.xlim(0, 120000)

# Add grid
plt.grid(True, alpha=0.3)

# Adjust layout
plt.tight_layout()

# Show the plot
plt.show()



In [None]:
%matplotlib widget
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets

# Create some sample data
x = np.linspace(0, 10, 100)
y = np.sin(x)

# Create the figure and plot
fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(x, y, 'b-', label='sin(x)')
ax.set_title('Click anywhere on the plot!')
ax.grid(True)
ax.legend()

# List to store click coordinates
click_coords = []

# Create output widget
out = widgets.Output()
display(out)

# Define click handler with output capture
@out.capture()
def onclick(event):
    print('onclicked')
    # Verify click is within axes
    if event.inaxes != ax:
        return
    
    # Get click coordinates
    xc, yc = event.xdata, event.ydata
    
    # Store coordinates
    click_coords.append((xc, yc))
    
    # Add a marker at click position
    ax.plot(xc, yc, 'ro', markersize=10)
    
    # Add text annotation
    ax.text(xc, yc, f'Click {len(click_coords)}', 
            fontsize=10, ha='right', va='bottom')
    
    # Update the plot
    fig.canvas.draw_idle()
    
    # Print coordinates
    print(f'Click {len(click_coords)}: x={xc:.2f}, y={yc:.2f}')

# Connect the click handler
cid = fig.canvas.mpl_connect('button_press_event', onclick)

plt.show()


In [None]:
import json
with open('/home/nikola/.ssh/200k_HEAVY_gpt4o-description-gpt4omini-code_generated_problems/ex.json', 'r') as file:
    ex = json.load(file)

print(ex[0][0])
print(ex[0][1])
print(ex[1][0])
print(ex[1][1])

In [None]:
import torch

# Create a simple example
batch_size = 1
beam_width = 3
vocab_size = 10

# Create random log probabilities
log_probs = torch.randn(batch_size, beam_width, vocab_size)

# Print original shape
print("Original shape:", log_probs.shape)  # (1, 3, 10)

# Flatten and apply topk
flattened = log_probs.view(batch_size, -1)
print("Flattened shape:", flattened.shape)  # (1, 30)

# Get top k values and indices
top_scores, top_tokens = flattened.topk(beam_width, dim=-1)

print("Top scores shape:", top_scores.shape)    # (1, 3)
print("Top tokens shape:", top_tokens.shape)    # (1, 3)

# Demonstration with actual values
print("\nExample values:")
print("Top scores:", top_scores)
print("Top tokens:", top_tokens)