# Network training

For our project, we constructed a Vision Transformer (ViT) instance that we named `GlowViT` to assess how light exposure affects Neural Network's result at classification.

In [1]:
import torch
import json
import ast
import random
import torch.nn as nn
import numpy as np
import albumentations
from math import floor
from tqdm import tqdm

from pathlib import Path
from PIL import Image, ImageDraw, ImageFont
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from transformers import ViTImageProcessor, ViTForImageClassification, TrainingArguments, Trainer, DefaultDataCollator
from datasets import load_dataset, Dataset

  check_for_updates()


In [2]:
class GlowViT(ViTForImageClassification):
    def help():
        print(ViTForImageClassification.__doc__)

In [3]:
# wild_train_ds = load_dataset("yin30lei/wildlife-from-wildme", split="train[30:60%]", cache_dir=Path.cwd() / "wildlife", num_proc=2)
yalu_ds_list = ["SeaSponge/wildme_dataset", "SeaSponge/wildme10_classify"]
yalu_ds = yalu_ds_list[1]
wildlife_ds = load_dataset(yalu_ds, cache_dir=Path.cwd() / "yalu_dataset", num_proc=2)

In [4]:
wildlife_train_ds = wildlife_ds["train"]
if yalu_ds != "SeaSponge/wildme10_classify":
    wildlife_val_ds = wildlife_ds["validation"]
wildlife_test_ds = wildlife_ds["test"]
del wildlife_ds

In [5]:
# 11 labels for wildme10_classify
if yalu_ds == "SeaSponge/wildme10_classify":
    num_labels = 11
else:
    num_labels = 2
print(wildlife_train_ds[0])

{'file_name': '50002127.jpg', 'image_id': 7399, 'width': 438, 'height': 500, 'pixel_values': [[[-0.8588235378265381, -0.8509804010391235, -0.843137264251709, -0.8274509906768799, -0.8196078538894653, -0.8352941274642944, -0.843137264251709, -0.8274509906768799, -0.7960784435272217, -0.7803921699523926, -0.772549033164978, -0.7882353067398071, -0.7882353067398071, -0.7882353067398071, -0.7882353067398071, -0.7882353067398071, -0.7960784435272217, -0.7960784435272217, -0.7882353067398071, -0.7803921699523926, -0.772549033164978, -0.7803921699523926, -0.7882353067398071, -0.8039215803146362, -0.8039215803146362, -0.8039215803146362, -0.8039215803146362, -0.8039215803146362, -0.8039215803146362, -0.8039215803146362, -0.8039215803146362, -0.8117647171020508, -0.8274509906768799, -0.8509804010391235, -0.8588235378265381, -0.8588235378265381, -0.8745098114013672, -0.9137254953384399, -0.9137254953384399, -0.9058823585510254, -0.9058823585510254, -0.8823529481887817, -0.8823529481887817, -0.87

In [6]:
device = "cuda" if torch.cuda.is_available else "cpu"

In [7]:
# def show_examples(ds, seed: int = 1234, examples_per_class: int = 3, size=(350, 350)):

#     w, h = size
#     print(f"ds.features -> {ds.features}")
#     labels = ds.features['objects']['label']
#     if not isinstance(labels, list):
#         labels = [labels]
#     grid = Image.new('RGB', size=(examples_per_class * w, len(labels) * h))
#     draw = ImageDraw.Draw(grid)
#     font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationMono-Bold.ttf", 24)

#     for label_id, label in enumerate(labels):

#         # Filter the dataset by a single label, shuffle it, and grab a few samples
#         ds_slice = ds.filter(lambda ex: ex['labels'] == label_id).shuffle(seed).select(range(examples_per_class))

#         # Plot this label's examples along a row
#         for i, example in enumerate(ds_slice):
#             image = example['image']
#             idx = examples_per_class * label_id + i
#             box = (idx % examples_per_class * w, idx // examples_per_class * h)
#             grid.paste(image.resize(size), box=box)
#             draw.text(box, label, (255, 255, 255), font=font)

#     return grid

In [8]:
# show_examples(wildlife_val_ds, seed=random.randint(0, 1337), examples_per_class=3)

In [9]:
if yalu_ds != "SeaSponge/wildme10_classify":
    label2id = {"leopard" : 0, "hyena": 1}
else:
    label2id = {"lion": 0, "raccoon": 1, "tiger": 2, "wolf": 3,
            "bear": 4, "hare": 5, "fox": 6, "deer": 7, 
            "leopard" : 8, "hyena": 9, "antelope": 10}


id2label = {v:k for k, v in label2id.items()}

In [10]:
checkpoint = "google/vit-base-patch16-224-in21k"
model = GlowViT.from_pretrained(checkpoint,
                label2id=label2id,
                id2label=id2label,
                num_labels=num_labels,
                attn_implementation="sdpa") # no flash attention yet for ViT model

image_processor = ViTImageProcessor.from_pretrained(checkpoint)

Some weights of GlowViT were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
if yalu_ds != "SeaSponge/wildme10_classify":
    def transform_ds(examples):
        images, labels = [], []
        for image, objects in zip(examples["image"], examples["objects"]):
            # image = np.array(image.convert("RGB"))[:, :, ::-1]
            #! category needs to be list, hence I put the extra brackets around it
            out = image_processor(images=image.convert("RGB"), return_tensors="pt")["pixel_values"].squeeze(0)

            label = objects["category"][0]
            images.append(out)
            labels.append(label)

        return {"pixel_values": images, "labels": labels}
else:
    def transform_ds(examples):
        images, labels = [], []
        for image, label in zip(examples["image"], examples["labels"]):
            pix_val = image_processor(images=image.convert("RGB"), return_tensors="pt")["pixel_values"].squeeze(0)
            pix_val.to(device)
            #! supposed to be a number here
            label = label2id[label]
            images.append(pix_val)
            labels.append(label)

        return {"pixel_values": images, "labels": labels}

In [12]:
wildlife_train_ds_transformed = wildlife_train_ds.with_transform(transform_ds)
test_ds = wildlife_test_ds.with_transform(transform_ds)

In [13]:
print(test_ds[0])

{'pixel_values': tensor([[[-0.2000, -0.2392, -0.0275,  ...,  0.1843,  0.1843,  0.2157],
         [-0.1451, -0.1843, -0.0588,  ...,  0.1843,  0.1843,  0.2157],
         [-0.0902, -0.2000, -0.1294,  ...,  0.2157,  0.2157,  0.2314],
         ...,
         [-0.6549, -0.7333, -0.8745,  ..., -0.6392, -0.7020, -0.7804],
         [-0.8510, -0.8745, -0.9451,  ..., -0.5922, -0.6078, -0.6549],
         [-0.9059, -0.8667, -0.8118,  ..., -0.6078, -0.5608, -0.5373]],

        [[-0.3569, -0.3647, -0.1294,  ..., -0.0824, -0.0745, -0.0431],
         [-0.2941, -0.3098, -0.1608,  ..., -0.0824, -0.0745, -0.0431],
         [-0.2314, -0.3255, -0.2314,  ..., -0.0510, -0.0431, -0.0275],
         ...,
         [-0.5608, -0.6471, -0.7961,  ..., -0.7255, -0.7725, -0.8431],
         [-0.7490, -0.7804, -0.8745,  ..., -0.6784, -0.6784, -0.7176],
         [-0.8039, -0.7647, -0.7333,  ..., -0.6941, -0.6314, -0.6000]],

        [[-0.6549, -0.6392, -0.3569,  ..., -0.3882, -0.4745, -0.4667],
         [-0.5922, -0.5843, 

In [14]:
del wildlife_train_ds
del wildlife_test_ds
if yalu_ds != "SeaSponge/wildme10_classify":
    wildlife_val_ds_transformed = wildlife_val_ds.with_transform(transform_ds)
    del wildlife_val_ds

In [15]:
data_collator = DefaultDataCollator()

In [16]:
# Define the training arguments

# It doesn't utilize GPU effectively, so trying some larger ViT models
# Update: ViT large takes around 1GB, not very ideal. Some going to try other things

#! with ViT-base, can afford to have large batchsizes
# batchsz=2 -> 2133MiB / 11264MiB
# running with larger batchsize actually increase training time, trying to find optimal largest batch
# Update: This seems to go away when I put larger epoch numbers
# When epoch=[10,20] -> 5, 6, 16, 32 (6117MiB / 11264MiB), 50 (8055MiB / 11264MiB)

training_args = TrainingArguments(
    output_dir="glow-model",
    per_device_train_batch_size=24,
    gradient_accumulation_steps=1,
    num_train_epochs=5,
    # max_steps=2000,
    fp16=False,
    save_steps=10,
    eval_steps=50,
    logging_steps=30,
    learning_rate=2e-4,
    save_total_limit=1,
    remove_unused_columns=False,
    push_to_hub=True,
    lr_scheduler_type="cosine_with_restarts",  # Cosine scheduler with restarts
)

# Define the trainer

if yalu_ds != "SeaSponge/wildme10_classify":
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=wildlife_train_ds_transformed,
        eval_dataset=wildlife_val_ds_transformed,
        tokenizer=image_processor,
    )

    train_results = trainer.train()
    trainer.log_metrics("train", train_results.metrics)
else:
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=wildlife_train_ds_transformed,
        tokenizer=image_processor,
    )

    train_results = trainer.train()
    trainer.log_metrics("train", train_results.metrics)

  trainer = Trainer(


  0%|          | 0/1765 [00:00<?, ?it/s]

  context_layer = torch.nn.functional.scaled_dot_product_attention(


{'loss': 1.4057, 'grad_norm': 1.7096643447875977, 'learning_rate': 0.00019985746561607698, 'epoch': 0.08}
{'loss': 0.5294, 'grad_norm': 4.832664966583252, 'learning_rate': 0.00019943026878531983, 'epoch': 0.17}
{'loss': 0.3333, 'grad_norm': 4.434501647949219, 'learning_rate': 0.0001987196273124703, 'epoch': 0.25}
{'loss': 0.3193, 'grad_norm': 2.6981236934661865, 'learning_rate': 0.00019772756701441887, 'epoch': 0.34}
{'loss': 0.3156, 'grad_norm': 10.102080345153809, 'learning_rate': 0.0001964569159452335, 'epoch': 0.42}
{'loss': 0.3649, 'grad_norm': 2.8132810592651367, 'learning_rate': 0.00019491129633426068, 'epoch': 0.51}
{'loss': 0.2824, 'grad_norm': 2.505631923675537, 'learning_rate': 0.00019309511426028104, 'epoch': 0.59}
{'loss': 0.2585, 'grad_norm': 3.7677249908447266, 'learning_rate': 0.00019101354709115468, 'epoch': 0.68}
{'loss': 0.2592, 'grad_norm': 2.647081136703491, 'learning_rate': 0.00018867252872476257, 'epoch': 0.76}
{'loss': 0.2705, 'grad_norm': 1.9988516569137573, 'l

In [18]:
if yalu_ds != "SeaSponge/wildme10_classify":
    metrics = trainer.evaluate(wildlife_val_ds_transformed)
    trainer.log_metrics("eval", metrics)

In [19]:
metrics = trainer.evaluate(test_ds)
trainer.log_metrics("test", metrics)

  0%|          | 0/454 [00:00<?, ?it/s]

***** test metrics *****
  epoch                   =        5.0
  eval_loss               =     0.1419
  eval_runtime            = 0:05:34.25
  eval_samples_per_second =     10.863
  eval_steps_per_second   =      1.358
