In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
os.chdir('/content/drive/MyDrive/MANAGER_implementation/code/Text_preprocessing/')

In [None]:
!pip install neo4j

Collecting neo4j
  Downloading neo4j-5.28.1-py3-none-any.whl.metadata (5.9 kB)
Downloading neo4j-5.28.1-py3-none-any.whl (312 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/312.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m312.3/312.3 kB[0m [31m22.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: neo4j
Successfully installed neo4j-5.28.1


In [None]:
# 라이브러리 import
from neo4j import GraphDatabase
import re

# Neo4j 연결 정보 (너의 환경에 맞게 수정)
from neo4j import GraphDatabase

URI = "neo4j+s://4768c59e.databases.neo4j.io"
AUTH = ("neo4j", "zhUE8ZhblaUK2KQjwgXZ5m7sF6DcDHomIE6zygcYk0M")
# Neo4j 연결 드라이버 생성
driver = GraphDatabase.driver(URI, auth=AUTH)

# Neo4j에서 entity 목록을 한번만 가져와 메모리에 캐싱하는 함수
def load_all_entities_from_neo4j():
    entities_set = set()
    with driver.session() as session:
        query = "MATCH (e:Entity) RETURN DISTINCT e.name as name"
        results = session.run(query)
        for record in results:
            entities_set.add(record["name"].lower())
    return entities_set

# 캐싱된 entity 목록
cached_entities = load_all_entities_from_neo4j()

# 입력 텍스트에서 FinDKG의 entity만 효율적으로 식별하는 함수
def identify_entities_in_text(T, cached_entities):
    identified_entities = set()
    T_lower = T.lower() #  입력 텍스트를 소문자로 변환

    for entity in cached_entities:
        pattern = r'\b' + re.escape(entity.lower()) + r'\b'
        if re.search(pattern, T_lower):
            identified_entities.add(entity)

    return identified_entities

# 식별된 entity를 바탕으로 Neo4j에서 외부지식 N(e)를 추출하는 함수

def extract_external_knowledge(T, cached_entities, start_time=None, end_time=None):
    # time format :
    entities = identify_entities_in_text(T, cached_entities)

    knowledge = {}



    with driver.session() as session:
        for entity in entities:

            query = """
            MATCH (e:Entity {name: $entity})-[r]->(neighbor)
            WHERE 1=1
            """
            if start_time is not None:
                query += " AND r.time > $start_time"

            if end_time is not None:
                query += " AND r.time < $end_time"
            # RETURN 절
            query += """
            RETURN r.relation as relation, r.time as time, neighbor.name as neighbor_entity
            """
            results = session.run(query, entity=entity , start_time = start_time , end_time = end_time)

            knowledge[entity] = []

            for record in results:
                knowledge[entity].append({
                    "relation": record["relation"],
                    "neighbor_entity": record["neighbor_entity"],
                    "time": record["time"]
                })

    return knowledge

In [None]:
T_example = """
    President Trump Administration had an influence on the Volcker rule.
    Wells Fargo Co. also impacted U.S. Federal Reserve policies. + Ne
    """
external_knowledge = extract_external_knowledge(T_example, cached_entities,'2018-03-03','2018-06-03')
print("Extracted Knowledge:", external_knowledge)

Extracted Knowledge: {'president trump administration': [{'relation': 'raise', 'neighbor_entity': 'china', 'time': '2018-05-27'}, {'relation': 'impact', 'neighbor_entity': 'china', 'time': '2018-03-18'}, {'relation': 'impact', 'neighbor_entity': 'china', 'time': '2018-03-25'}, {'relation': 'relate_to', 'neighbor_entity': 'u.s. federal reserve', 'time': '2018-04-08'}, {'relation': 'control', 'neighbor_entity': 'russia', 'time': '2018-04-01'}, {'relation': 'is_member_of', 'neighbor_entity': 'us government', 'time': '2018-03-04'}, {'relation': 'control', 'neighbor_entity': 'us government', 'time': '2018-03-11'}, {'relation': 'operate_in', 'neighbor_entity': 'us government', 'time': '2018-05-13'}, {'relation': 'control', 'neighbor_entity': 'amazon.com inc.', 'time': '2018-04-01'}, {'relation': 'impact', 'neighbor_entity': 'amazon.com inc.', 'time': '2018-04-01'}, {'relation': 'decrease', 'neighbor_entity': 'u.s. dollar', 'time': '2018-03-25'}, {'relation': 'control', 'neighbor_entity': 'ju

In [None]:
def knowledge_to_text(knowledge: dict) -> str:
    """
    anchor entity, relation, neighbor entity를 문자열로 이어붙이는 간단 예시
    (실제로는 anchor별로 따로 임베딩을 만들 수도 있고, triple 단위로 끊을 수도 있음)
    """
    # 예: "Inflation impact Stock Market. Fed affect S&P 500"
    segments = []
    for rel_list in knowledge.values():
        for item in rel_list:
            r = item["relation"]
            nbr = item["neighbor_entity"]
            segment = f"{r} {nbr}"
            segments.append(segment)
    return " ".join(segments)

In [None]:
knowledge_seq = knowledge_to_text(external_knowledge)

In [None]:
knowledge_seq

'raise china impact china impact china relate_to u.s. federal reserve control russia is_member_of us government control us government operate_in us government control amazon.com inc. impact amazon.com inc. decrease u.s. dollar control justice department impact european union control north korea control north korea control republicans introduce supreme court control qualcomm inc. impact the u.s. economy negative_impact_on economy control federal bureau of investigation control federal government relate_to ford motor co. raise tariffs announce tariffs control tariffs announce tariffs introduce tariffs control zte corp. relate_to wilbur ross impact stock market control u.s. companies control sanctions against russia impact markets control u.s. currency impact trade impact farmers control north american free trade agreement control north american free trade agreement control north korean talks has china policy is_member_of north korean leader kim jong un has sanctions impact gold market co

# chatglm

In [1]:
!pip install sentence-transformers==2.2.2
!pip install protobuf transformers==4.30.2 cpm_kernels torch>=2.0 gradio mdtex2html sentencepiece accelerate

Collecting sentence-transformers==2.2.2
  Downloading sentence-transformers-2.2.2.tar.gz (85 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/86.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.0/86.0 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.6.0->sentence-transformers==2.2.2)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.6.0->sentence-transformers==2.2.2)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.6.0->sentence-transformers==2.2.2)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==

In [2]:
from transformers import AutoTokenizer, AutoModel
import torch
model_name = "THUDM/chatglm2-6b"  # 예시
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
chatglm_model = AutoModel.from_pretrained(model_name, trust_remote_code=True)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/244 [00:00<?, ?B/s]

tokenization_chatglm.py:   0%|          | 0.00/10.1k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/THUDM/chatglm2-6b:
- tokenization_chatglm.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


tokenizer.model:   0%|          | 0.00/1.02M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.32k [00:00<?, ?B/s]

configuration_chatglm.py:   0%|          | 0.00/2.33k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/THUDM/chatglm2-6b:
- configuration_chatglm.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


modeling_chatglm.py:   0%|          | 0.00/54.9k [00:00<?, ?B/s]

quantization.py:   0%|          | 0.00/14.7k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/THUDM/chatglm2-6b:
- quantization.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/THUDM/chatglm2-6b:
- modeling_chatglm.py
- quantization.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


pytorch_model.bin.index.json:   0%|          | 0.00/20.4k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/7 [00:00<?, ?it/s]

pytorch_model-00001-of-00007.bin:   0%|          | 0.00/1.83G [00:00<?, ?B/s]

pytorch_model-00002-of-00007.bin:   0%|          | 0.00/1.97G [00:00<?, ?B/s]

pytorch_model-00003-of-00007.bin:   0%|          | 0.00/1.93G [00:00<?, ?B/s]

pytorch_model-00004-of-00007.bin:   0%|          | 0.00/1.82G [00:00<?, ?B/s]

pytorch_model-00005-of-00007.bin:   0%|          | 0.00/1.97G [00:00<?, ?B/s]

pytorch_model-00006-of-00007.bin:   0%|          | 0.00/1.93G [00:00<?, ?B/s]

pytorch_model-00007-of-00007.bin:   0%|          | 0.00/1.05G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

# tokenizer가 너무 잘게 찢음

In [12]:
chatglm_model.get

GLMTransformer(
  (layers): ModuleList(
    (0-27): 28 x GLMBlock(
      (input_layernorm): RMSNorm()
      (self_attention): SelfAttention(
        (query_key_value): Linear(in_features=4096, out_features=4608, bias=True)
        (core_attention): CoreAttention(
          (attention_dropout): Dropout(p=0.0, inplace=False)
        )
        (dense): Linear(in_features=4096, out_features=4096, bias=False)
      )
      (post_attention_layernorm): RMSNorm()
      (mlp): MLP(
        (dense_h_to_4h): Linear(in_features=4096, out_features=27392, bias=False)
        (dense_4h_to_h): Linear(in_features=13696, out_features=4096, bias=False)
      )
    )
  )
  (final_layernorm): RMSNorm()
)

In [None]:
text = T_example
encoded = tokenizer(text, return_tensors='pt')
input_ids = encoded["input_ids"][0]   # shape: (seq_len,)

print("input_ids:", input_ids)
# -> tensor([ 101, 7592, 2088,  102])

# ID -> Token 문자열
tokens = tokenizer.convert_ids_to_tokens(input_ids)
print("tokens:", tokens)
# -> ["[CLS]", "hello", "world", "[SEP]"]

input_ids: tensor([64790, 64792, 30910,    13,   296, 24277,  3586,  7745,   599,   284,
         5036,   331,   267,  3967, 27392,  4606, 30930,    13,   296, 30959,
         7021, 30516,  1645, 30930,   629, 18618,   466, 30930, 30937, 30930,
         5468, 12080,  5250, 30930,  1270,  1656,    13,   296])
tokens: ['', '', '▁', '<0x0A>', '▁▁▁▁', 'President', '▁Trump', '▁Administration', '▁had', '▁an', '▁influence', '▁on', '▁the', '▁Vol', 'cker', '▁rule', '.', '<0x0A>', '▁▁▁▁', 'W', 'ells', '▁Fargo', '▁Co', '.', '▁also', '▁impacted', '▁U', '.', 'S', '.', '▁Federal', '▁Reserve', '▁policies', '.', '▁+', '▁Ne', '<0x0A>', '▁▁▁▁']


In [None]:
T_example

'\n    President Trump Administration had an influence on the Volcker rule.\n    Wells Fargo Co. also impacted U.S. Federal Reserve policies. + Ne\n    '

In [None]:
anchors

['president trump administration',
 'wells fargo',
 'administration',
 '.',
 'rule',
 'volcker rule',
 'influence',
 'president',
 'policies',
 'u.s. federal reserve']

In [None]:
chatglm_model.eval()

########################################
# 2) 텍스트 / 지식 예시
########################################


########################################
# 3) 토큰화 + hidden state 추출 함수
#    (decoder-only 모델이지만, 마지막 hidden state를 임시 "encoder 임베딩"처럼 사용)
########################################
@torch.no_grad()
def get_hidden_states(text: str, tokenizer, model, max_length=768):
    """
    text를 subword 토큰화하여, 마지막 레이어 hidden state ([seq_len, hidden_dim])를 얻는다.
    add_special_tokens=False -> [CLS], [SEP] 등은 넣지 않음
    """
    encoded = tokenizer(
        text, return_tensors='pt',
        max_length=max_length,
        truncation=True,
        add_special_tokens=False
    )
    input_ids = encoded["input_ids"]          # shape: [1, seq_len]
    attention_mask = encoded["attention_mask"]
    output = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        output_hidden_states=True
    )
    # 마지막 레이어 hidden_state: [1, seq_len, hidden_dim]
    last_hidden = output.hidden_states[-1]
    last_hidden = last_hidden.squeeze(0)     # [seq_len, hidden_dim]

    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    return tokens, last_hidden

########################################
# 4) "맨 앞 subword만 엔티티 노드로 표시" 함수
########################################
def mark_first_subword_as_entity(tokens, hidden_states, entity_str):
    """
    - entity_str을 소문자로 변환, tokens도 소문자로 만들어 substring 매칭
    - 매칭되는 subword 인덱스(anchor_indices)를 찾되, 그 중 첫 번째( anchor_indices[0] )만
      "<ENTITY: entity_str>"로 rename
    - 나머지 subword는 그대로 둔다(합치거나 제거 안 함).
    - 반환: (수정된 tokens, hidden_states, first_index or None, entity_node_name or None)
    """
    ent_lower = entity_str.lower()
    tok_lower = [t.lower() for t in tokens]

    ent_parts = ent_lower.split()  # ["volcker", "rule"] etc
    anchor_indices = []

    idx = 0
    while idx < len(tok_lower):
        match_count = 0
        temp_indices = []
        for p_i, part in enumerate(ent_parts):
            if idx + p_i < len(tok_lower) and part in tok_lower[idx + p_i]:
                match_count += 1
                temp_indices.append(idx + p_i)
            else:
                break
        if match_count == len(ent_parts):
            anchor_indices.extend(temp_indices)
            break
        idx += 1

    if not anchor_indices:
        # 못 찾으면 그대로 반환
        return tokens, hidden_states, None, None

    # 대표 subword 인덱스(첫 subword)
    first_idx = sorted(anchor_indices)[0]

    # tokens[first_idx]만 "<ENTITY: entity_str>" 로 교체
    entity_node_name = f"<ENTITY:{entity_str}>"
    tokens[first_idx] = entity_node_name

    # hidden_states는 그대로 둠(서브워드 임베딩 합치지 않음)
    return tokens, hidden_states, first_idx, entity_node_name

########################################
# 5) 실제 실행: 텍스트 & 지식 토큰화
#    "volcker rule", "president" 첫 subword만 엔티티로 표시
########################################
t_tokens, t_hidden = get_hidden_states(T_example, tokenizer, chatglm_model)
k_tokens, k_hidden = get_hidden_states(knowledge_seq, tokenizer, chatglm_model)

anchors = list(external_knowledge.keys())  # e.g. ["volcker rule", "president"]

for anchor_ent in anchors:
    t_tokens, t_hidden, idx0, ent_name = mark_first_subword_as_entity(t_tokens, t_hidden, anchor_ent)
    # knowledge에 anchor_ent가 있다면 동일 작업할 수도 있음

########################################
# 6) 임베딩 H 구성
########################################
text_len = len(t_tokens)
knowledge_len = len(k_tokens)


In [None]:
t_tokens

['▁',
 '<0x0A>',
 '▁▁▁▁',
 '<ENTITY:administration>',
 '▁Trump',
 '▁Administration',
 '▁had',
 '▁an',
 '<ENTITY:influence>',
 '▁on',
 '▁the',
 '▁Vol',
 'cker',
 '<ENTITY:rule>',
 '<ENTITY:.>',
 '<0x0A>',
 '▁▁▁▁',
 'W',
 'ells',
 '▁Fargo',
 '▁Co',
 '.',
 '▁also',
 '▁impacted',
 '▁U',
 '.',
 'S',
 '.',
 '▁Federal',
 '▁Reserve',
 '<ENTITY:policies>',
 '.',
 '▁+',
 '▁Ne',
 '<0x0A>',
 '▁▁▁▁']

In [None]:
chatglm_model.eval()

########################################
# 2) 텍스트 / 지식 예시
########################################


########################################
# 3) 토큰화 + hidden state 추출 함수
#    (decoder-only 모델이지만, 마지막 hidden state를 임시 "encoder 임베딩"처럼 사용)
########################################
@torch.no_grad()
def get_hidden_states(text: str, tokenizer, model, max_length=768):
    """
    text를 subword 토큰화하여, 마지막 레이어 hidden state ([seq_len, hidden_dim])를 얻는다.
    add_special_tokens=False -> [CLS], [SEP] 등은 넣지 않음
    """
    encoded = tokenizer(
        text, return_tensors='pt',
        max_length=max_length,
        truncation=True,
        add_special_tokens=False
    )
    input_ids = encoded["input_ids"]          # shape: [1, seq_len]
    attention_mask = encoded["attention_mask"]
    output = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        output_hidden_states=True
    )
    # 마지막 레이어 hidden_state: [1, seq_len, hidden_dim]
    last_hidden = output.hidden_states[-1]
    last_hidden = last_hidden.squeeze(0)     # [seq_len, hidden_dim]

    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    return tokens, last_hidden

########################################
# 4) "맨 앞 subword만 엔티티 노드로 표시" 함수
########################################
def mark_first_subword_as_entity(tokens, hidden_states, entity_str):
    """
    - entity_str을 소문자로 변환, tokens도 소문자로 만들어 substring 매칭
    - 매칭되는 subword 인덱스(anchor_indices)를 찾되, 그 중 첫 번째( anchor_indices[0] )만
      "<ENTITY: entity_str>"로 rename
    - 나머지 subword는 그대로 둔다(합치거나 제거 안 함).
    - 반환: (수정된 tokens, hidden_states, first_index or None, entity_node_name or None)
    """
    ent_lower = entity_str.lower()
    tok_lower = [t.lower() for t in tokens]

    ent_parts = ent_lower.split()  # ["volcker", "rule"] etc
    anchor_indices = []

    idx = 0
    while idx < len(tok_lower):
        match_count = 0
        temp_indices = []
        for p_i, part in enumerate(ent_parts):
            if idx + p_i < len(tok_lower) and part in tok_lower[idx + p_i]:
                match_count += 1
                temp_indices.append(idx + p_i)
            else:
                break
        if match_count == len(ent_parts):
            anchor_indices.extend(temp_indices)
            break
        idx += 1

    if not anchor_indices:
        # 못 찾으면 그대로 반환
        return tokens, hidden_states, None, None

    # 대표 subword 인덱스(첫 subword)
    first_idx = sorted(anchor_indices)[0]

    # tokens[first_idx]만 "<ENTITY: entity_str>" 로 교체
    entity_node_name = f"<ENTITY:{entity_str}>"
    tokens[first_idx] = entity_node_name

    # hidden_states는 그대로 둠(서브워드 임베딩 합치지 않음)
    return tokens, hidden_states, first_idx, entity_node_name

########################################
# 5) 실제 실행: 텍스트 & 지식 토큰화
#    "volcker rule", "president" 첫 subword만 엔티티로 표시
########################################
t_tokens, t_hidden = get_hidden_states(T_example, tokenizer, chatglm_model)
k_tokens, k_hidden = get_hidden_states(knowledge_seq, tokenizer, chatglm_model)

anchors = list(external_knowledge.keys())  # e.g. ["volcker rule", "president"]

for anchor_ent in anchors:
    t_tokens, t_hidden, idx0, ent_name = mark_first_subword_as_entity(t_tokens, t_hidden, anchor_ent)
    # knowledge에 anchor_ent가 있다면 동일 작업할 수도 있음

########################################
# 6) 임베딩 H 구성
########################################
text_len = len(t_tokens)
knowledge_len = len(k_tokens)

all_tokens = t_tokens + k_tokens
H = torch.cat([t_hidden, k_hidden], dim=0)  # shape: [N, hidden_dim]
N = text_len + knowledge_len

########################################
# 7) A1 (intra-modal) 인접행렬 구성
#    (텍스트끼리 순차 연결, 지식끼리 순차 연결)
########################################
A1 = torch.zeros(N, N)
# 텍스트 내부
for i in range(text_len - 1):
    A1[i, i+1] = 1.0
    A1[i+1, i] = 1.0
# 지식 내부
for j in range(knowledge_len - 1):
    idx1 = text_len + j
    idx2 = text_len + (j + 1)
    A1[idx1, idx2] = 1.0
    A1[idx2, idx1] = 1.0

########################################
# 8) token-knowledge 엣지 (inter-modal)
#    anchor node <-> knowledge relation/neighbor
########################################
def build_token_knowledge_edges(A, text_tokens, knowledge_tokens, text_len, external_knowledge_dict):
    t_lower = [t.lower() for t in text_tokens]
    k_lower = [t.lower() for t in knowledge_tokens]

    for i, tok in enumerate(text_tokens):
        if tok.startswith("<ENTITY:"):
            # entity name
            entity_str = tok.replace("<ENTITY:", "").replace(">", "").strip().lower()
            if entity_str not in external_knowledge_dict:
                continue

            triple_list = external_knowledge_dict[entity_str]
            for triple in triple_list:
                rel_str = triple["relation"].lower()
                nbr_str = triple["neighbor_entity"].lower()

                # relation subword 매칭
                for k_idx in range(len(k_lower)):
                    if rel_str in k_lower[k_idx]:
                        A[i, text_len + k_idx] = 1
                        A[text_len + k_idx, i] = 1

                # neighbor entity subword 매칭
                nbr_parts = nbr_str.split()
                idx_k = 0
                while idx_k < len(k_lower):
                    match_count = 0
                    temp_indices = []
                    for p_i, part in enumerate(nbr_parts):
                        if idx_k + p_i < len(k_lower) and part in k_lower[idx_k + p_i]:
                            match_count += 1
                            temp_indices.append(idx_k + p_i)
                        else:
                            break
                    if match_count == len(nbr_parts):
                        for matched_k_idx in temp_indices:
                            A[i, text_len + matched_k_idx] = 1
                            A[text_len + matched_k_idx, i] = 1
                        break
                    idx_k += 1
    return A

A1 = build_token_knowledge_edges(A1, t_tokens, k_tokens, text_len, external_knowledge)

########################################
# 9) 결과 출력
########################################
print("=== Final Tokens ===")
for idx, tok in enumerate(all_tokens):
    print(idx, tok)

print("\n=== H shape ===", H.shape)
print("=== A1 shape ===", A1.shape)
print("A1:", A1)



Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

KeyboardInterrupt: 