# Fine Tune GPT-2 Model

Open Questions:
- Is it useful to add \<bot> statement as preparing step? -> yesss
- Should one batch, one dialog?
- Currently the dialogs are mixed, so only question and answer is paired right now
    - How to fix?
    - The batches?
- Removing Bot answers?<br>
    From:<br>
    '<start> Create me a unique interactive story to calm with the topic: Ocean. <bot>:Ah, the ocean... <end>',<br>
    "<start> Ah, the ocean. Such a ... <end>",<br>
    "<start> Yes, I can feel it...'<br>
    <br>
    to:<br>
    '<start> Create me a unique interactive story to calm with the topic: Ocean. <bot>:Ah, the ocean... <end>',<br>
    "<start> Yes, I can feel it...'<br>


1. **Dialog-based Approach:**
   - **One Batch, One Dialog:**
     - Treat each dialog as a separate training example. This allows the model to learn the context and flow of individual conversations.
     - Helps the model focus on capturing the nuances of each conversation independently.
     - Useful if your storytelling involves short, distinct dialogs.

   - **Inclusion of the Past:**
     - You can include the past history within each dialog example. Concatenate the previous turns in the conversation to provide context.
     - This helps the model understand the context and continuity of the ongoing dialog.
     - Be mindful of the token limit, as GPT-2 has a maximum token limit, and longer sequences might get truncated.

2. **Memory and Context:**
   - GPT-2 has a limited context window due to its fixed input size. If the conversations are long, you might lose relevant information.
   - Consider balancing the length of your input sequences to ensure the model can capture essential details.

3. **Dynamic Context Window:**
   - Instead of a fixed history length, you could use a sliding window approach.
   - Maintain a dynamic context window that moves along the conversation, incorporating the most recent interactions.

4. **Experiment and Evaluate:**
   - It's often beneficial to experiment with different approaches to see what works best for your specific use case.
   - Conduct thorough evaluations using validation data to ensure the model is learning effectively and providing desired responses.

5. **Training Strategies:**
   - Experiment with hyperparameters like learning rate, batch size, and the number of training epochs to fine-tune the model effectively.
   - Monitor the model's performance on both training and validation sets.

Preprocess: handling tokenization, special tokens, and managing the context window.

Hint: Use the dialogs.txt file to train the model on google colab.

### System

In [12]:
!python --version

Python 3.10.12


In [13]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

/bin/bash: line 1: nvidia-smi: command not found


In [14]:
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

Your runtime has 13.6 gigabytes of available RAM

Not using a high-RAM runtime


In [15]:
!cat /proc/cpuinfo

processor	: 0
vendor_id	: AuthenticAMD
cpu family	: 23
model		: 49
model name	: AMD EPYC 7B12
stepping	: 0
microcode	: 0xffffffff
cpu MHz		: 2249.998
cache size	: 512 KB
physical id	: 0
siblings	: 2
core id		: 0
cpu cores	: 1
apicid		: 0
initial apicid	: 0
fpu		: yes
fpu_exception	: yes
cpuid level	: 13
wp		: yes
flags		: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr arat npt nrip_save umip rdpid
bugs		: sysret_ss_attrs null_seg spectre_v1 spectre_v2 spec_store_bypass retbleed smt_rsb srso
bogomips	: 

In [16]:
!cat /proc/meminfo

MemTotal:       13290480 kB
MemFree:         7239128 kB
MemAvailable:   11348988 kB
Buffers:          363924 kB
Cached:          3935072 kB
SwapCached:            0 kB
Active:           639376 kB
Inactive:        5135192 kB
Active(anon):       1084 kB
Inactive(anon):  1478976 kB
Active(file):     638292 kB
Inactive(file):  3656216 kB
Unevictable:           4 kB
Mlocked:               4 kB
SwapTotal:             0 kB
SwapFree:              0 kB
Dirty:              2632 kB
Writeback:             0 kB
AnonPages:       1473492 kB
Mapped:           474232 kB
Shmem:              4484 kB
KReclaimable:     112584 kB
Slab:             149860 kB
SReclaimable:     112584 kB
SUnreclaim:        37276 kB
KernelStack:        4420 kB
PageTables:        30812 kB
SecPageTables:         0 kB
NFS_Unstable:          0 kB
Bounce:                0 kB
WritebackTmp:          0 kB
CommitLimit:     6645240 kB
Committed_AS:    2495140 kB
VmallocTotal:   34359738367 kB
VmallocUsed:       10656 kB
VmallocChunk:    

### Imports

In [17]:
#!python -m pip install torch
#!python -m pip install transformers

In [18]:
#from transformers import GPT2Tokenizer, GPT2LMHeadModel
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import os
from datetime import datetime as dt

import json

import transformers
import torch
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
from torch.optim import Adam

In [19]:
os.listdir("./")

['.config', 'dialogs.txt', 'sample_data']

### Load and Prepare the data

In [20]:
MODEL_PATH = "./model/model.pth"
MODEL_WEIGHT_PATH = "./model/model_weights.pth"
ONNX_PATH = "./model/model.onnx"
MAX_LENGTH = 1024   #"auto"
# ".pt", ".pth", ".pkl", or ".h5"

class Dialog_Data(Dataset):

    def __init__(self, tokenizer, data_dir_path="./data", read_one_file=False, should_save_as_one_file=True):
        self.tokenizer = tokenizer
        self.data_dir_path = data_dir_path
        self.read_data(data_dir_path, read_one_file, should_save_as_one_file)

    def read_data(self, data_dir_path, read_one_file, should_save_as_one_file=True):
        global MAX_LENGTH

        data = []
        conversations = []
        if read_one_file:
            with open("./dialogs.txt", "r", encoding="latin1") as f:
                raw = f.read()
            for dialog in raw.split("#/"):
                cur_conversation = []
                for sentence in dialog.split(";"):
                    data += [sentence]
                    cur_conversation += [sentence]
                conversations += [(cur_conversation)]
        else:
            for dialog in os.listdir(self.data_dir_path):
                    with open(f"{self.data_dir_path}/{dialog}", "r") as f:
                        cur_conversation = []
                        for idx, line in enumerate(f.read().split("\n")):
                            content = ":".join(line.split(":")[1:]).strip()
                            if len(content) > 0:
                                data += [content]
                                cur_conversation += [content]
                    conversations += [(cur_conversation)]
            if should_save_as_one_file:
                save_data = ""
                for idx_1, dialog in enumerate(conversations):
                    if idx_1 > 0:
                        save_data += "#/"

                    for idx_2, elem in enumerate(dialog):
                        if idx_2 == 0:
                            save_data += f"{elem}"
                        else:
                            save_data += f";{elem}"
                    with open("./dialogs.txt", "w") as f:
                        f.write(save_data)

        # add markers and trim
        X = []
        y = []
        is_conversation_beginning = []
        for cur_conversation in conversations:
            for idx in range(0, len(cur_conversation)-1, 2):
                if idx == 0:
                    is_conversation_beginning += [1]
                else:
                    is_conversation_beginning += [0]
                X += [cur_conversation[idx]]
                y += [cur_conversation[idx+1]]

        self.conversations = conversations
        self.X = X
        self.y = y
        self.is_conversation_beginning = is_conversation_beginning

        # encoded_data = self.tokenizer(self.X, truncation=True, return_tensors="pt", max_length=MAX_LENGTH, padding="max_length") # max_length=40, padding="max_length"
        # self.X_encoded = encoded_data['input_ids']
        # self.X_attention_mask = encoded_data['attention_mask']

        encoded_data = self.tokenizer(self.y, truncation=False, return_tensors="pt", max_length=MAX_LENGTH, padding="max_length")
        self.y_encoded =  encoded_data['input_ids']
        self.y_attention_mask = encoded_data['attention_mask']

    def get_context(self, idx):
        with_context = ""
        cur_idx = idx
        while cur_idx >= 0:
            if cur_idx == idx:
                with_context += f"<bot>{self.y[cur_idx]}"  # <end>
                with_context = f"{self.X[cur_idx]}{with_context}"
            else:
                with_context = f"{self.y[cur_idx]}<sep>{with_context}"
                with_context = f"{self.X[cur_idx]}<sep>{with_context}"
            if self.is_conversation_beginning[cur_idx] == 1:
                break
            cur_idx -= 1
        #with_context = f"<start>{with_context}"
        return with_context

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        X = self.get_context(idx)
        encoded_data = self.tokenizer(X, truncation=True, return_tensors="pt", max_length=MAX_LENGTH, padding="max_length")
        return (encoded_data['input_ids'], encoded_data['attention_mask'])#, self.y_encoded[idx])



In [21]:
tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2", padding_side="left")
tokenizer.add_special_tokens({  "pad_token": "<pad>",
                                # "bos_token": "<start>",
                                # "eos_token": "<end>",
                                "sep_token": "<sep>"})
tokenizer.add_tokens(["<bot>"])

1

In [22]:
data = Dialog_Data(tokenizer=tokenizer, read_one_file=True, should_save_as_one_file=True)
data = DataLoader(data, batch_size=4, shuffle=True)

In [23]:
# check the data
[print(i, "\n") for i in data.dataset.X[:3]];

Hey, I've been feeling really down lately. I just can't seem to find any motivation or purpose in my life. 

What was the story about? 

That's really inspiring. But how did he manage to find motivation and purpose in his life? 



In [24]:
data.dataset.get_context(3)

"Hey, I've been feeling really down lately. I just can't seem to find any motivation or purpose in my life.<sep>I completely understand how you feel. I went through a similar phase a while back. I had lost all sense of direction and felt like my life lacked purpose. But then I came across this incredible story that really inspired me.<sep>What was the story about?<sep>It was about a man named Nick Vujicic. He was born without arms and legs, and faced numerous challenges and obstacles throughout his life. Despite all that, he never let his disabilities define him. Instead, he used his setbacks as fuel to achieve incredible things. He became a motivational speaker, inspiring millions of people around the world.<sep>That's really inspiring. But how did he manage to find motivation and purpose in his life?<sep>Well, Nick didn't let his circumstances determine his happiness or success. He believed that true happiness and purpose come from within, and he focused on developing a positive mind

In [25]:
data.dataset.get_context(10)

"Hey, I've been feeling really overwhelmed lately and I think I might have Separation Anxiety Disorder.<bot>Oh, I'm sorry to hear that. What exactly are you experiencing?"

In [26]:
data.dataset.get_context(203)

"Good afternoon, Doctor. Thank you for seeing me today.<sep>Good afternoon. It's my pleasure. How can I assist you?<sep>I've been feeling really overwhelmed lately. I constantly feel anxious and have trouble sleeping.<sep>I'm sorry to hear that. When did you first start noticing these symptoms?<sep>It's been going on for a few months now. It started after a major life event - a job loss and the end of a long-term relationship.<bot>Those are significant stressors indeed. How do you cope with these feelings of anxiety?"

In [27]:
len(data.dataset)

11149

In [28]:
# # Test saved dialogs in one file
# counter = 0
# with open("./dialogs.txt", "r") as f:
#     dialogs = f.read()
# print(f"Dialogs amount: {len(os.listdir('./data'))}")
# print(f"In one file dialogs amount: {len(dialogs.split('#'))}")

In [29]:
BATCH_AMOUNT = 0
for X, a in data:
    BATCH_AMOUNT += 1
    if BATCH_AMOUNT == 5:
        print("X:")
        print("Decoded:", tokenizer.decode(X[0][0]))
        print("Encoded:", X[0])
        print(type(X[0]))
        print(len(X[0]))
        print("AttentionMask:\n", a[0])
        print(type(a[0]))
        print(len(a[0]))
        # print("Target:\nDecoded:", tokenizer.decode(y[0][0]))
        # print("Encoded:", y[0])
        # print(type(y[0]))
        # print(len(y[0]))

BATCH_AMOUNT

X:
Decoded: <pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pa

2788

### Load pretrained model

In [30]:
# config = transformers.GPT2Config.from_pretrained("gpt2")
# config.do_sample = config.task_specific_params['text-generation']['do_sample']
# config.max_length = MAX_LENGTH #config.task_specific_params['text-generation']['max_length']
model = transformers.GPT2LMHeadModel.from_pretrained("gpt2")
model.resize_token_embeddings(len(tokenizer))
model.eval()

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50260, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50260, bias=False)
)

In [31]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
device

device(type='cpu')

In [32]:
model = model.to(device)

### First test

In [48]:
def inference(prompt:str, model, tokenizer, device, padding, clear_output=True):
    model.eval()
    prompt = f"{prompt}<bot>"
    prompt = tokenizer(prompt, return_tensors="pt", padding=padding)
    X = prompt["input_ids"].to(device)
    a = prompt["attention_mask"].to(device)
    with torch.no_grad():
        output = model.generate(X, attention_mask=a, pad_token_id=tokenizer.pad_token_id,
                                                        do_sample=True, max_length=MAX_LENGTH)

    if clear_output:
        output = tokenizer.decode(output[0], skip_special_tokens=True)
    else:
        output = tokenizer.decode(output[0], skip_special_tokens=False)

    if type(output) == list and len(output) == 1:
        output = output[0]
    return output

In [49]:
inference(prompt="Hey, I'm feeling not so good.", model=model, tokenizer=tokenizer,
                                                        device=device, padding="max_length", clear_output=True)

"Hey, I'm feeling not so good.<bot>"

In [50]:
inference(prompt="Hey, I've been feeling really down lately.", model=model, tokenizer=tokenizer, device=device, padding="max_length", clear_output=False)

"<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad

### Fine Tune Model

In [51]:
def print_time_information(start, end=None, total_seconds=None, text="Total training time:", should_print=True):
    if type(total_seconds) is type(None):
        if type(end) is type(None):
            end = dt.now()
        total_seconds = abs((start-end).total_seconds())

    minutes, seconds = divmod(total_seconds, 60)
    hours, minutes = divmod(minutes, 60)
    days, hours = divmod(hours, 24)
    res = f"{text}\n    -> {int(days)} Days\n    -> {int(hours)} Hours\n    -> {int(minutes)} Minutes\n    -> {int(seconds)} Seconds"
    if should_print:
        print(res)
    return res

def calculate_train_duration(epoch_start, batch_amount, epochs, cur_epoch):
    """
    Call this function once after the first batch every epoch.
    """
    cur_epoch += 1
    now = dt.now()
    duration_one_batch = abs((epoch_start-now).total_seconds())
    duration_for_one_epoch = duration_one_batch * batch_amount
    epochs_left = (epochs - cur_epoch) + 1    # current epoch also have to run
    predicted_training_duration = epochs_left * duration_for_one_epoch
    res = f"{'-'*16}\n"
    text = f"Training will need about following time for {epochs_left} epochs:"
    res += print_time_information(start=epoch_start, total_seconds=predicted_training_duration, text=text, should_print=False)
    res += f"\n{'-'*16}"
    print(res)
    return res

In [53]:
optimizer = Adam(model.parameters(), lr=1e-4)
epochs = 6

loss_hist = []
steps = 0

solutions = []
solutions_cleared = []

start = dt.now()

with open("./log.txt", "w") as f:
    f.write(f"Log File for Training Calm Chatbot {start.strftime('%d.%m.%Y - %H:%M:%S')}")

for cur_epoch in range(0, epochs):
    model.train()
    epoch_start = dt.now()
    new_epoch = True
    for input_ids, attention_masks in data:
        input_ids = input_ids.to(device)
        attention_masks = attention_masks.to(device)
        optimizer.zero_grad()
        loss = model(input_ids, attention_mask=attention_masks, labels=input_ids).loss
        loss_hist += [loss.item()]
        loss.backward()
        optimizer.step()
        steps += 1

        if new_epoch:
            time_prediction = calculate_train_duration(epoch_start, BATCH_AMOUNT, epochs, cur_epoch)
            new_epoch = False

            with open("./log.txt", "a") as f:
                f.write(f"\n\n{time_prediction}")

    torch.save(model.state_dict(), f"./model_state_V4_{cur_epoch}.pt")
    epoch_info = f'Epoch {cur_epoch+1}/{epochs}, Training Loss: {loss.item():.4f}, Steps: {steps}, Current Time:{dt.now().strftime("%H:%M:%S")}'
    print(epoch_info)
    test_prompt = inference(prompt="Hey, I'm feeling not so good.", model=model, tokenizer=tokenizer,
                                                        device=device, padding=True, clear_output=False)
    solutions += [test_prompt]
    solutions_cleared += [inference(prompt="Hey, I'm feeling not so good.", model=model, tokenizer=tokenizer,
                                                        device=device, padding=True, clear_output=True)]

    with open("./log.txt", "a") as f:
        f.write(f"\n\n{epoch_info}\n\nTest-Prompt:\n{test_prompt}")

    # allocate memory
    #gc.collect()
    # if device.type == "cuda":
    #     torch.cuda.empty_cache()

TypeError: GPT2LMHeadModel.forward() got an unexpected keyword argument 'max_length'

In [None]:
print_time_information(start)

# plot loss
fig, ax = plt.subplots(1, 1, figsize=(16, 8))
ax.plot(np.arange(len(loss_hist)-1), loss_hist[1], label='Loss')
ax.set_xlabel('Learning progress')
ax.set_ylabel('Loss (normalized mean absolute error)')
ax.set_title('Loss over time')
ax.legend()
ax.grid()

# save step solution-predictions:
with open("./result_per_epoch.txt", "w") as f:
    res = ""
    for i, cur_res in enumerate(solutions, start=1):
        res += f"\n{'-'*16}\n{i:02d}. Epoch:\n{cur_res}"
    f.write(res)

In [None]:
import pickle

with open('./loss_hist.pkl', 'wb') as f:
    pickle.dump(loss_hist, f)

In [None]:
# plot loss
OFFSET = 0
fig, ax = plt.subplots(1, 1, figsize=(16, 8))
ax.plot(np.arange(len(loss_hist)-OFFSET), loss_hist[OFFSET:], label='Loss')
ax.set_xlabel('Learning progress')
ax.set_ylabel('Loss (normalized mean absolute error)')
ax.set_title('Loss over time')
ax.legend()
ax.grid()

In [None]:
import matplotlib.ticker as mticker

OFFSET = 0
loss_series = pd.Series(loss_hist[OFFSET:])

# Wenden Sie das gleitende Fenster an
window_size = 1000  # Größe des gleitenden Fensters
loss_rolling = loss_series.rolling(window=window_size).mean()

# Zeichnen Sie die ursprünglichen und geglätteten Daten
fig, ax = plt.subplots(1, 1, figsize=(16, 8))
#ax.plot(np.arange(len(loss_hist)), loss_hist, label='Original Loss')
ax.plot(np.arange(len(loss_series.index)), loss_rolling, label='Smoothed Loss', linewidth=2)
ax.set_xlabel('Learning progress')
ax.set_ylabel('Loss (normalized mean absolute error)')
ax.set_title('Loss over time')
ax.legend()
ax.grid()
# formatter = mticker.ScalarFormatter()
# formatter.set_scientific(False)
# ax.yaxis.set_major_formatter(formatter)
# ax.set_yscale('log')

plt.savefig("./loss.png")
plt.show()


### Save model

-> Propably save the model in a extra repository/branch and provide it as python module<br>
-> Is model very big?

save only weights

In [None]:
torch.save(model.state_dict(), MODEL_WEIGHT_PATH)

# loading
# config = transformers.GPT2Config.from_pretrained("gpt2")
# config.max_length = MAX_LENGTH #config.task_specific_params['text-generation']['max_length']
# model = transformers.GPT2LMHeadModel.from_pretrained("gpt2", config=config)
# model.resize_token_embeddings(len(tokenizer))
# model.load_state_dict(torch.load(MODEL_WEIGHT_PATH))
# model.eval()

save whole model

In [None]:
torch.save(model, MODEL_PATH)

# loading
# model = torch.load(MODEL_PATH)
# model.eval()

save in Google Drive

In [29]:
# from google.colab import drive
# drive.mount('/content/gdrive')

# path = "/content/gdrive/My Drive/model.pt"
# torch.save(model.state_dict(), path)

Mounted at /content/gdrive


---
### Ressources:

- https://www.toolify.ai/ai-news/finetuning-gpt2-for-conversational-chatbots-10476
- https://huggingface.co/docs/transformers/model_doc/gpt2
- [PyTorch kompakt](https://www.thalia.de/shop/home/artikeldetails/A1062166688)
- https://pytorch.org/tutorials/beginner/chatbot_tutorial.html
- https://github.com/itsuncheng/fine-tuning-GPT2/tree/master
- https://www.kaggle.com/code/pinooxd/gpt2-chatbot/notebook

<br>

---