In [None]:
"""Interactive python notebook containing attempt to create a model with a custom MLP head to include longitudinal data during classification.
Ultimately did not produce enough results to warrant use."""

In [None]:
from huggingface_hub import login
from datasets import load_dataset, Dataset
from transformers import AutoImageProcessor, DefaultDataCollator, AutoModelForImageClassification, TrainingArguments, Trainer, pipeline
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor, RandomApply
import torchvision.transforms.v2 as v2
from torch.nn import ModuleList
import evaluate
import numpy as np
import wandb

WRITE_TOKEN = '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'
MODEL_NAME = 'aug_toy_model'
login(token=WRITE_TOKEN, write_permission=True)

DATASET_DIRECTORY = '/Users/uochuba/Documents/Stanford/Senior/CS229/custom_transformer/trees_dataset'

In [None]:
dataset = load_dataset("imagefolder", data_dir=DATASET_DIRECTORY)

# label the names
labels = dataset["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label
    
checkpoint = "google/vit-base-patch16-224-in21k"

In [None]:
class CustomModel(nn.Module):
  def __init__(self,checkpoint,num_labels): 
    super(CustomModel,self).__init__() 
    self.num_labels = 4 

    #Load Model with given checkpoint and extract its body
    self.model = model = AutoModelForImageClassification.from_pretrained(checkpoint,config=AutoConfig.from_pretrained(checkpoint, output_attentions=True,output_hidden_states=True))
    # self.dropout = nn.Dropout(0.1) 
    # self.classifier = nn.Linear(768,num_labels) # load and initialize weights
    self.classifier = nn.Sequential(
                                    nn.GELU(),
                                    nn.GELU(),
                                    nn.Linear(768, self.num_labels)
                                    )

  def forward(self, input_ids=None, attention_mask=None,labels=None):
    #Extract outputs from the body
    outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)

    #Add custom layers
    sequence_output = self.dropout(outputs[0], inp) #outputs[0]=last hidden state

    logits = self.classifier(sequence_output[:,0,:].view(-1,768)) # calculate losses
    
    loss = None
    if labels is not None:
      loss_fct = nn.CrossEntropyLoss()
      loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
    
    return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states,attentions=outputs.attentions)

In [None]:
image_processor = AutoImageProcessor.from_pretrained(checkpoint)

normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (
    image_processor.size["shortest_edge"]
    if "shortest_edge" in image_processor.size
    else (image_processor.size["height"], image_processor.size["width"])
)

_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])

def transforms(examples):
    examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
    del examples["image"]
    return examples

dataset = dataset.with_transform(transforms)

data_collator = DefaultDataCollator()

accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

model = AutoModelForImageClassification.from_pretrained(
    checkpoint,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id,
)

training_args = TrainingArguments(
    output_dir="tree_class_model",
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=32,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=16,
    num_train_epochs=300,
    warmup_ratio=0.1,
    logging_strategy='epoch',
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=True,
    report_to="wandb"
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
)

trainer.train()