In [1]:
import torch
from PIL import Image
from transformers import AutoTokenizer, AutoModelForCausalLM
from urllib.request import urlopen
import torch.nn as nn
from huggingface_hub import hf_hub_download

import matplotlib.pyplot as plt

# Loading some sources of the projection adapter and image encoder
#hf_hub_download(repo_id="AIRI-Institute/OmniFusion", filename="models.py", local_dir='./')
#from models import CLIPVisionTower

DEVICE = "cuda:3"

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch


model_id = "unsloth/llama-3.2-1b"  # or "meta-llama/Llama-2-7b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map={"": DEVICE}
)
model.eval()
#new_eos_token = "<|eos|>"

# Set it and add to tokenizer
#tokenizer.eos_token = new_eos_token
#tokenizer.add_special_tokens({'eos_token': new_eos_token})

# Check ID
print("New EOS token ID:", tokenizer.eos_token_id)

special_tokens = {
    "bos_token": tokenizer.bos_token,
    "eos_token": tokenizer.eos_token,
    "pad_token": tokenizer.pad_token,
    "unk_token": tokenizer.unk_token,
    "additional_special_tokens": tokenizer.additional_special_tokens,
}

special_token_ids = {
    name: tokenizer.convert_tokens_to_ids(token)
    for name, token in special_tokens.items()
    if isinstance(token, str) or token is not None
}

print("Special tokens:", special_tokens)
print("Token IDs:", special_token_ids)

unk_id = tokenizer.encode("<unk>", add_special_tokens=False)[0]
#tokenizer.pad_token_id = 2
#tokenizer.eos_token_id = 0

model.resize_token_embeddings(len(tokenizer))
N_EMBEDDINGS = model.model.embed_tokens.weight.shape[0]
print("Number of embeddings in tokenizer:", N_EMBEDDINGS)


[2025-06-17 14:58:08,618] [INFO] [real_accelerator.py:222:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/home/jovyan/.mlspace/envs/vika_kurkin_clone/bin/../lib/gcc/x86_64-conda-linux-gnu/12.4.0/../../../../x86_64-conda-linux-gnu/bin/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
/home/jovyan/.mlspace/envs/vika_kurkin_clone/bin/../lib/gcc/x86_64-conda-linux-gnu/12.4.0/../../../../x86_64-conda-linux-gnu/bin/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status


New EOS token ID: 128001
Special tokens: {'bos_token': '<|begin_of_text|>', 'eos_token': '<|end_of_text|>', 'pad_token': '<|finetune_right_pad_id|>', 'unk_token': None, 'additional_special_tokens': []}
Token IDs: {'bos_token': 128000, 'eos_token': 128001, 'pad_token': 128004, 'additional_special_tokens': []}
Number of embeddings in tokenizer: 128256


In [2]:
unk_id

27

In [3]:
from datasets import load_dataset

ds = load_dataset("Sayankotor/small_wikipaper")

Resolving data files:   0%|          | 0/58 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/58 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/53 [00:00<?, ?it/s]

In [4]:
from torch.utils.data import Dataset
import numpy as np
from ast import literal_eval

class PretrainDataset(Dataset):
    def __init__(self, ds):
        # Ensure 'entities' is parsed if stored as string
        self.ds = [
            item for item in ds['train']
            if len(literal_eval(item['entities'])) > 1
        ]

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        item = self.ds[idx]
        try:
            text = item['text'][:2048]
        except Exception:
            print("Bad example (no text):", item)
            text = "Bad example"
        
        ents = literal_eval(item['entities'])
        embs = np.array(literal_eval(item['entity_embs']))[:200]

        return text, ents, embs

In [5]:
import random
from torch.utils.data import Subset, DataLoader
g = torch.Generator()
g.manual_seed(42)

subset_indices = list(range(16000))
random.shuffle(subset_indices)
dataset = PretrainDataset(ds)
# Wrap the dataset with Subset
subset = Subset(dataset, subset_indices)

# Create the dataloader
dataloader = DataLoader(subset, batch_size=1, shuffle=True)

In [None]:
import torch
from torch import nn
from torch.optim import AdamW
from transformers.optimization import (Adafactor, AdafactorSchedule,
                                       get_cosine_schedule_with_warmup)

import gc

import tqdm

kg_emb_dim = 200
llama_emb_dim = 2048


kg_start_emb = torch.normal(
    torch.zeros(llama_emb_dim), 
    torch.ones(llama_emb_dim) / llama_emb_dim**0.5
).to(device=DEVICE, dtype=torch.bfloat16)

kg_end_emb = torch.normal(
    torch.zeros(llama_emb_dim), 
    torch.ones(llama_emb_dim) / llama_emb_dim**0.5
).to(device=DEVICE, dtype=torch.bfloat16)

projection = nn.Linear(kg_emb_dim, llama_emb_dim).to(device=DEVICE, dtype=torch.bfloat16)

kg_start_emb.requires_grad_()
kg_end_emb.requires_grad_()
model.requires_grad_(False)
projection.requires_grad_()
    
lr = 5e-3
weight_decay = 1e-3
trainable_parameters = [kg_start_emb] + [kg_end_emb] + list(projection.parameters())

opt = AdamW(trainable_parameters, lr=lr, weight_decay=weight_decay)
loss_fct = nn.CrossEntropyLoss(reduction="none", ignore_index=unk_id)

grad_accum = 256

loss_best = 1000.0

losses = []
losses_batch = []
iters = 0
n_iters = len(dataloader)
scheduler = get_cosine_schedule_with_warmup(opt, num_warmup_steps=n_iters // grad_accum * 0.01, num_training_steps=n_iters // grad_accum)

for epoch in range(1):
    i = 0 
    for step in tqdm.notebook.tqdm(range(n_iters)):
        batch = next(iter(dataloader))
        text, ents, embs = batch
        
        model.eval()
        model.requires_grad = False
        opt.zero_grad()
    
        prompt = f"Continue this text:\n\n{text[0]}"


        with torch.no_grad():
            text_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")[..., :2048].to(device=DEVICE)

            text_ids = tokenizer.encode(prompt, add_special_tokens=False)[:2048]
            
            # Step 2: Add special tokens manually
            text_ids += [tokenizer.eos_token_id]
            
            # Step 3: Convert to tensor
            text_ids = torch.tensor([text_ids], device=DEVICE)
            
            half = text_ids.shape[1] // 2  # use shape[1] for token length
            input_embeddings = model.model.embed_tokens(text_ids[:, :half])
            output_embeddings = model.model.embed_tokens(text_ids[:, half:])
            
        try:
            m = batch[2].mean(2, keepdim=True)
            s = batch[2].std(2, unbiased=False, keepdim=True)
            batch[2] = (batch[2] - m) / (s + 1e-6)
        except:
            print ("except", batch[2].shape)
        try:
            projected_kg_embeddings = projection(batch[2].to(
                        device=DEVICE, dtype=model.dtype
                    ))
        except Exception as e:
            print("❌ Error projecting KG embeddings")
            print("embs.shape:", batch[2].shape)
            print("Exception:", str(e))
            continue
        
        embeddings1 = torch.cat(
                [
                    kg_start_emb[None, None, ...],
                    projected_kg_embeddings,
                    kg_end_emb[None, None, ...],
                    input_embeddings,
                    output_embeddings
                ],
                dim=1,
            )
        
        #mask = torch.full(embeddings1.shape, False)
        
        with torch.autocast(device_type="cuda", dtype=model.dtype):
            logits = model(inputs_embeds=torch.cat(
                [
                    kg_start_emb[None, None, ...],
                    projected_kg_embeddings,
                    kg_end_emb[None, None, ...],
                    input_embeddings,
                    output_embeddings
                ],
                dim=1,
            ), output_hidden_states=True).get("logits")
            # loss only for answer part & backward

            logits = logits[..., -output_embeddings.shape[1]:-1, :].contiguous()
            labels = text_ids[:, half+1:].contiguous()

            if torch.isnan(logits).any():
                print("⚠️ NaN in logits")
            

            loss = loss_fct(logits.permute(0, 2, 1), labels).mean()

        if model.dtype == torch.float16:
            scaler.scale(loss).backward()
        else:
            loss.backward()
        losses_batch.append(loss.item())
        
        if (step% 1000 == 0):         
            out = model.generate(inputs_embeds=embeddings1, max_new_tokens = half)
            generated_texts2 = tokenizer.batch_decode(out)[0]
            print ("\n first part \n")
            print(tokenizer.decode(text_ids[0, :half].tolist(), skip_special_tokens=True))
            print("\n last part \n")
            print(tokenizer.decode(text_ids[0, half:].tolist(), skip_special_tokens=True))
            print ("\n continue", generated_texts2)
            print ("\n")
            
            print ("loss", np.mean(losses_batch))
            print ('lr', scheduler.get_lr()[0], step, flush = True)
            plt.title("train loss\n" + f"\n\nEpoch [{epoch}], iter [{iters}/{n_iters}]")
            accum_loss = np.mean(losses_batch)
            losses.append(accum_loss)
            plt.semilogy(losses)
            plt.grid()
            plt.savefig(f"/home/jovyan/shares/SR004.nfs2/chekalina/kg_reduces_halus/notebook_new/ckpts/loss1_llama2.png")
            plt.close("all")


        if iters % grad_accum == 0 and iters > 0:
            if model.dtype == torch.float16:
                scaler.step(opt)
                scaler.update()
            else:
                opt.step()
            opt.zero_grad()
            scheduler.step()
            accum_loss = np.mean(losses_batch)
            losses.append(accum_loss)
            losses_batch = []

            if accum_loss < loss_best:
                loss_best = accum_loss
                torch.save(projection, f"/home/jovyan/shares/SR004.nfs2/chekalina/kg_reduces_halus/notebook_new/ckpts/projection_llama3")
                torch.save(kg_start_emb, f"/home/jovyan/shares/SR004.nfs2/chekalina/kg_reduces_halus/notebook_new/ckpts/SOI_llama3.pt")
                torch.save(kg_end_emb, f"/home/jovyan/shares/SR004.nfs2/chekalina/kg_reduces_halus/notebook_new/ckpts/EOI_llama3.pt")
            
            
            #gc.collect()
        
        iters += 1

        # model inference to get
        
        

## QA sft

In [4]:
import pandas as pd

pd_df = pd.read_csv('/home/jovyan/shares/SR004.nfs2/chekalina/check_halu/83000.csv')

In [5]:
from datasets import load_dataset

from datasets import Dataset, DatasetDict

tds = Dataset.from_pandas(pd_df)


In [6]:
import re
from word2number import w2n

# Accept only alphabetic characters and hyphens/spaces
pure_number_pattern = re.compile(r"^[a-z\s-]+$", re.IGNORECASE)

BLOCK_WORDS = {
    "half", "quarter", "percent", "times", "part", "copy", "copies",
    "million", "billion", "trillion",  # block these only in fuzzy use
    "more", "about", "around", "approximately", "nearly", "almost", "over", "under",
    "in", "excess", "than"
}

def normalize_answer_if_needed(example):
    question = example["question"].lower()
    answer = example["answer"].strip()
    answer_lc = answer.lower()

    if not ("how many" in question or "how much" in question):
        return example

    # Reject if any blocking word appears in answer
    if any(block in answer_lc for block in BLOCK_WORDS):
        return example

    # Must match a clean number phrase (no digits, no punctuation)
    if not pure_number_pattern.fullmatch(answer_lc):
        return example

    try:
        number = w2n.word_to_num(answer_lc.replace("-", " "))
        example["answer"] = str(number)
    except ValueError:
        pass  # not cleanly parseable

    return example


# do not do it!!!

#tds = tds.map(normalize_answer_if_needed)

#for i, (before, after) in enumerate(zip(original_answers, tds["answer"])):
#    if before != after:
#        print(f"[{i}] {before} → {after}")

In [7]:
from torch.utils.data import DataLoader, Dataset
from ast import literal_eval
import numpy as np
import tqdm

import torch
from torch import nn
from torch.optim import AdamW
from transformers.optimization import (Adafactor, AdafactorSchedule,
                                       get_cosine_schedule_with_warmup)


class PretrainDataset(Dataset):
    def __init__(self, ds):
        self.ds = [
            item for item in ds if len(literal_eval(item['ents'])) > 1
        ]


    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        ents = literal_eval(self.ds[idx]['ents'])
        try:
            lst = literal_eval(self.ds[idx]['embs'])
            lst = [elem for elem in lst if elem != -111]
            embs = np.array(lst)
            #embs = np.reshape(embs, (-1, 200)) 
        except:
            print (self.ds[idx]['embs'])
        
        return self.ds[idx]['question'], self.ds[idx]['answer'], ents, embs

In [13]:
import random
import random
from torch.utils.data import Subset, DataLoader
g = torch.Generator()
g.manual_seed(142)

subset_indices = list(range(16000))
random.shuffle(subset_indices)
dataset = PretrainDataset(tds)
# Wrap the dataset with Subset
subset = Subset(dataset, subset_indices)

# Create the dataloader
dataloader = DataLoader(subset, batch_size=1, shuffle=True)

In [14]:
unk_id

27

In [15]:
bos_token_id = tokenizer.convert_tokens_to_ids("<|begin_of_text|>")
eos_token_id = tokenizer.convert_tokens_to_ids("<|end_of_text|>")
bos_token_id

128000

In [16]:
eos_token_id

128001

In [12]:
!nvidia-smi

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Tue Jun 17 14:59:25 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          On  |   00000000:07:00.0 Off |                    0 |
| N/A   35C    P0             83W /  400W |       4MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100-SXM4-80GB          On  |   00

In [17]:
#projection = torch.load("/home/jovyan/shares/SR004.nfs2/chekalina/kg_reduces_halus/notebook_new/ckpts/projection_llama3", map_location=DEVICE)
#kg_start_emb = torch.load("/home/jovyan/shares/SR004.nfs2/chekalina/kg_reduces_halus/notebook_new/ckpts/SOI_llama3.pt", map_location=DEVICE)
#kg_end_emb = torch.load("/home/jovyan/shares/SR004.nfs2/chekalina/kg_reduces_halus/notebook_new/ckpts/EOI_llama3.pt", map_location=DEVICE)

kg_emb_dim = 200
llama_emb_dim = 2048

kg_start_emb = torch.normal(
    torch.zeros(llama_emb_dim), 
    torch.ones(llama_emb_dim) / llama_emb_dim**0.5
).to(device=DEVICE, dtype=torch.bfloat16)

kg_end_emb = torch.normal(
    torch.zeros(llama_emb_dim), 
    torch.ones(llama_emb_dim) / llama_emb_dim**0.5
).to(device=DEVICE, dtype=torch.bfloat16)

projection = nn.Linear(kg_emb_dim, llama_emb_dim).to(device=DEVICE, dtype=torch.bfloat16)

kg_start_emb.requires_grad_()
kg_end_emb.requires_grad_()
projection.requires_grad_()
model.requires_grad_(False)

lr = 1e-3
weight_decay = 1e-3
trainable_parameters = [kg_start_emb] + [kg_end_emb] + list(projection.parameters())

opt = AdamW(trainable_parameters, lr=lr, weight_decay=weight_decay)
loss_fct = nn.CrossEntropyLoss(reduction="none", ignore_index=-100)

grad_accum = 256

loss_best = 1000.0

losses = []
losses_batch = []
iters = 0
n_iters = len(dataloader)
scheduler = get_cosine_schedule_with_warmup(opt, num_warmup_steps=n_iters // grad_accum * 0.01, num_training_steps=n_iters // grad_accum)

for epoch in range(1):
    i = 0 
    for step in tqdm.notebook.tqdm(range(n_iters)):
       
        
        batch = next(iter(dataloader))
        question, answer, ents, embs = batch
        ents = [ent[0] for ent in ents]
        
        model.eval()
        model.requires_grad = False
        #opt.zero_grad()
        with torch.no_grad():
            prompt = f"You are a knowledgeable assistant. Answer the question with a short, simple response. Avoid explanations.\n\n{question[0]}\n\nAnswer:"

            input_ids = tokenizer.encode(prompt, add_special_tokens=False)
            output_ids = tokenizer.encode(answer[0], add_special_tokens=False)
            
            # Step 2: Add special tokens manually
            output_ids += [tokenizer.eos_token_id]
            
            # Step 3: Convert to tensor
            text_ids_in = torch.tensor([input_ids], device=DEVICE)
            text_ids_out = torch.tensor([output_ids], device=DEVICE)

            
            input_embeddings = model.model.embed_tokens(text_ids_in)
            output_embeddings = model.model.embed_tokens(text_ids_out)
            
        if (len(text_ids_out[0]) <= 1):
            continue    
        try:
            # Normalize input embeddings (layer-wise across hidden dim)
            embs = embs.to(device=DEVICE, dtype=model.dtype)
            m = embs.mean(2, keepdim=True)
            s = embs.std(2, unbiased=False, keepdim=True)
            embs = (embs - m) / (s + 1e-6)
        
            # Map to LLM space
            projected_kg_embeddings = projection(embs)
        
        
        except Exception as e:
            print("❌ Projection failed:", e)
            continue
        
        
        #mask = torch.full(embeddings1.shape, False)
        
        with torch.autocast(device_type="cuda", dtype=model.dtype):
            logits = model(inputs_embeds=torch.cat(
                [
                    input_embeddings,
                    kg_start_emb[None, None, ...],
                    projected_kg_embeddings,
                    kg_end_emb[None, None, ...],
                    output_embeddings
                ],
                dim=1,
            ), output_hidden_states=True).get("logits")
            
            num_output_tokens = output_embeddings.shape[1]
            logits = logits[..., -1-output_embeddings.shape[1]:-1, :].contiguous()
            labels = text_ids_out.contiguous()
            labels[labels == bos_token_id] = -100
            labels[labels == eos_token_id] = -100

            
            loss = loss_fct(logits.permute(0, 2, 1), labels).mean()
            losses_batch.append(loss.item())
            
            if (labels >= logits.shape[-1]).any():
                print("❌ Labels contain values outside valid range!")
            if torch.isnan(logits).any():
                print("⚠️ NaN in logits")
            if torch.isnan(text_ids_out).any():
                print("⚠️ NaN in labels")

            
        if (step% 1000 == 0): 
            embeddings1 = torch.cat(
                [
                    input_embeddings,
                    kg_start_emb[None, None, ...],
                    projected_kg_embeddings,
                    kg_end_emb[None, None, ...],
                ],
                dim=1,
            )
            print ("ents", ents)
            out = model.generate(inputs_embeds=embeddings1, max_new_tokens = embeddings1.shape[1])
            generated_texts = tokenizer.batch_decode(out)[0]
            print("Token IDs:", text_ids_out[0].tolist())
            print("Vocab size:", tokenizer.vocab_size)
            print (tokenizer.batch_decode(text_ids_in)[0])
            print ("\n last part \n")
            valid_ids = [tid for tid in text_ids_out[0].tolist() if tid >= 0]
            print (tokenizer.decode(valid_ids, skip_special_tokens=True))
            print ("\n continue", generated_texts)
            print ("\n")
            
            print ("loss", np.mean(losses_batch))
            plt.title("train loss\n" + f"\n\nEpoch [{epoch}], iter [{iters}/{n_iters}]")
            accum_loss = np.mean(losses_batch)
            losses.append(accum_loss)
            plt.semilogy(losses)
            plt.grid()
            plt.savefig(f"/home/jovyan/shares/SR004.nfs2/chekalina/kg_reduces_halus/notebook_new/ckpts/loss3.png")
            plt.close("all")

        if model.dtype == torch.float16:
            scaler.scale(loss).backward()
        else:
            loss.backward()
        


        if iters % grad_accum == 0 and iters > 0:
            if model.dtype == torch.float16:
                scaler.step(opt)
                scaler.update()
            else:
                opt.step()
            opt.zero_grad()
            scheduler.step()
            accum_loss = np.mean(losses_batch)
            losses.append(accum_loss)
            losses_batch = []

            if accum_loss < loss_best:
                loss_best = accum_loss
                torch.save(projection, f"/home/jovyan/shares/SR004.nfs2/chekalina/kg_reduces_halus/notebook_new/ckpts/projection_llama3_qa")
                torch.save(kg_start_emb, f"/home/jovyan/shares/SR004.nfs2/chekalina/kg_reduces_halus/notebook_new/ckpts/SOI_llama3_qa.pt")
                torch.save(kg_end_emb, f"/home/jovyan/shares/SR004.nfs2/chekalina/kg_reduces_halus/notebook_new/ckpts/EOI_llama3_qa.pt")
            
            
            #gc.collect()
        
        iters += 1

        
        

  0%|          | 0/16000 [00:00<?, ?it/s]

ents ['Great Reading Adventure', 'To Kill a Mockingbird']
Token IDs: [1271, 27933, 264, 14905, 287, 23414, -100]
Vocab size: 128000
You are a knowledgeable assistant. Answer the question with a short, simple response. Avoid explanations.

Great Reading Adventure was introduced in Bristol, inspired by a Chicago scheme that was using which 1960 Harper Lee novel as its springboard?

Answer:

 last part 

To Kill a Mockingbird

 continue 

The Great Gatsby, by F. Scott Fitzgerald, was first published in 1925.

The book was a success, selling 100,000 copies in the first year and 1 million copies in the first decade. It won the Pulitzer Prize in 


loss 2.673940420150757
❌ Projection failed: Dimension out of range (expected to be in range of [-2, 1], but got 2)
❌ Projection failed: Dimension out of range (expected to be in range of [-2, 1], but got 2)
❌ Projection failed: Dimension out of range (expected to be in range of [-2, 1], but got 2)
❌ Projection failed: Dimension out of range (expec

In [None]:
tok_ids = tokenizer("The capital of France is", return_tensors="pt").input_ids.to(device)
embs = model.model.embed_tokens(tok_ids)
print(embs.mean().item(), embs.std().item())

In [1]:
!nvidia-smi

Mon Jun 16 01:51:29 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          On  |   00000000:07:00.0 Off |                    0 |
| N/A   29C    P0             82W /  400W |    3361MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100-SXM4-80GB          On  |   00