# load dataset

In [1]:
from datasets import load_dataset
data = load_dataset("./data/20230419_224x224/20230419_224x224.py", cache_dir="/pscratch/sd/s/shubh/")

Found cached dataset 20230419_224x224 (/pscratch/sd/s/shubh/20230419_224x224/first_domain/1.1.0/09ca5b8cc58d1d3ba6ea901ee6712e987b511b5da6e12fc36863d53a37c994e7)


  0%|          | 0/3 [00:00<?, ?it/s]

In [2]:
labels = ["H0", "Ob", "Om", "ns", "s8", "w0"]
id2label = {labels.index(label):label for label in data["train"].features if label in labels}
# id2label = {id:label for id, label in enumerate(["Om", "s8"])}
label2id = {label:id for id,label in id2label.items()}
id2label, label2id

({0: 'H0', 1: 'Ob', 2: 'Om', 3: 'ns', 4: 's8', 5: 'w0'},
 {'H0': 0, 'Ob': 1, 'Om': 2, 'ns': 3, 's8': 4, 'w0': 5})

In [3]:
data

DatasetDict({
    train: Dataset({
        features: ['As', 'bary_Mc', 'bary_nu', 'H0', 'O_cdm', 'O_nu', 'Ob', 'Om', 'ns', 's8', 'w0', 'sim_type', 'sim_name', 'map'],
        num_rows: 135730
    })
    validation: Dataset({
        features: ['As', 'bary_Mc', 'bary_nu', 'H0', 'O_cdm', 'O_nu', 'Ob', 'Om', 'ns', 's8', 'w0', 'sim_type', 'sim_name', 'map'],
        num_rows: 16940
    })
    test: Dataset({
        features: ['As', 'bary_Mc', 'bary_nu', 'H0', 'O_cdm', 'O_nu', 'Ob', 'Om', 'ns', 's8', 'w0', 'sim_type', 'sim_name', 'map'],
        num_rows: 16940
    })
})

# load pretrained model

In [4]:
%%capture

from transformers import AutoImageProcessor, ViTForImageClassification, ViTConfig, ViTModel

checkpoint = "google/vit-base-patch16-224-in21k"
model = ViTForImageClassification.from_pretrained(checkpoint,
            problem_type = "regression", id2label=id2label, label2id=label2id, hidden_dropout_prob=0.1,
            num_channels=4, image_size=224, patch_size=16, ignore_mismatched_sizes=True)

print(model.config.problem_type)
print(model.config.num_labels)

model

In [5]:
import numpy as np 
from matplotlib import pyplot as plt
from PIL import Image
import torch

size = (224, 224)

print(size)

from torchvision.transforms import (CenterCrop, 
                                    Compose, 
                                    Normalize, 
                                    RandomHorizontalFlip,
                                    RandomVerticalFlip,
                                    RandomResizedCrop, 
                                    Resize, 
                                    ToTensor)

normalize = lambda img: (img - torch.mean(img)) / torch.std(img)

train_data_augmentation = Compose(
        [
            RandomHorizontalFlip(),
            RandomVerticalFlip(),
            normalize,
        ]
    )

val_data_augmentation = Compose(
        [
            normalize,
        ]
    )

def preprocess_train(examples):
    # examples["labels"] = np.transpose([examples[x] for x in examples.keys() if x != "map"]).astype(np.float32)
    examples["labels"] = np.transpose([examples[x] for x in labels]).astype(np.float32)
    examples['pixel_values'] = [train_data_augmentation(torch.swapaxes(torch.Tensor(np.array(image)), 0, 2)) for image in examples['map']]
    return examples

def preprocess_val(examples):
    # examples["labels"] = np.transpose([examples[x] for x in examples.keys() if x != "map"]).astype(np.float32)
    examples["labels"] = np.transpose([examples[x] for x in labels]).astype(np.float32)
    examples['pixel_values'] = [val_data_augmentation(torch.swapaxes(torch.Tensor(np.array(image)), 0, 2)) for image in examples['map']]
    return examples

data["train"].set_transform(preprocess_train)
data["validation"].set_transform(preprocess_val)
data["test"].set_transform(preprocess_val)

(224, 224)


In [6]:
from torch.utils.data import DataLoader
import torch

subset = "train"

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["labels"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

train_dataloader = DataLoader(data[subset], collate_fn=collate_fn, batch_size=4)

batch = next(iter(train_dataloader))
for k,v in batch.items():
    if isinstance(v, torch.Tensor):
        print(k, v.shape)
        
from transformers import create_optimizer

batch_size = 32
num_epochs = 2
num_train_steps = len(data[subset]) * num_epochs
learning_rate = 0.001
weight_decay_rate = 0.001

optimizer, lr_schedule = create_optimizer(
    init_lr=learning_rate,
    num_train_steps=num_train_steps,
    weight_decay_rate=weight_decay_rate,
    num_warmup_steps=0,
)

from transformers import TrainingArguments, Trainer

args = TrainingArguments(
    f"./temp/dropout",
    save_strategy="epoch",
    evaluation_strategy="epoch",
    num_train_epochs=num_epochs,
    weight_decay=weight_decay_rate,
    load_best_model_at_end=True,
    logging_dir='logs',
    remove_unused_columns=False,
)

  labels = torch.tensor([example["labels"] for example in examples])


pixel_values torch.Size([4, 4, 224, 224])
labels torch.Size([4, 6])


2023-06-20 08:40:33.391887: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [7]:
import sys
sys.path.append("./scripts/")
from TrainerWithDropout import DropoutTrainer

import torch

trainer = DropoutTrainer(
    model,
    args,
    train_dataset= data[subset],
    eval_dataset= data["validation"],
    data_collator=collate_fn
)

In [65]:
trainer.train()



Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

In [8]:
n_pred = 5
preds = np.empty((n_pred, len(data["validation"]), 6))
for i in range(n_pred):
    if i % 2 == 0: 
        print(i)
    ps = trainer.predict(data["validation"])
    preds[i] = ps.predictions
    if i == 0:
        label_ids = ps.label_ids
# preds, label_ids

0


KeyboardInterrupt: 

In [9]:
plot_y = label_ids
predictions_best = np.nanmean(preds, axis=0)
predictions_std = np.nanstd(preds, axis=0)

upp_lims = np.nanmax(plot_y, axis=0)
low_lims = np.nanmin(plot_y, axis=0)
fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(10, 5))
fig.subplots_adjust(wspace=0.3, hspace=0.2)
labels = [r"$\Omega_m$", r"$\sigma_8$"]
for ind, (label, ax, low_lim, upp_lim) in enumerate(zip(labels, axs.ravel(), low_lims, upp_lims)):
    p = np.poly1d(np.polyfit(plot_y[:, ind], predictions_best[:, ind], 1))
    # ax.errorbar(plot_y[:, ind][::10], predictions_best[:, ind][::10],  predictions_std[:, ind][::10], marker="x", ls='none', alpha=0.4)
    ax.errorbar(plot_y[:, ind][::10], predictions_best[:, ind][::10],  0, marker="x", ls='none', alpha=0.4)
    ax.set_xlabel("true")
    ax.set_ylabel("prediction")
    ax.plot([low_lim, upp_lim], [low_lim, upp_lim], color="black")
    ax.plot([low_lim, upp_lim], [p(low_lim), p(upp_lim)], color="black", ls=":")
    ax.set_xlim([low_lim, upp_lim])
    ax.set_ylim([low_lim, upp_lim])
    ax.set_aspect('equal', adjustable='box')
    ax.set_title(label)
    ax.grid()
plt.show()
plt.close()

NameError: name 'label_ids' is not defined