In [1]:
import json
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import ViTForImageClassification, ViTFeatureExtractor, TrainingArguments, Trainer
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import cv2

# Load annotations
def load_annotations(json_file):
    with open(json_file, 'r') as f:
        annotation_data = json.load(f)
    image_id_to_file = {img['id']: img['file_name'] for img in annotation_data['images']}
    annotations = {image_id_to_file[ann['image_id']]: ann['category_id'] for ann in annotation_data['annotations']}
    return annotations

# Data generator function
def data_generator(directory, annotations, batch_size=32, target_size=(224, 224)):
    files = list(annotations.keys())
    num_files = len(files)
    while True:
        np.random.shuffle(files)
        for offset in range(0, num_files, batch_size):
            batch_files = files[offset:offset+batch_size]
            images = []
            labels = []
            for file in batch_files:
                image_path = os.path.join(directory, file)
                image = cv2.imread(image_path)
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                image = cv2.resize(image, target_size)
                images.append(image)
                labels.append(annotations[file])
            yield np.array(images) / 255.0, np.array(labels)

# Custom Dataset class
class CustomDataset(Dataset):
    def __init__(self, directory, annotations, transform=None):
        self.directory = directory
        self.annotations = annotations
        self.transform = transform
        self.files = list(annotations.keys())

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file = self.files[idx]
        image_path = os.path.join(self.directory, file)
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, (224, 224))
        label = self.annotations[file]

        if self.transform:
            image = self.transform(image)

        return {"pixel_values": torch.tensor(image).permute(2, 0, 1).float(), "label": torch.tensor(label).long()}

# Load annotations
train_annotations = load_annotations(r'dataset\train\_annotations.coco.json')
val_annotations = load_annotations(r'dataset\valid\_annotations.coco.json')

# Create datasets
train_dataset = CustomDataset(r'dataset\train', train_annotations)
val_dataset = CustomDataset(r'dataset\valid', val_annotations)

# Load the pre-trained ViT model and feature extractor
model_name = "google/vit-base-patch16-224"
num_classes = 4

# Ignore size mismatches for the classifier layer
model = ViTForImageClassification.from_pretrained(model_name, num_labels=num_classes, ignore_mismatched_sizes=True)
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)

# Define compute metrics function
def compute_metrics(p):
    preds = np.argmax(p.predictions, axis=1)
    labels = p.label_ids
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'precision': precision,
        'recall': recall,
        'f1': f1,
    }

# Use mixed precision training
training_args = TrainingArguments(
    per_device_train_batch_size=16,  # Increase batch size for faster training
    evaluation_strategy="steps",
    num_train_epochs=10,  # Increase number of epochs
    learning_rate=5e-5,  # Adjust learning rate
    weight_decay=0.01,
    save_steps=500,
    save_total_limit=2,
    output_dir="./vit-finetuned",
    logging_dir="./vit-finetuned/logs",
    fp16=True,  # Enable mixed precision training
    logging_steps=10,
    eval_steps=100,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
)

# Train the model
trainer.train()

# Save the fine-tuned model
trainer.save_model("./vit-finetuned")


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([4]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([4, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/2360 [00:00<?, ?it/s]

{'loss': 1.3849, 'grad_norm': 4.339969635009766, 'learning_rate': 4.9830508474576276e-05, 'epoch': 0.04}
{'loss': 1.2616, 'grad_norm': 3.3621714115142822, 'learning_rate': 4.961864406779661e-05, 'epoch': 0.08}
{'loss': 1.2247, 'grad_norm': 5.570422172546387, 'learning_rate': 4.940677966101695e-05, 'epoch': 0.13}
{'loss': 1.0638, 'grad_norm': 4.337039947509766, 'learning_rate': 4.919491525423729e-05, 'epoch': 0.17}
{'loss': 1.1508, 'grad_norm': 6.269214630126953, 'learning_rate': 4.898305084745763e-05, 'epoch': 0.21}
{'loss': 1.036, 'grad_norm': 6.752583980560303, 'learning_rate': 4.877118644067797e-05, 'epoch': 0.25}
{'loss': 1.0256, 'grad_norm': 4.5840678215026855, 'learning_rate': 4.855932203389831e-05, 'epoch': 0.3}
{'loss': 0.9327, 'grad_norm': nan, 'learning_rate': 4.8368644067796615e-05, 'epoch': 0.34}
{'loss': 1.0286, 'grad_norm': 6.176445960998535, 'learning_rate': 4.815677966101695e-05, 'epoch': 0.38}
{'loss': 1.0096, 'grad_norm': 5.513400077819824, 'learning_rate': 4.79449152

  0%|          | 0/114 [00:00<?, ?it/s]

{'eval_loss': 1.2564431428909302, 'eval_accuracy': 0.4718232044198895, 'eval_precision': 0.5423971270575607, 'eval_recall': 0.4718232044198895, 'eval_f1': 0.4314397442986558, 'eval_runtime': 15.3002, 'eval_samples_per_second': 59.15, 'eval_steps_per_second': 7.451, 'epoch': 0.42}
{'loss': 0.9487, 'grad_norm': 5.125683307647705, 'learning_rate': 4.7733050847457624e-05, 'epoch': 0.47}
{'loss': 0.9688, 'grad_norm': 5.3657755851745605, 'learning_rate': 4.752118644067797e-05, 'epoch': 0.51}
{'loss': 1.023, 'grad_norm': 10.060126304626465, 'learning_rate': 4.7309322033898304e-05, 'epoch': 0.55}
{'loss': 0.9265, 'grad_norm': 4.157229900360107, 'learning_rate': 4.709745762711865e-05, 'epoch': 0.59}
{'loss': 0.9891, 'grad_norm': 5.18560266494751, 'learning_rate': 4.6885593220338983e-05, 'epoch': 0.64}
{'loss': 0.9145, 'grad_norm': 6.376780033111572, 'learning_rate': 4.667372881355933e-05, 'epoch': 0.68}
{'loss': 0.9515, 'grad_norm': 7.651336193084717, 'learning_rate': 4.646186440677966e-05, 'ep

  0%|          | 0/114 [00:00<?, ?it/s]

{'eval_loss': 1.0619277954101562, 'eval_accuracy': 0.5480662983425414, 'eval_precision': 0.6101329471032585, 'eval_recall': 0.5480662983425414, 'eval_f1': 0.5135286365445679, 'eval_runtime': 9.0708, 'eval_samples_per_second': 99.77, 'eval_steps_per_second': 12.568, 'epoch': 0.85}
{'loss': 0.8402, 'grad_norm': 4.520654678344727, 'learning_rate': 4.5614406779661015e-05, 'epoch': 0.89}
{'loss': 0.8754, 'grad_norm': 8.836793899536133, 'learning_rate': 4.540254237288136e-05, 'epoch': 0.93}
{'loss': 0.9129, 'grad_norm': 5.590432643890381, 'learning_rate': 4.5190677966101695e-05, 'epoch': 0.97}
{'loss': 0.7367, 'grad_norm': 6.506397247314453, 'learning_rate': 4.497881355932204e-05, 'epoch': 1.02}
{'loss': 0.8291, 'grad_norm': 6.117615222930908, 'learning_rate': 4.4766949152542374e-05, 'epoch': 1.06}
{'loss': 0.7453, 'grad_norm': 6.453190326690674, 'learning_rate': 4.455508474576272e-05, 'epoch': 1.1}
{'loss': 0.7257, 'grad_norm': 13.946218490600586, 'learning_rate': 4.4343220338983054e-05, 'e

  0%|          | 0/114 [00:00<?, ?it/s]

{'eval_loss': 1.0600569248199463, 'eval_accuracy': 0.5513812154696133, 'eval_precision': 0.5863328544272923, 'eval_recall': 0.5513812154696133, 'eval_f1': 0.5464039923983277, 'eval_runtime': 7.0888, 'eval_samples_per_second': 127.665, 'eval_steps_per_second': 16.082, 'epoch': 1.27}
{'loss': 0.7579, 'grad_norm': 12.604698181152344, 'learning_rate': 4.3495762711864406e-05, 'epoch': 1.31}
{'loss': 0.7857, 'grad_norm': 5.4248576164245605, 'learning_rate': 4.328389830508475e-05, 'epoch': 1.36}
{'loss': 0.7029, 'grad_norm': 9.600992202758789, 'learning_rate': 4.3072033898305085e-05, 'epoch': 1.4}
{'loss': 0.7687, 'grad_norm': 9.32540225982666, 'learning_rate': 4.286016949152543e-05, 'epoch': 1.44}
{'loss': 0.6906, 'grad_norm': 6.682056903839111, 'learning_rate': 4.2648305084745765e-05, 'epoch': 1.48}
{'loss': 0.8408, 'grad_norm': 8.464341163635254, 'learning_rate': 4.243644067796611e-05, 'epoch': 1.53}
{'loss': 0.8459, 'grad_norm': 8.105401992797852, 'learning_rate': 4.2224576271186444e-05, 

  0%|          | 0/114 [00:00<?, ?it/s]

{'eval_loss': 1.1535677909851074, 'eval_accuracy': 0.5834254143646409, 'eval_precision': 0.6521873813705926, 'eval_recall': 0.5834254143646409, 'eval_f1': 0.5465822853788804, 'eval_runtime': 6.9718, 'eval_samples_per_second': 129.808, 'eval_steps_per_second': 16.352, 'epoch': 1.69}
{'loss': 0.6811, 'grad_norm': 8.741164207458496, 'learning_rate': 4.1377118644067797e-05, 'epoch': 1.74}
{'loss': 0.7047, 'grad_norm': 8.004161834716797, 'learning_rate': 4.116525423728814e-05, 'epoch': 1.78}
{'loss': 1.0156, 'grad_norm': 8.333066940307617, 'learning_rate': 4.0953389830508476e-05, 'epoch': 1.82}
{'loss': 0.6671, 'grad_norm': 8.817276000976562, 'learning_rate': 4.074152542372881e-05, 'epoch': 1.86}
{'loss': 0.6319, 'grad_norm': 6.339125633239746, 'learning_rate': 4.0529661016949156e-05, 'epoch': 1.91}
{'loss': 0.6844, 'grad_norm': 7.746669292449951, 'learning_rate': 4.031779661016949e-05, 'epoch': 1.95}
{'loss': 0.7559, 'grad_norm': 11.195008277893066, 'learning_rate': 4.0105932203389835e-05,

  0%|          | 0/114 [00:00<?, ?it/s]

{'eval_loss': 0.9567936062812805, 'eval_accuracy': 0.611049723756906, 'eval_precision': 0.6288571902587222, 'eval_recall': 0.611049723756906, 'eval_f1': 0.6169281189774889, 'eval_runtime': 7.0013, 'eval_samples_per_second': 129.261, 'eval_steps_per_second': 16.283, 'epoch': 2.12}
{'loss': 0.5218, 'grad_norm': 4.636612892150879, 'learning_rate': 3.925847457627119e-05, 'epoch': 2.16}
{'loss': 0.6288, 'grad_norm': 9.713019371032715, 'learning_rate': 3.9046610169491524e-05, 'epoch': 2.2}
{'loss': 0.5341, 'grad_norm': 6.354191303253174, 'learning_rate': 3.883474576271187e-05, 'epoch': 2.25}
{'loss': 0.3838, 'grad_norm': 3.3089637756347656, 'learning_rate': 3.86228813559322e-05, 'epoch': 2.29}
{'loss': 0.5023, 'grad_norm': 5.590969562530518, 'learning_rate': 3.8411016949152546e-05, 'epoch': 2.33}
{'loss': 0.6601, 'grad_norm': 8.844505310058594, 'learning_rate': 3.819915254237288e-05, 'epoch': 2.37}
{'loss': 0.4983, 'grad_norm': 6.218653202056885, 'learning_rate': 3.7987288135593226e-05, 'epo

  0%|          | 0/114 [00:00<?, ?it/s]

{'eval_loss': 0.9188376665115356, 'eval_accuracy': 0.6198895027624309, 'eval_precision': 0.6177910743219333, 'eval_recall': 0.6198895027624309, 'eval_f1': 0.6182594011193133, 'eval_runtime': 7.0295, 'eval_samples_per_second': 128.744, 'eval_steps_per_second': 16.217, 'epoch': 2.54}
{'loss': 0.4045, 'grad_norm': 6.786558628082275, 'learning_rate': 3.713983050847458e-05, 'epoch': 2.58}
{'loss': 0.5705, 'grad_norm': 11.156797409057617, 'learning_rate': 3.6927966101694914e-05, 'epoch': 2.63}
{'loss': 0.6019, 'grad_norm': 3.9266834259033203, 'learning_rate': 3.671610169491526e-05, 'epoch': 2.67}
{'loss': 0.6011, 'grad_norm': 13.746156692504883, 'learning_rate': 3.6504237288135594e-05, 'epoch': 2.71}
{'loss': 0.5526, 'grad_norm': 6.688516616821289, 'learning_rate': 3.629237288135594e-05, 'epoch': 2.75}
{'loss': 0.5071, 'grad_norm': 5.1260833740234375, 'learning_rate': 3.608050847457627e-05, 'epoch': 2.8}
{'loss': 0.4592, 'grad_norm': 10.460002899169922, 'learning_rate': 3.5868644067796617e-0

  0%|          | 0/114 [00:00<?, ?it/s]

{'eval_loss': 0.9100421667098999, 'eval_accuracy': 0.6607734806629835, 'eval_precision': 0.6751963596859927, 'eval_recall': 0.6607734806629835, 'eval_f1': 0.6637408598298367, 'eval_runtime': 7.02, 'eval_samples_per_second': 128.917, 'eval_steps_per_second': 16.239, 'epoch': 2.97}
{'loss': 0.7037, 'grad_norm': 7.908899307250977, 'learning_rate': 3.502118644067797e-05, 'epoch': 3.01}
{'loss': 0.4209, 'grad_norm': 7.343352317810059, 'learning_rate': 3.4809322033898305e-05, 'epoch': 3.05}
{'loss': 0.2332, 'grad_norm': 3.7378792762756348, 'learning_rate': 3.459745762711865e-05, 'epoch': 3.09}
{'loss': 0.2802, 'grad_norm': 7.202507495880127, 'learning_rate': 3.4385593220338985e-05, 'epoch': 3.14}
{'loss': 0.2804, 'grad_norm': 4.644125461578369, 'learning_rate': 3.417372881355933e-05, 'epoch': 3.18}
{'loss': 0.2718, 'grad_norm': 2.566854238510132, 'learning_rate': 3.3961864406779664e-05, 'epoch': 3.22}
{'loss': 0.2774, 'grad_norm': 7.240749359130859, 'learning_rate': 3.375000000000001e-05, 'e

  0%|          | 0/114 [00:00<?, ?it/s]

{'eval_loss': 1.0821460485458374, 'eval_accuracy': 0.5911602209944752, 'eval_precision': 0.6218296268508681, 'eval_recall': 0.5911602209944752, 'eval_f1': 0.590409400168999, 'eval_runtime': 7.0457, 'eval_samples_per_second': 128.446, 'eval_steps_per_second': 16.18, 'epoch': 3.39}
{'loss': 0.3621, 'grad_norm': 4.669732570648193, 'learning_rate': 3.290254237288136e-05, 'epoch': 3.43}
{'loss': 0.2912, 'grad_norm': 6.292046546936035, 'learning_rate': 3.2690677966101696e-05, 'epoch': 3.47}
{'loss': 0.3139, 'grad_norm': 9.630937576293945, 'learning_rate': 3.247881355932203e-05, 'epoch': 3.52}
{'loss': 0.279, 'grad_norm': 3.6577982902526855, 'learning_rate': 3.2266949152542375e-05, 'epoch': 3.56}
{'loss': 0.3465, 'grad_norm': 5.74441385269165, 'learning_rate': 3.205508474576271e-05, 'epoch': 3.6}
{'loss': 0.386, 'grad_norm': 18.881990432739258, 'learning_rate': 3.1843220338983055e-05, 'epoch': 3.64}
{'loss': 0.3201, 'grad_norm': 12.033082962036133, 'learning_rate': 3.163135593220339e-05, 'epo

  0%|          | 0/114 [00:00<?, ?it/s]

{'eval_loss': 1.0108590126037598, 'eval_accuracy': 0.6386740331491713, 'eval_precision': 0.6578810673980335, 'eval_recall': 0.6386740331491713, 'eval_f1': 0.640305251877368, 'eval_runtime': 7.0652, 'eval_samples_per_second': 128.092, 'eval_steps_per_second': 16.135, 'epoch': 3.81}
{'loss': 0.291, 'grad_norm': 9.812150955200195, 'learning_rate': 3.078389830508474e-05, 'epoch': 3.86}
{'loss': 0.3077, 'grad_norm': 10.15845012664795, 'learning_rate': 3.0572033898305086e-05, 'epoch': 3.9}
{'loss': 0.3192, 'grad_norm': 11.038859367370605, 'learning_rate': 3.0360169491525426e-05, 'epoch': 3.94}
{'loss': 0.4067, 'grad_norm': 7.538419246673584, 'learning_rate': 3.0148305084745766e-05, 'epoch': 3.98}
{'loss': 0.252, 'grad_norm': 9.898098945617676, 'learning_rate': 2.9936440677966106e-05, 'epoch': 4.03}
{'loss': 0.1714, 'grad_norm': 4.892890930175781, 'learning_rate': 2.9724576271186445e-05, 'epoch': 4.07}
{'loss': 0.1583, 'grad_norm': 2.6443281173706055, 'learning_rate': 2.9512711864406782e-05, 

  0%|          | 0/114 [00:00<?, ?it/s]

{'eval_loss': 1.0443743467330933, 'eval_accuracy': 0.6276243093922652, 'eval_precision': 0.6244641651103858, 'eval_recall': 0.6276243093922652, 'eval_f1': 0.6146439204191904, 'eval_runtime': 7.027, 'eval_samples_per_second': 128.79, 'eval_steps_per_second': 16.223, 'epoch': 4.24}
{'loss': 0.1348, 'grad_norm': 3.5360970497131348, 'learning_rate': 2.8665254237288137e-05, 'epoch': 4.28}
{'loss': 0.1384, 'grad_norm': 5.331282615661621, 'learning_rate': 2.8453389830508477e-05, 'epoch': 4.32}
{'loss': 0.1173, 'grad_norm': 5.184047222137451, 'learning_rate': 2.8241525423728814e-05, 'epoch': 4.36}
{'loss': 0.1966, 'grad_norm': 10.24756908416748, 'learning_rate': 2.8029661016949153e-05, 'epoch': 4.41}
{'loss': 0.1473, 'grad_norm': 6.799671649932861, 'learning_rate': 2.7817796610169493e-05, 'epoch': 4.45}
{'loss': 0.1363, 'grad_norm': 5.290993690490723, 'learning_rate': 2.7605932203389833e-05, 'epoch': 4.49}
{'loss': 0.1404, 'grad_norm': 7.3325724601745605, 'learning_rate': 2.7394067796610173e-0

  0%|          | 0/114 [00:00<?, ?it/s]

{'eval_loss': 1.120589017868042, 'eval_accuracy': 0.6353591160220995, 'eval_precision': 0.6336213946221242, 'eval_recall': 0.6353591160220995, 'eval_f1': 0.6273032371486374, 'eval_runtime': 6.9978, 'eval_samples_per_second': 129.327, 'eval_steps_per_second': 16.291, 'epoch': 4.66}
{'loss': 0.1486, 'grad_norm': 18.092220306396484, 'learning_rate': 2.6546610169491525e-05, 'epoch': 4.7}
{'loss': 0.149, 'grad_norm': 9.24963092803955, 'learning_rate': 2.6334745762711865e-05, 'epoch': 4.75}
{'loss': 0.1631, 'grad_norm': 8.231568336486816, 'learning_rate': 2.6122881355932204e-05, 'epoch': 4.79}
{'loss': 0.1143, 'grad_norm': 4.070137977600098, 'learning_rate': 2.5911016949152544e-05, 'epoch': 4.83}
{'loss': 0.1359, 'grad_norm': 6.593826770782471, 'learning_rate': 2.5699152542372884e-05, 'epoch': 4.87}
{'loss': 0.1241, 'grad_norm': 4.943995475769043, 'learning_rate': 2.5487288135593224e-05, 'epoch': 4.92}
{'loss': 0.1267, 'grad_norm': 6.646225452423096, 'learning_rate': 2.5275423728813563e-05, 

  0%|          | 0/114 [00:00<?, ?it/s]

{'eval_loss': 1.0815123319625854, 'eval_accuracy': 0.6541436464088398, 'eval_precision': 0.6722808684465286, 'eval_recall': 0.6541436464088398, 'eval_f1': 0.6572308276492534, 'eval_runtime': 7.0227, 'eval_samples_per_second': 128.868, 'eval_steps_per_second': 16.233, 'epoch': 5.08}
{'loss': 0.0593, 'grad_norm': 2.5291764736175537, 'learning_rate': 2.4427966101694915e-05, 'epoch': 5.13}
{'loss': 0.0311, 'grad_norm': 1.6286453008651733, 'learning_rate': 2.4216101694915255e-05, 'epoch': 5.17}
{'loss': 0.0425, 'grad_norm': 1.8574614524841309, 'learning_rate': 2.4004237288135595e-05, 'epoch': 5.21}
{'loss': 0.0645, 'grad_norm': 3.3376286029815674, 'learning_rate': 2.3792372881355935e-05, 'epoch': 5.25}
{'loss': 0.0537, 'grad_norm': 5.972339630126953, 'learning_rate': 2.358050847457627e-05, 'epoch': 5.3}
{'loss': 0.0599, 'grad_norm': 2.4735734462738037, 'learning_rate': 2.336864406779661e-05, 'epoch': 5.34}
{'loss': 0.0516, 'grad_norm': 3.3025097846984863, 'learning_rate': 2.315677966101695e

  0%|          | 0/114 [00:00<?, ?it/s]

{'eval_loss': 1.269785761833191, 'eval_accuracy': 0.6486187845303868, 'eval_precision': 0.655763776129835, 'eval_recall': 0.6486187845303868, 'eval_f1': 0.6449207826599932, 'eval_runtime': 6.9622, 'eval_samples_per_second': 129.987, 'eval_steps_per_second': 16.374, 'epoch': 5.51}
{'loss': 0.0378, 'grad_norm': 2.9294545650482178, 'learning_rate': 2.2309322033898306e-05, 'epoch': 5.55}
{'loss': 0.0308, 'grad_norm': 0.3869571387767792, 'learning_rate': 2.2097457627118646e-05, 'epoch': 5.59}
{'loss': 0.0442, 'grad_norm': 7.035733222961426, 'learning_rate': 2.1885593220338986e-05, 'epoch': 5.64}
{'loss': 0.0549, 'grad_norm': 2.1625025272369385, 'learning_rate': 2.1673728813559325e-05, 'epoch': 5.68}
{'loss': 0.0456, 'grad_norm': 8.810321807861328, 'learning_rate': 2.1461864406779662e-05, 'epoch': 5.72}
{'loss': 0.0342, 'grad_norm': 1.4686696529388428, 'learning_rate': 2.125e-05, 'epoch': 5.76}
{'loss': 0.0364, 'grad_norm': 4.746033191680908, 'learning_rate': 2.103813559322034e-05, 'epoch': 

  0%|          | 0/114 [00:00<?, ?it/s]

{'eval_loss': 1.3872288465499878, 'eval_accuracy': 0.630939226519337, 'eval_precision': 0.6325977375732607, 'eval_recall': 0.630939226519337, 'eval_f1': 0.621710741794161, 'eval_runtime': 6.9747, 'eval_samples_per_second': 129.755, 'eval_steps_per_second': 16.345, 'epoch': 5.93}
{'loss': 0.0565, 'grad_norm': 14.437975883483887, 'learning_rate': 2.0190677966101697e-05, 'epoch': 5.97}
{'loss': 0.0363, 'grad_norm': 2.699342727661133, 'learning_rate': 1.9978813559322033e-05, 'epoch': 6.02}
{'loss': 0.0238, 'grad_norm': 0.5630282759666443, 'learning_rate': 1.9766949152542373e-05, 'epoch': 6.06}
{'loss': 0.0188, 'grad_norm': 0.6141085028648376, 'learning_rate': 1.9555084745762713e-05, 'epoch': 6.1}
{'loss': 0.016, 'grad_norm': 3.182340621948242, 'learning_rate': 1.934322033898305e-05, 'epoch': 6.14}
{'loss': 0.0143, 'grad_norm': 0.6292985081672668, 'learning_rate': 1.913135593220339e-05, 'epoch': 6.19}
{'loss': 0.0085, 'grad_norm': 0.5470725297927856, 'learning_rate': 1.891949152542373e-05, 

  0%|          | 0/114 [00:00<?, ?it/s]

{'eval_loss': 1.4188685417175293, 'eval_accuracy': 0.6397790055248619, 'eval_precision': 0.6469088649602887, 'eval_recall': 0.6397790055248619, 'eval_f1': 0.6354780641061185, 'eval_runtime': 7.0148, 'eval_samples_per_second': 129.014, 'eval_steps_per_second': 16.251, 'epoch': 6.36}
{'loss': 0.0091, 'grad_norm': 0.3465679883956909, 'learning_rate': 1.8072033898305084e-05, 'epoch': 6.4}
{'loss': 0.0093, 'grad_norm': 0.6324267387390137, 'learning_rate': 1.7860169491525424e-05, 'epoch': 6.44}
{'loss': 0.0079, 'grad_norm': 0.33108845353126526, 'learning_rate': 1.7648305084745764e-05, 'epoch': 6.48}
{'loss': 0.0095, 'grad_norm': 0.24289800226688385, 'learning_rate': 1.7436440677966103e-05, 'epoch': 6.53}
{'loss': 0.0075, 'grad_norm': 0.2144320160150528, 'learning_rate': 1.722457627118644e-05, 'epoch': 6.57}
{'loss': 0.0089, 'grad_norm': 1.6761102676391602, 'learning_rate': 1.701271186440678e-05, 'epoch': 6.61}
{'loss': 0.0159, 'grad_norm': 3.6873245239257812, 'learning_rate': 1.6800847457627

  0%|          | 0/114 [00:00<?, ?it/s]

{'eval_loss': 1.4651143550872803, 'eval_accuracy': 0.6375690607734806, 'eval_precision': 0.6402621959694129, 'eval_recall': 0.6375690607734806, 'eval_f1': 0.6313073895892909, 'eval_runtime': 7.0285, 'eval_samples_per_second': 128.762, 'eval_steps_per_second': 16.22, 'epoch': 6.78}
{'loss': 0.0085, 'grad_norm': 1.6042416095733643, 'learning_rate': 1.5953389830508475e-05, 'epoch': 6.82}
{'loss': 0.0151, 'grad_norm': 1.4477825164794922, 'learning_rate': 1.5741525423728815e-05, 'epoch': 6.86}
{'loss': 0.0063, 'grad_norm': 1.9307973384857178, 'learning_rate': 1.5529661016949154e-05, 'epoch': 6.91}
{'loss': 0.0118, 'grad_norm': 1.3263036012649536, 'learning_rate': 1.5317796610169494e-05, 'epoch': 6.95}
{'loss': 0.009, 'grad_norm': 0.5736232995986938, 'learning_rate': 1.510593220338983e-05, 'epoch': 6.99}
{'loss': 0.0047, 'grad_norm': 0.706710934638977, 'learning_rate': 1.489406779661017e-05, 'epoch': 7.03}
{'loss': 0.0031, 'grad_norm': 0.0903150662779808, 'learning_rate': 1.468220338983051e-

  0%|          | 0/114 [00:00<?, ?it/s]

{'eval_loss': 1.5372898578643799, 'eval_accuracy': 0.63646408839779, 'eval_precision': 0.6375595849508836, 'eval_recall': 0.63646408839779, 'eval_f1': 0.6327554434719412, 'eval_runtime': 6.9352, 'eval_samples_per_second': 130.494, 'eval_steps_per_second': 16.438, 'epoch': 7.2}
{'loss': 0.0029, 'grad_norm': 0.2819845974445343, 'learning_rate': 1.3834745762711866e-05, 'epoch': 7.25}
{'loss': 0.0036, 'grad_norm': 0.6126620769500732, 'learning_rate': 1.3622881355932204e-05, 'epoch': 7.29}
{'loss': 0.0018, 'grad_norm': 0.18634407222270966, 'learning_rate': 1.3411016949152543e-05, 'epoch': 7.33}
{'loss': 0.0024, 'grad_norm': 0.33525991439819336, 'learning_rate': 1.3199152542372881e-05, 'epoch': 7.37}
{'loss': 0.0018, 'grad_norm': 0.10686402767896652, 'learning_rate': 1.298728813559322e-05, 'epoch': 7.42}
{'loss': 0.0038, 'grad_norm': 0.06207828223705292, 'learning_rate': 1.277542372881356e-05, 'epoch': 7.46}
{'loss': 0.0022, 'grad_norm': 0.1578681319952011, 'learning_rate': 1.256355932203389

  0%|          | 0/114 [00:00<?, ?it/s]

{'eval_loss': 1.5689345598220825, 'eval_accuracy': 0.6397790055248619, 'eval_precision': 0.6443369117218846, 'eval_recall': 0.6397790055248619, 'eval_f1': 0.6373523305099794, 'eval_runtime': 6.9721, 'eval_samples_per_second': 129.804, 'eval_steps_per_second': 16.351, 'epoch': 7.63}
{'loss': 0.0019, 'grad_norm': 0.05572817102074623, 'learning_rate': 1.1716101694915255e-05, 'epoch': 7.67}
{'loss': 0.0023, 'grad_norm': 0.7961853742599487, 'learning_rate': 1.1504237288135594e-05, 'epoch': 7.71}
{'loss': 0.0021, 'grad_norm': 0.09453441947698593, 'learning_rate': 1.1292372881355932e-05, 'epoch': 7.75}
{'loss': 0.0024, 'grad_norm': 0.24578362703323364, 'learning_rate': 1.1080508474576272e-05, 'epoch': 7.8}
{'loss': 0.0016, 'grad_norm': 0.07784632593393326, 'learning_rate': 1.086864406779661e-05, 'epoch': 7.84}
{'loss': 0.003, 'grad_norm': 0.3498338758945465, 'learning_rate': 1.065677966101695e-05, 'epoch': 7.88}
{'loss': 0.0034, 'grad_norm': 0.3788709044456482, 'learning_rate': 1.044491525423

  0%|          | 0/114 [00:00<?, ?it/s]

{'eval_loss': 1.5405248403549194, 'eval_accuracy': 0.6475138121546962, 'eval_precision': 0.64889037180888, 'eval_recall': 0.6475138121546962, 'eval_f1': 0.647601714380261, 'eval_runtime': 7.0084, 'eval_samples_per_second': 129.131, 'eval_steps_per_second': 16.266, 'epoch': 8.05}
{'loss': 0.0015, 'grad_norm': 0.8023620843887329, 'learning_rate': 9.597457627118645e-06, 'epoch': 8.09}
{'loss': 0.0011, 'grad_norm': 0.06867803633213043, 'learning_rate': 9.385593220338983e-06, 'epoch': 8.14}
{'loss': 0.0012, 'grad_norm': 0.05407997593283653, 'learning_rate': 9.173728813559321e-06, 'epoch': 8.18}
{'loss': 0.0013, 'grad_norm': 0.097072534263134, 'learning_rate': 8.961864406779661e-06, 'epoch': 8.22}
{'loss': 0.001, 'grad_norm': 0.026518691331148148, 'learning_rate': 8.75e-06, 'epoch': 8.26}
{'loss': 0.0011, 'grad_norm': 0.06383247673511505, 'learning_rate': 8.538135593220339e-06, 'epoch': 8.31}
{'loss': 0.0011, 'grad_norm': 0.04413522779941559, 'learning_rate': 8.326271186440679e-06, 'epoch': 

  0%|          | 0/114 [00:00<?, ?it/s]

{'eval_loss': 1.5936353206634521, 'eval_accuracy': 0.6541436464088398, 'eval_precision': 0.6518412383133928, 'eval_recall': 0.6541436464088398, 'eval_f1': 0.6512950846357748, 'eval_runtime': 6.9251, 'eval_samples_per_second': 130.683, 'eval_steps_per_second': 16.462, 'epoch': 8.47}
{'loss': 0.0014, 'grad_norm': 0.10438597202301025, 'learning_rate': 7.478813559322034e-06, 'epoch': 8.52}
{'loss': 0.0014, 'grad_norm': 0.1225394755601883, 'learning_rate': 7.266949152542374e-06, 'epoch': 8.56}
{'loss': 0.0009, 'grad_norm': 0.07740975171327591, 'learning_rate': 7.055084745762712e-06, 'epoch': 8.6}
{'loss': 0.001, 'grad_norm': 0.040078986436128616, 'learning_rate': 6.843220338983052e-06, 'epoch': 8.64}
{'loss': 0.0009, 'grad_norm': 0.09473121911287308, 'learning_rate': 6.63135593220339e-06, 'epoch': 8.69}
{'loss': 0.0013, 'grad_norm': 0.08369968086481094, 'learning_rate': 6.419491525423729e-06, 'epoch': 8.73}
{'loss': 0.001, 'grad_norm': 0.07657002657651901, 'learning_rate': 6.207627118644068

  0%|          | 0/114 [00:00<?, ?it/s]

{'eval_loss': 1.6206227540969849, 'eval_accuracy': 0.6497237569060773, 'eval_precision': 0.6465021096900179, 'eval_recall': 0.6497237569060773, 'eval_f1': 0.6454800092068237, 'eval_runtime': 6.9526, 'eval_samples_per_second': 130.167, 'eval_steps_per_second': 16.397, 'epoch': 8.9}
{'loss': 0.0009, 'grad_norm': 0.0855693519115448, 'learning_rate': 5.360169491525424e-06, 'epoch': 8.94}
{'loss': 0.001, 'grad_norm': 0.10975849628448486, 'learning_rate': 5.148305084745763e-06, 'epoch': 8.98}
{'loss': 0.0009, 'grad_norm': 0.05465579405426979, 'learning_rate': 4.936440677966102e-06, 'epoch': 9.03}
{'loss': 0.001, 'grad_norm': 0.04077761247754097, 'learning_rate': 4.724576271186441e-06, 'epoch': 9.07}
{'loss': 0.0008, 'grad_norm': 0.027294278144836426, 'learning_rate': 4.51271186440678e-06, 'epoch': 9.11}
{'loss': 0.001, 'grad_norm': 0.047495290637016296, 'learning_rate': 4.300847457627119e-06, 'epoch': 9.15}
{'loss': 0.0008, 'grad_norm': 0.04055822268128395, 'learning_rate': 4.088983050847458

  0%|          | 0/114 [00:00<?, ?it/s]

{'eval_loss': 1.6298973560333252, 'eval_accuracy': 0.6541436464088398, 'eval_precision': 0.6509033363158168, 'eval_recall': 0.6541436464088398, 'eval_f1': 0.6497853352834244, 'eval_runtime': 6.9971, 'eval_samples_per_second': 129.338, 'eval_steps_per_second': 16.292, 'epoch': 9.32}
{'loss': 0.0008, 'grad_norm': 0.06421801447868347, 'learning_rate': 3.241525423728814e-06, 'epoch': 9.36}
{'loss': 0.0008, 'grad_norm': 0.04872427135705948, 'learning_rate': 3.029661016949153e-06, 'epoch': 9.41}
{'loss': 0.0008, 'grad_norm': 0.04480556398630142, 'learning_rate': 2.817796610169492e-06, 'epoch': 9.45}
{'loss': 0.0008, 'grad_norm': 0.08226297795772552, 'learning_rate': 2.6059322033898303e-06, 'epoch': 9.49}
{'loss': 0.0009, 'grad_norm': 0.036264531314373016, 'learning_rate': 2.3940677966101697e-06, 'epoch': 9.53}
{'loss': 0.0011, 'grad_norm': 0.0338827483355999, 'learning_rate': 2.1822033898305086e-06, 'epoch': 9.58}
{'loss': 0.0007, 'grad_norm': 0.05246584862470627, 'learning_rate': 1.97033898

  0%|          | 0/114 [00:00<?, ?it/s]

{'eval_loss': 1.6342209577560425, 'eval_accuracy': 0.6574585635359116, 'eval_precision': 0.6548113782584233, 'eval_recall': 0.6574585635359116, 'eval_f1': 0.6535494395842923, 'eval_runtime': 7.0274, 'eval_samples_per_second': 128.781, 'eval_steps_per_second': 16.222, 'epoch': 9.75}
{'loss': 0.0008, 'grad_norm': 0.022197166457772255, 'learning_rate': 1.1228813559322035e-06, 'epoch': 9.79}
{'loss': 0.0009, 'grad_norm': 0.0444762147963047, 'learning_rate': 9.110169491525425e-07, 'epoch': 9.83}
{'loss': 0.001, 'grad_norm': 0.03895612806081772, 'learning_rate': 6.991525423728814e-07, 'epoch': 9.87}
{'loss': 0.0009, 'grad_norm': 0.03454240411520004, 'learning_rate': 4.872881355932204e-07, 'epoch': 9.92}
{'loss': 0.0007, 'grad_norm': 0.06842169910669327, 'learning_rate': 2.7542372881355935e-07, 'epoch': 9.96}
{'loss': 0.0009, 'grad_norm': 0.1448354721069336, 'learning_rate': 6.355932203389831e-08, 'epoch': 10.0}
{'train_runtime': 978.5686, 'train_samples_per_second': 38.434, 'train_steps_per_

In [2]:
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc
import matplotlib.pyplot as plt
import numpy as np

In [1]:
# Evaluate the model on the validation dataset
evaluation_results = trainer.evaluate()
print(f"Evaluation Results: {evaluation_results}")

Evaluation Results: {'eval_loss': 1.6354875564575195, 'eval_accuracy': 0.8552486187845303, 'eval_precision': 0.852748436391131, 'eval_recall': 0.8552486187845303, 'eval_f1': 0.8513106008106088, 'eval_runtime': 6.9583, 'eval_samples_per_second': 130.061, 'eval_steps_per_second': 16.383, 'epoch': 10.0}



In [4]:
# Generate predictions
predictions, labels, metrics = trainer.predict(val_dataset)
preds = np.argmax(predictions, axis=1)

  0%|          | 0/114 [00:00<?, ?it/s]

In [5]:
# Confusion Matrix
conf_matrix = confusion_matrix(labels, preds)
print('Confusion Matrix')
print(conf_matrix)

Confusion Matrix
[[187  13  35  20]
 [ 18 101  29  50]
 [ 30  40 110  29]
 [ 10  16  22 195]]


In [2]:
# Classification Report
class_report = classification_report(labels, preds, target_names=["Class 0", "Class 1", "Class 2", "Class 3"])
print('Classification Report')
print(class_report)


Classification Report
              precision    recall  f1-score   support

     Class 0       0.76      0.73      0.75       255
     Class 1       0.79      0.51      0.55       198
     Class 2       0.76      0.53      0.54       209
     Class 3       0.76      0.80      0.73       243

    accuracy                           0.86       905
   macro avg       0.78      0.84      0.84       905
weighted avg       0.78      0.86      0.85       905





In [7]:
# ROC Curve and AUC (for binary classification)
if num_classes == 2:
    fpr, tpr, _ = roc_curve(labels, predictions[:, 1])
    roc_auc = auc(fpr, tpr)
    plt.figure()
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic')
    plt.legend(loc='lower right')
    plt.show()