In [1]:
print('test')

test


In [2]:
# Bulk import cell
import wandb
import random
import torch
import transformers
import numpy as np
import pandas as pd
import pytorch_lightning as pl
from vit_pytorch.mobile_vit import MobileViT
import torchvision
from torchvision import transforms
from torch.nn import functional as F
from x_transformers import ViTransformerWrapper, Encoder




In [3]:

dataset_path = '/Users/rotemisraeli/Documents/datasets/TinyImageNet/'
# batch_size = 2
pl.seed_everything(1234)

Global seed set to 1234


1234

In [4]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mrotem98[0m (use `wandb login --relogin` to force relogin)


True

In [5]:
project = "mobile_test2"  # W&B project name here
entity = 'rotem98'  # your W&B username or teamname here

In [6]:
class SentenceClassifier(pl.LightningModule):

  def __init__(self, learning_rate=5e-5):
    super(SentenceClassifier, self).__init__()
    # self.model = MobileViT(
    #   image_size = (64, 64),
    #   dims = [96, 120, 144],
    #   channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384],
    #   num_classes = 100
    # )
    self.model = ViTransformerWrapper(
      image_size = 64,
      patch_size = 8,
      num_classes = 100,
      attn_layers = Encoder(
          dim = 256,
          depth = 4,
          heads = 4,
          use_qk_norm_attn = True, # set this to True
          qk_norm_attn_seq_len = 256 # set this to max_seq_len from above
      )
    )
    self.learning_rate = learning_rate

  def training_step(self, batch, batch_no):
    input, labels = batch
    outputs = self.model(input)
    preds = torch.argmax(outputs, axis=1)
    correct = sum(preds.flatten() == labels.flatten())
    loss = F.cross_entropy(outputs, labels)
    self.log("train/loss", loss, on_step=True, on_epoch=True)
    self.log("train/acc", correct/len(batch), on_step=True, on_epoch=True)
    return loss

  def validation_step(self, batch, batch_no):
    input, labels = batch
    outputs = self.model(input)
    preds = torch.argmax(outputs, axis=1)
    correct = sum(preds.flatten() == labels.flatten())
    loss = F.cross_entropy(outputs, labels)
    self.log("val/loss", loss, on_step=False, on_epoch=True)
    self.log("val/acc", correct/len(batch), on_step=False, on_epoch=True)

  def configure_optimizers(self):
    return transformers.AdamW(
        self.model.parameters(),
        lr = self.learning_rate,
        eps = 1e-8
    )


In [7]:
def train(config={"learning_rate": 5e-5, "batch_size": 8, "epochs": 2}):

  with wandb.init(project=project, entity=entity, job_type="train", config=config,dir='wandb_dir') as run:
    config = run.config

    TRANSFORM_IMG = transforms.Compose([
      transforms.Resize(64),
      transforms.CenterCrop(64),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225] )
    ])
    train_data = torchvision.datasets.ImageFolder(root=dataset_path+'train', transform=TRANSFORM_IMG)
    train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, shuffle=True,  num_workers=4)
    val_data = torchvision.datasets.ImageFolder(root=dataset_path+'val', transform=TRANSFORM_IMG)
    val_data_loader = torch.utils.data.DataLoader(val_data, batch_size=config.batch_size, shuffle=False,  num_workers=4)

    model = SentenceClassifier(learning_rate=config.learning_rate)

    logger = pl.loggers.WandbLogger(experiment=run, log_model=True,save_dir='wandb_logger_savedir')

    gpus = -1 if torch.cuda.is_available() else 0

    trainer = pl.Trainer(max_epochs=config.epochs, gpus=gpus, logger=logger,weights_save_path='models')

    trainer.fit(model, train_data_loader, val_data_loader)

In [8]:
train()

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name  | Type      | Params
------------------------------------
0 | model | MobileViT | 2.0 M 
------------------------------------
2.0 M     Trainable params
0         Non-trainable params
2.0 M     Total params
8.149     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

wandb: Network error (ReadTimeout), entering retry loop.





VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/acc_step,▁▁▁▅▁▅▅▁▁▅▅▁▁▁▁▅█▁▁▅▅▁▁▅▁█▅▅▁▁▅▅▁▁▁▅▁▅▁▁
train/loss_step,▆▇▆▆▅▅▅▇▆▆▆▆▅▅▆▆▃▅▆▆▄▄▃▃█▂▄▄▄▆▂▅▆▄▇▁▆▅▄▇
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███

0,1
epoch,0.0
train/acc_step,0.5
train/loss_step,4.20482
trainer/global_step,4799.0
