# Network training

For our project, we constructed a Vision Transformer (ViT) instance that we named `GlowViT` to assess how light exposure affects Neural Network's result at classification.

> Note: This script trains the GlowViT model based on different processed datsets, thus the models have different names

In [1]:
import torch
import json
import ast
import random
import torch.nn as nn
import numpy as np
import albumentations
from math import floor
from tqdm import tqdm

from pathlib import Path
from PIL import Image, ImageDraw, ImageFont
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from transformers import ViTImageProcessor, ViTForImageClassification, TrainingArguments, Trainer, DefaultDataCollator
from datasets import load_dataset, Dataset

2024-12-11 17:22:58.281792: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-12-11 17:22:58.325081: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-12-11 17:22:58.325118: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-12-11 17:22:58.326413: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-11 17:22:58.334376: I tensorflow/core/platform/cpu_feature_guar

In [2]:
class GlowViT(ViTForImageClassification):
    def help():
        print(ViTForImageClassification.__doc__)

In [None]:
# wild_train_ds = load_dataset("yin30lei/wildlife-from-wildme", split="train[30:60%]", cache_dir=Path.cwd() / "wildlife", num_proc=2)
yalu_ds_list = [("SeaSponge/wildme10_classify", "glow-vit"),
                ("yin30lei/wildlife_very_dark", "glow-vit-dark"),
                ("yin30lei/wildlife_well_illuminated", "glow-vit-illuminate"),
                ("yin30lei/wildlife_mixed", "glow-vit-mix")]
yalu_ds = yalu_ds_list[2][0]
yalu_model_name = yalu_ds_list[2][1]
wildlife_ds = load_dataset(yalu_ds, cache_dir=Path.cwd() / "yalu_dataset", num_proc=2)

README.md:   0%|          | 0.00/569 [00:00<?, ?B/s]

Generating train split:   0%|          | 0/5441 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/2333 [00:00<?, ? examples/s]

In [4]:
wildlife_train_ds = wildlife_ds["train"]
wildlife_test_ds = wildlife_ds["test"]
del wildlife_ds

In [5]:
# 11 labels for wildme10_classify
num_labels = 11
print(wildlife_train_ds[0])

{'file_name': '60000760.jpg', 'image_id': 4964, 'width': 1024, 'height': 686, 'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1024x686 at 0x7FDA192FA410>, 'labels': 'raccoon'}


In [6]:
device = "cuda" if torch.cuda.is_available else "cpu"

In [7]:
# def show_examples(ds, seed: int = 1234, examples_per_class: int = 3, size=(350, 350)):

#     w, h = size
#     print(f"ds.features -> {ds.features}")
#     labels = ds.features['objects']['label']
#     if not isinstance(labels, list):
#         labels = [labels]
#     grid = Image.new('RGB', size=(examples_per_class * w, len(labels) * h))
#     draw = ImageDraw.Draw(grid)
#     font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationMono-Bold.ttf", 24)

#     for label_id, label in enumerate(labels):

#         # Filter the dataset by a single label, shuffle it, and grab a few samples
#         ds_slice = ds.filter(lambda ex: ex['labels'] == label_id).shuffle(seed).select(range(examples_per_class))

#         # Plot this label's examples along a row
#         for i, example in enumerate(ds_slice):
#             image = example['image']
#             idx = examples_per_class * label_id + i
#             box = (idx % examples_per_class * w, idx // examples_per_class * h)
#             grid.paste(image.resize(size), box=box)
#             draw.text(box, label, (255, 255, 255), font=font)

#     return grid

In [8]:
# show_examples(wildlife_val_ds, seed=random.randint(0, 1337), examples_per_class=3)

In [9]:
label2id = {"lion": 0, "raccoon": 1, "tiger": 2, "wolf": 3,
            "bear": 4, "hare": 5, "fox": 6, "deer": 7, 
            "leopard" : 8, "hyena": 9, "antelope": 10}


id2label = {v:k for k, v in label2id.items()}

In [None]:
checkpoint = "google/vit-base-patch16-224-in21k"
model = GlowViT.from_pretrained(checkpoint,
                cache_dir="default_vit",
                label2id=label2id,
                id2label=id2label,
                num_labels=num_labels,
                attn_implementation="sdpa") # no flash attention yet for ViT model

image_processor = ViTImageProcessor.from_pretrained(checkpoint)

Some weights of GlowViT were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
def transform_ds(examples):
    images, labels = [], []
    for image, label in zip(examples["image"], examples["labels"]):
        pix_val = image_processor(images=image.convert("RGB"), return_tensors="pt")["pixel_values"].squeeze(0)
        pix_val.to(device)
        #! supposed to be a number here
        label = label2id[label]
        images.append(pix_val)
        labels.append(label)

    return {"pixel_values": images, "labels": labels}

In [12]:
wildlife_train_ds_transformed = wildlife_train_ds.with_transform(transform_ds)
test_ds = wildlife_test_ds.with_transform(transform_ds)

In [13]:
print(test_ds[0])

{'pixel_values': tensor([[[ 0.8353,  0.6784,  0.5059,  ..., -0.5137, -0.4980, -0.4745],
         [ 0.8824,  0.7255,  0.5059,  ..., -0.5059, -0.4902, -0.4667],
         [ 0.8980,  0.7569,  0.5451,  ..., -0.5059, -0.4824, -0.4667],
         ...,
         [ 0.9216,  0.9216,  0.9216,  ..., -0.1451, -0.1529, -0.2000],
         [ 0.9529,  0.9529,  0.9529,  ..., -0.2314, -0.2627, -0.2941],
         [-0.2471, -0.2471, -0.2471,  ..., -0.7020, -0.7098, -0.7176]],

        [[ 0.7961,  0.6392,  0.4667,  ..., -0.5059, -0.4824, -0.4588],
         [ 0.8431,  0.6863,  0.4667,  ..., -0.4980, -0.4745, -0.4510],
         [ 0.8588,  0.7176,  0.5059,  ..., -0.4902, -0.4667, -0.4510],
         ...,
         [ 0.9294,  0.9294,  0.9294,  ..., -0.1137, -0.1216, -0.1686],
         [ 0.9608,  0.9608,  0.9608,  ..., -0.2000, -0.2314, -0.2627],
         [-0.2392, -0.2392, -0.2392,  ..., -0.6706, -0.6784, -0.6941]],

        [[ 0.6471,  0.4902,  0.3098,  ..., -0.5765, -0.5686, -0.5686],
         [ 0.6941,  0.5373, 

In [14]:
data_collator = DefaultDataCollator()

In [15]:
# Define the training arguments

# It doesn't utilize GPU effectively, so trying some larger ViT models
# Update: ViT large takes around 1GB, not very ideal. Some going to try other things

#! with ViT-base, can afford to have large batchsizes
# batchsz=2 -> 2133MiB / 11264MiB
# running with larger batchsize actually increase training time, trying to find optimal largest batch
# Update: This seems to go away when I put larger epoch numbers
# When epoch=[10,20] -> 5, 6, 16, 32 (6117MiB / 11264MiB), 50 (8055MiB / 11264MiB)

training_args = TrainingArguments(
    output_dir=yalu_model_name,
    per_device_train_batch_size=50,
    gradient_accumulation_steps=1,
    num_train_epochs=10,
    # max_steps=2000,
    fp16=False,
    save_steps=10,
    eval_steps=50,
    logging_steps=30,
    learning_rate=2e-4,
    save_total_limit=1,
    remove_unused_columns=False,
    push_to_hub=True,
    lr_scheduler_type="cosine_with_restarts",  # Cosine scheduler with restarts
)

# Define the trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=wildlife_train_ds_transformed,
    tokenizer=image_processor,
)

train_results = trainer.train()
trainer.log_metrics("train", train_results.metrics)

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Step,Training Loss
30,1.2326
60,0.3768
90,0.2983
120,0.1996
150,0.1062
180,0.1401
210,0.1095
240,0.0608
270,0.0715
300,0.0785


***** train metrics *****
  epoch                    =         10.0
  total_flos               = 3927088586GF
  train_loss               =       0.0819
  train_runtime            =   1:07:38.30
  train_samples_per_second =       13.407
  train_steps_per_second   =        0.269


In [16]:
metrics = trainer.evaluate(test_ds)
trainer.log_metrics("test", metrics)

***** test metrics *****
  epoch                   =       10.0
  eval_loss               =     0.1064
  eval_runtime            = 0:01:53.13
  eval_samples_per_second =     20.622
  eval_steps_per_second   =      2.581
