# Network training

For our project, we constructed a Vision Transformer (ViT) instance that we named `GlowViT` to assess how light exposure affects Neural Network's result at classification.

In [1]:
import torch
import json
import ast
import random
import torch.nn as nn
import numpy as np
import albumentations
from math import floor
from tqdm import tqdm

from pathlib import Path
from PIL import Image, ImageDraw, ImageFont
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from transformers import ViTImageProcessor, ViTForImageClassification, TrainingArguments, Trainer, DefaultDataCollator
from datasets import load_dataset, Dataset

2024-11-26 04:42:24.140768: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-11-26 04:42:24.183669: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-11-26 04:42:24.183708: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-11-26 04:42:24.184973: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-11-26 04:42:24.193022: I tensorflow/core/platform/cpu_feature_guar

uintx feature requires torch 2.3+, please upgrade pytorch


In [2]:
class GlowViT(ViTForImageClassification):
    def help():
        print(ViTForImageClassification.__doc__)

In [3]:
# wild_train_ds = load_dataset("yin30lei/wildlife-from-wildme", split="train[30:60%]", cache_dir=Path.cwd() / "wildlife", num_proc=2)
wildlife_ds = load_dataset("SeaSponge/wildme_dataset", cache_dir=Path.cwd() / "yalu_dataset", num_proc=2)

In [4]:
wildlife_train_ds = wildlife_ds["train"]
wildlife_val_ds = wildlife_ds["validation"]
wildlife_test_ds = wildlife_ds["test"]
del wildlife_ds

In [5]:
# two labels at the moment
num_labels = 2
print(wildlife_train_ds[0])

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=2400x1800 at 0x7FD7FC1A0090>, 'image_id': 0, 'width': 2400, 'height': 1800, 'objects': {'area': [1200451.817558299], 'bbox': [[0.04552469135802469, 0.4825102880658436, 0.5929783950617284, 0.4686213991769547]], 'category': [0], 'label': 'leopard'}}


In [None]:
device = "cuda" if torch.cuda.is_available else "cpu"

In [7]:
# def show_examples(ds, seed: int = 1234, examples_per_class: int = 3, size=(350, 350)):

#     w, h = size
#     print(f"ds.features -> {ds.features}")
#     labels = ds.features['objects']['label']
#     if not isinstance(labels, list):
#         labels = [labels]
#     grid = Image.new('RGB', size=(examples_per_class * w, len(labels) * h))
#     draw = ImageDraw.Draw(grid)
#     font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationMono-Bold.ttf", 24)

#     for label_id, label in enumerate(labels):

#         # Filter the dataset by a single label, shuffle it, and grab a few samples
#         ds_slice = ds.filter(lambda ex: ex['labels'] == label_id).shuffle(seed).select(range(examples_per_class))

#         # Plot this label's examples along a row
#         for i, example in enumerate(ds_slice):
#             image = example['image']
#             idx = examples_per_class * label_id + i
#             box = (idx % examples_per_class * w, idx // examples_per_class * h)
#             grid.paste(image.resize(size), box=box)
#             draw.text(box, label, (255, 255, 255), font=font)

#     return grid

In [8]:
# show_examples(wildlife_val_ds, seed=random.randint(0, 1337), examples_per_class=3)

In [9]:
label2id = {"leopard" : 0, "hyena": 1}
id2label = {v:k for k, v in label2id.items()}

In [10]:
checkpoint = "google/vit-base-patch16-224-in21k"
model = GlowViT.from_pretrained(checkpoint,
                label2id=label2id,
                id2label=id2label,
                num_labels=num_labels)

image_processor = ViTImageProcessor.from_pretrained(checkpoint)

Some weights of GlowViT were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
def transform_ds(examples):
    images, labels = [], []
    for image, objects in zip(examples["image"], examples["objects"]):
        # image = np.array(image.convert("RGB"))[:, :, ::-1]
        #! category needs to be list, hence I put the extra brackets around it
        out = image_processor(images=image.convert("RGB"), return_tensors="pt")["pixel_values"].squeeze(0)

        label = objects["category"][0]
        images.append(out)
        labels.append(label)

    return {"pixel_values": images, "labels": labels}

In [12]:
wildlife_train_ds_transformed = wildlife_train_ds.with_transform(transform_ds)
wildlife_val_ds_transformed = wildlife_val_ds.with_transform(transform_ds)
test_ds = wildlife_test_ds.with_transform(transform_ds)
del wildlife_train_ds
del wildlife_val_ds
del wildlife_test_ds

In [21]:
data_collator = DefaultDataCollator()

In [None]:
# Define the training arguments

# It doesn't utilize GPU effectively, so trying some larger ViT models
# Update: ViT large takes around 1GB, not very ideal. Some going to try other things

#! with ViT-base, can afford to have large batchsizes
# batchsz=2 -> 2133MiB / 11264MiB
# running with larger batchsize actually increase training time, trying to find optimal largest batch
# Update: This seems to go away when I put larger epoch numbers
# When epoch=[10,20] -> 5, 6, 16, 32 (6117MiB / 11264MiB), 50 (8055MiB / 11264MiB)

training_args = TrainingArguments(
    output_dir="glow-model",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    num_train_epochs=20,
    # max_steps=2000,
    fp16=False,
    save_steps=10,
    eval_steps=50,
    logging_steps=30,
    learning_rate=2e-4,
    save_total_limit=1,
    remove_unused_columns=False,
    push_to_hub=True,
    lr_scheduler_type="cosine_with_restarts",  # Cosine scheduler with restarts
)

# Define the trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=wildlife_train_ds_transformed,
    eval_dataset=wildlife_val_ds_transformed,
    tokenizer=image_processor,
)

train_results = trainer.train()
trainer.log_metrics("train", train_results.metrics)

  trainer = Trainer(
max_steps is given, it will override any value given in num_train_epochs


Step,Training Loss
30,0.2363
60,0.2654
90,0.2888
120,0.1694
150,0.211
180,0.3331
210,0.2001
240,0.2575
270,0.3841
300,0.2941


***** train metrics *****
  epoch                    =      5.3908
  total_flos               = 576277763GF
  train_loss               =      0.1638
  train_runtime            =  0:25:23.45
  train_samples_per_second =       5.251
  train_steps_per_second   =       1.313


In [23]:
metrics = trainer.evaluate(wildlife_val_ds_transformed)
trainer.log_metrics("eval", metrics)

***** eval metrics *****
  epoch                   =     5.3908
  eval_loss               =     0.0766
  eval_runtime            = 0:00:37.13
  eval_samples_per_second =     10.261
  eval_steps_per_second   =      1.293


In [24]:
metrics = trainer.evaluate(test_ds)
trainer.log_metrics("test", metrics)

***** test metrics *****
  epoch                   =     5.3908
  eval_loss               =      0.075
  eval_runtime            = 0:00:25.61
  eval_samples_per_second =      9.917
  eval_steps_per_second   =      1.249
