We will cover - 
- different llm fine tuning approaches
- prepare dataset for text classification
- fine tune llm for spam message detection
- evaluate the accuracy of a fine-tuned llm classifier

most common ways of finetuning - instruction fine tuning and classification fine tuning.
We will learn about Classification Fine Tuning in this chapter6.

The downside is its restricted to predicting classes it has encountered during training.

Clas. ft model can be seen as a specialized and generally it is easier to develop a specialized model than a generalist model.

![ft1](../images/ft-1.png)

<h4>Stage1</h4>

1. Download the dataset, Balance it (oversample or undersample), Give it numerical class representation (0 or normal, 1 for spam, similar to converting tokens to token_ids), split it training and evaluation subsets.

2. Creating data loaders - we used a sliding window approach to generate uniformly sized text chuncks to make batches for efficient model training.
Here also for such batching, we can decide on the chunk size based on 2 approaches -
 - truncate all messages to the lenfth of the shortest message in the dataset
 - pad all messages to the length of the longest message in the dataset

    We go with the second approach so that we dont have any info loss, and for it we pad with the token_id with <|endoftext|>



<h4>Stage 2</h4>

1. Adding a classification head - we replace the original output layer, which maps the hidden representation to a vocab of 50,257, with a smaller output layer that maps to two classes: 0 "notspam" and 1 "spam"

    note - When fine-tuning a pretrained language model, you often don’t need to update all layers. Lower layers learn general language features that transfer well across tasks, while higher layers capture task-specific patterns. Fine-tuning only the top (output-side) layers is usually enough to adapt the model, and it’s more computationally efficient than updating the entire network.

2. Our GPT Model architecture has 12 repeated transformers blocks, we only keep the output layer, final layernorm and the last transformer block as trainable. Remaining 11 transformers blocks and the embedding layers are kept non-trainable.

3. Now the last layer outputs 2 columns (2 dim) instead of 50k (vocab size) and we only have to work with the last token (last row), reason being - because of the causal attention mask, the last token in the sequence accumulates the most information since it is the only token with access to the data from all the prev tokens.

![ft2](../images/ft-2.png)

4. Before implimenting evalution utilities, we must first convert the model outputs into class label predictions

![ft3](../images/ft-3.png)

5. Define the loss function, and select logits of just the last tokens for this cross entropy loss function

<h4>Stage 3</h4>

![ft8](../images/ft-8.png)

<h3>Code</h3>

In [54]:
from importlib.metadata import version

pkgs = ["matplotlib",  # Plotting library
        "numpy",       # PyTorch & TensorFlow dependency
        "tiktoken",    # Tokenizer
        "torch",       # Deep learning library
        "tensorflow",  # For OpenAI's pretrained weights
        "pandas"       # Dataset loading
       ]
for p in pkgs:
    print(f"{p} version: {version(p)}")


matplotlib version: 3.10.8
numpy version: 2.4.1
tiktoken version: 0.12.0
torch version: 2.9.1
tensorflow version: 2.20.0
pandas version: 3.0.0


<h4>Preparing dataset and creating dataloaders</h4>

In [55]:
import requests, zipfile, os
from pathlib import Path

url = "https://www.kaggle.com/api/v1/datasets/download/danofer/sarcasm"
zip_path = "sarcasm.zip"
extracted_path = "sarcasm"
data_file_path = Path(extracted_path) / "train-balanced-sarcasm.csv"

def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path):
    if data_file_path.exists():
        return

    # downloading the file
    response = requests.get(url, stream=True, timeout=60)
    response.raise_for_status()
    with open(zip_path, "wb") as out_file:
        for chunk in response.iter_content(chunk_size=8192):
            if chunk:
                out_file.write(chunk)

    # unzipping the file
    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        zip_ref.extractall(extracted_path)

    # add .tsv file extension
    original_file_path = Path(extracted_path) / "train-balanced-sarcasm"
    os.rename(original_file_path, data_file_path)
    print(f"File downlaod and saved as {data_file_path}")


try:
    download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path)
except (requests.exceptions.RequestException, TimeoutError) as e:
    print(f"Primary URL failed: {e}.")
    

In [56]:
import pandas as pd
df = pd.read_csv(data_file_path, sep=",")
df_selected = df[["label", "comment", "parent_comment"]]
df_new = pd.DataFrame()
df_new["label"] = df_selected["label"]
df_new["text"] = df_selected["parent_comment"]+" "+df_selected["comment"]
pd.set_option('display.max_colwidth', None)
df_new.sample(5) 

Unnamed: 0,label,text
245000,0,"I've worked for Bernie. I've donated to Bernie. When the primary finally reaches CA, I'll vote for Bernie. I would not vote for him as an independent. I have no interest in splitting the leftist vote and getting a Trump presidency. You're gonna get one if that vile woman is on the ballot, like it or not."
39624,0,"So you name one thing? That's ""dumbing"" it down? *what is inside the team site* hasn't changed. In fact, you are getting more functionality than you could before... but I guess since you can't rename the URL, the whole application got worse. 2nd - How do I delete the site?"
373963,0,*cross fingers* Please be red herring Please be misleading Please be a punishment for decompilers Open box in a nutshell
211118,0,"The most frustrating conversation I've ever had An middle-aged sister who I used to be kind of close to when I was in mentioned to my aunt a couple months ago that she'd like to see me again. I was kinda excited because I had the impression she was genuinely curious why I had left. So when she got back from her trip recently, we made plans. We met for coffee just yesterday and it went terribly. Basically the conversation went downhill when I stupidly mentioned I was living with my boyfriend now and she's says ""Oh... Well you know I can't really condone that..."" For the rest of the conversation, she tried to get me to admit that the only reason I left was because I felt I had been wronged by God and that I can only be happy if I come back. I tried to say tell her that I left because of my personal research and that I was very happy now that I'm no longer a JW. So when she realized she wasn't getting to me, she tried guilt tripping me. She said stuff like ""you're really hurting me by saying such disrespectful stuff about Jehovah and not recognizing the wonderful gifts he's given you"". Finally I got tired of her trying to claim I was ""hurting"" her so I ended the conversation by rolling my eyes and sarcastically saying ""Right, because you're the real victim here aren't you?"" And walked away. ""I can't condone that"" ""Since when do I need you to condone my life choices?"""
664996,0,The church handbook of instructions seems to think so. Spiritual polygamy is still (to a degree) being practiced. Its just temporal polygamy that is not allowed. Exactly


In [57]:
print(df_new.columns)

Index(['label', 'text'], dtype='str')


In [58]:
print(df_new["label"].value_counts())

label
0    505413
1    505413
Name: count, dtype: int64


In [59]:
df_reduced = (
    df_new
    .groupby("label")
    .apply(lambda x: x.sample(frac=0.01, random_state=42))
    .reset_index(drop=False)
)
df_reduced["label"].value_counts()


label
0    5054
1    5054
Name: count, dtype: int64

In [60]:
df_reduced

Unnamed: 0,label,level_1,text
0,0,772932,"I love Sanders people. When 60% say wealth distribution in this country is unfair, or 50% say *Citizens United* is bad, it's all ""OMG see, the people are on our side."" When 53% say they won't vote for a socialist, ""most"" of them are just ignorant people who don't know better. Pick a direction and stick with it, are polls to be relied on, or are the American people ignorant? Yes"
1,0,445608,Pawsitively relaxed Is there a subreddit for innocent thumbnails that look NSFW?
2,0,354910,Probably skiddies that are too afraid/unskilled (most likely a combination of both) to take on a bigger target while they're on spring break. The truth sadly.
3,0,284811,"huge geography freak here as well, SO many major differences there have been for myself, since you're talking about countries that you don't recognize, can you look at Africa again and tell me if you recognize any of these: Mauritania, Western Sahara, Burkina Faso, Eritrea and then Moldova in Europe. Thank you. Yes, all of those countries I remember."
4,0,516997,"Hearthstone? That's not an actually competitive game. They play BO3 despite it being so so luck focused. MTG doesn't have as many problems with RNG like HS does. In Poker they play hundreds of hands to decide who is the better player. The more they play the more the RNG matters less, and skill becomes more apparent. erm.. ""the more u play the less RNG matters"" .. So then you cant attribute critical strike being in the game to you being silver elo?"
...,...,...,...
10103,1,547967,"As amazing an experience this had to have been, this is not good at all for the animal Leaving the orphaned injured cheetah to die alone is much better for the animal"
10104,1,498424,"This may just be me, but Sanders supporters seem to have a penchant for sexual metaphors. I don't particularly care for it. But its valid, because obviously she only got where she is because woman ,duh."
10105,1,505928,"Thanks big tobacco! I better stick with cigs, at least they've been PROVEN to kill me! the real meaning behind it, is to stop overpopulation"
10106,1,427952,"SO, NASA Got Sick of all that Conspiracy Thing and Released over 10,000 Photos from the Apollo Moon Mission now that they can photoshop it well enough do do doooo"


In [61]:
def random_split(df, train_frac, validation_frac):
    # shuffling first
    df = df.sample(frac=1, random_state=123).reset_index(drop=True)

    train_end = int(len(df)*train_frac)
    validation_end = train_end + int(len(df)* validation_frac)
    
    train_df = df[:train_end]
    validation_df = df[train_end: validation_end]
    test_df = df[validation_end:]

    return train_df, validation_df, test_df

train_df, validation_df, test_df = random_split(df_reduced, 0.7, 0.1)

train_df.to_csv("../data/sarcasm/train.csv", index=None)
validation_df.to_csv("../data/sarcasm/validation.csv", index=None)
test_df.to_csv("../data/sarcasm/test.csv", index=None)

Dataloader part - padding all messages to match the length of longest sequence in a batch, using gpt2 to train so will use its padding token only

In [62]:
import tiktoken
tokenizer = tiktoken.get_encoding("gpt2")
tokenizer.encode("<|endoftext|>", allowed_special={'<|endoftext|>'})

[50256]

In [63]:
from torch.utils.data import Dataset

class SarcasmDataset(Dataset):
    def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):
        self.data = pd.read_csv(csv_file)

        # Convert text to string and handle NaN values
        self.encoded_texts = [
            tokenizer.encode(str(text)) if pd.notna(text) else tokenizer.encode("") 
            for text in self.data["text"]
        ]

        if max_length is None:
            self.max_length = self._longest_encoded_length()
        else:
            self.max_length = max_length

            # we will pad, but we also need to truncate the long ones to a given max length(future compatability, if given)
            self.encoded_texts = [encoded_text[:self.max_length] for encoded_text in self.encoded_texts]

        # now the main padding thing
        self.encoded_texts = [encoded_text + [pad_token_id]* (self.max_length - len(encoded_text)) for encoded_text in self.encoded_texts]

    def __getitem__(self, index):
        encoded = self.encoded_texts[index]
        label = self.data.iloc[index]["label"]
        return (
            torch.tensor(encoded, dtype=torch.long),
            torch.tensor(label, dtype=torch.long)
        )

    def __len__(self):
        return len(self.data)

    def _longest_encoded_length(self):
        max_length = 0
        for text in self.encoded_texts:
            length = len(text)
            if length>max_length:
                max_length = length
        return max_length

In [64]:
train_dataset = SarcasmDataset(
    csv_file="../data/sarcasm/train.csv",
    max_length=1024,
    tokenizer=tokenizer
)

print(train_dataset.max_length)

1024


In [65]:
val_dataset = SarcasmDataset(
    csv_file="../data/sarcasm/validation.csv",
    max_length=train_dataset.max_length,
    tokenizer=tokenizer
)
test_dataset = SarcasmDataset(
    csv_file="../data/sarcasm/test.csv",
    max_length=train_dataset.max_length,
    tokenizer=tokenizer
)

In [66]:
import torch
from torch.utils.data import DataLoader

num_workers = 0
batch_size = 8

torch.manual_seed(123)

train_loader = DataLoader(
    dataset = train_dataset,
    batch_size = batch_size,
    shuffle = True,
    num_workers = num_workers,
    drop_last = True,
)

val_loader = DataLoader(
    dataset=val_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    drop_last=False,
)

test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    drop_last=False,
)

In [67]:
print("Train loader:")
for input_batch, target_batch in train_loader:
    pass

print("Input batch dimensions:", input_batch.shape)
print("Label batch dimensions", target_batch.shape)

Train loader:
Input batch dimensions: torch.Size([8, 1024])
Label batch dimensions torch.Size([8])


In [68]:
print(f"{len(train_loader)} training batches")
print(f"{len(val_loader)} validation batches")
print(f"{len(test_loader)} test batches")

884 training batches
127 validation batches
253 test batches


<h4>Bringing our gpt2 architecture and uploading gpt2 weights in it</h4>

In [69]:
CHOOSE_MODEL = "gpt2-small (124M)"
INPUT_PROMPT = "Every effort moves"

BASE_CONFIG = {
    "vocab_size": 50257,     # Vocabulary size
    "context_length": 1024,  # Context length
    "drop_rate": 0.0,        # Dropout rate
    "qkv_bias": True         # Query-key-value bias
}

model_configs = {
    "gpt2-small (124M)": {"emb_dim": 768, "number_of_layers": 12, "num_heads": 12},
    "gpt2-medium (355M)": {"emb_dim": 1024, "number_of_layers": 24, "num_heads": 16},
    "gpt2-large (774M)": {"emb_dim": 1280, "number_of_layers": 36, "num_heads": 20},
    "gpt2-xl (1558M)": {"emb_dim": 1600, "number_of_layers": 48, "num_heads": 25},
}

BASE_CONFIG.update(model_configs[CHOOSE_MODEL])

assert train_dataset.max_length <= BASE_CONFIG["context_length"], (
    f"Dataset length {train_dataset.max_length} exceeds model's context "
    f"length {BASE_CONFIG['context_length']}. Reinitialize data sets with "
    f"`max_length={BASE_CONFIG['context_length']}`"
)

In [70]:
from modules import gpt2_download, gpt, weight_loader
model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")")
settings, params = gpt2_download.download_and_load_gpt2(model_size=model_size, models_dir="../models/gpt2")

model = gpt.GPTModel(BASE_CONFIG)
weight_loader.load_weights_into_gpt(model, params)
model.eval()

File already exists and is up-to-date: gpt2/124M/checkpoint
File already exists and is up-to-date: gpt2/124M/encoder.json
File already exists and is up-to-date: gpt2/124M/hparams.json
File already exists and is up-to-date: gpt2/124M/model.ckpt.data-00000-of-00001
File already exists and is up-to-date: gpt2/124M/model.ckpt.index
File already exists and is up-to-date: gpt2/124M/model.ckpt.meta
File already exists and is up-to-date: gpt2/124M/vocab.bpe


GPTModel(
  (tok_emb): Embedding(50257, 768)
  (pos_emb): Embedding(1024, 768)
  (drop_emb): Dropout(p=0.0, inplace=False)
  (transformerlayers): Sequential(
    (0): TransformerBlock(
      (attention): MultiHeadAttention(
        (W_query): Linear(in_features=768, out_features=768, bias=True)
        (W_key): Linear(in_features=768, out_features=768, bias=True)
        (W_value): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
      )
      (feedforward): FeedForward(
        (layers): Sequential(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU()
          (2): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
      (norm1): LayerNorm()
      (norm2): LayerNorm()
      (drop_shortcut): Dropout(p=0.0, inplace=False)
    )
    (1): TransformerBlock(
      (attention): MultiHeadAttention(
        (W_

In [71]:
from modules import generation, data

text_1 = "Every effort moves you"

token_ids = generation.generate_text_simple(
    model=model,
    idx=data.text_to_token_ids(text_1, tokenizer),
    max_new_tokens=15,
    context_size=BASE_CONFIG["context_length"]
)

print(data.token_ids_to_text(token_ids, tokenizer))


Every effort moves you forward.

The first step is to understand the importance of your work


In [72]:
text_2 = (
    "Is the following text 'sarcastic'? Answer with 'yes' or 'no':"
    " 'You are a winner you have been specially"
    " selected to receive $1000 cash or a $2000 award.'"
)

token_ids = generation.generate_text_simple(
    model=model,
    idx=data.text_to_token_ids(text_2, tokenizer),
    max_new_tokens=23,
    context_size=BASE_CONFIG["context_length"]
)

print(data.token_ids_to_text(token_ids, tokenizer))

Is the following text 'sarcastic'? Answer with 'yes' or 'no': 'You are a winner you have been specially selected to receive $1000 cash or a $2000 award.'

'You have been specially selected to receive $1000 cash or a $2000 award.' 'You have been


<h4>Replacing the last layer to a Classification Layer</h4>

In [73]:
print(model)

GPTModel(
  (tok_emb): Embedding(50257, 768)
  (pos_emb): Embedding(1024, 768)
  (drop_emb): Dropout(p=0.0, inplace=False)
  (transformerlayers): Sequential(
    (0): TransformerBlock(
      (attention): MultiHeadAttention(
        (W_query): Linear(in_features=768, out_features=768, bias=True)
        (W_key): Linear(in_features=768, out_features=768, bias=True)
        (W_value): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
      )
      (feedforward): FeedForward(
        (layers): Sequential(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU()
          (2): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
      (norm1): LayerNorm()
      (norm2): LayerNorm()
      (drop_shortcut): Dropout(p=0.0, inplace=False)
    )
    (1): TransformerBlock(
      (attention): MultiHeadAttention(
        (W_

the goal is to replace and finetune the last transformer block and final output layer, so lets freeze the model first

In [74]:
for param in model.parameters():
    param.requires_grad = False

In [75]:
torch.manual_seed(123)

num_classes = 2
model.out_head = torch.nn.Linear(in_features=BASE_CONFIG["emb_dim"], out_features=num_classes)

In [76]:
for param in model.transformerlayers[-1].parameters():
    param.requires_grad = True

for param in model.finalnorm.parameters():
    param.requires_grad = True

In [77]:
# lets try to use it
inputs = tokenizer.encode("do you have time")
inputs = torch.tensor(inputs).unsqueeze(0)
print("Inputs:", inputs)
print("Inputs dimensions:", inputs.shape) # shape: (batch_size, num_tokens)

with torch.no_grad():
    outputs = model(inputs)

print("Outputs:\n", outputs)
print("Outputs dimensions:", outputs.shape) # shape: (batch_size, num_tokens, num_classes)

Inputs: tensor([[4598,  345,  423,  640]])
Inputs dimensions: torch.Size([1, 4])
Outputs:
 tensor([[[-1.5465,  0.9756],
         [-3.7792,  7.4043],
         [-2.3096,  6.4195],
         [-3.6631,  4.2262]]])
Outputs dimensions: torch.Size([1, 4, 2])


Since we are heavily interested in the last token as its the only token that contains information about the other tokens (enriched of meaning from all the text)

In [78]:
print("Last output token:", outputs[:, -1, :])

Last output token: tensor([[-3.6631,  4.2262]])


<h4>Making an accuracy loader and a loss loader now which will be used when training starts</h4>

In [79]:
def calc_accuracy_loader(data_loader, model, device, num_batches=None):
    model.eval()
    correct_predictions, num_examples = 0, 0

    if num_batches is None:
        num_batches = len(data_loader)
    else:
        num_batches = min(num_batches, len(data_loader))
    
    for i, (input_batch, target_batch) in enumerate(data_loader):
        if i < num_batches:
            input_batch, target_batch = input_batch.to(device), target_batch.to(device)

            with torch.no_grad():
                logits = model(input_batch)[:, -1, :] #logits of last output token
            predicted_labels = torch.argmax(logits, dim=-1)

            num_examples += predicted_labels.shape[0]
            correct_predictions += (predicted_labels == target_batch).sum().item()

        else:
            break

    return correct_predictions/num_examples

In [80]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    # Use PyTorch 2.9 or newer for stable mps results
    major, minor = map(int, torch.__version__.split(".")[:2])
    if (major, minor) >= (2, 9):
        device = torch.device("mps")
    else:
        device = torch.device("cpu")
else:
    device = torch.device("cpu")

print("Device:", device)

model.to(device) # no assignment model = model.to(device) necessary for nn.Module classes

torch.manual_seed(123) # For reproducibility due to the shuffling in the training data loader

train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=10)
val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=10)
test_accuracy = calc_accuracy_loader(test_loader, model, device, num_batches=10)

print(f"Training accuracy: {train_accuracy*100:.2f}%")
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
print(f"Test accuracy: {test_accuracy*100:.2f}%")

Device: mps
Training accuracy: 43.75%
Validation accuracy: 52.50%
Test accuracy: 50.00%


In [81]:
def calc_loss_batch(input_batch, target_batch, model, device):
    input_batch, target_batch = input_batch.to(device), target_batch.to(device)
    logits = model(input_batch)[:, -1, :]  # Logits of last output token
    loss = torch.nn.functional.cross_entropy(logits, target_batch)
    return loss


In [82]:
def calc_loss_loader(data_loader, model, device, num_batches=None):
    total_loss = 0.
    if len(data_loader) == 0:
        return float("nan")
    elif num_batches is None:
        num_batches = len(data_loader)
    else:
        num_batches = min(num_batches, len(data_loader))
    for i, (input_batch, target_batch) in enumerate(data_loader):
        if i < num_batches:
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            total_loss += loss.item()
        else:
            break
    return total_loss / num_batches


In [83]:
with torch.no_grad():
    train_loss = calc_loss_loader(train_loader, model, device, num_batches=5)
    val_loss = calc_loss_loader(val_loader, model, device, num_batches=5)
    test_loss = calc_loss_loader(test_loader, model, device, num_batches=5)

print(f"Training loss: {train_loss:.3f}")
print(f"Validation loss: {val_loss:.3f}")
print(f"Test loss: {test_loss:.3f}")

Training loss: 4.115
Validation loss: 4.790
Test loss: 4.573


In [None]:
def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
                            eval_freq, eval_iter):
    train_losses, val_losses, train_accs, val_accs = [], [], [], []
    examples_seen, global_step = 0, -1

    # main training loop
    for epoch in range(num_epochs):
        model.train()

        for input_batch, target_batch in train_loader:
            optimizer.zero_grad()
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            loss.backward()
            optimizer.step()
            examples_seen += input_batch.shape[0]
            global_step += 1

            # copying this optional evaluation step
            if global_step % eval_freq == 0:
                train_loss, val_loss = evaluate_model(
                    model, train_loader, val_loader, device, eval_iter)
                train_losses.append(train_loss)
                val_losses.append(val_loss)
                print(f"Ep {epoch+1} (Step {global_step:06d}): "
                      f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")

        
        # Calculate accuracy after each epoch
        train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter)
        val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter)
        print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="")
        print(f"Validation accuracy: {val_accuracy*100:.2f}%")
        train_accs.append(train_accuracy)
        val_accs.append(val_accuracy)

def evaluate_model(model, train_loader, val_loader, device, eval_iter):
    model.eval()
    with torch.no_grad(): # disables gradient, saves time and memory
        train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
        val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
    model.train()
    return train_loss, val_loss

In [None]:
import time

start_time = time.time()

torch.manual_seed(123)

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1)

num_epochs = 1
train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
    model, train_loader, val_loader, optimizer, device,
    num_epochs=num_epochs, eval_freq=50, eval_iter=5,
)

end_time = time.time()
execution_time_minutes = (end_time - start_time) / 60
print(f"Training completed in {execution_time_minutes:.2f} minutes.")

In [86]:
train_accuracy = calc_accuracy_loader(train_loader, model, device)
val_accuracy = calc_accuracy_loader(val_loader, model, device)
test_accuracy = calc_accuracy_loader(test_loader, model, device)

print(f"Training accuracy: {train_accuracy*100:.2f}%")
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
print(f"Test accuracy: {test_accuracy*100:.2f}%")

Training accuracy: 50.99%
Validation accuracy: 50.79%
Test accuracy: 50.17%


<h4>Inferencing one example to see how the model behaves</h4>

In [89]:
def is_this_sarcasm_101(text, model, tokenizer, device, max_length=None, pad_token_id=50256):
    model.eval() # dropout off, batchnorm stops updating stats, 

    # you will get an input text, encode it, truncate it to max_length if given, pad it, convert to tensor and add a batch dimension, pass it to model with no_grad(), get the index of the logit that is max, return sarcasm if that index is 1

    input_ids = tokenizer.encode(text)
    supported_context_length_for_this_model = model.pos_emb.weight.shape[0]
    input_ids = input_ids[:min(max_length, supported_context_length_for_this_model)]
    input_ids += [pad_token_id] * (max_length - len(input_ids))
    input_tensor = torch.tensor(input_ids, device=device).unsqueeze(0)

    with torch.no_grad():
        logits = model(input_tensor)[:, -1, :]
    predicted_label = torch.argmax(logits, dim=-1).item()

    return "funny" if predicted_label == 1 else "not so funny"


In [95]:
text_1 = (
    "You are a winner you have been specially"
    " selected to receive $1000 cash or a $2000 award."
)
text_2 = (
    "whatevver you say"
    " for dinner tonight? Let me know!"
)

print(is_this_sarcasm_101(
    text_2, model, tokenizer, device, max_length=train_dataset.max_length
))

funny


In [None]:
torch.save(model.state_dict(), "sarcasm_classifier.pth")