In [1]:
low = ["low" for i in range(5)]
lower = ["lower" for i in range(2)]
widest = ["widest" for i in range(3)]
newest = ["newest" for i in range(6)]
text = low + lower + widest + newest

In [2]:
text

['low',
 'low',
 'low',
 'low',
 'low',
 'lower',
 'lower',
 'widest',
 'widest',
 'widest',
 'newest',
 'newest',
 'newest',
 'newest',
 'newest',
 'newest']

In [3]:
from typing import List, Tuple, Dict
import os
from multiprocessing import Pool,cpu_count
import regex as re
from cs336_basics.pretokenization_example import find_chunk_boundaries
import time
import json
from collections import Counter,defaultdict
import heapq

def process_single_chunk(text):
        words_to_count = Counter()

        for t in text:

            words_to_count[t] += 1
        
        words_to_tokens = {}
        pair_to_words = defaultdict(set)
        pair_to_count = Counter()
        for word, count in words_to_count.items():
            word_encoded = word.encode("utf-8")
            l_list = [bytes([b]) for b in word_encoded]
            
            words_to_tokens[word] = l_list
            if len(l_list) >= 2:
                for i in range(len(l_list) - 1):
                    pair = (l_list[i], l_list[i + 1])
                    pair_to_words[pair].add(word)
                    pair_to_count[pair] += count
        return {
            "words_to_count": words_to_count,
            "words_to_tokens": words_to_tokens,
            "pair_to_words": pair_to_words,
            "pair_to_count": pair_to_count
        }
        
def pre_token(text):
    chunk_results = []
    chunk_results.append(process_single_chunk(text))
    words_to_count = Counter()
    pair_to_count = Counter()
    words_to_tokens = {}
    pair_to_words = defaultdict(set)
    for chunk_result in chunk_results:
        words_to_count.update(chunk_result["words_to_count"])
        pair_to_count.update(chunk_result["pair_to_count"])
        words_to_tokens.update(chunk_result["words_to_tokens"])
        for pair, word_set in chunk_result["pair_to_words"].items():
            pair_to_words[pair].update(word_set)

    return {
            "words_to_count": words_to_count,
            "words_to_tokens": words_to_tokens,
            "pair_to_words": pair_to_words,
            "pair_to_count": pair_to_count
        }

In [4]:
ans = pre_token(text)

In [5]:
ans

{'words_to_count': Counter({'newest': 6, 'low': 5, 'widest': 3, 'lower': 2}),
 'words_to_tokens': {'low': [b'l', b'o', b'w'],
  'lower': [b'l', b'o', b'w', b'e', b'r'],
  'widest': [b'w', b'i', b'd', b'e', b's', b't'],
  'newest': [b'n', b'e', b'w', b'e', b's', b't']},
 'pair_to_words': defaultdict(set,
             {(b'l', b'o'): {'low', 'lower'},
              (b'o', b'w'): {'low', 'lower'},
              (b'w', b'e'): {'lower', 'newest'},
              (b'e', b'r'): {'lower'},
              (b'w', b'i'): {'widest'},
              (b'i', b'd'): {'widest'},
              (b'd', b'e'): {'widest'},
              (b'e', b's'): {'newest', 'widest'},
              (b's', b't'): {'newest', 'widest'},
              (b'n', b'e'): {'newest'},
              (b'e', b'w'): {'newest'}}),
 'pair_to_count': Counter({(b'e', b's'): 9,
          (b's', b't'): 9,
          (b'w', b'e'): 8,
          (b'l', b'o'): 7,
          (b'o', b'w'): 7,
          (b'n', b'e'): 6,
          (b'e', b'w'): 6,
       

In [9]:
class ReverseSortPair:
    def __init__(self, pair):
        self.pair = pair
    def __lt__(self, other):
        # 频率相同时，字典序大的返回 True (即它更“小”，会排在堆顶)
        return self.pair > other.pair

def train_bpe(text, vocab_size: int, special_tokens: List[str]) -> Tuple[Dict[int, bytes], List[Tuple[bytes, bytes]]]:
    pre_token_result_dict = pre_token(text)
    w2c = pre_token_result_dict["words_to_count"]
    w2t = pre_token_result_dict["words_to_tokens"]
    p2w = pre_token_result_dict["pair_to_words"]
    p2c = pre_token_result_dict["pair_to_count"]
    
    # 先放special token（从ID 0开始），然后是256个字节
    vocab = {}
    for i, special_token in enumerate(special_tokens):
        vocab[i] = special_token.encode("utf-8")
    for i in range(256):
        vocab[len(special_tokens) + i] = bytes([i])
    cur_vocab_size = 256 + len(special_tokens)
    heap = [(-count, ReverseSortPair(pair)) for pair, count in p2c.items() if count > 0]
    heapq.heapify(heap)
    merges = []
    while cur_vocab_size < vocab_size:
        if not heap:
            break
        
        neg_count, wrapper = heapq.heappop(heap)
        best_pair = wrapper.pair
        count = - neg_count
        if count != p2c.get(best_pair,0):
            continue
        merges.append(best_pair)
        words_to_update = list(p2w[best_pair])
        new_token = best_pair[0] + best_pair[1]
        vocab[cur_vocab_size] = new_token
        cur_vocab_size += 1
        for word in words_to_update:
            new_tokens = []
            old_tokens = w2t[word]
            i = 0
            while i < len(old_tokens):
                if i < len(old_tokens) -1 and (old_tokens[i], old_tokens[i + 1]) == best_pair:
                    new_tokens.append(new_token)
                    if i > 0:
                        # (old_token[i-1], old_token[i]) 这对pair 的count减去一个word count
                        # p2w里面移除掉(old_token[i-1], old_token[i])对应的word
                        # 生成新pair: (old_token[i], new_token)
                        old_near_pair = (old_tokens[i - 1], old_tokens[i])
                        p2c[old_near_pair] -= w2c[word]
                        new_near_pair = (old_tokens[i -1], new_token)
                        p2w[new_near_pair].add(word)
                        if word in p2w[old_near_pair]:
                            p2w[old_near_pair].remove(word)
                        p2c[new_near_pair] += w2c[word]
                        heapq.heappush(heap,(-p2c[new_near_pair], ReverseSortPair(new_near_pair)))
                    if i < len(old_tokens) - 2:
                        if (old_tokens[i + 1], old_tokens[i + 2]) != best_pair:
                            old_near_pair = (old_tokens[i + 1], old_tokens[i + 2])
                            p2c[old_near_pair] -= w2c[word]
                            new_near_pair = (new_token, old_tokens[i + 2])
                            p2w[new_near_pair].add(word)
                            if word in p2w[old_near_pair]:
                                p2w[old_near_pair].remove(word)
                            p2c[new_near_pair] += w2c[word]
                            heapq.heappush(heap,(-p2c[new_near_pair], ReverseSortPair(new_near_pair)))
                    i += 2
                else:
                    new_tokens.append(old_tokens[i])
                    i += 1
            w2t[word] = new_tokens
        del p2w[best_pair]
        p2c[best_pair] = 0
    
    return vocab, merges

In [10]:
special_tokens = ["<|endoftext|>"]
train_bpe(text, 263,special_tokens)

({0: b'<|endoftext|>',
  1: b'\x00',
  2: b'\x01',
  3: b'\x02',
  4: b'\x03',
  5: b'\x04',
  6: b'\x05',
  7: b'\x06',
  8: b'\x07',
  9: b'\x08',
  10: b'\t',
  11: b'\n',
  12: b'\x0b',
  13: b'\x0c',
  14: b'\r',
  15: b'\x0e',
  16: b'\x0f',
  17: b'\x10',
  18: b'\x11',
  19: b'\x12',
  20: b'\x13',
  21: b'\x14',
  22: b'\x15',
  23: b'\x16',
  24: b'\x17',
  25: b'\x18',
  26: b'\x19',
  27: b'\x1a',
  28: b'\x1b',
  29: b'\x1c',
  30: b'\x1d',
  31: b'\x1e',
  32: b'\x1f',
  33: b' ',
  34: b'!',
  35: b'"',
  36: b'#',
  37: b'$',
  38: b'%',
  39: b'&',
  40: b"'",
  41: b'(',
  42: b')',
  43: b'*',
  44: b'+',
  45: b',',
  46: b'-',
  47: b'.',
  48: b'/',
  49: b'0',
  50: b'1',
  51: b'2',
  52: b'3',
  53: b'4',
  54: b'5',
  55: b'6',
  56: b'7',
  57: b'8',
  58: b'9',
  59: b':',
  60: b';',
  61: b'<',
  62: b'=',
  63: b'>',
  64: b'?',
  65: b'@',
  66: b'A',
  67: b'B',
  68: b'C',
  69: b'D',
  70: b'E',
  71: b'F',
  72: b'G',
  73: b'H',
  74: b'I',
  75: b'

In [4]:
import json
from pathlib import Path
vocab_path = "tests/fixtures/gpt2_vocab.json"
vocab_path = Path(vocab_path)


In [5]:
vocab_path.read_text(encoding="utf-8")



In [7]:
content = json.loads(vocab_path.read_text(encoding="utf-8"))

In [8]:
content

{'!': 0,
 '"': 1,
 '#': 2,
 '$': 3,
 '%': 4,
 '&': 5,
 "'": 6,
 '(': 7,
 ')': 8,
 '*': 9,
 '+': 10,
 ',': 11,
 '-': 12,
 '.': 13,
 '/': 14,
 '0': 15,
 '1': 16,
 '2': 17,
 '3': 18,
 '4': 19,
 '5': 20,
 '6': 21,
 '7': 22,
 '8': 23,
 '9': 24,
 ':': 25,
 ';': 26,
 '<': 27,
 '=': 28,
 '>': 29,
 '?': 30,
 '@': 31,
 'A': 32,
 'B': 33,
 'C': 34,
 'D': 35,
 'E': 36,
 'F': 37,
 'G': 38,
 'H': 39,
 'I': 40,
 'J': 41,
 'K': 42,
 'L': 43,
 'M': 44,
 'N': 45,
 'O': 46,
 'P': 47,
 'Q': 48,
 'R': 49,
 'S': 50,
 'T': 51,
 'U': 52,
 'V': 53,
 'W': 54,
 'X': 55,
 'Y': 56,
 'Z': 57,
 '[': 58,
 '\\': 59,
 ']': 60,
 '^': 61,
 '_': 62,
 '`': 63,
 'a': 64,
 'b': 65,
 'c': 66,
 'd': 67,
 'e': 68,
 'f': 69,
 'g': 70,
 'h': 71,
 'i': 72,
 'j': 73,
 'k': 74,
 'l': 75,
 'm': 76,
 'n': 77,
 'o': 78,
 'p': 79,
 'q': 80,
 'r': 81,
 's': 82,
 't': 83,
 'u': 84,
 'v': 85,
 'w': 86,
 'x': 87,
 'y': 88,
 'z': 89,
 '{': 90,
 '|': 91,
 '}': 92,
 '~': 93,
 '¡': 94,
 '¢': 95,
 '£': 96,
 '¤': 97,
 '¥': 98,
 '¦': 99,
 '§': 100

In [9]:
type(content)

dict

In [10]:
raw_vocab: dict[str, int] = json.loads(vocab_path.read_text(encoding="utf-8"))

In [12]:
vocab: dict[int,bytes] = {token_id: token_str.encode("utf-8") for token_str,token_id in raw_vocab.items()
}

In [13]:
vocab

{0: b'!',
 1: b'"',
 2: b'#',
 3: b'$',
 4: b'%',
 5: b'&',
 6: b"'",
 7: b'(',
 8: b')',
 9: b'*',
 10: b'+',
 11: b',',
 12: b'-',
 13: b'.',
 14: b'/',
 15: b'0',
 16: b'1',
 17: b'2',
 18: b'3',
 19: b'4',
 20: b'5',
 21: b'6',
 22: b'7',
 23: b'8',
 24: b'9',
 25: b':',
 26: b';',
 27: b'<',
 28: b'=',
 29: b'>',
 30: b'?',
 31: b'@',
 32: b'A',
 33: b'B',
 34: b'C',
 35: b'D',
 36: b'E',
 37: b'F',
 38: b'G',
 39: b'H',
 40: b'I',
 41: b'J',
 42: b'K',
 43: b'L',
 44: b'M',
 45: b'N',
 46: b'O',
 47: b'P',
 48: b'Q',
 49: b'R',
 50: b'S',
 51: b'T',
 52: b'U',
 53: b'V',
 54: b'W',
 55: b'X',
 56: b'Y',
 57: b'Z',
 58: b'[',
 59: b'\\',
 60: b']',
 61: b'^',
 62: b'_',
 63: b'`',
 64: b'a',
 65: b'b',
 66: b'c',
 67: b'd',
 68: b'e',
 69: b'f',
 70: b'g',
 71: b'h',
 72: b'i',
 73: b'j',
 74: b'k',
 75: b'l',
 76: b'm',
 77: b'n',
 78: b'o',
 79: b'p',
 80: b'q',
 81: b'r',
 82: b's',
 83: b't',
 84: b'u',
 85: b'v',
 86: b'w',
 87: b'x',
 88: b'y',
 89: b'z',
 90: b'{',
 91: b'|

In [14]:
merges_raw_path = "tests/fixtures/gpt2_merges.txt"
merges_path = Path(merges_raw_path)
raw_merges = merges_path.read_text(encoding="utf-8")


In [15]:
raw_merges



In [16]:
raw_merges.strip().split("\n")

['Ġ t',
 'Ġ a',
 'h e',
 'i n',
 'r e',
 'o n',
 'Ġt he',
 'e r',
 'Ġ s',
 'a t',
 'Ġ w',
 'Ġ o',
 'e n',
 'Ġ c',
 'i t',
 'i s',
 'a n',
 'o r',
 'e s',
 'Ġ b',
 'e d',
 'Ġ f',
 'in g',
 'Ġ p',
 'o u',
 'Ġa n',
 'a l',
 'a r',
 'Ġt o',
 'Ġ m',
 'Ġo f',
 'Ġ in',
 'Ġ d',
 'Ġ h',
 'Ġan d',
 'i c',
 'a s',
 'l e',
 'Ġt h',
 'i on',
 'o m',
 'l l',
 'en t',
 'Ġ n',
 'Ġ l',
 's t',
 'Ġ re',
 'v e',
 'Ġ e',
 'r o',
 'l y',
 'Ġb e',
 'Ġ g',
 'Ġ T',
 'c t',
 'Ġ S',
 'i d',
 'o t',
 'Ġ I',
 'u t',
 'e t',
 'Ġ A',
 'Ġ is',
 'Ġ on',
 'i m',
 'a m',
 'o w',
 'a y',
 'a d',
 's e',
 'Ġth at',
 'Ġ C',
 'i g',
 'Ġf or',
 'a c',
 'Ġ y',
 'v er',
 'u r',
 'Ġ u',
 'l d',
 'Ġs t',
 'Ġ M',
 "' s",
 'Ġ he',
 'Ġ it',
 'at ion',
 'it h',
 'i r',
 'c e',
 'Ġy ou',
 'i l',
 'Ġ B',
 'Ġw h',
 'o l',
 'Ġ P',
 'Ġw ith',
 'Ġ 1',
 't er',
 'c h',
 'Ġa s',
 'Ġw e',
 'Ġ (',
 'n d',
 'i ll',
 'Ġ D',
 'i f',
 'Ġ 2',
 'a g',
 'er s',
 'k e',
 'Ġ "',
 'Ġ H',
 'e m',
 'Ġc on',
 'Ġ W',
 'Ġ R',
 'he r',
 'Ġw as',
 'Ġ r',
 'o

In [20]:
merges: list[tuple[bytes,bytes]] = [
    (token1.encode("utf-8"), token2.encode("utf-8")) for line in raw_merges.strip().split('\n') for token1, token2 in [line.split()]
]

In [21]:
merges

[(b'\xc4\xa0', b't'),
 (b'\xc4\xa0', b'a'),
 (b'h', b'e'),
 (b'i', b'n'),
 (b'r', b'e'),
 (b'o', b'n'),
 (b'\xc4\xa0t', b'he'),
 (b'e', b'r'),
 (b'\xc4\xa0', b's'),
 (b'a', b't'),
 (b'\xc4\xa0', b'w'),
 (b'\xc4\xa0', b'o'),
 (b'e', b'n'),
 (b'\xc4\xa0', b'c'),
 (b'i', b't'),
 (b'i', b's'),
 (b'a', b'n'),
 (b'o', b'r'),
 (b'e', b's'),
 (b'\xc4\xa0', b'b'),
 (b'e', b'd'),
 (b'\xc4\xa0', b'f'),
 (b'in', b'g'),
 (b'\xc4\xa0', b'p'),
 (b'o', b'u'),
 (b'\xc4\xa0a', b'n'),
 (b'a', b'l'),
 (b'a', b'r'),
 (b'\xc4\xa0t', b'o'),
 (b'\xc4\xa0', b'm'),
 (b'\xc4\xa0o', b'f'),
 (b'\xc4\xa0', b'in'),
 (b'\xc4\xa0', b'd'),
 (b'\xc4\xa0', b'h'),
 (b'\xc4\xa0an', b'd'),
 (b'i', b'c'),
 (b'a', b's'),
 (b'l', b'e'),
 (b'\xc4\xa0t', b'h'),
 (b'i', b'on'),
 (b'o', b'm'),
 (b'l', b'l'),
 (b'en', b't'),
 (b'\xc4\xa0', b'n'),
 (b'\xc4\xa0', b'l'),
 (b's', b't'),
 (b'\xc4\xa0', b're'),
 (b'v', b'e'),
 (b'\xc4\xa0', b'e'),
 (b'r', b'o'),
 (b'l', b'y'),
 (b'\xc4\xa0b', b'e'),
 (b'\xc4\xa0', b'g'),
 (b'\xc4\xa0', b

In [11]:
import regex as re
special_tokens = ["<|endoftext|>"]
text = "Héllò hôw"
special_pattern = '|'.join(re.escape(token) for token in special_tokens)
re.findall(special_pattern,text)

[]

In [13]:
result = re.findall(special_pattern,text)
if len(result) == 0:
    print("fuck")

fuck


In [1]:
text = "shuifrfgtr"
text_encoded = text.encode("utf-8")
text_encoded

b'shuifrfgtr'

In [3]:
text_encoded_bytes = [bytes([b]) for b in text_encoded]
text_encoded_bytes

[b's', b'h', b'u', b'i', b'f', b'r', b'f', b'g', b't', b'r']

In [15]:
text = "Héllò"
unicode = [c.encode("utf-8") for c in text]
unicode

[b'H', b'\xc3\xa9', b'l', b'l', b'\xc3\xb2']

In [20]:
def get_bytes_to_unicode():
      """返回 GPT-2 BPE 使用的 bytes_to_unicode 映射表"""
      bs = list(range(ord("!"), ord("~")+1)) + \
           list(range(ord("¡"), ord("¬")+1)) + \
           list(range(ord("®"), ord("ÿ")+1))
      cs = bs[:]
      n = 0
      for b in range(2**8):
          if b not in bs:
              bs.append(b)
              cs.append(2**8 + n)
              n += 1
      cs = [chr(n) for n in cs]
      return dict(zip(bs, cs))  

bytes_to_unicode = get_bytes_to_unicode()
unicode_to_bytes = {v: k for k, v in bytes_to_unicode.items()}
token_str = "ÃBC"  # "ĠTo"
token_bytes = [bytes([b]) for b in bytes([unicode_to_bytes[c] for c in token_str])]

In [21]:
token_bytes

[b'\xc3', b'B', b'C']

In [None]:
import json
from pathlib import Path
bs = list(range(ord("!"), ord("~")+1)) + \
           list(range(ord("¡"), ord("¬")+1)) + \
           list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
    if b not in bs:
        bs.append(b)
        cs.append(2**8 + n)
        n += 1
cs = [chr(n) for n in cs]
bytes_to_unicode = dict(zip(bs, cs))      # {0: '0', 32: 'Ġ', 195: 'Ã', ...}
unicode_to_bytes = {v: k for k, v in bytes_to_unicode.items()}
vocab_file_path = "tests/fixtures/gpt2_vocab.json"
vocab_path = Path(vocab_file_path)
raw_vocab: dict[str, int] = json.loads(vocab_path.read_text(encoding="utf-8"))
vocab: dict[int, bytes] = {}
for token_str, token_id in raw_vocab.items():
    # 逐个字符逆向映射
    # "Ã" → ["Ã"] → [195] → b'\xc3'
    # "ĠTo" → ["Ġ","T","o"] → [32,84,111] → b' To'
    token_bytes = bytes([unicode_to_bytes[c] for c in token_str])
    vocab[token_id] = token_bytes

print(vocab)



In [14]:
merges_file_path = "tests/fixtures/gpt2_merges.txt"
merges_path = Path(merges_file_path)
raw_merges = merges_path.read_text(encoding="utf-8")
def token_str_to_bytes(token_str):
    return bytes([unicode_to_bytes[c] for c in token_str])
merges: list[tuple[bytes,bytes]] = [
    (token_str_to_bytes(token1), token_str_to_bytes(token2)) for line in raw_merges.strip().split('\n') for token1,token2 in [line.split()]
]

In [15]:
merges

[(b' ', b't'),
 (b' ', b'a'),
 (b'h', b'e'),
 (b'i', b'n'),
 (b'r', b'e'),
 (b'o', b'n'),
 (b' t', b'he'),
 (b'e', b'r'),
 (b' ', b's'),
 (b'a', b't'),
 (b' ', b'w'),
 (b' ', b'o'),
 (b'e', b'n'),
 (b' ', b'c'),
 (b'i', b't'),
 (b'i', b's'),
 (b'a', b'n'),
 (b'o', b'r'),
 (b'e', b's'),
 (b' ', b'b'),
 (b'e', b'd'),
 (b' ', b'f'),
 (b'in', b'g'),
 (b' ', b'p'),
 (b'o', b'u'),
 (b' a', b'n'),
 (b'a', b'l'),
 (b'a', b'r'),
 (b' t', b'o'),
 (b' ', b'm'),
 (b' o', b'f'),
 (b' ', b'in'),
 (b' ', b'd'),
 (b' ', b'h'),
 (b' an', b'd'),
 (b'i', b'c'),
 (b'a', b's'),
 (b'l', b'e'),
 (b' t', b'h'),
 (b'i', b'on'),
 (b'o', b'm'),
 (b'l', b'l'),
 (b'en', b't'),
 (b' ', b'n'),
 (b' ', b'l'),
 (b's', b't'),
 (b' ', b're'),
 (b'v', b'e'),
 (b' ', b'e'),
 (b'r', b'o'),
 (b'l', b'y'),
 (b' b', b'e'),
 (b' ', b'g'),
 (b' ', b'T'),
 (b'c', b't'),
 (b' ', b'S'),
 (b'i', b'd'),
 (b'o', b't'),
 (b' ', b'I'),
 (b'u', b't'),
 (b'e', b't'),
 (b' ', b'A'),
 (b' ', b'is'),
 (b' ', b'on'),
 (b'i', b'm'),
 (b'a', b

In [22]:
import time
from tokenizer import Tokenizer
t = Tokenizer.from_files('tests/fixtures/gpt2_vocab.json', 'tests/fixtures/gpt2_merges.txt')
start = time.time()
for _ in range(100):
    t.encode('Hello world, how are you doing today?')
print(f'100次编码耗时: {time.time()-start:.2f}s')

100次编码耗时: 0.84s


In [1]:
import torch
a = torch.ones(512)
a.dtype

torch.float32

In [2]:
a.shape[-1] == 512

True

In [4]:
import torch
theta = 10000
d_k = 128
freq = 1.0 / (theta ** (torch.arange(0, d_k, 2) / d_k))

In [2]:
freq

tensor([1.0000e+00, 8.6596e-01, 7.4989e-01, 6.4938e-01, 5.6234e-01, 4.8697e-01,
        4.2170e-01, 3.6517e-01, 3.1623e-01, 2.7384e-01, 2.3714e-01, 2.0535e-01,
        1.7783e-01, 1.5399e-01, 1.3335e-01, 1.1548e-01, 1.0000e-01, 8.6596e-02,
        7.4989e-02, 6.4938e-02, 5.6234e-02, 4.8697e-02, 4.2170e-02, 3.6517e-02,
        3.1623e-02, 2.7384e-02, 2.3714e-02, 2.0535e-02, 1.7783e-02, 1.5399e-02,
        1.3335e-02, 1.1548e-02, 1.0000e-02, 8.6596e-03, 7.4989e-03, 6.4938e-03,
        5.6234e-03, 4.8697e-03, 4.2170e-03, 3.6517e-03, 3.1623e-03, 2.7384e-03,
        2.3714e-03, 2.0535e-03, 1.7783e-03, 1.5399e-03, 1.3335e-03, 1.1548e-03,
        1.0000e-03, 8.6596e-04, 7.4989e-04, 6.4938e-04, 5.6234e-04, 4.8697e-04,
        4.2170e-04, 3.6517e-04, 3.1623e-04, 2.7384e-04, 2.3714e-04, 2.0535e-04,
        1.7783e-04, 1.5399e-04, 1.3335e-04, 1.1548e-04])

In [3]:
freq.shape

torch.Size([64])

In [5]:
m = torch.arange(128, dtype = torch.float32)
ans = torch.outer(m,freq)

In [6]:
ans.shape

torch.Size([128, 64])

In [7]:
r = torch.repeat_interleave(ans, 2, dim = -1)
r.shape

torch.Size([128, 128])

In [8]:
assert r[:,1] == ans[:,0]

RuntimeError: Boolean value of Tensor with more than one value is ambiguous

In [9]:
ans.shape

torch.Size([128, 64])

In [18]:
r[:,2]

tensor([  0.0000,   0.8660,   1.7319,   2.5979,   3.4639,   4.3298,   5.1958,
          6.0618,   6.9277,   7.7937,   8.6596,   9.5256,  10.3916,  11.2575,
         12.1235,  12.9895,  13.8554,  14.7214,  15.5874,  16.4533,  17.3193,
         18.1853,  19.0512,  19.9172,  20.7831,  21.6491,  22.5151,  23.3810,
         24.2470,  25.1130,  25.9789,  26.8449,  27.7109,  28.5768,  29.4428,
         30.3088,  31.1747,  32.0407,  32.9066,  33.7726,  34.6386,  35.5045,
         36.3705,  37.2365,  38.1024,  38.9684,  39.8344,  40.7003,  41.5663,
         42.4323,  43.2982,  44.1642,  45.0301,  45.8961,  46.7621,  47.6280,
         48.4940,  49.3600,  50.2259,  51.0919,  51.9579,  52.8238,  53.6898,
         54.5558,  55.4217,  56.2877,  57.1536,  58.0196,  58.8856,  59.7515,
         60.6175,  61.4835,  62.3494,  63.2154,  64.0814,  64.9473,  65.8133,
         66.6793,  67.5452,  68.4112,  69.2771,  70.1431,  71.0091,  71.8750,
         72.7410,  73.6070,  74.4729,  75.3389,  76.2049,  77.07

In [17]:
ans[:,1]

tensor([  0.0000,   0.8660,   1.7319,   2.5979,   3.4639,   4.3298,   5.1958,
          6.0618,   6.9277,   7.7937,   8.6596,   9.5256,  10.3916,  11.2575,
         12.1235,  12.9895,  13.8554,  14.7214,  15.5874,  16.4533,  17.3193,
         18.1853,  19.0512,  19.9172,  20.7831,  21.6491,  22.5151,  23.3810,
         24.2470,  25.1130,  25.9789,  26.8449,  27.7109,  28.5768,  29.4428,
         30.3088,  31.1747,  32.0407,  32.9066,  33.7726,  34.6386,  35.5045,
         36.3705,  37.2365,  38.1024,  38.9684,  39.8344,  40.7003,  41.5663,
         42.4323,  43.2982,  44.1642,  45.0301,  45.8961,  46.7621,  47.6280,
         48.4940,  49.3600,  50.2259,  51.0919,  51.9579,  52.8238,  53.6898,
         54.5558,  55.4217,  56.2877,  57.1536,  58.0196,  58.8856,  59.7515,
         60.6175,  61.4835,  62.3494,  63.2154,  64.0814,  64.9473,  65.8133,
         66.6793,  67.5452,  68.4112,  69.2771,  70.1431,  71.0091,  71.8750,
         72.7410,  73.6070,  74.4729,  75.3389,  76.2049,  77.07

In [12]:
r

tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.0000e+00, 8.6596e-01,  ..., 1.3335e-04, 1.1548e-04,
         1.1548e-04],
        [2.0000e+00, 2.0000e+00, 1.7319e+00,  ..., 2.6670e-04, 2.3096e-04,
         2.3096e-04],
        ...,
        [1.2500e+02, 1.2500e+02, 1.0825e+02,  ..., 1.6669e-02, 1.4435e-02,
         1.4435e-02],
        [1.2600e+02, 1.2600e+02, 1.0911e+02,  ..., 1.6802e-02, 1.4550e-02,
         1.4550e-02],
        [1.2700e+02, 1.2700e+02, 1.0998e+02,  ..., 1.6936e-02, 1.4666e-02,
         1.4666e-02]])

In [13]:
ans

tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 8.6596e-01, 7.4989e-01,  ..., 1.5399e-04, 1.3335e-04,
         1.1548e-04],
        [2.0000e+00, 1.7319e+00, 1.4998e+00,  ..., 3.0799e-04, 2.6670e-04,
         2.3096e-04],
        ...,
        [1.2500e+02, 1.0825e+02, 9.3737e+01,  ..., 1.9249e-02, 1.6669e-02,
         1.4435e-02],
        [1.2600e+02, 1.0911e+02, 9.4487e+01,  ..., 1.9403e-02, 1.6802e-02,
         1.4550e-02],
        [1.2700e+02, 1.0998e+02, 9.5237e+01,  ..., 1.9557e-02, 1.6936e-02,
         1.4666e-02]])