In [1]:
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
import datasets
import torch
import gc
import tqdm
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device {device}')

  from .autonotebook import tqdm as notebook_tqdm


Using device cuda


In [2]:
gc.collect()
print('tokenizer')
teacher_tokenizer = AutoTokenizer.from_pretrained('speakleash/Bielik-1.5B-v3.0-Instruct')
print('model')
teacher = AutoModelForCausalLM.from_pretrained('speakleash/Bielik-1.5B-v3.0-Instruct-FP8-Dynamic').to(device)
print('student')
student_tokenizer = AutoTokenizer.from_pretrained('sdadas/polish-gpt2-small')
student = AutoModelForCausalLM.from_pretrained('sdadas/polish-gpt2-small').to(device)

tokenizer
model
student


In [3]:
dataset = datasets.load_dataset('Igorrr0/polish-qa-general').remove_columns(['input'])
print(dataset['train'][0].keys())

def bielik_prompt(prompt):
    return f"<s><|im_start|> user\n{prompt}<|im_end|> \n<|im_start|> assistant\n"


def question_to_prompt(question, answer="", bielik=False):
    if bielik:
        return bielik_prompt(question) + answer + ("\n<|im_end|><s>" if answer != "" else "")
    return f"user\n{question}\n\nassistant\n{answer}"


def batch_to_prompt(batch):
    prompts = []
    for question, answer in zip(batch['instruction'], batch['output']):
        prompts.append(question_to_prompt(question, answer))
    return prompts

dict_keys(['instruction', 'output'])


In [4]:
gc.collect()
def test_model(model, tokenizer, custom_prompt="Ile to jest 2+2?"):
    tokens = tokenizer(custom_prompt, return_tensors='pt')
    tokens = tokens.to(model.device)
    print(tokens)
    model.eval()
    output = model.generate(**tokens, max_new_tokens=512)
    print(output.size())
    decoded = tokenizer.decode(output[0], skip_special_tokens=True)
    print(decoded)

print('Testing teacher model:')
test_model(teacher, teacher_tokenizer, custom_prompt=bielik_prompt("Ile to jest 2+2?"))
print('Testing student model:')
test_model(student, student_tokenizer, custom_prompt="2+2 równa się ")

Testing teacher model:


Setting `pad_token_id` to `eos_token_id`:4 for open-end generation.


{'input_ids': tensor([[    1,     1,     3, 31887,   310,  2272,    17, 31942,   296,   373,
           403, 31887, 31924, 31979, 31924, 31956,     4, 31887, 31887,    17,
             3, 31887,   322,  3988, 19681,    17]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1]], device='cuda:0')}


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


torch.Size([1, 33])
 user
Ile to jest 2+2?  
  assistant
2+2 = 4
Testing student model:
{'input_ids': tensor([[   22,    15,    22, 22274,   309,   225]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]], device='cuda:0')}
torch.Size([1, 8])
2+2 równa się %.


In [None]:
# Dataset caching
from copy import deepcopy

cached = deepcopy(dataset)

teacher.eval()

for i, x in enumerate(cached['train']):
    print(f'train_{i}')
    print(x['instruction'])
    tokens = teacher_tokenizer(bielik_prompt(x['instruction']), return_tensors='pt').to(device)
    output = teacher_tokenizer.decode(teacher.generate(**tokens, max_new_tokens=512)[0], skip_special_tokens=True)
    cached['train'][i]['output'] = output

print(cached['train'][0])

Setting `pad_token_id` to `eos_token_id`:4 for open-end generation.


train_0
Jak mogę pomóc?


Setting `pad_token_id` to `eos_token_id`:4 for open-end generation.


train_1
Cześć!


Setting `pad_token_id` to `eos_token_id`:4 for open-end generation.


train_2
Dzień dobry, proszę o informacje.


Setting `pad_token_id` to `eos_token_id`:4 for open-end generation.


train_3
Czy możesz to powtórzyć?


KeyboardInterrupt: 

In [None]:
MAX_TOKENS = 256
gc.collect()
dataloader = torch.utils.data.DataLoader(dataset['train'], batch_size=2, shuffle=True)
teacher_tokenizer.pad_token = teacher_tokenizer.eos_token
student_tokenizer.pad_token = student_tokenizer.eos_token

student = AutoModelForCausalLM.from_pretrained('sdadas/polish-gpt2-small').to(device)
teacher.eval()
student.train()

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(student.parameters(), lr=1e-5)


def train_step(batch):
    gc.collect()
        
    prompt = batch_to_prompt(batch)
    print(prompt)
    # print(torch.cuda.memory_summary())
    # TODO: one tokenization + mask prompt
    student_inputs = student_tokenizer(batch['instruction'], return_tensors='pt', padding="max_length", truncation=True, max_length=MAX_TOKENS).to(device)
    student_labels = student_tokenizer(batch['output'], return_tensors='pt', padding="max_length", truncation=True, max_length=MAX_TOKENS).to(device)
    student_outputs = student(**student_inputs, labels=student_labels['input_ids'], max_length=MAX_TOKENS)

    # print(torch.cuda.memory_summary())

    loss = student_outputs.loss
    
    optimizer.zero_grad()
    loss.backward()
    # print(torch.cuda.memory_summary())
    optimizer.step()
    
    # print(torch.cuda.memory_summary())

    return loss.item()

for epoch in range(1):
    for i, batch in enumerate(dataloader):
        print()
        print(f'Epoch {epoch}, {i+1}/{len(dataloader)}: Training on batch...')
        print(f'loss= {train_step(batch)}')
        if i % 100 == 99:
            print('Saving model...')
            student.save_pretrained(f'ptaqqqq/polish-gpt2-small-distilled-{epoch}-{i}')
            student_tokenizer.save_pretrained(f'ptaqqqq/polish-gpt2-small-distilled-{epoch}-{i}')

    print('Saving model...')
    student.save_pretrained(f'ptaqqqq/polish-gpt2-small-distilled-{epoch}')
    student_tokenizer.save_pretrained(f'ptaqqqq/polish-gpt2-small-distilled-{epoch}')

print('Training finished.')
print('Saving final model...')
student.save_pretrained('ptaqqqq/polish-gpt2-small-distilled-final')
student_tokenizer.save_pretrained('ptaqqqq/polish-gpt2-small-distilled-final')  
print('Testing final model:')
test_model(student, student_tokenizer, custom_prompt=dataset['train'][0]['input'])


Epoch 0, 1/184: Training on batch...
['user\nOpisz, jak działa telefon komórkowy.\n\nassistant\nTelefon komórkowy działa poprzez przesyłanie sygnałów radiowych między urządzeniem a stacjami bazowymi, co umożliwia komunikację.', 'user\nDziękuję za Twój czas.\n\nassistant\nCała przyjemność po mojej stronie.']
loss= 9.36070728302002

Epoch 0, 2/184: Training on batch...
['user\nDo widzenia!\n\nassistant\nDo zobaczenia wkrótce!', 'user\nCzym jest magnetyzm?\n\nassistant\nMagnetyzm to zjawisko fizyczne, w którym materia wywiera siły przyciągania lub odpychania na inne materiały.']
loss= 8.052878379821777

Epoch 0, 3/184: Training on batch...
['user\nCzym jest wolność słowa?\n\nassistant\nWolność słowa to prawo do wyrażania opinii i idei bez cenzury lub strachu przed represjami.', 'user\nMiłego dnia!\n\nassistant\nWzajemnie, życzę miłego dnia!']
loss= 8.073185920715332

Epoch 0, 4/184: Training on batch...
['user\nNapisz krótkie pytanie o pogodę.\n\nassistant\nJaka jest prognoza pogody na w

KeyboardInterrupt: 

In [None]:
test_model(student, student_tokenizer, custom_prompt="user\nNapisz krótką prośbę o pomoc w wyborze\nassistant\n")

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


{'input_ids': tensor([[   89,  1446,   203, 49900, 15309, 12843,   282,  2834,   264, 23796,
           203,   400,   393,  1111,    88,   203]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}
torch.Size([1, 528])
user
Napisz krótką prośbę o pomoc w wyborze
assistant
assistantatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedatedateda

In [None]:
# List saved models
import os
saved_models = [f for f in os.listdir('ptaqqqq/') if f.startswith('polish-gpt2-small-distilled-')]
print('Saved models:')
for model in saved_models:
    print(model)

if False:
    for name, module in student.named_modules():
        if "c_attn" in name or "c_proj" in name:
            print(name)