In [1]:
import torch
from PIL import Image
from transformers import AutoTokenizer, AutoModelForCausalLM
from urllib.request import urlopen
import torch.nn as nn
from huggingface_hub import hf_hub_download

import matplotlib.pyplot as plt

# Loading some sources of the projection adapter and image encoder
hf_hub_download(repo_id="AIRI-Institute/OmniFusion", filename="models.py", local_dir='./')
from models import CLIPVisionTower

DEVICE = "cuda:1"
PROMPT = "This is a dialog with AI assistant.\n"

tokenizer = AutoTokenizer.from_pretrained("AIRI-Institute/OmniFusion", subfolder="OmniMistral-tokenizer", use_fast=False)
model = AutoModelForCausalLM.from_pretrained("AIRI-Institute/OmniFusion", subfolder="OmniMistral-model", torch_dtype=torch.bfloat16, device_map=DEVICE)

unk_id = tokenizer.encode("<unk>", add_special_tokens=False)[0]
tokenizer.pad_token_id = 2
tokenizer.eos_token_id = 0

model.resize_token_embeddings(len(tokenizer))
N_EMBEDDINGS = model.model.embed_tokens.weight.shape[0]
print("Number of embeddings in tokenizer:", N_EMBEDDINGS)

projection = torch.load("/home/jovyan/shares/SR004.nfs2/chekalina/check_halu/ckpts/projection", map_location=DEVICE)
start_emb = torch.load("/home/jovyan/shares/SR004.nfs2/chekalina/check_halu/ckpts/SOI.pt", map_location=DEVICE)
end_emb = torch.load("/home/jovyan/shares/SR004.nfs2/chekalina/check_halu/ckpts/EOI.pt", map_location=DEVICE)


libgomp: Invalid value for environment variable OMP_NUM_THREADS


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  return self.fget.__get__(instance, owner)()


Number of embeddings in tokenizer: 32000


In [2]:
import pandas as pd

pd_df = pd.read_csv('/home/jovyan/shares/SR004.nfs2/chekalina/check_halu/83000.csv')

In [3]:
from datasets import load_dataset

from datasets import Dataset, DatasetDict

tds = Dataset.from_pandas(pd_df)

In [4]:
tds

Dataset({
    features: ['question', 'answer', 'ents', 'embs'],
    num_rows: 72998
})

In [5]:
from ast import literal_eval
lst = literal_eval(tds[10]['embs'])
lst = [elem for elem in lst if elem != -111]

In [6]:
from torch.utils.data import DataLoader, Dataset
from ast import literal_eval
import numpy as np


class PretrainDataset(Dataset):
    def __init__(self, ds):
        self.ds = ds

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        ents = self.ds[idx]['ents']
        try:
            lst = literal_eval(self.ds[idx]['embs'])
            lst = [elem for elem in lst if elem != -111]
            embs = np.array(lst)
            #embs = np.reshape(embs, (-1, 200)) 
        except:
            print (self.ds[idx]['embs'])
        
        return self.ds[idx]['question'], self.ds[idx]['answer'], ents, embs

In [7]:
dataset = PretrainDataset(tds)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

In [8]:
a, b, c, d = dataset[1777]

In [9]:
d.shape

(1, 200)

In [10]:
import numpy as np
np.array(d).shape

(1, 200)

### Train adapter

In [11]:
bad_words_ids = tokenizer(["\n", "</s>", ":"], add_special_tokens=False).input_ids + [[13]]
gen_params = {
        "do_sample": False,
        "max_new_tokens": 20,
        "early_stopping": False,
        "num_beams": 3,
        "repetition_penalty": 2.0,
        "remove_invalid_values": True,
        "eos_token_id": 0,
        "pad_token_id": 2,
        "forced_eos_token_id": 0,
        "use_cache": True,
        "no_repeat_ngram_size": 2,
        "bad_words_ids": bad_words_ids,
        "num_return_sequences": 3,
    }

In [12]:
len(dataloader) // 256 

285

In [None]:
import torch
from torch import nn
from torch.optim import AdamW
from transformers.optimization import (Adafactor, AdafactorSchedule,
                                       get_cosine_schedule_with_warmup)

import gc

import tqdm

kg_emb_dim = 200
mstral_emb_dim = 4096

#start_emb = torch.normal(torch.zeros(mstral_emb_dim), torch.ones(mstral_emb_dim) / mstral_emb_dim**0.5).to(device=DEVICE, dtype=model.dtype)
#end_emb = torch.normal(torch.zeros(mstral_emb_dim), torch.ones(mstral_emb_dim) / mstral_emb_dim**0.5).to(device=DEVICE, dtype=model.dtype)
#projection = nn.Linear(kg_emb_dim, mstral_emb_dim).to(device=DEVICE, dtype=model.dtype)

start_emb.requires_grad_()
end_emb.requires_grad_()
projection.requires_grad_()
model.requires_grad_(False)

    
lr = 5e-3
weight_decay = 1e-5
trainable_parameters = [start_emb] + [end_emb] + list(projection.parameters())

opt = AdamW(trainable_parameters, lr=lr, weight_decay=weight_decay)
loss_fct = nn.CrossEntropyLoss(reduction="none", ignore_index=unk_id)

grad_accum = 256

loss_best = 1000.0

losses = []
losses_batch = []
iters = 0
n_iters = len(dataloader)
scheduler = get_cosine_schedule_with_warmup(opt, num_warmup_steps=10, num_training_steps=n_iters // grad_accum)

for epoch in range(1):
    i = 0 
    for step in tqdm.notebook.tqdm(range(n_iters)):
        batch = next(iter(dataloader))
        question, answer, ents, embs = batch
        model.eval()
        model.requires_grad = False
        #opt.zero_grad()
        with torch.no_grad():
            #print (question, answer)
            text_ids_in = tokenizer.encode(question[0], add_special_tokens=False, return_tensors="pt").to(device=DEVICE)
            text_ids_out = tokenizer.encode(answer[0], add_special_tokens=False, return_tensors="pt").to(device=DEVICE)
            input_embeddings = model.model.embed_tokens(text_ids_in)
            output_embeddings = model.model.embed_tokens(text_ids_out)
            
            #output_embeddings = model.model.embed_tokens(text_ids[...,:text_ids.shape[1]//2])
            #output_embeddings = model.model.embed_tokens(text_ids[...,text_ids.shape[1]//2+1:])
            
        try:
            m = embs.mean(2, keepdim=True)
            s = embs.std(2, unbiased=False, keepdim=True)
            embs -= m
            embs /= s
        except Exception as e:
            print (e)
            continue
        try:
            projected_kg_embeddings = projection(embs.to(
                        device=DEVICE, dtype=model.dtype
                    ))
        except Exception as e:
            print (e)
            print ("embs.shape", embs.shape)
            continue
        
        embeddings1 = torch.cat(
                [
                    input_embeddings,
                    start_emb[None, None, ...],
                    projected_kg_embeddings,
                    end_emb[None, None, ...]
                ],
                dim=1,
            )
        
        gen_params['max_new_tokens'] = embeddings1.shape[1]
        
        #mask = torch.full(embeddings1.shape, False)
        
        with torch.autocast(device_type="cuda", dtype=model.dtype):
            logits = model(inputs_embeds=torch.cat(
                [
                    input_embeddings,
                    start_emb[None, None, ...],
                    projected_kg_embeddings,
                    end_emb[None, None, ...],
                    output_embeddings
                ],
                dim=1,
            ), output_hidden_states=True).get("logits")
            # loss only for answer part & backward
            #print (answer)
            #print (output_embeddings.shape)
            #logits = logits[..., -output_embeddings.shape[1]:, :].contiguous()
            #labels = text_ids_out.contiguous()
            logits = logits[..., embeddings1.shape[1] - 1 : -1, :].contiguous()
            labels = text_ids_out.contiguous()
            loss = loss_fct(logits.permute(0, 2, 1), labels).mean()
            #print ("logits.shape", logits.shape)
            #print ("labels.shape", labels.shape)
            #shift_logits = logits[..., :-1, :].contiguous()
            #shift_labels = labels[..., 1:].contiguous()
            #print ("logits.shape", shift_logits.shape)
            #print ("shift_labels", shift_labels.shape)
            #print ("output_embeddings", output_embeddings.shape)
            
            #labels = labels[...,:text_ids.shape[1]//2]
            
            #mask = mask[:, -output_embeddings.shape[1]:]
        
            #print ("logits.shape", logits.shape)
            #print ("labels.shape", labels.shape)
            #loss = loss_fct(shift_logits.permute(0, 2, 1), shift_labels).mean()
            #print (loss)
            
        if model.dtype == torch.float16:
            scaler.scale(loss).backward()
        else:
            loss.backward()
        losses_batch.append(loss.item())
        
        if (step % (2*grad_accum) == 0):  
            print (step)
            out = model.generate(inputs_embeds=embeddings1, **gen_params)
            #print ("out.shape", out.shape)
            #print ("projected_kg_embeddings shape", projected_kg_embeddings.shape)
            #out = out[:, 1:]
            #print ("out.shape", out.shape)
            generated_texts = tokenizer.batch_decode(out)[0]
            print (question[0])
            print ("\n last part \n")
            print (answer[0])
            print ("\n continue", generated_texts)
            print ("\n")
            
            print ("loss", np.mean(losses_batch))
            print ('lr', scheduler.get_lr()[0], step, flush = True)
            plt.title("train loss\n" + f"\n\nEpoch [{epoch}], iter [{iters}/{n_iters}]")
            accum_loss = np.mean(losses_batch)
            losses.append(accum_loss)
            plt.semilogy(losses)
            plt.grid()
            plt.savefig(f"ckpts/loss3.png")
            plt.close("all")


        if iters % grad_accum == 0 and iters > 0:
            if model.dtype == torch.float16:
                scaler.step(opt)
                scaler.update()
            else:
                opt.step()
            opt.zero_grad()
            scheduler.step()
            accum_loss = np.mean(losses_batch)
            losses.append(accum_loss)
            losses_batch = []

            if accum_loss < loss_best:
                loss_best = accum_loss
                torch.save(projection, f"ckpts/projection_qa1")
                torch.save(start_emb, f"ckpts/SOI2_qa1.pt")
                torch.save(end_emb, f"ckpts/EOI2_qa1.pt")
            
            
            #gc.collect()
        
        iters += 1

        # model inference to get
        
        

  0%|          | 0/72998 [00:00<?, ?it/s]

0
Songs from the Southern Mountains is the title of a recording by a folk music artist that has won how many Grammy awards?

 last part 

seven

 continue  's Grammy Award for Best Traditional Folk Album in 1962. The album was produced by Ralph Rinzler<unk>


loss 8.006061553955078
lr 0.0 0
Dimension out of range (expected to be in range of [-2, 1], but got 2)
Dimension out of range (expected to be in range of [-2, 1], but got 2)
512
Lynda Bird Johnson Robb has an elder brother who died in 2013, and he was ambassador to which European country?

 last part 

Belgium

 continue  1960s, Lynda Bird Johnson Robb's father was the 36th President of the United States. Who was he? a.<unk>


loss 3.7830482571143804
lr 0.0005 512
Dimension out of range (expected to be in range of [-2, 1], but got 2)
Dimension out of range (expected to be in range of [-2, 1], but got 2)
Dimension out of range (expected to be in range of [-2, 1], but got 2)
1024
Where is the group, which organised the Mexican refer

In [None]:
model.generate(inputs_embeds=embeddings1, **gen_params)

In [None]:
embeddings1 = torch.cat(
                [
                    input_embeddings,
                    start_emb[None, None, ...],
                    projected_kg_embeddings,
                    end_emb[None, None, ...],
                    output_embeddings
                ],
                dim=1,
            )

In [None]:
out = model.generate(inputs_embeds=embeddings1,
    do_sample=True,
    max_new_tokens=150,
    top_p=0.82,
    top_k=0,
    eos_token_id= 0,
    pad_token_id=2,
    temperature=3.5)

print ("out.shape", out.shape)
out = out[:, 1:]

generated_texts = tokenizer.batch_decode(out)[0]

generated_texts


In [None]:
gen_params = {
        "do_sample": False,
        "max_new_tokens": 150,
        "early_stopping": False,
        "num_beams": 3,
        "repetition_penalty": 2.0,
        "remove_invalid_values": True,
        "eos_token_id": 0,
        "pad_token_id": 2,
        "forced_eos_token_id": 0,
        "use_cache": True,
        "no_repeat_ngram_size": 2,
        "bad_words_ids": bad_words_ids,
        "num_return_sequences": 3,
    }

In [None]:
out = model.generate(inputs_embeds=embeddings1,
    **gen_params)

print ("out.shape", out.shape)
#out = out[:, 1:]

generated_texts = tokenizer.decode(out)[0]


In [None]:
generated_texts